ML-KEM: Fixes for comments plus bug fixes

wc_MlKemKey_SharedSecretSize: Check len is not NULL before use.
wc_MlKemKey_DecodePrivateKey:
  Don't set flags when public key hash fails.
  ForceZero the private key on failure if copied.
This commit is contained in:
Sean Parkinson
2026-03-10 21:09:08 +10:00
parent b3f08f33b8
commit b180a279b0
2 changed files with 283 additions and 273 deletions
+62 -48
View File
@@ -60,7 +60,7 @@
* Stores the matrix A during key generation for use in encapsulation when
* performing decapsulation.
* KyberKey is 8KB larger but decapsulation is significantly faster.
* Turn on when performing make key and decapsualtion with same object.
* Turn on when performing make key and decapsulation with same object.
*/
#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
@@ -219,10 +219,10 @@ int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p)
/**
* Initialize the Kyber key.
*
* @param [out] key Kyber key object to initialize.
* @param [in] type Type of key:
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
* KYBER512, KYBER768, KYBER1024.
* @param [out] key Kyber key object to initialize.
* @param [in] heap Dynamic memory hint.
* @param [in] devId Device Id.
* @return 0 on success.
@@ -292,7 +292,7 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
/* Cache heap pointer. */
key->heap = heap;
#ifdef WOLF_CRYPTO_CB
/* Cache device id - not used in for this algorithm yet. */
/* Cache device id - not used in this algorithm yet. */
key->devId = devId;
#endif
key->flags = 0;
@@ -353,17 +353,16 @@ int wc_MlKemKey_Free(MlKemKey* key)
* 4: return falsum
* > return an error indication if random bit generation failed
* 5: end if
* 6: (ek,dk) <- ML-KEM.KeyGen_Interal(d, z)
* 6: (ek,dk) <- ML-KEM.KeyGen_Internal(d, z)
* > run internal key generation algorithm
* &: return (ek,dk)
* 7: return (ek,dk)
*
* @param [in, out] key Kyber key object.
* @param [in] rng Random number generator.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or rng is NULL.
* @return MEMORY_E when dynamic memory allocation failed.
* @return MEMORY_E when dynamic memory allocation failed.
* @return RNG_FAILURE_E when generating random numbers failed.
* @return RNG_FAILURE_E when generating random numbers failed.
* @return DRBG_CONT_FAILURE when random number generator health check fails.
*/
int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
@@ -405,13 +404,13 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
* FIPS 203 - Algorithm 16: ML-KEM.KeyGen_internal(d,z)
* Uses randomness to generate an encapsulation key and a corresponding
* decapsulation key.
* 1: (ek_PKE,dk_PKE) < K-PKE.KeyGen(d) > run key generation for K-PKE
* 1: (ek_PKE,dk_PKE) <- K-PKE.KeyGen(d) > run key generation for K-PKE
* ...
*
* FIPS 203 - Algorithm 13: K-PKE.KeyGen(d)
* Uses randomness to generate an encryption key and a corresponding decryption
* key.
* 1: (rho,sigma) <- G(d||k)A
* 1: (rho,sigma) <- G(d||k)
* > expand 32+1 bytes to two pseudorandom 32-byte seeds
* 2: N <- 0
* 3-7: generate matrix A_hat
@@ -420,7 +419,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
* 16-18: calculate t_hat from A_hat, s and e
* ...
*
* @param [in, out] key Kyber key ovject.
* @param [in, out] key Kyber key object.
* @param [in] rand Random data.
* @param [in] len Length of random data in bytes.
* @return 0 on success.
@@ -552,7 +551,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#endif
#ifdef WOLFSSL_MLKEM_KYBER
{
/* Expand 32 bytes of random to 32. */
/* Expand 32 bytes of random to 64. */
ret = MLKEM_HASH_G(&key->hash, d, WC_ML_KEM_SYM_SZ, NULL, 0, buf);
}
#endif
@@ -562,7 +561,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#ifndef WOLFSSL_NO_ML_KEM
{
buf[0] = k;
/* Expand 33 bytes of random to 32.
/* Expand 33 bytes of random to 64.
* Alg 13: Step 1: (rho,sigma) <- G(d||k)
*/
ret = MLKEM_HASH_G(&key->hash, d, WC_ML_KEM_SYM_SZ, buf, 1, buf);
@@ -572,9 +571,11 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#ifdef WC_MLKEM_FAULT_HARDEN
if (ret == 0) {
XMEMCPY(sigma, buf + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ);
/* Check that correct data was copied and pointer not changed. */
if (XMEMCMP(sigma, rho, WC_ML_KEM_SYM_SZ) == 0) {
ret = BAD_COND_E;
}
/* Check that rho is sigma - rho may have been modified. */
if (XMEMCMP(sigma, rho + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ) != 0) {
ret = BAD_COND_E;
}
@@ -619,8 +620,8 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
if (ret == 0) {
/* Generate key pair from private vector and seeds.
* Alg 13: Steps 3-7: generate matrix A_hat
* Alg 13: 12-15: generate e
* Alg 13: 16-18: calculate t_hat from A_hat, s and e
* Alg 13: Steps 12-15: generate e
* Alg 13: Steps 16-18: calculate t_hat from A_hat, s and e
*/
ret = mlkem_keygen_seeds(s, t, &key->prf, e, k, rho, sigma);
}
@@ -715,17 +716,23 @@ int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len)
* Size of a shared secret in bytes. Always KYBER_SS_SZ.
*
* @param [in] key Kyber key object. Not used.
* @param [out] Size of the shared secret created with a Kyber key.
* @param [out] len Size of the shared secret created with a Kyber key.
* @return 0 on success.
* @return 0 to indicate success.
* @return BAD_FUNC_ARG when len is NULL.
*/
int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
{
int ret = 0;
if (len == NULL) {
ret = BAD_FUNC_ARG;
}
else {
*len = WC_ML_KEM_SS_SZ;
}
(void)key;
*len = WC_ML_KEM_SS_SZ;
return 0;
return ret;
}
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
@@ -738,7 +745,7 @@ int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
* 1: N <- 0
* 2: t_hat <- ByteDecode_12(ek_PKE[0:384k])
* > run ByteDecode_12 k times to decode t_hat
* 3: rho <- ek_PKE[384k : 384K + 32]
* 3: rho <- ek_PKE[384k : 384k + 32]
* > extract 32-byte seed from ek_PKE
* 4-8: generate matrix A_hat
* 9-12: generate y
@@ -889,7 +896,7 @@ static int mlkemkey_encapsulate(MlKemKey* key, const byte* m, byte* r, byte* c)
}
if (ret == 0) {
/* Assign remaining allocated dynamic memory to pointers.
* y (v) | a (m) | mu (p) | e1 (p) | r2 (v) | u (v) | v (p)*/
* y (b) | a (m) | mu (p) | e1 (p) | e2 (v) | u (v) | v (p) */
u = e2 + MLKEM_N;
v = u + MLKEM_N * k;
@@ -1034,7 +1041,7 @@ static int wc_mlkemkey_check_h(MlKemKey* key)
* @param [out] k Shared secret generated.
* @param [in] rng Random number generator.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, ct, ss or RNG is NULL.
* @return BAD_FUNC_ARG when key, c, k or rng is NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return MEMORY_E when dynamic memory allocation failed.
*/
@@ -1075,7 +1082,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
* ciphertext.
* Step 1: (K,r) <- G(m||H(ek))
* > derive shared secret key K and randomness r
* Step 2: c <- K-PPKE.Encrypt(ek, m, r)
* Step 2: c <- K-PKE.Encrypt(ek, m, r)
* > encrypt m using K-PKE with randomness r
* Step 3: return (K,c)
*
@@ -1084,7 +1091,7 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
* @param [in] m Random bytes.
* @param [in] len Length of random bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, c, k or RNG is NULL.
* @return BAD_FUNC_ARG when key, c, k or m is NULL.
* @return BUFFER_E when len is not WC_ML_KEM_ENC_RAND_SZ.
* @return NOT_COMPILED_IN when key type is not supported.
* @return MEMORY_E when dynamic memory allocation failed.
@@ -1248,16 +1255,16 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
* FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE,c)
* Uses the decryption key to decrypt a ciphertext.
* 1: c1 <- c[0 : 32.d_u.k]
* 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)]
* 3: u' <= Decompress_d_u(ByteDecode_d_u(c1))
* 4: v' <= Decompress_d_v(ByteDecode_d_v(c2))
* 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)]
* 3: u' <- Decompress_d_u(ByteDecode_d_u(c1))
* 4: v' <- Decompress_d_v(ByteDecode_d_v(c2))
* ...
* 6: w <- v' - InvNTT(s_hat_trans o NTT(u'))
* 7: m <- ByteEncode_1(Compress_1(w))
* 8: return m
*
* @param [in] key Kyber key object.
* @param [out] m Message than was encapsulated.
* @param [out] m Message that was encapsulated.
* @param [in] c Cipher text.
* @return 0 on success.
* @return NOT_COMPILED_IN when key type is not supported.
@@ -1340,7 +1347,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
if (ret == 0) {
/* Step 1: c1 <- c[0 : 32.d_u.k] */
const byte* c1 = c;
/* Step 2: c2 <= c[32.d_u.k : 32(d_u.k + d_v)] */
/* Step 2: c2 <- c[32.d_u.k : 32(d_u.k + d_v)] */
const byte* c2 = c + compVecSz;
/* Assign allocated dynamic memory to pointers.
@@ -1350,25 +1357,25 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
if (k == WC_ML_KEM_512_K) {
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
mlkem_vec_decompress_10(u, c1, k);
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
mlkem_decompress_4(v, c2);
}
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
if (k == WC_ML_KEM_768_K) {
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
mlkem_vec_decompress_10(u, c1, k);
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
mlkem_decompress_4(v, c2);
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
if (k == WC_ML_KEM_1024_K) {
/* Step 3: u' <= Decompress_d_u(ByteDecode_d_u(c1)) */
/* Step 3: u' <- Decompress_d_u(ByteDecode_d_u(c1)) */
mlkem_vec_decompress_11(u, c1);
/* Step 4: v' <= Decompress_d_v(ByteDecode_d_v(c2)) */
/* Step 4: v' <- Decompress_d_v(ByteDecode_d_v(c2)) */
mlkem_decompress_5(v, c2);
}
#endif
@@ -1408,11 +1415,11 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
* ...
* 1: dk_PKE <- dk[0 : 384k]
* > extract (from KEM decaps key) the PKE decryption key
* 2: ek_PKE <- dk[384k : 768l + 32]
* 2: ek_PKE <- dk[384k : 768k + 32]
* > extract PKE encryption key
* 3: h <- dk[768K + 32 : 768k + 64]
* 3: h <- dk[768k + 32 : 768k + 64]
* > extract hash of PKE encryption key
* 4: z <- dk[768K + 64 : 768k + 96]
* 4: z <- dk[768k + 64 : 768k + 96]
* > extract implicit rejection value
* 5: m' <- K-PKE.Decrypt(dk_PKE, c) > decrypt ciphertext
* 6: (K', r') <- G(m'||h)
@@ -1420,7 +1427,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
* 8: c' <- K-PKE.Encrypt(ek_PKE, m', r')
* > re-encrypt using the derived randomness r'
* 9: if c != c' then
* 10: K' <= K_bar
* 10: K' <- K_bar
* > if ciphertexts do not match, "implicitly reject"
* 11: end if
* 12: return K'
@@ -1430,7 +1437,7 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
* @param [in] ct Cipher text.
* @param [in] len Length of cipher text.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, ss or cr are NULL.
* @return BAD_FUNC_ARG when key, ss or ct are NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the length of cipher text for the key type.
* @return MEMORY_E when dynamic memory allocation failed.
@@ -1588,7 +1595,7 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
/**
* Get the public key and public seed from bytes.
*
* FIPS 203, Algorithm 14 K-PKE.Encrypt(ek_PKE, m, r)
* FIPS 203, Algorithm 14: K-PKE.Encrypt(ek_PKE, m, r)
* ...
* 2: t <- ByteDecode_12(ek_PKE[0 : 384k])
* 3: rho <- ek_PKE[384k : 384k + 32]
@@ -1624,16 +1631,16 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p,
* FIPS 203, Algorithm 18: ML-KEM.Decaps_internal(dk, c)
* 1: dk_PKE <- dk[0 : 384k]
* > extract (from KEM decaps key) the PKE decryption key
* 2: ek_PKE <- dk[384k : 768l + 32]
* 2: ek_PKE <- dk[384k : 768k + 32]
* > extract PKE encryption key
* 3: h <- dk[768K + 32 : 768k + 64]
* 3: h <- dk[768k + 32 : 768k + 64]
* > extract hash of PKE encryption key
* 4: z <- dk[768K + 64 : 768k + 96]
* 4: z <- dk[768k + 64 : 768k + 96]
* > extract implicit rejection value
*
* FIPS 203, Algorithm 15: K-PKE.Decrypt(dk_PKE, c)
* ...
* 5: s_hat <= ByteDecode_12(dk_PKE)
* 5: s_hat <- ByteDecode_12(dk_PKE)
* ...
*
* @param [in, out] key Kyber key object.
@@ -1729,14 +1736,21 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
/* Compute the hash of the public key. */
ret = MLKEM_HASH_H(&key->hash, p, pubLen, key->h);
p += pubLen;
if (ret != 0) {
ForceZero(key->priv, k * MLKEM_N);
}
}
if (ret == 0) {
p += pubLen;
/* Compare computed public key hash with stored hash */
if (XMEMCMP(key->h, p, WC_ML_KEM_SYM_SZ) != 0)
if (XMEMCMP(key->h, p, WC_ML_KEM_SYM_SZ) != 0) {
ForceZero(key->priv, k * MLKEM_N);
ret = MLKEM_PUB_HASH_E;
}
}
if (ret == 0) {
/* Copy the hash of the encoded public key that is after public key. */
XMEMCPY(key->h, p, sizeof(key->h));
p += WC_ML_KEM_SYM_SZ;
File diff suppressed because it is too large Load Diff