diff --git a/src/internal.c b/src/internal.c index 07e1ea094..9f6693882 100644 --- a/src/internal.c +++ b/src/internal.c @@ -193,8 +193,8 @@ WOLFSSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS #ifdef WOLFSSL_DTLS - static WC_INLINE int DtlsCheckWindow(WOLFSSL* ssl); - static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl); + static int _DtlsCheckWindow(WOLFSSL* ssl); + static int _DtlsUpdateWindow(WOLFSSL* ssl); #endif #ifdef WOLFSSL_DTLS13 @@ -9878,7 +9878,7 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx, /* DTLSv1.3 MUST check window after deprotecting to avoid timing channel (RFC9147 Section 4.5.1) */ if (IsDtlsNotSctpMode(ssl) && !IsAtLeastTLSv1_3(ssl->version)) { - if (!DtlsCheckWindow(ssl) || + if (!_DtlsCheckWindow(ssl) || (rh->type == application_data && ssl->keys.curEpoch == 0) || (rh->type == alert && ssl->options.handShakeDone && ssl->keys.curEpoch == 0 && ssl->keys.dtls_epoch != 0)) { @@ -15189,7 +15189,7 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx, #ifdef WOLFSSL_DTLS -static WC_INLINE int DtlsCheckWindow(WOLFSSL* ssl) +static int _DtlsCheckWindow(WOLFSSL* ssl) { word32* window; word16 cur_hi, next_hi; @@ -15358,18 +15358,19 @@ static WC_INLINE word32 UpdateHighwaterMark(word32 cur, word32 first, } #endif /* WOLFSSL_MULTICAST */ -/* diff must be already incremented by one */ -static void DtlsUpdateWindowGTSeq(word32 diff, word32* window) +/* diff is the difference between the message sequence and the + * expected sequence number. 0 is special where it is an overflow. */ +static void _DtlsUpdateWindowGTSeq(word32 diff, word32* window) { - word32 idx, newDiff, temp, i; + word32 idx, temp, i; word32 oldWindow[WOLFSSL_DTLS_WINDOW_WORDS]; - if (diff >= DTLS_SEQ_BITS) + if (diff == 0 || diff >= DTLS_SEQ_BITS) XMEMSET(window, 0, DTLS_SEQ_SZ); else { temp = 0; idx = diff / DTLS_WORD_BITS; - newDiff = diff % DTLS_WORD_BITS; + diff %= DTLS_WORD_BITS; XMEMCPY(oldWindow, window, sizeof(oldWindow)); @@ -15377,52 +15378,97 @@ static void DtlsUpdateWindowGTSeq(word32 diff, word32* window) if (i < idx) window[i] = 0; else { - temp |= (oldWindow[i-idx] << newDiff); + temp |= (oldWindow[i-idx] << diff); window[i] = temp; - temp = oldWindow[i-idx] >> (DTLS_WORD_BITS - newDiff - 1); + temp = oldWindow[i-idx] >> (DTLS_WORD_BITS - diff); } } } window[0] |= 1; } -static WC_INLINE int _DtlsUpdateWindow(WOLFSSL* ssl, word16* next_hi, - word32* next_lo, word32 *window) +int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo, + word16* next_hi, word32* next_lo, word32 *window) { - word32 cur_lo, diff; + word32 diff; int curLT; - word16 cur_hi; - - cur_hi = ssl->keys.curSeq_hi; - cur_lo = ssl->keys.curSeq_lo; if (cur_hi == *next_hi) { curLT = cur_lo < *next_lo; - diff = curLT ? *next_lo - cur_lo - 1 : cur_lo - *next_lo + 1; + diff = curLT ? *next_lo - cur_lo : cur_lo - *next_lo; } else { - curLT = cur_hi < *next_hi; - diff = curLT ? cur_lo - *next_lo - 1 : *next_lo - cur_lo + 1; + if (cur_hi > *next_hi + 1) { + /* reset window */ + _DtlsUpdateWindowGTSeq(0, window); + *next_lo = cur_lo + 1; + if (*next_lo == 0) + *next_hi = cur_hi + 1; + else + *next_hi = cur_hi; + return 1; + } + else if (*next_hi > cur_hi + 1) { + return 1; + } + else { + curLT = cur_hi < *next_hi; + if (curLT) { + if (cur_lo > (word32)(0 - DTLS_SEQ_BITS) && + *next_lo < DTLS_SEQ_BITS) { + diff = *next_lo - cur_lo; + } + else { + _DtlsUpdateWindowGTSeq(0, window); + *next_lo = cur_lo + 1; + if (*next_lo == 0) + *next_hi = cur_hi + 1; + else + *next_hi = cur_hi; + return 1; + } + } + else { + if (*next_lo > (word32)(0 - DTLS_SEQ_BITS) && + cur_lo < DTLS_SEQ_BITS) { + diff = cur_lo - *next_lo; + } + else { + _DtlsUpdateWindowGTSeq(0, window); + *next_lo = cur_lo + 1; + if (*next_lo == 0) + *next_hi = cur_hi + 1; + else + *next_hi = cur_hi; + return 1; + } + } + } } if (curLT) { - word32 idx = diff / DTLS_WORD_BITS; - word32 newDiff = diff % DTLS_WORD_BITS; + word32 idx; + + diff--; + idx = diff / DTLS_WORD_BITS; + diff %= DTLS_WORD_BITS; if (idx < WOLFSSL_DTLS_WINDOW_WORDS) - window[idx] |= (1 << newDiff); + window[idx] |= (1 << diff); } else { - DtlsUpdateWindowGTSeq(diff, window); + _DtlsUpdateWindowGTSeq(diff + 1, window); *next_lo = cur_lo + 1; - if (*next_lo < cur_lo) - (*next_hi)++; + if (*next_lo == 0) + *next_hi = cur_hi + 1; + else + *next_hi = cur_hi; } return 1; } -static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl) +static int _DtlsUpdateWindow(WOLFSSL* ssl) { WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq; word16 *next_hi; @@ -15483,7 +15529,8 @@ static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl) window = peerSeq->prevWindow; } - return _DtlsUpdateWindow(ssl, next_hi, next_lo, window); + return wolfSSL_DtlsUpdateWindow(ssl->keys.curSeq_hi, ssl->keys.curSeq_lo, + next_hi, next_lo, window); } #ifdef WOLFSSL_DTLS13 @@ -15531,7 +15578,7 @@ static WC_INLINE int Dtls13UpdateWindow(WOLFSSL* ssl) /* as we are considering nextSeq inside the window, we should add + 1 */ w64Increment(&diff64); - DtlsUpdateWindowGTSeq(w64GetLow32(diff64), window); + _DtlsUpdateWindowGTSeq(w64GetLow32(diff64), window); w64Increment(&seq); ssl->dtls13DecryptEpoch->nextPeerSeqNumber = seq; @@ -18656,7 +18703,7 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) #ifdef WOLFSSL_DTLS if (IsDtlsNotSctpMode(ssl) && !IsAtLeastTLSv1_3(ssl->version)) { - DtlsUpdateWindow(ssl); + _DtlsUpdateWindow(ssl); } #endif /* WOLFSSL_DTLS */ diff --git a/tests/api.c b/tests/api.c index 4fd108285..d54fd3cc4 100644 --- a/tests/api.c +++ b/tests/api.c @@ -338,9 +338,10 @@ #if (defined(SESSION_CERTS) && defined(TEST_PEER_CERT_CHAIN)) || \ defined(HAVE_SESSION_TICKET) || (defined(OPENSSL_EXTRA) && \ defined(WOLFSSL_CERT_EXT) && defined(WOLFSSL_CERT_GEN)) || \ - defined(WOLFSSL_TEST_STATIC_BUILD) + defined(WOLFSSL_TEST_STATIC_BUILD) || defined(WOLFSSL_DTLS) /* for testing SSL_get_peer_cert_chain, or SESSION_TICKET_HINT_DEFAULT, - * or for setting authKeyIdSrc in WOLFSSL_X509 */ + * for setting authKeyIdSrc in WOLFSSL_X509, or testing DTLS sequence + * number tracking */ #include "wolfssl/internal.h" #endif @@ -55541,6 +55542,94 @@ static void test_wolfSSL_FIPS_mode(void) #endif } +#ifdef WOLFSSL_DTLS + +/* Prints out the current window */ +static void DUW_TEST_print_window_binary(word32 h, word32 l, word32* w) { +#ifdef WOLFSSL_DEBUG_DTLS_WINDOW + int i; + for (i = WOLFSSL_DTLS_WINDOW_WORDS - 1; i >= 0; i--) { + word32 b = w[i]; + int j; + /* Prints out a 32 bit binary number in big endian order */ + for (j = 0; j < 32; j++, b <<= 1) { + if (b & (((word32)1) << 31)) + printf("1"); + else + printf("0"); + } + printf(" "); + } + printf("cur_hi %u cur_lo %u\n", h, l); +#else + (void)h; + (void)l; + (void)w; +#endif +} + +/* a - cur_hi + * b - cur_lo + * c - next_hi + * d - next_lo + * e - window + * f - expected next_hi + * g - expected next_lo + * h - expected window[1] + * i - expected window[0] + */ +#define DUW_TEST(a,b,c,d,e,f,g,h,i) do { \ + wolfSSL_DtlsUpdateWindow((a), (b), &(c), &(d), (e)); \ + DUW_TEST_print_window_binary((a), (b), (e)); \ + AssertIntEQ((c), (f)); \ + AssertIntEQ((d), (g)); \ + AssertIntEQ((e[1]), (h)); \ + AssertIntEQ((e[0]), (i)); \ +} while (0) + +static void test_wolfSSL_DtlsUpdateWindow(void) +{ + word32 window[WOLFSSL_DTLS_WINDOW_WORDS]; + word32 next_lo = 0; + word16 next_hi = 0; + + printf(testingFmt, "wolfSSL_DtlsUpdateWindow()"); +#ifdef WOLFSSL_DEBUG_DTLS_WINDOW + printf("\n"); +#endif + + XMEMSET(window, 0, sizeof window); + DUW_TEST(0, 0, next_hi, next_lo, window, 0, 1, 0, 0x01); + DUW_TEST(0, 1, next_hi, next_lo, window, 0, 2, 0, 0x03); + DUW_TEST(0, 5, next_hi, next_lo, window, 0, 6, 0, 0x31); + DUW_TEST(0, 4, next_hi, next_lo, window, 0, 6, 0, 0x33); + DUW_TEST(0, 100, next_hi, next_lo, window, 0, 101, 0, 0x01); + DUW_TEST(0, 101, next_hi, next_lo, window, 0, 102, 0, 0x03); + DUW_TEST(0, 133, next_hi, next_lo, window, 0, 134, 0x03, 0x01); + DUW_TEST(0, 200, next_hi, next_lo, window, 0, 201, 0, 0x01); + DUW_TEST(0, 264, next_hi, next_lo, window, 0, 265, 0, 0x01); + DUW_TEST(0, 0xFFFFFFFF, next_hi, next_lo, window, 1, 0, 0, 0x01); + DUW_TEST(0, 0xFFFFFFFD, next_hi, next_lo, window, 1, 0, 0, 0x05); + DUW_TEST(0, 0xFFFFFFFE, next_hi, next_lo, window, 1, 0, 0, 0x07); + DUW_TEST(1, 3, next_hi, next_lo, window, 1, 4, 0, 0x71); + DUW_TEST(1, 0, next_hi, next_lo, window, 1, 4, 0, 0x79); + DUW_TEST(1, 0xFFFFFFFF, next_hi, next_lo, window, 2, 0, 0, 0x01); + DUW_TEST(2, 3, next_hi, next_lo, window, 2, 4, 0, 0x11); + DUW_TEST(2, 0, next_hi, next_lo, window, 2, 4, 0, 0x19); + DUW_TEST(2, 25, next_hi, next_lo, window, 2, 26, 0, 0x6400001); + DUW_TEST(2, 27, next_hi, next_lo, window, 2, 28, 0, 0x19000005); + DUW_TEST(2, 29, next_hi, next_lo, window, 2, 30, 0, 0x64000015); + DUW_TEST(2, 33, next_hi, next_lo, window, 2, 34, 6, 0x40000151); + DUW_TEST(2, 60, next_hi, next_lo, window, 2, 61, 0x3200000A, 0x88000001); + DUW_TEST(2, 0xFFFFFFFD, next_hi, next_lo, window, 2, 0xFFFFFFFE, 0, 0x01); + DUW_TEST(3, 1, next_hi, next_lo, window, 3, 2, 0, 0x11); + DUW_TEST(99, 66, next_hi, next_lo, window, 99, 67, 0, 0x01); + DUW_TEST(100, 68, next_hi, next_lo, window, 100, 69, 0, 0x01); + + printf(resultFmt, passed); +} +#endif /* WOLFSSL_DTLS */ + /*----------------------------------------------------------------------------* | Main *----------------------------------------------------------------------------*/ @@ -56430,6 +56519,9 @@ void ApiTest(void) test_wc_CryptoCb(); test_wolfSSL_CTX_StaticMemory(); test_wolfSSL_FIPS_mode(); +#ifdef WOLFSSL_DTLS + test_wolfSSL_DtlsUpdateWindow(); +#endif AssertIntEQ(test_ForceZero(), 0); diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 34be9e14e..d2833f72d 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -5422,6 +5422,11 @@ WOLFSSL_LOCAL int oid2nid(word32 oid, int grp); WOLFSSL_LOCAL word32 nid2oid(int nid, int grp); #endif +#ifdef WOLFSSL_DTLS +WOLFSSL_API int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo, + word16* next_hi, word32* next_lo, word32 *window); +#endif + #ifdef WOLFSSL_DTLS13 WOLFSSL_LOCAL struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl,