Improve ML-DSA private key import

This commit is contained in:
Koji Takeda
2025-04-10 14:17:56 +09:00
parent 16a6818614
commit a3862f0e59
5 changed files with 219 additions and 91 deletions

View File

@@ -946,6 +946,9 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl,
int ret;
word32 idx;
dilithium_key* key;
int keyFormatTemp = 0;
int keyTypeTemp;
int keySizeTemp;
/* Allocate a Dilithium key to parse into. */
key = (dilithium_key*)XMALLOC(sizeof(dilithium_key), heap,
@@ -955,106 +958,75 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl,
}
/* Initialize Dilithium key. */
ret = wc_dilithium_init(key);
if (ret == 0) {
/* Set up key to parse the format specified. */
if ((*keyFormat == ML_DSA_LEVEL2k) || ((*keyFormat == 0) &&
((der->length == ML_DSA_LEVEL2_KEY_SIZE) ||
(der->length == ML_DSA_LEVEL2_PRV_KEY_SIZE)))) {
ret = wc_dilithium_set_level(key, WC_ML_DSA_44);
}
else if ((*keyFormat == ML_DSA_LEVEL3k) || ((*keyFormat == 0) &&
((der->length == ML_DSA_LEVEL3_KEY_SIZE) ||
(der->length == ML_DSA_LEVEL3_PRV_KEY_SIZE)))) {
ret = wc_dilithium_set_level(key, WC_ML_DSA_65);
}
else if ((*keyFormat == ML_DSA_LEVEL5k) || ((*keyFormat == 0) &&
((der->length == ML_DSA_LEVEL5_KEY_SIZE) ||
(der->length == ML_DSA_LEVEL5_PRV_KEY_SIZE)))) {
ret = wc_dilithium_set_level(key, WC_ML_DSA_87);
}
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
else if ((*keyFormat == DILITHIUM_LEVEL2k) || ((*keyFormat == 0) &&
((der->length == DILITHIUM_LEVEL2_KEY_SIZE) ||
(der->length == DILITHIUM_LEVEL2_PRV_KEY_SIZE)))) {
ret = wc_dilithium_set_level(key, WC_ML_DSA_44_DRAFT);
}
else if ((*keyFormat == DILITHIUM_LEVEL3k) || ((*keyFormat == 0) &&
((der->length == DILITHIUM_LEVEL3_KEY_SIZE) ||
(der->length == DILITHIUM_LEVEL3_PRV_KEY_SIZE)))) {
ret = wc_dilithium_set_level(key, WC_ML_DSA_65_DRAFT);
}
else if ((*keyFormat == DILITHIUM_LEVEL5k) || ((*keyFormat == 0) &&
((der->length == DILITHIUM_LEVEL5_KEY_SIZE) ||
(der->length == DILITHIUM_LEVEL5_PRV_KEY_SIZE)))) {
ret = wc_dilithium_set_level(key, WC_ML_DSA_87_DRAFT);
}
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
else {
wc_dilithium_free(key);
ret = ALGO_ID_E;
}
}
ret = wc_dilithium_init(key);
if (ret == 0) {
/* Decode as a Dilithium private key. */
idx = 0;
ret = wc_Dilithium_PrivateKeyDecode(der->buffer, &idx, key, der->length);
if (ret == 0) {
/* Get the minimum Dilithium key size from SSL or SSL context
* object. */
int minKeySz = ssl ? ssl->options.minDilithiumKeySz :
ctx->minDilithiumKeySz;
ret = dilithium_get_oid_sum(key, &keyFormatTemp);
if(ret == 0) {
/* Format is known. */
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
if (keyFormatTemp == DILITHIUM_LEVEL2k) {
keyTypeTemp = dilithium_level2_sa_algo;
keySizeTemp = DILITHIUM_LEVEL2_KEY_SIZE;
}
else if (keyFormatTemp == DILITHIUM_LEVEL3k) {
keyTypeTemp = dilithium_level3_sa_algo;
keySizeTemp = DILITHIUM_LEVEL3_KEY_SIZE;
}
else if (keyFormatTemp == DILITHIUM_LEVEL5k) {
keyTypeTemp = dilithium_level5_sa_algo;
keySizeTemp = DILITHIUM_LEVEL5_KEY_SIZE;
}
else
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
if (keyFormatTemp == ML_DSA_LEVEL2k) {
keyTypeTemp = dilithium_level2_sa_algo;
keySizeTemp = ML_DSA_LEVEL2_KEY_SIZE;
}
else if (keyFormatTemp == ML_DSA_LEVEL3k) {
keyTypeTemp = dilithium_level3_sa_algo;
keySizeTemp = ML_DSA_LEVEL3_KEY_SIZE;
}
else if (keyFormatTemp == ML_DSA_LEVEL5k) {
keyTypeTemp = dilithium_level5_sa_algo;
keySizeTemp = ML_DSA_LEVEL5_KEY_SIZE;
}
else {
ret = ALGO_ID_E;
}
}
/* Format is known. */
if (*keyFormat == ML_DSA_LEVEL2k) {
*keyType = dilithium_level2_sa_algo;
*keySize = ML_DSA_LEVEL2_KEY_SIZE;
}
else if (*keyFormat == ML_DSA_LEVEL3k) {
*keyType = dilithium_level3_sa_algo;
*keySize = ML_DSA_LEVEL3_KEY_SIZE;
}
else if (*keyFormat == ML_DSA_LEVEL5k) {
*keyType = dilithium_level5_sa_algo;
*keySize = ML_DSA_LEVEL5_KEY_SIZE;
}
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
else if (*keyFormat == DILITHIUM_LEVEL2k) {
*keyType = dilithium_level2_sa_algo;
*keySize = DILITHIUM_LEVEL2_KEY_SIZE;
}
else if (*keyFormat == DILITHIUM_LEVEL3k) {
*keyType = dilithium_level3_sa_algo;
*keySize = DILITHIUM_LEVEL3_KEY_SIZE;
}
else if (*keyFormat == DILITHIUM_LEVEL5k) {
*keyType = dilithium_level5_sa_algo;
*keySize = DILITHIUM_LEVEL5_KEY_SIZE;
}
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
if(ret == 0) {
/* Get the minimum Dilithium key size from SSL or SSL context
* object. */
int minKeySz = ssl ? ssl->options.minDilithiumKeySz :
ctx->minDilithiumKeySz;
/* Check that the size of the Dilithium key is enough. */
if (*keySize < minKeySz) {
WOLFSSL_MSG("Dilithium private key too small");
ret = DILITHIUM_KEY_SIZE_E;
/* Check that the size of the Dilithium key is enough. */
if (keySizeTemp < minKeySz) {
WOLFSSL_MSG("Dilithium private key too small");
ret = DILITHIUM_KEY_SIZE_E;
}
}
if(ret == 0) {
*keyFormat = keyFormatTemp;
*keyType = keyTypeTemp;
*keySize = keySizeTemp;
}
}
/* Not a Dilithium key but check whether we know what it is. */
else if (*keyFormat == 0) {
WOLFSSL_MSG("Not a Dilithium key");
/* Format unknown so keep trying. */
/* Unknowun format was not dilithium, so keep trying other formats. */
ret = 0;
}
/* Free dynamically allocated data in key. */
wc_dilithium_free(key);
}
else if ((ret == WC_NO_ERR_TRACE(ALGO_ID_E)) && (*keyFormat == 0)) {
WOLFSSL_MSG("Not a Dilithium key");
/* Format unknown so keep trying. */
ret = 0;
}
/* Dispose of allocated key. */
XFREE(key, heap, DYNAMIC_TYPE_DILITHIUM);

View File

@@ -13933,6 +13933,119 @@ static int test_wolfSSL_PKCS8_ED448(void)
return EXPECT_RESULT();
}
static int test_wolfSSL_PKCS8_MLDSA(void)
{
EXPECT_DECLS;
#if !defined(NO_ASN) && defined(HAVE_PKCS8) && \
defined(HAVE_DILITHIUM) && !defined(NO_TLS) && \
(!defined(NO_WOLFSSL_CLIENT) || !defined(NO_WOLFSSL_SERVER))
WOLFSSL_CTX* ctx = NULL;
size_t i;
const int derMaxSz = 8192; /* Largest size will be 7520 of separated format, WC_ML_DSA_87, DER */
const int tempMaxSz = 10240; /* Largest size will be 10239 of separated format, WC_MLS_DSA_87, PEM */
byte* der = NULL;
byte* temp = NULL; /* Store PEM or intermediate key */
word32 derSz = 0;
word32 pemSz = 0;
word32 keySz = 0;
dilithium_key mldsa_key;
WC_RNG rng;
word32 size;
struct {
int wcId;
int oidSum;
int keySz;
} test_variant[] = {{WC_ML_DSA_44, ML_DSA_LEVEL2k, ML_DSA_LEVEL2_PRV_KEY_SIZE},
{WC_ML_DSA_65, ML_DSA_LEVEL3k, ML_DSA_LEVEL3_PRV_KEY_SIZE},
{WC_ML_DSA_87, ML_DSA_LEVEL5k, ML_DSA_LEVEL5_PRV_KEY_SIZE}};
(void) pemSz;
ExpectNotNull(der = (byte*) XMALLOC(derMaxSz, NULL, DYNAMIC_TYPE_TMP_BUFFER));
ExpectNotNull(temp = (byte*) XMALLOC(tempMaxSz, NULL, DYNAMIC_TYPE_TMP_BUFFER));
#ifndef NO_WOLFSSL_SERVER
ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_server_method()));
#else
ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_client_method()));
#endif /* NO_WOLFSSL_SERVER */
ExpectIntEQ(wc_InitRng(&rng), 0);
ExpectIntEQ(wc_dilithium_init(&mldsa_key), 0);
/* Test private + public key (separated format) */
for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0);
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
ExpectIntGT(derSz = wc_Dilithium_KeyToDer(&mldsa_key, der, derMaxSz), 0);
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
#ifdef WOLFSSL_DER_TO_PEM
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0);
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
#endif /* WOLFSSL_DER_TO_PEM */
}
/* Test private key only */
for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0);
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
ExpectIntGT(derSz = wc_Dilithium_PrivateKeyToDer(&mldsa_key, der, derMaxSz), 0);
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
#ifdef WOLFSSL_DER_TO_PEM
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0);
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
#endif /* WOLFSSL_DER_TO_PEM */
}
/* Test private + public key (integrated format) */
for(i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId), 0);
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
keySz = 0;
temp[0] = 0x04; /* ASN.1 OCTET STRING */
temp[1] = 0x82; /* 2 bytes length field */
temp[2] = (test_variant[i].keySz >> 8) & 0xff; /* MSB of the length */
temp[3] = test_variant[i].keySz & 0xff; /* LSB of the length */
keySz += 4;
size = tempMaxSz - keySz;
ExpectIntEQ(wc_dilithium_export_private(&mldsa_key, temp + keySz, &size), 0);
keySz += size;
size = tempMaxSz - keySz;
ExpectIntEQ(wc_dilithium_export_public(&mldsa_key, temp + keySz, &size), 0);
keySz += size;
derSz = derMaxSz;
ExpectIntGT(wc_CreatePKCS8Key(der, &derSz, temp, keySz, test_variant[i].oidSum, NULL, 0), 0);
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
#ifdef WOLFSSL_DER_TO_PEM
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz, PKCS8_PRIVATEKEY_TYPE), 0);
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
#endif /* WOLFSSL_DER_TO_PEM */
}
wc_dilithium_free(&mldsa_key);
ExpectIntEQ(wc_FreeRng(&rng), 0);
wolfSSL_CTX_free(ctx);
XFREE(temp, NULL, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(der, NULL, DYNAMIC_TYPE_TMP_BUFFER);
#endif
return EXPECT_RESULT();
}
/* Testing functions dealing with PKCS5 */
static int test_wolfSSL_PKCS5(void)
{
@@ -67519,6 +67632,7 @@ TEST_CASE testCases[] = {
TEST_DECL(test_wolfSSL_PKCS8),
TEST_DECL(test_wolfSSL_PKCS8_ED25519),
TEST_DECL(test_wolfSSL_PKCS8_ED448),
TEST_DECL(test_wolfSSL_PKCS8_MLDSA),
#ifdef HAVE_IO_TESTS_DEPENDENCIES
TEST_DECL(test_wolfSSL_get_finished),

View File

@@ -2959,7 +2959,7 @@ int test_wc_dilithium_der(void)
idx = 0;
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen),
WC_NO_ERR_TRACE(BAD_FUNC_ARG));
WC_NO_ERR_TRACE(ASN_PARSE_E));
#else
ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen),
WC_NO_ERR_TRACE(ASN_PARSE_E));

View File

@@ -9589,6 +9589,42 @@ static int mapOidToSecLevel(word32 oid)
}
}
/* Get OID sum from dilithium key */
int dilithium_get_oid_sum(dilithium_key* key, int* keyFormat) {
int ret = 0;
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
if (key->params == NULL) {
ret = BAD_FUNC_ARG;
}
else if (key->params->level == WC_ML_DSA_44_DRAFT) {
*keyFormat = DILITHIUM_LEVEL2k;
}
else if (key->params->level == WC_ML_DSA_65_DRAFT) {
*keyFormat = DILITHIUM_LEVEL3k;
}
else if (key->params->level == WC_ML_DSA_87_DRAFT) {
*keyFormat = DILITHIUM_LEVEL5k;
}
else
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
if (key->level == WC_ML_DSA_44) {
*keyFormat = ML_DSA_LEVEL2k;
}
else if (key->level == WC_ML_DSA_65) {
*keyFormat = ML_DSA_LEVEL3k;
}
else if (key->level == WC_ML_DSA_87) {
*keyFormat = ML_DSA_LEVEL5k;
}
else {
/* Level is not set */
ret = ALGO_ID_E;
}
return ret;
}
#if defined(WOLFSSL_DILITHIUM_PRIVATE_KEY)
/* Decode the DER encoded Dilithium key.
@@ -9627,9 +9663,13 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
}
if (ret == 0) {
/* Get OID sum for level. */
/* Get OID sum for level. */
if(key->level == 0) { /* Check first, because key->params will be NULL when key->level = 0 */
/* Level not set by caller, decode from DER */
keytype = ANONk;
}
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
if (key->params == NULL) {
else if (key->params == NULL) {
ret = BAD_FUNC_ARG;
}
else if (key->params->level == WC_ML_DSA_44_DRAFT) {
@@ -9641,9 +9681,8 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
else if (key->params->level == WC_ML_DSA_87_DRAFT) {
keytype = DILITHIUM_LEVEL5k;
}
else
#endif
if (key->level == WC_ML_DSA_44) {
else if (key->level == WC_ML_DSA_44) {
keytype = ML_DSA_LEVEL2k;
}
else if (key->level == WC_ML_DSA_65) {
@@ -9653,8 +9692,7 @@ int wc_Dilithium_PrivateKeyDecode(const byte* input, word32* inOutIdx,
keytype = ML_DSA_LEVEL5k;
}
else {
/* Level not set by caller, decode from DER */
keytype = ANONk; /* 0, not a valid key type in this situation*/
ret = BAD_FUNC_ARG;
}
}

View File

@@ -813,6 +813,10 @@ int wc_dilithium_export_key(dilithium_key* key, byte* priv, word32 *privSz,
byte* pub, word32 *pubSz);
#endif
#ifndef WOLFSSL_DILITHIUM_NO_ASN1
WOLFSSL_LOCAL int dilithium_get_oid_sum(dilithium_key* key, int* keyFormat);
#endif /* WOLFSSL_DILITHIUM_NO_ASN1 */
#ifndef WOLFSSL_DILITHIUM_NO_ASN1
#if defined(WOLFSSL_DILITHIUM_PRIVATE_KEY)
WOLFSSL_API int wc_Dilithium_PrivateKeyDecode(const byte* input,