From 1c83e24a7aee2adb185c299e2de7bcf8eb635255 Mon Sep 17 00:00:00 2001 From: Marco Oliverio Date: Wed, 25 Mar 2026 08:35:41 +0100 Subject: [PATCH] dtls13: keep a counter for seenRecords list --- src/dtls13.c | 56 ++++++++++++++++++++++--------------------- tests/api/test_dtls.c | 2 +- wolfssl/internal.h | 8 ++++++- 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/dtls13.c b/src/dtls13.c index 90dd14ca44..ca2389944c 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -734,6 +734,13 @@ int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq) Dtls13RecordNumber** prevNext = &ssl->dtls13Rtx.seenRecords; Dtls13RecordNumber* cur = ssl->dtls13Rtx.seenRecords; + if (ssl->dtls13Rtx.seenRecordsCount >= DTLS13_ACK_MAX_RECORDS) { + #ifdef WOLFSSL_RW_THREADED + wc_UnLockMutex(&ssl->dtls13Rtx.mutex); + #endif + return 0; /* list full, silently drop */ + } + for (; cur != NULL; prevNext = &cur->next, cur = cur->next) { if (w64Equal(cur->epoch, epoch) && w64Equal(cur->seq, seq)) { /* already in list. no duplicates. */ @@ -759,6 +766,7 @@ int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq) *prevNext = rn; rn->next = cur; + ssl->dtls13Rtx.seenRecordsCount++; #ifdef WOLFSSL_RW_THREADED wc_UnLockMutex(&ssl->dtls13Rtx.mutex); #endif @@ -788,6 +796,7 @@ static void Dtls13RtxFlushAcks(WOLFSSL* ssl) } ssl->dtls13Rtx.seenRecords = NULL; + ssl->dtls13Rtx.seenRecordsCount = 0; #ifdef WOLFSSL_RW_THREADED wc_UnLockMutex(&ssl->dtls13Rtx.mutex); #endif @@ -850,6 +859,8 @@ static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl) w64Equal(rn->seq, ssl->keys.curSeq)) { *prevNext = rn->next; XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG); + if (ssl->dtls13Rtx.seenRecordsCount > 0) + ssl->dtls13Rtx.seenRecordsCount--; #ifdef WOLFSSL_RW_THREADED wc_UnLockMutex(&ssl->dtls13Rtx.mutex); #endif @@ -2563,39 +2574,26 @@ int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side) return NOT_COMPILED_IN; } -/* 64 bits epoch + 64 bits sequence */ -#define DTLS13_RN_SIZE 16 - -static int Dtls13GetAckListLength(Dtls13RecordNumber* list, word16* length) -{ - int numberElements; - - numberElements = 0; - - /* TODO: check that we don't exceed the maximum length */ - - while (list != NULL) { - list = list->next; - numberElements++; - } - - *length = (word16)(DTLS13_RN_SIZE * numberElements); - return 0; -} int Dtls13WriteAckMessage(WOLFSSL* ssl, - Dtls13RecordNumber* recordNumberList, word32* length) + Dtls13RecordNumber* recordNumberList, word16 recordsCount, word32* length) { word16 msgSz, headerLength; byte *output, *ackMessage; word32 sendSz; + word32 written; int ret; sendSz = 0; + written = 0; if (ssl->dtls13EncryptEpoch == NULL) return BAD_STATE_E; + if (recordsCount > DTLS13_ACK_MAX_RECORDS) + return BUFFER_E; + msgSz = (word16)(DTLS13_RN_SIZE * recordsCount); + if (w64IsZero(ssl->dtls13EncryptEpoch->epochNumber)) { /* unprotected ACK */ headerLength = DTLS_RECORD_HEADER_SZ; @@ -2605,10 +2603,6 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl, sendSz += MAX_MSG_EXTRA; } - ret = Dtls13GetAckListLength(recordNumberList, &msgSz); - if (ret != 0) - return ret; - sendSz += headerLength; /* ACK list 2 bytes length field */ @@ -2631,6 +2625,8 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl, WOLFSSL_MSG("write ack records"); while (recordNumberList != NULL) { + if (written + DTLS13_RN_SIZE > msgSz) + return BUFFER_E; WOLFSSL_MSG_EX("epoch %d seq %d", recordNumberList->epoch, recordNumberList->seq); c64toa(&recordNumberList->epoch, ackMessage); @@ -2638,8 +2634,12 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl, c64toa(&recordNumberList->seq, ackMessage); ackMessage += OPAQUE64_LEN; recordNumberList = recordNumberList->next; + written += DTLS13_RN_SIZE; } + if (written != msgSz) + return BUFFER_E; + *length = msgSz + OPAQUE16_LEN; return 0; @@ -2750,6 +2750,7 @@ int Dtls13DoScheduledWork(WOLFSSL* ssl) tail = &(*tail)->next; *tail = ssl->dtls13Rtx.seenRecords; ssl->dtls13Rtx.seenRecords = NULL; + ssl->dtls13Rtx.seenRecordsCount = 0; ssl->dupWrite->sendAcks = 1; wc_UnLockMutex(&ssl->dupWrite->dupMutex); } @@ -2963,12 +2964,13 @@ int SendDtls13Ack(WOLFSSL* ssl) if (ret < 0) return ret; #endif - ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length); + ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, + ssl->dtls13Rtx.seenRecordsCount, &length); #ifdef WOLFSSL_RW_THREADED wc_UnLockMutex(&ssl->dtls13Rtx.mutex); #endif - if (ret != 0) - return ret; + if (ret != 0) + return ret; output = GetOutputBuffer(ssl); diff --git a/tests/api/test_dtls.c b/tests/api/test_dtls.c index bc277066dc..16e1c93642 100644 --- a/tests/api/test_dtls.c +++ b/tests/api/test_dtls.c @@ -918,7 +918,7 @@ int test_dtls13_ack_order(void) ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0); ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0); ExpectIntEQ(Dtls13WriteAckMessage(ssl_c, ssl_c->dtls13Rtx.seenRecords, - &length), 0); + ssl_c->dtls13Rtx.seenRecordsCount, &length), 0); /* must zero the span reserved for the header to avoid read of uninited * data. diff --git a/wolfssl/internal.h b/wolfssl/internal.h index e12bdc1239..0dc586a6e3 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -5857,6 +5857,11 @@ enum { DTLS13_EPOCH_TRAFFIC0 = 3 }; +/* 64-bit epoch + 64-bit sequence number */ +#define DTLS13_RN_SIZE (OPAQUE64_LEN + OPAQUE64_LEN) +/* Maximum number of ACK records encodable in the word16 length field */ +#define DTLS13_ACK_MAX_RECORDS ((int)(WOLFSSL_MAX_16BIT / DTLS13_RN_SIZE)) + typedef struct Dtls13Epoch { w64wrapper epochNumber; @@ -5925,6 +5930,7 @@ typedef struct Dtls13Rtx { Dtls13RtxRecord *rtxRecords; Dtls13RtxRecord **rtxRecordTailPtr; Dtls13RecordNumber *seenRecords; + word16 seenRecordsCount; #ifdef WOLFSSL_32BIT_MILLI_TIME word32 lastRtx; #else @@ -7279,7 +7285,7 @@ WOLFSSL_LOCAL int Dtls13ReconstructEpochNumber(WOLFSSL* ssl, byte epochBits, WOLFSSL_LOCAL int Dtls13ReconstructSeqNumber(WOLFSSL* ssl, Dtls13UnifiedHdrInfo* hdrInfo, w64wrapper* out); WOLFSSL_TEST_VIS int Dtls13WriteAckMessage(WOLFSSL* ssl, - Dtls13RecordNumber* recordNumberList, word32* length); + Dtls13RecordNumber* recordNumberList, word16 recordsCount, word32* length); WOLFSSL_LOCAL int SendDtls13Ack(WOLFSSL* ssl); WOLFSSL_TEST_VIS int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq); WOLFSSL_LOCAL int Dtls13RtxProcessingCertificate(WOLFSSL* ssl, byte* input,