From a999909969e794966ca8ad6641c0ec8ee291dbef Mon Sep 17 00:00:00 2001 From: Juliusz Sosinowicz Date: Tue, 3 Jan 2023 17:05:12 +0100 Subject: [PATCH] Use PSK callback to get the ciphersuite to use - Allocate additional byte in TLSX_PreSharedKey_New for null terminator --- src/dtls.c | 105 ++++++++++++++-------- src/tls.c | 219 ++++++++++++++++++++++++--------------------- src/tls13.c | 103 +++++++++++---------- wolfssl/internal.h | 13 ++- 4 files changed, 255 insertions(+), 185 deletions(-) diff --git a/src/dtls.c b/src/dtls.c index 6fa691bba..95ec7f517 100644 --- a/src/dtls.c +++ b/src/dtls.c @@ -435,6 +435,38 @@ static int CopyExtensions(TLSX* src, TLSX** dst, void* heap) } #endif +#if defined(WOLFSSL_DTLS13) && !defined(NO_PSK) +/* Very simplified version of CheckPreSharedKeys to find the current suite */ +static void FindPskSuiteFromExt(const WOLFSSL* ssl, TLSX* extensions, + PskInfo* pskInfo, Suites* suites) +{ + TLSX* pskExt = TLSX_Find(extensions, TLSX_PRE_SHARED_KEY); + int found = 0; + PreSharedKey* current; + byte psk_key[MAX_PSK_KEY_LEN]; + word32 psk_keySz; + int i; + + if (pskExt == NULL) + return; + + for (i = 0; i < suites->suiteSz; i += 2) { + for (current = (PreSharedKey*)pskExt->data; current != NULL; + current = current->next) { + if (FindPskSuite(ssl, current, psk_key, &psk_keySz, + suites->suites + i, &found) == 0) { + if (found) { + pskInfo->cipherSuite0 = suites->suites[i]; + pskInfo->cipherSuite = suites->suites[i + 1]; + pskInfo->isValid = 1; + return; + } + } + } + } +} +#endif + static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13, PskInfo* pskInfo) { @@ -550,39 +582,46 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13, * and if they don't match we will error out there anyway. */ byte modes; - ret = TlsxFindByType(&tlsx, TLSX_PSK_KEY_EXCHANGE_MODES, - ch->extension); - if (ret != 0) - goto dtls13_cleanup; - if (tlsx.size == 0) - ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup); - ret = TLSX_PskKeyModes_Parse_Modes(tlsx.elements, tlsx.size, - client_hello, &modes); - if (ret != 0) - goto dtls13_cleanup; - if ((modes & (1 << PSK_DHE_KE)) && !ssl->options.noPskDheKe) { - if (!haveKS) +#ifndef NO_PSK + /* When we didn't find a valid ticket ask the user for the + * ciphersuite matching this identity */ + if (!pskInfo->isValid) { + if (TLSX_PreSharedKey_Parse_ClientHello(&parsedExts, + tlsx.elements, tlsx.size, ssl->heap) == 0) + FindPskSuiteFromExt(ssl, parsedExts, pskInfo, &suites); + /* Revert to full handshake if PSK parsing failed */ + } +#endif + + if (pskInfo->isValid) { + ret = TlsxFindByType(&tlsx, TLSX_PSK_KEY_EXCHANGE_MODES, + ch->extension); + if (ret != 0) + goto dtls13_cleanup; + if (tlsx.size == 0) ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup); - doKE = 1; + ret = TLSX_PskKeyModes_Parse_Modes(tlsx.elements, tlsx.size, + client_hello, &modes); + if (ret != 0) + goto dtls13_cleanup; + if ((modes & (1 << PSK_DHE_KE)) && + !ssl->options.noPskDheKe) { + if (!haveKS) + ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup); + doKE = 1; + } + else if ((modes & (1 << PSK_KE)) == 0) { + ERROR_OUT(PSK_KEY_ERROR, dtls13_cleanup); + } + usePSK = 1; } - else if ((modes & (1 << PSK_KE)) == 0) { - ERROR_OUT(PSK_KEY_ERROR, dtls13_cleanup); - } - usePSK = 1; } #endif #if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK) - if (usePSK) { - if (pskInfo->isValid) { - cs.cipherSuite0 = pskInfo->cipherSuite0; - cs.cipherSuite = pskInfo->cipherSuite; - } - else { - /* Only support the default ciphersuite for PSK */ - cs.cipherSuite0 = TLS13_BYTE; - cs.cipherSuite = WOLFSSL_DEF_PSK_CIPHER; - } + if (usePSK && pskInfo->isValid) { + cs.cipherSuite0 = pskInfo->cipherSuite0; + cs.cipherSuite = pskInfo->cipherSuite; if (doKE) { byte searched = 0; @@ -609,15 +648,9 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13, if (ret != 0) goto dtls13_cleanup; } - - /* Need to remove the keyshare ext if we are not doing PSK and we - * found a common group. */ - if (cs.clientKSE != NULL -#if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK) - && !usePSK -#endif - ) - { + else { + /* Need to remove the keyshare ext if we found a common group + * and are not doing curve negotiation. */ TLSX_Remove(&parsedExts, TLSX_KEY_SHARE, ssl->heap); } diff --git a/src/tls.c b/src/tls.c index 815da07cf..6f15edbf9 100644 --- a/src/tls.c +++ b/src/tls.c @@ -9503,6 +9503,100 @@ static int TLSX_PreSharedKey_Write(PreSharedKey* list, byte* output, return 0; } +int TLSX_PreSharedKey_Parse_ClientHello(TLSX** extensions, const byte* input, + word16 length, void* heap) +{ + + int ret; + word16 len; + word16 idx = 0; + TLSX* extension; + PreSharedKey* list; + + TLSX_Remove(extensions, TLSX_PRE_SHARED_KEY, heap); + + /* Length of identities and of binders. */ + if ((int)(length - idx) < OPAQUE16_LEN + OPAQUE16_LEN) + return BUFFER_E; + + /* Length of identities. */ + ato16(input + idx, &len); + idx += OPAQUE16_LEN; + if (len < MIN_PSK_ID_LEN || length - idx < len) + return BUFFER_E; + + /* Create a pre-shared key object for each identity. */ + while (len > 0) { + const byte* identity; + word16 identityLen; + word32 age; + + if (len < OPAQUE16_LEN) + return BUFFER_E; + + /* Length of identity. */ + ato16(input + idx, &identityLen); + idx += OPAQUE16_LEN; + if (len < OPAQUE16_LEN + identityLen + OPAQUE32_LEN || + identityLen > MAX_PSK_ID_LEN) + return BUFFER_E; + /* Cache identity pointer. */ + identity = input + idx; + idx += identityLen; + /* Ticket age. */ + ato32(input + idx, &age); + idx += OPAQUE32_LEN; + + ret = TLSX_PreSharedKey_Use(extensions, identity, identityLen, age, no_mac, + 0, 0, 1, NULL, heap); + if (ret != 0) + return ret; + + /* Done with this identity. */ + len -= OPAQUE16_LEN + identityLen + OPAQUE32_LEN; + } + + /* Find the list of identities sent to server. */ + extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY); + if (extension == NULL) + return PSK_KEY_ERROR; + list = (PreSharedKey*)extension->data; + + /* Length of binders. */ + if (idx + OPAQUE16_LEN > length) + return BUFFER_E; + ato16(input + idx, &len); + idx += OPAQUE16_LEN; + if (len < MIN_PSK_BINDERS_LEN || length - idx < len) + return BUFFER_E; + + /* Set binder for each identity. */ + while (list != NULL && len > 0) { + /* Length of binder */ + list->binderLen = input[idx++]; + if (list->binderLen < WC_SHA256_DIGEST_SIZE || + list->binderLen > WC_MAX_DIGEST_SIZE) + return BUFFER_E; + if (len < OPAQUE8_LEN + list->binderLen) + return BUFFER_E; + + /* Copy binder into static buffer. */ + XMEMCPY(list->binder, input + idx, list->binderLen); + idx += (word16)list->binderLen; + + /* Done with binder entry. */ + len -= OPAQUE8_LEN + (word16)list->binderLen; + + /* Next identity. */ + list = list->next; + } + if (list != NULL || len != 0) + return BUFFER_E; + + return 0; + +} + /* Parse the pre-shared key extension. * Different formats in different messages. * @@ -9519,91 +9613,8 @@ static int TLSX_PreSharedKey_Parse(WOLFSSL* ssl, const byte* input, PreSharedKey* list; if (msgType == client_hello) { - int ret; - word16 len; - word16 idx = 0; - - TLSX_Remove(&ssl->extensions, TLSX_PRE_SHARED_KEY, ssl->heap); - - /* Length of identities and of binders. */ - if ((int)(length - idx) < OPAQUE16_LEN + OPAQUE16_LEN) - return BUFFER_E; - - /* Length of identities. */ - ato16(input + idx, &len); - idx += OPAQUE16_LEN; - if (len < MIN_PSK_ID_LEN || length - idx < len) - return BUFFER_E; - - /* Create a pre-shared key object for each identity. */ - while (len > 0) { - const byte* identity; - word16 identityLen; - word32 age; - - if (len < OPAQUE16_LEN) - return BUFFER_E; - - /* Length of identity. */ - ato16(input + idx, &identityLen); - idx += OPAQUE16_LEN; - if (len < OPAQUE16_LEN + identityLen + OPAQUE32_LEN || - identityLen > MAX_PSK_ID_LEN) - return BUFFER_E; - /* Cache identity pointer. */ - identity = input + idx; - idx += identityLen; - /* Ticket age. */ - ato32(input + idx, &age); - idx += OPAQUE32_LEN; - - ret = TLSX_PreSharedKey_Use(ssl, identity, identityLen, age, no_mac, - 0, 0, 1, NULL); - if (ret != 0) - return ret; - - /* Done with this identity. */ - len -= OPAQUE16_LEN + identityLen + OPAQUE32_LEN; - } - - /* Find the list of identities sent to server. */ - extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY); - if (extension == NULL) - return PSK_KEY_ERROR; - list = (PreSharedKey*)extension->data; - - /* Length of binders. */ - if (idx + OPAQUE16_LEN > length) - return BUFFER_E; - ato16(input + idx, &len); - idx += OPAQUE16_LEN; - if (len < MIN_PSK_BINDERS_LEN || length - idx < len) - return BUFFER_E; - - /* Set binder for each identity. */ - while (list != NULL && len > 0) { - /* Length of binder */ - list->binderLen = input[idx++]; - if (list->binderLen < WC_SHA256_DIGEST_SIZE || - list->binderLen > WC_MAX_DIGEST_SIZE) - return BUFFER_E; - if (len < OPAQUE8_LEN + list->binderLen) - return BUFFER_E; - - /* Copy binder into static buffer. */ - XMEMCPY(list->binder, input + idx, list->binderLen); - idx += (word16)list->binderLen; - - /* Done with binder entry. */ - len -= OPAQUE8_LEN + (word16)list->binderLen; - - /* Next identity. */ - list = list->next; - } - if (list != NULL || len != 0) - return BUFFER_E; - - return 0; + return TLSX_PreSharedKey_Parse_ClientHello(&ssl->extensions, input, + length, ssl->heap); } if (msgType == server_hello) { @@ -9675,13 +9686,16 @@ static int TLSX_PreSharedKey_New(PreSharedKey** list, const byte* identity, XMEMSET(psk, 0, sizeof(*psk)); /* Make a copy of the identity data. */ - psk->identity = (byte*)XMALLOC(len, heap, DYNAMIC_TYPE_TLSX); + psk->identity = (byte*)XMALLOC(len + NULL_TERM_LEN, heap, + DYNAMIC_TYPE_TLSX); if (psk->identity == NULL) { XFREE(psk, heap, DYNAMIC_TYPE_TLSX); return MEMORY_E; } XMEMCPY(psk->identity, identity, len); psk->identityLen = len; + /* Use a NULL terminator in case it is a C string */ + psk->identity[psk->identityLen] = '\0'; /* Add it to the end and maintain the links. */ while (*list != NULL) { @@ -9729,24 +9743,24 @@ static WC_INLINE byte GetHmacLength(int hmac) * preSharedKey The new pre-shared key object. * returns 0 on success and other values indicate failure. */ -int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, word16 len, +int TLSX_PreSharedKey_Use(TLSX** extensions, const byte* identity, word16 len, word32 age, byte hmac, byte cipherSuite0, byte cipherSuite, byte resumption, - PreSharedKey **preSharedKey) + PreSharedKey **preSharedKey, void* heap) { int ret = 0; TLSX* extension; PreSharedKey* psk = NULL; /* Find the pre-shared key extension if it exists. */ - extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY); + extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY); if (extension == NULL) { /* Push new pre-shared key extension. */ - ret = TLSX_Push(&ssl->extensions, TLSX_PRE_SHARED_KEY, NULL, ssl->heap); + ret = TLSX_Push(extensions, TLSX_PRE_SHARED_KEY, NULL, heap); if (ret != 0) return ret; - extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY); + extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY); if (extension == NULL) return MEMORY_E; } @@ -9764,7 +9778,7 @@ int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, word16 len, /* Create a new pre-shared key object if not found. */ if (psk == NULL) { ret = TLSX_PreSharedKey_New((PreSharedKey**)&extension->data, identity, - len, ssl->heap, &psk); + len, heap, &psk); if (ret != 0) return ret; } @@ -12026,17 +12040,18 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer) milli += sess->ticketAdd; /* Pre-shared key is mandatory extension for resumption. */ - ret = TLSX_PreSharedKey_Use(ssl, sess->ticket, sess->ticketLen, - milli, ssl->specs.mac_algorithm, ssl->options.cipherSuite0, - ssl->options.cipherSuite, 1, NULL); + ret = TLSX_PreSharedKey_Use(&ssl->extensions, sess->ticket, + sess->ticketLen, milli, ssl->specs.mac_algorithm, + ssl->options.cipherSuite0, ssl->options.cipherSuite, 1, + NULL, ssl->heap); #else milli = now - sess->ticketSeen + sess->ticketAdd; /* Pre-shared key is mandatory extension for resumption. */ - ret = TLSX_PreSharedKey_Use(ssl, sess->ticket, sess->ticketLen, - (word32)milli, ssl->specs.mac_algorithm, + ret = TLSX_PreSharedKey_Use(&ssl->extensions, sess->ticket, + sess->ticketLen, (word32)milli, ssl->specs.mac_algorithm, ssl->options.cipherSuite0, ssl->options.cipherSuite, 1, - NULL); + NULL, ssl->heap); #endif if (ret != 0) return ret; @@ -12083,11 +12098,11 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer) GetCipherNameInternal(cipherSuite0, cipherSuite)); if (keySz > 0) { ssl->arrays->psk_keySz = keySz; - ret = TLSX_PreSharedKey_Use(ssl, + ret = TLSX_PreSharedKey_Use(&ssl->extensions, (byte*)ssl->arrays->client_identity, (word16)XSTRLEN(ssl->arrays->client_identity), 0, SuiteMac(WOLFSSL_SUITES(ssl)->suites + i), - cipherSuite0, cipherSuite, 0, NULL); + cipherSuite0, cipherSuite, 0, NULL, ssl->heap); if (ret != 0) return ret; #ifdef WOLFSSL_PSK_MULTI_ID_PER_CS @@ -12150,12 +12165,12 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer) if (ret != 0) return ret; - ret = TLSX_PreSharedKey_Use(ssl, + ret = TLSX_PreSharedKey_Use(&ssl->extensions, (byte*)ssl->arrays->client_identity, (word16)XSTRLEN(ssl->arrays->client_identity), 0, ssl->specs.mac_algorithm, cipherSuite0, cipherSuite, 0, - NULL); + NULL, ssl->heap); if (ret != 0) return ret; diff --git a/src/tls13.c b/src/tls13.c index fd418f7d1..c5c61f5f3 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -5496,6 +5496,60 @@ static void RefineSuites(WOLFSSL* ssl, Suites* peerSuites) #ifndef NO_PSK +int FindPskSuite(const WOLFSSL* ssl, PreSharedKey* psk, byte* psk_key, + word32* psk_keySz, byte* suite, int* found) +{ + const char* cipherName = NULL; + byte cipherSuite0 = TLS13_BYTE; + byte cipherSuite = WOLFSSL_DEF_PSK_CIPHER; + int ret = 0; + + *found = 0; + (void)suite; + + if (ssl->options.server_psk_tls13_cb != NULL) { + *psk_keySz = ssl->options.server_psk_tls13_cb((WOLFSSL*)ssl, + (char*)psk->identity, psk_key, MAX_PSK_KEY_LEN, &cipherName); + if (*psk_keySz != 0) { + int cipherSuiteFlags = WOLFSSL_CIPHER_SUITE_FLAG_NONE; + *found = (GetCipherSuiteFromName(cipherName, &cipherSuite0, + &cipherSuite, &cipherSuiteFlags) == 0); + (void)cipherSuiteFlags; + } + } + if (*found == 0 && (ssl->options.server_psk_cb != NULL)) { + *psk_keySz = ssl->options.server_psk_cb((WOLFSSL*)ssl, + (char*)psk->identity, psk_key, + MAX_PSK_KEY_LEN); + *found = (*psk_keySz != 0); + } + if (*found) { + if (*psk_keySz > MAX_PSK_KEY_LEN) { + WOLFSSL_MSG("Key len too long in FindPsk()"); + ret = PSK_KEY_ERROR; + WOLFSSL_ERROR_VERBOSE(ret); + } + if (ret == 0) { + #if !defined(WOLFSSL_PSK_ONE_ID) && !defined(WOLFSSL_PRIORITIZE_PSK) + /* Check whether PSK ciphersuite is in SSL. */ + *found = (suite[0] == cipherSuite0) && (suite[1] == cipherSuite); + #else + (void)suite; + /* Check whether PSK ciphersuite is in SSL. */ + { + byte s[2] = { + cipherSuite0, + cipherSuite, + }; + *found = FindSuiteSSL(ssl, s); + } + #endif + } + } + + return ret; +} + /* Attempt to find the PSK (not session ticket) that matches. * * @param [in, out] ssl The SSL/TLS object. @@ -5509,55 +5563,14 @@ static void RefineSuites(WOLFSSL* ssl, Suites* peerSuites) * @return 1 when a match found - but check error code. * @return 0 when no match found. */ -static int FindPsk(WOLFSSL* ssl, PreSharedKey* psk, const byte* suite, int* err) +static int FindPsk(WOLFSSL* ssl, PreSharedKey* psk, byte* suite, int* err) { int ret = 0; int found = 0; - const char* cipherName = NULL; - byte cipherSuite0 = TLS13_BYTE; - byte cipherSuite = WOLFSSL_DEF_PSK_CIPHER; - Arrays* sa = ssl->arrays; - (void)suite; - - if (ssl->options.server_psk_tls13_cb != NULL) { - sa->psk_keySz = ssl->options.server_psk_tls13_cb(ssl, - sa->client_identity, sa->psk_key, MAX_PSK_KEY_LEN, &cipherName); - if (sa->psk_keySz != 0) { - int cipherSuiteFlags = WOLFSSL_CIPHER_SUITE_FLAG_NONE; - found = (GetCipherSuiteFromName(cipherName, &cipherSuite0, - &cipherSuite, &cipherSuiteFlags) == 0); - (void)cipherSuiteFlags; - } - } - if (!found && (ssl->options.server_psk_cb != NULL)) { - sa->psk_keySz = ssl->options.server_psk_cb(ssl, - sa->client_identity, sa->psk_key, - MAX_PSK_KEY_LEN); - found = (sa->psk_keySz != 0); - } - if (found) { - if (sa->psk_keySz > MAX_PSK_KEY_LEN) { - WOLFSSL_MSG("Key len too long in FindPsk"); - ret = PSK_KEY_ERROR; - WOLFSSL_ERROR_VERBOSE(ret); - } - if (ret == 0) { - #if !defined(WOLFSSL_PSK_ONE_ID) && !defined(WOLFSSL_PRIORITIZE_PSK) - /* Check whether PSK ciphersuite is in SSL. */ - found = (suite[0] == cipherSuite0) && (suite[1] == cipherSuite); - #else - (void)suite; - /* Check whether PSK ciphersuite is in SSL. */ - { - byte s[2] = { - cipherSuite0, - cipherSuite, - }; - found = FindSuiteSSL(ssl, s); - } - #endif - } + ret = FindPskSuite(ssl, psk, ssl->arrays->psk_key, &ssl->arrays->psk_keySz, + suite, &found); + if (ret == 0 && found) { if ((ret == 0) && found) { /* Default to ciphersuite if cb doesn't specify. */ ssl->options.resuming = 0; diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 5bb6c3c31..c7bc10c57 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -3115,11 +3115,14 @@ WOLFSSL_LOCAL int TLSX_PreSharedKey_WriteBinders(PreSharedKey* list, word16* pSz); WOLFSSL_LOCAL int TLSX_PreSharedKey_GetSizeBinders(PreSharedKey* list, byte msgType, word16* pSz); -WOLFSSL_LOCAL int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, +WOLFSSL_LOCAL int TLSX_PreSharedKey_Use(TLSX** extensions, const byte* identity, word16 len, word32 age, byte hmac, byte cipherSuite0, byte cipherSuite, byte resumption, - PreSharedKey **preSharedKey); + PreSharedKey **preSharedKey, + void* heap); +WOLFSSL_LOCAL int TLSX_PreSharedKey_Parse_ClientHello(TLSX** extensions, + const byte* input, word16 length, void* heap); /* The possible Pre-Shared Key key exchange modes. */ enum PskKeyExchangeMode { @@ -6204,6 +6207,12 @@ WOLFSSL_LOCAL int wolfSSL_quic_keys_active(WOLFSSL* ssl, enum encrypt_side side) #define WOLFSSL_IS_QUIC(s) 0 #endif /* WOLFSSL_QUIC (else) */ + +#ifndef NO_PSK +WOLFSSL_LOCAL int FindPskSuite(const WOLFSSL* ssl, PreSharedKey* psk, + byte* psk_key, word32* psk_keySz, byte* suite, int* found); +#endif + #ifdef __cplusplus } /* extern "C" */ #endif