ML-KEM/Kyber: cache A from key generation for decapsulation

Matrix A is expensive to calculate.
Usage of ML-KEM/Kyber is
  1. First peer generates a key and sends public to second peer.
2. Second peer encapsulates secret with public key and sends to first
peer.
3. First peer decapsulates (including encapsulating to ensure same as
seen) with key from key generation.
Caching A keeps the matrix A for encapsulation part of decapsulation.
The matrix needs to be transposed for encapsulation.
This commit is contained in:
Sean Parkinson
2025-02-11 09:19:16 +10:00
parent 4373e551e7
commit 9253d1d3ac
4 changed files with 96 additions and 26 deletions

View File

@@ -1399,6 +1399,9 @@ do
small)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_KYBER_SMALL"
;;
cache-a)
AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_MLKEM_CACHE_A"
;;
512)
ENABLED_KYBER512=yes
;;

View File

@@ -9630,17 +9630,37 @@ exit:
#endif
}
static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
static void bench_kyber_encap(int type, const char* name, int keySize,
KyberKey* key1, KyberKey* key2)
{
int ret = 0, times, count, pending = 0;
double start;
const char**desc = bench_desc_words[lng_index];
byte ct[KYBER_MAX_CIPHER_TEXT_SIZE];
byte ss[KYBER_SS_SZ];
byte pub[KYBER_MAX_PUBLIC_KEY_SIZE];
word32 pubLen;
word32 ctSz;
DECLARE_MULTI_VALUE_STATS_VARS()
ret = wc_KyberKey_CipherTextSize(key, &ctSz);
ret = wc_KyberKey_PublicKeySize(key1, &pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_EncodePublicKey(key1, pub, pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_Init(type, key2, HEAP_HINT, INVALID_DEVID);
if (ret != 0) {
return;
}
ret = wc_KyberKey_DecodePublicKey(key2, pub, pubLen);
if (ret != 0) {
return;
}
ret = wc_KyberKey_CipherTextSize(key2, &ctSz);
if (ret != 0) {
return;
}
@@ -9651,10 +9671,10 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key)
/* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) {
#ifdef KYBER_NONDETERMINISTIC
ret = wc_KyberKey_Encapsulate(key, ct, ss, &gRng);
ret = wc_KyberKey_Encapsulate(key2, ct, ss, &gRng);
#else
unsigned char rand[KYBER_ENC_RAND_SZ] = {0,};
ret = wc_KyberKey_EncapsulateWithRandom(key, ct, ss, rand,
ret = wc_KyberKey_EncapsulateWithRandom(key2, ct, ss, rand,
sizeof(rand));
#endif
if (ret != 0)
@@ -9681,7 +9701,7 @@ exit_encap:
do {
/* while free pending slots in queue, submit ops */
for (times = 0; times < agreeTimes || pending > 0; times++) {
ret = wc_KyberKey_Decapsulate(key, ss, ct, ctSz);
ret = wc_KyberKey_Decapsulate(key1, ss, ct, ctSz);
if (ret != 0)
goto exit_decap;
RECORD_MULTI_VALUE_STATS();
@@ -9702,7 +9722,8 @@ exit_decap:
void bench_kyber(int type)
{
KyberKey key;
KyberKey key1;
KyberKey key2;
const char* name = NULL;
int keySize = 0;
@@ -9749,10 +9770,11 @@ void bench_kyber(int type)
#endif
}
bench_kyber_keygen(type, name, keySize, &key);
bench_kyber_encap(name, keySize, &key);
bench_kyber_keygen(type, name, keySize, &key1);
bench_kyber_encap(type, name, keySize, &key1, &key2);
wc_KyberKey_Free(&key);
wc_KyberKey_Free(&key2);
wc_KyberKey_Free(&key1);
}
#endif

View File

@@ -63,6 +63,12 @@
#error "Can't use small memory with assembly optimized code"
#endif
#endif
#if defined(WOLFSSL_MLKEM_CACHE_A)
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
#error "Can't cache A with small memory code"
#endif
#endif
#ifdef WOLFSSL_WC_KYBER
@@ -265,10 +271,14 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
sword16* e = NULL;
#else
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
sword16 e[(KYBER_MAX_K + 1) * KYBER_MAX_K * KYBER_N];
#else
sword16 e[KYBER_MAX_K * KYBER_N];
#endif
#else
sword16 e[KYBER_MAX_K * KYBER_N];
#endif
#endif
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
sword16* a = NULL;
@@ -285,6 +295,8 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
}
if (ret == 0) {
key->flags = 0;
/* Establish parameters based on key type. */
switch (key->type) {
#ifndef WOLFSSL_NO_ML_KEM
@@ -332,9 +344,17 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) {
/* Allocate dynamic memory for matrix and error vector. */
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
#ifndef WOLFSSL_MLKEM_CACHE_A
/* e (v) | a (m) */
e = (sword16*)XMALLOC((kp + 1) * kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
#else
/* e (v) */
e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
@@ -346,8 +366,10 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
if (ret == 0) {
const byte* d = rand;
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
/* Error vector allocated at end of a. */
#ifdef WOLFSSL_MLKEM_CACHE_A
a = key->a;
#elif !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM)
/* Matrix A allocated at end of error vector. */
a = e + (kp * KYBER_N);
#endif
@@ -391,6 +413,9 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand,
ret = kyber_gen_matrix(&key->prf, a, kp, pubSeed, 0);
}
if (ret == 0) {
#ifdef WOLFSSL_MLKEM_CACHE_A
key->flags |= KYBER_FLAG_A_SET;
#endif
/* Generate key pair from random data. */
kyber_keygen(key->priv, key->pub, e, a, kp);
#else
@@ -514,7 +539,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned char* ct)
{
int ret = 0;
sword16* sp = NULL;
sword16* at = NULL;
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16* k = NULL;
sword16* ep = NULL;
@@ -523,12 +548,12 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
unsigned int kp = 0;
unsigned int compVecSz = 0;
#ifndef WOLFSSL_NO_MALLOC
sword16* at = NULL;
sword16* sp = NULL;
#else
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
sword16 at[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
sword16 sp[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N];
#else
sword16 at[3 * KYBER_MAX_K * KYBER_N];
sword16 sp[3 * KYBER_MAX_K * KYBER_N];
#endif
#endif
#ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
@@ -588,13 +613,13 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) {
/* Allocate dynamic memory for all matrices, vectors and polynomials. */
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
at = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
sp = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16),
key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#else
at = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
sp = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap,
DYNAMIC_TYPE_TMP_BUFFER);
#endif
if (at == NULL) {
if (sp == NULL) {
ret = MEMORY_E;
}
}
@@ -603,15 +628,15 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
if (ret == 0) {
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
/* Assign allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p) */
* sp (b) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p) */
at = sp + KYBER_N * kp;
k = at + KYBER_N * kp * kp;
sp = k + KYBER_N;
ep = sp + KYBER_N * kp;
ep = k + KYBER_N;
epp = ep + KYBER_N * kp;
#else
/* Assign allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */
sp = at + KYBER_N * kp;
* sp (v) | at (v) | bp (v) */
at = sp + KYBER_N * kp;
#endif
/* Initialize the PRF for use in the noise generation. */
@@ -623,6 +648,21 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
/* Generate noise using PRF. */
ret = kyber_get_noise(&key->prf, kp, sp, ep, epp, coins);
}
#ifdef WOLFSSL_MLKEM_CACHE_A
if ((ret == 0) && ((key->flags & KYBER_FLAG_A_SET) != 0)) {
unsigned int i;
/* Transpose matrix. */
for (i = 0; i < kp; i++) {
unsigned int j;
for (j = 0; j < kp; j++) {
XMEMCPY(&at[(i * kp + j) * KYBER_N],
&key->a[(j * kp + i) * KYBER_N],
KYBER_N * 2);
}
}
}
else
#endif
if (ret == 0) {
/* Generate the transposed matrix. */
ret = kyber_gen_matrix(&key->prf, at, kp, key->pubSeed, 1);
@@ -632,7 +672,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
sword16* v;
/* Assign remaining allocated dynamic memory to pointers.
* at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p)*/
* sp (v) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p)*/
bp = epp + KYBER_N;
v = bp + KYBER_N * kp;
@@ -644,7 +684,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
}
if (ret == 0) {
/* Assign remaining allocated dynamic memory to pointers.
* at (v) | sp (v) | bp (v) */
* sp (v) | at (v) | bp (v) */
bp = sp + KYBER_N * kp;
v = at;
@@ -676,7 +716,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins,
#ifndef WOLFSSL_NO_MALLOC
/* Dispose of dynamic memory allocated in function. */
XFREE(at, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(sp, key->heap, DYNAMIC_TYPE_TMP_BUFFER);
#endif
return ret;

View File

@@ -62,6 +62,7 @@ enum {
KYBER_FLAG_PUB_SET = 0x0002,
KYBER_FLAG_BOTH_SET = 0x0003,
KYBER_FLAG_H_SET = 0x0004,
KYBER_FLAG_A_SET = 0x0008,
/* 2 bits of random used to create noise value. */
KYBER_CBD_ETA2 = 2,
@@ -137,6 +138,10 @@ struct KyberKey {
byte h[KYBER_SYM_SZ];
/* Randomizer for decapsulation. */
byte z[KYBER_SYM_SZ];
#ifdef WOLFSSL_MLKEM_CACHE_A
/* A matrix from key generation. */
sword16 a[KYBER_MAX_K * KYBER_MAX_K * KYBER_N];
#endif
};
#ifdef __cplusplus