diff --git a/tests/api/test_tls13.c b/tests/api/test_tls13.c index 59ad15abe..ee9cbcc65 100644 --- a/tests/api/test_tls13.c +++ b/tests/api/test_tls13.c @@ -1992,6 +1992,101 @@ int test_tls13_pq_groups(void) return EXPECT_RESULT(); } +#if defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && \ + defined(WOLFSSL_EARLY_DATA) && defined(HAVE_SESSION_TICKET) +static int test_tls13_read_until_write_ok(WOLFSSL* ssl, void* buf, int bufLen) +{ + int ret, err; + int tries = 5; + + err = 0; + do { + ret = wolfSSL_read(ssl, buf, bufLen); + if (ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR)) { + err = wolfSSL_get_error(ssl, ret); + } + } while (tries-- && ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR) && + err == WC_NO_ERR_TRACE(WOLFSSL_ERROR_WANT_WRITE)); + return ret; +} +static int test_tls13_connect_until_write_ok(WOLFSSL* ssl) +{ + int ret, err; + int tries = 5; + + err = 0; + do { + ret = wolfSSL_connect(ssl); + if (ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR)) { + err = wolfSSL_get_error(ssl, ret); + } + } while (tries-- && ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR) && + err == WC_NO_ERR_TRACE(WOLFSSL_ERROR_WANT_WRITE)); + return ret; +} +static int test_tls13_write_until_write_ok(WOLFSSL* ssl, const void* msg, + int msgLen) +{ + int ret, err; + int tries = 5; + + err = 0; + do { + ret = wolfSSL_write(ssl, msg, msgLen); + if (ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR)) { + err = wolfSSL_get_error(ssl, ret); + } + } while (tries-- && ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR) && + err == WC_NO_ERR_TRACE(WOLFSSL_ERROR_WANT_WRITE)); + return ret; +} +static int test_tls13_early_data_read_until_write_ok(WOLFSSL* ssl, void* buf, + int bufLen, int* read) +{ + int ret, err; + int tries = 5; + + err = 0; + do { + ret = wolfSSL_read_early_data(ssl, buf, bufLen, read); + if (ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR)) { + err = wolfSSL_get_error(ssl, ret); + } + } while (tries-- && ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR) && + err == WC_NO_ERR_TRACE(WOLFSSL_ERROR_WANT_WRITE)); + return ret; +} +static int test_tls13_early_data_write_until_write_ok(WOLFSSL* ssl, + const void* msg, int msgLen, int* written) +{ + int ret, err; + int tries = 5; + + err = 0; + do { + ret = wolfSSL_write_early_data(ssl, msg, msgLen, written); + if (ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR)) { + err = wolfSSL_get_error(ssl, ret); + } + } while (tries-- && ret == WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR) && + err == WC_NO_ERR_TRACE(WOLFSSL_ERROR_WANT_WRITE)); + return ret; +} +struct test_tls13_wwrite_ctx { + int want_write; + struct test_memio_ctx *text_ctx; +}; +static int test_tls13_mock_wantwrite_cb(WOLFSSL* ssl, char* data, int sz, + void* ctx) +{ + struct test_tls13_wwrite_ctx *wwctx = (struct test_tls13_wwrite_ctx *)ctx; + wwctx->want_write = !wwctx->want_write; + if (wwctx->want_write) { + return WOLFSSL_CBIO_ERR_WANT_WRITE; + } + return test_memio_write_cb(ssl, data, sz, wwctx->text_ctx); +} +#endif /* HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES && WOLFSSL_EARLY_DATA */ int test_tls13_early_data(void) { EXPECT_DECLS; @@ -2000,7 +2095,6 @@ int test_tls13_early_data(void) int written = 0; int read = 0; size_t i; - int splitEarlyData; char msg[] = "This is early data"; char msg2[] = "This is client data"; char msg3[] = "This is server data"; @@ -2012,18 +2106,27 @@ int test_tls13_early_data(void) const char* tls_version; int isUdp; int splitEarlyData; + int everyWriteWantWrite; } params[] = { #ifdef WOLFSSL_TLS13 { wolfTLSv1_3_client_method, wolfTLSv1_3_server_method, - "TLS 1.3", 0, 0 }, + "TLS 1.3", 0, 0, 0 }, { wolfTLSv1_3_client_method, wolfTLSv1_3_server_method, - "TLS 1.3", 0, 1 }, + "TLS 1.3", 0, 1, 0 }, + { wolfTLSv1_3_client_method, wolfTLSv1_3_server_method, + "TLS 1.3", 0, 0, 1 }, + { wolfTLSv1_3_client_method, wolfTLSv1_3_server_method, + "TLS 1.3", 0, 1, 1 }, #endif #ifdef WOLFSSL_DTLS13 { wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method, - "DTLS 1.3", 1, 0 }, + "DTLS 1.3", 1, 0, 0 }, { wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method, - "DTLS 1.3", 1, 1 }, + "DTLS 1.3", 1, 1, 0 }, + { wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method, + "DTLS 1.3", 1, 0, 1 }, + { wolfDTLSv1_3_client_method, wolfDTLSv1_3_server_method, + "DTLS 1.3", 1, 1, 1 }, #endif }; @@ -2033,10 +2136,14 @@ int test_tls13_early_data(void) WOLFSSL *ssl_c = NULL, *ssl_s = NULL; WOLFSSL_SESSION *sess = NULL; int splitEarlyData = params[i].splitEarlyData; + int everyWriteWantWrite = params[i].everyWriteWantWrite; + struct test_tls13_wwrite_ctx wwrite_ctx_s, wwrite_ctx_c; XMEMSET(&test_ctx, 0, sizeof(test_ctx)); - fprintf(stderr, "\tEarly data with %s\n", params[i].tls_version); + fprintf(stderr, "\tEarly data with %s%s%s\n", params[i].tls_version, + splitEarlyData ? " (split early data)" : "", + everyWriteWantWrite ? " (every write WANT_WRITE)" : ""); ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, &ctx_s, &ssl_c, &ssl_s, params[i].client_meth, params[i].server_meth), 0); @@ -2071,49 +2178,66 @@ int test_tls13_early_data(void) } #endif + if (everyWriteWantWrite) { + XMEMSET(&wwrite_ctx_c, 0, sizeof(wwrite_ctx_c)); + XMEMSET(&wwrite_ctx_s, 0, sizeof(wwrite_ctx_s)); + wwrite_ctx_c.text_ctx = &test_ctx; + wwrite_ctx_s.text_ctx = &test_ctx; + wolfSSL_SetIOWriteCtx(ssl_c, &wwrite_ctx_c); + wolfSSL_SSLSetIOSend(ssl_c, test_tls13_mock_wantwrite_cb); + wolfSSL_SetIOWriteCtx(ssl_s, &wwrite_ctx_s); + wolfSSL_SSLSetIOSend(ssl_s, test_tls13_mock_wantwrite_cb); + } /* Test 0-RTT data */ wolfSSL_SetLoggingPrefix("client"); - ExpectIntEQ(wolfSSL_write_early_data(ssl_c, msg, sizeof(msg), - &written), sizeof(msg)); + + ExpectIntEQ(test_tls13_early_data_write_until_write_ok(ssl_c, msg, + sizeof(msg), &written), + sizeof(msg)); ExpectIntEQ(written, sizeof(msg)); if (splitEarlyData) { - ExpectIntEQ(wolfSSL_write_early_data(ssl_c, msg, sizeof(msg), - &written), sizeof(msg)); + ExpectIntEQ(test_tls13_early_data_write_until_write_ok(ssl_c, msg, + sizeof(msg), &written), + sizeof(msg)); ExpectIntEQ(written, sizeof(msg)); } /* Read first 0-RTT data (if split otherwise entire data) */ wolfSSL_SetLoggingPrefix("server"); - ExpectIntEQ(wolfSSL_read_early_data(ssl_s, msgBuf, sizeof(msgBuf), - &read), sizeof(msg)); + ExpectIntEQ(test_tls13_early_data_read_until_write_ok(ssl_s, msgBuf, + sizeof(msgBuf), &read), + sizeof(msg)); ExpectIntEQ(read, sizeof(msg)); ExpectStrEQ(msg, msgBuf); /* Test 0.5-RTT data */ - ExpectIntEQ(wolfSSL_write(ssl_s, msg4, sizeof(msg4)), sizeof(msg4)); + ExpectIntEQ(test_tls13_write_until_write_ok(ssl_s, msg4, sizeof(msg4)), + sizeof(msg4)); if (splitEarlyData) { /* Read second 0-RTT data */ - ExpectIntEQ(wolfSSL_read_early_data(ssl_s, msgBuf, - sizeof(msgBuf), &read), sizeof(msg)); + ExpectIntEQ(test_tls13_early_data_read_until_write_ok(ssl_s, msgBuf, + sizeof(msgBuf), &read), + sizeof(msg)); ExpectIntEQ(read, sizeof(msg)); ExpectStrEQ(msg, msgBuf); } if (params[i].isUdp) { wolfSSL_SetLoggingPrefix("client"); - ExpectIntEQ(wolfSSL_connect(ssl_c), -1); + ExpectIntEQ(test_tls13_connect_until_write_ok(ssl_c), -1); ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WC_NO_ERR_TRACE(APP_DATA_READY)); /* Read server 0.5-RTT data */ - ExpectIntEQ(wolfSSL_read(ssl_c, msgBuf, sizeof(msgBuf)), + ExpectIntEQ( + test_tls13_read_until_write_ok(ssl_c, msgBuf, sizeof(msgBuf)), sizeof(msg4)); ExpectStrEQ(msg4, msgBuf); /* Complete handshake */ - ExpectIntEQ(wolfSSL_connect(ssl_c), -1); + ExpectIntEQ(test_tls13_connect_until_write_ok(ssl_c), -1); ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); /* Use wolfSSL_is_init_finished to check if handshake is @@ -2125,42 +2249,51 @@ int test_tls13_early_data(void) * early data parsing logic. */ wolfSSL_SetLoggingPrefix("server"); ExpectFalse(wolfSSL_is_init_finished(ssl_s)); - ExpectIntEQ(wolfSSL_read_early_data(ssl_s, msgBuf, - sizeof(msgBuf), &read), 0); + ExpectIntEQ(test_tls13_early_data_read_until_write_ok(ssl_s, msgBuf, + sizeof(msgBuf), &read), + 0); ExpectIntEQ(read, 0); ExpectTrue(wolfSSL_is_init_finished(ssl_s)); wolfSSL_SetLoggingPrefix("client"); - ExpectIntEQ(wolfSSL_connect(ssl_c), WOLFSSL_SUCCESS); + ExpectIntEQ(test_tls13_connect_until_write_ok(ssl_c), + WOLFSSL_SUCCESS); } else { wolfSSL_SetLoggingPrefix("client"); - ExpectIntEQ(wolfSSL_connect(ssl_c), WOLFSSL_SUCCESS); + ExpectIntEQ(test_tls13_connect_until_write_ok(ssl_c), + WOLFSSL_SUCCESS); wolfSSL_SetLoggingPrefix("server"); ExpectFalse(wolfSSL_is_init_finished(ssl_s)); - ExpectIntEQ(wolfSSL_read_early_data(ssl_s, msgBuf, - sizeof(msgBuf), &read), 0); + ExpectIntEQ(test_tls13_early_data_read_until_write_ok(ssl_s, msgBuf, + sizeof(msgBuf), &read), + 0); ExpectIntEQ(read, 0); ExpectTrue(wolfSSL_is_init_finished(ssl_s)); /* Read server 0.5-RTT data */ wolfSSL_SetLoggingPrefix("client"); - ExpectIntEQ(wolfSSL_read(ssl_c, msgBuf, sizeof(msgBuf)), + ExpectIntEQ( + test_tls13_read_until_write_ok(ssl_c, msgBuf, sizeof(msgBuf)), sizeof(msg4)); ExpectStrEQ(msg4, msgBuf); } /* Test bi-directional write */ wolfSSL_SetLoggingPrefix("client"); - ExpectIntEQ(wolfSSL_write(ssl_c, msg2, sizeof(msg2)), sizeof(msg2)); + ExpectIntEQ(test_tls13_write_until_write_ok(ssl_c, msg2, sizeof(msg2)), + sizeof(msg2)); wolfSSL_SetLoggingPrefix("server"); - ExpectIntEQ(wolfSSL_read(ssl_s, msgBuf, sizeof(msgBuf)), + ExpectIntEQ( + test_tls13_read_until_write_ok(ssl_s, msgBuf, sizeof(msgBuf)), sizeof(msg2)); ExpectStrEQ(msg2, msgBuf); - ExpectIntEQ(wolfSSL_write(ssl_s, msg3, sizeof(msg3)), sizeof(msg3)); + ExpectIntEQ(test_tls13_write_until_write_ok(ssl_s, msg3, sizeof(msg3)), + sizeof(msg3)); wolfSSL_SetLoggingPrefix("client"); - ExpectIntEQ(wolfSSL_read(ssl_c, msgBuf, sizeof(msgBuf)), + ExpectIntEQ( + test_tls13_read_until_write_ok(ssl_c, msgBuf, sizeof(msgBuf)), sizeof(msg3)); ExpectStrEQ(msg3, msgBuf);