Use PSK callback to get the ciphersuite to use

- Allocate additional byte in TLSX_PreSharedKey_New for null terminator
This commit is contained in:
Juliusz Sosinowicz
2023-01-03 17:05:12 +01:00
parent 6160f93f94
commit a999909969
4 changed files with 255 additions and 185 deletions

View File

@ -435,6 +435,38 @@ static int CopyExtensions(TLSX* src, TLSX** dst, void* heap)
}
#endif
#if defined(WOLFSSL_DTLS13) && !defined(NO_PSK)
/* Very simplified version of CheckPreSharedKeys to find the current suite */
static void FindPskSuiteFromExt(const WOLFSSL* ssl, TLSX* extensions,
PskInfo* pskInfo, Suites* suites)
{
TLSX* pskExt = TLSX_Find(extensions, TLSX_PRE_SHARED_KEY);
int found = 0;
PreSharedKey* current;
byte psk_key[MAX_PSK_KEY_LEN];
word32 psk_keySz;
int i;
if (pskExt == NULL)
return;
for (i = 0; i < suites->suiteSz; i += 2) {
for (current = (PreSharedKey*)pskExt->data; current != NULL;
current = current->next) {
if (FindPskSuite(ssl, current, psk_key, &psk_keySz,
suites->suites + i, &found) == 0) {
if (found) {
pskInfo->cipherSuite0 = suites->suites[i];
pskInfo->cipherSuite = suites->suites[i + 1];
pskInfo->isValid = 1;
return;
}
}
}
}
}
#endif
static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
PskInfo* pskInfo)
{
@ -550,39 +582,46 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
* and if they don't match we will error out there anyway. */
byte modes;
ret = TlsxFindByType(&tlsx, TLSX_PSK_KEY_EXCHANGE_MODES,
ch->extension);
if (ret != 0)
goto dtls13_cleanup;
if (tlsx.size == 0)
ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup);
ret = TLSX_PskKeyModes_Parse_Modes(tlsx.elements, tlsx.size,
client_hello, &modes);
if (ret != 0)
goto dtls13_cleanup;
if ((modes & (1 << PSK_DHE_KE)) && !ssl->options.noPskDheKe) {
if (!haveKS)
#ifndef NO_PSK
/* When we didn't find a valid ticket ask the user for the
* ciphersuite matching this identity */
if (!pskInfo->isValid) {
if (TLSX_PreSharedKey_Parse_ClientHello(&parsedExts,
tlsx.elements, tlsx.size, ssl->heap) == 0)
FindPskSuiteFromExt(ssl, parsedExts, pskInfo, &suites);
/* Revert to full handshake if PSK parsing failed */
}
#endif
if (pskInfo->isValid) {
ret = TlsxFindByType(&tlsx, TLSX_PSK_KEY_EXCHANGE_MODES,
ch->extension);
if (ret != 0)
goto dtls13_cleanup;
if (tlsx.size == 0)
ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup);
doKE = 1;
ret = TLSX_PskKeyModes_Parse_Modes(tlsx.elements, tlsx.size,
client_hello, &modes);
if (ret != 0)
goto dtls13_cleanup;
if ((modes & (1 << PSK_DHE_KE)) &&
!ssl->options.noPskDheKe) {
if (!haveKS)
ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup);
doKE = 1;
}
else if ((modes & (1 << PSK_KE)) == 0) {
ERROR_OUT(PSK_KEY_ERROR, dtls13_cleanup);
}
usePSK = 1;
}
else if ((modes & (1 << PSK_KE)) == 0) {
ERROR_OUT(PSK_KEY_ERROR, dtls13_cleanup);
}
usePSK = 1;
}
#endif
#if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK)
if (usePSK) {
if (pskInfo->isValid) {
cs.cipherSuite0 = pskInfo->cipherSuite0;
cs.cipherSuite = pskInfo->cipherSuite;
}
else {
/* Only support the default ciphersuite for PSK */
cs.cipherSuite0 = TLS13_BYTE;
cs.cipherSuite = WOLFSSL_DEF_PSK_CIPHER;
}
if (usePSK && pskInfo->isValid) {
cs.cipherSuite0 = pskInfo->cipherSuite0;
cs.cipherSuite = pskInfo->cipherSuite;
if (doKE) {
byte searched = 0;
@ -609,15 +648,9 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
if (ret != 0)
goto dtls13_cleanup;
}
/* Need to remove the keyshare ext if we are not doing PSK and we
* found a common group. */
if (cs.clientKSE != NULL
#if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK)
&& !usePSK
#endif
)
{
else {
/* Need to remove the keyshare ext if we found a common group
* and are not doing curve negotiation. */
TLSX_Remove(&parsedExts, TLSX_KEY_SHARE, ssl->heap);
}

219
src/tls.c
View File

@ -9503,6 +9503,100 @@ static int TLSX_PreSharedKey_Write(PreSharedKey* list, byte* output,
return 0;
}
int TLSX_PreSharedKey_Parse_ClientHello(TLSX** extensions, const byte* input,
word16 length, void* heap)
{
int ret;
word16 len;
word16 idx = 0;
TLSX* extension;
PreSharedKey* list;
TLSX_Remove(extensions, TLSX_PRE_SHARED_KEY, heap);
/* Length of identities and of binders. */
if ((int)(length - idx) < OPAQUE16_LEN + OPAQUE16_LEN)
return BUFFER_E;
/* Length of identities. */
ato16(input + idx, &len);
idx += OPAQUE16_LEN;
if (len < MIN_PSK_ID_LEN || length - idx < len)
return BUFFER_E;
/* Create a pre-shared key object for each identity. */
while (len > 0) {
const byte* identity;
word16 identityLen;
word32 age;
if (len < OPAQUE16_LEN)
return BUFFER_E;
/* Length of identity. */
ato16(input + idx, &identityLen);
idx += OPAQUE16_LEN;
if (len < OPAQUE16_LEN + identityLen + OPAQUE32_LEN ||
identityLen > MAX_PSK_ID_LEN)
return BUFFER_E;
/* Cache identity pointer. */
identity = input + idx;
idx += identityLen;
/* Ticket age. */
ato32(input + idx, &age);
idx += OPAQUE32_LEN;
ret = TLSX_PreSharedKey_Use(extensions, identity, identityLen, age, no_mac,
0, 0, 1, NULL, heap);
if (ret != 0)
return ret;
/* Done with this identity. */
len -= OPAQUE16_LEN + identityLen + OPAQUE32_LEN;
}
/* Find the list of identities sent to server. */
extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY);
if (extension == NULL)
return PSK_KEY_ERROR;
list = (PreSharedKey*)extension->data;
/* Length of binders. */
if (idx + OPAQUE16_LEN > length)
return BUFFER_E;
ato16(input + idx, &len);
idx += OPAQUE16_LEN;
if (len < MIN_PSK_BINDERS_LEN || length - idx < len)
return BUFFER_E;
/* Set binder for each identity. */
while (list != NULL && len > 0) {
/* Length of binder */
list->binderLen = input[idx++];
if (list->binderLen < WC_SHA256_DIGEST_SIZE ||
list->binderLen > WC_MAX_DIGEST_SIZE)
return BUFFER_E;
if (len < OPAQUE8_LEN + list->binderLen)
return BUFFER_E;
/* Copy binder into static buffer. */
XMEMCPY(list->binder, input + idx, list->binderLen);
idx += (word16)list->binderLen;
/* Done with binder entry. */
len -= OPAQUE8_LEN + (word16)list->binderLen;
/* Next identity. */
list = list->next;
}
if (list != NULL || len != 0)
return BUFFER_E;
return 0;
}
/* Parse the pre-shared key extension.
* Different formats in different messages.
*
@ -9519,91 +9613,8 @@ static int TLSX_PreSharedKey_Parse(WOLFSSL* ssl, const byte* input,
PreSharedKey* list;
if (msgType == client_hello) {
int ret;
word16 len;
word16 idx = 0;
TLSX_Remove(&ssl->extensions, TLSX_PRE_SHARED_KEY, ssl->heap);
/* Length of identities and of binders. */
if ((int)(length - idx) < OPAQUE16_LEN + OPAQUE16_LEN)
return BUFFER_E;
/* Length of identities. */
ato16(input + idx, &len);
idx += OPAQUE16_LEN;
if (len < MIN_PSK_ID_LEN || length - idx < len)
return BUFFER_E;
/* Create a pre-shared key object for each identity. */
while (len > 0) {
const byte* identity;
word16 identityLen;
word32 age;
if (len < OPAQUE16_LEN)
return BUFFER_E;
/* Length of identity. */
ato16(input + idx, &identityLen);
idx += OPAQUE16_LEN;
if (len < OPAQUE16_LEN + identityLen + OPAQUE32_LEN ||
identityLen > MAX_PSK_ID_LEN)
return BUFFER_E;
/* Cache identity pointer. */
identity = input + idx;
idx += identityLen;
/* Ticket age. */
ato32(input + idx, &age);
idx += OPAQUE32_LEN;
ret = TLSX_PreSharedKey_Use(ssl, identity, identityLen, age, no_mac,
0, 0, 1, NULL);
if (ret != 0)
return ret;
/* Done with this identity. */
len -= OPAQUE16_LEN + identityLen + OPAQUE32_LEN;
}
/* Find the list of identities sent to server. */
extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY);
if (extension == NULL)
return PSK_KEY_ERROR;
list = (PreSharedKey*)extension->data;
/* Length of binders. */
if (idx + OPAQUE16_LEN > length)
return BUFFER_E;
ato16(input + idx, &len);
idx += OPAQUE16_LEN;
if (len < MIN_PSK_BINDERS_LEN || length - idx < len)
return BUFFER_E;
/* Set binder for each identity. */
while (list != NULL && len > 0) {
/* Length of binder */
list->binderLen = input[idx++];
if (list->binderLen < WC_SHA256_DIGEST_SIZE ||
list->binderLen > WC_MAX_DIGEST_SIZE)
return BUFFER_E;
if (len < OPAQUE8_LEN + list->binderLen)
return BUFFER_E;
/* Copy binder into static buffer. */
XMEMCPY(list->binder, input + idx, list->binderLen);
idx += (word16)list->binderLen;
/* Done with binder entry. */
len -= OPAQUE8_LEN + (word16)list->binderLen;
/* Next identity. */
list = list->next;
}
if (list != NULL || len != 0)
return BUFFER_E;
return 0;
return TLSX_PreSharedKey_Parse_ClientHello(&ssl->extensions, input,
length, ssl->heap);
}
if (msgType == server_hello) {
@ -9675,13 +9686,16 @@ static int TLSX_PreSharedKey_New(PreSharedKey** list, const byte* identity,
XMEMSET(psk, 0, sizeof(*psk));
/* Make a copy of the identity data. */
psk->identity = (byte*)XMALLOC(len, heap, DYNAMIC_TYPE_TLSX);
psk->identity = (byte*)XMALLOC(len + NULL_TERM_LEN, heap,
DYNAMIC_TYPE_TLSX);
if (psk->identity == NULL) {
XFREE(psk, heap, DYNAMIC_TYPE_TLSX);
return MEMORY_E;
}
XMEMCPY(psk->identity, identity, len);
psk->identityLen = len;
/* Use a NULL terminator in case it is a C string */
psk->identity[psk->identityLen] = '\0';
/* Add it to the end and maintain the links. */
while (*list != NULL) {
@ -9729,24 +9743,24 @@ static WC_INLINE byte GetHmacLength(int hmac)
* preSharedKey The new pre-shared key object.
* returns 0 on success and other values indicate failure.
*/
int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, word16 len,
int TLSX_PreSharedKey_Use(TLSX** extensions, const byte* identity, word16 len,
word32 age, byte hmac, byte cipherSuite0,
byte cipherSuite, byte resumption,
PreSharedKey **preSharedKey)
PreSharedKey **preSharedKey, void* heap)
{
int ret = 0;
TLSX* extension;
PreSharedKey* psk = NULL;
/* Find the pre-shared key extension if it exists. */
extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY);
extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY);
if (extension == NULL) {
/* Push new pre-shared key extension. */
ret = TLSX_Push(&ssl->extensions, TLSX_PRE_SHARED_KEY, NULL, ssl->heap);
ret = TLSX_Push(extensions, TLSX_PRE_SHARED_KEY, NULL, heap);
if (ret != 0)
return ret;
extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY);
extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY);
if (extension == NULL)
return MEMORY_E;
}
@ -9764,7 +9778,7 @@ int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, word16 len,
/* Create a new pre-shared key object if not found. */
if (psk == NULL) {
ret = TLSX_PreSharedKey_New((PreSharedKey**)&extension->data, identity,
len, ssl->heap, &psk);
len, heap, &psk);
if (ret != 0)
return ret;
}
@ -12026,17 +12040,18 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
milli += sess->ticketAdd;
/* Pre-shared key is mandatory extension for resumption. */
ret = TLSX_PreSharedKey_Use(ssl, sess->ticket, sess->ticketLen,
milli, ssl->specs.mac_algorithm, ssl->options.cipherSuite0,
ssl->options.cipherSuite, 1, NULL);
ret = TLSX_PreSharedKey_Use(&ssl->extensions, sess->ticket,
sess->ticketLen, milli, ssl->specs.mac_algorithm,
ssl->options.cipherSuite0, ssl->options.cipherSuite, 1,
NULL, ssl->heap);
#else
milli = now - sess->ticketSeen + sess->ticketAdd;
/* Pre-shared key is mandatory extension for resumption. */
ret = TLSX_PreSharedKey_Use(ssl, sess->ticket, sess->ticketLen,
(word32)milli, ssl->specs.mac_algorithm,
ret = TLSX_PreSharedKey_Use(&ssl->extensions, sess->ticket,
sess->ticketLen, (word32)milli, ssl->specs.mac_algorithm,
ssl->options.cipherSuite0, ssl->options.cipherSuite, 1,
NULL);
NULL, ssl->heap);
#endif
if (ret != 0)
return ret;
@ -12083,11 +12098,11 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
GetCipherNameInternal(cipherSuite0, cipherSuite));
if (keySz > 0) {
ssl->arrays->psk_keySz = keySz;
ret = TLSX_PreSharedKey_Use(ssl,
ret = TLSX_PreSharedKey_Use(&ssl->extensions,
(byte*)ssl->arrays->client_identity,
(word16)XSTRLEN(ssl->arrays->client_identity),
0, SuiteMac(WOLFSSL_SUITES(ssl)->suites + i),
cipherSuite0, cipherSuite, 0, NULL);
cipherSuite0, cipherSuite, 0, NULL, ssl->heap);
if (ret != 0)
return ret;
#ifdef WOLFSSL_PSK_MULTI_ID_PER_CS
@ -12150,12 +12165,12 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
if (ret != 0)
return ret;
ret = TLSX_PreSharedKey_Use(ssl,
ret = TLSX_PreSharedKey_Use(&ssl->extensions,
(byte*)ssl->arrays->client_identity,
(word16)XSTRLEN(ssl->arrays->client_identity),
0, ssl->specs.mac_algorithm,
cipherSuite0, cipherSuite, 0,
NULL);
NULL, ssl->heap);
if (ret != 0)
return ret;

View File

@ -5496,6 +5496,60 @@ static void RefineSuites(WOLFSSL* ssl, Suites* peerSuites)
#ifndef NO_PSK
int FindPskSuite(const WOLFSSL* ssl, PreSharedKey* psk, byte* psk_key,
word32* psk_keySz, byte* suite, int* found)
{
const char* cipherName = NULL;
byte cipherSuite0 = TLS13_BYTE;
byte cipherSuite = WOLFSSL_DEF_PSK_CIPHER;
int ret = 0;
*found = 0;
(void)suite;
if (ssl->options.server_psk_tls13_cb != NULL) {
*psk_keySz = ssl->options.server_psk_tls13_cb((WOLFSSL*)ssl,
(char*)psk->identity, psk_key, MAX_PSK_KEY_LEN, &cipherName);
if (*psk_keySz != 0) {
int cipherSuiteFlags = WOLFSSL_CIPHER_SUITE_FLAG_NONE;
*found = (GetCipherSuiteFromName(cipherName, &cipherSuite0,
&cipherSuite, &cipherSuiteFlags) == 0);
(void)cipherSuiteFlags;
}
}
if (*found == 0 && (ssl->options.server_psk_cb != NULL)) {
*psk_keySz = ssl->options.server_psk_cb((WOLFSSL*)ssl,
(char*)psk->identity, psk_key,
MAX_PSK_KEY_LEN);
*found = (*psk_keySz != 0);
}
if (*found) {
if (*psk_keySz > MAX_PSK_KEY_LEN) {
WOLFSSL_MSG("Key len too long in FindPsk()");
ret = PSK_KEY_ERROR;
WOLFSSL_ERROR_VERBOSE(ret);
}
if (ret == 0) {
#if !defined(WOLFSSL_PSK_ONE_ID) && !defined(WOLFSSL_PRIORITIZE_PSK)
/* Check whether PSK ciphersuite is in SSL. */
*found = (suite[0] == cipherSuite0) && (suite[1] == cipherSuite);
#else
(void)suite;
/* Check whether PSK ciphersuite is in SSL. */
{
byte s[2] = {
cipherSuite0,
cipherSuite,
};
*found = FindSuiteSSL(ssl, s);
}
#endif
}
}
return ret;
}
/* Attempt to find the PSK (not session ticket) that matches.
*
* @param [in, out] ssl The SSL/TLS object.
@ -5509,55 +5563,14 @@ static void RefineSuites(WOLFSSL* ssl, Suites* peerSuites)
* @return 1 when a match found - but check error code.
* @return 0 when no match found.
*/
static int FindPsk(WOLFSSL* ssl, PreSharedKey* psk, const byte* suite, int* err)
static int FindPsk(WOLFSSL* ssl, PreSharedKey* psk, byte* suite, int* err)
{
int ret = 0;
int found = 0;
const char* cipherName = NULL;
byte cipherSuite0 = TLS13_BYTE;
byte cipherSuite = WOLFSSL_DEF_PSK_CIPHER;
Arrays* sa = ssl->arrays;
(void)suite;
if (ssl->options.server_psk_tls13_cb != NULL) {
sa->psk_keySz = ssl->options.server_psk_tls13_cb(ssl,
sa->client_identity, sa->psk_key, MAX_PSK_KEY_LEN, &cipherName);
if (sa->psk_keySz != 0) {
int cipherSuiteFlags = WOLFSSL_CIPHER_SUITE_FLAG_NONE;
found = (GetCipherSuiteFromName(cipherName, &cipherSuite0,
&cipherSuite, &cipherSuiteFlags) == 0);
(void)cipherSuiteFlags;
}
}
if (!found && (ssl->options.server_psk_cb != NULL)) {
sa->psk_keySz = ssl->options.server_psk_cb(ssl,
sa->client_identity, sa->psk_key,
MAX_PSK_KEY_LEN);
found = (sa->psk_keySz != 0);
}
if (found) {
if (sa->psk_keySz > MAX_PSK_KEY_LEN) {
WOLFSSL_MSG("Key len too long in FindPsk");
ret = PSK_KEY_ERROR;
WOLFSSL_ERROR_VERBOSE(ret);
}
if (ret == 0) {
#if !defined(WOLFSSL_PSK_ONE_ID) && !defined(WOLFSSL_PRIORITIZE_PSK)
/* Check whether PSK ciphersuite is in SSL. */
found = (suite[0] == cipherSuite0) && (suite[1] == cipherSuite);
#else
(void)suite;
/* Check whether PSK ciphersuite is in SSL. */
{
byte s[2] = {
cipherSuite0,
cipherSuite,
};
found = FindSuiteSSL(ssl, s);
}
#endif
}
ret = FindPskSuite(ssl, psk, ssl->arrays->psk_key, &ssl->arrays->psk_keySz,
suite, &found);
if (ret == 0 && found) {
if ((ret == 0) && found) {
/* Default to ciphersuite if cb doesn't specify. */
ssl->options.resuming = 0;

View File

@ -3115,11 +3115,14 @@ WOLFSSL_LOCAL int TLSX_PreSharedKey_WriteBinders(PreSharedKey* list,
word16* pSz);
WOLFSSL_LOCAL int TLSX_PreSharedKey_GetSizeBinders(PreSharedKey* list,
byte msgType, word16* pSz);
WOLFSSL_LOCAL int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity,
WOLFSSL_LOCAL int TLSX_PreSharedKey_Use(TLSX** extensions, const byte* identity,
word16 len, word32 age, byte hmac,
byte cipherSuite0, byte cipherSuite,
byte resumption,
PreSharedKey **preSharedKey);
PreSharedKey **preSharedKey,
void* heap);
WOLFSSL_LOCAL int TLSX_PreSharedKey_Parse_ClientHello(TLSX** extensions,
const byte* input, word16 length, void* heap);
/* The possible Pre-Shared Key key exchange modes. */
enum PskKeyExchangeMode {
@ -6204,6 +6207,12 @@ WOLFSSL_LOCAL int wolfSSL_quic_keys_active(WOLFSSL* ssl, enum encrypt_side side)
#define WOLFSSL_IS_QUIC(s) 0
#endif /* WOLFSSL_QUIC (else) */
#ifndef NO_PSK
WOLFSSL_LOCAL int FindPskSuite(const WOLFSSL* ssl, PreSharedKey* psk,
byte* psk_key, word32* psk_keySz, byte* suite, int* found);
#endif
#ifdef __cplusplus
} /* extern "C" */
#endif