From 477754024dbac6f472148e2f3b7948300b552c2b Mon Sep 17 00:00:00 2001 From: aidan garske Date: Thu, 4 Jun 2026 15:24:10 -0700 Subject: [PATCH] ML-KEM: FIPS 203 modulus check - reject non-reduced private key vector on decode --- wolfcrypt/src/wc_mlkem.c | 21 ++++++++++++++++----- wolfcrypt/src/wc_mlkem_poly.c | 9 ++++++--- wolfcrypt/test/test.c | 16 +++++++++++++++- wolfssl/wolfcrypt/wc_mlkem.h | 2 +- 4 files changed, 38 insertions(+), 10 deletions(-) diff --git a/wolfcrypt/src/wc_mlkem.c b/wolfcrypt/src/wc_mlkem.c index 89f647ebef..e217e5ebed 100644 --- a/wolfcrypt/src/wc_mlkem.c +++ b/wolfcrypt/src/wc_mlkem.c @@ -2043,7 +2043,9 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p, * @return BAD_FUNC_ARG when key or in is NULL. * @return NOT_COMPILED_IN when key type is not supported. * @return BUFFER_E when len is not the correct size. - * @return PUBLIC_KEY_E when public key data doesn't match parameters. + * @return PUBLIC_KEY_E when the private or public vector has a coefficient + * that is not reduced modulo q, or public key data doesn't match + * parameters. * @return MLKEM_PUB_HASH_E when public key hash doesn't match stored hash. * @return MEMORY_E when dynamic memory allocation failed. */ @@ -2130,15 +2132,24 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in, } #endif if (ret == 0) { + /* Clear the key-set flags first so any failure below (size, reduction + * check, or hash) leaves a reused key object consistently unusable + * rather than flagged-set with zeroed material. */ + key->flags &= ~(MLKEM_FLAG_BOTH_SET | MLKEM_FLAG_H_SET); + /* Decode private key that is vector of polynomials. * Alg 18 Step 1: dk_PKE <- dk[0 : 384k] * Alg 15 Step 5: s_hat <- ByteDecode_12(dk_PKE) */ mlkem_from_bytes(key->priv, p, (int)k); p += k * WC_ML_KEM_POLY_SIZE; - /* Decode the public key that is after the private key. */ - mlkemkey_decode_public(key->pub, key->pubSeed, p, k); - ret = mlkem_check_public(key->pub, (int)k); + /* Both vectors must decode to coefficients reduced modulo q. */ + ret = mlkem_check_reduced(key->priv, (int)k); + if (ret == 0) { + /* Decode the public key that is after the private key. */ + mlkemkey_decode_public(key->pub, key->pubSeed, p, k); + ret = mlkem_check_reduced(key->pub, (int)k); + } if (ret != 0) { ForceZero(key->priv, k * MLKEM_N * sizeof(sword16)); } @@ -2263,7 +2274,7 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in, if (ret == 0) { /* Decode public key and check public key matches parameters. */ mlkemkey_decode_public(key->pub, key->pubSeed, p, k); - ret = mlkem_check_public(key->pub, (int)k); + ret = mlkem_check_reduced(key->pub, (int)k); } if (ret == 0) { /* Calculate public hash. */ diff --git a/wolfcrypt/src/wc_mlkem_poly.c b/wolfcrypt/src/wc_mlkem_poly.c index 1a0f3e2213..aa3d7835d5 100644 --- a/wolfcrypt/src/wc_mlkem_poly.c +++ b/wolfcrypt/src/wc_mlkem_poly.c @@ -6096,14 +6096,17 @@ void mlkem_to_bytes(byte* b, sword16* p, int k) } /** - * Check the public key values are smaller than the modulus. + * Check the vector coefficients are reduced modulo q. * - * @param [in] p Public key - vector. + * FIPS 203, Sections 7.2 and 7.3: encapsulation and decapsulation keys must + * decode to coefficients in Z_q; reject any that are not reduced. + * + * @param [in] p Key - vector of polynomials. * @param [in] k Number of polynomials in vector. * @return 0 when all values are in range. * @return PUBLIC_KEY_E when at least one value is out of range. */ -int mlkem_check_public(const sword16* p, int k) +int mlkem_check_reduced(const sword16* p, int k) { int ret = 0; int i; diff --git a/wolfcrypt/test/test.c b/wolfcrypt/test/test.c index d76dd112c9..b004d3a7b7 100644 --- a/wolfcrypt/test/test.c +++ b/wolfcrypt/test/test.c @@ -52299,9 +52299,23 @@ WOLFSSL_TEST_SUBROUTINE wc_test_ret_t mlkem_test(void) if (ret != 0) ERROR_OUT(WC_TEST_RET_ENC_I(i), out); - if (XMEMCMP(priv, priv2, testData[i][2]) != 0) + if (XMEMCMP(priv, priv2, testData[i][1]) != 0) ERROR_OUT(WC_TEST_RET_ENC_I(i), out); + /* FIPS 203 modulus check: a private key whose first coefficient is + * not reduced (>= q) must be rejected on decode. Free first so the + * reinit does not leak the decoded dynamic priv/pub buffers. */ + wc_MlKemKey_Free(key); + ret = wc_MlKemKey_Init(key, testData[i][0], HEAP_HINT, devId); + if (ret != 0) + ERROR_OUT(WC_TEST_RET_ENC_I(i), out); + priv[0] = 0xff; + priv[1] |= 0x0f; + ret = wc_MlKemKey_DecodePrivateKey(key, priv, testData[i][1]); + if (ret != PUBLIC_KEY_E) + ERROR_OUT(WC_TEST_RET_ENC_I(i), out); + ret = 0; + #if !defined(WOLFSSL_NO_MALLOC) && !defined(WC_NO_CONSTRUCTORS) tmpKey = wc_MlKemKey_New(testData[i][0], HEAP_HINT, devId); if (tmpKey == NULL) diff --git a/wolfssl/wolfcrypt/wc_mlkem.h b/wolfssl/wolfcrypt/wc_mlkem.h index 4d02a6252a..3dad3a5e94 100644 --- a/wolfssl/wolfcrypt/wc_mlkem.h +++ b/wolfssl/wolfcrypt/wc_mlkem.h @@ -576,7 +576,7 @@ void mlkem_from_bytes(sword16* p, const byte* b, int k); WOLFSSL_LOCAL void mlkem_to_bytes(byte* b, sword16* p, int k); WOLFSSL_LOCAL -int mlkem_check_public(const sword16* p, int k); +int mlkem_check_reduced(const sword16* p, int k); #ifdef USE_INTEL_SPEEDUP WOLFSSL_LOCAL