diff --git a/components/tcp_transport/host_test/main/test_websocket_transport.cpp b/components/tcp_transport/host_test/main/test_websocket_transport.cpp index 8ea89cf9c0..15599e626d 100644 --- a/components/tcp_transport/host_test/main/test_websocket_transport.cpp +++ b/components/tcp_transport/host_test/main/test_websocket_transport.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: 2024 Espressif Systems (Shanghai) CO LTD + * SPDX-FileCopyrightText: 2024-2025 Espressif Systems (Shanghai) CO LTD * * SPDX-License-Identifier: Apache-2.0 */ @@ -145,9 +145,8 @@ int mock_valid_poll_read_fragmented_callback(esp_transport_handle_t t, int timeo } -void test_ws_connect(bool expect_valid_connection, - CMOCK_mock_read_CALLBACK read_callback, - CMOCK_mock_poll_read_CALLBACK poll_read_callback=mock_poll_read_callback) { +TEST_CASE("WebSocket Transport Connection", "[success]") +{ constexpr static auto timeout = 50; constexpr static auto port = 8080; constexpr static auto host = "localhost"; @@ -161,26 +160,39 @@ void test_ws_connect(bool expect_valid_connection, unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy}; REQUIRE(websocket_transport); - SECTION("Successful connection and read data") { - fmt::print("Attempting to connect to WebSocket\n"); - esp_crypto_sha1_ExpectAnyArgsAndReturn(0); - esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + fmt::print("Attempting to connect to WebSocket\n"); + esp_crypto_sha1_ExpectAnyArgsAndReturn(0); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); - // Set the callback function for mock_write - mock_write_Stub(mock_write_callback); - mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK); + // Set the callback function for mock_write + mock_write_Stub(mock_write_callback); + mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK); + + SECTION("Happy flow") { // Set the callback function for mock_read - mock_read_Stub(read_callback); - mock_poll_read_Stub(poll_read_callback); + mock_read_Stub(mock_valid_read_callback); + mock_poll_read_Stub(mock_poll_read_callback); esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); - if (!expect_valid_connection) { - // for invalid connections we only check that the connect() function fails - REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0); - // and we're done here - return; - } + REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0); + char buffer[WS_BUFFER_SIZE]; + int read_len = 0; + read_len = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout); + + fmt::print("Read result: {}\n", read_len); + REQUIRE(read_len > 0); // Ensure data is read + + std::string response(buffer, read_len); + REQUIRE(response == "Test"); + } + + SECTION("Happy flow with fragmented reads byte by byte") { + // Set the callback function for mock_read + mock_read_Stub(mock_valid_read_fragmented_callback); + mock_poll_read_Stub(mock_valid_poll_read_fragmented_callback); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0); @@ -195,43 +207,111 @@ void test_ws_connect(bool expect_valid_connection, std::string response(buffer, read_len); REQUIRE(response == "Test"); - } } -// Happy flow -TEST_CASE("WebSocket Transport Connection", "[websocket_transport]") +TEST_CASE("WebSocket Transport Connection", "[failure]") { - test_ws_connect(true, mock_valid_read_callback); -} + constexpr static auto timeout = 50; + constexpr static auto port = 8080; + constexpr static auto host = "localhost"; + // Initialize the parent handle + unique_transport parent_handle{esp_transport_init(), esp_transport_destroy}; + REQUIRE(parent_handle); -// Happy flow with fragmented reads byte by byte -TEST_CASE("ws connect and reads by fragments", "[websocket_transport]") -{ - test_ws_connect(true, mock_valid_read_fragmented_callback, mock_valid_poll_read_fragmented_callback); -} + // Set mock functions for parent handle + esp_transport_set_func(parent_handle.get(), mock_connect, mock_read, mock_write, mock_close, mock_poll_read, mock_poll_write, mock_destroy); -// Some corner cases where we expect the ws connection to fail + unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy}; + REQUIRE(websocket_transport); -TEST_CASE("ws connect fails (0 len response)", "[websocket_transport]") -{ - test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) { - return 0; - }); -} + fmt::print("Attempting to connect to WebSocket\n"); + esp_crypto_sha1_ExpectAnyArgsAndReturn(0); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); -TEST_CASE("ws connect fails (invalid response)", "[websocket_transport]") -{ - test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) { - int resp_len = len/2; - std::memset(buf, '?', resp_len); - return resp_len; - }); -} + // Set the callback function for mock_write + mock_write_Stub(mock_write_callback); + mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK); -TEST_CASE("ws connect fails (big response)", "[websocket_transport]") -{ - test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) { - return WS_BUFFER_SIZE; - }); + SECTION("ws connect fails (0 len response)") { + // Set the callback function for mock_read + mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) { + return 0; + }); + mock_poll_read_Stub(mock_poll_read_callback); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); + + // check that the connect() function fails + REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0); + } + + SECTION("ws connect fails (invalid response)") { + // Set the callback function for mock_read + mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) { + int resp_len = len / 2; + std::memset(buf, '?', resp_len); + return resp_len; + }); + mock_poll_read_Stub(mock_poll_read_callback); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); + + // check that the connect() function fails + REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0); + } + + SECTION("ws connect fails (big response)") { + // Set the callback function for mock_read + mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) { + return WS_BUFFER_SIZE; + }); + mock_poll_read_Stub(mock_poll_read_callback); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); + + // check that the connect() function fails + REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0); + } + + SECTION("ws connect receives redirection response") { + // Set the callback function for mock_read + mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) { + char response[WS_BUFFER_SIZE]; + int response_length = snprintf(response, WS_BUFFER_SIZE, + "HTTP/1.1 301 Moved Permanently\r\n" + "Location: ws://newhost:8080\r\n" + "\r\n"); + std::memcpy(buf, response, response_length); + return response_length; + }); + mock_poll_read_Stub(mock_poll_read_callback); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); + + // check that the connect() function returns redir status + REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 301); + // Assert the expected HTTP status code + REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301); + } + + SECTION("ws connect receives redirection response without location uri") { + // Set the callback function for mock_read + mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) { + char response[WS_BUFFER_SIZE]; + int response_length = snprintf(response, WS_BUFFER_SIZE, + "HTTP/1.1 301 Moved Permanently\r\n" + "\r\n"); + std::memcpy(buf, response, response_length); + return response_length; + }); + mock_poll_read_Stub(mock_poll_read_callback); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); + + // check that the connect() function fails + REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == -1); + // Assert the expected HTTP status code + REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301); + } } diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index fcf2230732..053d9f3ccf 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -14,6 +14,7 @@ extern "C" { #endif + typedef enum ws_transport_opcodes { WS_TRANSPORT_OPCODES_CONT = 0x00, WS_TRANSPORT_OPCODES_TEXT = 0x01, @@ -152,7 +153,7 @@ bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t); /** * @brief Returns the HTTP status code of the websocket handshake * - * This API should be called after the connection atempt otherwise its result is meaningless + * This API should be called after the connection attempt otherwise its result is meaningless * * @param t websocket transport handle * @@ -162,6 +163,17 @@ bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t); */ int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t); +/** + * @brief Returns websocket redir host for the last connection attempt + * + * @param t websocket transport handle + * + * @return + * - URI of the redirection host + * - NULL if no redirection was attempted + */ +char* esp_transport_ws_get_redir_uri(esp_transport_handle_t t); + /** * @brief Returns websocket op-code for last received data * diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 4a523b91bd..7a554771ca 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -37,9 +37,17 @@ static const char *TAG = "transport_ws"; #define WS_SIZE16 126 #define WS_SIZE64 127 #define MAX_WEBSOCKET_HEADER_SIZE 16 -#define WS_RESPONSE_OK 101 #define WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN 125 +// HTTP status codes for redirection as described in RFC 9110. +#define WS_HTTP_CODE_MOVED_PERMANENTLY 301 +#define WS_HTTP_CODE_FOUND 302 +#define WS_HTTP_CODE_SEE_OTHER 303 +#define WS_HTTP_CODE_TEMPORARY_REDIRECT 307 +#define WS_HTTP_CODE_PERMANENT_REDIRECT 308 +// Grouped HTTP status codes for redirection types. +#define WS_HTTP_PERMANENT_REDIRECT(code) ((code == WS_HTTP_CODE_MOVED_PERMANENTLY) || (code == WS_HTTP_CODE_PERMANENT_REDIRECT)) +#define WS_HTTP_TEMPORARY_REDIRECT(code) ((code == WS_HTTP_CODE_FOUND) || (code == WS_HTTP_CODE_SEE_OTHER) || (code == WS_HTTP_CODE_TEMPORARY_REDIRECT)) typedef struct { uint8_t opcode; @@ -62,6 +70,7 @@ typedef struct { bool propagate_control_frames; ws_transport_frame_state_t frame_state; esp_transport_handle_t parent; + char *redir_host; } transport_ws_t; /** @@ -306,6 +315,13 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int if (ws->http_status_code == -1) { ESP_LOGE(TAG, "HTTP upgrade failed"); return -1; + } else if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) { + ws->redir_host = get_http_header(ws->buffer, "Location:"); + if (ws->redir_host == NULL) { + ESP_LOGE(TAG, "Location header not found"); + return -1; + } + return ws->http_status_code; } char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:"); @@ -343,6 +359,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int } else { #ifdef CONFIG_WS_DYNAMIC_BUFFER free(ws->buffer); + ws->redir_host = NULL; ws->buffer = NULL; #endif ws->buffer_len = 0; @@ -897,6 +914,22 @@ int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t) return ws->http_status_code; } +char* esp_transport_ws_get_redir_uri(esp_transport_handle_t t) +{ + if (!t) { + ESP_LOGE(TAG, "Invalid Transport handle (null)"); + return NULL; + } + + transport_ws_t *ws = esp_transport_get_context_data(t); + if (!ws) { + ESP_LOGE(TAG, "Invalid ws context data (null)"); + return NULL; + } + + return ws->redir_host; +} + ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t) { transport_ws_t *ws = esp_transport_get_context_data(t);