diff --git a/components/tcp_transport/host_test/main/test_websocket_transport.cpp b/components/tcp_transport/host_test/main/test_websocket_transport.cpp index 15599e626d..a13782ee97 100644 --- a/components/tcp_transport/host_test/main/test_websocket_transport.cpp +++ b/components/tcp_transport/host_test/main/test_websocket_transport.cpp @@ -113,9 +113,6 @@ int mock_poll_read_callback(esp_transport_handle_t t, int timeout_ms, int num_ca int mock_valid_read_callback(esp_transport_handle_t transport, char *buffer, int len, int timeout_ms, int num_call) { - if (num_call) { - return 0; - } std::string websocket_response = make_response(); std::memcpy(buffer, websocket_response.data(), websocket_response.size()); return websocket_response.size(); @@ -160,6 +157,21 @@ TEST_CASE("WebSocket Transport Connection", "[success]") unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy}; REQUIRE(websocket_transport); + // Allocate buffer for response header + constexpr size_t response_header_len = 1024; + std::vector response_header_buffer(response_header_len); + esp_transport_ws_config_t ws_config = { + .ws_path = "/", + .sub_protocol = nullptr, + .user_agent = nullptr, + .headers = nullptr, + .auth = nullptr, + .response_headers = response_header_buffer.data(), + .response_headers_len = response_header_len, + .propagate_control_frames = false + }; + REQUIRE(esp_transport_ws_set_config(websocket_transport.get(), &ws_config) == ESP_OK); + fmt::print("Attempting to connect to WebSocket\n"); esp_crypto_sha1_ExpectAnyArgsAndReturn(0); esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); @@ -176,6 +188,11 @@ TEST_CASE("WebSocket Transport Connection", "[success]") mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0); + + // Verify the response header was stored correctly + std::string expected_header = make_response(); + REQUIRE(std::string(response_header_buffer.data()) == expected_header); + char buffer[WS_BUFFER_SIZE]; int read_len = 0; read_len = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout); @@ -196,6 +213,14 @@ TEST_CASE("WebSocket Transport Connection", "[success]") REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0); + // Verify the response header was stored correctly + std::string expected_header = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept:\r\n" + "\r\n"; + REQUIRE(std::string(response_header_buffer.data()) == expected_header); + char buffer[WS_BUFFER_SIZE]; int read_len = 0; int partial_read; @@ -208,6 +233,25 @@ TEST_CASE("WebSocket Transport Connection", "[success]") std::string response(buffer, read_len); REQUIRE(response == "Test"); } + + SECTION("Happy flow with smaller response header") { + // Set the response header length to 10 + ws_config.response_headers_len = 10; + REQUIRE(esp_transport_ws_set_config(websocket_transport.get(), &ws_config) == ESP_OK); + + // Set the callback function for mock_read + mock_read_Stub(mock_valid_read_callback); + mock_poll_read_Stub(mock_poll_read_callback); + esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); + mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); + + REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0); + + // Verify the response header was stored correctly. it must contain only ten bytes and be null terminated + std::string expected_header = "HTTP/1.1 1\0"; + + REQUIRE(std::string(response_header_buffer.data()) == expected_header); + } } TEST_CASE("WebSocket Transport Connection", "[failure]") @@ -225,6 +269,21 @@ TEST_CASE("WebSocket Transport Connection", "[failure]") unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy}; REQUIRE(websocket_transport); + // Allocate buffer for response header + constexpr size_t response_header_len = 1024; + std::vector response_header_buffer(response_header_len); + esp_transport_ws_config_t ws_config = { + .ws_path = "/", + .sub_protocol = nullptr, + .user_agent = nullptr, + .headers = nullptr, + .auth = nullptr, + .response_headers = response_header_buffer.data(), + .response_headers_len = response_header_len, + .propagate_control_frames = false + }; + REQUIRE(esp_transport_ws_set_config(websocket_transport.get(), &ws_config) == ESP_OK); + fmt::print("Attempting to connect to WebSocket\n"); esp_crypto_sha1_ExpectAnyArgsAndReturn(0); esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); @@ -244,6 +303,9 @@ TEST_CASE("WebSocket Transport Connection", "[failure]") // check that the connect() function fails REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0); + + // Verify the response header is empty + REQUIRE(std::string(response_header_buffer.data()) == ""); } SECTION("ws connect fails (invalid response)") { @@ -259,6 +321,9 @@ TEST_CASE("WebSocket Transport Connection", "[failure]") // check that the connect() function fails REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0); + + // Verify the response header is empty + REQUIRE(std::string(response_header_buffer.data()) == ""); } SECTION("ws connect fails (big response)") { @@ -272,46 +337,8 @@ TEST_CASE("WebSocket Transport Connection", "[failure]") // check that the connect() function fails REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0); - } - SECTION("ws connect receives redirection response") { - // Set the callback function for mock_read - mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) { - char response[WS_BUFFER_SIZE]; - int response_length = snprintf(response, WS_BUFFER_SIZE, - "HTTP/1.1 301 Moved Permanently\r\n" - "Location: ws://newhost:8080\r\n" - "\r\n"); - std::memcpy(buf, response, response_length); - return response_length; - }); - mock_poll_read_Stub(mock_poll_read_callback); - esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); - mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); - - // check that the connect() function returns redir status - REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 301); - // Assert the expected HTTP status code - REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301); - } - - SECTION("ws connect receives redirection response without location uri") { - // Set the callback function for mock_read - mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) { - char response[WS_BUFFER_SIZE]; - int response_length = snprintf(response, WS_BUFFER_SIZE, - "HTTP/1.1 301 Moved Permanently\r\n" - "\r\n"); - std::memcpy(buf, response, response_length); - return response_length; - }); - mock_poll_read_Stub(mock_poll_read_callback); - esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0); - mock_destroy_ExpectAnyArgsAndReturn(ESP_OK); - - // check that the connect() function fails - REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == -1); - // Assert the expected HTTP status code - REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301); + // Verify the response header is empty + REQUIRE(std::string(response_header_buffer.data()) == ""); } } diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index 053d9f3ccf..ba5eef8a11 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -14,6 +14,8 @@ extern "C" { #endif +// Features supported +#define ESP_TRANSPORT_WS_STORE_RESPONSE_HEADERS 1 typedef enum ws_transport_opcodes { WS_TRANSPORT_OPCODES_CONT = 0x00, @@ -36,6 +38,8 @@ typedef struct { const char *user_agent; /*!< WS user agent */ const char *headers; /*!< WS additional headers */ const char *auth; /*!< HTTP authorization header */ + char *response_headers; /*!< The buffer to copy the http response header */ + size_t response_headers_len; /*!< The length of the http response header */ bool propagate_control_frames; /*!< If true, control frames are passed to the reader * If false, only user frames are propagated, control frames are handled * automatically during read operations @@ -107,6 +111,19 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea */ esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth); +/** + * @brief Set the buffer to copy the http response header + * + * @param[in] t The transport handle + * @param[in] response_header The buffer to copy the http response header + * @param[in] response_header_len The length of the http response header + * + * @return + * - ESP_OK + * - ESP_FAIL + */ +esp_err_t esp_transport_ws_set_response_headers(esp_transport_handle_t t, char *response_header, size_t response_header_len); + /** * @brief Set websocket transport parameters * diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 7a554771ca..255c5cac5b 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -71,6 +71,8 @@ typedef struct { ws_transport_frame_state_t frame_state; esp_transport_handle_t parent; char *redir_host; + char *response_header; + size_t response_header_len; } transport_ws_t; /** @@ -305,14 +307,24 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int } while (NULL == strstr(ws->buffer, delimiter) && header_len < WS_BUFFER_SIZE - 1); if (header_len >= WS_BUFFER_SIZE - 1) { - ESP_LOGE(TAG, "Header size exceeded buffer size"); + ESP_LOGE(TAG, "Header size exceeded buffer size (need=%d, max=%d)", header_len + 1, WS_BUFFER_SIZE); return -1; } + if(ws->response_header) { + if(ws->response_header_len < header_len) { + ESP_LOGW(TAG, "Received header length exceedes the allocated buffer size (need=%d, allocated=%d), truncating to allocated size", header_len, ws->response_header_len); + header_len = ws->response_header_len; + } + // Copy response header to the static array + strncpy(ws->response_header, ws->buffer, header_len); + ws->response_header[header_len] = '\0'; + } + char* delim_ptr = strstr(ws->buffer, delimiter); ws->http_status_code = get_http_status_code(ws->buffer); - if (ws->http_status_code == -1) { + if (ws->http_status_code == -1) { ESP_LOGE(TAG, "HTTP upgrade failed"); return -1; } else if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) { @@ -605,7 +617,7 @@ static int ws_handle_control_frame_internal(esp_transport_handle_t t, int timeou if (payload_len > WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN) { ESP_LOGE(TAG, "Not enough room for reading control frames (need=%d, max_allowed=%d)", - ws->frame_state.payload_len, WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN); + ws->frame_state.payload_len, WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN); return -1; } @@ -625,7 +637,7 @@ static int ws_handle_control_frame_internal(esp_transport_handle_t t, int timeou int actual_len = ws_read_payload(t, control_frame_buffer, control_frame_buffer_len, timeout_ms); if (actual_len != payload_len) { ESP_LOGE(TAG, "Control frame (opcode=%d) payload read failed (payload_len=%d, read_len=%d)", - ws->frame_state.opcode, payload_len, actual_len); + ws->frame_state.opcode, payload_len, actual_len); ret = -1; goto free_payload_buffer; } @@ -751,8 +763,8 @@ static int ws_get_socket(esp_transport_handle_t t) esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle) { if (parent_handle == NULL) { - ESP_LOGE(TAG, "Invalid parent ptotocol"); - return NULL; + ESP_LOGE(TAG, "Invalid parent ptotocol"); + return NULL; } esp_transport_handle_t t = esp_transport_init(); if (t == NULL) { @@ -870,6 +882,28 @@ esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth) return ESP_OK; } +esp_err_t esp_transport_ws_set_response_headers(esp_transport_handle_t t, char *response_header, size_t response_header_len) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + + if (response_header != NULL && response_header_len == 0) { + ESP_LOGE(TAG, "Invalid response header length"); + return ESP_ERR_INVALID_ARG; + } + + transport_ws_t *ws = esp_transport_get_context_data(t); + + if (ws == NULL) { + return ESP_ERR_INVALID_ARG; + } + + ws->response_header = response_header; + ws->response_header_len = response_header_len; + return ESP_OK; +} + esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transport_ws_config_t *config) { if (t == NULL) { @@ -897,6 +931,11 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp err = esp_transport_ws_set_auth(t, config->auth); ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) } + if(config->response_headers) { + err = esp_transport_ws_set_response_headers(t, config->response_headers, config->response_headers_len); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } + ws->propagate_control_frames = config->propagate_control_frames; return err; @@ -904,8 +943,8 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t) { - transport_ws_t *ws = esp_transport_get_context_data(t); - return ws->frame_state.fin; +transport_ws_t *ws = esp_transport_get_context_data(t); +return ws->frame_state.fin; } int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t) @@ -969,7 +1008,7 @@ static int esp_transport_ws_handle_control_frames(esp_transport_handle_t t, char if (ws->frame_state.opcode == WS_OPCODE_PING) { // handle PING frames internally: just send a PONG with the same payload actual_len = _ws_write(t, WS_OPCODE_PONG | WS_FIN, WS_MASK, buffer, - payload_len, timeout_ms); + payload_len, timeout_ms); if (actual_len != payload_len) { ESP_LOGE(TAG, "PONG send failed (payload_len=%d, written_len=%d)", payload_len, actual_len); return -1;