diff --git a/src/ssl_ech.c b/src/ssl_ech.c index 6eccc3f871..a880b34d0c 100644 --- a/src/ssl_ech.c +++ b/src/ssl_ech.c @@ -34,8 +34,8 @@ int wolfSSL_CTX_GenerateEchConfig(WOLFSSL_CTX* ctx, const char* publicName, word16 kemId, word16 kdfId, word16 aeadId) { int ret = 0; - word16 encLen = DHKEM_X25519_ENC_LEN; WOLFSSL_EchConfig* newConfig; + word16 encLen = sizeof(newConfig->receiverPubkey); #ifdef WOLFSSL_SMALL_STACK Hpke* hpke = NULL; WC_RNG* rng; diff --git a/tests/api.c b/tests/api.c index 9e9689b9e8..10aea8d3fd 100644 --- a/tests/api.c +++ b/tests/api.c @@ -14438,6 +14438,9 @@ static byte echCbTestConfigs[512]; static word32 echCbTestConfigsLen; static const char* echCbTestPublicName = "ech-public-name.com"; static const char* echCbTestPrivateName = "ech-private-name.com"; +static word16 echCbTestKemID = 0; +static word16 echCbTestKdfID = 0; +static word16 echCbTestAeadID = 0; /* the arg is whether the client has ech enabled or not */ static int test_ech_server_sni_callback(WOLFSSL* ssl, int* ad, void* arg) @@ -14469,7 +14472,8 @@ static int test_ech_server_ctx_ready(WOLFSSL_CTX* ctx) { int ret; - ret = wolfSSL_CTX_GenerateEchConfig(ctx, echCbTestPublicName, 0, 0, 0); + ret = wolfSSL_CTX_GenerateEchConfig(ctx, echCbTestPublicName, + echCbTestKemID, echCbTestKdfID, echCbTestAeadID); if (ret != WOLFSSL_SUCCESS) return TEST_FAIL; @@ -14512,6 +14516,65 @@ static int test_ech_client_ssl_ready(WOLFSSL* ssl) return TEST_SUCCESS; } +static int test_wolfSSL_Tls13_ECH_all_algos_ex(void) +{ + EXPECT_DECLS; + struct test_ssl_memio_ctx test_ctx; + + XMEMSET(&test_ctx, 0, sizeof(test_ctx)); + + test_ctx.s_cb.method = wolfTLSv1_3_server_method; + test_ctx.c_cb.method = wolfTLSv1_3_client_method; + + test_ctx.s_cb.ctx_ready = test_ech_server_ctx_ready; + test_ctx.s_cb.ssl_ready = test_ech_server_ssl_ready; + test_ctx.c_cb.ssl_ready = test_ech_client_ssl_ready; + + ExpectIntEQ(test_ssl_memio_setup(&test_ctx), TEST_SUCCESS); + + ExpectIntEQ(test_ssl_memio_do_handshake(&test_ctx, 10, NULL), TEST_SUCCESS); + ExpectIntEQ(test_ctx.c_ssl->options.echAccepted, 1); + + test_ssl_memio_cleanup(&test_ctx); + + return EXPECT_RESULT(); +} + +static int test_wolfSSL_Tls13_ECH_all_algos(void) +{ + EXPECT_DECLS; + int i; + int j; + int k; + static const word16 kems[] = { + DHKEM_P256_HKDF_SHA256, + DHKEM_P384_HKDF_SHA384, + DHKEM_P521_HKDF_SHA512, + DHKEM_X25519_HKDF_SHA256, + }; + static const word16 kdfs[] = { HKDF_SHA256, HKDF_SHA384, HKDF_SHA512 }; + static const word16 aeads[] = { HPKE_AES_128_GCM, HPKE_AES_256_GCM }; + + /* test each KEM with default KDF and AEAD */ + for (i = 0; i < (int)(sizeof(kems) / sizeof(*kems)); i++) { + echCbTestKemID = kems[i]; + for (j = 0; j < (int)(sizeof(kdfs) / sizeof(*kdfs)); j++) { + echCbTestKdfID = kdfs[j]; + for (k = 0; k < (int)(sizeof(aeads) / sizeof(*aeads)); k++) { + echCbTestAeadID = aeads[k]; + ExpectIntEQ(test_wolfSSL_Tls13_ECH_all_algos_ex(), + WOLFSSL_SUCCESS); + } + } + } + + echCbTestKemID = 0; + echCbTestKdfID = 0; + echCbTestAeadID = 0; + + return EXPECT_RESULT(); +} + /* Test ECH when no private SNI is set */ static int test_wolfSSL_Tls13_ECH_no_private_name(void) { @@ -34426,6 +34489,7 @@ TEST_CASE testCases[] = { TEST_DECL(test_wolfSSL_SubTls13_ECH), #endif #if defined(HAVE_SSL_MEMIO_TESTS_DEPENDENCIES) + TEST_DECL(test_wolfSSL_Tls13_ECH_all_algos), TEST_DECL(test_wolfSSL_Tls13_ECH_no_private_name), TEST_DECL(test_wolfSSL_Tls13_ECH_bad_configs), TEST_DECL(test_wolfSSL_Tls13_ECH_new_config), diff --git a/wolfcrypt/src/hpke.c b/wolfcrypt/src/hpke.c index 86b740c9d0..cf13e91c8a 100644 --- a/wolfcrypt/src/hpke.c +++ b/wolfcrypt/src/hpke.c @@ -170,7 +170,7 @@ int wc_HpkeInit(Hpke* hpke, int kem, int kdf, int aead, void* heap) case DHKEM_P256_HKDF_SHA256: hpke->curve_id = ECC_SECP256R1; hpke->Nsecret = WC_SHA256_DIGEST_SIZE; - hpke->Nh = WC_SHA256_DIGEST_SIZE; + hpke->kem_digest = WC_SHA256; hpke->Ndh = (word32)wc_ecc_get_curve_size_from_id(hpke->curve_id); hpke->Npk = 1 + hpke->Ndh * 2; break; @@ -180,7 +180,7 @@ int wc_HpkeInit(Hpke* hpke, int kem, int kdf, int aead, void* heap) case DHKEM_P384_HKDF_SHA384: hpke->curve_id = ECC_SECP384R1; hpke->Nsecret = WC_SHA384_DIGEST_SIZE; - hpke->Nh = WC_SHA384_DIGEST_SIZE; + hpke->kem_digest = WC_SHA384; hpke->Ndh = (word32)wc_ecc_get_curve_size_from_id(hpke->curve_id); hpke->Npk = 1 + hpke->Ndh * 2; break; @@ -190,7 +190,7 @@ int wc_HpkeInit(Hpke* hpke, int kem, int kdf, int aead, void* heap) case DHKEM_P521_HKDF_SHA512: hpke->curve_id = ECC_SECP521R1; hpke->Nsecret = WC_SHA512_DIGEST_SIZE; - hpke->Nh = WC_SHA512_DIGEST_SIZE; + hpke->kem_digest = WC_SHA512; hpke->Ndh = (word32)wc_ecc_get_curve_size_from_id(hpke->curve_id); hpke->Npk = 1 + hpke->Ndh * 2; break; @@ -201,7 +201,7 @@ int wc_HpkeInit(Hpke* hpke, int kem, int kdf, int aead, void* heap) (defined(WOLFSSL_SHA224) || !defined(NO_SHA256)) case DHKEM_X25519_HKDF_SHA256: hpke->Nsecret = WC_SHA256_DIGEST_SIZE; - hpke->Nh = WC_SHA256_DIGEST_SIZE; + hpke->kem_digest = WC_SHA256; hpke->Ndh = CURVE25519_KEYSIZE; hpke->Npk = CURVE25519_PUB_KEY_SIZE; break; @@ -211,7 +211,7 @@ int wc_HpkeInit(Hpke* hpke, int kem, int kdf, int aead, void* heap) (defined(WOLFSSL_SHA384) || defined(WOLFSSL_SHA512)) case DHKEM_X448_HKDF_SHA512: hpke->Nsecret = WC_SHA512_DIGEST_SIZE; - hpke->Nh = WC_SHA512_DIGEST_SIZE; + hpke->kem_digest = WC_SHA512; /* size of x448 shared secret */ hpke->Ndh = 64; hpke->Npk = CURVE448_PUB_KEY_SIZE; @@ -228,14 +228,17 @@ int wc_HpkeInit(Hpke* hpke, int kem, int kdf, int aead, void* heap) if (ret == 0) { switch (kdf) { case HKDF_SHA256: + hpke->Nh = WC_SHA256_DIGEST_SIZE; hpke->kdf_digest = WC_SHA256; break; case HKDF_SHA384: + hpke->Nh = WC_SHA384_DIGEST_SIZE; hpke->kdf_digest = WC_SHA384; break; case HKDF_SHA512: + hpke->Nh = WC_SHA512_DIGEST_SIZE; hpke->kdf_digest = WC_SHA512; break; @@ -459,7 +462,7 @@ void wc_HpkeFreeKey(Hpke* hpke, word16 kem, void* keypair, void* heap) } static int wc_HpkeLabeledExtract(Hpke* hpke, byte* suite_id, - word32 suite_id_len, byte* salt, word32 salt_len, byte* label, + word32 suite_id_len, int digest, byte* salt, word32 salt_len, byte* label, word32 label_len, byte* ikm, word32 ikm_len, byte* out) { int ret; @@ -516,7 +519,7 @@ static int wc_HpkeLabeledExtract(Hpke* hpke, byte* suite_id, /* call extract */ PRIVATE_KEY_UNLOCK(); - ret = wc_HKDF_Extract(hpke->kdf_digest, salt, salt_len, labeled_ikm, + ret = wc_HKDF_Extract(digest, salt, salt_len, labeled_ikm, (word32)(size_t)(labeled_ikm_p - labeled_ikm), out); PRIVATE_KEY_LOCK(); @@ -528,8 +531,8 @@ static int wc_HpkeLabeledExtract(Hpke* hpke, byte* suite_id, /* do hkdf expand with the format specified in the hpke rfc, return 0 or * error */ static int wc_HpkeLabeledExpand(Hpke* hpke, byte* suite_id, word32 suite_id_len, - byte* prk, word32 prk_len, byte* label, word32 label_len, byte* info, - word32 infoSz, word32 L, byte* out) + int digest, byte* prk, word32 prk_len, byte* label, word32 label_len, + byte* info, word32 infoSz, word32 L, byte* out) { int ret; byte* labeled_info_p; @@ -592,10 +595,8 @@ static int wc_HpkeLabeledExpand(Hpke* hpke, byte* suite_id, word32 suite_id_len, /* call expand */ PRIVATE_KEY_UNLOCK(); - ret = wc_HKDF_Expand(hpke->kdf_digest, - prk, prk_len, - labeled_info, (word32)(size_t)(labeled_info_p - labeled_info), - out, L); + ret = wc_HKDF_Expand(digest, prk, prk_len, labeled_info, + (word32)(size_t)(labeled_info_p - labeled_info), out, L); PRIVATE_KEY_LOCK(); } @@ -643,15 +644,16 @@ static int wc_HpkeExtractAndExpand( Hpke* hpke, byte* dh, word32 dh_len, /* extract */ ret = wc_HpkeLabeledExtract(hpke, hpke->kem_suite_id, - sizeof( hpke->kem_suite_id ), NULL, 0, (byte*)EAE_PRK_LABEL_STR, - EAE_PRK_LABEL_STR_LEN, dh, dh_len, eae_prk); + sizeof( hpke->kem_suite_id ), hpke->kem_digest, NULL, 0, + (byte*)EAE_PRK_LABEL_STR, EAE_PRK_LABEL_STR_LEN, dh, dh_len, eae_prk); /* expand */ if ( ret == 0 ) { ret = wc_HpkeLabeledExpand(hpke, hpke->kem_suite_id, - sizeof( hpke->kem_suite_id ), eae_prk, hpke->Nh, - (byte*)SHARED_SECRET_LABEL_STR, SHARED_SECRET_LABEL_STR_LEN, - kemContext, kem_context_length, hpke->Nsecret, sharedSecret); + sizeof( hpke->kem_suite_id ), hpke->kem_digest, eae_prk, + hpke->Nsecret, (byte*)SHARED_SECRET_LABEL_STR, + SHARED_SECRET_LABEL_STR_LEN, kemContext, kem_context_length, + hpke->Nsecret, sharedSecret); } ForceZero(eae_prk, WC_MAX_DIGEST_SIZE); @@ -701,35 +703,37 @@ static int wc_HpkeKeyScheduleBase(Hpke* hpke, HpkeBaseContext* context, /* extract psk_id, which for base is null */ ret = wc_HpkeLabeledExtract(hpke, hpke->hpke_suite_id, - sizeof( hpke->hpke_suite_id ), NULL, 0, (byte*)PSK_ID_HASH_LABEL_STR, - PSK_ID_HASH_LABEL_STR_LEN, NULL, 0, key_schedule_context + 1); + sizeof( hpke->hpke_suite_id ), hpke->kdf_digest, NULL, 0, + (byte*)PSK_ID_HASH_LABEL_STR, PSK_ID_HASH_LABEL_STR_LEN, NULL, 0, + key_schedule_context + 1); /* extract info */ if (ret == 0) { ret = wc_HpkeLabeledExtract(hpke, hpke->hpke_suite_id, - sizeof( hpke->hpke_suite_id ), NULL, 0, (byte*)INFO_HASH_LABEL_STR, - INFO_HASH_LABEL_STR_LEN, info, infoSz, + sizeof( hpke->hpke_suite_id ), hpke->kdf_digest, NULL, 0, + (byte*)INFO_HASH_LABEL_STR, INFO_HASH_LABEL_STR_LEN, info, infoSz, key_schedule_context + 1 + hpke->Nh); } /* extract secret */ if (ret == 0) { ret = wc_HpkeLabeledExtract(hpke, hpke->hpke_suite_id, - sizeof( hpke->hpke_suite_id ), sharedSecret, hpke->Nsecret, - (byte*)SECRET_LABEL_STR, SECRET_LABEL_STR_LEN, NULL, 0, secret); + sizeof( hpke->hpke_suite_id ), hpke->kdf_digest, sharedSecret, + hpke->Nsecret, (byte*)SECRET_LABEL_STR, SECRET_LABEL_STR_LEN, + NULL, 0, secret); } /* expand key */ if (ret == 0) ret = wc_HpkeLabeledExpand(hpke, hpke->hpke_suite_id, - sizeof( hpke->hpke_suite_id ), secret, hpke->Nh, + sizeof( hpke->hpke_suite_id ), hpke->kdf_digest, secret, hpke->Nh, (byte*)KEY_LABEL_STR, KEY_LABEL_STR_LEN, key_schedule_context, 1 + 2 * hpke->Nh, hpke->Nk, context->key); /* expand nonce */ if (ret == 0) { ret = wc_HpkeLabeledExpand(hpke, hpke->hpke_suite_id, - sizeof( hpke->hpke_suite_id ), secret, hpke->Nh, + sizeof( hpke->hpke_suite_id ), hpke->kdf_digest, secret, hpke->Nh, (byte*)BASE_NONCE_LABEL_STR, BASE_NONCE_LABEL_STR_LEN, key_schedule_context, 1 + 2 * hpke->Nh, hpke->Nn, context->base_nonce); @@ -738,7 +742,7 @@ static int wc_HpkeKeyScheduleBase(Hpke* hpke, HpkeBaseContext* context, /* expand exporter_secret */ if (ret == 0) { ret = wc_HpkeLabeledExpand(hpke, hpke->hpke_suite_id, - sizeof( hpke->hpke_suite_id ), secret, hpke->Nh, + sizeof( hpke->hpke_suite_id ), hpke->kdf_digest, secret, hpke->Nh, (byte*)EXP_LABEL_STR, EXP_LABEL_STR_LEN, key_schedule_context, 1 + 2 * hpke->Nh, hpke->Nh, context->exporter_secret); } diff --git a/wolfcrypt/test/test.c b/wolfcrypt/test/test.c index cad604a6be..ade14b78f0 100644 --- a/wolfcrypt/test/test.c +++ b/wolfcrypt/test/test.c @@ -32102,9 +32102,24 @@ WOLFSSL_TEST_SUBROUTINE wc_test_ret_t hpke_test(void) ret = hpke_test_multi(hpke); if (ret != 0) return ret; - #endif + #if (defined(WOLFSSL_SHA224) || !defined(NO_SHA256)) && \ + (defined(WOLFSSL_SHA384) || defined(WOLFSSL_SHA512)) + /* p256 with sha512 kdf */ + ret = wc_HpkeInit(hpke, DHKEM_P256_HKDF_SHA256, HKDF_SHA512, + HPKE_AES_128_GCM, NULL); + if (ret != 0) + return WC_TEST_RET_ENC_EC(ret); + ret = hpke_test_single(hpke); + if (ret != 0) + return ret; + ret = hpke_test_multi(hpke); + if (ret != 0) + return ret; + #endif + + #if defined(WOLFSSL_SHA384) && \ (defined(HAVE_ECC384) || defined(HAVE_ALL_CURVES)) /* p384 */ @@ -32134,6 +32149,21 @@ WOLFSSL_TEST_SUBROUTINE wc_test_ret_t hpke_test(void) if (ret != 0) return ret; #endif + + #if defined(WOLFSSL_SHA384) && defined(WOLFSSL_SHA512) && \ + (defined(HAVE_ECC521) || defined(HAVE_ALL_CURVES)) + /* p521 with sha384 kdf */ + ret = wc_HpkeInit(hpke, DHKEM_P521_HKDF_SHA512, HKDF_SHA384, + HPKE_AES_128_GCM, NULL); + if (ret != 0) + return WC_TEST_RET_ENC_EC(ret); + ret = hpke_test_single(hpke); + if (ret != 0) + return ret; + ret = hpke_test_multi(hpke); + if (ret != 0) + return ret; + #endif #endif #if defined(HAVE_CURVE25519) diff --git a/wolfssl/wolfcrypt/hpke.h b/wolfssl/wolfcrypt/hpke.h index 307f46b6ea..d6d82aaf1f 100644 --- a/wolfssl/wolfcrypt/hpke.h +++ b/wolfssl/wolfcrypt/hpke.h @@ -99,6 +99,7 @@ typedef struct { word32 Npk; word32 Nsecret; int kdf_digest; + int kem_digest; int curve_id; byte kem_suite_id[KEM_SUITE_ID_LEN]; byte hpke_suite_id[HPKE_SUITE_ID_LEN];