diff --git a/src/dtls.c b/src/dtls.c index 9475bd2ed..8b5234fee 100644 --- a/src/dtls.c +++ b/src/dtls.c @@ -972,8 +972,16 @@ int DoClientHelloStateless(WOLFSSL* ssl, const byte* input, word32 helloSz, #endif ret = SendStatelessReply((WOLFSSL*)ssl, &ch, isTls13); } - else + else { ssl->options.dtlsStateful = 1; + /* Update the window now that we enter the stateful parsing */ +#ifdef WOLFSSL_DTLS13 + if (isTls13) + ret = Dtls13UpdateWindowRecordRecvd(ssl); + else +#endif + DtlsUpdateWindow(ssl); + } } return ret; diff --git a/src/internal.c b/src/internal.c index 138d65667..f530f9683 100644 --- a/src/internal.c +++ b/src/internal.c @@ -215,7 +215,6 @@ WOLFSSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS #ifdef WOLFSSL_DTLS static int _DtlsCheckWindow(WOLFSSL* ssl); - static int _DtlsUpdateWindow(WOLFSSL* ssl); #endif #ifdef WOLFSSL_DTLS13 @@ -16975,7 +16974,7 @@ int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo, return 1; } -static int _DtlsUpdateWindow(WOLFSSL* ssl) +int DtlsUpdateWindow(WOLFSSL* ssl) { WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq; word16 *next_hi; @@ -17040,20 +17039,6 @@ static int _DtlsUpdateWindow(WOLFSSL* ssl) next_hi, next_lo, window); } -static WC_INLINE int DtlsShouldUpdateWindow(int ret) -{ - switch (ret) { - case 0: -#ifdef WOLFSSL_ASYNC_CRYPT - case WC_PENDING_E: -#endif - case APP_DATA_READY: - return 1; - default: - return 0; - } -} - #ifdef WOLFSSL_DTLS13 static int Dtls13UpdateWindow(WOLFSSL* ssl) @@ -17120,7 +17105,7 @@ static int Dtls13UpdateWindow(WOLFSSL* ssl) return 0; } -static WC_INLINE int Dtls13UpdateWindowRecordRecvd(WOLFSSL* ssl) +int Dtls13UpdateWindowRecordRecvd(WOLFSSL* ssl) { int ret = Dtls13UpdateWindow(ssl); if (ret != 0) @@ -20751,17 +20736,33 @@ default: /* the record layer is here */ case runProcessingOneRecord: #ifdef WOLFSSL_DTLS13 - if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version) && - !Dtls13CheckWindow(ssl)) { - /* drop packet */ - WOLFSSL_MSG("Dropping DTLS record outside receiving window"); - ssl->options.processReply = doProcessInit; - ssl->buffers.inputBuffer.idx += ssl->curSize; - if (ssl->buffers.inputBuffer.idx > - ssl->buffers.inputBuffer.length) - return BUFFER_E; + if (ssl->options.dtls) { + if (IsAtLeastTLSv1_3(ssl->version)) { + if (!Dtls13CheckWindow(ssl)) { + /* drop packet */ + WOLFSSL_MSG("Dropping DTLS record outside receiving " + "window"); + ssl->options.processReply = doProcessInit; + ssl->buffers.inputBuffer.idx += ssl->curSize; + if (ssl->buffers.inputBuffer.idx > + ssl->buffers.inputBuffer.length) + return BUFFER_E; - continue; + continue; + } + + /* Only update the window once we enter stateful parsing */ + if (ssl->options.dtlsStateful) { + ret = Dtls13UpdateWindowRecordRecvd(ssl); + if (ret != 0) { + WOLFSSL_ERROR(ret); + return ret; + } + } + } + else if (IsDtlsNotSctpMode(ssl)) { + DtlsUpdateWindow(ssl); + } } #endif /* WOLFSSL_DTLS13 */ ssl->options.processReply = runProcessingOneMessage; @@ -20828,15 +20829,12 @@ default: ssl->buffers.inputBuffer.buffer, &ssl->buffers.inputBuffer.idx, ssl->buffers.inputBuffer.length); - if (DtlsShouldUpdateWindow(ret) && - ssl->options.dtlsStateful) { - if (IsDtlsNotSctpMode(ssl)) - _DtlsUpdateWindow(ssl); + if (ret == 0 || ret == WC_PENDING_E) { /* Reset timeout as we have received a valid * DTLS handshake message */ ssl->dtls_timeout = ssl->dtls_timeout_init; } - if (ret != 0) { + else { if (SendFatalAlertOnly(ssl, ret) == SOCKET_ERROR_E) { ret = SOCKET_ERROR_E; @@ -20850,15 +20848,6 @@ default: ssl->buffers.inputBuffer.buffer, &ssl->buffers.inputBuffer.idx, ssl->buffers.inputBuffer.length); - if (DtlsShouldUpdateWindow(ret) && - ssl->options.dtlsStateful) { - int updateRet = - Dtls13UpdateWindowRecordRecvd(ssl); - if (updateRet != 0) { - WOLFSSL_ERROR(updateRet); - return updateRet; - } - } #ifdef WOLFSSL_EARLY_DATA if (ret == 0 && ssl->options.side == WOLFSSL_SERVER_END && @@ -20979,15 +20968,6 @@ default: WOLFSSL_ERROR_VERBOSE(UNKNOWN_RECORD_TYPE); return UNKNOWN_RECORD_TYPE; } -#ifdef WOLFSSL_DTLS13 - if (ssl->options.dtls) { - ret = Dtls13UpdateWindowRecordRecvd(ssl); - if (ret != 0) { - WOLFSSL_ERROR(ret); - return ret; - } - } -#endif break; } #endif @@ -21075,8 +21055,6 @@ default: #ifdef WOLFSSL_DTLS if (ssl->options.dtls) { WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq; - if (IsDtlsNotSctpMode(ssl)) - _DtlsUpdateWindow(ssl); #ifdef WOLFSSL_MULTICAST if (ssl->options.haveMcast) { peerSeq += ssl->keys.curPeerId; @@ -21138,26 +21116,10 @@ default: return SANITY_MSG_E; } #endif - ret = DoApplicationData(ssl, - ssl->buffers.inputBuffer.buffer, - &ssl->buffers.inputBuffer.idx, NO_SNIFF); -#ifdef WOLFSSL_DTLS - if (ssl->options.dtls && DtlsShouldUpdateWindow(ret)) { -#ifdef WOLFSSL_DTLS13 - if (IsAtLeastTLSv1_3(ssl->version)) { - int updateRet = Dtls13UpdateWindowRecordRecvd(ssl); - if (updateRet != 0) { - WOLFSSL_ERROR(updateRet); - return updateRet; - } - } - else -#endif - if (IsDtlsNotSctpMode(ssl)) - _DtlsUpdateWindow(ssl); - } -#endif - if (ret != 0) { + if ((ret = DoApplicationData(ssl, + ssl->buffers.inputBuffer.buffer, + &ssl->buffers.inputBuffer.idx, + NO_SNIFF)) != 0) { WOLFSSL_ERROR(ret); return ret; } @@ -21186,22 +21148,6 @@ default: /* Reset error if we got an alert level in ret */ if (ret > 0) ret = 0; -#ifdef WOLFSSL_DTLS - if (ssl->options.dtls) { -#ifdef WOLFSSL_DTLS13 - if (IsAtLeastTLSv1_3(ssl->version)) { - ret = Dtls13UpdateWindowRecordRecvd(ssl); - if (ret != 0) { - WOLFSSL_ERROR(ret); - return ret; - } - } - else -#endif - if (IsDtlsNotSctpMode(ssl)) - _DtlsUpdateWindow(ssl); - } -#endif break; #ifdef WOLFSSL_DTLS13 @@ -21216,13 +21162,6 @@ default: ssl->keys.padSz, &processedSize); ssl->buffers.inputBuffer.idx += processedSize; ssl->buffers.inputBuffer.idx += ssl->keys.padSz; - if (DtlsShouldUpdateWindow(ret)) { - int updateRet = Dtls13UpdateWindowRecordRecvd(ssl); - if (updateRet != 0) { - WOLFSSL_ERROR(updateRet); - return updateRet; - } - } if (ret != 0) return ret; break; diff --git a/wolfssl/internal.h b/wolfssl/internal.h index ac7ca5a82..10c395551 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -6479,6 +6479,7 @@ WOLFSSL_LOCAL word32 nid2oid(int nid, int grp); #ifdef WOLFSSL_DTLS WOLFSSL_API int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo, word16* next_hi, word32* next_lo, word32 *window); +WOLFSSL_LOCAL int DtlsUpdateWindow(WOLFSSL* ssl); WOLFSSL_LOCAL void DtlsResetState(WOLFSSL *ssl); WOLFSSL_LOCAL int DtlsIgnoreError(int err); WOLFSSL_LOCAL void DtlsSetSeqNumForReply(WOLFSSL* ssl); @@ -6547,6 +6548,7 @@ WOLFSSL_LOCAL void Dtls13RtxFlushBuffered(WOLFSSL* ssl, WOLFSSL_LOCAL int Dtls13RtxTimeout(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13ProcessBufferedMessages(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13CheckAEADFailLimit(WOLFSSL* ssl); +WOLFSSL_LOCAL int Dtls13UpdateWindowRecordRecvd(WOLFSSL* ssl); #endif /* WOLFSSL_DTLS13 */ #ifdef WOLFSSL_STATIC_EPHEMERAL