Merge pull request #5353 from rizlik/dtls13_async_fixes

Dtls13 async fixes
This commit is contained in:
David Garske
2022-07-21 13:24:35 -07:00
committed by GitHub
4 changed files with 77 additions and 46 deletions

View File

@ -351,7 +351,7 @@ static void Dtls13MsgWasProcessed(WOLFSSL* ssl, enum HandShakeType hs)
ssl->dtls13Rtx.sendAcks = Dtls13RtxMsgNeedsAck(ssl, hs);
}
static int Dtls13ProcessBufferedMessages(WOLFSSL* ssl)
int Dtls13ProcessBufferedMessages(WOLFSSL* ssl)
{
DtlsMsg* msg = ssl->dtls_rx_msg_list;
word32 idx = 0;
@ -372,15 +372,21 @@ static int Dtls13ProcessBufferedMessages(WOLFSSL* ssl)
ret = DoTls13HandShakeMsgType(ssl, msg->msg, &idx, msg->type, msg->sz,
msg->sz);
/* processing certificate_request triggers a connect. The error came
* from there, the message can be considered processed successfully */
if (ret == 0 || (msg->type == certificate_request &&
ssl->options.handShakeDone && ret == WC_PENDING_E)) {
Dtls13MsgWasProcessed(ssl, (enum HandShakeType)msg->type);
ssl->dtls_rx_msg_list = msg->next;
DtlsMsgDelete(msg, ssl->heap);
msg = ssl->dtls_rx_msg_list;
ssl->dtls_rx_msg_list_sz--;
}
if (ret != 0)
break;
Dtls13MsgWasProcessed(ssl, (enum HandShakeType)msg->type);
ssl->dtls_rx_msg_list = msg->next;
DtlsMsgDelete(msg, ssl->heap);
msg = ssl->dtls_rx_msg_list;
ssl->dtls_rx_msg_list_sz--;
}
WOLFSSL_LEAVE("dtls13_process_buffered_messages()", ret);
@ -432,7 +438,9 @@ static int Dtls13SendNow(WOLFSSL* ssl, enum HandShakeType handshakeType)
if (handshakeType == client_hello || handshakeType == hello_retry_request ||
handshakeType == finished || handshakeType == session_ticket ||
handshakeType == session_ticket || handshakeType == key_update)
handshakeType == session_ticket || handshakeType == key_update ||
(handshakeType == certificate_request &&
ssl->options.handShakeState == HANDSHAKE_DONE))
return 1;
return 0;
@ -523,7 +531,7 @@ static int Dtls13SendFragment(WOLFSSL* ssl, byte* output, word16 output_size,
static void Dtls13FreeFragmentsBuffer(WOLFSSL* ssl)
{
XFREE(ssl->dtls13FragmentsBuffer.buffer, ssl->heap,
DYNAMIC_TYPE_TEMP_BUFFER);
DYNAMIC_TYPE_TMP_BUFFER);
ssl->dtls13FragmentsBuffer.buffer = NULL;
ssl->dtls13SendingFragments = 0;
ssl->dtls13MessageLength = ssl->dtls13FragOffset = 0;
@ -664,7 +672,7 @@ static void Dtls13RtxFlushAcks(WOLFSSL* ssl)
while (list != NULL) {
rn = list;
list = rn->next;
XFREE(rn, ssl->heap, DYNAMIC_TYEP_DTLS_MSG);
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
}
ssl->dtls13Rtx.seenRecords = NULL;
@ -710,7 +718,7 @@ static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl)
if (w64Equal(rn->epoch, ssl->keys.curEpoch64) &&
w64Equal(rn->seq, ssl->keys.curSeq)) {
*prevNext = rn->next;
XFREE(rn, ssl->heap, DYNAMIC_TYEP_DTLS_MSG);
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
return;
}
@ -796,7 +804,8 @@ static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs,
Dtls13RtxRemoveCurAck(ssl);
}
if (ssl->options.dtls13SendMoreAcks && Dtls13DetectDisruption(ssl, fragOffset)) {
if (ssl->options.dtls13SendMoreAcks &&
Dtls13DetectDisruption(ssl, fragOffset)) {
WOLFSSL_MSG("Disruption detected");
ssl->dtls13Rtx.sendAcks = 1;
}
@ -1416,40 +1425,41 @@ static int Dtls13RtxSendBuffered(WOLFSSL* ssl)
static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size,
word32* processedSize)
{
word32 frag_off, frag_length;
word32 fragOff, fragLength;
byte isComplete, isFirst;
word32 message_length;
byte handshake_type;
byte usingAsyncCrypto;
word32 messageLength;
byte handshakeType;
word32 idx;
int ret;
idx = 0;
ret = GetDtlsHandShakeHeader(ssl, input, &idx, &handshake_type,
&message_length, &frag_off, &frag_length, size);
ret = GetDtlsHandShakeHeader(ssl, input, &idx, &handshakeType,
&messageLength, &fragOff, &fragLength, size);
if (ret != 0)
return PARSE_ERROR;
if (idx + frag_length > size) {
if (idx + fragLength > size) {
WOLFSSL_ERROR(INCOMPLETE_DATA);
return INCOMPLETE_DATA;
}
if (frag_off + frag_length > message_length)
if (fragOff + fragLength > messageLength)
return BUFFER_ERROR;
if (handshake_type == client_hello &&
/* Only when receiving an unverified ClientHello */
ssl->options.serverState < SERVER_HELLO_COMPLETE) {
if (handshakeType == client_hello &&
/* Only when receiving an unverified ClientHello */
ssl->options.serverState < SERVER_HELLO_COMPLETE) {
/* To be able to operate in stateless mode, we assume the ClientHello
* is in order and we use its Handshake Message number and Sequence
* Number for our Tx. */
ssl->keys.dtls_expected_peer_handshake_number =
ssl->keys.dtls_handshake_number =
ssl->keys.dtls_peer_handshake_number;
ssl->keys.dtls_handshake_number =
ssl->keys.dtls_peer_handshake_number;
ssl->dtls13Epochs[0].nextSeqNumber = ssl->keys.curSeq;
}
ret = Dtls13RtxMsgRecvd(ssl, (enum HandShakeType)handshake_type, frag_off);
ret = Dtls13RtxMsgRecvd(ssl, (enum HandShakeType)handshakeType, fragOff);
if (ret != 0)
return ret;
@ -1462,40 +1472,42 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size,
#endif /* WOLFSSL_DEBUG_TLS */
/* ignore the message */
*processedSize = idx + frag_length;
*processedSize += ssl->keys.padSz;
*processedSize = idx + fragLength + ssl->keys.padSz;
return 0;
}
isFirst = frag_off == 0;
isComplete = isFirst && frag_length == message_length;
isFirst = fragOff == 0;
isComplete = isFirst && fragLength == messageLength;
usingAsyncCrypto = ssl->devId != INVALID_DEVID;
if (!isComplete || ssl->keys.dtls_peer_handshake_number >
ssl->keys.dtls_expected_peer_handshake_number) {
/* store the message if any of the following: (a) incomplete message, (b)
* out of order message or (c) if using async crypto. In (c) the processing
* of the message can return WC_PENDING_E, it's easier to handle this error
* if the message is stored in the buffer.
*/
if (!isComplete ||
ssl->keys.dtls_peer_handshake_number >
ssl->keys.dtls_expected_peer_handshake_number ||
usingAsyncCrypto) {
DtlsMsgStore(ssl, w64GetLow32(ssl->keys.curEpoch64),
ssl->keys.dtls_peer_handshake_number,
input + DTLS_HANDSHAKE_HEADER_SZ, message_length, handshake_type,
frag_off, frag_length, ssl->heap);
*processedSize = idx + frag_length;
*processedSize += ssl->keys.padSz;
input + DTLS_HANDSHAKE_HEADER_SZ, messageLength, handshakeType,
fragOff, fragLength, ssl->heap);
*processedSize = idx + fragLength + ssl->keys.padSz;
if (Dtls13NextMessageComplete(ssl))
return Dtls13ProcessBufferedMessages(ssl);
return 0;
}
ret = DoTls13HandShakeMsgType(ssl, input, &idx, handshake_type,
message_length, size);
ret = DoTls13HandShakeMsgType(ssl, input, &idx, handshakeType,
messageLength, size);
if (ret != 0)
return ret;
Dtls13MsgWasProcessed(ssl, (enum HandShakeType)handshake_type);
Dtls13MsgWasProcessed(ssl, (enum HandShakeType)handshakeType);
*processedSize = idx;
/* check if we have buffered some message */

View File

@ -15848,7 +15848,8 @@ static int DoDtlsHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx,
/* In async mode always store the message and process it with
* DtlsMsgDrain because in case of a WC_PENDING_E it will be
* easier this way. */
if (ssl->dtls_rx_msg_list_sz < DTLS_POOL_SZ) {
if (ssl->devId != INVALID_DEVID &&
ssl->dtls_rx_msg_list_sz < DTLS_POOL_SZ) {
DtlsMsgStore(ssl, ssl->keys.curEpoch,
ssl->keys.dtls_peer_handshake_number,
input + idx, size, type,
@ -18113,11 +18114,25 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr)
#if defined(WOLFSSL_DTLS) && defined(WOLFSSL_ASYNC_CRYPT)
/* process any pending DTLS messages - this flow can happen with async */
if (ssl->dtls_rx_msg_list != NULL) {
ret = DtlsMsgDrain(ssl);
word32 pendingMsg = ssl->dtls_rx_msg_list_sz;
if(IsAtLeastTLSv1_3(ssl->version)) {
#ifdef WOLFSSL_DTLS13
ret = Dtls13ProcessBufferedMessages(ssl);
#elif
ret = NOT_COMPILED_IN;
#endif /* WOLFSSL_DTLS13 */
}
else {
ret = DtlsMsgDrain(ssl);
}
if (ret != 0) {
WOLFSSL_ERROR(ret);
return ret;
}
/* we processed some messages, return so connect/accept can make
progress */
if (ssl->dtls_rx_msg_list_sz != pendingMsg)
return ret;
}
#endif

View File

@ -9475,8 +9475,11 @@ int DoTls13HandShakeMsgType(WOLFSSL* ssl, byte* input, word32* inOutIdx,
#if defined(WOLFSSL_ASYNC_CRYPT) || defined(WOLFSSL_ASYNC_IO)
/* if async, offset index so this msg will be processed again */
/* NOTE: check this now before other calls can overwirte ret */
/* NOTE: check this now before other calls can overwrite ret */
if ((ret == WC_PENDING_E || ret == OCSP_WANT_READ) && *inOutIdx > 0) {
/* DTLS always stores a message in a buffer when async is enable, so we
* don't need to adjust for the extra bytes here (*inOutIdx is always
* == 0) */
*inOutIdx -= HANDSHAKE_HEADER_SZ;
}
#endif

View File

@ -5479,6 +5479,7 @@ WOLFSSL_LOCAL int Dtls13HashHandshake(WOLFSSL* ssl, const byte* output,
word16 length);
WOLFSSL_LOCAL void Dtls13FreeFsmResources(WOLFSSL* ssl);
WOLFSSL_LOCAL int Dtls13RtxTimeout(WOLFSSL* ssl);
WOLFSSL_LOCAL int Dtls13ProcessBufferedMessages(WOLFSSL* ssl);
#endif /* WOLFSSL_DTLS13 */
#ifdef WOLFSSL_STATIC_EPHEMERAL