From 23b73bb298dde235a478df9207032a52ff8f337c Mon Sep 17 00:00:00 2001 From: Marco Oliverio Date: Tue, 13 May 2025 16:40:56 +0200 Subject: [PATCH] test_memio: preserve write boundaries in reads --- tests/api.c | 202 ++++++++++++++++++++++------------- tests/api/test_dtls.c | 40 +++---- tests/unit.h | 12 +++ tests/utils.c | 241 +++++++++++++++++++++++++++++++++++++++++- tests/utils.h | 15 +++ 5 files changed, 410 insertions(+), 100 deletions(-) diff --git a/tests/api.c b/tests/api.c index 8509f7989..e276f7d46 100644 --- a/tests/api.c +++ b/tests/api.c @@ -7475,22 +7475,33 @@ static WC_INLINE int test_ssl_memio_write_cb(WOLFSSL *ssl, char *data, int sz, struct test_ssl_memio_ctx *test_ctx; byte *buf; int *len; + int *msg_sizes; + int *msg_count; test_ctx = (struct test_ssl_memio_ctx*)ctx; if (wolfSSL_GetSide(ssl) == WOLFSSL_SERVER_END) { buf = test_ctx->c_buff; len = &test_ctx->c_len; + msg_sizes = test_ctx->c_msg_sizes; + msg_count = &test_ctx->c_msg_count; } else { buf = test_ctx->s_buff; len = &test_ctx->s_len; + msg_sizes = test_ctx->s_msg_sizes; + msg_count = &test_ctx->s_msg_count; } if ((unsigned)(*len + sz) > TEST_SSL_MEMIO_BUF_SZ) return WOLFSSL_CBIO_ERR_WANT_WRITE; + if (*msg_count >= TEST_MEMIO_MAX_MSGS) + return WOLFSSL_CBIO_ERR_WANT_WRITE; + XMEMCPY(buf + *len, data, sz); + msg_sizes[*msg_count] = sz; + (*msg_count)++; *len += sz; #ifdef WOLFSSL_DUMP_MEMIO_STREAM @@ -7521,27 +7532,63 @@ static WC_INLINE int test_ssl_memio_read_cb(WOLFSSL *ssl, char *data, int sz, int read_sz; byte *buf; int *len; + int *msg_sizes; + int *msg_count; + int *msg_pos; + int is_dtls; test_ctx = (struct test_ssl_memio_ctx*)ctx; + is_dtls = wolfSSL_dtls(ssl); if (wolfSSL_GetSide(ssl) == WOLFSSL_SERVER_END) { buf = test_ctx->s_buff; len = &test_ctx->s_len; + msg_sizes = test_ctx->s_msg_sizes; + msg_count = &test_ctx->s_msg_count; + msg_pos = &test_ctx->s_msg_pos; } else { buf = test_ctx->c_buff; len = &test_ctx->c_len; + msg_sizes = test_ctx->c_msg_sizes; + msg_count = &test_ctx->c_msg_count; + msg_pos = &test_ctx->c_msg_pos; } - if (*len == 0) + if (*len == 0 || *msg_pos >= *msg_count) return WOLFSSL_CBIO_ERR_WANT_READ; - read_sz = sz < *len ? sz : *len; + /* Calculate how much we can read from current message */ + read_sz = msg_sizes[*msg_pos]; + if (read_sz > sz) + read_sz = sz; - XMEMCPY(data, buf, read_sz); - XMEMMOVE(buf, buf + read_sz, *len - read_sz); + if (read_sz > *len) + return WOLFSSL_CBIO_ERR_GENERAL; + /* Copy data from current message */ + XMEMCPY(data, buf, (size_t)read_sz); + /* remove the read data from the buffer */ + XMEMMOVE(buf, buf + read_sz, (size_t)(*len - read_sz)); *len -= read_sz; + msg_sizes[*msg_pos] -= read_sz; + + /* if we are on dtls, discard the rest of the message */ + if (is_dtls && msg_sizes[*msg_pos] > 0) { + XMEMMOVE(buf, buf + msg_sizes[*msg_pos], (size_t)(*len - msg_sizes[*msg_pos])); + *len -= msg_sizes[*msg_pos]; + msg_sizes[*msg_pos] = 0; + } + + /* If we've read the entire message */ + if (msg_sizes[*msg_pos] == 0) { + /* Move to next message */ + (*msg_pos)++; + if (*msg_pos >= *msg_count) { + *msg_pos = 0; + *msg_count = 0; + } + } return read_sz; } @@ -60904,7 +60951,7 @@ static int test_wolfSSL_dtls_stateless2(void) ExpectFalse(wolfSSL_is_stateful(ssl_s)); ExpectIntNE(test_ctx.c_len, 0); /* consume HRR */ - test_ctx.c_len = 0; + test_memio_clear_buffer(&test_ctx, 1); /* send CH1 */ ExpectIntEQ(wolfSSL_connect(ssl_c), WOLFSSL_FATAL_ERROR); ExpectIntEQ(wolfSSL_get_error(ssl_c, WOLFSSL_FATAL_ERROR), @@ -60962,7 +61009,7 @@ static int test_wolfSSL_dtls_stateless_maxfrag(void) ExpectIntNE(test_ctx.c_len, 0); /* consume HRR from buffer */ - test_ctx.c_len = 0; + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0); wolfSSL_free(ssl_c2); @@ -61017,7 +61064,8 @@ static int _test_wolfSSL_dtls_stateless_resume(byte useticket, byte bad) wolfSSL_free(ssl_s); ssl_s = NULL; - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 1); + test_memio_clear_buffer(&test_ctx, 0); /* make resumption invalid */ if (bad && (sess != NULL)) { if (useticket) { @@ -61100,7 +61148,7 @@ static int test_wolfSSL_dtls_stateless_downgrade(void) (ssl_s->error == WC_NO_ERR_TRACE(WANT_READ))); ExpectIntNE(test_ctx.c_len, 0); /* consume HRR */ - test_ctx.c_len = 0; + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0); wolfSSL_free(ssl_c2); @@ -61850,11 +61898,11 @@ static int test_extra_alerts_wrong_cs(void) WOLFSSL_ERROR_WANT_READ); /* consume CH */ - test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); /* inject SH */ - XMEMCPY(test_ctx.c_buff, test_extra_alerts_wrong_cs_sh, - sizeof(test_extra_alerts_wrong_cs_sh)); - test_ctx.c_len = sizeof(test_extra_alerts_wrong_cs_sh); + ExpectIntEQ(test_memio_inject_message(&test_ctx, 1, + (const char *)test_extra_alerts_wrong_cs_sh, + sizeof(test_extra_alerts_wrong_cs_sh)), 0); ExpectIntNE(wolfSSL_connect(ssl_c), WOLFSSL_SUCCESS); ExpectIntNE(wolfSSL_get_error(ssl_c, WC_NO_ERR_TRACE(WOLFSSL_FATAL_ERROR)), @@ -61914,11 +61962,11 @@ static int test_wrong_cs_downgrade(void) WOLFSSL_ERROR_WANT_READ); /* consume CH */ - test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); /* inject SH */ - XMEMCPY(test_ctx.c_buff, test_wrong_cs_downgrade_sh, - sizeof(test_wrong_cs_downgrade_sh)); - test_ctx.c_len = sizeof(test_wrong_cs_downgrade_sh); + ExpectIntEQ(test_memio_inject_message(&test_ctx, 1, + (const char *)test_wrong_cs_downgrade_sh, + sizeof(test_wrong_cs_downgrade_sh)), 0); ExpectIntNE(wolfSSL_connect(ssl_c), WOLFSSL_SUCCESS); #ifdef OPENSSL_EXTRA @@ -61945,14 +61993,7 @@ static int test_wrong_cs_downgrade(void) #if !defined(WOLFSSL_NO_TLS12) && defined(WOLFSSL_EXTRA_ALERTS) && \ defined(HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES) && !defined(WOLFSSL_SP_MATH) -static void test_remove_msg(byte *msg, int tail_len, int *len, int msg_length) -{ - tail_len -= msg_length; - XMEMMOVE(msg, msg + msg_length, tail_len); - *len = *len - msg_length; -} - -static int test_remove_hs_msg_from_buffer(byte *buf, int *len, byte type, +static int test_remove_hs_msg_from_buffer(struct test_memio_ctx *test_ctx, byte type, byte *found) { const unsigned int _HANDSHAKE_HEADER_SZ = 4; @@ -61961,16 +62002,17 @@ static int test_remove_hs_msg_from_buffer(byte *buf, int *len, byte type, const int _change_cipher = 20; const int _handshake = 22; unsigned int tail_len; - byte *idx, *curr; + byte *idx; + int curr; word8 currType; word16 rLength; word32 hLength; - idx = buf; - tail_len = (unsigned int)*len; + idx = test_ctx->c_buff; + tail_len = (unsigned int)test_ctx->c_len; *found = 0; while (tail_len > _RECORD_HEADER_SZ) { - curr = idx; + curr = (int)(idx - test_ctx->c_buff); currType = *idx; ato16(idx + 3, &rLength); idx += _RECORD_HEADER_SZ; @@ -61983,8 +62025,8 @@ static int test_remove_hs_msg_from_buffer(byte *buf, int *len, byte type, if (rLength != 1) return -1; /* match */ - test_remove_msg(curr, *len - (int)(curr - buf), - len, _RECORD_HEADER_SZ + 1); + test_memio_remove_from_buffer(test_ctx, 1, curr, + _RECORD_HEADER_SZ + rLength); *found = 1; return 0; } @@ -62009,7 +62051,7 @@ static int test_remove_hs_msg_from_buffer(byte *buf, int *len, byte type, } /* match */ - test_remove_msg(curr, *len - (int)(curr - buf), len, + test_memio_remove_from_buffer(test_ctx, 1, curr, hLength + _RECORD_HEADER_SZ); *found = 1; return 0; @@ -62052,8 +62094,7 @@ static int test_remove_hs_message(byte hs_message_type, ExpectIntEQ(wolfSSL_accept(ssl_s), WOLFSSL_SUCCESS); } - ExpectIntEQ(test_remove_hs_msg_from_buffer(test_ctx.c_buff, - &test_ctx.c_len, hs_message_type, &found), 0); + ExpectIntEQ(test_remove_hs_msg_from_buffer(&test_ctx, hs_message_type, &found), 0); if (!found) { wolfSSL_free(ssl_c); @@ -64066,8 +64107,11 @@ static int test_dtls_no_extensions(void) ExpectIntEQ(test_memio_setup(&test_ctx, NULL, &ctx_s, NULL, &ssl_s, NULL, wolfDTLS_server_method), 0); - XMEMCPY(test_ctx.s_buff, chNoExtensions, sizeof(chNoExtensions)); - test_ctx.s_len = sizeof(chNoExtensions); + test_memio_clear_buffer(&test_ctx, 0); + ExpectIntEQ( + test_memio_inject_message(&test_ctx, 1, + (const char *)chNoExtensions, sizeof(chNoExtensions)), 0); + #ifdef OPENSSL_EXTRA if (i > 0) { @@ -64106,8 +64150,9 @@ static int test_tls_alert_no_server_hello(void) ExpectIntEQ(test_memio_setup(&test_ctx, &ctx_c, NULL, &ssl_c, NULL, wolfTLSv1_2_client_method, NULL), 0); - XMEMCPY(test_ctx.c_buff, alert_msg, sizeof(alert_msg)); - test_ctx.c_len = sizeof(alert_msg); + test_memio_clear_buffer(&test_ctx, 1); + ExpectIntEQ(test_memio_inject_message(&test_ctx, 1, + (const char *)alert_msg, sizeof(alert_msg)), 0); ExpectIntEQ(wolfSSL_connect(ssl_c), -1); ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WC_NO_ERR_TRACE(FATAL_ERROR)); @@ -64176,14 +64221,15 @@ static int test_TLSX_CA_NAMES_bad_extension(void) switch (i) { case 0: - XMEMCPY(test_ctx.c_buff, shBadCaNamesExt, - sizeof(shBadCaNamesExt)); - test_ctx.c_len = sizeof(shBadCaNamesExt); + test_memio_clear_buffer(&test_ctx, 0); + ExpectIntEQ(test_memio_inject_message(&test_ctx, 1, + (const char *)shBadCaNamesExt, sizeof(shBadCaNamesExt)), 0); break; case 1: - XMEMCPY(test_ctx.c_buff, shBadCaNamesExt2, - sizeof(shBadCaNamesExt2)); - test_ctx.c_len = sizeof(shBadCaNamesExt2); + test_memio_clear_buffer(&test_ctx, 0); + ExpectIntEQ(test_memio_inject_message(&test_ctx, 1, + (const char *)shBadCaNamesExt2, + sizeof(shBadCaNamesExt2)), 0); break; } @@ -64583,14 +64629,9 @@ static int test_dtls_client_hello_timeout_downgrade(void) ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); /* Drop the SH */ - dtlsRH = (DtlsRecordLayerHeader*)(test_ctx.c_buff); - len = (size_t)((dtlsRH->length[0] << 8) | dtlsRH->length[1]); if (EXPECT_SUCCESS()) { - XMEMMOVE(test_ctx.c_buff, test_ctx.c_buff + - sizeof(DtlsRecordLayerHeader) + len, test_ctx.c_len - - (sizeof(DtlsRecordLayerHeader) + len)); + ExpectIntEQ(test_memio_drop_message(&test_ctx, 1, 0), 0); } - test_ctx.c_len -= sizeof(DtlsRecordLayerHeader) + len; /* Read the remainder of the flight */ ExpectIntEQ(wolfSSL_negotiate(ssl_c), -1); ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); @@ -64616,14 +64657,9 @@ static int test_dtls_client_hello_timeout_downgrade(void) ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); /* Drop the SH */ - dtlsRH = (DtlsRecordLayerHeader*)(test_ctx.c_buff); - len = (size_t)((dtlsRH->length[0] << 8) | dtlsRH->length[1]); if (EXPECT_SUCCESS()) { - XMEMMOVE(test_ctx.c_buff, test_ctx.c_buff + - sizeof(DtlsRecordLayerHeader) + len, test_ctx.c_len - - (sizeof(DtlsRecordLayerHeader) + len)); + ExpectIntEQ(test_memio_drop_message(&test_ctx, 1, 0), 0); } - test_ctx.c_len -= sizeof(DtlsRecordLayerHeader) + len; /* Read the remainder of the flight */ ExpectIntEQ(wolfSSL_negotiate(ssl_c), -1); ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); @@ -64805,11 +64841,8 @@ static int test_dtls_dropped_ccs(void) ExpectIntEQ(len, 1); ExpectIntEQ(dtlsRH->type, change_cipher_spec); if (EXPECT_SUCCESS()) { - XMEMMOVE(test_ctx.c_buff, test_ctx.c_buff + - sizeof(DtlsRecordLayerHeader) + len, test_ctx.c_len - - (sizeof(DtlsRecordLayerHeader) + len)); + ExpectIntEQ(test_memio_drop_message(&test_ctx, 1, 0), 0); } - test_ctx.c_len -= sizeof(DtlsRecordLayerHeader) + len; /* Client rtx flight */ ExpectIntEQ(wolfSSL_negotiate(ssl_c), -1); @@ -65202,6 +65235,7 @@ static int test_dtls_frag_ch(void) WOLFSSL *ssl_s = NULL; struct test_memio_ctx test_ctx; static unsigned int DUMMY_MTU = 256; + unsigned int len; unsigned char four_frag_CH[] = { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xda, 0x01, 0x00, 0x02, 0xdc, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -65308,8 +65342,16 @@ static int test_dtls_frag_ch(void) /* Reject fragmented first CH */ ExpectIntEQ(test_dtls_frag_ch_count_records(four_frag_CH, sizeof(four_frag_CH)), 4); - XMEMCPY(test_ctx.s_buff, four_frag_CH, sizeof(four_frag_CH)); - test_ctx.s_len = sizeof(four_frag_CH); + len = sizeof(four_frag_CH); + test_memio_clear_buffer(&test_ctx, 0); + while (len > 0 && EXPECT_SUCCESS()) { + unsigned int inj_len = len > DUMMY_MTU ? DUMMY_MTU : len; + unsigned char *idx = four_frag_CH + sizeof(four_frag_CH) - len; + ExpectIntEQ(test_memio_inject_message(&test_ctx, 0, (const char *)idx, + inj_len), 0); + len -= inj_len; + } + ExpectIntEQ(test_ctx.s_len, sizeof(four_frag_CH)); while (test_ctx.s_len > 0 && EXPECT_SUCCESS()) { int s_len = test_ctx.s_len; ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); @@ -65403,11 +65445,11 @@ static int test_dtls_empty_keyshare_with_cookie(void) XMEMSET(&sequence_number, 0, sizeof(sequence_number)); XMEMSET(&test_ctx, 0, sizeof(test_ctx)); - XMEMCPY(test_ctx.s_buff, ch_empty_keyshare_with_cookie, - sizeof(ch_empty_keyshare_with_cookie)); - test_ctx.s_len = sizeof(ch_empty_keyshare_with_cookie); ExpectIntEQ(test_memio_setup(&test_ctx, NULL, &ctx_s, NULL, &ssl_s, NULL, wolfDTLSv1_3_server_method), 0); + ExpectIntEQ(test_memio_inject_message(&test_ctx, 0, + (const char *)ch_empty_keyshare_with_cookie, + sizeof(ch_empty_keyshare_with_cookie)), 0); /* CH1 */ ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); @@ -65522,7 +65564,7 @@ static int test_dtls12_missing_finished(void) /* Server second flight with finished */ ExpectIntEQ(wolfSSL_negotiate(ssl_s), 1); /* Let's clear the output */ - test_ctx.c_len = 0; + test_memio_clear_buffer(&test_ctx, 1); /* Let's send some app data */ ExpectIntEQ(wolfSSL_write(ssl_s, test_str, sizeof(test_str)), sizeof(test_str)); @@ -65578,7 +65620,7 @@ static int test_dtls13_missing_finished_client(void) ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); /* Let's clear the output */ - test_ctx.c_len = 0; + test_memio_clear_buffer(&test_ctx, 1); /* Let's send some app data */ ExpectIntEQ(wolfSSL_write(ssl_s, test_str, sizeof(test_str)), sizeof(test_str)); @@ -65642,7 +65684,7 @@ static int test_dtls13_missing_finished_server(void) ExpectIntEQ(wolfSSL_negotiate(ssl_c), -1); ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); /* Let's clear the output */ - test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); /* We should signal that the handshake is done */ ExpectTrue(wolfSSL_is_init_finished(ssl_c)); /* Let's send some app data */ @@ -65997,8 +66039,10 @@ static int test_tls_multi_handshakes_one_record(void) WOLFSSL *ssl_c = NULL, *ssl_s = NULL; RecordLayerHeader* rh = NULL; byte *len ; - int newRecIdx = RECORD_HEADER_SZ; - int idx = 0; + int newRecIdx; + int idx; + byte buff[64 * 1024]; + word16 recLen; XMEMSET(&test_ctx, 0, sizeof(test_ctx)); @@ -66010,9 +66054,14 @@ static int test_tls_multi_handshakes_one_record(void) ExpectIntEQ(wolfSSL_accept(ssl_s), -1); ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); + XMEMSET(buff, 0, sizeof(buff)); + rh = (RecordLayerHeader*)(test_ctx.c_buff); + len = &rh->length[0]; + ato16((const byte*)len, &recLen); + XMEMCPY(buff, test_ctx.c_buff, RECORD_HEADER_SZ + recLen); + newRecIdx = idx = RECORD_HEADER_SZ + recLen; /* Combine server handshake msgs into one record */ while (idx < test_ctx.c_len) { - word16 recLen; rh = (RecordLayerHeader*)(test_ctx.c_buff + idx); len = &rh->length[0]; @@ -66020,20 +66069,23 @@ static int test_tls_multi_handshakes_one_record(void) ato16((const byte*)len, &recLen); idx += RECORD_HEADER_SZ; - XMEMMOVE(test_ctx.c_buff + newRecIdx, test_ctx.c_buff + idx, + XMEMCPY(buff + newRecIdx, test_ctx.c_buff + idx, (size_t)recLen); newRecIdx += recLen; idx += recLen; } - rh = (RecordLayerHeader*)(test_ctx.c_buff); + rh = (RecordLayerHeader*)(buff); len = &rh->length[0]; c16toa((word16)newRecIdx - RECORD_HEADER_SZ, len); - test_ctx.c_len = newRecIdx; + test_memio_clear_buffer(&test_ctx, 1); + test_memio_inject_message(&test_ctx, 1, (const char*)buff, newRecIdx); ExpectIntEQ(wolfSSL_connect(ssl_c), -1); ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); + ExpectIntEQ(test_memio_do_handshake(ssl_c, ssl_s, 10, NULL), 0); + wolfSSL_free(ssl_c); wolfSSL_free(ssl_s); wolfSSL_CTX_free(ctx_c); @@ -66837,7 +66889,7 @@ static int test_wolfSSL_inject(void) if (test_ctx.s_len > 0) { ExpectIntEQ(wolfSSL_inject(ssl_s, test_ctx.s_buff, test_ctx.s_len), 1); - test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); } if (wolfSSL_negotiate(ssl_s) != 1) { ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), @@ -66847,7 +66899,7 @@ static int test_wolfSSL_inject(void) if (test_ctx.c_len > 0) { ExpectIntEQ(wolfSSL_inject(ssl_c, test_ctx.c_buff, test_ctx.c_len), 1); - test_ctx.c_len = 0; + test_memio_clear_buffer(&test_ctx, 1); } wolfSSL_SetLoggingPrefix(NULL); } diff --git a/tests/api/test_dtls.c b/tests/api/test_dtls.c index 16a2b6856..cfa7b36fa 100644 --- a/tests/api/test_dtls.c +++ b/tests/api/test_dtls.c @@ -140,7 +140,8 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); ExpectNull(CLIENT_CID()); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_c), 1); ExpectNull(CLIENT_CID()); } @@ -156,7 +157,8 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); ExpectNull(CLIENT_CID()); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_c), 1); ExpectNull(CLIENT_CID()); } @@ -166,7 +168,8 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); ExpectNull(SERVER_CID()); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_s), 1); ExpectNull(SERVER_CID()); } @@ -176,7 +179,8 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); ExpectNotNull(CLIENT_CID()); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_c), 1); ExpectNotNull(CLIENT_CID()); } @@ -185,7 +189,8 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_negotiate(ssl_s), 1); ExpectNotNull(SERVER_CID()); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_s), 1); ExpectNotNull(SERVER_CID()); } @@ -296,7 +301,8 @@ int test_dtls12_basic_connection_id(void) ExpectNotNull(SERVER_CID()); ExpectIntEQ(wolfSSL_SSL_renegotiate_pending(ssl_s), 1); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_s), 1); ExpectNotNull(SERVER_CID()); } @@ -309,7 +315,8 @@ int test_dtls12_basic_connection_id(void) ExpectNotNull(CLIENT_CID()); ExpectIntEQ(wolfSSL_SSL_renegotiate_pending(ssl_c), 1); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_c), 1); ExpectNotNull(CLIENT_CID()); } @@ -319,7 +326,8 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), WOLFSSL_ERROR_WANT_READ); ExpectNotNull(SERVER_CID()); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_s), 1); ExpectNotNull(SERVER_CID()); } @@ -329,7 +337,8 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), WOLFSSL_ERROR_WANT_READ); ExpectNotNull(CLIENT_CID()); if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; + test_memio_clear_buffer(&test_ctx, 0); + test_memio_clear_buffer(&test_ctx, 1); ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_c), 1); ExpectNotNull(CLIENT_CID()); } @@ -337,8 +346,7 @@ int test_dtls12_basic_connection_id(void) (int)XSTRLEN(params[i])), XSTRLEN(params[i])); /* Server second flight */ wolfSSL_SetLoggingPrefix("server"); - ExpectIntEQ(wolfSSL_negotiate(ssl_s), -1); - ExpectIntEQ(wolfSSL_get_error(ssl_s, -1), APP_DATA_READY); + ExpectIntEQ(wolfSSL_negotiate(ssl_s), 1); XMEMSET(readBuf, 0, sizeof(readBuf)); ExpectIntEQ(wolfSSL_read(ssl_s, readBuf, sizeof(readBuf)), XSTRLEN(params[i])); @@ -347,19 +355,11 @@ int test_dtls12_basic_connection_id(void) ExpectIntEQ(wolfSSL_write(ssl_s, params[i], (int)XSTRLEN(params[i])), XSTRLEN(params[i])); } - ExpectIntEQ(wolfSSL_negotiate(ssl_s), 1); - ExpectNotNull(SERVER_CID()); - if (run_params[j].drop) { - test_ctx.c_len = test_ctx.s_len = 0; - ExpectIntEQ(wolfSSL_dtls_got_timeout(ssl_s), 1); - ExpectNotNull(SERVER_CID()); - } /* Test loading old epoch */ /* Client complete connection */ wolfSSL_SetLoggingPrefix("client"); if (!run_params[j].drop) { - ExpectIntEQ(wolfSSL_negotiate(ssl_c), -1); - ExpectIntEQ(wolfSSL_get_error(ssl_c, -1), APP_DATA_READY); + ExpectIntEQ(wolfSSL_negotiate(ssl_c), 1); XMEMSET(readBuf, 0, sizeof(readBuf)); ExpectIntEQ(wolfSSL_read(ssl_c, readBuf, sizeof(readBuf)), XSTRLEN(params[i])); diff --git a/tests/unit.h b/tests/unit.h index 4edd20a22..67c555558 100644 --- a/tests/unit.h +++ b/tests/unit.h @@ -385,6 +385,8 @@ typedef struct test_ssl_cbf { } test_ssl_cbf; #define TEST_SSL_MEMIO_BUF_SZ (64 * 1024) +#define TEST_MEMIO_MAX_MSGS 32 + typedef struct test_ssl_memio_ctx { WOLFSSL_CTX* s_ctx; WOLFSSL_CTX* c_ctx; @@ -406,6 +408,16 @@ typedef struct test_ssl_memio_ctx { int c_len; byte s_buff[TEST_SSL_MEMIO_BUF_SZ]; int s_len; + + int c_msg_sizes[TEST_MEMIO_MAX_MSGS]; + int c_msg_count; + int c_msg_pos; + int c_msg_offset; + + int s_msg_sizes[TEST_MEMIO_MAX_MSGS]; + int s_msg_count; + int s_msg_pos; + int s_msg_offset; } test_ssl_memio_ctx; int test_ssl_memio_setup(test_ssl_memio_ctx *ctx); diff --git a/tests/utils.c b/tests/utils.c index 2b78b3a67..e0af4c923 100644 --- a/tests/utils.c +++ b/tests/utils.c @@ -43,21 +43,30 @@ int test_memio_write_cb(WOLFSSL *ssl, char *data, int sz, void *ctx) struct test_memio_ctx *test_ctx; byte *buf; int *len; + int *msg_sizes; + int *msg_count; test_ctx = (struct test_memio_ctx*)ctx; if (wolfSSL_GetSide(ssl) == WOLFSSL_SERVER_END) { buf = test_ctx->c_buff; len = &test_ctx->c_len; + msg_sizes = test_ctx->c_msg_sizes; + msg_count = &test_ctx->c_msg_count; } else { buf = test_ctx->s_buff; len = &test_ctx->s_len; + msg_sizes = test_ctx->s_msg_sizes; + msg_count = &test_ctx->s_msg_count; } if ((unsigned)(*len + sz) > TEST_MEMIO_BUF_SZ) return WOLFSSL_CBIO_ERR_WANT_WRITE; + if (*msg_count >= TEST_MEMIO_MAX_MSGS) + return WOLFSSL_CBIO_ERR_WANT_WRITE; + #ifdef WOLFSSL_DUMP_MEMIO_STREAM { char dump_file_name[64]; @@ -71,6 +80,8 @@ int test_memio_write_cb(WOLFSSL *ssl, char *data, int sz, void *ctx) } #endif XMEMCPY(buf + *len, data, (size_t)sz); + msg_sizes[*msg_count] = sz; + (*msg_count)++; *len += sz; return sz; @@ -82,27 +93,64 @@ int test_memio_read_cb(WOLFSSL *ssl, char *data, int sz, void *ctx) int read_sz; byte *buf; int *len; + int *msg_sizes; + int *msg_count; + int *msg_pos; + int is_dtls; test_ctx = (struct test_memio_ctx*)ctx; + is_dtls = wolfSSL_dtls(ssl); if (wolfSSL_GetSide(ssl) == WOLFSSL_SERVER_END) { buf = test_ctx->s_buff; len = &test_ctx->s_len; + msg_sizes = test_ctx->s_msg_sizes; + msg_count = &test_ctx->s_msg_count; + msg_pos = &test_ctx->s_msg_pos; } else { buf = test_ctx->c_buff; len = &test_ctx->c_len; + msg_sizes = test_ctx->c_msg_sizes; + msg_count = &test_ctx->c_msg_count; + msg_pos = &test_ctx->c_msg_pos; } - if (*len == 0) + if (*len == 0 || *msg_pos >= *msg_count) return WOLFSSL_CBIO_ERR_WANT_READ; - read_sz = sz < *len ? sz : *len; + /* Calculate how much we can read from current message */ + read_sz = msg_sizes[*msg_pos]; + if (read_sz > sz) + read_sz = sz; + if (read_sz > *len) { + return WOLFSSL_CBIO_ERR_GENERAL; + } + + /* Copy data from current message */ XMEMCPY(data, buf, (size_t)read_sz); - XMEMMOVE(buf, buf + read_sz,(size_t) (*len - read_sz)); - + /* remove the read data from the buffer */ + XMEMMOVE(buf, buf + read_sz, (size_t)(*len - read_sz)); *len -= read_sz; + msg_sizes[*msg_pos] -= read_sz; + + /* if we are on dtls, discard the rest of the message */ + if (is_dtls && msg_sizes[*msg_pos] > 0) { + XMEMMOVE(buf, buf + msg_sizes[*msg_pos], (size_t)(*len - msg_sizes[*msg_pos])); + *len -= msg_sizes[*msg_pos]; + msg_sizes[*msg_pos] = 0; + } + + /* If we've read the entire message */ + if (msg_sizes[*msg_pos] == 0) { + /* Move to next message */ + (*msg_pos)++; + if (*msg_pos >= *msg_count) { + *msg_pos = 0; + *msg_count = 0; + } + } return read_sz; } @@ -251,6 +299,190 @@ int test_memio_setup_ex(struct test_memio_ctx *ctx, return 0; } +void test_memio_clear_buffer(struct test_memio_ctx *ctx, int is_client) +{ + if (is_client) { + ctx->c_len = 0; + ctx->c_msg_pos = 0; + ctx->c_msg_count = 0; + } else { + ctx->s_len = 0; + ctx->s_msg_pos = 0; + ctx->s_msg_count = 0; + } +} + +int test_memio_inject_message(struct test_memio_ctx* ctx, int client, + const char* data, int sz) +{ + int* len; + int* msg_count; + int* msg_sizes; + byte* buff; + + if (client) { + buff = ctx->c_buff; + len = &ctx->c_len; + msg_count = &ctx->c_msg_count; + msg_sizes = ctx->c_msg_sizes; + } + else { + buff = ctx->s_buff; + len = &ctx->s_len; + msg_count = &ctx->s_msg_count; + msg_sizes = ctx->s_msg_sizes; + } + if (*len + sz > TEST_MEMIO_BUF_SZ) { + return -1; + } + if (*msg_count >= TEST_MEMIO_MAX_MSGS) { + return -1; + } + XMEMCPY(buff + *len, data, (size_t)sz); + msg_sizes[*msg_count] = sz; + (*msg_count)++; + *len += sz; + return 0; +} + +int test_memio_drop_message(struct test_memio_ctx *ctx, int client, int msg_pos) +{ + int *len; + int *msg_count; + int *msg_sizes; + int msg_off, msg_sz; + int i; + byte *buff; + if (client) { + buff = ctx->c_buff; + len = &ctx->c_len; + msg_count = &ctx->c_msg_count; + msg_sizes = ctx->c_msg_sizes; + } else { + buff = ctx->s_buff; + len = &ctx->s_len; + msg_count = &ctx->s_msg_count; + msg_sizes = ctx->s_msg_sizes; + } + if (*msg_count == 0) { + return -1; + } + msg_off = 0; + if (msg_pos >= *msg_count) { + return -1; + } + msg_sz = msg_sizes[msg_pos]; + for (i = 0; i < msg_pos; i++) { + msg_off += msg_sizes[i]; + } + XMEMMOVE(buff + msg_off, buff + msg_off + msg_sz, *len - msg_off - msg_sz); + for (i = msg_pos; i < *msg_count - 1; i++) { + msg_sizes[i] = msg_sizes[i + 1]; + } + *len -= msg_sz; + (*msg_count)--; + return 0; +} + +int test_memio_remove_from_buffer(struct test_memio_ctx* ctx, int client, + int off, int sz) +{ + int* len; + int* msg_count; + int* msg_sizes; + int msg_off; + int i; + byte* buff; + + if (client) { + buff = ctx->c_buff; + len = &ctx->c_len; + msg_count = &ctx->c_msg_count; + msg_sizes = ctx->c_msg_sizes; + } + else { + buff = ctx->s_buff; + len = &ctx->s_len; + msg_count = &ctx->s_msg_count; + msg_sizes = ctx->s_msg_sizes; + } + if (*len == 0) { + return -1; + } + if (off >= *len) { + return -1; + } + if (off + sz > *len) { + return -1; + } + /* find which message the offset is in */ + msg_off = 0; + for (i = 0; i < *msg_count; i++) { + if (off >= msg_off && off < msg_off + msg_sizes[i]) { + break; + } + msg_off += msg_sizes[i]; + } + /* don't support records that are split across messages */ + if (off + sz > msg_off + msg_sizes[i]) { + return -1; + } + if (i == *msg_count) { + return -1; + } + if (sz == msg_sizes[i]) { + return test_memio_drop_message(ctx, client, i); + } + XMEMMOVE(buff + off, buff + off + sz, *len - off - sz); + msg_sizes[i] -= sz; + *len -= sz; + return 0; +} + +int test_memio_modify_message_len(struct test_memio_ctx* ctx, int client, + int msg_pos, int new_len) +{ + int* len; + int* msg_count; + int* msg_sizes; + int msg_off, msg_sz; + int i; + byte* buff; + if (client) { + buff = ctx->c_buff; + len = &ctx->c_len; + msg_count = &ctx->c_msg_count; + msg_sizes = ctx->c_msg_sizes; + } + else { + buff = ctx->s_buff; + len = &ctx->s_len; + msg_count = &ctx->s_msg_count; + msg_sizes = ctx->s_msg_sizes; + } + if (*msg_count == 0) { + return -1; + } + if (msg_pos >= *msg_count) { + return -1; + } + msg_off = 0; + for (i = 0; i < msg_pos; i++) { + msg_off += msg_sizes[i]; + } + msg_sz = msg_sizes[msg_pos]; + if (new_len > msg_sz) { + if (*len + (new_len - msg_sz) > TEST_MEMIO_BUF_SZ) { + return -1; + } + } + XMEMMOVE(buff + msg_off + new_len, buff + msg_off + msg_sz, + *len - msg_off - msg_sz); + msg_sizes[msg_pos] = new_len; + *len = *len - msg_sz + new_len; + return 0; +} + int test_memio_setup(struct test_memio_ctx *ctx, WOLFSSL_CTX **ctx_c, WOLFSSL_CTX **ctx_s, WOLFSSL **ssl_c, WOLFSSL **ssl_s, method_provider method_c, method_provider method_s) @@ -260,4 +492,3 @@ int test_memio_setup(struct test_memio_ctx *ctx, } #endif /* HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES */ - diff --git a/tests/utils.h b/tests/utils.h index 75bae2cb0..ce410f86f 100644 --- a/tests/utils.h +++ b/tests/utils.h @@ -32,6 +32,8 @@ (!defined(WOLFSSL_NO_TLS12) || defined(WOLFSSL_TLS13)) #define HAVE_MANUAL_MEMIO_TESTS_DEPENDENCIES #define TEST_MEMIO_BUF_SZ (64 * 1024) +#define TEST_MEMIO_MAX_MSGS 32 + struct test_memio_ctx { byte c_buff[TEST_MEMIO_BUF_SZ]; @@ -40,6 +42,14 @@ struct test_memio_ctx byte s_buff[TEST_MEMIO_BUF_SZ]; int s_len; const char* s_ciphers; + + int c_msg_sizes[TEST_MEMIO_MAX_MSGS]; + int c_msg_count; + int c_msg_pos; + + int s_msg_sizes[TEST_MEMIO_MAX_MSGS]; + int s_msg_count; + int s_msg_pos; }; int test_memio_write_cb(WOLFSSL *ssl, char *data, int sz, void *ctx); int test_memio_read_cb(WOLFSSL *ssl, char *data, int sz, void *ctx); @@ -53,6 +63,11 @@ int test_memio_setup_ex(struct test_memio_ctx *ctx, method_provider method_c, method_provider method_s, byte *caCert, int caCertSz, byte *serverCert, int serverCertSz, byte *serverKey, int serverKeySz); +void test_memio_clear_buffer(struct test_memio_ctx *ctx, int is_client); +int test_memio_inject_message(struct test_memio_ctx *ctx, int client, const char *data, int sz); +int test_memio_drop_message(struct test_memio_ctx *ctx, int client, int msg_pos); +int test_memio_modify_message_len(struct test_memio_ctx *ctx, int client, int msg_pos, int new_len); +int test_memio_remove_from_buffer(struct test_memio_ctx *ctx, int client, int off, int sz); #endif #endif /* TESTS_UTILS_H */