diff --git a/doc/dox_comments/header_files/ssl.h b/doc/dox_comments/header_files/ssl.h index 681f2cee0..a1f498117 100644 --- a/doc/dox_comments/header_files/ssl.h +++ b/doc/dox_comments/header_files/ssl.h @@ -3326,6 +3326,47 @@ int wolfSSL_dtls_get_using_nonblock(WOLFSSL*); \sa wolfSSL_dtls_set_peer */ int wolfSSL_dtls_get_current_timeout(WOLFSSL* ssl); +/*! + \brief This function returns true if the application should setup a quicker + timeout. When using non-blocking sockets, something in the user code needs + to decide when to check for available data and how long it needs to wait. If + this function returns true, it means that the library already detected some + disruption in the communication, but it wants to wait for a little longer in + case some messages from the other peers are still in flight. Is up to the + application to fine tune the value of this timer, a good one may be + dtls_get_current_timeout() / 4. + + \return true if the application code should setup a quicker timeout + + \param ssl a pointer to a WOLFSSL structure, created using wolfSSL_new(). + + \sa wolfSSL_dtls + \sa wolfSSL_dtls_get_peer + \sa wolfSSL_dtls_got_timeout + \sa wolfSSL_dtls_set_peer + \sa wolfSSL_dtls13_set_send_more_acks +*/ +int wolfSSL_dtls13_use_quick_timeout(WOLFSSL *ssl); +/*! + \ingroup Setup + + \brief This function sets whether the library should send ACKs to the other + peer immediately when detecting disruption or not. Sending ACKs immediately + assures minimum latency but it may consume more bandwidth than necessary. If + the application manages the timer by itself and this option is set to 0 then + application code can use wolfSSL_dtls13_use_quick_timeout() to determine if + it should setup a quicker timeout to send those delayed ACKs. + + \param ssl a pointer to a WOLFSSL structure, created using wolfSSL_new(). + \param value 1 to set the option, 0 to disable the option + + \sa wolfSSL_dtls + \sa wolfSSL_dtls_get_peer + \sa wolfSSL_dtls_got_timeout + \sa wolfSSL_dtls_set_peer + \sa wolfSSL_dtls13_use_quick_timeout +*/ +void wolfSSL_dtls13_set_send_more_acks(WOLFSSL *ssl, int value); /*! \ingroup Setup diff --git a/src/dtls13.c b/src/dtls13.c index acc1a8686..abfe8aa9e 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -139,6 +139,7 @@ typedef struct Dtls13RecordCiphertextHeader { supported. */ #define DTLS13_UNIFIED_HEADER_SIZE 5 #define DTLS13_MIN_CIPHERTEXT 16 +#define DTLS13_MIN_RTX_INTERVAL 1 WOLFSSL_METHOD* wolfDTLSv1_3_client_method_ex(void* heap) { @@ -320,11 +321,31 @@ static int Dtls13EncryptDecryptRecordNumber(WOLFSSL* ssl, byte* seq, return 0; } +static byte Dtls13RtxMsgNeedsAck(WOLFSSL* ssl, enum HandShakeType hs) +{ + +#ifndef NO_WOLFSSL_SERVER + /* we send an ACK when processing the finished message. In this case either + we already sent an ACK for client's Certificate/CertificateVerify or they + are in our list of seen records and will be included in the ACK + message */ + if (ssl->options.side == WOLFSSL_SERVER_END && (hs == finished)) + return 1; +#endif /* NO_WOLFSSL_SERVER */ + + if (hs == session_ticket || hs == key_update) + return 1; + + return 0; +} + static void Dtls13MsgWasProcessed(WOLFSSL* ssl, enum HandShakeType hs) { - (void)hs; - ssl->keys.dtls_expected_peer_handshake_number++; + + /* we need to send ACKs on the last message of a flight that needs explicit + acknowledgment */ + ssl->dtls13Rtx.sendAcks = Dtls13RtxMsgNeedsAck(ssl, hs); } static int Dtls13ProcessBufferedMessages(WOLFSSL* ssl) @@ -505,6 +526,285 @@ static void Dtls13FreeFragmentsBuffer(WOLFSSL* ssl) ssl->dtls13MessageLength = ssl->dtls13FragOffset = 0; } +static WC_INLINE void Dtls13FreeRtxBufferRecord(WOLFSSL* ssl, + Dtls13RtxRecord* r) +{ + (void)ssl; + + XFREE(r->data, ssl->heap, DYNAMIC_TYPE_DTLS_MSG); + XFREE(r, ssl->heap, DYNAMIC_TYPE_DTLS_MSG); +} + +static Dtls13RtxRecord* Dtls13RtxNewRecord(WOLFSSL* ssl, byte* data, + word16 length, enum HandShakeType handshakeType, w64wrapper seq) +{ + w64wrapper epochNumber; + Dtls13RtxRecord* r; + + WOLFSSL_ENTER("Dtls13RtxNewRecord"); + + if (ssl->dtls13EncryptEpoch == NULL) + return NULL; + + epochNumber = ssl->dtls13EncryptEpoch->epochNumber; + + r = (Dtls13RtxRecord*)XMALLOC(sizeof(*r), ssl->heap, DYNAMIC_TYPE_DTLS_MSG); + if (r == NULL) + return NULL; + + r->data = (byte*)XMALLOC(length, ssl->heap, DYNAMIC_TYPE_DTLS_MSG); + if (r->data == NULL) { + XFREE(r, ssl->heap, DYNAMIC_TYPE_DTLS_MSG); + return NULL; + } + + XMEMCPY(r->data, data, length); + r->epoch = epochNumber; + r->length = length; + r->next = NULL; + r->handshakeType = handshakeType; + r->seq[0] = seq; + r->rnIdx = 1; + + return r; +} + +static void Dtls13RtxAddRecord(Dtls13Rtx* fsm, Dtls13RtxRecord* r) +{ + WOLFSSL_ENTER("Dtls13RtxAddRecord"); + + *fsm->rtxRecordTailPtr = r; + fsm->rtxRecordTailPtr = &r->next; + r->next = NULL; +} + +static void Dtls13RtxRecordUnlink(WOLFSSL* ssl, Dtls13RtxRecord** prevNext, + Dtls13RtxRecord* r) +{ + /* if r was at the tail of the list, update the tail pointer */ + if (r->next == NULL) + ssl->dtls13Rtx.rtxRecordTailPtr = prevNext; + + /* unlink */ + *prevNext = r->next; +} + +static void Dtls13RtxFlushBuffered(WOLFSSL* ssl, byte keepNewSessionTicket) +{ + Dtls13RtxRecord *r, **prevNext; + + WOLFSSL_ENTER("Dtls13RtxFlushBuffered"); + + prevNext = &ssl->dtls13Rtx.rtxRecords; + r = ssl->dtls13Rtx.rtxRecords; + + /* we process the head at the end */ + while (r != NULL) { + + if (keepNewSessionTicket && r->handshakeType == session_ticket) { + prevNext = &r->next; + r = r->next; + continue; + } + + *prevNext = r->next; + Dtls13FreeRtxBufferRecord(ssl, r); + r = *prevNext; + } + + ssl->dtls13Rtx.rtxRecordTailPtr = prevNext; +} + +static Dtls13RecordNumber* Dtls13NewRecordNumber(WOLFSSL* ssl, w64wrapper epoch, + w64wrapper seq) +{ + Dtls13RecordNumber* rn; + + rn = (Dtls13RecordNumber*)XMALLOC(sizeof(*rn), ssl->heap, + DYNAMIC_TYPE_DTLS_MSG); + if (rn == NULL) + return NULL; + + rn->next = NULL; + rn->epoch = epoch; + rn->seq = seq; + + return rn; +} + +static int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq) +{ + Dtls13RecordNumber* rn; + + WOLFSSL_ENTER("Dtls13RtxAddAck"); + + rn = Dtls13NewRecordNumber(ssl, epoch, seq); + if (rn == NULL) + return MEMORY_E; + + rn->next = ssl->dtls13Rtx.seenRecords; + ssl->dtls13Rtx.seenRecords = rn; + + return 0; +} + +static void Dtls13RtxFlushAcks(WOLFSSL* ssl) +{ + Dtls13RecordNumber *list, *rn; + + (void)ssl; + + WOLFSSL_ENTER("Dtls13RtxFlushAcks"); + + list = ssl->dtls13Rtx.seenRecords; + + while (list != NULL) { + rn = list; + list = rn->next; + XFREE(rn, ssl->heap, DYNAMIC_TYEP_DTLS_MSG); + } + + ssl->dtls13Rtx.seenRecords = NULL; +} + +static int Dtls13DetectDisruption(WOLFSSL* ssl, word32 fragOffset) +{ + /* retransmission. The other peer may have lost our flight or our ACKs. We + don't account this as a disruption */ + if (ssl->keys.dtls_peer_handshake_number < + ssl->keys.dtls_expected_peer_handshake_number) + return 0; + + /* out of order message */ + if (ssl->keys.dtls_peer_handshake_number > + ssl->keys.dtls_expected_peer_handshake_number) { + return 1; + } + + /* first fragment of in-order message */ + if (fragOffset == 0) + return 0; + + /* is not the next fragment in the message (the check is not 100% perfect, + in the worst case, we don't detect the disruption and wait for the other + peer retransmission) */ + if (ssl->dtls_rx_msg_list == NULL || + ssl->dtls_rx_msg_list->fragSz != fragOffset) { + return 1; + } + + return 0; +} + +static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl) +{ + Dtls13RecordNumber *rn, **prevNext; + + prevNext = &ssl->dtls13Rtx.seenRecords; + rn = ssl->dtls13Rtx.seenRecords; + + while (rn != NULL) { + if (w64Equal(rn->epoch, ssl->keys.curEpoch64) && + w64Equal(rn->seq, ssl->keys.curSeq)) { + *prevNext = rn->next; + XFREE(rn, ssl->heap, DYNAMIC_TYEP_DTLS_MSG); + return; + } + + prevNext = &rn->next; + rn = rn->next; + } +} + +static int Dtls13RtxMsgRecvd(WOLFSSL* ssl, enum HandShakeType hs, + word32 fragOffset) +{ + WOLFSSL_ENTER("Dtls13RtxMsgRecvd"); + + if (!ssl->options.handShakeDone && + ssl->keys.dtls_peer_handshake_number >= + ssl->keys.dtls_expected_peer_handshake_number) { + + /* In the handshake, receiving part of the next flight, acknowledge the + sent flight. The only exception is, on the server side, receiving the + last client flight does not ACK any sent new_session_ticket + messages. */ + Dtls13RtxFlushBuffered(ssl, 1); + } + + if (ssl->keys.dtls_peer_handshake_number < + ssl->keys.dtls_expected_peer_handshake_number) { + + /* retransmission detected. */ + ssl->dtls13Rtx.retransmit = 1; + + /* the other peer may have retransmitted because an ACK for a flight + that needs explicit ACK was lost.*/ + if (ssl->dtls13Rtx.seenRecords != NULL) + ssl->dtls13Rtx.sendAcks = (byte)ssl->options.dtls13SendMoreAcks; + } + + if (ssl->keys.dtls_peer_handshake_number == + ssl->keys.dtls_expected_peer_handshake_number && + ssl->options.handShakeDone && hs == certificate_request) { + + /* the current record, containing a post-handshake certificate request, + is implicitly acknowledged by the + certificate/certificate_verify/finished flight we are about to + send. Please note that if the certificate request came out-of-order + and we didn't send an ACK (sendMoreAcks == 0 and the missing + packet(s) arrive before that fast timeout expired), then we will send + both the ACK and the flight. While unnecessary this it's harmless, it + should be rare and simplifies the code. Otherwise, it would be + necessary to track which record number contained a CertificateRequest + with a particular context id */ + Dtls13RtxRemoveCurAck(ssl); + } + + if (ssl->options.dtls13SendMoreAcks && Dtls13DetectDisruption(ssl, fragOffset)) { + WOLFSSL_MSG("Disruption detected"); + ssl->dtls13Rtx.sendAcks = 1; + } + + return 0; +} + +void Dtls13FreeFsmResources(WOLFSSL* ssl) +{ + Dtls13RtxFlushAcks(ssl); + Dtls13RtxFlushBuffered(ssl, 0); +} + +static int Dtls13SendOneFragmentRtx(WOLFSSL* ssl, + enum HandShakeType handshakeType, word16 outputSize, byte* message, + word32 length, int hashOutput) +{ + Dtls13RtxRecord* rtxRecord; + word16 recordHeaderLength; + byte isProtected; + int ret; + + isProtected = Dtls13TypeIsEncrypted(handshakeType); + recordHeaderLength = Dtls13GetRlHeaderLength(isProtected); + + rtxRecord = Dtls13RtxNewRecord(ssl, message + recordHeaderLength, + (word16)(length - recordHeaderLength), handshakeType, + ssl->dtls13EncryptEpoch->nextSeqNumber); + + if (rtxRecord == NULL) + return MEMORY_E; + + ret = Dtls13SendFragment(ssl, message, outputSize, (word16)length, + handshakeType, hashOutput, Dtls13SendNow(ssl, handshakeType)); + + if (ret == 0 || ret == WANT_WRITE) + Dtls13RtxAddRecord(&ssl->dtls13Rtx, rtxRecord); + else + Dtls13FreeRtxBufferRecord(ssl, rtxRecord); + + return ret; +} + static int Dtls13SendFragmentedInternal(WOLFSSL* ssl) { int fragLength, rlHeaderLength; @@ -551,8 +851,8 @@ static int Dtls13SendFragmentedInternal(WOLFSSL* ssl) ssl->dtls13FragmentsBuffer.buffer + ssl->dtls13FragOffset, fragLength); - ret = Dtls13SendFragment(ssl, output, maxFragment, recordLength, - ssl->dtls13FragHandshakeType, 0, 1); + ret = Dtls13SendOneFragmentRtx(ssl, ssl->dtls13FragHandshakeType, + recordLength + MAX_MSG_EXTRA, output, recordLength, 0); if (ret == WANT_WRITE) { ssl->dtls13FragOffset += fragLength; return ret; @@ -940,7 +1240,112 @@ int Dtls13ParseUnifiedRecordLayer(WOLFSSL* ssl, const byte* input, int Dtls13RecordRecvd(WOLFSSL* ssl) { - (void)ssl; + int ret; + + if (ssl->curRL.type != handshake) + return 0; + + if (!ssl->options.dtls13SendMoreAcks) + ssl->dtls13FastTimeout = 1; + + ret = Dtls13RtxAddAck(ssl, ssl->keys.curEpoch64, ssl->keys.curSeq); + if (ret != 0) + WOLFSSL_MSG("can't save ack fragment"); + + return ret; +} + +static void Dtls13RtxMoveToEndOfList(WOLFSSL* ssl, Dtls13RtxRecord** prevNext, + Dtls13RtxRecord* r) +{ + /* already at the end */ + if (r->next == NULL) + return; + + Dtls13RtxRecordUnlink(ssl, prevNext, r); + /* add to the end */ + Dtls13RtxAddRecord(&ssl->dtls13Rtx, r); +} + +static int Dtls13RtxSendBuffered(WOLFSSL* ssl) +{ + word16 headerLength; + Dtls13RtxRecord *r, **prevNext; + w64wrapper seq; + byte* output; + int isLast; + int sendSz; + word32 now; + int ret; + + WOLFSSL_ENTER("Dtls13RtxSendBuffered"); + + now = LowResTimer(); + if (now - ssl->dtls13Rtx.lastRtx < DTLS13_MIN_RTX_INTERVAL) { +#ifdef WOLFSSL_DEBUG_TLS + WOLFSSL_MSG("Avoid too fast retransmission"); +#endif /* WOLFSSL_DEBUG_TLS */ + return 0; + } + + ssl->dtls13Rtx.lastRtx = now; + + r = ssl->dtls13Rtx.rtxRecords; + prevNext = &ssl->dtls13Rtx.rtxRecords; + while (r != NULL) { + isLast = r->next == NULL; + WOLFSSL_MSG("Dtls13Rtx One Record"); + + headerLength = Dtls13GetRlHeaderLength(!w64IsZero(r->epoch)); + + sendSz = r->length + headerLength; + + if (!w64IsZero(r->epoch)) + sendSz += MAX_MSG_EXTRA; + + ret = CheckAvailableSize(ssl, sendSz); + if (ret != 0) + return ret; + + output = + ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + + XMEMCPY(output + headerLength, r->data, r->length); + + if (!w64Equal(ssl->dtls13EncryptEpoch->epochNumber, r->epoch)) { + ret = Dtls13SetEpochKeys(ssl, r->epoch, ENCRYPT_SIDE_ONLY); + if (ret != 0) + return ret; + } + + seq = ssl->dtls13EncryptEpoch->nextSeqNumber; + + ret = Dtls13SendFragment(ssl, output, sendSz, r->length + headerLength, + r->handshakeType, 0, isLast || !ssl->options.groupMessages); + if (ret != 0 && ret != WANT_WRITE) + return ret; + + if (r->rnIdx >= DTLS13_RETRANS_RN_SIZE) + r->rnIdx = 0; + +#ifdef WOLFSSL_DEBUG_TLS + WOLFSSL_MSG_EX("tracking r hs: %d with seq: %ld", r->handshakeType, + seq); +#endif /* WOLFSSL_DEBUG_TLS */ + + r->seq[r->rnIdx] = seq; + r->rnIdx++; + + if (ret == WANT_WRITE) { + /* this fragment will be sent eventually. Move it to the end of the + list so next time we start with a new one. */ + Dtls13RtxMoveToEndOfList(ssl, prevNext, r); + return ret; + } + + prevNext = &r->next; + r = r->next; + } return 0; } @@ -956,8 +1361,8 @@ int Dtls13RecordRecvd(WOLFSSL* ssl) * * returns 0 on success */ -static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte *input, word32 size, - word32 *processedSize) +static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, + word32* processedSize) { word32 frag_off, frag_length; byte isComplete, isFirst; @@ -980,6 +1385,10 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte *input, word32 size, if (frag_off + frag_length > message_length) return BUFFER_ERROR; + ret = Dtls13RtxMsgRecvd(ssl, handshake_type, frag_off); + if (ret != 0) + return ret; + if (ssl->keys.dtls_peer_handshake_number < ssl->keys.dtls_expected_peer_handshake_number) { @@ -1117,6 +1526,21 @@ int Dtls13HandshakeSend(WOLFSSL* ssl, byte* message, word16 outputSize, /* if we are here, the message is built */ ssl->options.buildingMsg = 0; + if (!ssl->options.handShakeDone) { + + /* during the handshake, if we are sending a new flight, we can flush + our ACK list. When sending client + [certificate/certificate_verify]/finished flight, we may flush an ACK + for a newSessionticket message, sent by the server just after sending + its finished message. This should not be a problem. That message + arrived out-of-order (before the server finished) so likely an ACK + was already sent. In the worst case we will ACK the server + retranmission*/ + if (handshakeType == certificate || handshakeType == finished || + handshakeType == server_hello || handshakeType == client_hello) + Dtls13RtxFlushAcks(ssl); + } + /* we want to send always with the highest epoch */ if (!w64Equal(ssl->dtls13EncryptEpoch->epochNumber, ssl->dtls13Epoch)) { ret = Dtls13SetEpochKeys(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY); @@ -1128,9 +1552,8 @@ int Dtls13HandshakeSend(WOLFSSL* ssl, byte* message, word16 outputSize, maxLen = length; if (maxLen < maxFrag) { - ret = Dtls13SendFragment(ssl, message, outputSize, length, - handshakeType, hashOutput, Dtls13SendNow(ssl, handshakeType)); - + ret = Dtls13SendOneFragmentRtx(ssl, handshakeType, outputSize, message, + length, hashOutput); if (ret == 0 || ret == WANT_WRITE) ssl->keys.dtls_handshake_number++; } @@ -1588,4 +2011,365 @@ int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side) return NOT_COMPILED_IN; } +/* 64 bits epoch + 64 bits sequence */ +#define DTLS13_RN_SIZE 16 + +static int Dtls13GetAckListLength(Dtls13RecordNumber* list, word16* length) +{ + int numberElements; + + numberElements = 0; + + /* TODO: check that we don't exceed the maximum length */ + + while (list != NULL) { + list = list->next; + numberElements++; + } + + *length = DTLS13_RN_SIZE * numberElements; + return 0; +} + +static int Dtls13WriteAckMessage(WOLFSSL* ssl, + Dtls13RecordNumber* recordNumberList, word32* length) +{ + word16 msgSz, headerLength; + byte *output, *ackMessage; + word32 sendSz; + int ret; + + sendSz = 0; + + if (ssl->dtls13EncryptEpoch == NULL) + return BAD_STATE_E; + + if (w64IsZero(ssl->dtls13EncryptEpoch->epochNumber)) { + /* unprotected ACK */ + headerLength = DTLS_RECORD_HEADER_SZ; + ; + } + else { + headerLength = Dtls13GetRlHeaderLength(1); + sendSz += MAX_MSG_EXTRA; + } + + ret = Dtls13GetAckListLength(recordNumberList, &msgSz); + if (ret != 0) + return ret; + + sendSz += headerLength; + + /* ACK list 2 bytes length field */ + sendSz += OPAQUE16_LEN; + + /* ACK list */ + sendSz += msgSz; + + ret = CheckAvailableSize(ssl, sendSz); + if (ret != 0) + return ret; + + output = + ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + + ackMessage = output + headerLength; + + c16toa(msgSz, ackMessage); + ackMessage += OPAQUE16_LEN; + + while (recordNumberList != NULL) { + c64toa(&recordNumberList->epoch, ackMessage); + ackMessage += OPAQUE64_LEN; + c64toa(&recordNumberList->seq, ackMessage); + ackMessage += OPAQUE64_LEN; + recordNumberList = recordNumberList->next; + } + + *length = msgSz + OPAQUE16_LEN; + + return 0; +} + +static int Dtls13RtxIsTrackedByRn(const Dtls13RtxRecord* r, w64wrapper epoch, + w64wrapper seq) +{ + int i; + if (!w64Equal(r->epoch, epoch)) + return 0; + + for (i = 0; i < r->rnIdx; ++i) { + if (w64Equal(r->seq[i], seq)) + return 1; + } + + return 0; +} + +#ifdef WOLFSSL_DEBUG_TLS +static void Dtls13PrintRtxRecord(Dtls13RtxRecord* r) +{ + int i; + + WOLFSSL_MSG_EX("r: hs: %d epoch: %ld", r->handshakeType, r->epoch); + for (i = 0; i < r->rnIdx; i++) + WOLFSSL_MSG_EX("seq: %ld", r->seq[i]); +} +#endif /* WOLFSSL_DEBUG_TLS */ + +static void Dtls13RtxRemoveRecord(WOLFSSL* ssl, w64wrapper epoch, + w64wrapper seq) +{ + Dtls13RtxRecord *r, **prevNext; + + prevNext = &ssl->dtls13Rtx.rtxRecords; + r = ssl->dtls13Rtx.rtxRecords; + + while (r != NULL) { +#ifdef WOLFSSL_DEBUG_TLS + Dtls13PrintRtxRecord(r); +#endif /* WOLFSSL_DEBUG_TLS */ + + if (Dtls13RtxIsTrackedByRn(r, epoch, seq)) { +#ifdef WOLFSSL_DEBUG_TLS + WOLFSSL_MSG("removing record"); +#endif /* WOLFSSL_DEBUG_TLS */ + Dtls13RtxRecordUnlink(ssl, prevNext, r); + Dtls13FreeRtxBufferRecord(ssl, r); + return; + } + prevNext = &r->next; + r = r->next; + } + + return; +} + +int Dtls13DoScheduledWork(WOLFSSL* ssl) +{ + int ret; + + WOLFSSL_ENTER("Dtls13DoScheduledWork"); + + ssl->dtls13SendingAckOrRtx = 1; + + if (ssl->dtls13Rtx.sendAcks) { + ssl->dtls13Rtx.sendAcks = 0; + ret = SendDtls13Ack(ssl); + if (ret != 0) + return ret; + } + + if (ssl->dtls13Rtx.retransmit) { + ssl->dtls13Rtx.retransmit = 0; + ret = Dtls13RtxSendBuffered(ssl); + if (ret != 0) + return ret; + } + + ssl->dtls13SendingAckOrRtx = 0; + + return 0; +} + +/* Send ACKs when available after a timeout but only retransmit the last + * flight after a long timeout */ +int Dtls13RtxTimeout(WOLFSSL* ssl) +{ + int ret = 0; + + if (ssl->dtls13Rtx.seenRecords != NULL) { + ssl->dtls13Rtx.sendAcks = 0; + /* reset fast timeout as we are sending ACKs */ + ssl->dtls13FastTimeout = 0; + ret = SendDtls13Ack(ssl); + if (ret != 0) + return ret; + } + + /* we have two timeouts, a shorter (dtls13FastTimeout = 1) and a longer + one. When the shorter expires we only send ACKs, as it normally means + that some messages we are waiting for dont't arrive yet. But we + retransmit our buffered messages only if the longer timeout + expires. fastTimeout is 1/4 of the longer timeout */ + if (ssl->dtls13FastTimeout) { + ssl->dtls13FastTimeout = 0; + return 0; + } + + return Dtls13RtxSendBuffered(ssl); +} + +int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize, + word32* processedSize) +{ + const byte* ackMessage; + w64wrapper epoch, seq; + word16 length; + int i; + + if (inputSize < OPAQUE16_LEN) + return BUFFER_ERROR; + + ato16(input, &length); + + if (inputSize < (word32)(OPAQUE16_LEN + length)) + return BUFFER_ERROR; + + if (length % (DTLS13_RN_SIZE) != 0) + return PARSE_ERROR; + + ackMessage = input + OPAQUE16_LEN; + for (i = 0; i < length; i += DTLS13_RN_SIZE) { + ato64(ackMessage + i, &epoch); + ato64(ackMessage + i + OPAQUE64_LEN, &seq); + Dtls13RtxRemoveRecord(ssl, epoch, seq); + } + + /* last client flight was completely acknowledged by the server. Handshake + is complete. */ + if (ssl->options.side == WOLFSSL_CLIENT_END && + ssl->options.connectState == WAIT_FINISHED_ACK && + ssl->dtls13Rtx.rtxRecords == NULL) { + ssl->options.serverState = SERVER_FINISHED_ACKED; + } + + *processedSize = length + OPAQUE16_LEN; + + /* After the handshake, not retransmitting here may incur in some extra time + in case a post-handshake authentication message is lost, because the ACK + mechanism does not shortcut the retransmission timer. If, on the other + hand, we retransmit we may do extra retransmissions of unrelated messages + in the queue. ex: we send KeyUpdate, CertificateRequest that are + unrelated between each other, receiving the ACK for the KeyUpdate will + trigger re-sending the CertificateRequest before the timeout.*/ + /* TODO: be more smart about when doing retransmission looking in the + retransmission queue or based on the type of message removed from the + seen record list */ + if (ssl->dtls13Rtx.rtxRecords != NULL) + ssl->dtls13Rtx.retransmit = 1; + + return 0; +} + +int SendDtls13Ack(WOLFSSL* ssl) +{ + word32 outputSize; + int headerSize; + word32 length; + byte* output; + int ret; + + if (ssl->dtls13EncryptEpoch == NULL) + return BAD_STATE_E; + + WOLFSSL_ENTER("SendDtls13Ack"); + + ret = 0; + + /* The handshake is not complete and the client didn't setup the TRAFFIC0 + epoch yet */ + if (ssl->options.side == WOLFSSL_SERVER_END && + !ssl->options.handShakeDone && + w64GTE(ssl->dtls13Epoch, w64From32(0, DTLS13_EPOCH_TRAFFIC0))) { + ret = Dtls13SetEpochKeys(ssl, w64From32(0, DTLS13_EPOCH_HANDSHAKE), + ENCRYPT_SIDE_ONLY); + } + else if (!w64Equal(ssl->dtls13Epoch, + ssl->dtls13EncryptEpoch->epochNumber)) { + ret = Dtls13SetEpochKeys(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY); + } + + if (ret != 0) + return ret; + + if (w64IsZero(ssl->dtls13EncryptEpoch->epochNumber)) { + + ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length); + if (ret != 0) + return ret; + + output = + ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + + ret = Dtls13RlAddPlaintextHeader(ssl, output, ack, length); + if (ret != 0) + return ret; + + ssl->buffers.outputBuffer.length += length + DTLS_RECORD_HEADER_SZ; + } + else { + + ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length); + if (ret != 0) + return ret; + + output = + ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + + outputSize = ssl->buffers.outputBuffer.bufferSize - + ssl->buffers.outputBuffer.length; + + headerSize = Dtls13GetRlHeaderLength(1); + + ret = BuildTls13Message(ssl, output, outputSize, output + headerSize, + length, ack, 0, 0, 0); + if (ret < 0) + return ret; + + ssl->buffers.outputBuffer.length += ret; + } + + Dtls13RtxFlushAcks(ssl); + + return SendBuffered(ssl); +} + +static int Dtls13RtxRecordMatchesReqCtx(Dtls13RtxRecord* r, byte* ctx, + byte ctxLen) +{ + if (r->handshakeType != certificate_request) + return 0; + if (r->length <= ctxLen + 1) + return 0; + return XMEMCMP(ctx, r->data + 1, ctxLen) == 0; +} + +int Dtls13RtxProcessingCertificate(WOLFSSL* ssl, byte* input, word32 inputSize) +{ + Dtls13RtxRecord* rtxRecord = ssl->dtls13Rtx.rtxRecords; + Dtls13RtxRecord** prevNext = &ssl->dtls13Rtx.rtxRecords; + byte ctxLength; + + WOLFSSL_ENTER("Dtls13RtxProcessingCertificate"); + + if (inputSize <= 1) { + WOLFSSL_MSG("Malformed Certificate"); + return BAD_FUNC_ARG; + } + + ctxLength = *input; + + if (inputSize < (word32)ctxLength + OPAQUE8_LEN) { + WOLFSSL_MSG("Malformed Certificate"); + return BAD_FUNC_ARG; + } + + while (rtxRecord != NULL) { + if (Dtls13RtxRecordMatchesReqCtx(rtxRecord, input + 1, ctxLength)) { + Dtls13RtxRecordUnlink(ssl, prevNext, rtxRecord); + Dtls13FreeRtxBufferRecord(ssl, rtxRecord); + return 0; + } + prevNext = &rtxRecord->next; + rtxRecord = rtxRecord->next; + } + + /* This isn't an error since we just can't find a Dtls13RtxRecord that + * matches the Request Context. Request Context validity is checked + * later. */ + WOLFSSL_MSG("Can't find any previous Certificate Request"); + return 0; +} + #endif /* WOLFSSL_DTLS13 */ diff --git a/src/internal.c b/src/internal.c index 3486d0be2..ea5e21ac3 100644 --- a/src/internal.c +++ b/src/internal.c @@ -195,6 +195,11 @@ WOLFSSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl); #endif +#ifdef WOLFSSL_DTLS13 +#ifndef WOLFSSL_DTLS13_SEND_MOREACK_DEFAULT +#define WOLFSSL_DTLS13_SEND_MOREACK_DEFAULT 0 +#endif +#endif /* WOLFSSL_DTLS13 */ enum processReply { doProcessInit = 0, @@ -6778,6 +6783,8 @@ int InitSSL(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup) ssl->dtls13Epochs[0].side = ENCRYPT_AND_DECRYPT_SIDE; ssl->dtls13EncryptEpoch = &ssl->dtls13Epochs[0]; ssl->dtls13DecryptEpoch = &ssl->dtls13Epochs[0]; + ssl->options.dtls13SendMoreAcks = WOLFSSL_DTLS13_SEND_MOREACK_DEFAULT; + ssl->dtls13Rtx.rtxRecordTailPtr = &ssl->dtls13Rtx.rtxRecords; #endif /* WOLFSSL_DTLS13 */ return 0; @@ -7425,6 +7432,9 @@ void SSL_ResourceFree(WOLFSSL* ssl) wolfSSL_sk_X509_NAME_pop_free(ssl->ca_names, NULL); ssl->ca_names = NULL; #endif +#ifdef WOLFSSL_DTLS13 + Dtls13FreeFsmResources(ssl); +#endif /* WOLFSSL_DTLS13 */ } /* Free any handshake resources no longer needed */ @@ -9175,6 +9185,18 @@ retry: case WOLFSSL_CBIO_ERR_TIMEOUT: #ifdef WOLFSSL_DTLS +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version)) { + /* TODO: support WANT_WRITE here */ + if (Dtls13RtxTimeout(ssl) < 0) { + WOLFSSL_MSG( + "Error trying to retransmit DTLS buffered message"); + return -1; + } + goto retry; + } +#endif /* WOLFSSL_DTLS13 */ + if (IsDtlsNotSctpMode(ssl) && ssl->options.handShakeState != HANDSHAKE_DONE && DtlsMsgPoolTimeout(ssl) == 0 && @@ -9604,15 +9626,30 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input, #ifdef WOLFSSL_DTLS13 word32 read_size; + int ret; read_size = ssl->buffers.inputBuffer.length - *inOutIdx; if (Dtls13IsUnifiedHeader(*(input + *inOutIdx))) { /* version 1.3 already negotiated */ - if (ssl->options.tls1_3) - return GetDtls13RecordHeader(ssl, input, inOutIdx, rh, size); + if (ssl->options.tls1_3) { + ret = GetDtls13RecordHeader(ssl, input, inOutIdx, rh, size); + if (ret == 0 || ret != SEQUENCE_ERROR) + return ret; + } +#ifndef NO_WOLFSSL_CLIENT + if (ssl->options.side == WOLFSSL_CLIENT_END + && ssl->options.serverState < SERVER_HELLO_COMPLETE + && IsAtLeastTLSv1_3(ssl->version) + && !ssl->options.handShakeDone) { + /* we may have lost ServerHello. Try to send a empty ACK to shortcut + Server retransmission timer */ + ssl->dtls13Rtx.sendAcks = 1; + } +#endif + return SEQUENCE_ERROR; } /* not a unified header, check that we have at least @@ -9735,6 +9772,15 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx, else if (ssl->options.dtls && rh->type == handshake) /* Check the DTLS handshake message RH version later. */ WOLFSSL_MSG("DTLS handshake, skip RH version number check"); +#ifdef WOLFSSL_DTLS13 + else if (ssl->options.dtls && !ssl->options.handShakeDone) { + /* we may have lost the ServerHello and this is a unified record + before version been negotiated */ + if (Dtls13IsUnifiedHeader(*input)) { + return SEQUENCE_ERROR; + } + } +#endif /* WOLFSSL_DTLS13 */ else { WOLFSSL_MSG("SSL version error"); /* send alert per RFC5246 Appendix E. Backward Compatibility */ @@ -9771,6 +9817,9 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx, case change_cipher_spec: case application_data: case alert: +#ifdef WOLFSSL_DTLS13 + case ack: +#endif /* WOLFSSL_DTLS13 */ break; case no_type: default: @@ -17858,6 +17907,13 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) ssl->replayDropCount++; #endif /* WOLFSSL_DTLS_DROP_STATS */ +#ifdef WOLFSSL_DTLS13 + /* return to send ACKS and shortcut rtx timer */ + if (IsAtLeastTLSv1_3(ssl->version) + && ssl->dtls13Rtx.sendAcks) + return 0; +#endif /* WOLFSSL_DTLS13 */ + continue; } #endif @@ -18630,6 +18686,22 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) ret = 0; break; +#ifdef WOLFSSL_DTLS13 + case ack: + WOLFSSL_MSG("got ACK"); + if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version)) { + word32 processedSize = 0; + ret = DoDtls13Ack(ssl, ssl->buffers.inputBuffer.buffer + + ssl->buffers.inputBuffer.idx, + ssl->buffers.inputBuffer.length - + ssl->buffers.inputBuffer.idx - + ssl->keys.padSz, &processedSize); + ssl->buffers.inputBuffer.idx += processedSize; + ssl->buffers.inputBuffer.idx += ssl->keys.padSz; + break; + } + FALL_THROUGH; +#endif /* WOLFSSL_DTLS13 */ default: WOLFSSL_ERROR(UNKNOWN_RECORD_TYPE); return UNKNOWN_RECORD_TYPE; @@ -21088,6 +21160,16 @@ startScr: } return ssl->error; } + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + /* Dtls13DoScheduledWork(ssl) may return WANT_WRITE */ + if ((ssl->error = Dtls13DoScheduledWork(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return ssl->error; + } + } +#endif /* WOLFSSL_DTLS13 */ #ifdef HAVE_SECURE_RENEGOTIATION if (ssl->secure_renegotiation && ssl->secure_renegotiation->startScr) { diff --git a/src/ssl.c b/src/ssl.c index 25f27d237..ea30642ff 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -11497,6 +11497,30 @@ int wolfSSL_dtls_get_current_timeout(WOLFSSL* ssl) return timeout; } +#ifdef WOLFSSL_DTLS13 + +/* + * This API returns 1 when the user should set a short timeout for receiving + * data. It is recommended that it is at most 1/4 the value returned by + * wolfSSL_dtls_get_current_timeout(). + */ +int wolfSSL_dtls13_use_quick_timeout(WOLFSSL* ssl) +{ + return ssl->dtls13FastTimeout; +} + +/* + * When this is set, a DTLS 1.3 connection will send acks immediately when a + * disruption is detected to shortcut timeouts. This results in potentially + * more traffic but may make the handshake quicker. + */ +void wolfSSL_dtls13_set_send_more_acks(WOLFSSL* ssl, int value) +{ + if (ssl != NULL) + ssl->options.dtls13SendMoreAcks = !!value; +} +#endif /* WOLFSSL_DTLS13 */ + int wolfSSL_DTLSv1_get_timeout(WOLFSSL* ssl, WOLFSSL_TIMEVAL* timeleft) { if (ssl && timeleft) { @@ -11567,6 +11591,21 @@ int wolfSSL_dtls_got_timeout(WOLFSSL* ssl) if (ssl == NULL) return WOLFSSL_FATAL_ERROR; +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version)) { + result = Dtls13RtxTimeout(ssl); + if (result < 0) { + if (result == WANT_WRITE) + ssl->dtls13SendingAckOrRtx = 1; + ssl->error = result; + WOLFSSL_ERROR(result); + return WOLFSSL_FATAL_ERROR; + } + + return WOLFSSL_SUCCESS; + } +#endif /* WOLFSSL_DTLS13 */ + if ((IsSCR(ssl) || !ssl->options.handShakeDone)) { if (DtlsMsgPoolTimeout(ssl) < 0){ ssl->error = SOCKET_ERROR_E; @@ -11789,6 +11828,7 @@ int wolfSSL_DTLS_SetCookieSecret(WOLFSSL* ssl, { #if !(defined(WOLFSSL_NO_TLS12) && defined(NO_OLD_TLS) && defined(WOLFSSL_TLS13)) int neededState; + byte advanceState; #endif int ret = 0; @@ -11856,6 +11896,21 @@ int wolfSSL_DTLS_SetCookieSecret(WOLFSSL* ssl, } #endif + /* fragOffset is non-zero when sending fragments. On the last + * fragment, fragOffset is zero again, and the state can be + * advanced. */ + advanceState = ssl->fragOffset == 0 && + (ssl->options.connectState == CONNECT_BEGIN || + ssl->options.connectState == HELLO_AGAIN || + (ssl->options.connectState >= FIRST_REPLY_DONE && + ssl->options.connectState <= FIRST_REPLY_FOURTH)); +; + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version)) + advanceState = advanceState && !ssl->dtls13SendingAckOrRtx; +#endif /* WOLFSSL_DTLS13 */ + if (ssl->buffers.outputBuffer.length > 0 #ifdef WOLFSSL_ASYNC_CRYPT /* do not send buffered or advance state if last error was an @@ -11863,15 +11918,9 @@ int wolfSSL_DTLS_SetCookieSecret(WOLFSSL* ssl, && ssl->error != WC_PENDING_E #endif ) { - if ( (ret = SendBuffered(ssl)) == 0) { - /* fragOffset is non-zero when sending fragments. On the last - * fragment, fragOffset is zero again, and the state can be - * advanced. */ + if ( (ssl->error = SendBuffered(ssl)) == 0) { if (ssl->fragOffset == 0 && !ssl->options.buildingMsg) { - if (ssl->options.connectState == CONNECT_BEGIN || - ssl->options.connectState == HELLO_AGAIN || - (ssl->options.connectState >= FIRST_REPLY_DONE && - ssl->options.connectState <= FIRST_REPLY_FOURTH)) { + if (advanceState) { ssl->options.connectState++; WOLFSSL_MSG("connect state: " "Advanced from last buffered fragment send"); @@ -11942,6 +11991,27 @@ int wolfSSL_DTLS_SetCookieSecret(WOLFSSL* ssl, #endif neededState = SERVER_HELLODONE_COMPLETE; } +#ifdef WOLFSSL_DTLS13 + + if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version) + && ssl->dtls13Rtx.sendAcks == 1) { + ssl->dtls13Rtx.sendAcks = 0; + /* we aren't negotiated the version yet, so we aren't sure + * the other end can speak v1.3. On the other side we have + * received a unified records, assuming that the + * ServerHello got lost, we will send an empty ACK. In case + * the server is a DTLS with version less than 1.3, it + * should just ignore the message */ + if ((ssl->error = SendDtls13Ack(ssl)) < 0) { + if (ssl->error == WANT_WRITE) + ssl->dtls13SendingAckOrRtx = 1; + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + } + + +#endif /* WOLFSSL_DTLS13 */ } ssl->options.connectState = HELLO_AGAIN; @@ -12095,6 +12165,12 @@ int wolfSSL_DTLS_SetCookieSecret(WOLFSSL* ssl, WOLFSSL_MSG("connect state: FINISHED_DONE"); FALL_THROUGH; +#ifdef WOLFSSL_DTLS13 + case WAIT_FINISHED_ACK: + ssl->options.connectState = FINISHED_DONE; + FALL_THROUGH; +#endif /* WOLFSSL_DTLS13 */ + case FINISHED_DONE : /* get response */ while (ssl->options.serverState < SERVER_FINISHED_COMPLETE) diff --git a/src/tls13.c b/src/tls13.c index e1522f19c..d44b25887 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -6991,12 +6991,21 @@ exit_scv: static int DoTls13Certificate(WOLFSSL* ssl, byte* input, word32* inOutIdx, word32 totalSz) { - int ret; + int ret = 0; WOLFSSL_START(WC_FUNC_CERTIFICATE_DO); WOLFSSL_ENTER("DoTls13Certificate"); - ret = ProcessPeerCerts(ssl, input, inOutIdx, totalSz); +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls && ssl->options.handShakeDone) { + /* certificate needs some special care after the handshake */ + ret = Dtls13RtxProcessingCertificate( + ssl, input + *inOutIdx, totalSz); + } +#endif /* WOLFSSL_DTLS13 */ + + if (ret == 0) + ret = ProcessPeerCerts(ssl, input, inOutIdx, totalSz); if (ret == 0) { #if !defined(NO_WOLFSSL_CLIENT) if (ssl->options.side == WOLFSSL_CLIENT_END) @@ -9123,6 +9132,15 @@ int DoTls13HandShakeMsgType(WOLFSSL* ssl, byte* input, word32* inOutIdx, ssl->options.handShakeState = CLIENT_HELLO_COMPLETE; ssl->options.processReply = 0; /* doProcessInit */ + /* + DTLSv1.3 note: We can't reset serverState to + SERVER_FINISHED_COMPLETE with the goal that this connect + blocks until the cert/cert_verify/finished flight gets ACKed + by the server. The problem is that we will invoke + ProcessReplyEx() in that case, but we came here from + ProcessReplyEx() and it is not re-entrant safe (the input + buffer would still have the certificate_request message). */ + if (wolfSSL_connect_TLSv13(ssl) != WOLFSSL_SUCCESS) { ret = ssl->error; if (ret != WC_PENDING_E) @@ -9317,7 +9335,8 @@ int wolfSSL_connect_TLSv13(WOLFSSL* ssl) #ifdef WOLFSSL_DTLS13 if (ssl->options.dtls) - advanceState = advanceState && !ssl->dtls13SendingFragments; + advanceState = advanceState && !ssl->dtls13SendingFragments + && !ssl->dtls13SendingAckOrRtx; #endif /* WOLFSSL_DTLS13 */ if (ssl->buffers.outputBuffer.length > 0 @@ -9330,7 +9349,21 @@ int wolfSSL_connect_TLSv13(WOLFSSL* ssl) if ((ssl->error = SendBuffered(ssl)) == 0) { if (ssl->fragOffset == 0 && !ssl->options.buildingMsg) { if (advanceState) { - ssl->options.connectState++; +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version) && + ssl->options.connectState == FIRST_REPLY_FOURTH) { + /* WAIT_FINISHED_ACK is a state added afterwards, but it + can't follow FIRST_REPLY_FOURTH in the enum order. Indeed + the value of the enum ConnectState is stored in + serialized session. This would make importing serialized + session from other wolfSSL version incompatible */ + ssl->options.connectState = WAIT_FINISHED_ACK; + } + else +#endif /* WOLFSSL_DTLS13 */ + { + ssl->options.connectState++; + } WOLFSSL_MSG("connect state: " "Advanced from last buffered fragment send"); } @@ -9342,6 +9375,12 @@ int wolfSSL_connect_TLSv13(WOLFSSL* ssl) #ifdef WOLFSSL_ASYNC_IO FreeAsyncCtx(ssl, 0); #endif + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) + ssl->dtls13SendingAckOrRtx =0; +#endif /* WOLFSSL_DTLS13 */ + } else { ssl->error = ret; @@ -9402,9 +9441,18 @@ int wolfSSL_connect_TLSv13(WOLFSSL* ssl) while (ssl->options.serverState < SERVER_HELLO_RETRY_REQUEST_COMPLETE) { if ((ssl->error = ProcessReply(ssl)) < 0) { - WOLFSSL_ERROR(ssl->error); - return WOLFSSL_FATAL_ERROR; + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; } + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + if ((ssl->error = Dtls13DoScheduledWork(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + } +#endif /* WOLFSSL_DTLS13 */ } if (!ssl->options.tls1_3) { @@ -9450,9 +9498,18 @@ int wolfSSL_connect_TLSv13(WOLFSSL* ssl) /* Get the response/s from the server. */ while (ssl->options.serverState < SERVER_FINISHED_COMPLETE) { if ((ssl->error = ProcessReply(ssl)) < 0) { - WOLFSSL_ERROR(ssl->error); - return WOLFSSL_FATAL_ERROR; + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; } + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + if ((ssl->error = Dtls13DoScheduledWork(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + } +#endif /* WOLFSSL_DTLS13 */ } ssl->options.connectState = FIRST_REPLY_DONE; @@ -9545,6 +9602,26 @@ int wolfSSL_connect_TLSv13(WOLFSSL* ssl) } WOLFSSL_MSG("sent: finished"); +#ifdef WOLFSSL_DTLS13 + ssl->options.connectState = WAIT_FINISHED_ACK; + WOLFSSL_MSG("connect state: WAIT_FINISHED_ACK"); + FALL_THROUGH; + + case WAIT_FINISHED_ACK: + if (ssl->options.dtls) { + while (ssl->options.serverState != SERVER_FINISHED_ACKED) { + if ((ssl->error = ProcessReply(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + + if ((ssl->error = Dtls13DoScheduledWork(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + } + } +#endif /* WOLFSSL_DTLS13 */ ssl->options.connectState = FINISHED_DONE; WOLFSSL_MSG("connect state: FINISHED_DONE"); FALL_THROUGH; @@ -10340,8 +10417,8 @@ int wolfSSL_accept_TLSv13(WOLFSSL* ssl) #ifdef WOLFSSL_DTLS13 if (ssl->options.dtls) - advanceState = advanceState && - !ssl->dtls13SendingFragments; + advanceState = advanceState && !ssl->dtls13SendingFragments + && !ssl->dtls13SendingAckOrRtx; #endif /* WOLFSSL_DTLS13 */ if ((ssl->error = SendBuffered(ssl)) == 0) { @@ -10359,6 +10436,12 @@ int wolfSSL_accept_TLSv13(WOLFSSL* ssl) WOLFSSL_MSG("accept state: " "Not advanced, more fragments to send"); } + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) + ssl->dtls13SendingAckOrRtx = 0; +#endif /* WOLFSSL_DTLS13 */ + } else { ssl->error = ret; @@ -10397,6 +10480,16 @@ int wolfSSL_accept_TLSv13(WOLFSSL* ssl) WOLFSSL_ERROR(ssl->error); return WOLFSSL_FATAL_ERROR; } + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + if ((ssl->error = Dtls13DoScheduledWork(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + } +#endif /* WOLFSSL_DTLS13 */ + } ssl->options.acceptState = TLS13_ACCEPT_CLIENT_HELLO_DONE; @@ -10444,6 +10537,16 @@ int wolfSSL_accept_TLSv13(WOLFSSL* ssl) WOLFSSL_ERROR(ssl->error); return WOLFSSL_FATAL_ERROR; } + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + if ((ssl->error = Dtls13DoScheduledWork(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + } +#endif /* WOLFSSL_DTLS13 */ + } } @@ -10578,11 +10681,21 @@ int wolfSSL_accept_TLSv13(WOLFSSL* ssl) FALL_THROUGH; case TLS13_PRE_TICKET_SENT : - while (ssl->options.clientState < CLIENT_FINISHED_COMPLETE) + while (ssl->options.clientState < CLIENT_FINISHED_COMPLETE) { if ( (ssl->error = ProcessReply(ssl)) < 0) { - WOLFSSL_ERROR(ssl->error); - return WOLFSSL_FATAL_ERROR; + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + if ((ssl->error = Dtls13DoScheduledWork(ssl)) < 0) { + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } } +#endif /* WOLFSSL_DTLS13 */ + } ssl->options.acceptState = TLS13_ACCEPT_FINISHED_DONE; WOLFSSL_MSG("accept state ACCEPT_FINISHED_DONE"); @@ -10881,8 +10994,19 @@ int wolfSSL_read_early_data(WOLFSSL* ssl, void* data, int sz, int* outSz) ret = ReceiveData(ssl, (byte*)data, sz, FALSE); if (ret > 0) *outSz = ret; - if (ssl->error == ZERO_RETURN) + if (ssl->error == ZERO_RETURN) { ssl->error = WOLFSSL_ERROR_NONE; +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + ret = Dtls13DoScheduledWork(ssl); + if (ret < 0) { + ssl->error = ret; + WOLFSSL_ERROR(ssl->error); + return WOLFSSL_FATAL_ERROR; + } + } +#endif /* WOLFSSL_DTLS13 */ + } } else ret = 0; diff --git a/src/wolfio.c b/src/wolfio.c index 6685cb8e6..34a0742fc 100644 --- a/src/wolfio.c +++ b/src/wolfio.c @@ -381,6 +381,7 @@ int EmbedReceiveFrom(WOLFSSL *ssl, char *buf, int sz, void *ctx) int recvd; int sd = dtlsCtx->rfd; int dtls_timeout = wolfSSL_dtls_get_current_timeout(ssl); + byte doDtlsTimeout; SOCKADDR_S peer; XSOCKLENT peerSz = sizeof(peer); @@ -388,16 +389,41 @@ int EmbedReceiveFrom(WOLFSSL *ssl, char *buf, int sz, void *ctx) /* Don't use ssl->options.handShakeDone since it is true even if * we are in the process of renegotiation */ - if (ssl->options.handShakeState == HANDSHAKE_DONE) + doDtlsTimeout = ssl->options.handShakeState != HANDSHAKE_DONE; + +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version)) { + doDtlsTimeout = + doDtlsTimeout || ssl->dtls13Rtx.rtxRecords != NULL || + (ssl->dtls13FastTimeout && ssl->dtls13Rtx.seenRecords != NULL); + } +#endif /* WOLFSSL_DTLS13 */ + + if (!doDtlsTimeout) dtls_timeout = 0; if (!wolfSSL_get_using_nonblock(ssl)) { #ifdef USE_WINDOWS_API DWORD timeout = dtls_timeout * 1000; + #ifdef WOLFSSL_DTLS13 + if (wolfSSL_dtls13_use_quick_timeout(ssl) && + IsAtLeastTLSv1_3(ssl->version)) + timeout /= 4; + #endif /* WOLFSSL_DTLS13 */ #else struct timeval timeout; XMEMSET(&timeout, 0, sizeof(timeout)); - timeout.tv_sec = dtls_timeout; + #ifdef WOLFSSL_DTLS13 + if (wolfSSL_dtls13_use_quick_timeout(ssl) && + IsAtLeastTLSv1_3(ssl->version)) { + if (dtls_timeout >= 4) + timeout.tv_sec = dtls_timeout / 4; + else + timeout.tv_usec = dtls_timeout * 1000000 / 4; + } + else + #endif /* WOLFSSL_DTLS13 */ + timeout.tv_sec = dtls_timeout; #endif if (setsockopt(sd, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout, sizeof(timeout)) != 0) { diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 1da42bae0..d40e5f120 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -1633,7 +1633,7 @@ enum Misc { typedef char _args_test[sizeof((x)) >= sizeof((y)) ? 1 : -1]; \ (void)sizeof(_args_test) -/* states */ +/* states. Adding state before HANDSHAKE_DONE will break session importing */ enum states { NULL_STATE = 0, @@ -1654,7 +1654,12 @@ enum states { CLIENT_CHANGECIPHERSPEC_COMPLETE, CLIENT_FINISHED_COMPLETE, - HANDSHAKE_DONE + HANDSHAKE_DONE, + +#ifdef WOLFSSL_DTLS13 + SERVER_FINISHED_ACKED, +#endif /* WOLFSSL_DTLS13 */ + }; /* SSL Version */ @@ -3556,7 +3561,12 @@ enum ConnectState { FIRST_REPLY_THIRD, FIRST_REPLY_FOURTH, FINISHED_DONE, - SECOND_REPLY_DONE + SECOND_REPLY_DONE, + +#ifdef WOLFSSL_DTLS13 + WAIT_FINISHED_ACK +#endif /* WOLFSSL_DTLS13 */ + }; @@ -3836,6 +3846,10 @@ typedef struct Options { #endif word16 buildingMsg:1; /* If set then we need to re-enter the * handshake logic. */ +#ifdef WOLFSSL_DTLS13 + word16 dtls13SendMoreAcks:1; /* Send more acks during the + * handshake process */ +#endif /* need full byte values for this section */ byte processReply; /* nonblocking resume */ @@ -4367,7 +4381,47 @@ typedef struct Dtls13Epoch { byte side; } Dtls13Epoch; -#define DTLS13_EPOCH_SIZE 3 +#ifndef DTLS13_EPOCH_SIZE +#define DTLS13_EPOCH_SIZE 4 +#endif + +#ifndef DTLS13_RETRANS_RN_SIZE +#define DTLS13_RETRANS_RN_SIZE 3 +#endif + +enum Dtls13RtxFsmState { + DTLS13_RTX_FSM_PREPARING = 0, + DTLS13_RTX_FSM_SENDING, + DTLS13_RTX_FSM_WAITING, + DTLS13_RTX_FSM_FINISHED +}; + +typedef struct Dtls13RtxRecord { + struct Dtls13RtxRecord *next; + word16 length; + byte *data; + w64wrapper epoch; + w64wrapper seq[DTLS13_RETRANS_RN_SIZE]; + byte rnIdx; + byte handshakeType; +} Dtls13RtxRecord; + +typedef struct Dtls13RecordNumber { + struct Dtls13RecordNumber *next; + w64wrapper epoch; + w64wrapper seq; +} Dtls13RecordNumber; + +typedef struct Dtls13Rtx { + enum Dtls13RtxFsmState state; + Dtls13RtxRecord *rtxRecords; + Dtls13RtxRecord **rtxRecordTailPtr; + Dtls13RecordNumber *seenRecords; + byte triggeredRtxs; + byte sendAcks:1; + byte retransmit:1; + word32 lastRtx; +} Dtls13Rtx; #endif /* WOLFSSL_DTLS13 */ @@ -4579,9 +4633,13 @@ struct WOLFSSL { /* used to store the message if it needs to be fragmented */ buffer dtls13FragmentsBuffer; byte dtls13SendingFragments:1; + byte dtls13SendingAckOrRtx:1; + byte dtls13FastTimeout:1; word32 dtls13MessageLength; word32 dtls13FragOffset; byte dtls13FragHandshakeType; + Dtls13Rtx dtls13Rtx; + #endif /* WOLFSSL_DTLS13 */ #endif /* WOLFSSL_DTLS */ #ifdef WOLFSSL_CALLBACKS @@ -5338,6 +5396,7 @@ WOLFSSL_LOCAL int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber, enum encrypt_side side); WOLFSSL_LOCAL int Dtls13GetSeq(WOLFSSL* ssl, int order, word32* seq, byte increment); +WOLFSSL_LOCAL int Dtls13DoScheduledWork(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision); WOLFSSL_LOCAL int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side); @@ -5365,13 +5424,19 @@ WOLFSSL_LOCAL int Dtls13HandshakeAddHeader(WOLFSSL* ssl, byte* output, enum HandShakeType msg_type, word32 length); #define EE_MASK (0x3) WOLFSSL_LOCAL int Dtls13FragmentsContinue(WOLFSSL* ssl); - +WOLFSSL_LOCAL int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize, + word32* processedSize); WOLFSSL_LOCAL int Dtls13ReconstructEpochNumber(WOLFSSL* ssl, byte epochBits, w64wrapper* epoch); WOLFSSL_LOCAL int Dtls13ReconstructSeqNumber(WOLFSSL* ssl, Dtls13UnifiedHdrInfo* hdrInfo, w64wrapper* out); +WOLFSSL_LOCAL int SendDtls13Ack(WOLFSSL* ssl); +WOLFSSL_LOCAL int Dtls13RtxProcessingCertificate(WOLFSSL* ssl, byte* input, + word32 inputSize); WOLFSSL_LOCAL int Dtls13HashHandshake(WOLFSSL* ssl, const byte* output, word16 length); +WOLFSSL_LOCAL void Dtls13FreeFsmResources(WOLFSSL* ssl); +WOLFSSL_LOCAL int Dtls13RtxTimeout(WOLFSSL* ssl); #endif /* WOLFSSL_DTLS13 */ #ifdef WOLFSSL_STATIC_EPHEMERAL diff --git a/wolfssl/ssl.h b/wolfssl/ssl.h index 1d02e0f87..d419339fa 100644 --- a/wolfssl/ssl.h +++ b/wolfssl/ssl.h @@ -1310,6 +1310,8 @@ WOLFSSL_API int wolfSSL_dtls_get_using_nonblock(WOLFSSL* ssl); #define wolfSSL_get_using_nonblock wolfSSL_dtls_get_using_nonblock /* The old names are deprecated. */ WOLFSSL_API int wolfSSL_dtls_get_current_timeout(WOLFSSL* ssl); +WOLFSSL_API int wolfSSL_dtls13_use_quick_timeout(WOLFSSL* ssl); +WOLFSSL_API void wolfSSL_dtls13_set_send_more_acks(WOLFSSL* ssl, int value); WOLFSSL_API int wolfSSL_DTLSv1_get_timeout(WOLFSSL* ssl, WOLFSSL_TIMEVAL* timeleft); WOLFSSL_API void wolfSSL_DTLSv1_set_initial_timeout_duration(WOLFSSL* ssl,