dtls13: keep a counter for seenRecords list

This commit is contained in:
Marco Oliverio
2026-03-25 08:35:41 +01:00
parent 025a7dcd16
commit 1c83e24a7a
3 changed files with 37 additions and 29 deletions
+29 -27
View File
@@ -734,6 +734,13 @@ int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq)
Dtls13RecordNumber** prevNext = &ssl->dtls13Rtx.seenRecords;
Dtls13RecordNumber* cur = ssl->dtls13Rtx.seenRecords;
if (ssl->dtls13Rtx.seenRecordsCount >= DTLS13_ACK_MAX_RECORDS) {
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
return 0; /* list full, silently drop */
}
for (; cur != NULL; prevNext = &cur->next, cur = cur->next) {
if (w64Equal(cur->epoch, epoch) && w64Equal(cur->seq, seq)) {
/* already in list. no duplicates. */
@@ -759,6 +766,7 @@ int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq)
*prevNext = rn;
rn->next = cur;
ssl->dtls13Rtx.seenRecordsCount++;
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
@@ -788,6 +796,7 @@ static void Dtls13RtxFlushAcks(WOLFSSL* ssl)
}
ssl->dtls13Rtx.seenRecords = NULL;
ssl->dtls13Rtx.seenRecordsCount = 0;
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
@@ -850,6 +859,8 @@ static void Dtls13RtxRemoveCurAck(WOLFSSL* ssl)
w64Equal(rn->seq, ssl->keys.curSeq)) {
*prevNext = rn->next;
XFREE(rn, ssl->heap, DYNAMIC_TYPE_DTLS_MSG);
if (ssl->dtls13Rtx.seenRecordsCount > 0)
ssl->dtls13Rtx.seenRecordsCount--;
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
@@ -2563,39 +2574,26 @@ int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side)
return NOT_COMPILED_IN;
}
/* 64 bits epoch + 64 bits sequence */
#define DTLS13_RN_SIZE 16
static int Dtls13GetAckListLength(Dtls13RecordNumber* list, word16* length)
{
int numberElements;
numberElements = 0;
/* TODO: check that we don't exceed the maximum length */
while (list != NULL) {
list = list->next;
numberElements++;
}
*length = (word16)(DTLS13_RN_SIZE * numberElements);
return 0;
}
int Dtls13WriteAckMessage(WOLFSSL* ssl,
Dtls13RecordNumber* recordNumberList, word32* length)
Dtls13RecordNumber* recordNumberList, word16 recordsCount, word32* length)
{
word16 msgSz, headerLength;
byte *output, *ackMessage;
word32 sendSz;
word32 written;
int ret;
sendSz = 0;
written = 0;
if (ssl->dtls13EncryptEpoch == NULL)
return BAD_STATE_E;
if (recordsCount > DTLS13_ACK_MAX_RECORDS)
return BUFFER_E;
msgSz = (word16)(DTLS13_RN_SIZE * recordsCount);
if (w64IsZero(ssl->dtls13EncryptEpoch->epochNumber)) {
/* unprotected ACK */
headerLength = DTLS_RECORD_HEADER_SZ;
@@ -2605,10 +2603,6 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl,
sendSz += MAX_MSG_EXTRA;
}
ret = Dtls13GetAckListLength(recordNumberList, &msgSz);
if (ret != 0)
return ret;
sendSz += headerLength;
/* ACK list 2 bytes length field */
@@ -2631,6 +2625,8 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl,
WOLFSSL_MSG("write ack records");
while (recordNumberList != NULL) {
if (written + DTLS13_RN_SIZE > msgSz)
return BUFFER_E;
WOLFSSL_MSG_EX("epoch %d seq %d", recordNumberList->epoch,
recordNumberList->seq);
c64toa(&recordNumberList->epoch, ackMessage);
@@ -2638,8 +2634,12 @@ int Dtls13WriteAckMessage(WOLFSSL* ssl,
c64toa(&recordNumberList->seq, ackMessage);
ackMessage += OPAQUE64_LEN;
recordNumberList = recordNumberList->next;
written += DTLS13_RN_SIZE;
}
if (written != msgSz)
return BUFFER_E;
*length = msgSz + OPAQUE16_LEN;
return 0;
@@ -2750,6 +2750,7 @@ int Dtls13DoScheduledWork(WOLFSSL* ssl)
tail = &(*tail)->next;
*tail = ssl->dtls13Rtx.seenRecords;
ssl->dtls13Rtx.seenRecords = NULL;
ssl->dtls13Rtx.seenRecordsCount = 0;
ssl->dupWrite->sendAcks = 1;
wc_UnLockMutex(&ssl->dupWrite->dupMutex);
}
@@ -2963,12 +2964,13 @@ int SendDtls13Ack(WOLFSSL* ssl)
if (ret < 0)
return ret;
#endif
ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords, &length);
ret = Dtls13WriteAckMessage(ssl, ssl->dtls13Rtx.seenRecords,
ssl->dtls13Rtx.seenRecordsCount, &length);
#ifdef WOLFSSL_RW_THREADED
wc_UnLockMutex(&ssl->dtls13Rtx.mutex);
#endif
if (ret != 0)
return ret;
if (ret != 0)
return ret;
output = GetOutputBuffer(ssl);
+1 -1
View File
@@ -918,7 +918,7 @@ int test_dtls13_ack_order(void)
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0);
ExpectIntEQ(Dtls13RtxAddAck(ssl_c, w64From32(0, 2), w64From32(0, 2)), 0);
ExpectIntEQ(Dtls13WriteAckMessage(ssl_c, ssl_c->dtls13Rtx.seenRecords,
&length), 0);
ssl_c->dtls13Rtx.seenRecordsCount, &length), 0);
/* must zero the span reserved for the header to avoid read of uninited
* data.
+7 -1
View File
@@ -5857,6 +5857,11 @@ enum {
DTLS13_EPOCH_TRAFFIC0 = 3
};
/* 64-bit epoch + 64-bit sequence number */
#define DTLS13_RN_SIZE (OPAQUE64_LEN + OPAQUE64_LEN)
/* Maximum number of ACK records encodable in the word16 length field */
#define DTLS13_ACK_MAX_RECORDS ((int)(WOLFSSL_MAX_16BIT / DTLS13_RN_SIZE))
typedef struct Dtls13Epoch {
w64wrapper epochNumber;
@@ -5925,6 +5930,7 @@ typedef struct Dtls13Rtx {
Dtls13RtxRecord *rtxRecords;
Dtls13RtxRecord **rtxRecordTailPtr;
Dtls13RecordNumber *seenRecords;
word16 seenRecordsCount;
#ifdef WOLFSSL_32BIT_MILLI_TIME
word32 lastRtx;
#else
@@ -7279,7 +7285,7 @@ WOLFSSL_LOCAL int Dtls13ReconstructEpochNumber(WOLFSSL* ssl, byte epochBits,
WOLFSSL_LOCAL int Dtls13ReconstructSeqNumber(WOLFSSL* ssl,
Dtls13UnifiedHdrInfo* hdrInfo, w64wrapper* out);
WOLFSSL_TEST_VIS int Dtls13WriteAckMessage(WOLFSSL* ssl,
Dtls13RecordNumber* recordNumberList, word32* length);
Dtls13RecordNumber* recordNumberList, word16 recordsCount, word32* length);
WOLFSSL_LOCAL int SendDtls13Ack(WOLFSSL* ssl);
WOLFSSL_TEST_VIS int Dtls13RtxAddAck(WOLFSSL* ssl, w64wrapper epoch, w64wrapper seq);
WOLFSSL_LOCAL int Dtls13RtxProcessingCertificate(WOLFSSL* ssl, byte* input,