diff --git a/components/tcp_transport/host_test/main/test_websocket_transport.cpp b/components/tcp_transport/host_test/main/test_websocket_transport.cpp index 2d00bb143f..8ea89cf9c0 100644 --- a/components/tcp_transport/host_test/main/test_websocket_transport.cpp +++ b/components/tcp_transport/host_test/main/test_websocket_transport.cpp @@ -102,17 +102,52 @@ int mock_write_callback(esp_transport_handle_t transport, const char *request_se return len; } -// Callback function for mock_read +// Callbacks for mocked poll_reed and read functions +int mock_poll_read_callback(esp_transport_handle_t t, int timeout_ms, int num_call) +{ + if (num_call) { + return 0; + } + return 1; +} + int mock_valid_read_callback(esp_transport_handle_t transport, char *buffer, int len, int timeout_ms, int num_call) { + if (num_call) { + return 0; + } std::string websocket_response = make_response(); std::memcpy(buffer, websocket_response.data(), websocket_response.size()); return websocket_response.size(); } +// Callback function for mock_read +int mock_valid_read_fragmented_callback(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, int num_call) +{ + static int offset = 0; + std::string websocket_response = make_response(); + if (buffer == nullptr) { + return offset == websocket_response.size() ? 0 : 1; + } + int read_size = 1; + if (offset == websocket_response.size()) { + return 0; + } + std::memcpy(buffer, websocket_response.data() + offset, read_size); + offset += read_size; + return read_size; } -void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read_callback) { +int mock_valid_poll_read_fragmented_callback(esp_transport_handle_t t, int timeout_ms, int num_call) +{ + return mock_valid_read_fragmented_callback(t, nullptr, 0, 0, 0); +} + +} + +void test_ws_connect(bool expect_valid_connection, + CMOCK_mock_read_CALLBACK read_callback, + CMOCK_mock_poll_read_CALLBACK poll_read_callback=mock_poll_read_callback) { constexpr static auto timeout = 50; constexpr static auto port = 8080; constexpr static auto host = "localhost"; @@ -128,7 +163,7 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read SECTION("Successful connection and read data") { fmt::print("Attempting to connect to WebSocket\n"); - esp_crypto_sha1_ExpectAnyArgsAndReturn(0); + esp_crypto_sha1_ExpectAnyArgsAndReturn(0); esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); // Set the callback function for mock_write @@ -136,7 +171,7 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK); // Set the callback function for mock_read mock_read_Stub(read_callback); - mock_poll_read_ExpectAnyArgsAndReturn(1); + mock_poll_read_Stub(poll_read_callback); esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); @@ -150,7 +185,11 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0); char buffer[WS_BUFFER_SIZE]; - int read_len = esp_transport_read(websocket_transport.get(), buffer, WS_BUFFER_SIZE, timeout); + int read_len = 0; + int partial_read; + while ((partial_read = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout)) > 0 ) { + read_len+= partial_read; + } fmt::print("Read result: {}\n", read_len); REQUIRE(read_len > 0); // Ensure data is read @@ -166,6 +205,12 @@ TEST_CASE("WebSocket Transport Connection", "[websocket_transport]") test_ws_connect(true, mock_valid_read_callback); } +// Happy flow with fragmented reads byte by byte +TEST_CASE("ws connect and reads by fragments", "[websocket_transport]") +{ + test_ws_connect(true, mock_valid_read_fragmented_callback, mock_valid_poll_read_fragmented_callback); +} + // Some corner cases where we expect the ws connection to fail TEST_CASE("ws connect fails (0 len response)", "[websocket_transport]") diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 0211347fba..6367b0d315 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -467,6 +467,34 @@ static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int return rlen; } +static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms) +{ + int total_read = 0; + int len = requested_len; + + while (len > 0) { + int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms); + + if (bytes_read < 0) { + return bytes_read; // Return error from the underlying read + } + + if (bytes_read == 0) { + // If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation + ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read); + return -1; + } + + // Update buffer and remaining length + buffer += bytes_read; + len -= bytes_read; + total_read += bytes_read; + + ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read); + } + return total_read; +} + /* Read and parse the WS header, determine length of payload */ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) @@ -486,7 +514,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t // Receive and process header first (based on header size) int header = 2; int mask_len = 4; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -500,7 +528,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d", ws->frame_state.opcode, mask, payload_len); if (payload_len == 126) { // headerLen += 2; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -508,7 +536,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t } else if (payload_len == 127) { // headerLen += 8; header = 8; - if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) { + if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; } @@ -523,7 +551,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t if (mask) { // Read and store mask - if (payload_len != 0 && (rlen = esp_transport_read_internal(ws, buffer, mask_len, timeout_ms)) <= 0) { + if (payload_len != 0 && (rlen = esp_transport_read_exact_size(ws, buffer, mask_len, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; }