ML-KEM: fix comments, API signatures, minor issues

More checks for public or private key not set.
wc_MlKemKey_Free clears key->flags
wc_MlKemKey_DecodePrivateKey now checks the public key is valid.
wc_MlKemKey_EncodePrivateKey doesn't need calculate hash of public key
as encoding the public key will do this.
EncodePrivateKey/EncodePublicKey now return BAD_STATE_E when flags not
set.
mlkem_kdf, mlkem_check_public, mlkem_xof_absorb pointer parameters are
now const.
Now all mlkem_redistribute_*_rand_avx2 functions are WOLFSSL_LOCAL.
Changed Kyber uses to MlKem.
This commit is contained in:
Sean Parkinson
2026-05-06 22:12:05 +10:00
parent 980fc51ea7
commit 15398c26d0
3 changed files with 273 additions and 195 deletions
+187 -114
View File
@@ -28,30 +28,30 @@
* post-quantum-cryptography-standardization/round-3-submissions
*/
/* Possible Kyber options:
/* Possible ML-KEM options:
*
* WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM Default: OFF
* WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM Default: OFF
* Uses less dynamic memory to perform key generation.
* Has a small performance trade-off.
* Only usable with C implementation.
*
* WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM Default: OFF
* WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM Default: OFF
* Uses less dynamic memory to perform encapsulation.
* Affects decapsulation too as encapsulation called.
* Has a small performance trade-off.
* Only usable with C implementation.
*
* WOLFSSL_MLKEM_NO_MAKE_KEY Default: OFF
* WOLFSSL_MLKEM_NO_MAKE_KEY Default: OFF
* Disable the make key or key generation API.
* Reduces the code size.
* Turn on when only doing encapsulation.
*
* WOLFSSL_MLKEM_NO_ENCAPSULATE Default: OFF
* WOLFSSL_MLKEM_NO_ENCAPSULATE Default: OFF
* Disable the encapsulation API.
* Reduces the code size.
* Turn on when doing make key/decapsulation.
*
* WOLFSSL_MLKEM_NO_DECAPSULATE Default: OFF
* WOLFSSL_MLKEM_NO_DECAPSULATE Default: OFF
* Disable the decapsulation API.
* Reduces the code size.
* Turn on when only doing encapsulation.
@@ -59,7 +59,7 @@
* WOLFSSL_MLKEM_CACHE_A Default: OFF
* Stores the matrix A during key generation for use in encapsulation when
* performing decapsulation.
* KyberKey is 8KB larger but decapsulation is significantly faster.
* MlKemKey is 8KB larger but decapsulation is significantly faster.
* Turn on when performing make key and decapsulation with same object.
*
* WOLFSSL_MLKEM_DYNAMIC_KEYS Default: OFF
@@ -282,6 +282,8 @@ static int mlkemkey_alloc_pub(MlKemKey* key, unsigned int k)
*/
static int mlkemkey_alloc_a(MlKemKey* key, unsigned int k)
{
int ret = 0;
if (key->a != NULL) {
XFREE(key->a, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
key->a = NULL;
@@ -289,9 +291,10 @@ static int mlkemkey_alloc_a(MlKemKey* key, unsigned int k)
key->a = (sword16*)XMALLOC(k * k * MLKEM_N * sizeof(sword16), key->heap,
DYNAMIC_TYPE_TMP_BUFFER);
if (key->a == NULL) {
return MEMORY_E;
ret = MEMORY_E;
}
return 0;
return ret;
}
#endif /* WOLFSSL_MLKEM_CACHE_A */
#endif /* WOLFSSL_MLKEM_DYNAMIC_KEYS */
@@ -304,19 +307,20 @@ static int mlkemkey_alloc_a(MlKemKey* key, unsigned int k)
*
* Allocates and initializes a ML-KEM key object.
*
* @param [in] type Type of key:
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
* KYBER512, KYBER768, KYBER1024.
* @param [in] heap Dynamic memory hint.
* @param [in] devId Device Id.
* @return Pointer to new MlKemKey object, or NULL on failure.
* @param [in] type Type of key:
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
* KYBER512, KYBER768, KYBER1024.
* @param [in] heap Dynamic memory hint.
* @param [in] devId Device Id.
* @return Pointer to new MlKemKey object on success.
* @return NULL on failure.
*/
MlKemKey* wc_MlKemKey_New(int type, void* heap, int devId)
{
int ret;
MlKemKey* key = (MlKemKey*)XMALLOC(sizeof(MlKemKey), heap,
DYNAMIC_TYPE_TMP_BUFFER);
MlKemKey* key;
key = (MlKemKey*)XMALLOC(sizeof(MlKemKey), heap, DYNAMIC_TYPE_TMP_BUFFER);
if (key != NULL) {
ret = wc_MlKemKey_Init(key, type, heap, devId);
if (ret != 0) {
@@ -333,31 +337,36 @@ MlKemKey* wc_MlKemKey_New(int type, void* heap, int devId)
*
* Frees resources associated with a ML-KEM key object and sets pointer to NULL.
*
* @param [in] key ML-KEM key object to delete.
* @param [in, out] key_p Pointer to key pointer to set to NULL.
* @param [in] key ML-KEM key object to delete.
* @param [in, out] key_p Pointer to key pointer to set to NULL.
* @return 0 on success.
* @return BAD_FUNC_ARG when key is NULL.
*/
int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p)
{
void* heap;
if (key == NULL)
return BAD_FUNC_ARG;
heap = key->heap;
wc_MlKemKey_Free(key);
XFREE(key, heap, DYNAMIC_TYPE_TMP_BUFFER);
if (key_p != NULL)
*key_p = NULL;
int ret = 0;
return 0;
if (key == NULL) {
ret = BAD_FUNC_ARG;
}
else {
void* heap = key->heap;
wc_MlKemKey_Free(key);
XFREE(key, heap, DYNAMIC_TYPE_TMP_BUFFER);
if (key_p != NULL) {
*key_p = NULL;
}
}
return ret;
}
#endif /* !WC_NO_CONSTRUCTORS */
/**
* Initialize the Kyber key.
* Initialize the ML-KEM key.
*
* @param [out] key Kyber key object to initialize.
* @param [out] key ML-KEM key object to initialize.
* @param [in] type Type of key:
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
* KYBER512, KYBER768, KYBER1024.
@@ -381,19 +390,19 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
#ifndef WOLFSSL_NO_ML_KEM
case WC_ML_KEM_512:
#ifndef WOLFSSL_WC_ML_KEM_512
/* Code not compiled in for Kyber-512. */
/* Code not compiled in for ML-KEM-512. */
ret = NOT_COMPILED_IN;
#endif
break;
case WC_ML_KEM_768:
#ifndef WOLFSSL_WC_ML_KEM_768
/* Code not compiled in for Kyber-768. */
/* Code not compiled in for ML-KEM-768. */
ret = NOT_COMPILED_IN;
#endif
break;
case WC_ML_KEM_1024:
#ifndef WOLFSSL_WC_ML_KEM_1024
/* Code not compiled in for Kyber-1024. */
/* Code not compiled in for ML-KEM-1024. */
ret = NOT_COMPILED_IN;
#endif
break;
@@ -468,22 +477,42 @@ int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
}
#ifdef WOLF_PRIVATE_KEY_ID
/**
* Initialize the ML-KEM key with an id.
*
* @param [out] key ML-KEM key object to initialize.
* @param [in] type Type of key:
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
* KYBER512, KYBER768, KYBER1024.
* @param [in] id Identifier of key.
* @param [in] len Length of key identifier in bytes.
* @param [in] heap Dynamic memory hint.
* @param [in] devId Device Id.
* @return 0 on success.
* @return BAD_FUNC_ARG when key is NULL, id is NULL but len is not zero, or
* type is unrecognized.
* @return BUFFER_E when len is out of range.
* @return NOT_COMPILED_IN when key type is not supported.
*/
int wc_MlKemKey_Init_Id(MlKemKey* key, int type, const unsigned char* id,
int len, void* heap, int devId)
{
int ret = 0;
if (key == NULL || (id == NULL && len != 0)) {
/* Validate parameters. */
if ((key == NULL) || (id == NULL && len != 0)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0 && (len < 0 || len > MLKEM_MAX_ID_LEN)) {
if ((ret == 0) && ((len < 0) || (len > MLKEM_MAX_ID_LEN))) {
ret = BUFFER_E;
}
if (ret == 0) {
/* Initialize key. */
ret = wc_MlKemKey_Init(key, type, heap, devId);
}
if (ret == 0 && id != NULL && len != 0) {
if ((ret == 0) && (id != NULL) && (len != 0)) {
/* Store key identifier. */
XMEMCPY(key->id, id, (size_t)len);
key->idLen = len;
}
@@ -491,16 +520,33 @@ int wc_MlKemKey_Init_Id(MlKemKey* key, int type, const unsigned char* id,
return ret;
}
/**
* Initialize the ML-KEM key with a label.
*
* @param [out] key ML-KEM key object to initialize.
* @param [in] type Type of key:
* WC_ML_KEM_512, WC_ML_KEM_768, WC_ML_KEM_1024,
* KYBER512, KYBER768, KYBER1024.
* @param [in] label Label of key. Must be a null-terminated string.
* @param [in] heap Dynamic memory hint.
* @param [in] devId Device Id.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or label is NULL, or type is unrecognized.
* @return BUFFER_E when label is too small or big.
* @return NOT_COMPILED_IN when key type is not supported.
*/
int wc_MlKemKey_Init_Label(MlKemKey* key, int type, const char* label,
void* heap, int devId)
{
int ret = 0;
int labelLen = 0;
if (key == NULL || label == NULL) {
/* Validate parameters. */
if ((key == NULL) || (label == NULL)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
/* Validate label length. */
labelLen = (int)XSTRLEN(label);
if ((labelLen == 0) || (labelLen > MLKEM_MAX_LABEL_LEN)) {
ret = BUFFER_E;
@@ -508,10 +554,11 @@ int wc_MlKemKey_Init_Label(MlKemKey* key, int type, const char* label,
}
if (ret == 0) {
/* Initialize key. */
ret = wc_MlKemKey_Init(key, type, heap, devId);
}
if (ret == 0) {
/* The string in key->label is not necessarily null-terminated.
/* Don't save string in key->label with null terminator.
* Use key->labelLen to get the length if required. */
XMEMCPY(key->label, label, (size_t)labelLen);
key->labelLen = labelLen;
@@ -522,9 +569,9 @@ int wc_MlKemKey_Init_Label(MlKemKey* key, int type, const char* label,
#endif
/**
* Free the Kyber key object.
* Free the ML-KEM key object.
*
* @param [in, out] key Kyber key object to dispose of.
* @param [in, out] key ML-KEM key object to dispose of.
* @return 0 on success.
*/
int wc_MlKemKey_Free(MlKemKey* key)
@@ -533,9 +580,7 @@ int wc_MlKemKey_Free(MlKemKey* key)
#if defined(WOLF_CRYPTO_CB) && defined(WOLF_CRYPTO_CB_FREE)
if (key->devId != INVALID_DEVID) {
(void)wc_CryptoCb_Free(key->devId, WC_ALGO_TYPE_PK,
WC_PK_TYPE_PQC_KEM_KEYGEN,
WC_PQC_KEM_TYPE_KYBER,
(void*)key);
WC_PK_TYPE_PQC_KEM_KEYGEN, WC_PQC_KEM_TYPE_KYBER, (void*)key);
/* always continue to software cleanup */
}
#endif
@@ -567,6 +612,9 @@ int wc_MlKemKey_Free(MlKemKey* key)
ForceZero(key->priv, sizeof(key->priv));
#endif
ForceZero(key->z, sizeof(key->z));
/* Clear flags as values are no longer set. */
key->flags = 0;
}
return 0;
@@ -576,7 +624,7 @@ int wc_MlKemKey_Free(MlKemKey* key)
#ifndef WOLFSSL_MLKEM_NO_MAKE_KEY
/**
* Make a Kyber key object using a random number generator.
* Make a ML-KEM key object using a random number generator.
*
* FIPS 203 - Algorithm 19: ML-KEM.KeyGen()
* Generates an encapsulation key and a corresponding decapsulation key.
@@ -590,13 +638,17 @@ int wc_MlKemKey_Free(MlKemKey* key)
* > run internal key generation algorithm
* 7: return (ek,dk)
*
* @param [in, out] key Kyber key object.
* @param [in, out] key ML-KEM key object.
* @param [in] rng Random number generator.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or rng is NULL.
* @return MEMORY_E when dynamic memory allocation failed.
* @return RNG_FAILURE_E when generating random numbers failed.
* @return DRBG_CONT_FAILURE when random number generator health check fails.
* @return ML_KEM_PCT_E when pairwise consistency test fails. FIPS only.
* @return BAD_COND_E when fault attack detected.
* @return NOT_COMPILED_IN when no random number generator is compiled in or
* key type is not supported.
*/
int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
{
@@ -615,8 +667,8 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
#else
if (ret == 0) {
#endif
ret = wc_CryptoCb_MakePqcKemKey(rng, WC_PQC_KEM_TYPE_KYBER,
key->type, key);
ret = wc_CryptoCb_MakePqcKemKey(rng, WC_PQC_KEM_TYPE_KYBER, key->type,
key);
if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
return ret;
/* fall-through when unavailable */
@@ -637,7 +689,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
* Step 6. run internal key generation algorithm
* Step 7. public and private key are stored in key
*/
ret = wc_KyberKey_MakeKeyWithRandom(key, rand, sizeof(rand));
ret = wc_MlKemKey_MakeKeyWithRandom(key, rand, sizeof(rand));
}
#ifdef HAVE_FIPS
@@ -697,7 +749,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
}
/**
* Make a Kyber key object using random data.
* Make a ML-KEM key object using random data.
*
* FIPS 203 - Algorithm 16: ML-KEM.KeyGen_internal(d,z)
* Uses randomness to generate an encapsulation key and a corresponding
@@ -717,7 +769,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
* 16-18: calculate t_hat from A_hat, s and e
* ...
*
* @param [in, out] key Kyber key object.
* @param [in, out] key ML-KEM key object.
* @param [in] rand Random data.
* @param [in] len Length of random data in bytes.
* @return 0 on success.
@@ -725,6 +777,7 @@ int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
* @return BUFFER_E when length is not WC_ML_KEM_MAKEKEY_RAND_SZ.
* @return NOT_COMPILED_IN when key type is not supported.
* @return MEMORY_E when dynamic memory allocation failed.
* @return BAD_COND_E when fault attack detected.
*/
int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
int len)
@@ -846,11 +899,12 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
#ifdef WC_MLKEM_FAULT_HARDEN
if (ret == 0) {
XMEMCPY(sigma, buf + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ);
/* Check that correct data was copied and pointer not changed. */
/* Check that correct data was copied and pointer was not faulted. */
if (XMEMCMP(sigma, rho, WC_ML_KEM_SYM_SZ) == 0) {
ret = BAD_COND_E;
}
/* Check that rho is sigma - rho may have been modified. */
/* Check that sigma is after rho - rho pointer may have been modified.
*/
if (XMEMCMP(sigma, rho + WC_ML_KEM_SYM_SZ, WC_ML_KEM_SYM_SZ) != 0) {
ret = BAD_COND_E;
}
@@ -928,7 +982,7 @@ int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
/**
* Get the size in bytes of cipher text for key.
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [out] len Length of cipher text in bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or len is NULL.
@@ -991,10 +1045,10 @@ int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len)
}
/**
* Size of a shared secret in bytes. Always KYBER_SS_SZ.
* Size of a shared secret in bytes. Always WC_ML_KEM_SS_SZ.
*
* @param [in] key Kyber key object. Not used.
* @param [out] len Size of the shared secret created with a Kyber key.
* @param [in] key ML-KEM key object. Not used.
* @param [out] len Size of the shared secret created with a ML-KEM key.
* @return 0 on success.
* @return BAD_FUNC_ARG when len is NULL.
*/
@@ -1037,7 +1091,7 @@ int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
* 23: c_2 <- ByteEncode_d_v(Compress_d_v(v))
* 24: return c <- (c_1||c_2)
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [in] m Random bytes.
* @param [in] r Seed to feed to PRF when generating y, e1 and e2.
* @param [out] c Calculated cipher text.
@@ -1270,7 +1324,7 @@ static int wc_mlkemkey_check_h(MlKemKey* key)
#endif
/* Determine how big an encoded public key will be. */
ret = wc_KyberKey_PublicKeySize(key, &pubKeyLen);
ret = wc_MlKemKey_PublicKeySize(key, &pubKeyLen);
if (ret == 0) {
#ifndef WOLFSSL_NO_MALLOC
/* Allocate dynamic memory for encoded public key. */
@@ -1283,15 +1337,15 @@ static int wc_mlkemkey_check_h(MlKemKey* key)
if (ret == 0) {
#endif
/* Encode public key - h is hash of encoded public key. */
ret = wc_KyberKey_EncodePublicKey(key, pubKey, pubKeyLen);
ret = wc_MlKemKey_EncodePublicKey(key, pubKey, pubKeyLen);
}
#ifndef WOLFSSL_NO_MALLOC
/* Dispose of encoded public key. */
XFREE(pubKey, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
#endif
}
if ((ret == 0) && ((key->flags & MLKEM_FLAG_H_SET) == 0)) {
/* Implementation issue if h not cached and flag set. */
/* Implementation issue if h not cached and flag not set. */
ret = BAD_STATE_E;
}
@@ -1314,16 +1368,17 @@ static int wc_mlkemkey_check_h(MlKemKey* key)
* > run internal encapsulation algorithm
* 6: return (K,c)
*
* @param [in] key Kyber key object.
* @param [out] c Cipher text.
* @param [out] k Shared secret generated.
* @param [in] key ML-KEM key object.
* @param [out] ct Cipher text.
* @param [out] ss Shared secret generated.
* @param [in] rng Random number generator.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, c, k or rng is NULL.
* @return BAD_FUNC_ARG when key, ct, ss or rng is NULL.
* @return BAD_STATE_E when public key not set.
* @return NOT_COMPILED_IN when key type is not supported.
* @return MEMORY_E when dynamic memory allocation failed.
*/
int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* ct, unsigned char* ss,
WC_RNG* rng)
{
#ifndef WC_NO_RNG
@@ -1334,9 +1389,13 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
#endif
/* Validate parameters. */
if ((key == NULL) || (c == NULL) || (k == NULL) || (rng == NULL)) {
if ((key == NULL) || (ct == NULL) || (ss == NULL) || (rng == NULL)) {
ret = BAD_FUNC_ARG;
}
/* Check the public key has been set. */
else if ((key->flags & MLKEM_FLAG_PUB_SET) == 0) {
ret = BAD_STATE_E;
}
#ifdef WOLF_CRYPTO_CB
if (ret == 0) {
@@ -1347,8 +1406,8 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
#else
if (ret == 0) {
#endif
ret = wc_CryptoCb_PqcEncapsulate(c, ctlen, k, KYBER_SS_SZ, rng,
WC_PQC_KEM_TYPE_KYBER, key);
ret = wc_CryptoCb_PqcEncapsulate(ct, ctlen, ss, WC_ML_KEM_SS_SZ, rng,
WC_PQC_KEM_TYPE_KYBER, key);
if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
return ret;
/* fall-through when unavailable */
@@ -1367,15 +1426,15 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
/* Encapsulate with the random.
* Step 5: run internal encapsulation algorithm
*/
ret = wc_KyberKey_EncapsulateWithRandom(key, c, k, m, sizeof(m));
ret = wc_MlKemKey_EncapsulateWithRandom(key, ct, ss, m, sizeof(m));
}
/* Step 3: return ret != 0 on falsum or internal key generation failure. */
return ret;
#else
(void)key;
(void)c;
(void)k;
(void)ct;
(void)ss;
(void)rng;
return NOT_COMPILED_IN;
#endif /* WC_NO_RNG */
@@ -1393,35 +1452,41 @@ int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* c, unsigned char* k,
* > encrypt m using K-PKE with randomness r
* Step 3: return (K,c)
*
* @param [out] c Cipher text.
* @param [out] k Shared secret generated.
* @param [in] m Random bytes.
* @param [in] len Length of random bytes.
* @param [in] key ML-KEM key object.
* @param [out] ct Cipher text.
* @param [out] ss Shared secret generated.
* @param [in] rand Random bytes.
* @param [in] len Length of random bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, c, k or m is NULL.
* @return BAD_FUNC_ARG when key, ct, ss or rand is NULL.
* @return BUFFER_E when len is not WC_ML_KEM_ENC_RAND_SZ.
* @return BAD_STATE_E when public key not set.
* @return NOT_COMPILED_IN when key type is not supported.
* @return MEMORY_E when dynamic memory allocation failed.
*/
int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
unsigned char* k, const unsigned char* m, int len)
int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* ct,
unsigned char* ss, const unsigned char* rand, int len)
{
#ifdef WOLFSSL_MLKEM_KYBER
byte msg[KYBER_SYM_SZ];
byte msg[WC_ML_KEM_SYM_SZ];
#endif
byte kr[2 * KYBER_SYM_SZ + 1];
byte kr[2 * WC_ML_KEM_SYM_SZ + 1];
int ret = 0;
#ifdef WOLFSSL_MLKEM_KYBER
unsigned int cSz = 0;
#endif
/* Validate parameters. */
if ((key == NULL) || (c == NULL) || (k == NULL) || (m == NULL)) {
if ((key == NULL) || (ct == NULL) || (ss == NULL) || (rand == NULL)) {
ret = BAD_FUNC_ARG;
}
if ((ret == 0) && (len != WC_ML_KEM_ENC_RAND_SZ)) {
ret = BUFFER_E;
}
/* Check the public key has been set. */
if ((ret == 0) && ((key->flags & MLKEM_FLAG_PUB_SET) == 0)) {
ret = BAD_STATE_E;
}
#ifdef WOLFSSL_MLKEM_KYBER
if (ret == 0) {
@@ -1473,7 +1538,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
#endif
{
/* Hash random to anonymize as seed data. */
ret = MLKEM_HASH_H(&key->hash, m, WC_ML_KEM_SYM_SZ, msg);
ret = MLKEM_HASH_H(&key->hash, rand, WC_ML_KEM_SYM_SZ, msg);
}
}
#endif
@@ -1494,7 +1559,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
#ifndef WOLFSSL_NO_ML_KEM
{
/* Step 1: (K,r) <- G(m||H(ek)) */
ret = MLKEM_HASH_G(&key->hash, m, WC_ML_KEM_SYM_SZ, key->h,
ret = MLKEM_HASH_G(&key->hash, rand, WC_ML_KEM_SYM_SZ, key->h,
WC_ML_KEM_SYM_SZ, kr);
}
#endif
@@ -1507,7 +1572,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
#endif
#ifdef WOLFSSL_MLKEM_KYBER
{
ret = mlkemkey_encapsulate(key, msg, kr + WC_ML_KEM_SYM_SZ, c);
ret = mlkemkey_encapsulate(key, msg, kr + WC_ML_KEM_SYM_SZ, ct);
}
#endif
#if defined(WOLFSSL_MLKEM_KYBER) && !defined(WOLFSSL_NO_ML_KEM)
@@ -1516,7 +1581,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
#ifndef WOLFSSL_NO_ML_KEM
{
/* Step 2: c <- K-PKE.Encrypt(ek,m,r) */
ret = mlkemkey_encapsulate(key, m, kr + WC_ML_KEM_SYM_SZ, c);
ret = mlkemkey_encapsulate(key, rand, kr + WC_ML_KEM_SYM_SZ, ct);
}
#endif
}
@@ -1528,11 +1593,11 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
{
if (ret == 0) {
/* Hash the cipher text after the seed. */
ret = MLKEM_HASH_H(&key->hash, c, cSz, kr + WC_ML_KEM_SYM_SZ);
ret = MLKEM_HASH_H(&key->hash, ct, cSz, kr + WC_ML_KEM_SYM_SZ);
}
if (ret == 0) {
/* Derive the secret from the seed and hash of cipher text. */
ret = MLKEM_KDF(kr, 2 * WC_ML_KEM_SYM_SZ, k, WC_ML_KEM_SS_SZ);
ret = MLKEM_KDF(kr, 2 * WC_ML_KEM_SYM_SZ, ss, WC_ML_KEM_SS_SZ);
}
}
#endif
@@ -1543,7 +1608,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
{
if (ret == 0) {
/* return (K,c) */
XMEMCPY(k, kr, WC_ML_KEM_SS_SZ);
XMEMCPY(ss, kr, WC_ML_KEM_SS_SZ);
}
}
#endif
@@ -1570,7 +1635,7 @@ int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* c,
* 7: m <- ByteEncode_1(Compress_1(w))
* 8: return m
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [out] m Message that was encapsulated.
* @param [in] c Cipher text.
* @return 0 on success.
@@ -1739,12 +1804,13 @@ static MLKEM_NOINLINE int mlkemkey_decapsulate(MlKemKey* key, byte* m,
* 11: end if
* 12: return K'
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [out] ss Shared secret.
* @param [in] ct Cipher text.
* @param [in] len Length of cipher text.
* @return 0 on success.
* @return BAD_FUNC_ARG when key, ss or ct are NULL.
* @return BAD_STATE_E when private key is not set.
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the length of cipher text for the key type.
* @return MEMORY_E when dynamic memory allocation failed.
@@ -1827,8 +1893,8 @@ int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
#else
if (ret == 0) {
#endif
ret = wc_CryptoCb_PqcDecapsulate(ct, ctSz, ss, KYBER_SS_SZ,
WC_PQC_KEM_TYPE_KYBER, key);
ret = wc_CryptoCb_PqcDecapsulate(ct, ctSz, ss, WC_ML_KEM_SS_SZ,
WC_PQC_KEM_TYPE_KYBER, key);
if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
return ret;
/* fall-through when unavailable */
@@ -1968,13 +2034,16 @@ static void mlkemkey_decode_public(sword16* pub, byte* pubSeed, const byte* p,
* 5: s_hat <- ByteDecode_12(dk_PKE)
* ...
*
* @param [in, out] key Kyber key object.
* @param [in, out] key ML-KEM key object.
* @param [in] in Buffer holding encoded key.
* @param [in] len Length of data in buffer.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or in is NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
* @return PUBLIC_KEY_E when public key data doesn't match parameters.
* @return MLKEM_PUB_HASH_E when public key hash doesn't match stored hash.
* @return MEMORY_E when dynamic memory allocation failed.
*/
int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
word32 len)
@@ -2067,6 +2136,12 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
/* Decode the public key that is after the private key. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_public(key->pub, (int)k);
if (ret != 0) {
ForceZero(key->priv, k * MLKEM_N * sizeof(sword16));
}
}
if (ret == 0) {
/* Compute the hash of the public key. */
ret = MLKEM_HASH_H(&key->hash, p, pubLen, key->h);
if (ret != 0) {
@@ -2102,13 +2177,15 @@ int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
*
* Public vector | Public Seed
*
* @param [in, out] key Kyber key object.
* @param [in, out] key ML-KEM key object.
* @param [in] in Buffer holding encoded key.
* @param [in] len Length of data in buffer.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or in is NULL.
* @return NOT_COMPILED_IN when key type is not supported.
* @return BUFFER_E when len is not the correct size.
* @return PUBLIC_KEY_E when public key data doesn't match parameters.
* @return MEMORY_E when dynamic memory allocation failed.
*/
int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in,
word32 len)
@@ -2182,6 +2259,7 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in,
}
#endif
if (ret == 0) {
/* Decode public key and check public key matches parameters. */
mlkemkey_decode_public(key->pub, key->pubSeed, p, k);
ret = mlkem_check_public(key->pub, (int)k);
}
@@ -2200,7 +2278,7 @@ int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in,
/**
* Get the size in bytes of encoded private key for the key.
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [out] len Length of encoded private key in bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or len is NULL.
@@ -2266,7 +2344,7 @@ int wc_MlKemKey_PrivateKeySize(MlKemKey* key, word32* len)
/**
* Get the size in bytes of encoded public key for the key.
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [out] len Length of encoded public key in bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or len is NULL.
@@ -2343,12 +2421,12 @@ int wc_MlKemKey_PublicKeySize(MlKemKey* key, word32* len)
* 20: dk_PKE <- ByteEncode_12(s_hat)
* ...
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [out] out Buffer to hold data.
* @param [in] len Size of buffer in bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or out is NULL or private/public key not
* available.
* @return BAD_FUNC_ARG when key or out is NULL.
* @return BAD_STATE_E when private/public key not available.
* @return NOT_COMPILED_IN when key type is not supported.
*/
int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len)
@@ -2364,7 +2442,7 @@ int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len)
}
if ((ret == 0) &&
((key->flags & MLKEM_FLAG_BOTH_SET) != MLKEM_FLAG_BOTH_SET)) {
ret = BAD_FUNC_ARG;
ret = BAD_STATE_E;
}
if (ret == 0) {
@@ -2431,17 +2509,11 @@ int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len)
mlkem_to_bytes(p, key->priv, (int)k);
p += WC_ML_KEM_POLY_SIZE * k;
/* Encode public key. */
ret = wc_KyberKey_EncodePublicKey(key, p, pubLen);
/* Encode public key - calculates hash of public key. */
ret = wc_MlKemKey_EncodePublicKey(key, p, pubLen);
p += pubLen;
}
/* Ensure hash of public key is available. */
if ((ret == 0) && ((key->flags & MLKEM_FLAG_H_SET) == 0)) {
ret = MLKEM_HASH_H(&key->hash, p - pubLen, pubLen, key->h);
}
if (ret == 0) {
/* Public hash is available. */
key->flags |= MLKEM_FLAG_H_SET;
/* Append public hash. */
XMEMCPY(p, key->h, sizeof(key->h));
p += WC_ML_KEM_SYM_SZ;
@@ -2466,11 +2538,12 @@ int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len)
* 19: ek_PKE <- ByteEncode_12(t_hat)||rho
* ...
*
* @param [in] key Kyber key object.
* @param [in] key ML-KEM key object.
* @param [out] out Buffer to hold data.
* @param [in] len Size of buffer in bytes.
* @return 0 on success.
* @return BAD_FUNC_ARG when key or out is NULL or public key not available.
* @return BAD_FUNC_ARG when key or out is NULL.
* @return BAD_STATE_E when public key not available.
* @return NOT_COMPILED_IN when key type is not supported.
*/
int wc_MlKemKey_EncodePublicKey(MlKemKey* key, unsigned char* out, word32 len)
@@ -2485,7 +2558,7 @@ int wc_MlKemKey_EncodePublicKey(MlKemKey* key, unsigned char* out, word32 len)
}
if ((ret == 0) &&
((key->flags & MLKEM_FLAG_PUB_SET) != MLKEM_FLAG_PUB_SET)) {
ret = BAD_FUNC_ARG;
ret = BAD_STATE_E;
}
if (ret == 0) {
+80 -76
View File
@@ -32,9 +32,9 @@
* polynomials.
*/
/* Possible Kyber options:
/* Possible ML-KEM options:
*
* WOLFSSL_HAVE_MLKEM Default: OFF
* WOLFSSL_HAVE_MLKEM Default: OFF
* Enables this code, wolfSSL implementation, to be built.
*
* WOLFSSL_WC_ML_KEM_512 Default: OFF
@@ -112,7 +112,7 @@ static cpuid_flags_t cpuid_flags = WC_CPUID_INITIALIZER;
#define MLKEM_Q_HALF (MLKEM_Q / 2)
/* q^-1 mod 2^16 (inverse of 3329 mod 16384) */
/* q^-1 mod 2^16 (inverse of 3329 mod 65536) */
#define MLKEM_QINV 62209
/* Used in Barrett Reduction:
@@ -1062,7 +1062,7 @@ static void mlkem_basemul(sword16* r, const sword16* a, const sword16* b,
* 1: for (i <- 0; i < 128; i++)
* 2: (h_hat[2i],h_hat[2i+1]) <-
* BaseCaseMultiply(f_hat[2i],f_hat[2i+1],g_hat[2i],g_hat[2i+1],
* zetas^(BitRev_7(i)+1)
* zetas^(BitRev_7(i)+1))
* 3: end for
* 4: return h_hat
*
@@ -1115,7 +1115,7 @@ static void mlkem_basemul_mont(sword16* r, const sword16* a, const sword16* b)
* 1: for (i <- 0; i < 128; i++)
* 2: (h_hat[2i],h_hat[2i+1]) <-
* BaseCaseMultiply(f_hat[2i],f_hat[2i+1],g_hat[2i],g_hat[2i+1],
* zetas^(BitRev_7(i)+1)
* zetas^(BitRev_7(i)+1))
* 3: end for
* 4: return h_hat
* Add h_hat to r.
@@ -1237,7 +1237,7 @@ static void mlkem_pointwise_acc_mont(sword16* r, const sword16* a,
/******************************************************************************/
/* Initialize Kyber implementation.
/* Initialize ML-KEM implementation.
*/
void mlkem_init(void)
{
@@ -1285,7 +1285,7 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k)
/* Multiply a by private into public polynomial.
* Step 18: ... A_hat o s_hat ... */
mlkem_pointwise_acc_mont(t + i * MLKEM_N, a + i * k * MLKEM_N, s,
k);
(unsigned int)k);
/* Convert public polynomial to Montgomery form.
* Step 18: ... MontRed(A_hat o s_hat) ... */
mlkem_to_mont_sqrdmlsh(t + i * MLKEM_N);
@@ -1312,7 +1312,7 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k)
/* Multiply a by private into public polynomial.
* Step 18: ... A_hat o s_hat ... */
mlkem_pointwise_acc_mont(t + i * MLKEM_N, a + i * k * MLKEM_N, s,
k);
(unsigned int)k);
/* Convert public polynomial to Montgomery form.
* Step 18: ... MontRed(A_hat o s_hat) ... */
mlkem_to_mont(t + i * MLKEM_N);
@@ -1349,7 +1349,7 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k)
* @param [in] m Message polynomial.
* @param [in] k Number of polynomials in vector.
*/
void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v,
void mlkem_encapsulate(const sword16* t, sword16* u, sword16* v,
const sword16* a, sword16* y, const sword16* e1, const sword16* e2,
const sword16* m, int k)
{
@@ -1364,25 +1364,25 @@ void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v,
}
/* For each polynomial in the vectors.
* Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1) */
* Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1 */
for (i = 0; i < k; ++i) {
/* Multiply at by y into u polynomial.
* Step 19: ... A_hat_trans o y_hat ... */
mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y,
k);
/* Inverse transform u polynomial.
(unsigned int)k);
/* Inverse transform u polynomial.
* Step 19: ... InvNTT(A_hat_trans o y_hat) ... */
mlkem_invntt_sqrdmlsh(u + i * MLKEM_N);
/* Add errors to u and reduce.
* Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1) */
mlkem_add_reduce(u + i * MLKEM_N, e1 + i * MLKEM_N);
mlkem_invntt_sqrdmlsh(u + i * MLKEM_N);
/* Add errors to u and reduce.
* Step 19: u <- InvNTT(A_hat_trans o y_hat) + e_1 */
mlkem_add_reduce(u + i * MLKEM_N, e1 + i * MLKEM_N);
}
/* Multiply public key by y into v polynomial.
* Step 21: ... t_hat_trans o y_hat ... */
mlkem_pointwise_acc_mont(v, t, y, k);
mlkem_pointwise_acc_mont(v, t, y, (unsigned int)k);
/* Inverse transform v.
* Step 22: ... InvNTT(t_hat_trans o y_hat) ... */
* Step 21: ... InvNTT(t_hat_trans o y_hat) ... */
mlkem_invntt_sqrdmlsh(v);
}
else
@@ -1400,8 +1400,8 @@ void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v,
/* Multiply at by y into u polynomial.
* Step 19: ... A_hat_trans o y_hat ... */
mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y,
k);
/* Inverse transform u polynomial.
(unsigned int)k);
/* Inverse transform u polynomial.
* Step 19: ... InvNTT(A_hat_trans o y_hat) ... */
mlkem_invntt(u + i * MLKEM_N);
/* Add errors to u and reduce.
@@ -1411,9 +1411,9 @@ void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v,
/* Multiply public key by y into v polynomial.
* Step 21: ... t_hat_trans o y_hat ... */
mlkem_pointwise_acc_mont(v, t, y, k);
mlkem_pointwise_acc_mont(v, t, y, (unsigned int)k);
/* Inverse transform v.
* Step 22: ... InvNTT(t_hat_trans o y_hat) ... */
* Step 21: ... InvNTT(t_hat_trans o y_hat) ... */
mlkem_invntt(v);
}
/* Add errors and message to v and reduce.
@@ -1452,7 +1452,7 @@ void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u,
/* Multiply private key by u into w polynomial.
* Step 6: ... s_hat_trans o NTT(u') */
mlkem_pointwise_acc_mont(w, s, u, k);
mlkem_pointwise_acc_mont(w, s, u, (unsigned int)k);
/* Inverse transform w.
* Step 6: ... InvNTT(s_hat_trans o NTT(u')) */
mlkem_invntt_sqrdmlsh(w);
@@ -1468,7 +1468,7 @@ void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u,
/* Multiply private key by u into w polynomial.
* Step 6: ... s_hat_trans o NTT(u') */
mlkem_pointwise_acc_mont(w, s, u, k);
mlkem_pointwise_acc_mont(w, s, u, (unsigned int)k);
/* Inverse transform w.
* Step 6: ... InvNTT(s_hat_trans o NTT(u')) */
mlkem_invntt(w);
@@ -1863,7 +1863,7 @@ static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a,
void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k)
{
#ifdef USE_INTEL_SPEEDUP
if ((IS_INTEL_AVX2(cpuid_flags)) && (SAVE_VECTOR_REGISTERS2() == 0)) {
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
/* Alg 13: Steps 16-18 */
mlkem_keygen_avx2(s, t, e, a, k);
RESTORE_VECTOR_REGISTERS();
@@ -1898,7 +1898,11 @@ void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k)
* @param [in] tv Temporary vector of polynomials.
* @param [in] k Number of polynomials in vector.
* @param [in] rho Random seed to generate matrix A from.
* @param [in] sigma Random seed to generate noise from.
* @param [in, out] sigma Random seed to generate noise from.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* @return Other negative value when a hash error occurred.
*/
int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf,
sword16* tv, int k, byte* rho, byte* sigma)
@@ -2087,7 +2091,11 @@ void mlkem_encapsulate(const sword16* pub, sword16* u, sword16* v,
* @param [in] k Number of polynomials in vector.
* @param [in] msg Message to encapsulate.
* @param [in] seed Random seed to generate matrix A from.
* @param [in] coins Random seed to generate noise from.
* @param [in, out] coins Random seed to generate noise from.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* @return Other negative value when a hash error occurred.
*/
int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* u,
sword16* tp, sword16* y, int k, const byte* msg, byte* seed, byte* coins)
@@ -2283,7 +2291,7 @@ void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u,
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_k2_avx2(sword16* a, byte* seed, int transposed)
{
@@ -2395,7 +2403,7 @@ static int mlkem_gen_matrix_k2_avx2(sword16* a, byte* seed, int transposed)
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_k3_avx2(sword16* a, byte* seed, int transposed)
{
@@ -2553,7 +2561,7 @@ static int mlkem_gen_matrix_k3_avx2(sword16* a, byte* seed, int transposed)
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_k4_avx2(sword16* a, byte* seed, int transposed)
{
@@ -2665,8 +2673,6 @@ static int mlkem_gen_matrix_k4_avx2(sword16* a, byte* seed, int transposed)
* @param [in] seed Bytes to seed XOF generation.
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_k2_aarch64(sword16* a, byte* seed, int transposed)
{
@@ -2739,8 +2745,6 @@ static int mlkem_gen_matrix_k2_aarch64(sword16* a, byte* seed, int transposed)
* @param [in] seed Bytes to seed XOF generation.
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_k3_aarch64(sword16* a, byte* seed, int transposed)
{
@@ -2805,8 +2809,6 @@ static int mlkem_gen_matrix_k3_aarch64(sword16* a, byte* seed, int transposed)
* @param [in] seed Bytes to seed XOF generation.
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_k4_aarch64(sword16* a, byte* seed, int transposed)
{
@@ -2891,7 +2893,7 @@ static int mlkem_gen_matrix_k4_aarch64(sword16* a, byte* seed, int transposed)
* @param [in] len Length of data to absorb in bytes.
* @return 0 on success always.
*/
static int mlkem_xof_absorb(wc_Shake* shake128, byte* seed, int len)
static int mlkem_xof_absorb(wc_Shake* shake128, const byte* seed, int len)
{
int ret;
@@ -2992,7 +2994,7 @@ int mlkem_hash512(wc_Sha3* hash, const byte* data1, word32 data1Len,
/* Process first block of data. */
ret = wc_Sha3_512_Update(hash, data1, data1Len);
/* Check if there is a second block of data. */
if ((ret == 0) && (data2Len > 0)) {
if ((ret == 0) && (data2 != NULL) && (data2Len > 0)) {
/* Process second block of data. */
ret = wc_Sha3_512_Update(hash, data2, data2Len);
}
@@ -3125,7 +3127,7 @@ static int mlkem_prf(wc_Shake* shake256, byte* out, unsigned int outLen,
* @param [in] outLen Number of bytes to derive.
* @return 0 on success always.
*/
int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen)
int mlkem_kdf(const byte* seed, int seedLen, byte* out, int outLen)
{
word64 state[25];
word32 len64 = seedLen / 8;
@@ -3163,7 +3165,7 @@ int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen)
* @param [in] outLen Number of bytes to derive.
* @return 0 on success always.
*/
int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen)
int mlkem_kdf(const byte* seed, int seedLen, byte* out, int outLen)
{
word64 state[25];
word32 len64 = seedLen / 8;
@@ -3184,41 +3186,41 @@ int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen)
#ifndef WOLFSSL_NO_ML_KEM
/* Derive the secret from z and cipher text.
*
* @param [in, out] shake256 SHAKE-256 object.
* @param [in] z Implicit rejection value.
* @param [in] ct Cipher text.
* @param [in] ctSz Length of cipher text in bytes.
* @param [out] ss Shared secret.
* @param [in, out] prf SHAKE-256 object.
* @param [in] z Implicit rejection value.
* @param [in] ct Cipher text.
* @param [in] ctSz Length of cipher text in bytes.
* @param [out] ss Shared secret.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation failed.
* @return Other negative value when a hash error occurred.
*/
int mlkem_derive_secret(wc_Shake* shake256, const byte* z, const byte* ct,
int mlkem_derive_secret(wc_Shake* prf, const byte* z, const byte* ct,
word32 ctSz, byte* ss)
{
int ret;
#ifdef USE_INTEL_SPEEDUP
XMEMCPY(shake256->t, z, WC_ML_KEM_SYM_SZ);
XMEMCPY(shake256->t + WC_ML_KEM_SYM_SZ, ct,
XMEMCPY(prf->t, z, WC_ML_KEM_SYM_SZ);
XMEMCPY(prf->t + WC_ML_KEM_SYM_SZ, ct,
WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ);
shake256->i = WC_ML_KEM_SYM_SZ + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
prf->i = WC_ML_KEM_SYM_SZ + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ct += WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ctSz -= WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ret = wc_Shake256_Update(shake256, ct, ctSz);
ret = wc_Shake256_Update(prf, ct, ctSz);
if (ret == 0) {
ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ);
ret = wc_Shake256_Final(prf, ss, WC_ML_KEM_SS_SZ);
}
#else
ret = wc_InitShake256(shake256, NULL, INVALID_DEVID);
ret = wc_InitShake256(prf, NULL, INVALID_DEVID);
if (ret == 0) {
ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ);
ret = wc_Shake256_Update(prf, z, WC_ML_KEM_SYM_SZ);
}
if (ret == 0) {
ret = wc_Shake256_Update(shake256, ct, ctSz);
ret = wc_Shake256_Update(prf, ct, ctSz);
}
if (ret == 0) {
ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ);
ret = wc_Shake256_Final(prf, ss, WC_ML_KEM_SS_SZ);
}
#endif
@@ -3427,7 +3429,7 @@ static unsigned int mlkem_rej_uniform_c(sword16* p, unsigned int len,
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_c(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
int transposed)
@@ -3530,7 +3532,7 @@ static int mlkem_gen_matrix_c(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* WOLFSSL_SMALL_STACK is defined.
*/
int mlkem_gen_matrix(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
int transposed)
@@ -3634,7 +3636,7 @@ int mlkem_gen_matrix(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
* @param [in] transposed Whether A or A^T is generated.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_gen_matrix_i(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
int i, int transposed)
@@ -3729,7 +3731,7 @@ static int mlkem_gen_matrix_i(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
*
* @param [in] d Value containing sequential 2 bit values.
* @param [in] i Start index of the two values in 2 bits each.
* @return Difference of the two values with range 0..2.
* @return Difference of the two values with range -2..2.
*/
#define ETA2_SUB(d, i) \
(sword16)(((sword16)(((d) >> ((i) * 4 + 0)) & 0x3)) - \
@@ -3845,7 +3847,7 @@ static void mlkem_cbd_eta2(sword16* p, const byte* r)
*
* @param [in] d Value containing sequential 3 bit values.
* @param [in] i Start index of the two values in 3 bits each.
* @return Difference of the two values with range 0..3.
* @return Difference of the two values with range -3..3.
*/
#define ETA3_SUB(d, i) \
(sword16)(((sword16)(((d) >> ((i) * 6 + 0)) & 0x7)) - \
@@ -4220,6 +4222,8 @@ static void mlkem_get_noise_x4_eta3_avx2(byte* rand, byte* seed)
* @param [out] poly Polynomial.
* @param [in, out] seed Seed to use when calculating random.
* @return 0 on success.
* @return MEMORY_E when dynamic memory allocation fails. Only possible when
* WOLFSSL_SMALL_STACK is defined.
*/
static int mlkem_get_noise_k2_avx2(MLKEM_PRF_T* prf, sword16* vec1,
sword16* vec2, sword16* poly, byte* seed)
@@ -4559,7 +4563,7 @@ static int mlkem_get_noise_k4_aarch64(sword16* vec1, sword16* vec2,
* @param [out] vec2 Second Vector of polynomials.
* @param [in] eta2 Size of noise/error integers with second vector.
* @param [out] poly Polynomial.
* @param [in] seed Seed to use when calculating random.
* @param [in, out] seed Seed to use when calculating random.
* @return 0 on success.
*/
static int mlkem_get_noise_c(MLKEM_PRF_T* prf, int k, sword16* vec1, int eta1,
@@ -4598,7 +4602,7 @@ static int mlkem_get_noise_c(MLKEM_PRF_T* prf, int k, sword16* vec1, int eta1,
return ret;
}
#endif /* __aarch64__ && WOLFSSL_ARMASM */
#endif /* !(__aarch64__ && WOLFSSL_ARMASM) */
/* Get the noise/error by calculating random bytes and sampling to a binomial
* distribution.
@@ -4697,7 +4701,7 @@ int mlkem_get_noise(MLKEM_PRF_T* prf, int k, sword16* vec1, sword16* vec2,
* @param [in, out] prf Pseudo-random function object.
* @param [in] k Number of polynomials in vector.
* @param [out] vec2 Second Vector of polynomials.
* @param [in] seed Seed to use when calculating random.
* @param [in, out] seed Seed to use when calculating random.
* @param [in] i Index of vector to generate.
* @param [in] make Indicates generation is for making a key.
* @return 0 on success.
@@ -5147,8 +5151,8 @@ static void mlkem_vec_compress_11_c(byte* r, sword16* v)
*
* FIPS 203, Section 4.2.1, Compression and decompression
*
* @param [out] r Array of bytes.
* @param [in] v Vector of polynomials.
* @param [out] r Array of bytes.
* @param [in, out] v Vector of polynomials.
*/
void mlkem_vec_compress_11(byte* r, sword16* v)
{
@@ -5839,7 +5843,7 @@ void mlkem_from_msg(sword16* p, const byte* msg)
*
* Uses div operator that may be slow.
*
* FIPS 203, Algorithm 6: ByteEncode_d(F)
* FIPS 203, Algorithm 5: ByteEncode_d(F)
*
* @param [in, out] m Message.
* @param [in] p Polynomial.
@@ -5862,7 +5866,7 @@ void mlkem_from_msg(sword16* p, const byte* msg)
*
* Uses mul instead of div.
*
* FIPS 203, Algorithm 6: ByteEncode_d(F)
* FIPS 203, Algorithm 5: ByteEncode_d(F)
*
* @param [in, out] m Message.
* @param [in] p Polynomial.
@@ -5877,7 +5881,7 @@ void mlkem_from_msg(sword16* p, const byte* msg)
/* Convert polynomial to message.
*
* FIPS 203, Algorithm 6: ByteEncode_d(F)
* FIPS 203, Algorithm 5: ByteEncode_d(F)
*
* @param [out] msg Message as a byte array.
* @param [in, out] p Polynomial.
@@ -5913,7 +5917,7 @@ static void mlkem_to_msg_c(byte* msg, sword16* p)
/* Convert polynomial to message.
*
* FIPS 203, Algorithm 6: ByteEncode_d(F)
* FIPS 203, Algorithm 5: ByteEncode_d(F)
*
* @param [out] msg Message as a byte array.
* @param [in, out] p Polynomial.
@@ -5952,7 +5956,7 @@ void mlkem_from_msg(sword16* p, const byte* msg)
#ifndef WOLFSSL_MLKEM_NO_DECAPSULATE
/* Convert polynomial to message.
*
* FIPS 203, Algorithm 6: ByteEncode_d(F)
* FIPS 203, Algorithm 5: ByteEncode_d(F)
*
* @param [out] msg Message as a byte array.
* @param [in, out] p Polynomial.
@@ -6031,7 +6035,7 @@ void mlkem_from_bytes(sword16* p, const byte* b, int k)
* Consecutive 12 bits hold each coefficient of polynomial.
* Used in encoding private and public keys.
*
* FIPS 203, Algorithm 6: ByteEncode_d(F)
* FIPS 203, Algorithm 5: ByteEncode_d(F)
*
* @param [out] b Array of bytes.
* @param [in, out] p Polynomial.
@@ -6064,7 +6068,7 @@ static void mlkem_to_bytes_c(byte* b, sword16* p, int k)
* Consecutive 12 bits hold each coefficient of polynomial.
* Used in encoding private and public keys.
*
* FIPS 203, Algorithm 6: ByteEncode_d(F)
* FIPS 203, Algorithm 5: ByteEncode_d(F)
*
* @param [out] b Array of bytes.
* @param [in, out] p Polynomial.
@@ -6094,18 +6098,18 @@ void mlkem_to_bytes(byte* b, sword16* p, int k)
/**
* Check the public key values are smaller than the modulus.
*
* @param [in] pub Public key - vector.
* @param [in] k Number of polynomials in vector.
* @param [in] p Public key - vector.
* @param [in] k Number of polynomials in vector.
* @return 0 when all values are in range.
* @return PUBLIC_KEY_E when at least one value is out of range.
*/
int mlkem_check_public(sword16* pub, int k)
int mlkem_check_public(const sword16* p, int k)
{
int ret = 0;
int i;
for (i = 0; i < k * MLKEM_N; i++) {
if (pub[i] >= MLKEM_Q) {
if (p[i] >= MLKEM_Q) {
ret = PUBLIC_KEY_E;
break;
}
+6 -5
View File
@@ -422,7 +422,7 @@ typedef struct MlKemKey {
WOLFSSL_API MlKemKey* wc_MlKemKey_New(int type, void* heap, int devId);
WOLFSSL_API int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p);
WOLFSSL_API int wc_MlKemKey_Delete(MlKemKey* key, MlKemKey** key_p);
WOLFSSL_API int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap,
int devId);
@@ -522,11 +522,9 @@ int mlkem_get_noise(MLKEM_PRF_T* prf, int kp, sword16* vec1, sword16* vec2,
#if defined(USE_INTEL_SPEEDUP) || \
(defined(WOLFSSL_ARMASM) && defined(__aarch64__))
WOLFSSL_LOCAL
int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen);
int mlkem_kdf(const byte* seed, int seedLen, byte* out, int outLen);
#endif
WOLFSSL_LOCAL
void mlkem_hash_init(MLKEM_HASH_T* hash);
WOLFSSL_LOCAL
int mlkem_hash_new(MLKEM_HASH_T* hash, void* heap, int devId);
WOLFSSL_LOCAL
void mlkem_hash_free(MLKEM_HASH_T* hash);
@@ -578,7 +576,7 @@ void mlkem_from_bytes(sword16* p, const byte* b, int k);
WOLFSSL_LOCAL
void mlkem_to_bytes(byte* b, sword16* p, int k);
WOLFSSL_LOCAL
int mlkem_check_public(sword16* p, int k);
int mlkem_check_public(const sword16* p, int k);
#ifdef USE_INTEL_SPEEDUP
WOLFSSL_LOCAL
@@ -601,10 +599,13 @@ unsigned int mlkem_rej_uniform_avx2(sword16* p, unsigned int len, const byte* r,
WOLFSSL_LOCAL
void mlkem_redistribute_21_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);
WOLFSSL_LOCAL
void mlkem_redistribute_17_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);
WOLFSSL_LOCAL
void mlkem_redistribute_16_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);
WOLFSSL_LOCAL
void mlkem_redistribute_8_rand_avx2(const word64* s, byte* r0, byte* r1,
byte* r2, byte* r3);