diff --git a/src/internal.c b/src/internal.c index 913dc1335..6d531b621 100644 --- a/src/internal.c +++ b/src/internal.c @@ -7831,7 +7831,7 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx, static INLINE int DtlsCheckWindow(WOLFSSL* ssl) { - DtlsSeq window; + word32* window; word16 cur_hi, next_hi; word32 cur_lo, next_lo, diff; int curLT; @@ -7869,23 +7869,28 @@ static INLINE int DtlsCheckWindow(WOLFSSL* ssl) } /* Check to see that the next value is greater than the number of messages - * trackable in the window (32 or 64), and that the difference between the - * next expected sequence number and the received sequence number is - * inside the window. */ + * trackable in the window, and that the difference between the next + * expected sequence number and the received sequence number is inside the + * window. */ if ((next_hi || next_lo > DTLS_SEQ_BITS) && curLT && (diff > DTLS_SEQ_BITS)) { WOLFSSL_MSG("Current record sequence number from the past."); return 0; } - else if (curLT && (window & ((DtlsSeq)1 << (diff - 1)))) { - WOLFSSL_MSG("Current record sequence number already received."); - return 0; - } else if (!curLT && (diff > DTLS_SEQ_BITS)) { WOLFSSL_MSG("Rejecting message too far into the future."); return 0; } + else if (curLT) { + word32 idx = diff / DTLS_WORD_BITS; + word32 newDiff = diff % DTLS_WORD_BITS; + + if (window[idx] & (1 << (newDiff - 1))) { + WOLFSSL_MSG("Current record sequence number already received."); + return 0; + } + } return 1; } @@ -7893,7 +7898,7 @@ static INLINE int DtlsCheckWindow(WOLFSSL* ssl) static INLINE int DtlsUpdateWindow(WOLFSSL* ssl) { - DtlsSeq* window; + word32* window; word32* next_lo; word16* next_hi; int curLT; @@ -7903,12 +7908,12 @@ static INLINE int DtlsUpdateWindow(WOLFSSL* ssl) if (ssl->keys.curEpoch == ssl->keys.nextEpoch) { next_hi = &ssl->keys.nextSeq_hi; next_lo = &ssl->keys.nextSeq_lo; - window = &ssl->keys.window; + window = ssl->keys.window; } else { next_hi = &ssl->keys.prevSeq_hi; next_lo = &ssl->keys.prevSeq_lo; - window = &ssl->keys.prevWindow; + window = ssl->keys.prevWindow; } cur_hi = ssl->keys.curSeq_hi; @@ -7924,14 +7929,36 @@ static INLINE int DtlsUpdateWindow(WOLFSSL* ssl) } if (curLT) { - *window |= ((DtlsSeq)1 << (diff - 1)); + word32 idx = diff / DTLS_WORD_BITS; + word32 newDiff = diff % DTLS_WORD_BITS; + + window[idx] |= (1 << (newDiff - 1)); } else { if (diff >= DTLS_SEQ_BITS) - *window = 0; - else - *window <<= (1 + diff); - *window |= 1; + XMEMSET(window, 0, DTLS_SEQ_SZ); + else { + word32 idx, newDiff, temp, i; + word32 oldWindow[WOLFSSL_DTLS_WINDOW_WORDS]; + + temp = 0; + diff++; + idx = diff / DTLS_WORD_BITS; + newDiff = diff % DTLS_WORD_BITS; + + XMEMCPY(oldWindow, window, sizeof(oldWindow)); + + for (i = 0; i < WOLFSSL_DTLS_WINDOW_WORDS; i++) { + if (i < idx) + window[i] = 0; + else { + temp |= (oldWindow[i-idx] << newDiff); + window[i] = temp; + temp = oldWindow[i-idx] >> (DTLS_WORD_BITS - newDiff); + } + } + } + window[0] |= 1; *next_lo = cur_lo + 1; if (*next_lo < cur_lo) (*next_hi)++; @@ -9597,8 +9624,9 @@ int ProcessReply(WOLFSSL* ssl) DtlsMsgPoolReset(ssl); ssl->keys.nextEpoch++; ssl->keys.nextSeq_lo = 0; - ssl->keys.prevWindow = ssl->keys.window; - ssl->keys.window = 0; + XMEMCPY(ssl->keys.prevWindow, ssl->keys.window, + DTLS_SEQ_SZ); + XMEMSET(ssl->keys.window, 0, DTLS_SEQ_SZ); } #endif diff --git a/wolfssl/internal.h b/wolfssl/internal.h index b29199fae..3f895df1f 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -1584,12 +1584,13 @@ typedef struct WOLFSSL_DTLS_CTX { #ifdef WOLFSSL_DTLS - #ifdef WORD64_AVAILABLE - typedef word64 DtlsSeq; - #else - typedef word32 DtlsSeq; - #endif - #define DTLS_SEQ_BITS (sizeof(DtlsSeq) * CHAR_BIT) + #ifndef WOLFSSL_DTLS_WINDOW_WORDS + #define WOLFSSL_DTLS_WINDOW_WORDS 2 + #endif /* WOLFSSL_DTLS_WINDOW_WORDS */ + + #define DTLS_WORD_BITS (sizeof(word32) * CHAR_BIT) + #define DTLS_SEQ_BITS (WOLFSSL_DTLS_WINDOW_WORDS * DTLS_WORD_BITS) + #define DTLS_SEQ_SZ (sizeof(word32) * WOLFSSL_DTLS_WINDOW_WORDS) #endif /* WOLFSSL_DTLS */ @@ -1617,7 +1618,8 @@ typedef struct Keys { word32 sequence_number_lo; #ifdef WOLFSSL_DTLS - DtlsSeq window; /* Sliding window for current epoch */ + word32 window[WOLFSSL_DTLS_WINDOW_WORDS]; + /* Sliding window for current epoch */ word16 nextEpoch; /* Expected epoch in next record */ word16 nextSeq_hi; /* Expected sequence in next record */ word32 nextSeq_lo; @@ -1626,7 +1628,8 @@ typedef struct Keys { word16 curSeq_hi; /* Received sequence in current record */ word32 curSeq_lo; - DtlsSeq prevWindow; /* Sliding window for old epoch */ + word32 prevWindow[WOLFSSL_DTLS_WINDOW_WORDS]; + /* Sliding window for old epoch */ word16 prevSeq_hi; /* Next sequence in allowed old epoch */ word32 prevSeq_lo;