diff --git a/cyassl/internal.h b/cyassl/internal.h index 517035cbc..9f48d2f1f 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1390,6 +1390,7 @@ typedef struct DtlsMsg { word32 seq; /* Handshake sequence number */ word32 sz; /* Length of whole mesage */ word32 fragSz; /* Length of fragments received */ + byte type; byte* msg; } DtlsMsg; @@ -1690,10 +1691,11 @@ CYASSL_LOCAL int GrowInputBuffer(CYASSL* ssl, int size, int usedLength); CYASSL_LOCAL DtlsMsg* DtlsMsgNew(word32, void*); CYASSL_LOCAL void DtlsMsgDelete(DtlsMsg*, void*); CYASSL_LOCAL void DtlsMsgListDelete(DtlsMsg*, void*); - CYASSL_LOCAL void DtlsMsgSet(DtlsMsg*, word32, const byte*, word32, word32); + CYASSL_LOCAL void DtlsMsgSet(DtlsMsg*, word32, const byte*, byte, + word32, word32); CYASSL_LOCAL DtlsMsg* DtlsMsgFind(DtlsMsg*, word32); CYASSL_LOCAL DtlsMsg* DtlsMsgStore(DtlsMsg*, word32, const byte*, word32, - word32, word32, void*); + byte, word32, word32, void*); CYASSL_LOCAL DtlsMsg* DtlsMsgInsert(DtlsMsg*, DtlsMsg*); #endif /* CYASSL_DTLS */ diff --git a/src/internal.c b/src/internal.c index 0a9078863..bfd36042b 100644 --- a/src/internal.c +++ b/src/internal.c @@ -1699,11 +1699,12 @@ void DtlsMsgListDelete(DtlsMsg* head, void* heap) } -void DtlsMsgSet(DtlsMsg* msg, word32 seq, const byte* data, +void DtlsMsgSet(DtlsMsg* msg, word32 seq, const byte* data, byte type, word32 fragOffset, word32 fragSz) { if (msg != NULL && data != NULL && msg->fragSz <= msg->sz) { msg->seq = seq; + msg->type = type; msg->fragSz += fragSz; XMEMCPY(&msg->msg[fragOffset], data, fragSz); } @@ -1719,8 +1720,8 @@ DtlsMsg* DtlsMsgFind(DtlsMsg* head, word32 seq) } -DtlsMsg* DtlsMsgStore(DtlsMsg* head, word32 seq, const byte* data, word32 dataSz, - word32 fragOffset, word32 fragSz, void* heap) +DtlsMsg* DtlsMsgStore(DtlsMsg* head, word32 seq, const byte* data, + word32 dataSz, byte type, word32 fragOffset, word32 fragSz, void* heap) { /* See if seq exists in the list. If it isn't in the list, make @@ -1742,16 +1743,16 @@ DtlsMsg* DtlsMsgStore(DtlsMsg* head, word32 seq, const byte* data, word32 dataSz DtlsMsg* cur = DtlsMsgFind(head, seq); if (cur == NULL) { cur = DtlsMsgNew(dataSz, heap); - DtlsMsgSet(cur, seq, data, fragOffset, fragSz); + DtlsMsgSet(cur, seq, data, type, fragOffset, fragSz); head = DtlsMsgInsert(head, cur); } else { - DtlsMsgSet(cur, seq, data, fragOffset, fragSz); + DtlsMsgSet(cur, seq, data, type, fragOffset, fragSz); } } else { head = DtlsMsgNew(dataSz, heap); - DtlsMsgSet(head, seq, data, fragOffset, fragSz); + DtlsMsgSet(head, seq, data, type, fragOffset, fragSz); } return head; @@ -2349,7 +2350,6 @@ static int GetDtlsHandShakeHeader(CYASSL* ssl, const byte* input, c24to32(input + idx, fragOffset); idx += DTLS_HANDSHAKE_FRAG_SZ; c24to32(input + idx, fragSz); - idx += DTLS_HANDSHAKE_FRAG_SZ; return 0; } @@ -3094,66 +3094,55 @@ static int DoDtlsHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, if (*inOutIdx + fragSz > totalSz) return INCOMPLETE_DATA; - if (fragSz < size) { - /* message is fragmented, knit back together */ - byte* buf = ssl->buffers.dtlsHandshake.buffer; - if (ssl->buffers.dtlsHandshake.length == 0) { - /* Need to add a header back into the data. The Hash is calculated - * as if this were a single message, not several fragments. */ - buf = (byte*)XMALLOC(size + DTLS_HANDSHAKE_HEADER_SZ, - ssl->heap, DYNAMIC_TYPE_NONE); - if (buf == NULL) - return MEMORY_ERROR; - - ssl->buffers.dtlsHandshake.length = size; - ssl->buffers.dtlsHandshake.buffer = buf; - ssl->buffers.dtlsUsed = 0; - ssl->buffers.dtlsType = type; - - /* Construct a new header for the reassembled message as if it - * were originally sent as one fragment for the hashing later. */ - XMEMCPY(buf, - input + *inOutIdx - DTLS_HANDSHAKE_HEADER_SZ, - DTLS_HANDSHAKE_HEADER_SZ - DTLS_HANDSHAKE_FRAG_SZ); - XMEMCPY(buf + DTLS_HANDSHAKE_HEADER_SZ - DTLS_HANDSHAKE_FRAG_SZ, - input + *inOutIdx - DTLS_HANDSHAKE_HEADER_SZ + ENUM_LEN, - DTLS_HANDSHAKE_FRAG_SZ); - } - /* readjust the buf pointer past the header */ - buf += DTLS_HANDSHAKE_HEADER_SZ; - - XMEMCPY(buf + fragOffset, input + *inOutIdx, fragSz); - ssl->buffers.dtlsUsed += fragSz; - *inOutIdx += fragSz; - - if (ssl->buffers.dtlsUsed != size) { - CYASSL_LEAVE("DoDtlsHandShakeMsg()", 0); - return 0; - } - else { - if (ssl->keys.dtls_peer_handshake_number == + /* Check the handshake sequence number first. If out of order, + * add the current message to the list. If the message is in order, + * but it is a fragment, add the current message to the list, then + * check the head of the list to see if it is complete, if so, pop + * it out as the current message. If the message is complete and in + * order, process it. Check the head of the list to see if it is in + * order, if so, process it. (Repeat until list exhausted.) If the + * head is out of order, return for more processing. + * NOTE: The hash is calculated on the data, not the header. In + * DoHandShakeMsgType(), HashInput starts with inOutIdx. + */ + if (ssl->keys.dtls_peer_handshake_number > ssl->keys.dtls_expected_peer_handshake_number) { - word32 idx = 0; - totalSz = size; - ssl->keys.dtls_expected_peer_handshake_number++; - ret = DoHandShakeMsgType(ssl, buf, &idx, type, size, totalSz); - } - else { - *inOutIdx += size; - ret = 0; - } + /* Current message is out of order. It will get stored in the list. + * Storing also takes care of defragmentation. */ + ssl->dtls_msg_list = DtlsMsgStore(ssl->dtls_msg_list, + ssl->keys.dtls_peer_handshake_number, input + *inOutIdx, + size, type, fragOffset, fragSz, ssl->heap); + *inOutIdx += fragSz; + ret = 0; + } + else if (ssl->keys.dtls_peer_handshake_number < + ssl->keys.dtls_expected_peer_handshake_number) { + /* Already saw this message and processed it. It can be ignored. */ + *inOutIdx += fragSz; + ret = 0; + } + else if (fragSz < size) { + /* Since this branch is in order, but fragmented, dtls_msg_list will be + * pointing the the message with this fragment in it. Check it to see + * if it is completed. */ + ssl->dtls_msg_list = DtlsMsgStore(ssl->dtls_msg_list, + ssl->keys.dtls_peer_handshake_number, input + *inOutIdx, + size, type, fragOffset, fragSz, ssl->heap); + *inOutIdx += fragSz; + if (ssl->dtls_msg_list->fragSz >= ssl->dtls_msg_list->sz) { + DtlsMsg* item = ssl->dtls_msg_list; + word32 idx = 0; + ssl->keys.dtls_expected_peer_handshake_number++; + ret = DoHandShakeMsgType(ssl, item->msg, &idx, + item->type, item->sz, item->sz); + ssl->dtls_msg_list = item->next; + DtlsMsgDelete(item, ssl->heap); } } else { - if (ssl->keys.dtls_peer_handshake_number == - ssl->keys.dtls_expected_peer_handshake_number) { - ssl->keys.dtls_expected_peer_handshake_number++; - ret = DoHandShakeMsgType(ssl, input, inOutIdx, type, size, totalSz); - } - else { - *inOutIdx += size; - ret = 0; - } + /* This branch is in order next, and a complete message. */ + ssl->keys.dtls_expected_peer_handshake_number++; + ret = DoHandShakeMsgType(ssl, input, inOutIdx, type, size, totalSz); } if (ssl->buffers.dtlsHandshake.buffer != NULL) { diff --git a/src/ssl.c b/src/ssl.c index 97ef9682f..e5a3c9050 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -2467,6 +2467,8 @@ int CyaSSL_dtls_got_timeout(CYASSL* ssl) { #ifdef CYASSL_DTLS int result = SSL_SUCCESS; + DtlsMsgListDelete(ssl->dtls_msg_list, ssl->heap); + ssl->dtls_msg_list = NULL; if (DtlsPoolTimeout(ssl) < 0 || DtlsPoolSend(ssl) < 0) { result = SSL_FATAL_ERROR; }