From 9253d1d3acdd51f6e7479b530002f255315f4f4b Mon Sep 17 00:00:00 2001 From: Sean Parkinson Date: Tue, 11 Feb 2025 09:19:16 +1000 Subject: [PATCH] ML-KEM/Kyber: cache A from key generation for decapsulation Matrix A is expensive to calculate. Usage of ML-KEM/Kyber is 1. First peer generates a key and sends public to second peer. 2. Second peer encapsulates secret with public key and sends to first peer. 3. First peer decapsulates (including encapsulating to ensure same as seen) with key from key generation. Caching A keeps the matrix A for encapsulation part of decapsulation. The matrix needs to be transposed for encapsulation. --- configure.ac | 3 ++ wolfcrypt/benchmark/benchmark.c | 40 ++++++++++++++---- wolfcrypt/src/wc_kyber.c | 74 +++++++++++++++++++++++++-------- wolfssl/wolfcrypt/wc_kyber.h | 5 +++ 4 files changed, 96 insertions(+), 26 deletions(-) diff --git a/configure.ac b/configure.ac index 56293fb69..ca095d480 100644 --- a/configure.ac +++ b/configure.ac @@ -1399,6 +1399,9 @@ do small) AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_KYBER_SMALL" ;; + cache-a) + AM_CFLAGS="$AM_CFLAGS -DWOLFSSL_MLKEM_CACHE_A" + ;; 512) ENABLED_KYBER512=yes ;; diff --git a/wolfcrypt/benchmark/benchmark.c b/wolfcrypt/benchmark/benchmark.c index 5a8306f90..0d6f05522 100644 --- a/wolfcrypt/benchmark/benchmark.c +++ b/wolfcrypt/benchmark/benchmark.c @@ -9630,17 +9630,37 @@ exit: #endif } -static void bench_kyber_encap(const char* name, int keySize, KyberKey* key) +static void bench_kyber_encap(int type, const char* name, int keySize, + KyberKey* key1, KyberKey* key2) { int ret = 0, times, count, pending = 0; double start; const char**desc = bench_desc_words[lng_index]; byte ct[KYBER_MAX_CIPHER_TEXT_SIZE]; byte ss[KYBER_SS_SZ]; + byte pub[KYBER_MAX_PUBLIC_KEY_SIZE]; + word32 pubLen; word32 ctSz; DECLARE_MULTI_VALUE_STATS_VARS() - ret = wc_KyberKey_CipherTextSize(key, &ctSz); + ret = wc_KyberKey_PublicKeySize(key1, &pubLen); + if (ret != 0) { + return; + } + ret = wc_KyberKey_EncodePublicKey(key1, pub, pubLen); + if (ret != 0) { + return; + } + ret = wc_KyberKey_Init(type, key2, HEAP_HINT, INVALID_DEVID); + if (ret != 0) { + return; + } + ret = wc_KyberKey_DecodePublicKey(key2, pub, pubLen); + if (ret != 0) { + return; + } + + ret = wc_KyberKey_CipherTextSize(key2, &ctSz); if (ret != 0) { return; } @@ -9651,10 +9671,10 @@ static void bench_kyber_encap(const char* name, int keySize, KyberKey* key) /* while free pending slots in queue, submit ops */ for (times = 0; times < agreeTimes || pending > 0; times++) { #ifdef KYBER_NONDETERMINISTIC - ret = wc_KyberKey_Encapsulate(key, ct, ss, &gRng); + ret = wc_KyberKey_Encapsulate(key2, ct, ss, &gRng); #else unsigned char rand[KYBER_ENC_RAND_SZ] = {0,}; - ret = wc_KyberKey_EncapsulateWithRandom(key, ct, ss, rand, + ret = wc_KyberKey_EncapsulateWithRandom(key2, ct, ss, rand, sizeof(rand)); #endif if (ret != 0) @@ -9681,7 +9701,7 @@ exit_encap: do { /* while free pending slots in queue, submit ops */ for (times = 0; times < agreeTimes || pending > 0; times++) { - ret = wc_KyberKey_Decapsulate(key, ss, ct, ctSz); + ret = wc_KyberKey_Decapsulate(key1, ss, ct, ctSz); if (ret != 0) goto exit_decap; RECORD_MULTI_VALUE_STATS(); @@ -9702,7 +9722,8 @@ exit_decap: void bench_kyber(int type) { - KyberKey key; + KyberKey key1; + KyberKey key2; const char* name = NULL; int keySize = 0; @@ -9749,10 +9770,11 @@ void bench_kyber(int type) #endif } - bench_kyber_keygen(type, name, keySize, &key); - bench_kyber_encap(name, keySize, &key); + bench_kyber_keygen(type, name, keySize, &key1); + bench_kyber_encap(type, name, keySize, &key1, &key2); - wc_KyberKey_Free(&key); + wc_KyberKey_Free(&key2); + wc_KyberKey_Free(&key1); } #endif diff --git a/wolfcrypt/src/wc_kyber.c b/wolfcrypt/src/wc_kyber.c index 658d73a8d..d09fc7422 100644 --- a/wolfcrypt/src/wc_kyber.c +++ b/wolfcrypt/src/wc_kyber.c @@ -63,6 +63,12 @@ #error "Can't use small memory with assembly optimized code" #endif #endif +#if defined(WOLFSSL_MLKEM_CACHE_A) + #if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \ + defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM) + #error "Can't cache A with small memory code" + #endif +#endif #ifdef WOLFSSL_WC_KYBER @@ -265,10 +271,14 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand, sword16* e = NULL; #else #ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM +#ifndef WOLFSSL_MLKEM_CACHE_A sword16 e[(KYBER_MAX_K + 1) * KYBER_MAX_K * KYBER_N]; #else sword16 e[KYBER_MAX_K * KYBER_N]; #endif +#else + sword16 e[KYBER_MAX_K * KYBER_N]; +#endif #endif #ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM sword16* a = NULL; @@ -285,6 +295,8 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand, } if (ret == 0) { + key->flags = 0; + /* Establish parameters based on key type. */ switch (key->type) { #ifndef WOLFSSL_NO_ML_KEM @@ -332,9 +344,17 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand, if (ret == 0) { /* Allocate dynamic memory for matrix and error vector. */ #ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM +#ifndef WOLFSSL_MLKEM_CACHE_A + /* e (v) | a (m) */ e = (sword16*)XMALLOC((kp + 1) * kp * KYBER_N * sizeof(sword16), key->heap, DYNAMIC_TYPE_TMP_BUFFER); #else + /* e (v) */ + e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16), + key->heap, DYNAMIC_TYPE_TMP_BUFFER); +#endif +#else + /* e (v) */ e = (sword16*)XMALLOC(kp * KYBER_N * sizeof(sword16), key->heap, DYNAMIC_TYPE_TMP_BUFFER); #endif @@ -346,8 +366,10 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand, if (ret == 0) { const byte* d = rand; -#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM - /* Error vector allocated at end of a. */ +#ifdef WOLFSSL_MLKEM_CACHE_A + a = key->a; +#elif !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) + /* Matrix A allocated at end of error vector. */ a = e + (kp * KYBER_N); #endif @@ -391,6 +413,9 @@ int wc_KyberKey_MakeKeyWithRandom(KyberKey* key, const unsigned char* rand, ret = kyber_gen_matrix(&key->prf, a, kp, pubSeed, 0); } if (ret == 0) { +#ifdef WOLFSSL_MLKEM_CACHE_A + key->flags |= KYBER_FLAG_A_SET; +#endif /* Generate key pair from random data. */ kyber_keygen(key->priv, key->pub, e, a, kp); #else @@ -514,7 +539,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, unsigned char* ct) { int ret = 0; - sword16* sp = NULL; + sword16* at = NULL; #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM sword16* k = NULL; sword16* ep = NULL; @@ -523,12 +548,12 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, unsigned int kp = 0; unsigned int compVecSz = 0; #ifndef WOLFSSL_NO_MALLOC - sword16* at = NULL; + sword16* sp = NULL; #else #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM - sword16 at[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N]; + sword16 sp[((KYBER_MAX_K + 3) * KYBER_MAX_K + 3) * KYBER_N]; #else - sword16 at[3 * KYBER_MAX_K * KYBER_N]; + sword16 sp[3 * KYBER_MAX_K * KYBER_N]; #endif #endif #ifdef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM @@ -588,13 +613,13 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, if (ret == 0) { /* Allocate dynamic memory for all matrices, vectors and polynomials. */ #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM - at = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16), + sp = (sword16*)XMALLOC(((kp + 3) * kp + 3) * KYBER_N * sizeof(sword16), key->heap, DYNAMIC_TYPE_TMP_BUFFER); #else - at = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap, + sp = (sword16*)XMALLOC(3 * kp * KYBER_N * sizeof(sword16), key->heap, DYNAMIC_TYPE_TMP_BUFFER); #endif - if (at == NULL) { + if (sp == NULL) { ret = MEMORY_E; } } @@ -603,15 +628,15 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, if (ret == 0) { #ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM /* Assign allocated dynamic memory to pointers. - * at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p) */ + * sp (b) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p) */ + at = sp + KYBER_N * kp; k = at + KYBER_N * kp * kp; - sp = k + KYBER_N; - ep = sp + KYBER_N * kp; + ep = k + KYBER_N; epp = ep + KYBER_N * kp; #else /* Assign allocated dynamic memory to pointers. - * at (v) | sp (v) | bp (v) */ - sp = at + KYBER_N * kp; + * sp (v) | at (v) | bp (v) */ + at = sp + KYBER_N * kp; #endif /* Initialize the PRF for use in the noise generation. */ @@ -623,6 +648,21 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, /* Generate noise using PRF. */ ret = kyber_get_noise(&key->prf, kp, sp, ep, epp, coins); } +#ifdef WOLFSSL_MLKEM_CACHE_A + if ((ret == 0) && ((key->flags & KYBER_FLAG_A_SET) != 0)) { + unsigned int i; + /* Transpose matrix. */ + for (i = 0; i < kp; i++) { + unsigned int j; + for (j = 0; j < kp; j++) { + XMEMCPY(&at[(i * kp + j) * KYBER_N], + &key->a[(j * kp + i) * KYBER_N], + KYBER_N * 2); + } + } + } + else +#endif if (ret == 0) { /* Generate the transposed matrix. */ ret = kyber_gen_matrix(&key->prf, at, kp, key->pubSeed, 1); @@ -632,7 +672,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, sword16* v; /* Assign remaining allocated dynamic memory to pointers. - * at (m) | k (p) | sp (v) | ep (p) | epp (v) | bp (v) | v (p)*/ + * sp (v) | at (m) | k (p) | ep (p) | epp (v) | bp (v) | v (p)*/ bp = epp + KYBER_N; v = bp + KYBER_N * kp; @@ -644,7 +684,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, } if (ret == 0) { /* Assign remaining allocated dynamic memory to pointers. - * at (v) | sp (v) | bp (v) */ + * sp (v) | at (v) | bp (v) */ bp = sp + KYBER_N * kp; v = at; @@ -676,7 +716,7 @@ static int kyberkey_encapsulate(KyberKey* key, const byte* msg, byte* coins, #ifndef WOLFSSL_NO_MALLOC /* Dispose of dynamic memory allocated in function. */ - XFREE(at, key->heap, DYNAMIC_TYPE_TMP_BUFFER); + XFREE(sp, key->heap, DYNAMIC_TYPE_TMP_BUFFER); #endif return ret; diff --git a/wolfssl/wolfcrypt/wc_kyber.h b/wolfssl/wolfcrypt/wc_kyber.h index 4909f7b73..9b9163d0e 100644 --- a/wolfssl/wolfcrypt/wc_kyber.h +++ b/wolfssl/wolfcrypt/wc_kyber.h @@ -62,6 +62,7 @@ enum { KYBER_FLAG_PUB_SET = 0x0002, KYBER_FLAG_BOTH_SET = 0x0003, KYBER_FLAG_H_SET = 0x0004, + KYBER_FLAG_A_SET = 0x0008, /* 2 bits of random used to create noise value. */ KYBER_CBD_ETA2 = 2, @@ -137,6 +138,10 @@ struct KyberKey { byte h[KYBER_SYM_SZ]; /* Randomizer for decapsulation. */ byte z[KYBER_SYM_SZ]; +#ifdef WOLFSSL_MLKEM_CACHE_A + /* A matrix from key generation. */ + sword16 a[KYBER_MAX_K * KYBER_MAX_K * KYBER_N]; +#endif }; #ifdef __cplusplus