diff --git a/tests/api/test_mldsa.c b/tests/api/test_mldsa.c index a9d2e1980..69a2668f3 100644 --- a/tests/api/test_mldsa.c +++ b/tests/api/test_mldsa.c @@ -16785,3 +16785,165 @@ int test_mldsa_pkcs8(void) #endif return EXPECT_RESULT(); } + +int test_mldsa_pkcs12(void) +{ + EXPECT_DECLS; +#if !defined(NO_ASN) && defined(HAVE_PKCS12) && \ + defined(HAVE_DILITHIUM) && !defined(NO_TLS) && \ + !defined(NO_PWDBASED) && !defined(NO_HMAC) && \ + !defined(NO_CERTS) && !defined(NO_DES3) && \ + (!defined(NO_WOLFSSL_CLIENT) || !defined(NO_WOLFSSL_SERVER)) + + WOLFSSL_CTX* ctx = NULL; + word32 i; + byte* inKey = NULL; + byte* inCert = NULL; + const word32 inKeyHeaderSz = 4; + const word32 inKeyMaxSz = inKeyHeaderSz + DILITHIUM_MAX_PRV_KEY_SIZE; + const word32 certConstSz = 412; + const word32 inCertMaxSz = + certConstSz + DILITHIUM_MAX_SIG_SIZE + DILITHIUM_MAX_PUB_KEY_SIZE; + const word32 pkcs8HeaderSz = 24; + WC_RNG rng; + dilithium_key mldsa_key; + char pkcs12Passwd[] = "mldsa"; + + struct { + int enc; + int wcId; + int oidSum; + int keySz; + int sigType; + int keyType; + } test_variant[] = { + {PBE_SHA1_DES3, WC_ML_DSA_44, ML_DSA_LEVEL2k, + ML_DSA_LEVEL2_PRV_KEY_SIZE, CTC_ML_DSA_LEVEL2, ML_DSA_LEVEL2_TYPE}, + {PBE_SHA1_DES3, WC_ML_DSA_65, ML_DSA_LEVEL3k, + ML_DSA_LEVEL3_PRV_KEY_SIZE, CTC_ML_DSA_LEVEL3, ML_DSA_LEVEL3_TYPE}, + {PBE_SHA1_DES3, WC_ML_DSA_87, ML_DSA_LEVEL5k, + ML_DSA_LEVEL5_PRV_KEY_SIZE, CTC_ML_DSA_LEVEL5, ML_DSA_LEVEL5_TYPE}, + {-1, WC_ML_DSA_44, ML_DSA_LEVEL2k, + ML_DSA_LEVEL2_PRV_KEY_SIZE, CTC_ML_DSA_LEVEL2, ML_DSA_LEVEL2_TYPE}, + {-1, WC_ML_DSA_65, ML_DSA_LEVEL3k, + ML_DSA_LEVEL3_PRV_KEY_SIZE, CTC_ML_DSA_LEVEL3, ML_DSA_LEVEL3_TYPE}, + {-1, WC_ML_DSA_87, ML_DSA_LEVEL5k, + ML_DSA_LEVEL5_PRV_KEY_SIZE, CTC_ML_DSA_LEVEL5, ML_DSA_LEVEL5_TYPE}, + }; + + ExpectNotNull(inKey = (byte*) XMALLOC(inKeyMaxSz, NULL, + DYNAMIC_TYPE_TMP_BUFFER)); + ExpectNotNull(inCert = (byte*) XMALLOC(inCertMaxSz, NULL, + DYNAMIC_TYPE_TMP_BUFFER)); + +#ifndef NO_WOLFSSL_SERVER + ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_server_method())); +#else + ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_client_method())); +#endif /* NO_WOLFSSL_SERVER */ + + ExpectIntEQ(wc_InitRng(&rng), 0); + ExpectIntEQ(wc_dilithium_init(&mldsa_key), 0); + + for (i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) { + WC_PKCS12* pkcs12Export = NULL; + WC_PKCS12* pkcs12Import = NULL; + byte* pkcs12Der = NULL; + byte* outKey = NULL; + byte* outCert = NULL; + word32 inKeySz = 0; + word32 inCertSz = 0; + word32 pkcs12DerSz = 0; + word32 outKeySz = 0; + word32 outCertSz = 0; + Cert cert; + word32 size; + + if (EXPECT_FAIL()) + break; + + /* Create a key for wc_PKCS12_create() */ + inKeySz = 0; + inKey[0] = 0x04; /* ASN.1 OCTET STRING */ + inKey[1] = 0x82; /* 2 bytes length field */ + inKey[2] = (test_variant[i].keySz >> 8) & 0xff; /* MSB of the length */ + inKey[3] = test_variant[i].keySz & 0xff; /* LSB of the length */ + inKeySz += inKeyHeaderSz; + ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), + 0); + ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0); + size = inKeyMaxSz - inKeySz; + ExpectIntEQ(wc_dilithium_export_private(&mldsa_key, inKey + inKeySz, + &size), 0); + inKeySz += size; + size = inKeyMaxSz - inKeySz; + ExpectIntEQ(wc_dilithium_export_public(&mldsa_key, inKey + inKeySz, + &size), 0); + inKeySz += size; + + /* Create a certificate for wc_PKCS12_create() */ + ExpectIntEQ(wc_InitCert(&cert), 0); + XSTRNCPY(cert.subject.country, "US", CTC_NAME_SIZE); + XSTRNCPY(cert.subject.state, "MT", CTC_NAME_SIZE); + XSTRNCPY(cert.subject.locality, "Bozeman", CTC_NAME_SIZE); + XSTRNCPY(cert.subject.org, "wolfSSL", CTC_NAME_SIZE); + XSTRNCPY(cert.subject.unit, "Engineering", CTC_NAME_SIZE); + XSTRNCPY(cert.subject.commonName, "www.wolfssl.com", CTC_NAME_SIZE); + XSTRNCPY(cert.subject.email, "root@wolfssl.com", CTC_NAME_SIZE); + XSTRNCPY((char*)cert.beforeDate, "\x18\x0f""20250101000000Z", + CTC_DATE_SIZE); + cert.beforeDateSz = 17; + XSTRNCPY((char*)cert.afterDate, "\x18\x0f""20493112115959Z", + CTC_DATE_SIZE); + cert.afterDateSz = 17; + cert.selfSigned = 1; + cert.sigType = test_variant[i].sigType; + cert.isCA = 0; + ExpectIntGE(inCertSz = wc_MakeCert_ex(&cert, inCert, inCertMaxSz, + test_variant[i].keyType, &mldsa_key, &rng), 0); + ExpectIntGE(inCertSz = wc_SignCert_ex(cert.bodySz, cert.sigType, inCert, + inCertMaxSz, test_variant[i].keyType, &mldsa_key, &rng), 0); + + ExpectNotNull(pkcs12Export = wc_PKCS12_create(pkcs12Passwd, + sizeof(pkcs12Passwd) - 1, + (char*) "friendlyName" /* not used currently */, + (byte*) inKey, inKeySz, (byte*) inCert, inCertSz, + NULL, test_variant[i].enc, test_variant[i].enc, 100, 100, + 0 /* not used currently */, NULL)); + pkcs12Der = NULL; + ExpectIntGE((pkcs12DerSz = wc_i2d_PKCS12(pkcs12Export, &pkcs12Der, + NULL)), 0); + + ExpectNotNull(pkcs12Import = wc_PKCS12_new_ex(NULL)); + ExpectIntGE(wc_d2i_PKCS12(pkcs12Der, pkcs12DerSz, pkcs12Import), 0); + ExpectIntEQ(wc_PKCS12_parse_ex(pkcs12Import, pkcs12Passwd, &outKey, + &outKeySz, + &outCert, &outCertSz, NULL, 1), 0); + ExpectIntGT(outKeySz, 0); + ExpectIntGT(outCertSz, 0); + ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, outKey, outKeySz, + WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS); + ExpectIntEQ(wolfSSL_CTX_use_certificate_buffer(ctx, outCert, outCertSz, + WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS); + + ExpectIntEQ(inKeySz, outKeySz - pkcs8HeaderSz); + ExpectIntEQ(XMEMCMP(inKey, outKey + pkcs8HeaderSz, inKeySz), 0); + ExpectIntEQ(inCertSz, outCertSz); + ExpectIntEQ(XMEMCMP(inCert, outCert, inCertSz), 0); + + XFREE(outKey, NULL, DYNAMIC_TYPE_PUBLIC_KEY); + XFREE(outCert, NULL, DYNAMIC_TYPE_PKCS); + wc_PKCS12_free(pkcs12Import); + XFREE(pkcs12Der, NULL, DYNAMIC_TYPE_PKCS); + wc_PKCS12_free(pkcs12Export); + } + + wc_dilithium_free(&mldsa_key); + ExpectIntEQ(wc_FreeRng(&rng), 0); + wolfSSL_CTX_free(ctx); + XFREE(inCert, NULL, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(inKey, NULL, DYNAMIC_TYPE_TMP_BUFFER); + +#endif + return EXPECT_RESULT(); +} diff --git a/tests/api/test_mldsa.h b/tests/api/test_mldsa.h index 38568d275..f444e5695 100644 --- a/tests/api/test_mldsa.h +++ b/tests/api/test_mldsa.h @@ -36,6 +36,7 @@ int test_wc_dilithium_make_key_from_seed(void); int test_wc_dilithium_sig_kats(void); int test_wc_dilithium_verify_kats(void); int test_mldsa_pkcs8(void); +int test_mldsa_pkcs12(void); #define TEST_MLDSA_DECLS \ TEST_DECL_GROUP("mldsa", test_wc_dilithium), \ @@ -49,6 +50,7 @@ int test_mldsa_pkcs8(void); TEST_DECL_GROUP("mldsa", test_wc_dilithium_make_key_from_seed), \ TEST_DECL_GROUP("mldsa", test_wc_dilithium_sig_kats), \ TEST_DECL_GROUP("mldsa", test_wc_dilithium_verify_kats), \ - TEST_DECL_GROUP("mldsa", test_mldsa_pkcs8) + TEST_DECL_GROUP("mldsa", test_mldsa_pkcs8), \ + TEST_DECL_GROUP("mldsa", test_mldsa_pkcs12) #endif /* WOLFCRYPT_TEST_MLDSA_H */ diff --git a/wolfcrypt/src/asn.c b/wolfcrypt/src/asn.c index 8148a2cba..a541b737d 100644 --- a/wolfcrypt/src/asn.c +++ b/wolfcrypt/src/asn.c @@ -9155,9 +9155,12 @@ int wc_GetKeyOID(byte* key, word32 keySz, const byte** curveOID, word32* oidSz, if (dilithium == NULL) return MEMORY_E; - if (wc_dilithium_init(dilithium) != 0) { - tmpIdx = 0; - if (wc_dilithium_set_level(dilithium, WC_ML_DSA_44) == 0) { + /* wc_dilithium_init() returns 0 on success and a non-zero value on + * failure. */ + if (wc_dilithium_init(dilithium) == 0) { + if ((*algoID == 0) && + (wc_dilithium_set_level(dilithium, WC_ML_DSA_44) == 0)) { + tmpIdx = 0; if (wc_Dilithium_PrivateKeyDecode(key, &tmpIdx, dilithium, keySz) == 0) { *algoID = ML_DSA_LEVEL2k; @@ -9166,7 +9169,9 @@ int wc_GetKeyOID(byte* key, word32 keySz, const byte** curveOID, word32* oidSz, WOLFSSL_MSG("Not Dilithium Level 2 DER key"); } } - else if (wc_dilithium_set_level(dilithium, WC_ML_DSA_65) == 0) { + if ((*algoID == 0) && + (wc_dilithium_set_level(dilithium, WC_ML_DSA_65) == 0)) { + tmpIdx = 0; if (wc_Dilithium_PrivateKeyDecode(key, &tmpIdx, dilithium, keySz) == 0) { *algoID = ML_DSA_LEVEL3k; @@ -9175,7 +9180,9 @@ int wc_GetKeyOID(byte* key, word32 keySz, const byte** curveOID, word32* oidSz, WOLFSSL_MSG("Not Dilithium Level 3 DER key"); } } - else if (wc_dilithium_set_level(dilithium, WC_ML_DSA_87) == 0) { + if ((*algoID == 0) && + (wc_dilithium_set_level(dilithium, WC_ML_DSA_87) == 0)) { + tmpIdx = 0; if (wc_Dilithium_PrivateKeyDecode(key, &tmpIdx, dilithium, keySz) == 0) { *algoID = ML_DSA_LEVEL5k; diff --git a/wolfcrypt/src/pkcs12.c b/wolfcrypt/src/pkcs12.c index 9d026792b..2342cd632 100644 --- a/wolfcrypt/src/pkcs12.c +++ b/wolfcrypt/src/pkcs12.c @@ -1297,6 +1297,27 @@ static int PKCS12_CoalesceOctetStrings(WC_PKCS12* pkcs12, byte* data, int wc_PKCS12_parse(WC_PKCS12* pkcs12, const char* psw, byte** pkey, word32* pkeySz, byte** cert, word32* certSz, WC_DerCertList** ca) +{ + return wc_PKCS12_parse_ex(pkcs12, psw, pkey, pkeySz, cert, certSz, ca, 0); +} + +/* return 0 on success and negative on failure. + * By side effect returns private key, cert, and optionally ca. + * Parses and decodes the parts of PKCS12 + * + * NOTE: can parse with USER RSA enabled but may return cert that is not the + * pair for the key when using RSA key pairs. + * + * pkcs12 : non-null WC_PKCS12 struct + * psw : password to use for PKCS12 decode + * pkey : Private key returned + * cert : x509 cert returned + * ca : optional ca returned + * keepKeyHeader : 0 removes PKCS8 header, other than 0 keeps PKCS8 header + */ +int wc_PKCS12_parse_ex(WC_PKCS12* pkcs12, const char* psw, + byte** pkey, word32* pkeySz, byte** cert, word32* certSz, + WC_DerCertList** ca, int keepKeyHeader) { ContentInfo* ci = NULL; WC_DerCertList* certList = NULL; @@ -1492,7 +1513,13 @@ int wc_PKCS12_parse(WC_PKCS12* pkcs12, const char* psw, ERROR_OUT(MEMORY_E, exit_pk12par); } XMEMCPY(*pkey, data + idx, (size_t)size); - *pkeySz = (word32)ToTraditional_ex(*pkey, (word32)size, &algId); + if (keepKeyHeader) { + *pkeySz = (word32)size; + } + else { + *pkeySz = (word32)ToTraditional_ex(*pkey, + (word32)size, &algId); + } } #ifdef WOLFSSL_DEBUG_PKCS12 @@ -1531,10 +1558,19 @@ int wc_PKCS12_parse(WC_PKCS12* pkcs12, const char* psw, XMEMCPY(k, data + idx, (size_t)size); /* overwrites input, be warned */ - if ((ret = ToTraditionalEnc(k, (word32)size, psw, pswSz, - &algId)) < 0) { - XFREE(k, pkcs12->heap, DYNAMIC_TYPE_PUBLIC_KEY); - goto exit_pk12par; + if (keepKeyHeader) { + if ((ret = wc_DecryptPKCS8Key(k, (word32)size, psw, + pswSz)) < 0) { + XFREE(k, pkcs12->heap, DYNAMIC_TYPE_PUBLIC_KEY); + goto exit_pk12par; + } + } + else { + if ((ret = ToTraditionalEnc(k, (word32)size, psw, + pswSz, &algId)) < 0) { + XFREE(k, pkcs12->heap, DYNAMIC_TYPE_PUBLIC_KEY); + goto exit_pk12par; + } } if (ret < size) { diff --git a/wolfssl/wolfcrypt/pkcs12.h b/wolfssl/wolfcrypt/pkcs12.h index da74fe2e4..6dc6e9df8 100644 --- a/wolfssl/wolfcrypt/pkcs12.h +++ b/wolfssl/wolfcrypt/pkcs12.h @@ -55,6 +55,9 @@ WOLFSSL_API int wc_i2d_PKCS12(WC_PKCS12* pkcs12, byte** der, int* derSz); WOLFSSL_API int wc_PKCS12_parse(WC_PKCS12* pkcs12, const char* psw, byte** pkey, word32* pkeySz, byte** cert, word32* certSz, WC_DerCertList** ca); +WOLFSSL_API int wc_PKCS12_parse_ex(WC_PKCS12* pkcs12, const char* psw, + byte** pkey, word32* pkeySz, byte** cert, word32* certSz, + WC_DerCertList** ca, int keepKeyHeader); WOLFSSL_LOCAL int wc_PKCS12_verify_ex(WC_PKCS12* pkcs12, const byte* psw, word32 pswSz); WOLFSSL_API WC_PKCS12* wc_PKCS12_create(char* pass, word32 passSz,