ML-KEM: FIPS 203 modulus check - reject non-reduced private key vector on decode

This commit is contained in:
aidan garske
2026-06-04 15:24:10 -07:00
parent 63fd322382
commit 477754024d
4 changed files with 38 additions and 10 deletions
+16 -5
View File
@@ -2043,7 +2043,9 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p,
* @return BAD_FUNC_ARG when key or in is NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
* @return PUBLIC_KEY_E when public key data doesn't match parameters.
* @return PUBLIC_KEY_E when the private or public vector has a coefficient
* that is not reduced modulo q, or public key data doesn't match
* parameters.
* @return MLKEM_PUB_HASH_E when public key hash doesn't match stored hash.
* @return MEMORY_E when dynamic memory allocation failed.
*/
@@ -2130,15 +2132,24 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
}
#endif
if (ret == 0) {
/* Clear the key-set flags first so any failure below (size, reduction
* check, or hash) leaves a reused key object consistently unusable
* rather than flagged-set with zeroed material. */
key->flags &= ~(MLKEM_FLAG_BOTH_SET | MLKEM_FLAG_H_SET);
/* Decode private key that is vector of polynomials.
* Alg 18 Step 1: dk_PKE <- dk[0 : 384k]
* Alg 15 Step 5: s_hat <- ByteDecode_12(dk_PKE) */
mlkem_from_bytes(key->priv, p, (int)k);
p += k * WC_ML_KEM_POLY_SIZE;
/* Decode the public key that is after the private key. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_public(key->pub, (int)k);
/* Both vectors must decode to coefficients reduced modulo q. */
ret = mlkem_check_reduced(key->priv, (int)k);
if (ret == 0) {
/* Decode the public key that is after the private key. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_reduced(key->pub, (int)k);
}
if (ret != 0) {
ForceZero(key->priv, k * MLKEM_N * sizeof(sword16));
}
@@ -2263,7 +2274,7 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in,
if (ret == 0) {
/* Decode public key and check public key matches parameters. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_public(key->pub, (int)k);
ret = mlkem_check_reduced(key->pub, (int)k);
}
if (ret == 0) {
/* Calculate public hash. */
+6 -3
View File
@@ -6096,14 +6096,17 @@ void mlkem_to_bytes(byte* b, sword16* p, int k)
}
/**
* Check the public key values are smaller than the modulus.
* Check the vector coefficients are reduced modulo q.
*
* @param [in] p Public key - vector.
* FIPS 203, Sections 7.2 and 7.3: encapsulation and decapsulation keys must
* decode to coefficients in Z_q; reject any that are not reduced.
*
* @param [in] p Key - vector of polynomials.
* @param [in] k Number of polynomials in vector.
* @return 0 when all values are in range.
* @return PUBLIC_KEY_E when at least one value is out of range.
*/
int mlkem_check_public(const sword16* p, int k)
int mlkem_check_reduced(const sword16* p, int k)
{
int ret = 0;
int i;
+15 -1
View File
@@ -52299,9 +52299,23 @@ WOLFSSL_TEST_SUBROUTINE wc_test_ret_t mlkem_test(void)
if (ret != 0)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);
if (XMEMCMP(priv, priv2, testData[i][2]) != 0)
if (XMEMCMP(priv, priv2, testData[i][1]) != 0)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);
/* FIPS 203 modulus check: a private key whose first coefficient is
* not reduced (>= q) must be rejected on decode. Free first so the
* reinit does not leak the decoded dynamic priv/pub buffers. */
wc_MlKemKey_Free(key);
ret = wc_MlKemKey_Init(key, testData[i][0], HEAP_HINT, devId);
if (ret != 0)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);
priv[0] = 0xff;
priv[1] |= 0x0f;
ret = wc_MlKemKey_DecodePrivateKey(key, priv, testData[i][1]);
if (ret != PUBLIC_KEY_E)
ERROR_OUT(WC_TEST_RET_ENC_I(i), out);
ret = 0;
#if !defined(WOLFSSL_NO_MALLOC) && !defined(WC_NO_CONSTRUCTORS)
tmpKey = wc_MlKemKey_New(testData[i][0], HEAP_HINT, devId);
if (tmpKey == NULL)
+1 -1
View File
@@ -576,7 +576,7 @@ void mlkem_from_bytes(sword16* p, const byte* b, int k);
WOLFSSL_LOCAL
void mlkem_to_bytes(byte* b, sword16* p, int k);
WOLFSSL_LOCAL
int mlkem_check_public(const sword16* p, int k);
int mlkem_check_reduced(const sword16* p, int k);
#ifdef USE_INTEL_SPEEDUP
WOLFSSL_LOCAL