Refactor DTLS Window Update (Fix #5211)

1. Rename _DtlsUpdateWindow() as wolfSSL_Dtls_UpdateWindow() and make
   it public so it may be tested.
2. Rename the internal functions DtlsWindowUpdate(), DtlsWindowCheck(),
   and DtlsUpdateWindowGTSeq() as _DtlsWindowUpdate() and
   _DtlsWindowCheck(), and _DtlsUpdateWindowGTSeq().
3. When updating the DTLS sequence window, and the next sequence
   number (lo) wraps to zero, increment the next sequence number (hi)
   by 1.
4. Fix an off-by-one error that wrapped around when saving the
   packet sequence number in the bit-field window.
5. Adding a test for wolfSSL_DtlsUpdateWindow() function. With many test
   cases. It is set up in a table format with running check values.
6. Change location of incrementing the difference when calculating the
   location for setting the bit.
7. Updated the check of the sequence difference in the GT scenario.
8. In the DTLS window update functions remove newDiff and just use diff.
9. Handle the cases where the DTLS window crosses the high order word
   sequence number change.
10. Add a debug option to print out the state of the DTLS sequence number
   window.
This commit is contained in:
John Safranek
2022-06-30 12:27:44 -07:00
parent 90c2f4ad00
commit 8f3449ffea
3 changed files with 177 additions and 33 deletions

View File

@ -193,8 +193,8 @@ WOLFSSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS
#ifdef WOLFSSL_DTLS
static WC_INLINE int DtlsCheckWindow(WOLFSSL* ssl);
static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl);
static int _DtlsCheckWindow(WOLFSSL* ssl);
static int _DtlsUpdateWindow(WOLFSSL* ssl);
#endif
#ifdef WOLFSSL_DTLS13
@ -9878,7 +9878,7 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
/* DTLSv1.3 MUST check window after deprotecting to avoid timing channel
(RFC9147 Section 4.5.1) */
if (IsDtlsNotSctpMode(ssl) && !IsAtLeastTLSv1_3(ssl->version)) {
if (!DtlsCheckWindow(ssl) ||
if (!_DtlsCheckWindow(ssl) ||
(rh->type == application_data && ssl->keys.curEpoch == 0) ||
(rh->type == alert && ssl->options.handShakeDone &&
ssl->keys.curEpoch == 0 && ssl->keys.dtls_epoch != 0)) {
@ -15189,7 +15189,7 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx,
#ifdef WOLFSSL_DTLS
static WC_INLINE int DtlsCheckWindow(WOLFSSL* ssl)
static int _DtlsCheckWindow(WOLFSSL* ssl)
{
word32* window;
word16 cur_hi, next_hi;
@ -15358,18 +15358,19 @@ static WC_INLINE word32 UpdateHighwaterMark(word32 cur, word32 first,
}
#endif /* WOLFSSL_MULTICAST */
/* diff must be already incremented by one */
static void DtlsUpdateWindowGTSeq(word32 diff, word32* window)
/* diff is the difference between the message sequence and the
* expected sequence number. 0 is special where it is an overflow. */
static void _DtlsUpdateWindowGTSeq(word32 diff, word32* window)
{
word32 idx, newDiff, temp, i;
word32 idx, temp, i;
word32 oldWindow[WOLFSSL_DTLS_WINDOW_WORDS];
if (diff >= DTLS_SEQ_BITS)
if (diff == 0 || diff >= DTLS_SEQ_BITS)
XMEMSET(window, 0, DTLS_SEQ_SZ);
else {
temp = 0;
idx = diff / DTLS_WORD_BITS;
newDiff = diff % DTLS_WORD_BITS;
diff %= DTLS_WORD_BITS;
XMEMCPY(oldWindow, window, sizeof(oldWindow));
@ -15377,52 +15378,97 @@ static void DtlsUpdateWindowGTSeq(word32 diff, word32* window)
if (i < idx)
window[i] = 0;
else {
temp |= (oldWindow[i-idx] << newDiff);
temp |= (oldWindow[i-idx] << diff);
window[i] = temp;
temp = oldWindow[i-idx] >> (DTLS_WORD_BITS - newDiff - 1);
temp = oldWindow[i-idx] >> (DTLS_WORD_BITS - diff);
}
}
}
window[0] |= 1;
}
static WC_INLINE int _DtlsUpdateWindow(WOLFSSL* ssl, word16* next_hi,
word32* next_lo, word32 *window)
int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo,
word16* next_hi, word32* next_lo, word32 *window)
{
word32 cur_lo, diff;
word32 diff;
int curLT;
word16 cur_hi;
cur_hi = ssl->keys.curSeq_hi;
cur_lo = ssl->keys.curSeq_lo;
if (cur_hi == *next_hi) {
curLT = cur_lo < *next_lo;
diff = curLT ? *next_lo - cur_lo - 1 : cur_lo - *next_lo + 1;
diff = curLT ? *next_lo - cur_lo : cur_lo - *next_lo;
}
else {
curLT = cur_hi < *next_hi;
diff = curLT ? cur_lo - *next_lo - 1 : *next_lo - cur_lo + 1;
if (cur_hi > *next_hi + 1) {
/* reset window */
_DtlsUpdateWindowGTSeq(0, window);
*next_lo = cur_lo + 1;
if (*next_lo == 0)
*next_hi = cur_hi + 1;
else
*next_hi = cur_hi;
return 1;
}
else if (*next_hi > cur_hi + 1) {
return 1;
}
else {
curLT = cur_hi < *next_hi;
if (curLT) {
if (cur_lo > (word32)(0 - DTLS_SEQ_BITS) &&
*next_lo < DTLS_SEQ_BITS) {
diff = *next_lo - cur_lo;
}
else {
_DtlsUpdateWindowGTSeq(0, window);
*next_lo = cur_lo + 1;
if (*next_lo == 0)
*next_hi = cur_hi + 1;
else
*next_hi = cur_hi;
return 1;
}
}
else {
if (*next_lo > (word32)(0 - DTLS_SEQ_BITS) &&
cur_lo < DTLS_SEQ_BITS) {
diff = cur_lo - *next_lo;
}
else {
_DtlsUpdateWindowGTSeq(0, window);
*next_lo = cur_lo + 1;
if (*next_lo == 0)
*next_hi = cur_hi + 1;
else
*next_hi = cur_hi;
return 1;
}
}
}
}
if (curLT) {
word32 idx = diff / DTLS_WORD_BITS;
word32 newDiff = diff % DTLS_WORD_BITS;
word32 idx;
diff--;
idx = diff / DTLS_WORD_BITS;
diff %= DTLS_WORD_BITS;
if (idx < WOLFSSL_DTLS_WINDOW_WORDS)
window[idx] |= (1 << newDiff);
window[idx] |= (1 << diff);
}
else {
DtlsUpdateWindowGTSeq(diff, window);
_DtlsUpdateWindowGTSeq(diff + 1, window);
*next_lo = cur_lo + 1;
if (*next_lo < cur_lo)
(*next_hi)++;
if (*next_lo == 0)
*next_hi = cur_hi + 1;
else
*next_hi = cur_hi;
}
return 1;
}
static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl)
static int _DtlsUpdateWindow(WOLFSSL* ssl)
{
WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq;
word16 *next_hi;
@ -15483,7 +15529,8 @@ static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl)
window = peerSeq->prevWindow;
}
return _DtlsUpdateWindow(ssl, next_hi, next_lo, window);
return wolfSSL_DtlsUpdateWindow(ssl->keys.curSeq_hi, ssl->keys.curSeq_lo,
next_hi, next_lo, window);
}
#ifdef WOLFSSL_DTLS13
@ -15531,7 +15578,7 @@ static WC_INLINE int Dtls13UpdateWindow(WOLFSSL* ssl)
/* as we are considering nextSeq inside the window, we should add + 1 */
w64Increment(&diff64);
DtlsUpdateWindowGTSeq(w64GetLow32(diff64), window);
_DtlsUpdateWindowGTSeq(w64GetLow32(diff64), window);
w64Increment(&seq);
ssl->dtls13DecryptEpoch->nextPeerSeqNumber = seq;
@ -18656,7 +18703,7 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr)
#ifdef WOLFSSL_DTLS
if (IsDtlsNotSctpMode(ssl) && !IsAtLeastTLSv1_3(ssl->version)) {
DtlsUpdateWindow(ssl);
_DtlsUpdateWindow(ssl);
}
#endif /* WOLFSSL_DTLS */

View File

@ -338,9 +338,10 @@
#if (defined(SESSION_CERTS) && defined(TEST_PEER_CERT_CHAIN)) || \
defined(HAVE_SESSION_TICKET) || (defined(OPENSSL_EXTRA) && \
defined(WOLFSSL_CERT_EXT) && defined(WOLFSSL_CERT_GEN)) || \
defined(WOLFSSL_TEST_STATIC_BUILD)
defined(WOLFSSL_TEST_STATIC_BUILD) || defined(WOLFSSL_DTLS)
/* for testing SSL_get_peer_cert_chain, or SESSION_TICKET_HINT_DEFAULT,
* or for setting authKeyIdSrc in WOLFSSL_X509 */
* for setting authKeyIdSrc in WOLFSSL_X509, or testing DTLS sequence
* number tracking */
#include "wolfssl/internal.h"
#endif
@ -55541,6 +55542,94 @@ static void test_wolfSSL_FIPS_mode(void)
#endif
}
#ifdef WOLFSSL_DTLS
/* Prints out the current window */
static void DUW_TEST_print_window_binary(word32 h, word32 l, word32* w) {
#ifdef WOLFSSL_DEBUG_DTLS_WINDOW
int i;
for (i = WOLFSSL_DTLS_WINDOW_WORDS - 1; i >= 0; i--) {
word32 b = w[i];
int j;
/* Prints out a 32 bit binary number in big endian order */
for (j = 0; j < 32; j++, b <<= 1) {
if (b & (((word32)1) << 31))
printf("1");
else
printf("0");
}
printf(" ");
}
printf("cur_hi %u cur_lo %u\n", h, l);
#else
(void)h;
(void)l;
(void)w;
#endif
}
/* a - cur_hi
* b - cur_lo
* c - next_hi
* d - next_lo
* e - window
* f - expected next_hi
* g - expected next_lo
* h - expected window[1]
* i - expected window[0]
*/
#define DUW_TEST(a,b,c,d,e,f,g,h,i) do { \
wolfSSL_DtlsUpdateWindow((a), (b), &(c), &(d), (e)); \
DUW_TEST_print_window_binary((a), (b), (e)); \
AssertIntEQ((c), (f)); \
AssertIntEQ((d), (g)); \
AssertIntEQ((e[1]), (h)); \
AssertIntEQ((e[0]), (i)); \
} while (0)
static void test_wolfSSL_DtlsUpdateWindow(void)
{
word32 window[WOLFSSL_DTLS_WINDOW_WORDS];
word32 next_lo = 0;
word16 next_hi = 0;
printf(testingFmt, "wolfSSL_DtlsUpdateWindow()");
#ifdef WOLFSSL_DEBUG_DTLS_WINDOW
printf("\n");
#endif
XMEMSET(window, 0, sizeof window);
DUW_TEST(0, 0, next_hi, next_lo, window, 0, 1, 0, 0x01);
DUW_TEST(0, 1, next_hi, next_lo, window, 0, 2, 0, 0x03);
DUW_TEST(0, 5, next_hi, next_lo, window, 0, 6, 0, 0x31);
DUW_TEST(0, 4, next_hi, next_lo, window, 0, 6, 0, 0x33);
DUW_TEST(0, 100, next_hi, next_lo, window, 0, 101, 0, 0x01);
DUW_TEST(0, 101, next_hi, next_lo, window, 0, 102, 0, 0x03);
DUW_TEST(0, 133, next_hi, next_lo, window, 0, 134, 0x03, 0x01);
DUW_TEST(0, 200, next_hi, next_lo, window, 0, 201, 0, 0x01);
DUW_TEST(0, 264, next_hi, next_lo, window, 0, 265, 0, 0x01);
DUW_TEST(0, 0xFFFFFFFF, next_hi, next_lo, window, 1, 0, 0, 0x01);
DUW_TEST(0, 0xFFFFFFFD, next_hi, next_lo, window, 1, 0, 0, 0x05);
DUW_TEST(0, 0xFFFFFFFE, next_hi, next_lo, window, 1, 0, 0, 0x07);
DUW_TEST(1, 3, next_hi, next_lo, window, 1, 4, 0, 0x71);
DUW_TEST(1, 0, next_hi, next_lo, window, 1, 4, 0, 0x79);
DUW_TEST(1, 0xFFFFFFFF, next_hi, next_lo, window, 2, 0, 0, 0x01);
DUW_TEST(2, 3, next_hi, next_lo, window, 2, 4, 0, 0x11);
DUW_TEST(2, 0, next_hi, next_lo, window, 2, 4, 0, 0x19);
DUW_TEST(2, 25, next_hi, next_lo, window, 2, 26, 0, 0x6400001);
DUW_TEST(2, 27, next_hi, next_lo, window, 2, 28, 0, 0x19000005);
DUW_TEST(2, 29, next_hi, next_lo, window, 2, 30, 0, 0x64000015);
DUW_TEST(2, 33, next_hi, next_lo, window, 2, 34, 6, 0x40000151);
DUW_TEST(2, 60, next_hi, next_lo, window, 2, 61, 0x3200000A, 0x88000001);
DUW_TEST(2, 0xFFFFFFFD, next_hi, next_lo, window, 2, 0xFFFFFFFE, 0, 0x01);
DUW_TEST(3, 1, next_hi, next_lo, window, 3, 2, 0, 0x11);
DUW_TEST(99, 66, next_hi, next_lo, window, 99, 67, 0, 0x01);
DUW_TEST(100, 68, next_hi, next_lo, window, 100, 69, 0, 0x01);
printf(resultFmt, passed);
}
#endif /* WOLFSSL_DTLS */
/*----------------------------------------------------------------------------*
| Main
*----------------------------------------------------------------------------*/
@ -56430,6 +56519,9 @@ void ApiTest(void)
test_wc_CryptoCb();
test_wolfSSL_CTX_StaticMemory();
test_wolfSSL_FIPS_mode();
#ifdef WOLFSSL_DTLS
test_wolfSSL_DtlsUpdateWindow();
#endif
AssertIntEQ(test_ForceZero(), 0);

View File

@ -5422,6 +5422,11 @@ WOLFSSL_LOCAL int oid2nid(word32 oid, int grp);
WOLFSSL_LOCAL word32 nid2oid(int nid, int grp);
#endif
#ifdef WOLFSSL_DTLS
WOLFSSL_API int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo,
word16* next_hi, word32* next_lo, word32 *window);
#endif
#ifdef WOLFSSL_DTLS13
WOLFSSL_LOCAL struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl,