diff --git a/wolfcrypt/src/wc_mlkem.c b/wolfcrypt/src/wc_mlkem.c index b7bc4a7c5..d9372f5f7 100644 --- a/wolfcrypt/src/wc_mlkem.c +++ b/wolfcrypt/src/wc_mlkem.c @@ -105,6 +105,42 @@ #ifdef WOLFSSL_WC_MLKEM +#ifdef DEBUG_MLKEM +void print_polys(const char* name, const sword16* a, int d1, int d2); +void print_polys(const char* name, const sword16* a, int d1, int d2) +{ + int i; + int j; + int k; + + fprintf(stderr, "%s: %d %d\n", name, d1, d2); + for (i = 0; i < d1; i++) { + for (j = 0; j < d2; j++) { + for (k = 0; k < 256; k++) { + fprintf(stderr, "%9d,", a[(i*d2*256) + (j*256) + k]); + if ((k % 8) == 7) fprintf(stderr, "\n"); + } + fprintf(stderr, "\n"); + } + } +} +#endif + +#ifdef DEBUG_MLKEM +void print_data(const char* name, const byte* d, int len); +void print_data(const char* name, const byte* d, int len) +{ + int i; + + fprintf(stderr, "%s\n", name); + for (i = 0; i < len; i++) { + fprintf(stderr, "0x%02x,", d[i]); + if ((i % 16) == 15) fprintf(stderr, "\n"); + } + fprintf(stderr, "\n"); +} +#endif + /******************************************************************************/ /* Use SHA3-256 to generate 32-bytes of hash. */ diff --git a/wolfcrypt/src/wc_mlkem_poly.c b/wolfcrypt/src/wc_mlkem_poly.c index 615f31cd1..b13d9305a 100644 --- a/wolfcrypt/src/wc_mlkem_poly.c +++ b/wolfcrypt/src/wc_mlkem_poly.c @@ -3184,8 +3184,9 @@ int mlkem_derive_secret(wc_Shake* shake256, const byte* z, const byte* ct, #ifdef USE_INTEL_SPEEDUP XMEMCPY(shake256->t, z, WC_ML_KEM_SYM_SZ); - XMEMCPY(shake256->t, ct, WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ); - shake256->i = WC_ML_KEM_SYM_SZ; + XMEMCPY(shake256->t + WC_ML_KEM_SYM_SZ, ct, + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ); + shake256->i = WC_ML_KEM_SYM_SZ + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ; ct += WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ; ctSz -= WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ; ret = wc_Shake256_Update(shake256, ct, ctSz); @@ -3193,7 +3194,10 @@ int mlkem_derive_secret(wc_Shake* shake256, const byte* z, const byte* ct, ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ); } #else - ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ); + ret = wc_InitShake256(shake256, NULL, INVALID_DEVID); + if (ret == 0) { + ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ); + } if (ret == 0) { ret = wc_Shake256_Update(shake256, ct, ctSz); }