From 1fd952d6d0bfc552bfed51f14d75974c377935c6 Mon Sep 17 00:00:00 2001 From: John Bland Date: Mon, 10 Mar 2025 09:12:13 -0400 Subject: [PATCH] fix bad ech transaction hash calculations --- src/tls.c | 51 ++++++++++++---------- src/tls13.c | 123 ++++++++++++++++++++++++++++------------------------ 2 files changed, 95 insertions(+), 79 deletions(-) diff --git a/src/tls.c b/src/tls.c index 94eb3f830..18ae2b53f 100644 --- a/src/tls.c +++ b/src/tls.c @@ -14983,7 +14983,9 @@ static int TLSX_GetSizeWithEch(WOLFSSL* ssl, byte* semaphore, byte msgType, echX = TLSX_Find(ssl->ctx->extensions, TLSX_ECH); /* if type is outer change sni to public name */ - if (echX != NULL && ((WOLFSSL_ECH*)echX->data)->type == ECH_TYPE_OUTER) { + if (echX != NULL && ((WOLFSSL_ECH*)echX->data)->type == ECH_TYPE_OUTER && + (ssl->options.echAccepted || + ((WOLFSSL_ECH*)echX->data)->innerCount == 0)) { if (ssl->extensions) { serverNameX = TLSX_Find(ssl->extensions, TLSX_SERVER_NAME); @@ -15190,7 +15192,9 @@ static int TLSX_WriteWithEch(WOLFSSL* ssl, byte* output, byte* semaphore, } /* if type is outer change sni to public name */ - if (echX != NULL && ((WOLFSSL_ECH*)echX->data)->type == ECH_TYPE_OUTER) { + if (echX != NULL && ((WOLFSSL_ECH*)echX->data)->type == ECH_TYPE_OUTER && + (ssl->options.echAccepted || + ((WOLFSSL_ECH*)echX->data)->innerCount == 0)) { if (ssl->extensions) { serverNameX = TLSX_Find(ssl->extensions, TLSX_SERVER_NAME); @@ -15250,31 +15254,34 @@ static int TLSX_WriteWithEch(WOLFSSL* ssl, byte* output, byte* semaphore, msgType, pOffset); } - if (echX != NULL) { - /* turn off and write it last */ - TURN_OFF(semaphore, TLSX_ToSemaphore(echX->type)); - } + /* only write if have a shot at acceptance */ + if (ssl->options.echAccepted || ((WOLFSSL_ECH*)echX->data)->innerCount == 0) { + if (echX != NULL) { + /* turn off and write it last */ + TURN_OFF(semaphore, TLSX_ToSemaphore(echX->type)); + } - if (ret == 0 && ssl->extensions) { - ret = TLSX_Write(ssl->extensions, output + *pOffset, semaphore, - msgType, pOffset); - } + if (ret == 0 && ssl->extensions) { + ret = TLSX_Write(ssl->extensions, output + *pOffset, semaphore, + msgType, pOffset); + } - if (ret == 0 && ssl->ctx && ssl->ctx->extensions) { - ret = TLSX_Write(ssl->ctx->extensions, output + *pOffset, semaphore, - msgType, pOffset); - } + if (ret == 0 && ssl->ctx && ssl->ctx->extensions) { + ret = TLSX_Write(ssl->ctx->extensions, output + *pOffset, semaphore, + msgType, pOffset); + } - if (serverNameX != NULL) { - /* remove the public name SNI */ - TLSX_Remove(extensions, TLSX_SERVER_NAME, ssl->heap); + if (serverNameX != NULL) { + /* remove the public name SNI */ + TLSX_Remove(extensions, TLSX_SERVER_NAME, ssl->heap); - ret = TLSX_UseSNI(extensions, WOLFSSL_SNI_HOST_NAME, tmpServerName, - XSTRLEN(tmpServerName), ssl->heap); + ret = TLSX_UseSNI(extensions, WOLFSSL_SNI_HOST_NAME, tmpServerName, + XSTRLEN(tmpServerName), ssl->heap); - /* restore the inner server name */ - if (ret == WOLFSSL_SUCCESS) - ret = 0; + /* restore the inner server name */ + if (ret == WOLFSSL_SUCCESS) + ret = 0; + } } #ifdef WOLFSSL_SMALL_STACK diff --git a/src/tls13.c b/src/tls13.c index f7c758105..3098f268c 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -4182,10 +4182,10 @@ static int EchHashHelloInner(WOLFSSL* ssl, WOLFSSL_ECH* ech) return BAD_FUNC_ARG; realSz = ech->innerClientHelloLen - ech->paddingLen - ech->hpke->Nt; tmpHashes = ssl->hsHashes; + ssl->hsHashes = NULL; /* init the ech hashes */ - InitHandshakeHashesAndCopy(ssl, ssl->hsHashes, &ssl->hsHashesEch); - /* swap hsHashes so the regular hash functions work */ - ssl->hsHashes = ssl->hsHashesEch; + InitHandshakeHashes(ssl); + ssl->hsHashesEch = ssl->hsHashes; if (ret == 0) { /* do the handshake header then the body */ AddTls13HandShakeHeader(falseHeader, realSz, 0, 0, client_hello, ssl); @@ -4200,7 +4200,7 @@ static int EchHashHelloInner(WOLFSSL* ssl, WOLFSSL_ECH* ech) ech->innerCount = 1; } else { - /* switch back to primary so we can copy it to inner */ + /* switch back to hsHashes so we have hrr -> echInner2 */ ssl->hsHashes = tmpHashes; InitHandshakeHashesAndCopy(ssl, ssl->hsHashes, &ssl->hsHashesEchInner); @@ -4464,23 +4464,26 @@ int SendTls13ClientHello(WOLFSSL* ssl) if (args->ech == NULL) return WOLFSSL_FATAL_ERROR; - /* set the type to inner */ - args->ech->type = ECH_TYPE_INNER; - args->preXLength = (int)args->length; + /* only prepare if we have a chance at acceptance */ + if (ssl->options.echAccepted || args->ech->innerCount == 0) { + /* set the type to inner */ + args->ech->type = ECH_TYPE_INNER; + args->preXLength = (int)args->length; - /* get size for inner */ - ret = TLSX_GetRequestSize(ssl, client_hello, &args->length); - if (ret != 0) - return ret; + /* get size for inner */ + ret = TLSX_GetRequestSize(ssl, client_hello, &args->length); + if (ret != 0) + return ret; - /* set the type to outer */ - args->ech->type = 0; - /* set innerClientHelloLen to ClientHelloInner + padding + tag */ - args->ech->paddingLen = 31 - ((args->length - 1) % 32); - args->ech->innerClientHelloLen = (word16)(args->length + - args->ech->paddingLen + args->ech->hpke->Nt); - /* set the length back to before we computed ClientHelloInner size */ - args->length = (word32)args->preXLength; + /* set the type to outer */ + args->ech->type = 0; + /* set innerClientHelloLen to ClientHelloInner + padding + tag */ + args->ech->paddingLen = 31 - ((args->length - 1) % 32); + args->ech->innerClientHelloLen = (word16)(args->length + + args->ech->paddingLen + args->ech->hpke->Nt); + /* set the length back to before we computed ClientHelloInner size */ + args->length = (word32)args->preXLength; + } } #endif @@ -4606,7 +4609,8 @@ int SendTls13ClientHello(WOLFSSL* ssl) #if defined(HAVE_ECH) /* write inner then outer */ - if (ssl->options.useEch == 1 && !ssl->options.disableECH) { + if (ssl->options.useEch == 1 && !ssl->options.disableECH && + (ssl->options.echAccepted || args->ech->innerCount == 0)) { /* set the type to inner */ args->ech->type = ECH_TYPE_INNER; /* innerClientHello may already exist from hrr, free if it does */ @@ -4663,7 +4667,8 @@ int SendTls13ClientHello(WOLFSSL* ssl) #if defined(HAVE_ECH) /* encrypt and pack the ech innerClientHello */ - if (ssl->options.useEch == 1 && !ssl->options.disableECH) { + if (ssl->options.useEch == 1 && !ssl->options.disableECH && + (ssl->options.echAccepted || args->ech->innerCount == 0)) { ret = TLSX_FinalizeEch(args->ech, args->output + RECORD_HEADER_SZ + HANDSHAKE_HEADER_SZ, (word32)(args->sendSz - (RECORD_HEADER_SZ + HANDSHAKE_HEADER_SZ))); @@ -4693,7 +4698,8 @@ int SendTls13ClientHello(WOLFSSL* ssl) { #if defined(HAVE_ECH) /* compute the inner hash */ - if (ssl->options.useEch == 1 && !ssl->options.disableECH) + if (ssl->options.useEch == 1 && !ssl->options.disableECH && + (ssl->options.echAccepted || args->ech->innerCount == 0)) ret = EchHashHelloInner(ssl, args->ech); #endif /* compute the outer hash */ @@ -4789,13 +4795,12 @@ static int Dtls13ClientDoDowngrade(WOLFSSL* ssl) /* check if the server accepted ech or not, must be run after an hsHashes * restart */ static int EchCheckAcceptance(WOLFSSL* ssl, byte* label, word16 labelSz, - const byte* input, int acceptOffset, int helloSz, byte msgType) + const byte* input, int acceptOffset, int helloSz) { int ret = 0; int digestType = 0; int digestSize = 0; HS_Hashes* tmpHashes; - HS_Hashes* acceptHashes = NULL; byte zeros[WC_MAX_DIGEST_SIZE]; byte transcriptEchConf[WC_MAX_DIGEST_SIZE]; byte expandLabelPrk[WC_MAX_DIGEST_SIZE]; @@ -4806,14 +4811,10 @@ static int EchCheckAcceptance(WOLFSSL* ssl, byte* label, word16 labelSz, XMEMSET(acceptConfirmation, 0, sizeof(acceptConfirmation)); /* store so we can restore regardless of the outcome */ tmpHashes = ssl->hsHashes; - /* copy ech hashes to accept */ - ret = InitHandshakeHashesAndCopy(ssl, ssl->hsHashesEch, &acceptHashes); - if (ret == 0) { - /* swap hsHashes to acceptHashes */ - ssl->hsHashes = acceptHashes; - /* hash up to the last 8 bytes */ - ret = HashRaw(ssl, input, acceptOffset); - } + /* swap hsHashes to hsHashesEch */ + ssl->hsHashes = ssl->hsHashesEch; + /* hash up to the last 8 bytes */ + ret = HashRaw(ssl, input, acceptOffset); /* hash 8 zeros */ if (ret == 0) ret = HashRaw(ssl, zeros, ECH_ACCEPT_CONFIRMATION_SZ); @@ -4903,17 +4904,17 @@ static int EchCheckAcceptance(WOLFSSL* ssl, byte* label, word16 labelSz, else { /* set echAccepted to 0, needed in case HRR */ ssl->options.echAccepted = 0; + /* free inner since we're continuing with outer */ + ssl->hsHashes = ssl->hsHashesEchInner; + FreeHandshakeHashes(ssl); + ssl->hsHashesEchInner = NULL; } /* continue with outer if we failed to verify ech was accepted */ ret = 0; } - /* free hsHashesEch */ - if (ssl->options.echAccepted == 0 || msgType != hello_retry_request) { - /* free hsHashesEch */ - FreeHandshakeHashes(ssl); - /* set hsHashesEch to NULL to avoid double free */ - ssl->hsHashesEch = NULL; - } + FreeHandshakeHashes(ssl); + /* set hsHashesEch to NULL to avoid double free */ + ssl->hsHashesEch = NULL; /* swap to tmp, will be inner if accepted, hsHashes if rejected */ ssl->hsHashes = tmpHashes; return ret; @@ -4928,7 +4929,6 @@ static int EchWriteAcceptance(WOLFSSL* ssl, byte* label, word16 labelSz, int digestType = 0; int digestSize = 0; HS_Hashes* tmpHashes = NULL; - HS_Hashes* acceptHashes = NULL; byte zeros[WC_MAX_DIGEST_SIZE]; byte transcriptEchConf[WC_MAX_DIGEST_SIZE]; byte expandLabelPrk[WC_MAX_DIGEST_SIZE]; @@ -4937,14 +4937,9 @@ static int EchWriteAcceptance(WOLFSSL* ssl, byte* label, word16 labelSz, XMEMSET(expandLabelPrk, 0, sizeof(expandLabelPrk)); /* store so we can restore regardless of the outcome */ tmpHashes = ssl->hsHashes; - /* copy ech hashes to accept */ - ret = InitHandshakeHashesAndCopy(ssl, ssl->hsHashesEch, &acceptHashes); - if (ret == 0) { - /* swap hsHashes to acceptHashes */ - ssl->hsHashes = acceptHashes; - /* hash up to the acceptOffset */ - ret = HashRaw(ssl, output, acceptOffset); - } + ssl->hsHashes = ssl->hsHashesEch; + /* hash up to the acceptOffset */ + ret = HashRaw(ssl, output, acceptOffset); /* hash 8 zeros */ if (ret == 0) ret = HashRaw(ssl, zeros, ECH_ACCEPT_CONFIRMATION_SZ); @@ -5015,12 +5010,12 @@ static int EchWriteAcceptance(WOLFSSL* ssl, byte* label, word16 labelSz, PRIVATE_KEY_LOCK(); } if (ret == 0) { - /* free hsHashesEch if this is the last ech involved message */ - if (msgType != hello_retry_request) { - FreeHandshakeHashes(ssl); - ssl->hsHashesEch = NULL; + /* free hsHashesEch, if this is an HRR we will start at client hello 2*/ + FreeHandshakeHashes(ssl); + ssl->hsHashesEch = NULL; + /* mark that ech was accepted */ + if (msgType != hello_retry_request) ssl->options.echAccepted = 1; - } } ssl->hsHashes = tmpHashes; return ret; @@ -5560,8 +5555,7 @@ int DoTls13ServerHello(WOLFSSL* ssl, const byte* input, word32* inOutIdx, /* check acceptance */ if (ret == 0) { ret = EchCheckAcceptance(ssl, args->acceptLabel, - args->acceptLabelSz, input, args->acceptOffset, helloSz, - args->extMsgType); + args->acceptLabelSz, input, args->acceptOffset, helloSz); } if (ret != 0) return ret; @@ -6741,6 +6735,7 @@ int DoTls13ClientHello(WOLFSSL* ssl, const byte* input, word32* inOutIdx, #endif #if defined(HAVE_ECH) TLSX* echX = NULL; + HS_Hashes* tmpHashes; #endif WOLFSSL_START(WC_FUNC_CLIENT_HELLO_DO); @@ -7064,6 +7059,22 @@ int DoTls13ClientHello(WOLFSSL* ssl, const byte* input, word32* inOutIdx, } #endif +#if defined(HAVE_ECH) + /* hash clientHelloInner to hsHashesEch independently since it can't include + * the HRR */ + if (!ssl->options.disableECH) { + tmpHashes = ssl->hsHashes; + ssl->hsHashes = NULL; + ret = InitHandshakeHashes(ssl); + if (ret != 0) + goto exit_dch; + if ((ret = HashInput(ssl, input + args->begin, (int)helloSz)) != 0) + goto exit_dch; + ssl->hsHashesEch = ssl->hsHashes; + ssl->hsHashes = tmpHashes; + } +#endif + #if (defined(HAVE_SESSION_TICKET) || !defined(NO_PSK)) && \ defined(HAVE_TLS_EXTENSIONS) ret = CheckPreSharedKeys(ssl, input + args->begin, helloSz, ssl->clSuites, @@ -7481,8 +7492,6 @@ int SendTls13ServerHello(WOLFSSL* ssl, byte extMsgType) } /* replace the last 8 bytes of server random with the accept */ if (((WOLFSSL_ECH*)echX->data)->state == ECH_PARSED_INTERNAL) { - ret = InitHandshakeHashesAndCopy(ssl, ssl->hsHashes, - &ssl->hsHashesEch); if (ret == 0) { ret = EchWriteAcceptance(ssl, acceptLabel, acceptLabelSz, output + RECORD_HEADER_SZ,