Merge pull request #5278 from ejohnstown/dtls-seq

Refactor DTLS Window Update (Fix #5211)
This commit is contained in:
David Garske
2022-07-07 10:22:21 -07:00
committed by GitHub
3 changed files with 184 additions and 33 deletions

View File

@@ -193,8 +193,8 @@ WOLFSSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS
#ifdef WOLFSSL_DTLS
static WC_INLINE int DtlsCheckWindow(WOLFSSL* ssl);
static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl);
static int _DtlsCheckWindow(WOLFSSL* ssl);
static int _DtlsUpdateWindow(WOLFSSL* ssl);
#endif
#ifdef WOLFSSL_DTLS13
@@ -9878,7 +9878,7 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx,
/* DTLSv1.3 MUST check window after deprotecting to avoid timing channel
(RFC9147 Section 4.5.1) */
if (IsDtlsNotSctpMode(ssl) && !IsAtLeastTLSv1_3(ssl->version)) {
if (!DtlsCheckWindow(ssl) ||
if (!_DtlsCheckWindow(ssl) ||
(rh->type == application_data && ssl->keys.curEpoch == 0) ||
(rh->type == alert && ssl->options.handShakeDone &&
ssl->keys.curEpoch == 0 && ssl->keys.dtls_epoch != 0)) {
@@ -15189,7 +15189,7 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx,
#ifdef WOLFSSL_DTLS
static WC_INLINE int DtlsCheckWindow(WOLFSSL* ssl)
static int _DtlsCheckWindow(WOLFSSL* ssl)
{
word32* window;
word16 cur_hi, next_hi;
@@ -15358,18 +15358,19 @@ static WC_INLINE word32 UpdateHighwaterMark(word32 cur, word32 first,
}
#endif /* WOLFSSL_MULTICAST */
/* diff must be already incremented by one */
static void DtlsUpdateWindowGTSeq(word32 diff, word32* window)
/* diff is the difference between the message sequence and the
* expected sequence number. 0 is special where it is an overflow. */
static void _DtlsUpdateWindowGTSeq(word32 diff, word32* window)
{
word32 idx, newDiff, temp, i;
word32 idx, temp, i;
word32 oldWindow[WOLFSSL_DTLS_WINDOW_WORDS];
if (diff >= DTLS_SEQ_BITS)
if (diff == 0 || diff >= DTLS_SEQ_BITS)
XMEMSET(window, 0, DTLS_SEQ_SZ);
else {
temp = 0;
idx = diff / DTLS_WORD_BITS;
newDiff = diff % DTLS_WORD_BITS;
diff %= DTLS_WORD_BITS;
XMEMCPY(oldWindow, window, sizeof(oldWindow));
@@ -15377,52 +15378,98 @@ static void DtlsUpdateWindowGTSeq(word32 diff, word32* window)
if (i < idx)
window[i] = 0;
else {
temp |= (oldWindow[i-idx] << newDiff);
temp |= (oldWindow[i-idx] << diff);
window[i] = temp;
temp = oldWindow[i-idx] >> (DTLS_WORD_BITS - newDiff - 1);
temp = oldWindow[i-idx] >> (DTLS_WORD_BITS - diff);
}
}
}
window[0] |= 1;
}
static WC_INLINE int _DtlsUpdateWindow(WOLFSSL* ssl, word16* next_hi,
word32* next_lo, word32 *window)
int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo,
word16* next_hi, word32* next_lo, word32 *window)
{
word32 cur_lo, diff;
word32 diff;
int curLT;
word16 cur_hi;
cur_hi = ssl->keys.curSeq_hi;
cur_lo = ssl->keys.curSeq_lo;
if (cur_hi == *next_hi) {
curLT = cur_lo < *next_lo;
diff = curLT ? *next_lo - cur_lo - 1 : cur_lo - *next_lo + 1;
diff = curLT ? *next_lo - cur_lo : cur_lo - *next_lo;
}
else {
curLT = cur_hi < *next_hi;
diff = curLT ? cur_lo - *next_lo - 1 : *next_lo - cur_lo + 1;
if (cur_hi > *next_hi + 1) {
/* reset window */
_DtlsUpdateWindowGTSeq(0, window);
*next_lo = cur_lo + 1;
if (*next_lo == 0)
*next_hi = cur_hi + 1;
else
*next_hi = cur_hi;
return 1;
}
else if (*next_hi > cur_hi + 1) {
return 1;
}
else {
curLT = cur_hi < *next_hi;
if (curLT) {
if (*next_lo < DTLS_SEQ_BITS &&
cur_lo >= (((word32)0xFFFFFFFF) - DTLS_SEQ_BITS)) {
/* diff here can still result in a difference that can not
* be stored in the window. The index is checked against
* WOLFSSL_DTLS_WINDOW_WORDS later. */
diff = *next_lo + ((word32)0xFFFFFFFF - cur_lo) + 1;
}
else {
/* Too far back to update */
return 1;
}
}
else {
if (*next_lo >= (((word32)0xFFFFFFFF) - DTLS_SEQ_BITS) &&
cur_lo < DTLS_SEQ_BITS) {
/* diff here can still result in a difference that can not
* be stored in the window. The index is checked against
* WOLFSSL_DTLS_WINDOW_WORDS later. */
diff = cur_lo - *next_lo;
}
else {
_DtlsUpdateWindowGTSeq(0, window);
*next_lo = cur_lo + 1;
if (*next_lo == 0)
*next_hi = cur_hi + 1;
else
*next_hi = cur_hi;
return 1;
}
}
}
}
if (curLT) {
word32 idx = diff / DTLS_WORD_BITS;
word32 newDiff = diff % DTLS_WORD_BITS;
word32 idx;
diff--;
idx = diff / DTLS_WORD_BITS;
diff %= DTLS_WORD_BITS;
if (idx < WOLFSSL_DTLS_WINDOW_WORDS)
window[idx] |= (1 << newDiff);
window[idx] |= (1 << diff);
}
else {
DtlsUpdateWindowGTSeq(diff, window);
_DtlsUpdateWindowGTSeq(diff + 1, window);
*next_lo = cur_lo + 1;
if (*next_lo < cur_lo)
(*next_hi)++;
if (*next_lo == 0)
*next_hi = cur_hi + 1;
else
*next_hi = cur_hi;
}
return 1;
}
static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl)
static int _DtlsUpdateWindow(WOLFSSL* ssl)
{
WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq;
word16 *next_hi;
@@ -15483,7 +15530,8 @@ static WC_INLINE int DtlsUpdateWindow(WOLFSSL* ssl)
window = peerSeq->prevWindow;
}
return _DtlsUpdateWindow(ssl, next_hi, next_lo, window);
return wolfSSL_DtlsUpdateWindow(ssl->keys.curSeq_hi, ssl->keys.curSeq_lo,
next_hi, next_lo, window);
}
#ifdef WOLFSSL_DTLS13
@@ -15531,7 +15579,7 @@ static WC_INLINE int Dtls13UpdateWindow(WOLFSSL* ssl)
/* as we are considering nextSeq inside the window, we should add + 1 */
w64Increment(&diff64);
DtlsUpdateWindowGTSeq(w64GetLow32(diff64), window);
_DtlsUpdateWindowGTSeq(w64GetLow32(diff64), window);
w64Increment(&seq);
ssl->dtls13DecryptEpoch->nextPeerSeqNumber = seq;
@@ -18656,7 +18704,7 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr)
#ifdef WOLFSSL_DTLS
if (IsDtlsNotSctpMode(ssl) && !IsAtLeastTLSv1_3(ssl->version)) {
DtlsUpdateWindow(ssl);
_DtlsUpdateWindow(ssl);
}
#endif /* WOLFSSL_DTLS */

View File

@@ -338,9 +338,10 @@
#if (defined(SESSION_CERTS) && defined(TEST_PEER_CERT_CHAIN)) || \
defined(HAVE_SESSION_TICKET) || (defined(OPENSSL_EXTRA) && \
defined(WOLFSSL_CERT_EXT) && defined(WOLFSSL_CERT_GEN)) || \
defined(WOLFSSL_TEST_STATIC_BUILD)
defined(WOLFSSL_TEST_STATIC_BUILD) || defined(WOLFSSL_DTLS)
/* for testing SSL_get_peer_cert_chain, or SESSION_TICKET_HINT_DEFAULT,
* or for setting authKeyIdSrc in WOLFSSL_X509 */
* for setting authKeyIdSrc in WOLFSSL_X509, or testing DTLS sequence
* number tracking */
#include "wolfssl/internal.h"
#endif
@@ -55546,6 +55547,100 @@ static void test_wolfSSL_FIPS_mode(void)
#endif
}
#ifdef WOLFSSL_DTLS
/* Prints out the current window */
static void DUW_TEST_print_window_binary(word32 h, word32 l, word32* w) {
#ifdef WOLFSSL_DEBUG_DTLS_WINDOW
int i;
for (i = WOLFSSL_DTLS_WINDOW_WORDS - 1; i >= 0; i--) {
word32 b = w[i];
int j;
/* Prints out a 32 bit binary number in big endian order */
for (j = 0; j < 32; j++, b <<= 1) {
if (b & (((word32)1) << 31))
printf("1");
else
printf("0");
}
printf(" ");
}
printf("cur_hi %u cur_lo %u\n", h, l);
#else
(void)h;
(void)l;
(void)w;
#endif
}
/* a - cur_hi
* b - cur_lo
* c - next_hi
* d - next_lo
* e - window
* f - expected next_hi
* g - expected next_lo
* h - expected window[1]
* i - expected window[0]
*/
#define DUW_TEST(a,b,c,d,e,f,g,h,i) do { \
wolfSSL_DtlsUpdateWindow((a), (b), &(c), &(d), (e)); \
DUW_TEST_print_window_binary((a), (b), (e)); \
AssertIntEQ((c), (f)); \
AssertIntEQ((d), (g)); \
AssertIntEQ((e[1]), (h)); \
AssertIntEQ((e[0]), (i)); \
} while (0)
static void test_wolfSSL_DtlsUpdateWindow(void)
{
word32 window[WOLFSSL_DTLS_WINDOW_WORDS];
word32 next_lo = 0;
word16 next_hi = 0;
printf(testingFmt, "wolfSSL_DtlsUpdateWindow()");
#ifdef WOLFSSL_DEBUG_DTLS_WINDOW
printf("\n");
#endif
XMEMSET(window, 0, sizeof window);
DUW_TEST(0, 0, next_hi, next_lo, window, 0, 1, 0, 0x01);
DUW_TEST(0, 1, next_hi, next_lo, window, 0, 2, 0, 0x03);
DUW_TEST(0, 5, next_hi, next_lo, window, 0, 6, 0, 0x31);
DUW_TEST(0, 4, next_hi, next_lo, window, 0, 6, 0, 0x33);
DUW_TEST(0, 100, next_hi, next_lo, window, 0, 101, 0, 0x01);
DUW_TEST(0, 101, next_hi, next_lo, window, 0, 102, 0, 0x03);
DUW_TEST(0, 133, next_hi, next_lo, window, 0, 134, 0x03, 0x01);
DUW_TEST(0, 200, next_hi, next_lo, window, 0, 201, 0, 0x01);
DUW_TEST(0, 264, next_hi, next_lo, window, 0, 265, 0, 0x01);
DUW_TEST(0, 0xFFFFFFFF, next_hi, next_lo, window, 1, 0, 0, 0x01);
DUW_TEST(0, 0xFFFFFFFD, next_hi, next_lo, window, 1, 0, 0, 0x05);
DUW_TEST(0, 0xFFFFFFFE, next_hi, next_lo, window, 1, 0, 0, 0x07);
DUW_TEST(1, 3, next_hi, next_lo, window, 1, 4, 0, 0x71);
DUW_TEST(1, 0, next_hi, next_lo, window, 1, 4, 0, 0x79);
DUW_TEST(1, 0xFFFFFFFF, next_hi, next_lo, window, 2, 0, 0, 0x01);
DUW_TEST(2, 3, next_hi, next_lo, window, 2, 4, 0, 0x11);
DUW_TEST(2, 0, next_hi, next_lo, window, 2, 4, 0, 0x19);
DUW_TEST(2, 25, next_hi, next_lo, window, 2, 26, 0, 0x6400001);
DUW_TEST(2, 27, next_hi, next_lo, window, 2, 28, 0, 0x19000005);
DUW_TEST(2, 29, next_hi, next_lo, window, 2, 30, 0, 0x64000015);
DUW_TEST(2, 33, next_hi, next_lo, window, 2, 34, 6, 0x40000151);
DUW_TEST(2, 60, next_hi, next_lo, window, 2, 61, 0x3200000A, 0x88000001);
DUW_TEST(1, 0xFFFFFFF0, next_hi, next_lo, window, 2, 61, 0x3200000A, 0x88000001);
DUW_TEST(2, 0xFFFFFFFD, next_hi, next_lo, window, 2, 0xFFFFFFFE, 0, 0x01);
DUW_TEST(3, 1, next_hi, next_lo, window, 3, 2, 0, 0x11);
DUW_TEST(99, 66, next_hi, next_lo, window, 99, 67, 0, 0x01);
DUW_TEST(50, 66, next_hi, next_lo, window, 99, 67, 0, 0x01);
DUW_TEST(100, 68, next_hi, next_lo, window, 100, 69, 0, 0x01);
DUW_TEST(99, 50, next_hi, next_lo, window, 100, 69, 0, 0x01);
DUW_TEST(99, 0xFFFFFFFF, next_hi, next_lo, window, 100, 69, 0, 0x01);
DUW_TEST(150, 0xFFFFFFFF, next_hi, next_lo, window, 151, 0, 0, 0x01);
DUW_TEST(152, 0xFFFFFFFF, next_hi, next_lo, window, 153, 0, 0, 0x01);
printf(resultFmt, passed);
}
#endif /* WOLFSSL_DTLS */
/*----------------------------------------------------------------------------*
| Main
*----------------------------------------------------------------------------*/
@@ -56435,6 +56530,9 @@ void ApiTest(void)
test_wc_CryptoCb();
test_wolfSSL_CTX_StaticMemory();
test_wolfSSL_FIPS_mode();
#ifdef WOLFSSL_DTLS
test_wolfSSL_DtlsUpdateWindow();
#endif
AssertIntEQ(test_ForceZero(), 0);

View File

@@ -5422,6 +5422,11 @@ WOLFSSL_LOCAL int oid2nid(word32 oid, int grp);
WOLFSSL_LOCAL word32 nid2oid(int nid, int grp);
#endif
#ifdef WOLFSSL_DTLS
WOLFSSL_API int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo,
word16* next_hi, word32* next_lo, word32 *window);
#endif
#ifdef WOLFSSL_DTLS13
WOLFSSL_LOCAL struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl,