diff --git a/src/dtls13.c b/src/dtls13.c index 399fc7a61..9c729fa1e 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -1898,11 +1898,11 @@ static int _Dtls13HandshakeRecv(WOLFSSL* ssl, byte* input, word32 size, ret = DoTls13HandShakeMsgType(ssl, input, &idx, handshakeType, messageLength, size); + *processedSize = idx; if (ret != 0) return ret; Dtls13MsgWasProcessed(ssl, (enum HandShakeType)handshakeType); - *processedSize = idx; /* check if we have buffered some message */ if (Dtls13NextMessageComplete(ssl)) diff --git a/tests/api.c b/tests/api.c index f0090c183..0cecabfc0 100644 --- a/tests/api.c +++ b/tests/api.c @@ -47055,7 +47055,7 @@ static int test_multiple_shutdown_nonblocking(void) ExpectIntEQ(test_ctx.s_len, 0); ExpectIntEQ(ssl_c->buffers.outputBuffer.length, 0); - test_memio_simulate_want_write(&test_ctx, 0, 1); + test_memio_simulate_want_write(&test_ctx, 1, 1); /* * We call wolfSSL_shutdown multiple times to to check that it doesn't add @@ -47078,7 +47078,7 @@ static int test_multiple_shutdown_nonblocking(void) ExpectIntEQ(ssl_c->buffers.outputBuffer.length, size_of_last_packet); /* now send the CLOSE_NOTIFY to the server for real, expecting shutdown not done */ - test_memio_simulate_want_write(&test_ctx, 0, 0); + test_memio_simulate_want_write(&test_ctx, 1, 0); ExpectIntEQ(wolfSSL_shutdown(ssl_c), WOLFSSL_SHUTDOWN_NOT_DONE); /* output buffer should be empty and socket buffer should contain the message */ diff --git a/tests/api/test_dtls.c b/tests/api/test_dtls.c index baf94f325..ca3ab164c 100644 --- a/tests/api/test_dtls.c +++ b/tests/api/test_dtls.c @@ -568,6 +568,166 @@ int test_dtls13_basic_connection_id(void) return EXPECT_RESULT(); } +/** Test DTLS 1.3 behavior when server hits WANT_WRITE during HRR + * The test sets up a DTLS 1.3 connection where the server is forced to + * return WANT_WRITE when sending the HelloRetryRequest. After the handshake, + * application data is exchanged in both directions to verify the connection + * works as expected. + */ +int test_dtls13_hrr_want_write(void) +{ + EXPECT_DECLS; +#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && defined(WOLFSSL_DTLS13) + WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL; + WOLFSSL *ssl_c = NULL, *ssl_s = NULL; + const char msg[] = "hello"; + const int msgLen = sizeof(msg); + struct test_memio_ctx test_ctx; + char readBuf[sizeof(msg)]; + + XMEMSET(&test_ctx, 0, sizeof(test_ctx)); + + ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s, + wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method), + 0); + + /* Client sends first ClientHello */ + ExpectIntEQ(wolfSSL_negotiate(ssl_c), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); + + /* Force server to hit WANT_WRITE when producing the HRR */ + test_memio_simulate_want_write(&test_ctx, 0, 1); + ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_WRITE); + + /* Allow the server to flush the HRR and proceed */ + test_memio_simulate_want_write(&test_ctx, 0, 0); + ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); + + /* Resume the DTLS 1.3 handshake */ + ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0); + + /* Verify post-handshake application data in both directions */ + XMEMSET(readBuf, 0, sizeof(readBuf)); + ExpectIntEQ(wolfSSL_write(ssl_c, msg, msgLen), msgLen); + ExpectIntEQ(wolfSSL_read(ssl_s, readBuf, sizeof(readBuf)), msgLen); + ExpectStrEQ(readBuf, msg); + + XMEMSET(readBuf, 0, sizeof(readBuf)); + ExpectIntEQ(wolfSSL_write(ssl_s, msg, msgLen), msgLen); + ExpectIntEQ(wolfSSL_read(ssl_c, readBuf, sizeof(readBuf)), msgLen); + ExpectStrEQ(readBuf, msg); + + wolfSSL_free(ssl_c); + wolfSSL_CTX_free(ctx_c); + wolfSSL_free(ssl_s); + wolfSSL_CTX_free(ctx_s); +#endif + return EXPECT_RESULT(); +} + +#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && defined(WOLFSSL_DTLS13) +struct test_dtls13_wwrite_ctx { + int want_write; + struct test_memio_ctx *text_ctx; +}; +static int test_dtls13_want_write_send_cb(WOLFSSL *ssl, char *data, int sz, void *ctx) +{ + struct test_dtls13_wwrite_ctx *wwctx = (struct test_dtls13_wwrite_ctx *)ctx; + wwctx->want_write = !wwctx->want_write; + if (wwctx->want_write) { + return WOLFSSL_CBIO_ERR_WANT_WRITE; + } + return test_memio_write_cb(ssl, data, sz, wwctx->text_ctx); +} +#endif +/** Test DTLS 1.3 behavior when every other write returns WANT_WRITE + * The test sets up a DTLS 1.3 connection where both client and server + * alternate between WANT_WRITE and successful writes. After the handshake, + * application data is exchanged in both directions to verify the connection + * works as expected. + * + * Data exchanged after the handshake is also tested with simulated WANT_WRITE + * conditions to ensure the connection remains functional. + */ +int test_dtls13_every_write_want_write(void) +{ + EXPECT_DECLS; +#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && defined(WOLFSSL_DTLS13) + WOLFSSL_CTX *ctx_c = NULL, *ctx_s = NULL; + WOLFSSL *ssl_c = NULL, *ssl_s = NULL; + struct test_memio_ctx test_ctx; + const char msg[] = "want-write"; + const int msgLen = sizeof(msg); + char readBuf[sizeof(msg)]; + struct test_dtls13_wwrite_ctx wwctx_c; + struct test_dtls13_wwrite_ctx wwctx_s; + + XMEMSET(&test_ctx, 0, sizeof(test_ctx)); + + ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s, + wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method), + 0); + + wwctx_c.want_write = 0; + wwctx_c.text_ctx = &test_ctx; + wolfSSL_SetIOWriteCtx(ssl_c, &wwctx_c); + wolfSSL_SSLSetIOSend(ssl_c, test_dtls13_want_write_send_cb); + wwctx_s.want_write = 0; + wwctx_s.text_ctx = &test_ctx; + wolfSSL_SetIOWriteCtx(ssl_s, &wwctx_s); + wolfSSL_SSLSetIOSend(ssl_s, test_dtls13_want_write_send_cb); + + ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 20, NULL), 0); + + ExpectTrue(wolfSSL_is_init_finished(ssl_c)); + ExpectTrue(wolfSSL_is_init_finished(ssl_s)); + + test_memio_simulate_want_write(&test_ctx, 0, 0); + test_memio_simulate_want_write(&test_ctx, 1, 0); + + wolfSSL_SetIOWriteCtx(ssl_c, &test_ctx); + wolfSSL_SSLSetIOSend(ssl_c, test_memio_write_cb); + wolfSSL_SetIOWriteCtx(ssl_s, &test_ctx); + wolfSSL_SSLSetIOSend(ssl_s, test_memio_write_cb); + + XMEMSET(readBuf, 0, sizeof(readBuf)); + ExpectIntEQ(wolfSSL_write(ssl_c, msg, msgLen), msgLen); + ExpectIntEQ(wolfSSL_read(ssl_s, readBuf, sizeof(readBuf)), msgLen); + ExpectStrEQ(readBuf, msg); + + XMEMSET(readBuf, 0, sizeof(readBuf)); + ExpectIntEQ(wolfSSL_write(ssl_s, msg, msgLen), msgLen); + ExpectIntEQ(wolfSSL_read(ssl_c, readBuf, sizeof(readBuf)), msgLen); + ExpectStrEQ(readBuf, msg); + + test_memio_simulate_want_write(&test_ctx, 0, 1); + XMEMSET(readBuf, 0, sizeof(readBuf)); + ExpectIntEQ(wolfSSL_write(ssl_s, msg, msgLen), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_WRITE); + test_memio_simulate_want_write(&test_ctx, 0, 0); + ExpectIntEQ(wolfSSL_write(ssl_s, msg, msgLen), msgLen); + ExpectIntEQ(wolfSSL_read(ssl_c, readBuf, sizeof(readBuf)), msgLen); + ExpectStrEQ(readBuf, msg); + + XMEMSET(readBuf, 0, sizeof(readBuf)); + test_memio_simulate_want_write(&test_ctx, 1, 1); + ExpectIntEQ(wolfSSL_write(ssl_c, msg, msgLen), -1); + ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_WRITE); + test_memio_simulate_want_write(&test_ctx, 1, 0); + ExpectIntEQ(wolfSSL_write(ssl_c, msg, msgLen), msgLen); + ExpectIntEQ(wolfSSL_read(ssl_s, readBuf, sizeof(readBuf)), msgLen); + ExpectStrEQ(readBuf, msg); + + wolfSSL_free(ssl_c); + wolfSSL_CTX_free(ctx_c); + wolfSSL_free(ssl_s); + wolfSSL_CTX_free(ctx_s); +#endif + return EXPECT_RESULT(); +} + int test_wolfSSL_dtls_cid_parse(void) { EXPECT_DECLS; diff --git a/tests/api/test_dtls.h b/tests/api/test_dtls.h index 86d6cc473..397e21b0c 100644 --- a/tests/api/test_dtls.h +++ b/tests/api/test_dtls.h @@ -24,6 +24,8 @@ int test_dtls12_basic_connection_id(void); int test_dtls13_basic_connection_id(void); +int test_dtls13_hrr_want_write(void); +int test_dtls13_every_write_want_write(void); int test_wolfSSL_dtls_cid_parse(void); int test_wolfSSL_dtls_set_pending_peer(void); int test_dtls13_epochs(void); @@ -47,6 +49,8 @@ int test_dtls_certreq_order(void); #define TEST_DTLS_DECLS \ TEST_DECL_GROUP("dtls", test_dtls12_basic_connection_id), \ TEST_DECL_GROUP("dtls", test_dtls13_basic_connection_id), \ + TEST_DECL_GROUP("dtls", test_dtls13_hrr_want_write), \ + TEST_DECL_GROUP("dtls", test_dtls13_every_write_want_write), \ TEST_DECL_GROUP("dtls", test_wolfSSL_dtls_cid_parse), \ TEST_DECL_GROUP("dtls", test_wolfSSL_dtls_set_pending_peer), \ TEST_DECL_GROUP("dtls", test_dtls13_epochs), \ diff --git a/tests/utils.c b/tests/utils.c index 9d8a95cda..08150c3e3 100644 --- a/tests/utils.c +++ b/tests/utils.c @@ -57,14 +57,14 @@ int test_memio_write_cb(WOLFSSL *ssl, char *data, int sz, void *ctx) len = &test_ctx->c_len; msg_sizes = test_ctx->c_msg_sizes; msg_count = &test_ctx->c_msg_count; - forceWantWrite = &test_ctx->c_force_want_write; + forceWantWrite = &test_ctx->s_force_want_write; } else { buf = test_ctx->s_buff; len = &test_ctx->s_len; msg_sizes = test_ctx->s_msg_sizes; msg_count = &test_ctx->s_msg_count; - forceWantWrite = &test_ctx->s_force_want_write; + forceWantWrite = &test_ctx->c_force_want_write; } if (*forceWantWrite)