diff --git a/src/dtls13.c b/src/dtls13.c index 7cb02bd28..e2c3518d4 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -345,11 +345,11 @@ int Dtls13ProcessBufferedMessages(WOLFSSL* ssl) break; /* message not complete */ - if (msg->fragSz != msg->sz) + if (!msg->ready) break; - ret = DoTls13HandShakeMsgType(ssl, msg->msg, &idx, msg->type, msg->sz, - msg->sz); + ret = DoTls13HandShakeMsgType(ssl, msg->fullMsg, &idx, msg->type, + msg->sz, msg->sz); /* processing certificate_request triggers a connect. The error came * from there, the message can be considered processed successfully */ @@ -375,7 +375,7 @@ int Dtls13ProcessBufferedMessages(WOLFSSL* ssl) static int Dtls13NextMessageComplete(WOLFSSL* ssl) { return ssl->dtls_rx_msg_list != NULL && - ssl->dtls_rx_msg_list->fragSz == ssl->dtls_rx_msg_list->sz && + ssl->dtls_rx_msg_list->ready && ssl->dtls_rx_msg_list->seq == ssl->keys.dtls_expected_peer_handshake_number; } @@ -677,8 +677,18 @@ static int Dtls13DetectDisruption(WOLFSSL* ssl, word32 fragOffset) /* is not the next fragment in the message (the check is not 100% perfect, in the worst case, we don't detect the disruption and wait for the other peer retransmission) */ - if (ssl->dtls_rx_msg_list == NULL || - ssl->dtls_rx_msg_list->fragSz != fragOffset) { + if (ssl->dtls_rx_msg_list != NULL) { + DtlsFragBucket* last = ssl->dtls_rx_msg_list->fragBucketList; + while (last != NULL && last->next != NULL) + last = last->next; + /* Does this fragment start right after the last fragment we + * have stored? */ + if (last != NULL && (last->offset + last->sz) != fragOffset) + return 1; + } + else { + /* ssl->dtls_rx_msg_list is NULL and fragOffset != 0 so this is not in + * order */ return 1; } diff --git a/src/internal.c b/src/internal.c index 35b4c0864..6b1c437ca 100644 --- a/src/internal.c +++ b/src/internal.c @@ -8109,7 +8109,7 @@ void WriteSEQ(WOLFSSL* ssl, int verifyOrder, byte* out) * has the headers, and will include those headers in the hash. The store * routines need to take that into account as well. New will allocate * extra space for the headers. */ -DtlsMsg* DtlsMsgNew(word32 sz, void* heap) +DtlsMsg* DtlsMsgNew(word32 sz, byte tx, void* heap) { DtlsMsg* msg; WOLFSSL_ENTER("DtlsMsgNew()"); @@ -8119,16 +8119,17 @@ DtlsMsg* DtlsMsgNew(word32 sz, void* heap) if (msg != NULL) { XMEMSET(msg, 0, sizeof(DtlsMsg)); - msg->buf = (byte*)XMALLOC(sz + DTLS_HANDSHAKE_HEADER_SZ, - heap, DYNAMIC_TYPE_DTLS_BUFFER); - if (msg->buf != NULL) { - msg->sz = sz; - msg->type = no_shake; - msg->msg = msg->buf + DTLS_HANDSHAKE_HEADER_SZ; - } - else { - XFREE(msg, heap, DYNAMIC_TYPE_DTLS_MSG); - msg = NULL; + msg->sz = sz; + msg->type = no_shake; + if (tx) { + msg->raw = msg->fullMsg = + (byte*)XMALLOC(sz + DTLS_HANDSHAKE_HEADER_SZ, heap, + DYNAMIC_TYPE_DTLS_FRAG); + msg->ready = 1; + if (msg->raw == NULL) { + DtlsMsgDelete(msg, heap); + msg = NULL; + } } } @@ -8141,14 +8142,13 @@ void DtlsMsgDelete(DtlsMsg* item, void* heap) WOLFSSL_ENTER("DtlsMsgDelete()"); if (item != NULL) { - DtlsFrag* cur = item->fragList; - while (cur != NULL) { - DtlsFrag* next = cur->next; - XFREE(cur, heap, DYNAMIC_TYPE_DTLS_FRAG); - cur = next; + while (item->fragBucketList != NULL) { + DtlsFragBucket* next = item->fragBucketList->next; + DtlsMsgDestroyFragBucket(item->fragBucketList, heap); + item->fragBucketList = next; } - if (item->buf != NULL) - XFREE(item->buf, heap, DYNAMIC_TYPE_DTLS_BUFFER); + if (item->raw != NULL) + XFREE(item->raw, heap, DYNAMIC_TYPE_DTLS_FRAG); XFREE(item, heap, DYNAMIC_TYPE_DTLS_MSG); } } @@ -8187,131 +8187,279 @@ void DtlsTxMsgListClean(WOLFSSL* ssl) ssl->dtls_tx_msg_list = head; } -/* Create a DTLS Fragment from *begin - end, adjust new *begin and bytesLeft */ -static DtlsFrag* CreateFragment(word32* begin, word32 end, const byte* data, - byte* buf, word32* bytesLeft, void* heap) +static DtlsFragBucket* DtlsMsgCreateFragBucket(word32 offset, const byte* data, + word32 dataSz, void* heap) { - DtlsFrag* newFrag; - word32 added = end - *begin + 1; - - WOLFSSL_ENTER("CreateFragment()"); - - (void)heap; - newFrag = (DtlsFrag*)XMALLOC(sizeof(DtlsFrag), heap, - DYNAMIC_TYPE_DTLS_FRAG); - if (newFrag != NULL) { - newFrag->next = NULL; - newFrag->begin = *begin; - newFrag->end = end; - - XMEMCPY(buf + *begin, data, added); - *bytesLeft -= added; - *begin = newFrag->end + 1; + DtlsFragBucket* bucket = + (DtlsFragBucket*)XMALLOC(sizeof(DtlsFragBucket) + dataSz, heap, + DYNAMIC_TYPE_DTLS_FRAG); + if (bucket != NULL) { + XMEMSET(bucket, 0, sizeof(*bucket)); + bucket->next = NULL; + bucket->offset = offset; + bucket->sz = dataSz; + if (data != NULL) + XMEMCPY(bucket->buf, data, dataSz); } - - return newFrag; + return bucket; } - -int DtlsMsgSet(DtlsMsg* msg, word32 seq, word16 epoch, const byte* data, byte type, - word32 fragOffset, word32 fragSz, void* heap) +void DtlsMsgDestroyFragBucket(DtlsFragBucket* fragBucket, void* heap) { - WOLFSSL_ENTER("DtlsMsgSet()"); - if (msg != NULL && data != NULL && msg->fragSz <= msg->sz && - fragSz <= msg->sz && fragOffset <= msg->sz && - (fragOffset + fragSz) <= msg->sz) { - DtlsFrag* cur = msg->fragList; - DtlsFrag* prev = cur; - DtlsFrag* newFrag; - word32 bytesLeft = fragSz; /* could be overlapping fragment */ - word32 startOffset = fragOffset; - word32 added; + (void)heap; + XFREE(fragBucket, heap, DYNAMIC_TYPE_DTLS_FRAG); +} - msg->seq = seq; - msg->epoch = epoch; - msg->type = type; +/* + * data overlaps with cur but is before next. + * data + dataSz has to end before or inside next. next can be NULL. + */ +static DtlsFragBucket* DtlsMsgCombineFragBuckets(DtlsMsg* msg, + DtlsFragBucket* cur, DtlsFragBucket* next, word32 offset, + const byte* data, word32 dataSz, void* heap) +{ + word32 offsetEnd = offset + dataSz; + word32 newOffset = min(cur->offset, offset); + word32 newOffsetEnd; + word32 newSz; + word32 overlapSz = cur->sz; + DtlsFragBucket** chosenBucket; + DtlsFragBucket* newBucket; + DtlsFragBucket* otherBucket; + byte combineNext = FALSE; - if (fragOffset == 0) { - XMEMCPY(msg->buf, data - DTLS_HANDSHAKE_HEADER_SZ, - DTLS_HANDSHAKE_HEADER_SZ); - c32to24(msg->sz, msg->msg - DTLS_HANDSHAKE_FRAG_SZ); + if (next != NULL && offsetEnd >= next->offset) + combineNext = TRUE; + + if (combineNext) + newOffsetEnd = next->offset + next->sz; + else + newOffsetEnd = max(cur->offset + cur->sz, offsetEnd); + + newSz = newOffsetEnd - newOffset; + + /* Expand the larger bucket if data bridges the gap between cur and next */ + if (!combineNext || cur->sz >= next->sz) { + chosenBucket = &cur; + otherBucket = next; + } + else { + chosenBucket = &next; + otherBucket = cur; + } + + { + DtlsFragBucket* tmp = (DtlsFragBucket*)XREALLOC(*chosenBucket, + sizeof(DtlsFragBucket) + newSz, heap, DYNAMIC_TYPE_DTLS_FRAG); + if (tmp == NULL) + return NULL; + if (chosenBucket == &next) { + /* Update the link */ + DtlsFragBucket* beforeNext = cur; + while (beforeNext->next != next) + beforeNext = beforeNext->next; + beforeNext->next = tmp; } + newBucket = *chosenBucket = tmp; + } - /* if no message data, just return */ - if (fragSz == 0) - return 0; + if (combineNext) { + /* Put next first since it will always be at the end. Use memmove since + * newBucket may be next. */ + XMEMMOVE(newBucket->buf + (next->offset - newOffset), next->buf, + next->sz); + /* memory after newOffsetEnd is already copied. Don't do extra work. */ + newOffsetEnd = next->offset; + } - /* if list is empty add full fragment to front */ - if (cur == NULL) { - newFrag = CreateFragment(&fragOffset, fragOffset + fragSz - 1, data, - msg->msg, &bytesLeft, heap); - if (newFrag == NULL) - return MEMORY_E; - - msg->fragSz = fragSz; - msg->fragList = newFrag; - - return 0; + if (newOffset == offset) { + /* data comes first */ + if (newOffsetEnd <= offsetEnd) { + /* data encompasses cur. only copy data */ + XMEMCPY(newBucket->buf, data, + min(dataSz, newOffsetEnd - newOffset)); } - - /* add to front if before current front, up to next->begin */ - if (fragOffset < cur->begin) { - word32 end = fragOffset + fragSz - 1; - - if (end >= cur->begin) - end = cur->begin - 1; - - added = end - fragOffset + 1; - newFrag = CreateFragment(&fragOffset, end, data, msg->msg, - &bytesLeft, heap); - if (newFrag == NULL) - return MEMORY_E; - - msg->fragSz += added; - - newFrag->next = cur; - msg->fragList = newFrag; - } - - /* while we have bytes left, try to find a gap to fill */ - while (bytesLeft > 0) { - /* get previous packet in list */ - while (cur && (fragOffset >= cur->begin)) { - prev = cur; - cur = cur->next; - } - - /* don't add duplicate data */ - if (prev->end >= fragOffset) { - if ( (fragOffset + bytesLeft - 1) <= prev->end) - return 0; - fragOffset = prev->end + 1; - bytesLeft = startOffset + fragSz - fragOffset; - } - - if (cur == NULL) - /* we're at the end */ - added = bytesLeft; - else - /* we're in between two frames */ - added = min(bytesLeft, cur->begin - fragOffset); - - /* data already there */ - if (added == 0) - continue; - - newFrag = CreateFragment(&fragOffset, fragOffset + added - 1, - data + fragOffset - startOffset, - msg->msg, &bytesLeft, heap); - if (newFrag == NULL) - return MEMORY_E; - - msg->fragSz += added; - - newFrag->next = prev->next; - prev->next = newFrag; + else { + /* data -> cur. memcpy as much possible as its faster. */ + XMEMMOVE(newBucket->buf + dataSz, cur->buf, + cur->sz - (offsetEnd - cur->offset)); + XMEMCPY(newBucket->buf, data, dataSz); } } + else { + /* cur -> data */ + word32 curOffsetEnd = cur->offset + cur->sz; + if (newBucket != cur) + XMEMCPY(newBucket->buf, cur->buf, cur->sz); + XMEMCPY(newBucket->buf + cur->sz, + data + (curOffsetEnd - offset), + newOffsetEnd - curOffsetEnd); + } + /* FINALLY the newBucket is populated correctly */ + + /* All buckets up to and including next (if combining) have to be free'd */ + { + DtlsFragBucket* toFree = cur->next; + while (toFree != next) { + DtlsFragBucket* n = toFree->next; + overlapSz += toFree->sz; + DtlsMsgDestroyFragBucket(toFree, heap); + msg->fragBucketListCount--; + toFree = n; + } + if (combineNext) { + newBucket->next = next->next; + overlapSz += next->sz; + DtlsMsgDestroyFragBucket(otherBucket, heap); + msg->fragBucketListCount--; + } + else { + newBucket->next = next; + } + } + /* Adjust size in msg */ + msg->bytesReceived += newSz - overlapSz; + newBucket->offset = newOffset; + newBucket->sz = newSz; + return newBucket; +} + +static void DtlsMsgAssembleCompleteMessage(DtlsMsg* msg) +{ + /* We have received all necessary fragments. Reconstruct the header. */ + if (msg->fragBucketListCount != 1 || msg->fragBucketList->offset != 0 || + msg->fragBucketList->sz != msg->sz) { + WOLFSSL_MSG("Major error in fragment assembly logic"); + return; + } + + /* Re-cycle the DtlsFragBucket as the buffer that holds the complete + * handshake message and the header. */ + msg->raw = (byte*)msg->fragBucketList; + msg->fullMsg = msg->fragBucketList->buf; + msg->ready = 1; + + /* frag->padding makes sure we can fit the entire DTLS handshake header + * before frag->buf */ + DtlsHandShakeHeader* dtls = + (DtlsHandShakeHeader*)(msg->fragBucketList->buf - + DTLS_HANDSHAKE_HEADER_SZ); + + msg->fragBucketList = NULL; + msg->fragBucketListCount = 0; + + dtls->type = msg->type; + c32to24(msg->sz, dtls->length); + c16toa(msg->seq, dtls->message_seq); + c32to24(0, dtls->fragment_offset); + c32to24(msg->sz, dtls->fragment_length); +} + +int DtlsMsgSet(DtlsMsg* msg, word32 seq, word16 epoch, const byte* data, byte type, + word32 fragOffset, word32 fragSz, void* heap, word32 totalLen) +{ + word32 fragOffsetEnd = fragOffset + fragSz; + + WOLFSSL_ENTER("DtlsMsgSet()"); + + if (msg == NULL || data == NULL || msg->sz != totalLen || + fragOffsetEnd > totalLen) { + WOLFSSL_ERROR_VERBOSE(BAD_FUNC_ARG); + return BAD_FUNC_ARG; + } + + if (msg->ready) + return 0; /* msg is already complete */ + + if (msg->type != no_shake) { + /* msg is already populated with the correct seq, epoch, and type */ + if (msg->type != type || msg->epoch != epoch || msg->seq != seq) { + WOLFSSL_ERROR_VERBOSE(SEQUENCE_ERROR); + return SEQUENCE_ERROR; + } + } + else { + msg->type = type; + msg->epoch = epoch; + msg->seq = seq; + } + + if (msg->fragBucketList == NULL) { + /* Clean list. Create first fragment. */ + msg->fragBucketList = DtlsMsgCreateFragBucket(fragOffset, data, fragSz, heap); + msg->bytesReceived = fragSz; + msg->fragBucketListCount++; + } + else { + /* See if we can expand any existing bucket to fit this new data into */ + DtlsFragBucket* prev = NULL; + DtlsFragBucket* cur = msg->fragBucketList; + byte done = 0; + for (; cur != NULL; prev = cur, cur = cur->next) { + word32 curOffset = cur->offset; + word32 curEnd = cur->offset + cur->sz; + + if (fragOffset >= curOffset && fragOffsetEnd <= curEnd) { + /* We already have this fragment */ + done = 1; + break; + } + else if (fragOffset <= curEnd) { + /* found place to store fragment */ + break; + } + } + if (!done) { + if (cur == NULL) { + /* We reached the end of the list. data is after and disjointed + * from anything we have received so far. */ + if (msg->fragBucketListCount >= DTLS_FRAG_POOL_SZ) { + WOLFSSL_ERROR_VERBOSE(DTLS_TOO_MANY_FRAGMENTS_E); + return DTLS_TOO_MANY_FRAGMENTS_E; + } + prev->next = DtlsMsgCreateFragBucket(fragOffset, data, fragSz, heap); + if (prev->next != NULL) { + msg->bytesReceived += fragSz; + msg->fragBucketListCount++; + } + } + else if (prev == NULL && fragOffsetEnd < cur->offset) { + /* This is the new first fragment we have received */ + if (msg->fragBucketListCount >= DTLS_FRAG_POOL_SZ) { + WOLFSSL_ERROR_VERBOSE(DTLS_TOO_MANY_FRAGMENTS_E); + return DTLS_TOO_MANY_FRAGMENTS_E; + } + msg->fragBucketList = DtlsMsgCreateFragBucket(fragOffset, data, + fragSz, heap); + if (msg->fragBucketList != NULL) { + msg->fragBucketList->next = cur; + msg->bytesReceived += fragSz; + msg->fragBucketListCount++; + } + else { + /* reset on error */ + msg->fragBucketList = cur; + } + } + else { + /* Find if this fragment overlaps with any more */ + DtlsFragBucket* next = cur->next; + DtlsFragBucket** prev_next = prev != NULL + ? &prev->next : &msg->fragBucketList; + while (next != NULL && + (next->offset + next->sz) <= fragOffsetEnd) + next = next->next; + /* We can combine the buckets */ + *prev_next = DtlsMsgCombineFragBuckets(msg, cur, next, + fragOffset, data, fragSz, heap); + if (*prev_next == NULL) /* reset on error */ + *prev_next = cur; + } + } + } + + if (msg->bytesReceived == msg->sz) + DtlsMsgAssembleCompleteMessage(msg); return 0; } @@ -8353,10 +8501,10 @@ void DtlsMsgStore(WOLFSSL* ssl, word16 epoch, word32 seq, const byte* data, if (head != NULL) { DtlsMsg* cur = DtlsMsgFind(head, epoch, seq); if (cur == NULL) { - cur = DtlsMsgNew(dataSz, heap); + cur = DtlsMsgNew(dataSz, 0, heap); if (cur != NULL) { if (DtlsMsgSet(cur, seq, epoch, data, type, - fragOffset, fragSz, heap) < 0) { + fragOffset, fragSz, heap, dataSz) < 0) { DtlsMsgDelete(cur, heap); } else { @@ -8368,13 +8516,13 @@ void DtlsMsgStore(WOLFSSL* ssl, word16 epoch, word32 seq, const byte* data, else { /* If this fails, the data is just dropped. */ DtlsMsgSet(cur, seq, epoch, data, type, fragOffset, - fragSz, heap); + fragSz, heap, dataSz); } } else { - head = DtlsMsgNew(dataSz, heap); + head = DtlsMsgNew(dataSz, 0, heap); if (DtlsMsgSet(head, seq, epoch, data, type, fragOffset, - fragSz, heap) < 0) { + fragSz, heap, dataSz) < 0) { DtlsMsgDelete(head, heap); head = NULL; } @@ -8439,12 +8587,12 @@ int DtlsMsgPoolSave(WOLFSSL* ssl, const byte* data, word32 dataSz, return DTLS_POOL_SZ_E; } - item = DtlsMsgNew(dataSz, ssl->heap); + item = DtlsMsgNew(dataSz, 1, ssl->heap); if (item != NULL) { DtlsMsg* cur = ssl->dtls_tx_msg_list; - XMEMCPY(item->buf, data, dataSz); + XMEMCPY(item->raw, data, dataSz); item->sz = dataSz; item->epoch = ssl->keys.dtls_epoch; item->seq = ssl->keys.dtls_handshake_number; @@ -8580,7 +8728,7 @@ int DtlsMsgPoolSend(WOLFSSL* ssl, int sendOnlyFirstPacket) if (pool->epoch == 0) { DtlsRecordLayerHeader* dtls; - dtls = (DtlsRecordLayerHeader*)pool->buf; + dtls = (DtlsRecordLayerHeader*)pool->raw; /* If the stored record's epoch is 0, and the currently set * epoch is 0, use the "current order" sequence number. * If the stored record's epoch is 0 and the currently set @@ -8599,7 +8747,7 @@ int DtlsMsgPoolSend(WOLFSSL* ssl, int sendOnlyFirstPacket) XMEMCPY(ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.idx + ssl->buffers.outputBuffer.length, - pool->buf, pool->sz); + pool->raw, pool->sz); ssl->buffers.outputBuffer.length += pool->sz; } else { @@ -8608,7 +8756,7 @@ int DtlsMsgPoolSend(WOLFSSL* ssl, int sendOnlyFirstPacket) byte* output; int inputSz, sendSz; - input = pool->buf; + input = pool->raw; inputSz = pool->sz; sendSz = inputSz + cipherExtraData(ssl); @@ -9380,8 +9528,8 @@ static int SendHandshakeMsg(WOLFSSL* ssl, byte* input, word32 inputSz, if (ssl->options.dtls) { data -= DTLS_HANDSHAKE_HEADER_SZ; dataSz += DTLS_HANDSHAKE_HEADER_SZ; - AddHandShakeHeader(data, - inputSz, ssl->fragOffset, fragSz, type, ssl); + AddHandShakeHeader(data, inputSz, ssl->fragOffset, fragSz, + type, ssl); ssl->keys.dtls_handshake_number--; } if (IsDtlsNotSctpMode(ssl) && @@ -16112,15 +16260,14 @@ int DtlsMsgDrain(WOLFSSL* ssl) * last message... */ while (item != NULL && ssl->keys.dtls_expected_peer_handshake_number == item->seq && - item->fragSz == item->sz && - ret == 0) { + item->ready && ret == 0) { word32 idx = 0; #ifdef WOLFSSL_NO_TLS12 - ret = DoTls13HandShakeMsgType(ssl, item->msg, &idx, item->type, + ret = DoTls13HandShakeMsgType(ssl, item->fullMsg, &idx, item->type, item->sz, item->sz); #else - ret = DoHandShakeMsgType(ssl, item->msg, &idx, item->type, + ret = DoHandShakeMsgType(ssl, item->fullMsg, &idx, item->type, item->sz, item->sz); #endif if (ret == 0) { @@ -16356,8 +16503,7 @@ static int DoDtlsHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx, } #endif ret = 0; - if (ssl->dtls_rx_msg_list != NULL && - ssl->dtls_rx_msg_list->fragSz >= ssl->dtls_rx_msg_list->sz) + if (ssl->dtls_rx_msg_list != NULL && ssl->dtls_rx_msg_list->ready) ret = DtlsMsgDrain(ssl); } else { diff --git a/tests/api.c b/tests/api.c index add8b76a5..415323970 100644 --- a/tests/api.c +++ b/tests/api.c @@ -57640,6 +57640,139 @@ static int test_wolfSSL_DtlsUpdateWindow(void) } #endif /* WOLFSSL_DTLS */ +#ifdef WOLFSSL_DTLS +static int DFB_TEST(WOLFSSL* ssl, word32 seq, word32 len, word32 f_offset, + word32 f_len, word32 f_count, byte ready, word32 bytesReceived) +{ + DtlsMsg* cur; + static byte msg[100]; + static byte msgInit = 0; + + if (!msgInit) { + int i; + for (i = 0; i < 100; i++) + msg[i] = i + 1; + msgInit = 1; + } + + /* Sanitize test parameters */ + if (len > sizeof(msg)) + return -1; + if (f_offset + f_len > sizeof(msg)) + return -1; + + DtlsMsgStore(ssl, 0, seq, msg + f_offset, len, certificate, f_offset, f_len, NULL); + + if (ssl->dtls_rx_msg_list == NULL) + return -100; + + if ((cur = DtlsMsgFind(ssl->dtls_rx_msg_list, 0, seq)) == NULL) + return -200; + if (cur->fragBucketListCount != f_count) + return -300; + if (cur->ready != ready) + return -400; + if (cur->bytesReceived != bytesReceived) + return -500; + if (ready) { + if (cur->fragBucketList != NULL) + return -600; + if (XMEMCMP(cur->fullMsg, msg, cur->sz) != 0) + return -700; + } + else { + DtlsFragBucket* fb; + if (cur->fragBucketList == NULL) + return -800; + for (fb = cur->fragBucketList; fb != NULL; fb = fb->next) { + if (XMEMCMP(fb->buf, msg + fb->offset, fb->sz) != 0) + return -900; + } + } + return 0; +} + +static void DFB_TEST_RESET(WOLFSSL* ssl) +{ + DtlsMsgListDelete(ssl->dtls_rx_msg_list, ssl->heap); + ssl->dtls_rx_msg_list = NULL; + ssl->dtls_rx_msg_list_sz = 0; +} + +static int test_wolfSSL_DTLS_fragment_buckets(void) +{ + WOLFSSL ssl[1]; + + printf(testingFmt, "wolfSSL_DTLS_fragment_buckets()"); + + XMEMSET(ssl, 0, sizeof(*ssl)); + + AssertIntEQ(DFB_TEST(ssl, 0, 100, 0, 100, 0, 1, 100), 0); /* 0-100 */ + + AssertIntEQ(DFB_TEST(ssl, 1, 100, 0, 20, 1, 0, 20), 0); /* 0-20 */ + AssertIntEQ(DFB_TEST(ssl, 1, 100, 20, 20, 1, 0, 40), 0); /* 20-40 */ + AssertIntEQ(DFB_TEST(ssl, 1, 100, 40, 20, 1, 0, 60), 0); /* 40-60 */ + AssertIntEQ(DFB_TEST(ssl, 1, 100, 60, 20, 1, 0, 80), 0); /* 60-80 */ + AssertIntEQ(DFB_TEST(ssl, 1, 100, 80, 20, 0, 1, 100), 0); /* 80-100 */ + + /* Test all permutations of 3 regions */ + /* 1 2 3 */ + AssertIntEQ(DFB_TEST(ssl, 2, 100, 0, 30, 1, 0, 30), 0); /* 0-30 */ + AssertIntEQ(DFB_TEST(ssl, 2, 100, 30, 30, 1, 0, 60), 0); /* 30-60 */ + AssertIntEQ(DFB_TEST(ssl, 2, 100, 60, 40, 0, 1, 100), 0); /* 60-100 */ + /* 1 3 2 */ + AssertIntEQ(DFB_TEST(ssl, 3, 100, 0, 30, 1, 0, 30), 0); /* 0-30 */ + AssertIntEQ(DFB_TEST(ssl, 3, 100, 60, 40, 2, 0, 70), 0); /* 60-100 */ + AssertIntEQ(DFB_TEST(ssl, 3, 100, 30, 30, 0, 1, 100), 0); /* 30-60 */ + /* 2 1 3 */ + AssertIntEQ(DFB_TEST(ssl, 4, 100, 30, 30, 1, 0, 30), 0); /* 30-60 */ + AssertIntEQ(DFB_TEST(ssl, 4, 100, 0, 30, 1, 0, 60), 0); /* 0-30 */ + AssertIntEQ(DFB_TEST(ssl, 4, 100, 60, 40, 0, 1, 100), 0); /* 60-100 */ + /* 2 3 1 */ + AssertIntEQ(DFB_TEST(ssl, 5, 100, 30, 30, 1, 0, 30), 0); /* 30-60 */ + AssertIntEQ(DFB_TEST(ssl, 5, 100, 60, 40, 1, 0, 70), 0); /* 60-100 */ + AssertIntEQ(DFB_TEST(ssl, 5, 100, 0, 30, 0, 1, 100), 0); /* 0-30 */ + /* 3 1 2 */ + AssertIntEQ(DFB_TEST(ssl, 6, 100, 60, 40, 1, 0, 40), 0); /* 60-100 */ + AssertIntEQ(DFB_TEST(ssl, 6, 100, 0, 30, 2, 0, 70), 0); /* 0-30 */ + AssertIntEQ(DFB_TEST(ssl, 6, 100, 30, 30, 0, 1, 100), 0); /* 30-60 */ + /* 3 2 1 */ + AssertIntEQ(DFB_TEST(ssl, 7, 100, 60, 40, 1, 0, 40), 0); /* 60-100 */ + AssertIntEQ(DFB_TEST(ssl, 7, 100, 30, 30, 1, 0, 70), 0); /* 30-60 */ + AssertIntEQ(DFB_TEST(ssl, 7, 100, 0, 30, 0, 1, 100), 0); /* 0-30 */ + + /* Test overlapping regions */ + AssertIntEQ(DFB_TEST(ssl, 8, 100, 0, 30, 1, 0, 30), 0); /* 0-30 */ + AssertIntEQ(DFB_TEST(ssl, 8, 100, 20, 10, 1, 0, 30), 0); /* 20-30 */ + AssertIntEQ(DFB_TEST(ssl, 8, 100, 70, 10, 2, 0, 40), 0); /* 70-80 */ + AssertIntEQ(DFB_TEST(ssl, 8, 100, 20, 30, 2, 0, 60), 0); /* 20-50 */ + AssertIntEQ(DFB_TEST(ssl, 8, 100, 40, 60, 0, 1, 100), 0); /* 40-100 */ + + /* Test overlapping multiple regions */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 0, 20, 1, 0, 20), 0); /* 0-20 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 30, 5, 2, 0, 25), 0); /* 30-35 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 40, 5, 3, 0, 30), 0); /* 40-45 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 50, 5, 4, 0, 35), 0); /* 50-55 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 60, 5, 5, 0, 40), 0); /* 60-65 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 70, 5, 6, 0, 45), 0); /* 70-75 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 30, 25, 4, 0, 55), 0); /* 30-55 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 55, 15, 2, 0, 65), 0); /* 55-70 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 75, 25, 2, 0, 90), 0); /* 75-100 */ + AssertIntEQ(DFB_TEST(ssl, 9, 100, 10, 25, 0, 1, 100), 0); /* 10-35 */ + + AssertIntEQ(DFB_TEST(ssl, 10, 100, 0, 20, 1, 0, 20), 0); /* 0-20 */ + AssertIntEQ(DFB_TEST(ssl, 10, 100, 30, 20, 2, 0, 40), 0); /* 30-50 */ + AssertIntEQ(DFB_TEST(ssl, 10, 100, 0, 40, 1, 0, 50), 0); /* 0-40 */ + AssertIntEQ(DFB_TEST(ssl, 10, 100, 50, 50, 0, 1, 100), 0); /* 10-35 */ + + DFB_TEST_RESET(ssl); + + printf(resultFmt, passed); + + return 0; +} +#endif + /*----------------------------------------------------------------------------* | Main *----------------------------------------------------------------------------*/ @@ -58535,6 +58668,7 @@ TEST_CASE testCases[] = { TEST_DECL(test_wolfSSL_FIPS_mode), #ifdef WOLFSSL_DTLS TEST_DECL(test_wolfSSL_DtlsUpdateWindow), + TEST_DECL(test_wolfSSL_DTLS_fragment_buckets), #endif TEST_DECL(test_ForceZero), diff --git a/wolfssl/internal.h b/wolfssl/internal.h index a2f95daf7..ed9bc5eae 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -1115,6 +1115,9 @@ enum { #define WOLFSSL_DTLS_MTU_ADDITIONAL_READ_BUFFER 500 #endif /* WOLFSSL_DTLS_MTU_ADDITIONAL_READ_BUFFER */ +#ifndef WOLFSSL_DTLS_FRAG_POOL_SZ + #define WOLFSSL_DTLS_FRAG_POOL_SZ 10 +#endif /* set minimum DH key size allowed */ #ifndef WOLFSSL_MIN_DHKEY_BITS @@ -1398,6 +1401,8 @@ enum Misc { DTLS_HANDSHAKE_FRAG_SZ = 3, /* fragment offset and length are 24 bit */ DTLS_POOL_SZ = 20, /* allowed number of list items in TX and * RX pool */ + DTLS_FRAG_POOL_SZ = WOLFSSL_DTLS_FRAG_POOL_SZ, + /* allowed number of fragments per msg */ DTLS_EXPORT_PRO = 165,/* wolfSSL protocol for serialized session */ DTLS_EXPORT_STATE_PRO = 166,/* wolfSSL protocol for serialized state */ TLS_EXPORT_PRO = 167,/* wolfSSL protocol for serialized TLS */ @@ -4435,23 +4440,36 @@ typedef struct DtlsRecordLayerHeader { } DtlsRecordLayerHeader; -typedef struct DtlsFrag { - word32 begin; - word32 end; - struct DtlsFrag* next; -} DtlsFrag; +/* Padding necessary to fit DTLS_HANDSHAKE_HEADER_SZ bytes before the buf member + * of the DtlsFragBucket struct. */ +#define WOLFSSL_DTLS_FRAG_BUCKET_PADDING \ + ((DTLS_HANDSHAKE_HEADER_SZ > (sizeof(struct DtlsFragBucket*) + \ + sizeof(word32) + sizeof(word32))) ? \ + (DTLS_HANDSHAKE_HEADER_SZ - sizeof(struct DtlsFragBucket*) - \ + sizeof(word32) - sizeof(word32)) : 0) +typedef struct DtlsFragBucket { + struct DtlsFragBucket* next; + word32 offset; + word32 sz; + byte padding[WOLFSSL_DTLS_FRAG_BUCKET_PADDING]; + byte buf[]; + /* Add new member initialization to CreateFragBucket */ +} DtlsFragBucket; typedef struct DtlsMsg { struct DtlsMsg* next; - byte* buf; - byte* msg; - DtlsFrag* fragList; - word32 fragSz; /* Length of fragments received */ + byte* raw; + byte* fullMsg; /* for TX fullMsg == raw. For RX this points to + * the start of the message after headers. */ + DtlsFragBucket* fragBucketList; + word32 bytesReceived; word16 epoch; /* Epoch that this message belongs to */ word32 seq; /* Handshake sequence number */ word32 sz; /* Length of whole message */ byte type; + byte fragBucketListCount; + byte ready:1; } DtlsMsg; @@ -5462,13 +5480,14 @@ WOLFSSL_LOCAL int cipherExtraData(WOLFSSL* ssl); #endif /* NO_WOLFSSL_SERVER */ #ifdef WOLFSSL_DTLS - WOLFSSL_LOCAL DtlsMsg* DtlsMsgNew(word32 sz, void* heap); + WOLFSSL_LOCAL DtlsMsg* DtlsMsgNew(word32 sz, byte tx, void* heap); WOLFSSL_LOCAL void DtlsMsgDelete(DtlsMsg* item, void* heap); WOLFSSL_LOCAL void DtlsMsgListDelete(DtlsMsg* head, void* heap); WOLFSSL_LOCAL void DtlsTxMsgListClean(WOLFSSL* ssl); WOLFSSL_LOCAL int DtlsMsgSet(DtlsMsg* msg, word32 seq, word16 epoch, const byte* data, byte type, - word32 fragOffset, word32 fragSz, void* heap); + word32 fragOffset, word32 fragSz, void* heap, + word32 totalLen); WOLFSSL_LOCAL DtlsMsg* DtlsMsgFind(DtlsMsg* head, word16 epoch, word32 seq); WOLFSSL_LOCAL void DtlsMsgStore(WOLFSSL* ssl, word16 epoch, word32 seq, @@ -5485,6 +5504,7 @@ WOLFSSL_LOCAL int cipherExtraData(WOLFSSL* ssl); WOLFSSL_LOCAL int VerifyForTxDtlsMsgDelete(WOLFSSL* ssl, DtlsMsg* item); WOLFSSL_LOCAL void DtlsMsgPoolReset(WOLFSSL* ssl); WOLFSSL_LOCAL int DtlsMsgPoolSend(WOLFSSL* ssl, int sendOnlyFirstPacket); + WOLFSSL_LOCAL void DtlsMsgDestroyFragBucket(DtlsFragBucket* fragBucket, void* heap); WOLFSSL_LOCAL int GetDtlsHandShakeHeader(WOLFSSL *ssl, const byte *input, word32 *inOutIdx, byte *type, word32 *size, word32 *fragOffset, word32 *fragSz, word32 totalSz);