From a3862f0e5989db0e977677c91168e00c16850dc0 Mon Sep 17 00:00:00 2001 From: Koji Takeda Date: Thu, 10 Apr 2025 14:17:56 +0900 Subject: [PATCH] Improve ML-DSA private key import --- src/ssl_load.c | 140 ++++++++++++++-------------------- tests/api.c | 114 +++++++++++++++++++++++++++ tests/api/test_mldsa.c | 2 +- wolfcrypt/src/dilithium.c | 50 ++++++++++-- wolfssl/wolfcrypt/dilithium.h | 4 + 5 files changed, 219 insertions(+), 91 deletions(-) diff --git a/src/ssl_load.c b/src/ssl_load.c index 542289f72..72bf81719 100644 --- a/src/ssl_load.c +++ b/src/ssl_load.c @@ -946,6 +946,9 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl, int ret; word32 idx; dilithium_key* key; + int keyFormatTemp = 0; + int keyTypeTemp; + int keySizeTemp; /* Allocate a Dilithium key to parse into. */ key = (dilithium_key*)XMALLOC(sizeof(dilithium_key), heap, @@ -955,106 +958,75 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl, } /* Initialize Dilithium key. */ - ret = wc_dilithium_init(key); - if (ret == 0) { - /* Set up key to parse the format specified. */ - if ((*keyFormat == ML_DSA_LEVEL2k) || ((*keyFormat == 0) && - ((der->length == ML_DSA_LEVEL2_KEY_SIZE) || - (der->length == ML_DSA_LEVEL2_PRV_KEY_SIZE)))) { - ret = wc_dilithium_set_level(key, WC_ML_DSA_44); - } - else if ((*keyFormat == ML_DSA_LEVEL3k) || ((*keyFormat == 0) && - ((der->length == ML_DSA_LEVEL3_KEY_SIZE) || - (der->length == ML_DSA_LEVEL3_PRV_KEY_SIZE)))) { - ret = wc_dilithium_set_level(key, WC_ML_DSA_65); - } - else if ((*keyFormat == ML_DSA_LEVEL5k) || ((*keyFormat == 0) && - ((der->length == ML_DSA_LEVEL5_KEY_SIZE) || - (der->length == ML_DSA_LEVEL5_PRV_KEY_SIZE)))) { - ret = wc_dilithium_set_level(key, WC_ML_DSA_87); - } - #ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT - else if ((*keyFormat == DILITHIUM_LEVEL2k) || ((*keyFormat == 0) && - ((der->length == DILITHIUM_LEVEL2_KEY_SIZE) || - (der->length == DILITHIUM_LEVEL2_PRV_KEY_SIZE)))) { - ret = wc_dilithium_set_level(key, WC_ML_DSA_44_DRAFT); - } - else if ((*keyFormat == DILITHIUM_LEVEL3k) || ((*keyFormat == 0) && - ((der->length == DILITHIUM_LEVEL3_KEY_SIZE) || - (der->length == DILITHIUM_LEVEL3_PRV_KEY_SIZE)))) { - ret = wc_dilithium_set_level(key, WC_ML_DSA_65_DRAFT); - } - else if ((*keyFormat == DILITHIUM_LEVEL5k) || ((*keyFormat == 0) && - ((der->length == DILITHIUM_LEVEL5_KEY_SIZE) || - (der->length == DILITHIUM_LEVEL5_PRV_KEY_SIZE)))) { - ret = wc_dilithium_set_level(key, WC_ML_DSA_87_DRAFT); - } - #endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */ - else { - wc_dilithium_free(key); - ret = ALGO_ID_E; - } - } - + ret = wc_dilithium_init(key); if (ret == 0) { /* Decode as a Dilithium private key. */ idx = 0; ret = wc_Dilithium_PrivateKeyDecode(der->buffer, &idx, key, der->length); if (ret == 0) { - /* Get the minimum Dilithium key size from SSL or SSL context - * object. */ - int minKeySz = ssl ? ssl->options.minDilithiumKeySz : - ctx->minDilithiumKeySz; + ret = dilithium_get_oid_sum(key, &keyFormatTemp); + if(ret == 0) { + /* Format is known. */ + #if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT) + if (keyFormatTemp == DILITHIUM_LEVEL2k) { + keyTypeTemp = dilithium_level2_sa_algo; + keySizeTemp = DILITHIUM_LEVEL2_KEY_SIZE; + } + else if (keyFormatTemp == DILITHIUM_LEVEL3k) { + keyTypeTemp = dilithium_level3_sa_algo; + keySizeTemp = DILITHIUM_LEVEL3_KEY_SIZE; + } + else if (keyFormatTemp == DILITHIUM_LEVEL5k) { + keyTypeTemp = dilithium_level5_sa_algo; + keySizeTemp = DILITHIUM_LEVEL5_KEY_SIZE; + } + else + #endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */ + if (keyFormatTemp == ML_DSA_LEVEL2k) { + keyTypeTemp = dilithium_level2_sa_algo; + keySizeTemp = ML_DSA_LEVEL2_KEY_SIZE; + } + else if (keyFormatTemp == ML_DSA_LEVEL3k) { + keyTypeTemp = dilithium_level3_sa_algo; + keySizeTemp = ML_DSA_LEVEL3_KEY_SIZE; + } + else if (keyFormatTemp == ML_DSA_LEVEL5k) { + keyTypeTemp = dilithium_level5_sa_algo; + keySizeTemp = ML_DSA_LEVEL5_KEY_SIZE; + } + else { + ret = ALGO_ID_E; + } + } - /* Format is known. */ - if (*keyFormat == ML_DSA_LEVEL2k) { - *keyType = dilithium_level2_sa_algo; - *keySize = ML_DSA_LEVEL2_KEY_SIZE; - } - else if (*keyFormat == ML_DSA_LEVEL3k) { - *keyType = dilithium_level3_sa_algo; - *keySize = ML_DSA_LEVEL3_KEY_SIZE; - } - else if (*keyFormat == ML_DSA_LEVEL5k) { - *keyType = dilithium_level5_sa_algo; - *keySize = ML_DSA_LEVEL5_KEY_SIZE; - } - #ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT - else if (*keyFormat == DILITHIUM_LEVEL2k) { - *keyType = dilithium_level2_sa_algo; - *keySize = DILITHIUM_LEVEL2_KEY_SIZE; - } - else if (*keyFormat == DILITHIUM_LEVEL3k) { - *keyType = dilithium_level3_sa_algo; - *keySize = DILITHIUM_LEVEL3_KEY_SIZE; - } - else if (*keyFormat == DILITHIUM_LEVEL5k) { - *keyType = dilithium_level5_sa_algo; - *keySize = DILITHIUM_LEVEL5_KEY_SIZE; - } - #endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */ + if(ret == 0) { + /* Get the minimum Dilithium key size from SSL or SSL context + * object. */ + int minKeySz = ssl ? ssl->options.minDilithiumKeySz : + ctx->minDilithiumKeySz; - /* Check that the size of the Dilithium key is enough. */ - if (*keySize < minKeySz) { - WOLFSSL_MSG("Dilithium private key too small"); - ret = DILITHIUM_KEY_SIZE_E; + /* Check that the size of the Dilithium key is enough. */ + if (keySizeTemp < minKeySz) { + WOLFSSL_MSG("Dilithium private key too small"); + ret = DILITHIUM_KEY_SIZE_E; + } + } + + if(ret == 0) { + *keyFormat = keyFormatTemp; + *keyType = keyTypeTemp; + *keySize = keySizeTemp; } } - /* Not a Dilithium key but check whether we know what it is. */ else if (*keyFormat == 0) { WOLFSSL_MSG("Not a Dilithium key"); - /* Format unknown so keep trying. */ + /* Unknowun format was not dilithium, so keep trying other formats. */ ret = 0; } - + /* Free dynamically allocated data in key. */ wc_dilithium_free(key); } - else if ((ret == WC_NO_ERR_TRACE(ALGO_ID_E)) && (*keyFormat == 0)) { - WOLFSSL_MSG("Not a Dilithium key"); - /* Format unknown so keep trying. */ - ret = 0; - } /* Dispose of allocated key. */ XFREE(key, heap, DYNAMIC_TYPE_DILITHIUM); diff --git a/tests/api.c b/tests/api.c index 0a0913586..8b9770c4b 100644 --- a/tests/api.c +++ b/tests/api.c @@ -13933,6 +13933,119 @@ static int test_wolfSSL_PKCS8_ED448(void) return EXPECT_RESULT(); } +static int test_wolfSSL_PKCS8_MLDSA(void) +{ + EXPECT_DECLS; +#if !defined(NO_ASN) && defined(HAVE_PKCS8) && \ + defined(HAVE_DILITHIUM) && !defined(NO_TLS) && \ + (!defined(NO_WOLFSSL_CLIENT) || !defined(NO_WOLFSSL_SERVER)) + + WOLFSSL_CTX* ctx = NULL; + size_t i; + const int derMaxSz = 8192; /* Largest size will be 7520 of separated format, WC_ML_DSA_87, DER */ + const int tempMaxSz = 10240; /* Largest size will be 10239 of separated format, WC_MLS_DSA_87, PEM */ + byte* der = NULL; + byte* temp = NULL; /* Store PEM or intermediate key */ + word32 derSz = 0; + word32 pemSz = 0; + word32 keySz = 0; + dilithium_key mldsa_key; + WC_RNG rng; + word32 size; + + struct { + int wcId; + int oidSum; + int keySz; + } test_variant[] = {{WC_ML_DSA_44, ML_DSA_LEVEL2k, ML_DSA_LEVEL2_PRV_KEY_SIZE}, + {WC_ML_DSA_65, ML_DSA_LEVEL3k, ML_DSA_LEVEL3_PRV_KEY_SIZE}, + {WC_ML_DSA_87, ML_DSA_LEVEL5k, ML_DSA_LEVEL5_PRV_KEY_SIZE}}; + + (void) pemSz; + + ExpectNotNull(der = (byte*) XMALLOC(derMaxSz, NULL, DYNAMIC_TYPE_TMP_BUFFER)); + ExpectNotNull(temp = (byte*) XMALLOC(tempMaxSz, NULL, DYNAMIC_TYPE_TMP_BUFFER)); + +#ifndef NO_WOLFSSL_SERVER + ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_server_method())); +#else + ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_client_method())); +#endif /* NO_WOLFSSL_SERVER */ + + ExpectIntEQ(wc_InitRng(&rng), 0); + ExpectIntEQ(wc_dilithium_init(&mldsa_key), 0); + + /* Test private + public key (separated format) */ + for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) { + ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0); + ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0); + + ExpectIntGT(derSz = wc_Dilithium_KeyToDer(&mldsa_key, der, derMaxSz), 0); + ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz, + WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS); + +#ifdef WOLFSSL_DER_TO_PEM + ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0); + ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz, + WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS); +#endif /* WOLFSSL_DER_TO_PEM */ + } + + /* Test private key only */ + for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) { + ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0); + ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0); + + ExpectIntGT(derSz = wc_Dilithium_PrivateKeyToDer(&mldsa_key, der, derMaxSz), 0); + ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz, + WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS); + +#ifdef WOLFSSL_DER_TO_PEM + ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0); + ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz, + WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS); +#endif /* WOLFSSL_DER_TO_PEM */ + } + + /* Test private + public key (integrated format) */ + for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) { + ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0); + ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0); + + keySz = 0; + temp[0] = 0x04; /* ASN.1 OCTET STRING */ + temp[1] = 0x82; /* 2 bytes length field */ + temp[2] = (test_variant[i].keySz >> 8) & 0xff; /* MSB of the length */ + temp[3] = test_variant[i].keySz & 0xff; /* LSB of the length */ + keySz += 4; + size = tempMaxSz - keySz; + ExpectIntEQ(wc_dilithium_export_private(&mldsa_key, temp + keySz, &size), 0); + keySz += size; + size = tempMaxSz - keySz; + ExpectIntEQ(wc_dilithium_export_public(&mldsa_key, temp + keySz, &size), 0); + keySz += size; + derSz = derMaxSz; + ExpectIntGT(wc_CreatePKCS8Key(der, &derSz, temp, keySz, test_variant[i].oidSum, NULL, 0), 0); + ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz, + WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS); + +#ifdef WOLFSSL_DER_TO_PEM + ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0); + ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz, + WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS); +#endif /* WOLFSSL_DER_TO_PEM */ + } + + wc_dilithium_free(&mldsa_key); + ExpectIntEQ(wc_FreeRng(&rng), 0); + wolfSSL_CTX_free(ctx); + XFREE(temp, NULL, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(der, NULL, DYNAMIC_TYPE_TMP_BUFFER); + +#endif + return EXPECT_RESULT(); +} + /* Testing functions dealing with PKCS5 */ static int test_wolfSSL_PKCS5(void) { @@ -67519,6 +67632,7 @@ TEST_CASE testCases[] = { TEST_DECL(test_wolfSSL_PKCS8), TEST_DECL(test_wolfSSL_PKCS8_ED25519), TEST_DECL(test_wolfSSL_PKCS8_ED448), + TEST_DECL(test_wolfSSL_PKCS8_MLDSA), #ifdef HAVE_IO_TESTS_DEPENDENCIES TEST_DECL(test_wolfSSL_get_finished), diff --git a/tests/api/test_mldsa.c b/tests/api/test_mldsa.c index 2229c5f77..5991d0b16 100644 --- a/tests/api/test_mldsa.c +++ b/tests/api/test_mldsa.c @@ -2959,7 +2959,7 @@ int test_wc_dilithium_der(void) idx = 0; #ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen), - WC_NO_ERR_TRACE(BAD_FUNC_ARG)); + WC_NO_ERR_TRACE(ASN_PARSE_E)); #else ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen), WC_NO_ERR_TRACE(ASN_PARSE_E)); diff --git a/wolfcrypt/src/dilithium.c b/wolfcrypt/src/dilithium.c index e4498edca..ec60e5156 100644 --- a/wolfcrypt/src/dilithium.c +++ b/wolfcrypt/src/dilithium.c @@ -9589,6 +9589,42 @@ static int mapOidToSecLevel(word32 oid) } } +/* Get OID sum from dilithium key */ +int dilithium_get_oid_sum(dilithium_key* key, int* keyFormat) { + int ret = 0; + + #if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT) + if (key->params == NULL) { + ret = BAD_FUNC_ARG; + } + else if (key->params->level == WC_ML_DSA_44_DRAFT) { + *keyFormat = DILITHIUM_LEVEL2k; + } + else if (key->params->level == WC_ML_DSA_65_DRAFT) { + *keyFormat = DILITHIUM_LEVEL3k; + } + else if (key->params->level == WC_ML_DSA_87_DRAFT) { + *keyFormat = DILITHIUM_LEVEL5k; + } + else + #endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */ + if (key->level == WC_ML_DSA_44) { + *keyFormat = ML_DSA_LEVEL2k; + } + else if (key->level == WC_ML_DSA_65) { + *keyFormat = ML_DSA_LEVEL3k; + } + else if (key->level == WC_ML_DSA_87) { + *keyFormat = ML_DSA_LEVEL5k; + } + else { + /* Level is not set */ + ret = ALGO_ID_E; + } + + return ret; +} + #if defined(WOLFSSL_DILITHIUM_PRIVATE_KEY) /* Decode the DER encoded Dilithium key. @@ -9627,9 +9663,13 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx, } if (ret == 0) { - /* Get OID sum for level. */ + /* Get OID sum for level. */ + if(key->level == 0) { /* Check first, because key->params will be NULL when key->level = 0 */ + /* Level not set by caller, decode from DER */ + keytype = ANONk; + } #if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT) - if (key->params == NULL) { + else if (key->params == NULL) { ret = BAD_FUNC_ARG; } else if (key->params->level == WC_ML_DSA_44_DRAFT) { @@ -9641,9 +9681,8 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx, else if (key->params->level == WC_ML_DSA_87_DRAFT) { keytype = DILITHIUM_LEVEL5k; } - else #endif - if (key->level == WC_ML_DSA_44) { + else if (key->level == WC_ML_DSA_44) { keytype = ML_DSA_LEVEL2k; } else if (key->level == WC_ML_DSA_65) { @@ -9653,8 +9692,7 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx, keytype = ML_DSA_LEVEL5k; } else { - /* Level not set by caller, decode from DER */ - keytype = ANONk; /* 0, not a valid key type in this situation*/ + ret = BAD_FUNC_ARG; } } diff --git a/wolfssl/wolfcrypt/dilithium.h b/wolfssl/wolfcrypt/dilithium.h index 25fd1587f..c7655884f 100644 --- a/wolfssl/wolfcrypt/dilithium.h +++ b/wolfssl/wolfcrypt/dilithium.h @@ -813,6 +813,10 @@ int wc_dilithium_export_key(dilithium_key* key, byte* priv, word32 *privSz, byte* pub, word32 *pubSz); #endif +#ifndef WOLFSSL_DILITHIUM_NO_ASN1 +WOLFSSL_LOCAL int dilithium_get_oid_sum(dilithium_key* key, int* keyFormat); +#endif /* WOLFSSL_DILITHIUM_NO_ASN1 */ + #ifndef WOLFSSL_DILITHIUM_NO_ASN1 #if defined(WOLFSSL_DILITHIUM_PRIVATE_KEY) WOLFSSL_API int wc_Dilithium_PrivateKeyDecode(const byte* input,