From b8a7d96e98aada4a5fb19d902a237b4828c7a8c2 Mon Sep 17 00:00:00 2001 From: Richard Allen Date: Thu, 10 Oct 2024 17:31:21 -0500 Subject: [PATCH 1/2] fix(ws_transport): Fix reading WS header bytes Correct split header bytes When the underlying transport returns header, length, or mask bytes early, again call the underlying transport. This solves the WS parser getting offset when the server sends a burst of frames where the last WS header is split across packet boundaries, so fewer than the needed bytes may be available. Merges https://github.com/espressif/esp-idf/pull/14706 --- components/tcp_transport/transport_ws.c | 36 ++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 0211347fba..1e100a6269 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -133,6 +133,34 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len return to_read; } +static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms) +{ + int total_read = 0; + int len = requested_len; + + while (len > 0) { + int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms); + + if (bytes_read < 0) { + return bytes_read; // Return error from the underlying read + } + + if (bytes_read == 0) { + // If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation + ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read); + return -1; + } + + // Update buffer and remaining length + buffer += bytes_read; + len -= bytes_read; + total_read += bytes_read; + + ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read); + } + return total_read; +} + static char *trimwhitespace(char *str) { char *end; @@ -486,7 +514,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t // Receive and process header first (based on header size) int header = 2; int mask_len = 4; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -500,7 +528,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d", ws->frame_state.opcode, mask, payload_len); if (payload_len == 126) { // headerLen += 2; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -508,7 +536,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t } else if (payload_len == 127) { // headerLen += 8; header = 8; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -523,7 +551,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t if (mask) { // Read and store mask - if (payload_len != 0 && (rlen = esp_transport_read_internal(ws, buffer, mask_len, timeout_ms)) <= 0) { + if (payload_len != 0 && (rlen = esp_transport_read_exact_size(ws, buffer, mask_len, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } From 15636a2f14baf808367e28f97b4f58205fcb28f1 Mon Sep 17 00:00:00 2001 From: David Cermak Date: Tue, 29 Oct 2024 15:47:19 +0100 Subject: [PATCH 2/2] fix(ws_transport): Unit test on reading WS data byte by byte Closes https://github.com/espressif/esp-idf/issues/14704 Closes https://github.com/espressif/esp-protocols/issues/679 --- .../main/test_websocket_transport.cpp | 55 ++++++++++++++++-- components/tcp_transport/transport_ws.c | 56 +++++++++---------- 2 files changed, 78 insertions(+), 33 deletions(-) diff --git a/components/tcp_transport/host_test/main/test_websocket_transport.cpp b/components/tcp_transport/host_test/main/test_websocket_transport.cpp index 2d00bb143f..8ea89cf9c0 100644 --- a/components/tcp_transport/host_test/main/test_websocket_transport.cpp +++ b/components/tcp_transport/host_test/main/test_websocket_transport.cpp @@ -102,17 +102,52 @@ int mock_write_callback(esp_transport_handle_t transport, const char *request_se return len; } -// Callback function for mock_read +// Callbacks for mocked poll_reed and read functions +int mock_poll_read_callback(esp_transport_handle_t t, int timeout_ms, int num_call) +{ + if (num_call) { + return 0; + } + return 1; +} + int mock_valid_read_callback(esp_transport_handle_t transport, char *buffer, int len, int timeout_ms, int num_call) { + if (num_call) { + return 0; + } std::string websocket_response = make_response(); std::memcpy(buffer, websocket_response.data(), websocket_response.size()); return websocket_response.size(); } +// Callback function for mock_read +int mock_valid_read_fragmented_callback(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, int num_call) +{ + static int offset = 0; + std::string websocket_response = make_response(); + if (buffer == nullptr) { + return offset == websocket_response.size() ? 0 : 1; + } + int read_size = 1; + if (offset == websocket_response.size()) { + return 0; + } + std::memcpy(buffer, websocket_response.data() + offset, read_size); + offset += read_size; + return read_size; } -void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read_callback) { +int mock_valid_poll_read_fragmented_callback(esp_transport_handle_t t, int timeout_ms, int num_call) +{ + return mock_valid_read_fragmented_callback(t, nullptr, 0, 0, 0); +} + +} + +void test_ws_connect(bool expect_valid_connection, + CMOCK_mock_read_CALLBACK read_callback, + CMOCK_mock_poll_read_CALLBACK poll_read_callback=mock_poll_read_callback) { constexpr static auto timeout = 50; constexpr static auto port = 8080; constexpr static auto host = "localhost"; @@ -128,7 +163,7 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read SECTION("Successful connection and read data") { fmt::print("Attempting to connect to WebSocket\n"); - esp_crypto_sha1_ExpectAnyArgsAndReturn(0); + esp_crypto_sha1_ExpectAnyArgsAndReturn(0); esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); // Set the callback function for mock_write @@ -136,7 +171,7 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK); // Set the callback function for mock_read mock_read_Stub(read_callback); - mock_poll_read_ExpectAnyArgsAndReturn(1); + mock_poll_read_Stub(poll_read_callback); esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); @@ -150,7 +185,11 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0); char buffer[WS_BUFFER_SIZE]; - int read_len = esp_transport_read(websocket_transport.get(), buffer, WS_BUFFER_SIZE, timeout); + int read_len = 0; + int partial_read; + while ((partial_read = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout)) > 0 ) { + read_len+= partial_read; + } fmt::print("Read result: {}\n", read_len); REQUIRE(read_len > 0); // Ensure data is read @@ -166,6 +205,12 @@ TEST_CASE("WebSocket Transport Connection", "[websocket_transport]") test_ws_connect(true, mock_valid_read_callback); } +// Happy flow with fragmented reads byte by byte +TEST_CASE("ws connect and reads by fragments", "[websocket_transport]") +{ + test_ws_connect(true, mock_valid_read_fragmented_callback, mock_valid_poll_read_fragmented_callback); +} + // Some corner cases where we expect the ws connection to fail TEST_CASE("ws connect fails (0 len response)", "[websocket_transport]") diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 1e100a6269..6367b0d315 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -133,34 +133,6 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len return to_read; } -static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms) -{ - int total_read = 0; - int len = requested_len; - - while (len > 0) { - int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms); - - if (bytes_read < 0) { - return bytes_read; // Return error from the underlying read - } - - if (bytes_read == 0) { - // If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation - ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read); - return -1; - } - - // Update buffer and remaining length - buffer += bytes_read; - len -= bytes_read; - total_read += bytes_read; - - ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read); - } - return total_read; -} - static char *trimwhitespace(char *str) { char *end; @@ -495,6 +467,34 @@ static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int return rlen; } +static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms) +{ + int total_read = 0; + int len = requested_len; + + while (len > 0) { + int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms); + + if (bytes_read < 0) { + return bytes_read; // Return error from the underlying read + } + + if (bytes_read == 0) { + // If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation + ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read); + return -1; + } + + // Update buffer and remaining length + buffer += bytes_read; + len -= bytes_read; + total_read += bytes_read; + + ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read); + } + return total_read; +} + /* Read and parse the WS header, determine length of payload */ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)