diff --git a/src/dtls13.c b/src/dtls13.c index 91eb2b20d..11d7a018f 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -1637,6 +1637,102 @@ static int Dtls13AcceptFragmented(WOLFSSL *ssl, enum HandShakeType type) #endif return 0; } + +int Dtls13CheckEpoch(WOLFSSL* ssl, enum HandShakeType type) +{ + w64wrapper plainEpoch = w64From32(0x0, 0x0); + w64wrapper hsEpoch = w64From32(0x0, DTLS13_EPOCH_HANDSHAKE); + w64wrapper t0Epoch = w64From32(0x0, DTLS13_EPOCH_TRAFFIC0); + + if (IsAtLeastTLSv1_3(ssl->version)) { + switch (type) { + case client_hello: + case server_hello: + case hello_verify_request: + case hello_retry_request: + case hello_request: + if (!w64Equal(ssl->keys.curEpoch64, plainEpoch)) { + WOLFSSL_MSG("Msg should be epoch 0"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + break; + case encrypted_extensions: + case server_key_exchange: + case server_hello_done: + case client_key_exchange: + if (!w64Equal(ssl->keys.curEpoch64, hsEpoch)) { + if (ssl->options.side == WOLFSSL_CLIENT_END && + ssl->options.serverState < SERVER_HELLO_COMPLETE) { + /* before processing SH we don't know which version + * will be negotiated. */ + if (!w64Equal(ssl->keys.curEpoch64, plainEpoch)) { + WOLFSSL_MSG("Msg should be epoch 2 or 0"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + } + else { + WOLFSSL_MSG("Msg should be epoch 2"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + } + break; + case certificate_request: + case certificate: + case certificate_verify: + case finished: + if (!ssl->options.handShakeDone) { + if (!w64Equal(ssl->keys.curEpoch64, hsEpoch)) { + if (ssl->options.side == WOLFSSL_CLIENT_END && + ssl->options.serverState < SERVER_HELLO_COMPLETE) { + /* before processing SH we don't know which version + * will be negotiated. */ + if (!w64Equal(ssl->keys.curEpoch64, plainEpoch)) { + WOLFSSL_MSG("Msg should be epoch 2 or 0"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + } + else { + WOLFSSL_MSG("Msg should be epoch 2"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + } + } + else { + /* Allow epoch 2 in case of rtx */ + if (!w64GTE(ssl->keys.curEpoch64, hsEpoch)) { + WOLFSSL_MSG("Msg should be epoch 2+"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + } + break; + case certificate_status: + case change_cipher_hs: + case key_update: + case session_ticket: + if (!w64GTE(ssl->keys.curEpoch64, t0Epoch)) { + WOLFSSL_MSG("Msg should be epoch 3+"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + break; + case end_of_early_data: + case message_hash: + case no_shake: + default: + WOLFSSL_MSG("Unknown message type"); + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } + } + return 0; +} + /** * Dtls13HandshakeRecv() - process an handshake message. Deal with fragmentation if needed @@ -1672,6 +1768,12 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, return ret; } + ret = Dtls13CheckEpoch(ssl, (enum HandShakeType)handshakeType); + if (ret != 0) { + WOLFSSL_ERROR(ret); + return ret; + } + if (ssl->options.side == WOLFSSL_SERVER_END && ssl->options.acceptState < TLS13_ACCEPT_FIRST_REPLY_DONE) { if (handshakeType != client_hello) { diff --git a/src/internal.c b/src/internal.c index 2fce64fee..b52e01df8 100644 --- a/src/internal.c +++ b/src/internal.c @@ -21013,6 +21013,16 @@ int DoApplicationData(WOLFSSL* ssl, byte* input, word32* inOutIdx, int sniff) isEarlyData = isEarlyData && w64Equal(ssl->keys.curEpoch64, w64From32(0x0, DTLS13_EPOCH_EARLYDATA)); #endif +#ifdef WOLFSSL_DTLS13 + /* Application data should never appear in epoch 0 or 2 */ + if (ssl->options.tls1_3 && ssl->options.dtls && + (w64Equal(ssl->keys.curEpoch64, w64From32(0x0, DTLS13_EPOCH_HANDSHAKE)) + || w64Equal(ssl->keys.curEpoch64, w64From32(0x0, 0x0)))) + { + WOLFSSL_ERROR_VERBOSE(SANITY_MSG_E); + return SANITY_MSG_E; + } +#endif #ifdef WOLFSSL_EARLY_DATA if (isEarlyData && acceptEarlyData) { diff --git a/tests/api.c b/tests/api.c index d7840a46c..22d583af9 100644 --- a/tests/api.c +++ b/tests/api.c @@ -67811,6 +67811,7 @@ TEST_CASE testCases[] = { TEST_DECL(test_wolfSSL_SSLDisableRead), TEST_DECL(test_wolfSSL_inject), TEST_DECL(test_wolfSSL_dtls_cid_parse), + TEST_DECL(test_dtls13_epochs), TEST_DECL(test_ocsp_status_callback), TEST_DECL(test_ocsp_basic_verify), TEST_DECL(test_ocsp_response_parsing), diff --git a/tests/api/test_dtls.c b/tests/api/test_dtls.c index bb3cf51fd..6617c6ccc 100644 --- a/tests/api/test_dtls.c +++ b/tests/api/test_dtls.c @@ -595,3 +595,54 @@ int test_wolfSSL_dtls_cid_parse(void) #endif return EXPECT_RESULT(); } + +int test_dtls13_epochs(void) { + EXPECT_DECLS; +#if defined(WOLFSSL_DTLS13) && !defined(NO_WOLFSSL_CLIENT) + WOLFSSL_CTX* ctx = NULL; + WOLFSSL* ssl = NULL; + byte input[20]; + word32 inOutIdx = 0; + + XMEMSET(input, 0, sizeof(input)); + + ExpectNotNull(ctx = wolfSSL_CTX_new(wolfDTLSv1_3_client_method())); + ExpectNotNull(ssl = wolfSSL_new(ctx)); + /* Some manual setup to enter the epoch check */ + ExpectTrue(ssl->options.tls1_3 = 1); + + inOutIdx = 0; + if (ssl != NULL) ssl->keys.curEpoch64 = w64From32(0x0, 0x0); + ExpectIntEQ(DoApplicationData(ssl, input, &inOutIdx, 0), SANITY_MSG_E); + inOutIdx = 0; + if (ssl != NULL) ssl->keys.curEpoch64 = w64From32(0x0, 0x2); + ExpectIntEQ(DoApplicationData(ssl, input, &inOutIdx, 0), SANITY_MSG_E); + + if (ssl != NULL) ssl->keys.curEpoch64 = w64From32(0x0, 0x1); + ExpectIntEQ(Dtls13CheckEpoch(ssl, client_hello), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, server_hello), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, hello_verify_request), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, hello_retry_request), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, hello_request), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, encrypted_extensions), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, server_key_exchange), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, server_hello_done), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, client_key_exchange), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, certificate_request), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, certificate), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, certificate_verify), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, finished), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, certificate_status), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, change_cipher_hs), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, key_update), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, session_ticket), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, end_of_early_data), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, message_hash), SANITY_MSG_E); + ExpectIntEQ(Dtls13CheckEpoch(ssl, no_shake), SANITY_MSG_E); + + wolfSSL_CTX_free(ctx); + wolfSSL_free(ssl); +#endif + return EXPECT_RESULT(); +} + diff --git a/tests/api/test_dtls.h b/tests/api/test_dtls.h index 68980c3e5..a44b03676 100644 --- a/tests/api/test_dtls.h +++ b/tests/api/test_dtls.h @@ -25,5 +25,6 @@ int test_dtls12_basic_connection_id(void); int test_dtls13_basic_connection_id(void); int test_wolfSSL_dtls_cid_parse(void); +int test_dtls13_epochs(void); #endif /* TESTS_API_DTLS_H */ diff --git a/wolfssl/internal.h b/wolfssl/internal.h index a9838509f..fa211b732 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -2222,7 +2222,7 @@ WOLFSSL_LOCAL int DoFinished(WOLFSSL* ssl, const byte* input, word32* inOutIdx, WOLFSSL_LOCAL int DoTls13Finished(WOLFSSL* ssl, const byte* input, word32* inOutIdx, word32 size, word32 totalSz, int sniff); #endif -WOLFSSL_LOCAL int DoApplicationData(WOLFSSL* ssl, byte* input, word32* inOutIdx, +WOLFSSL_TEST_VIS int DoApplicationData(WOLFSSL* ssl, byte* input, word32* inOutIdx, int sniff); /* TLS v1.3 needs these */ WOLFSSL_LOCAL int HandleTlsResumption(WOLFSSL* ssl, Suites* clSuites); @@ -7052,6 +7052,7 @@ WOLFSSL_LOCAL int Dtls13HandshakeSend(WOLFSSL* ssl, byte* output, word16 output_size, word16 length, enum HandShakeType handshake_type, int hash_output); WOLFSSL_LOCAL int Dtls13RecordRecvd(WOLFSSL* ssl); +WOLFSSL_TEST_VIS int Dtls13CheckEpoch(WOLFSSL* ssl, enum HandShakeType type); WOLFSSL_LOCAL int Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32* inOutIdx, word32 totalSz); WOLFSSL_LOCAL int Dtls13HandshakeAddHeader(WOLFSSL* ssl, byte* output,