Update window in one place only when stateful

This commit is contained in:
Juliusz Sosinowicz
2023-10-03 16:46:32 +02:00
parent 8ac72750bc
commit 275c0a0838
3 changed files with 45 additions and 96 deletions

View File

@@ -972,8 +972,16 @@ int DoClientHelloStateless(WOLFSSL* ssl, const byte* input, word32 helloSz,
#endif #endif
ret = SendStatelessReply((WOLFSSL*)ssl, &ch, isTls13); ret = SendStatelessReply((WOLFSSL*)ssl, &ch, isTls13);
} }
else else {
ssl->options.dtlsStateful = 1; ssl->options.dtlsStateful = 1;
/* Update the window now that we enter the stateful parsing */
#ifdef WOLFSSL_DTLS13
if (isTls13)
ret = Dtls13UpdateWindowRecordRecvd(ssl);
else
#endif
DtlsUpdateWindow(ssl);
}
} }
return ret; return ret;

View File

@@ -215,7 +215,6 @@ WOLFSSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS
#ifdef WOLFSSL_DTLS #ifdef WOLFSSL_DTLS
static int _DtlsCheckWindow(WOLFSSL* ssl); static int _DtlsCheckWindow(WOLFSSL* ssl);
static int _DtlsUpdateWindow(WOLFSSL* ssl);
#endif #endif
#ifdef WOLFSSL_DTLS13 #ifdef WOLFSSL_DTLS13
@@ -16975,7 +16974,7 @@ int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo,
return 1; return 1;
} }
static int _DtlsUpdateWindow(WOLFSSL* ssl) int DtlsUpdateWindow(WOLFSSL* ssl)
{ {
WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq; WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq;
word16 *next_hi; word16 *next_hi;
@@ -17040,20 +17039,6 @@ static int _DtlsUpdateWindow(WOLFSSL* ssl)
next_hi, next_lo, window); next_hi, next_lo, window);
} }
static WC_INLINE int DtlsShouldUpdateWindow(int ret)
{
switch (ret) {
case 0:
#ifdef WOLFSSL_ASYNC_CRYPT
case WC_PENDING_E:
#endif
case APP_DATA_READY:
return 1;
default:
return 0;
}
}
#ifdef WOLFSSL_DTLS13 #ifdef WOLFSSL_DTLS13
static int Dtls13UpdateWindow(WOLFSSL* ssl) static int Dtls13UpdateWindow(WOLFSSL* ssl)
@@ -17120,7 +17105,7 @@ static int Dtls13UpdateWindow(WOLFSSL* ssl)
return 0; return 0;
} }
static WC_INLINE int Dtls13UpdateWindowRecordRecvd(WOLFSSL* ssl) int Dtls13UpdateWindowRecordRecvd(WOLFSSL* ssl)
{ {
int ret = Dtls13UpdateWindow(ssl); int ret = Dtls13UpdateWindow(ssl);
if (ret != 0) if (ret != 0)
@@ -20751,10 +20736,12 @@ default:
/* the record layer is here */ /* the record layer is here */
case runProcessingOneRecord: case runProcessingOneRecord:
#ifdef WOLFSSL_DTLS13 #ifdef WOLFSSL_DTLS13
if (ssl->options.dtls && IsAtLeastTLSv1_3(ssl->version) && if (ssl->options.dtls) {
!Dtls13CheckWindow(ssl)) { if (IsAtLeastTLSv1_3(ssl->version)) {
if (!Dtls13CheckWindow(ssl)) {
/* drop packet */ /* drop packet */
WOLFSSL_MSG("Dropping DTLS record outside receiving window"); WOLFSSL_MSG("Dropping DTLS record outside receiving "
"window");
ssl->options.processReply = doProcessInit; ssl->options.processReply = doProcessInit;
ssl->buffers.inputBuffer.idx += ssl->curSize; ssl->buffers.inputBuffer.idx += ssl->curSize;
if (ssl->buffers.inputBuffer.idx > if (ssl->buffers.inputBuffer.idx >
@@ -20763,6 +20750,20 @@ default:
continue; continue;
} }
/* Only update the window once we enter stateful parsing */
if (ssl->options.dtlsStateful) {
ret = Dtls13UpdateWindowRecordRecvd(ssl);
if (ret != 0) {
WOLFSSL_ERROR(ret);
return ret;
}
}
}
else if (IsDtlsNotSctpMode(ssl)) {
DtlsUpdateWindow(ssl);
}
}
#endif /* WOLFSSL_DTLS13 */ #endif /* WOLFSSL_DTLS13 */
ssl->options.processReply = runProcessingOneMessage; ssl->options.processReply = runProcessingOneMessage;
FALL_THROUGH; FALL_THROUGH;
@@ -20828,15 +20829,12 @@ default:
ssl->buffers.inputBuffer.buffer, ssl->buffers.inputBuffer.buffer,
&ssl->buffers.inputBuffer.idx, &ssl->buffers.inputBuffer.idx,
ssl->buffers.inputBuffer.length); ssl->buffers.inputBuffer.length);
if (DtlsShouldUpdateWindow(ret) && if (ret == 0 || ret == WC_PENDING_E) {
ssl->options.dtlsStateful) {
if (IsDtlsNotSctpMode(ssl))
_DtlsUpdateWindow(ssl);
/* Reset timeout as we have received a valid /* Reset timeout as we have received a valid
* DTLS handshake message */ * DTLS handshake message */
ssl->dtls_timeout = ssl->dtls_timeout_init; ssl->dtls_timeout = ssl->dtls_timeout_init;
} }
if (ret != 0) { else {
if (SendFatalAlertOnly(ssl, ret) if (SendFatalAlertOnly(ssl, ret)
== SOCKET_ERROR_E) { == SOCKET_ERROR_E) {
ret = SOCKET_ERROR_E; ret = SOCKET_ERROR_E;
@@ -20850,15 +20848,6 @@ default:
ssl->buffers.inputBuffer.buffer, ssl->buffers.inputBuffer.buffer,
&ssl->buffers.inputBuffer.idx, &ssl->buffers.inputBuffer.idx,
ssl->buffers.inputBuffer.length); ssl->buffers.inputBuffer.length);
if (DtlsShouldUpdateWindow(ret) &&
ssl->options.dtlsStateful) {
int updateRet =
Dtls13UpdateWindowRecordRecvd(ssl);
if (updateRet != 0) {
WOLFSSL_ERROR(updateRet);
return updateRet;
}
}
#ifdef WOLFSSL_EARLY_DATA #ifdef WOLFSSL_EARLY_DATA
if (ret == 0 && if (ret == 0 &&
ssl->options.side == WOLFSSL_SERVER_END && ssl->options.side == WOLFSSL_SERVER_END &&
@@ -20979,15 +20968,6 @@ default:
WOLFSSL_ERROR_VERBOSE(UNKNOWN_RECORD_TYPE); WOLFSSL_ERROR_VERBOSE(UNKNOWN_RECORD_TYPE);
return UNKNOWN_RECORD_TYPE; return UNKNOWN_RECORD_TYPE;
} }
#ifdef WOLFSSL_DTLS13
if (ssl->options.dtls) {
ret = Dtls13UpdateWindowRecordRecvd(ssl);
if (ret != 0) {
WOLFSSL_ERROR(ret);
return ret;
}
}
#endif
break; break;
} }
#endif #endif
@@ -21075,8 +21055,6 @@ default:
#ifdef WOLFSSL_DTLS #ifdef WOLFSSL_DTLS
if (ssl->options.dtls) { if (ssl->options.dtls) {
WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq; WOLFSSL_DTLS_PEERSEQ* peerSeq = ssl->keys.peerSeq;
if (IsDtlsNotSctpMode(ssl))
_DtlsUpdateWindow(ssl);
#ifdef WOLFSSL_MULTICAST #ifdef WOLFSSL_MULTICAST
if (ssl->options.haveMcast) { if (ssl->options.haveMcast) {
peerSeq += ssl->keys.curPeerId; peerSeq += ssl->keys.curPeerId;
@@ -21138,26 +21116,10 @@ default:
return SANITY_MSG_E; return SANITY_MSG_E;
} }
#endif #endif
ret = DoApplicationData(ssl, if ((ret = DoApplicationData(ssl,
ssl->buffers.inputBuffer.buffer, ssl->buffers.inputBuffer.buffer,
&ssl->buffers.inputBuffer.idx, NO_SNIFF); &ssl->buffers.inputBuffer.idx,
#ifdef WOLFSSL_DTLS NO_SNIFF)) != 0) {
if (ssl->options.dtls && DtlsShouldUpdateWindow(ret)) {
#ifdef WOLFSSL_DTLS13
if (IsAtLeastTLSv1_3(ssl->version)) {
int updateRet = Dtls13UpdateWindowRecordRecvd(ssl);
if (updateRet != 0) {
WOLFSSL_ERROR(updateRet);
return updateRet;
}
}
else
#endif
if (IsDtlsNotSctpMode(ssl))
_DtlsUpdateWindow(ssl);
}
#endif
if (ret != 0) {
WOLFSSL_ERROR(ret); WOLFSSL_ERROR(ret);
return ret; return ret;
} }
@@ -21186,22 +21148,6 @@ default:
/* Reset error if we got an alert level in ret */ /* Reset error if we got an alert level in ret */
if (ret > 0) if (ret > 0)
ret = 0; ret = 0;
#ifdef WOLFSSL_DTLS
if (ssl->options.dtls) {
#ifdef WOLFSSL_DTLS13
if (IsAtLeastTLSv1_3(ssl->version)) {
ret = Dtls13UpdateWindowRecordRecvd(ssl);
if (ret != 0) {
WOLFSSL_ERROR(ret);
return ret;
}
}
else
#endif
if (IsDtlsNotSctpMode(ssl))
_DtlsUpdateWindow(ssl);
}
#endif
break; break;
#ifdef WOLFSSL_DTLS13 #ifdef WOLFSSL_DTLS13
@@ -21216,13 +21162,6 @@ default:
ssl->keys.padSz, &processedSize); ssl->keys.padSz, &processedSize);
ssl->buffers.inputBuffer.idx += processedSize; ssl->buffers.inputBuffer.idx += processedSize;
ssl->buffers.inputBuffer.idx += ssl->keys.padSz; ssl->buffers.inputBuffer.idx += ssl->keys.padSz;
if (DtlsShouldUpdateWindow(ret)) {
int updateRet = Dtls13UpdateWindowRecordRecvd(ssl);
if (updateRet != 0) {
WOLFSSL_ERROR(updateRet);
return updateRet;
}
}
if (ret != 0) if (ret != 0)
return ret; return ret;
break; break;

View File

@@ -6479,6 +6479,7 @@ WOLFSSL_LOCAL word32 nid2oid(int nid, int grp);
#ifdef WOLFSSL_DTLS #ifdef WOLFSSL_DTLS
WOLFSSL_API int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo, WOLFSSL_API int wolfSSL_DtlsUpdateWindow(word16 cur_hi, word32 cur_lo,
word16* next_hi, word32* next_lo, word32 *window); word16* next_hi, word32* next_lo, word32 *window);
WOLFSSL_LOCAL int DtlsUpdateWindow(WOLFSSL* ssl);
WOLFSSL_LOCAL void DtlsResetState(WOLFSSL *ssl); WOLFSSL_LOCAL void DtlsResetState(WOLFSSL *ssl);
WOLFSSL_LOCAL int DtlsIgnoreError(int err); WOLFSSL_LOCAL int DtlsIgnoreError(int err);
WOLFSSL_LOCAL void DtlsSetSeqNumForReply(WOLFSSL* ssl); WOLFSSL_LOCAL void DtlsSetSeqNumForReply(WOLFSSL* ssl);
@@ -6547,6 +6548,7 @@ WOLFSSL_LOCAL void Dtls13RtxFlushBuffered(WOLFSSL* ssl,
WOLFSSL_LOCAL int Dtls13RtxTimeout(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13RtxTimeout(WOLFSSL* ssl);
WOLFSSL_LOCAL int Dtls13ProcessBufferedMessages(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13ProcessBufferedMessages(WOLFSSL* ssl);
WOLFSSL_LOCAL int Dtls13CheckAEADFailLimit(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13CheckAEADFailLimit(WOLFSSL* ssl);
WOLFSSL_LOCAL int Dtls13UpdateWindowRecordRecvd(WOLFSSL* ssl);
#endif /* WOLFSSL_DTLS13 */ #endif /* WOLFSSL_DTLS13 */
#ifdef WOLFSSL_STATIC_EPHEMERAL #ifdef WOLFSSL_STATIC_EPHEMERAL