Use PSK callback to get the ciphersuite to use

- Allocate additional byte in TLSX_PreSharedKey_New for null terminator
This commit is contained in:
Juliusz Sosinowicz
2023-01-03 17:05:12 +01:00
parent 6160f93f94
commit a999909969
4 changed files with 255 additions and 185 deletions

View File

@@ -435,6 +435,38 @@ static int CopyExtensions(TLSX* src, TLSX** dst, void* heap)
} }
#endif #endif
#if defined(WOLFSSL_DTLS13) && !defined(NO_PSK)
/* Very simplified version of CheckPreSharedKeys to find the current suite */
static void FindPskSuiteFromExt(const WOLFSSL* ssl, TLSX* extensions,
PskInfo* pskInfo, Suites* suites)
{
TLSX* pskExt = TLSX_Find(extensions, TLSX_PRE_SHARED_KEY);
int found = 0;
PreSharedKey* current;
byte psk_key[MAX_PSK_KEY_LEN];
word32 psk_keySz;
int i;
if (pskExt == NULL)
return;
for (i = 0; i < suites->suiteSz; i += 2) {
for (current = (PreSharedKey*)pskExt->data; current != NULL;
current = current->next) {
if (FindPskSuite(ssl, current, psk_key, &psk_keySz,
suites->suites + i, &found) == 0) {
if (found) {
pskInfo->cipherSuite0 = suites->suites[i];
pskInfo->cipherSuite = suites->suites[i + 1];
pskInfo->isValid = 1;
return;
}
}
}
}
}
#endif
static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13, static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
PskInfo* pskInfo) PskInfo* pskInfo)
{ {
@@ -550,6 +582,18 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
* and if they don't match we will error out there anyway. */ * and if they don't match we will error out there anyway. */
byte modes; byte modes;
#ifndef NO_PSK
/* When we didn't find a valid ticket ask the user for the
* ciphersuite matching this identity */
if (!pskInfo->isValid) {
if (TLSX_PreSharedKey_Parse_ClientHello(&parsedExts,
tlsx.elements, tlsx.size, ssl->heap) == 0)
FindPskSuiteFromExt(ssl, parsedExts, pskInfo, &suites);
/* Revert to full handshake if PSK parsing failed */
}
#endif
if (pskInfo->isValid) {
ret = TlsxFindByType(&tlsx, TLSX_PSK_KEY_EXCHANGE_MODES, ret = TlsxFindByType(&tlsx, TLSX_PSK_KEY_EXCHANGE_MODES,
ch->extension); ch->extension);
if (ret != 0) if (ret != 0)
@@ -560,7 +604,8 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
client_hello, &modes); client_hello, &modes);
if (ret != 0) if (ret != 0)
goto dtls13_cleanup; goto dtls13_cleanup;
if ((modes & (1 << PSK_DHE_KE)) && !ssl->options.noPskDheKe) { if ((modes & (1 << PSK_DHE_KE)) &&
!ssl->options.noPskDheKe) {
if (!haveKS) if (!haveKS)
ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup); ERROR_OUT(MISSING_HANDSHAKE_DATA, dtls13_cleanup);
doKE = 1; doKE = 1;
@@ -570,19 +615,13 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
} }
usePSK = 1; usePSK = 1;
} }
}
#endif #endif
#if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK) #if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK)
if (usePSK) { if (usePSK && pskInfo->isValid) {
if (pskInfo->isValid) {
cs.cipherSuite0 = pskInfo->cipherSuite0; cs.cipherSuite0 = pskInfo->cipherSuite0;
cs.cipherSuite = pskInfo->cipherSuite; cs.cipherSuite = pskInfo->cipherSuite;
}
else {
/* Only support the default ciphersuite for PSK */
cs.cipherSuite0 = TLS13_BYTE;
cs.cipherSuite = WOLFSSL_DEF_PSK_CIPHER;
}
if (doKE) { if (doKE) {
byte searched = 0; byte searched = 0;
@@ -609,15 +648,9 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13,
if (ret != 0) if (ret != 0)
goto dtls13_cleanup; goto dtls13_cleanup;
} }
else {
/* Need to remove the keyshare ext if we are not doing PSK and we /* Need to remove the keyshare ext if we found a common group
* found a common group. */ * and are not doing curve negotiation. */
if (cs.clientKSE != NULL
#if defined(HAVE_SESSION_TICKET) || !defined(NO_PSK)
&& !usePSK
#endif
)
{
TLSX_Remove(&parsedExts, TLSX_KEY_SHARE, ssl->heap); TLSX_Remove(&parsedExts, TLSX_KEY_SHARE, ssl->heap);
} }

View File

@@ -9503,27 +9503,17 @@ static int TLSX_PreSharedKey_Write(PreSharedKey* list, byte* output,
return 0; return 0;
} }
/* Parse the pre-shared key extension. int TLSX_PreSharedKey_Parse_ClientHello(TLSX** extensions, const byte* input,
* Different formats in different messages. word16 length, void* heap)
*
* ssl The SSL/TLS object.
* input The extension data.
* length The length of the extension data.
* msgType The type of the message this extension is being parsed from.
* returns 0 on success and other values indicate failure.
*/
static int TLSX_PreSharedKey_Parse(WOLFSSL* ssl, const byte* input,
word16 length, byte msgType)
{ {
TLSX* extension;
PreSharedKey* list;
if (msgType == client_hello) {
int ret; int ret;
word16 len; word16 len;
word16 idx = 0; word16 idx = 0;
TLSX* extension;
PreSharedKey* list;
TLSX_Remove(&ssl->extensions, TLSX_PRE_SHARED_KEY, ssl->heap); TLSX_Remove(extensions, TLSX_PRE_SHARED_KEY, heap);
/* Length of identities and of binders. */ /* Length of identities and of binders. */
if ((int)(length - idx) < OPAQUE16_LEN + OPAQUE16_LEN) if ((int)(length - idx) < OPAQUE16_LEN + OPAQUE16_LEN)
@@ -9557,8 +9547,8 @@ static int TLSX_PreSharedKey_Parse(WOLFSSL* ssl, const byte* input,
ato32(input + idx, &age); ato32(input + idx, &age);
idx += OPAQUE32_LEN; idx += OPAQUE32_LEN;
ret = TLSX_PreSharedKey_Use(ssl, identity, identityLen, age, no_mac, ret = TLSX_PreSharedKey_Use(extensions, identity, identityLen, age, no_mac,
0, 0, 1, NULL); 0, 0, 1, NULL, heap);
if (ret != 0) if (ret != 0)
return ret; return ret;
@@ -9567,7 +9557,7 @@ static int TLSX_PreSharedKey_Parse(WOLFSSL* ssl, const byte* input,
} }
/* Find the list of identities sent to server. */ /* Find the list of identities sent to server. */
extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY); extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY);
if (extension == NULL) if (extension == NULL)
return PSK_KEY_ERROR; return PSK_KEY_ERROR;
list = (PreSharedKey*)extension->data; list = (PreSharedKey*)extension->data;
@@ -9604,6 +9594,27 @@ static int TLSX_PreSharedKey_Parse(WOLFSSL* ssl, const byte* input,
return BUFFER_E; return BUFFER_E;
return 0; return 0;
}
/* Parse the pre-shared key extension.
* Different formats in different messages.
*
* ssl The SSL/TLS object.
* input The extension data.
* length The length of the extension data.
* msgType The type of the message this extension is being parsed from.
* returns 0 on success and other values indicate failure.
*/
static int TLSX_PreSharedKey_Parse(WOLFSSL* ssl, const byte* input,
word16 length, byte msgType)
{
TLSX* extension;
PreSharedKey* list;
if (msgType == client_hello) {
return TLSX_PreSharedKey_Parse_ClientHello(&ssl->extensions, input,
length, ssl->heap);
} }
if (msgType == server_hello) { if (msgType == server_hello) {
@@ -9675,13 +9686,16 @@ static int TLSX_PreSharedKey_New(PreSharedKey** list, const byte* identity,
XMEMSET(psk, 0, sizeof(*psk)); XMEMSET(psk, 0, sizeof(*psk));
/* Make a copy of the identity data. */ /* Make a copy of the identity data. */
psk->identity = (byte*)XMALLOC(len, heap, DYNAMIC_TYPE_TLSX); psk->identity = (byte*)XMALLOC(len + NULL_TERM_LEN, heap,
DYNAMIC_TYPE_TLSX);
if (psk->identity == NULL) { if (psk->identity == NULL) {
XFREE(psk, heap, DYNAMIC_TYPE_TLSX); XFREE(psk, heap, DYNAMIC_TYPE_TLSX);
return MEMORY_E; return MEMORY_E;
} }
XMEMCPY(psk->identity, identity, len); XMEMCPY(psk->identity, identity, len);
psk->identityLen = len; psk->identityLen = len;
/* Use a NULL terminator in case it is a C string */
psk->identity[psk->identityLen] = '\0';
/* Add it to the end and maintain the links. */ /* Add it to the end and maintain the links. */
while (*list != NULL) { while (*list != NULL) {
@@ -9729,24 +9743,24 @@ static WC_INLINE byte GetHmacLength(int hmac)
* preSharedKey The new pre-shared key object. * preSharedKey The new pre-shared key object.
* returns 0 on success and other values indicate failure. * returns 0 on success and other values indicate failure.
*/ */
int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, word16 len, int TLSX_PreSharedKey_Use(TLSX** extensions, const byte* identity, word16 len,
word32 age, byte hmac, byte cipherSuite0, word32 age, byte hmac, byte cipherSuite0,
byte cipherSuite, byte resumption, byte cipherSuite, byte resumption,
PreSharedKey **preSharedKey) PreSharedKey **preSharedKey, void* heap)
{ {
int ret = 0; int ret = 0;
TLSX* extension; TLSX* extension;
PreSharedKey* psk = NULL; PreSharedKey* psk = NULL;
/* Find the pre-shared key extension if it exists. */ /* Find the pre-shared key extension if it exists. */
extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY); extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY);
if (extension == NULL) { if (extension == NULL) {
/* Push new pre-shared key extension. */ /* Push new pre-shared key extension. */
ret = TLSX_Push(&ssl->extensions, TLSX_PRE_SHARED_KEY, NULL, ssl->heap); ret = TLSX_Push(extensions, TLSX_PRE_SHARED_KEY, NULL, heap);
if (ret != 0) if (ret != 0)
return ret; return ret;
extension = TLSX_Find(ssl->extensions, TLSX_PRE_SHARED_KEY); extension = TLSX_Find(*extensions, TLSX_PRE_SHARED_KEY);
if (extension == NULL) if (extension == NULL)
return MEMORY_E; return MEMORY_E;
} }
@@ -9764,7 +9778,7 @@ int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, word16 len,
/* Create a new pre-shared key object if not found. */ /* Create a new pre-shared key object if not found. */
if (psk == NULL) { if (psk == NULL) {
ret = TLSX_PreSharedKey_New((PreSharedKey**)&extension->data, identity, ret = TLSX_PreSharedKey_New((PreSharedKey**)&extension->data, identity,
len, ssl->heap, &psk); len, heap, &psk);
if (ret != 0) if (ret != 0)
return ret; return ret;
} }
@@ -12026,17 +12040,18 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
milli += sess->ticketAdd; milli += sess->ticketAdd;
/* Pre-shared key is mandatory extension for resumption. */ /* Pre-shared key is mandatory extension for resumption. */
ret = TLSX_PreSharedKey_Use(ssl, sess->ticket, sess->ticketLen, ret = TLSX_PreSharedKey_Use(&ssl->extensions, sess->ticket,
milli, ssl->specs.mac_algorithm, ssl->options.cipherSuite0, sess->ticketLen, milli, ssl->specs.mac_algorithm,
ssl->options.cipherSuite, 1, NULL); ssl->options.cipherSuite0, ssl->options.cipherSuite, 1,
NULL, ssl->heap);
#else #else
milli = now - sess->ticketSeen + sess->ticketAdd; milli = now - sess->ticketSeen + sess->ticketAdd;
/* Pre-shared key is mandatory extension for resumption. */ /* Pre-shared key is mandatory extension for resumption. */
ret = TLSX_PreSharedKey_Use(ssl, sess->ticket, sess->ticketLen, ret = TLSX_PreSharedKey_Use(&ssl->extensions, sess->ticket,
(word32)milli, ssl->specs.mac_algorithm, sess->ticketLen, (word32)milli, ssl->specs.mac_algorithm,
ssl->options.cipherSuite0, ssl->options.cipherSuite, 1, ssl->options.cipherSuite0, ssl->options.cipherSuite, 1,
NULL); NULL, ssl->heap);
#endif #endif
if (ret != 0) if (ret != 0)
return ret; return ret;
@@ -12083,11 +12098,11 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
GetCipherNameInternal(cipherSuite0, cipherSuite)); GetCipherNameInternal(cipherSuite0, cipherSuite));
if (keySz > 0) { if (keySz > 0) {
ssl->arrays->psk_keySz = keySz; ssl->arrays->psk_keySz = keySz;
ret = TLSX_PreSharedKey_Use(ssl, ret = TLSX_PreSharedKey_Use(&ssl->extensions,
(byte*)ssl->arrays->client_identity, (byte*)ssl->arrays->client_identity,
(word16)XSTRLEN(ssl->arrays->client_identity), (word16)XSTRLEN(ssl->arrays->client_identity),
0, SuiteMac(WOLFSSL_SUITES(ssl)->suites + i), 0, SuiteMac(WOLFSSL_SUITES(ssl)->suites + i),
cipherSuite0, cipherSuite, 0, NULL); cipherSuite0, cipherSuite, 0, NULL, ssl->heap);
if (ret != 0) if (ret != 0)
return ret; return ret;
#ifdef WOLFSSL_PSK_MULTI_ID_PER_CS #ifdef WOLFSSL_PSK_MULTI_ID_PER_CS
@@ -12150,12 +12165,12 @@ int TLSX_PopulateExtensions(WOLFSSL* ssl, byte isServer)
if (ret != 0) if (ret != 0)
return ret; return ret;
ret = TLSX_PreSharedKey_Use(ssl, ret = TLSX_PreSharedKey_Use(&ssl->extensions,
(byte*)ssl->arrays->client_identity, (byte*)ssl->arrays->client_identity,
(word16)XSTRLEN(ssl->arrays->client_identity), (word16)XSTRLEN(ssl->arrays->client_identity),
0, ssl->specs.mac_algorithm, 0, ssl->specs.mac_algorithm,
cipherSuite0, cipherSuite, 0, cipherSuite0, cipherSuite, 0,
NULL); NULL, ssl->heap);
if (ret != 0) if (ret != 0)
return ret; return ret;

View File

@@ -5496,6 +5496,60 @@ static void RefineSuites(WOLFSSL* ssl, Suites* peerSuites)
#ifndef NO_PSK #ifndef NO_PSK
int FindPskSuite(const WOLFSSL* ssl, PreSharedKey* psk, byte* psk_key,
word32* psk_keySz, byte* suite, int* found)
{
const char* cipherName = NULL;
byte cipherSuite0 = TLS13_BYTE;
byte cipherSuite = WOLFSSL_DEF_PSK_CIPHER;
int ret = 0;
*found = 0;
(void)suite;
if (ssl->options.server_psk_tls13_cb != NULL) {
*psk_keySz = ssl->options.server_psk_tls13_cb((WOLFSSL*)ssl,
(char*)psk->identity, psk_key, MAX_PSK_KEY_LEN, &cipherName);
if (*psk_keySz != 0) {
int cipherSuiteFlags = WOLFSSL_CIPHER_SUITE_FLAG_NONE;
*found = (GetCipherSuiteFromName(cipherName, &cipherSuite0,
&cipherSuite, &cipherSuiteFlags) == 0);
(void)cipherSuiteFlags;
}
}
if (*found == 0 && (ssl->options.server_psk_cb != NULL)) {
*psk_keySz = ssl->options.server_psk_cb((WOLFSSL*)ssl,
(char*)psk->identity, psk_key,
MAX_PSK_KEY_LEN);
*found = (*psk_keySz != 0);
}
if (*found) {
if (*psk_keySz > MAX_PSK_KEY_LEN) {
WOLFSSL_MSG("Key len too long in FindPsk()");
ret = PSK_KEY_ERROR;
WOLFSSL_ERROR_VERBOSE(ret);
}
if (ret == 0) {
#if !defined(WOLFSSL_PSK_ONE_ID) && !defined(WOLFSSL_PRIORITIZE_PSK)
/* Check whether PSK ciphersuite is in SSL. */
*found = (suite[0] == cipherSuite0) && (suite[1] == cipherSuite);
#else
(void)suite;
/* Check whether PSK ciphersuite is in SSL. */
{
byte s[2] = {
cipherSuite0,
cipherSuite,
};
*found = FindSuiteSSL(ssl, s);
}
#endif
}
}
return ret;
}
/* Attempt to find the PSK (not session ticket) that matches. /* Attempt to find the PSK (not session ticket) that matches.
* *
* @param [in, out] ssl The SSL/TLS object. * @param [in, out] ssl The SSL/TLS object.
@@ -5509,55 +5563,14 @@ static void RefineSuites(WOLFSSL* ssl, Suites* peerSuites)
* @return 1 when a match found - but check error code. * @return 1 when a match found - but check error code.
* @return 0 when no match found. * @return 0 when no match found.
*/ */
static int FindPsk(WOLFSSL* ssl, PreSharedKey* psk, const byte* suite, int* err) static int FindPsk(WOLFSSL* ssl, PreSharedKey* psk, byte* suite, int* err)
{ {
int ret = 0; int ret = 0;
int found = 0; int found = 0;
const char* cipherName = NULL;
byte cipherSuite0 = TLS13_BYTE;
byte cipherSuite = WOLFSSL_DEF_PSK_CIPHER;
Arrays* sa = ssl->arrays;
(void)suite; ret = FindPskSuite(ssl, psk, ssl->arrays->psk_key, &ssl->arrays->psk_keySz,
suite, &found);
if (ssl->options.server_psk_tls13_cb != NULL) { if (ret == 0 && found) {
sa->psk_keySz = ssl->options.server_psk_tls13_cb(ssl,
sa->client_identity, sa->psk_key, MAX_PSK_KEY_LEN, &cipherName);
if (sa->psk_keySz != 0) {
int cipherSuiteFlags = WOLFSSL_CIPHER_SUITE_FLAG_NONE;
found = (GetCipherSuiteFromName(cipherName, &cipherSuite0,
&cipherSuite, &cipherSuiteFlags) == 0);
(void)cipherSuiteFlags;
}
}
if (!found && (ssl->options.server_psk_cb != NULL)) {
sa->psk_keySz = ssl->options.server_psk_cb(ssl,
sa->client_identity, sa->psk_key,
MAX_PSK_KEY_LEN);
found = (sa->psk_keySz != 0);
}
if (found) {
if (sa->psk_keySz > MAX_PSK_KEY_LEN) {
WOLFSSL_MSG("Key len too long in FindPsk");
ret = PSK_KEY_ERROR;
WOLFSSL_ERROR_VERBOSE(ret);
}
if (ret == 0) {
#if !defined(WOLFSSL_PSK_ONE_ID) && !defined(WOLFSSL_PRIORITIZE_PSK)
/* Check whether PSK ciphersuite is in SSL. */
found = (suite[0] == cipherSuite0) && (suite[1] == cipherSuite);
#else
(void)suite;
/* Check whether PSK ciphersuite is in SSL. */
{
byte s[2] = {
cipherSuite0,
cipherSuite,
};
found = FindSuiteSSL(ssl, s);
}
#endif
}
if ((ret == 0) && found) { if ((ret == 0) && found) {
/* Default to ciphersuite if cb doesn't specify. */ /* Default to ciphersuite if cb doesn't specify. */
ssl->options.resuming = 0; ssl->options.resuming = 0;

View File

@@ -3115,11 +3115,14 @@ WOLFSSL_LOCAL int TLSX_PreSharedKey_WriteBinders(PreSharedKey* list,
word16* pSz); word16* pSz);
WOLFSSL_LOCAL int TLSX_PreSharedKey_GetSizeBinders(PreSharedKey* list, WOLFSSL_LOCAL int TLSX_PreSharedKey_GetSizeBinders(PreSharedKey* list,
byte msgType, word16* pSz); byte msgType, word16* pSz);
WOLFSSL_LOCAL int TLSX_PreSharedKey_Use(WOLFSSL* ssl, const byte* identity, WOLFSSL_LOCAL int TLSX_PreSharedKey_Use(TLSX** extensions, const byte* identity,
word16 len, word32 age, byte hmac, word16 len, word32 age, byte hmac,
byte cipherSuite0, byte cipherSuite, byte cipherSuite0, byte cipherSuite,
byte resumption, byte resumption,
PreSharedKey **preSharedKey); PreSharedKey **preSharedKey,
void* heap);
WOLFSSL_LOCAL int TLSX_PreSharedKey_Parse_ClientHello(TLSX** extensions,
const byte* input, word16 length, void* heap);
/* The possible Pre-Shared Key key exchange modes. */ /* The possible Pre-Shared Key key exchange modes. */
enum PskKeyExchangeMode { enum PskKeyExchangeMode {
@@ -6204,6 +6207,12 @@ WOLFSSL_LOCAL int wolfSSL_quic_keys_active(WOLFSSL* ssl, enum encrypt_side side)
#define WOLFSSL_IS_QUIC(s) 0 #define WOLFSSL_IS_QUIC(s) 0
#endif /* WOLFSSL_QUIC (else) */ #endif /* WOLFSSL_QUIC (else) */
#ifndef NO_PSK
WOLFSSL_LOCAL int FindPskSuite(const WOLFSSL* ssl, PreSharedKey* psk,
byte* psk_key, word32* psk_keySz, byte* suite, int* found);
#endif
#ifdef __cplusplus #ifdef __cplusplus
} /* extern "C" */ } /* extern "C" */
#endif #endif