mirror of
https://github.com/espressif/esp-idf.git
synced 2025-07-30 18:57:19 +02:00
feat(tcp_transport): Add websocket HTTP redirect
- Add and expose URI parser from HTTP when received a 301 status
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* SPDX-FileCopyrightText: 2024 Espressif Systems (Shanghai) CO LTD
|
* SPDX-FileCopyrightText: 2024-2025 Espressif Systems (Shanghai) CO LTD
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
@ -145,9 +145,8 @@ int mock_valid_poll_read_fragmented_callback(esp_transport_handle_t t, int timeo
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_ws_connect(bool expect_valid_connection,
|
TEST_CASE("WebSocket Transport Connection", "[success]")
|
||||||
CMOCK_mock_read_CALLBACK read_callback,
|
{
|
||||||
CMOCK_mock_poll_read_CALLBACK poll_read_callback=mock_poll_read_callback) {
|
|
||||||
constexpr static auto timeout = 50;
|
constexpr static auto timeout = 50;
|
||||||
constexpr static auto port = 8080;
|
constexpr static auto port = 8080;
|
||||||
constexpr static auto host = "localhost";
|
constexpr static auto host = "localhost";
|
||||||
@ -161,26 +160,39 @@ void test_ws_connect(bool expect_valid_connection,
|
|||||||
unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy};
|
unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy};
|
||||||
REQUIRE(websocket_transport);
|
REQUIRE(websocket_transport);
|
||||||
|
|
||||||
SECTION("Successful connection and read data") {
|
fmt::print("Attempting to connect to WebSocket\n");
|
||||||
fmt::print("Attempting to connect to WebSocket\n");
|
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
|
||||||
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
|
||||||
|
|
||||||
// Set the callback function for mock_write
|
// Set the callback function for mock_write
|
||||||
mock_write_Stub(mock_write_callback);
|
mock_write_Stub(mock_write_callback);
|
||||||
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
|
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
|
||||||
|
|
||||||
|
SECTION("Happy flow") {
|
||||||
// Set the callback function for mock_read
|
// Set the callback function for mock_read
|
||||||
mock_read_Stub(read_callback);
|
mock_read_Stub(mock_valid_read_callback);
|
||||||
mock_poll_read_Stub(poll_read_callback);
|
mock_poll_read_Stub(mock_poll_read_callback);
|
||||||
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||||
|
|
||||||
if (!expect_valid_connection) {
|
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
|
||||||
// for invalid connections we only check that the connect() function fails
|
char buffer[WS_BUFFER_SIZE];
|
||||||
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
|
int read_len = 0;
|
||||||
// and we're done here
|
read_len = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout);
|
||||||
return;
|
|
||||||
}
|
fmt::print("Read result: {}\n", read_len);
|
||||||
|
REQUIRE(read_len > 0); // Ensure data is read
|
||||||
|
|
||||||
|
std::string response(buffer, read_len);
|
||||||
|
REQUIRE(response == "Test");
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Happy flow with fragmented reads byte by byte") {
|
||||||
|
// Set the callback function for mock_read
|
||||||
|
mock_read_Stub(mock_valid_read_fragmented_callback);
|
||||||
|
mock_poll_read_Stub(mock_valid_poll_read_fragmented_callback);
|
||||||
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
|
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||||
|
|
||||||
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
|
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
|
||||||
|
|
||||||
@ -195,43 +207,111 @@ void test_ws_connect(bool expect_valid_connection,
|
|||||||
|
|
||||||
std::string response(buffer, read_len);
|
std::string response(buffer, read_len);
|
||||||
REQUIRE(response == "Test");
|
REQUIRE(response == "Test");
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Happy flow
|
TEST_CASE("WebSocket Transport Connection", "[failure]")
|
||||||
TEST_CASE("WebSocket Transport Connection", "[websocket_transport]")
|
|
||||||
{
|
{
|
||||||
test_ws_connect(true, mock_valid_read_callback);
|
constexpr static auto timeout = 50;
|
||||||
}
|
constexpr static auto port = 8080;
|
||||||
|
constexpr static auto host = "localhost";
|
||||||
|
// Initialize the parent handle
|
||||||
|
unique_transport parent_handle{esp_transport_init(), esp_transport_destroy};
|
||||||
|
REQUIRE(parent_handle);
|
||||||
|
|
||||||
// Happy flow with fragmented reads byte by byte
|
// Set mock functions for parent handle
|
||||||
TEST_CASE("ws connect and reads by fragments", "[websocket_transport]")
|
esp_transport_set_func(parent_handle.get(), mock_connect, mock_read, mock_write, mock_close, mock_poll_read, mock_poll_write, mock_destroy);
|
||||||
{
|
|
||||||
test_ws_connect(true, mock_valid_read_fragmented_callback, mock_valid_poll_read_fragmented_callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Some corner cases where we expect the ws connection to fail
|
unique_transport websocket_transport{esp_transport_ws_init(parent_handle.get()), esp_transport_destroy};
|
||||||
|
REQUIRE(websocket_transport);
|
||||||
|
|
||||||
TEST_CASE("ws connect fails (0 len response)", "[websocket_transport]")
|
fmt::print("Attempting to connect to WebSocket\n");
|
||||||
{
|
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
|
||||||
test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
return 0;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("ws connect fails (invalid response)", "[websocket_transport]")
|
// Set the callback function for mock_write
|
||||||
{
|
mock_write_Stub(mock_write_callback);
|
||||||
test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
|
||||||
int resp_len = len/2;
|
|
||||||
std::memset(buf, '?', resp_len);
|
|
||||||
return resp_len;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE("ws connect fails (big response)", "[websocket_transport]")
|
SECTION("ws connect fails (0 len response)") {
|
||||||
{
|
// Set the callback function for mock_read
|
||||||
test_ws_connect(false, [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
||||||
return WS_BUFFER_SIZE;
|
return 0;
|
||||||
});
|
});
|
||||||
|
mock_poll_read_Stub(mock_poll_read_callback);
|
||||||
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
|
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||||
|
|
||||||
|
// check that the connect() function fails
|
||||||
|
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("ws connect fails (invalid response)") {
|
||||||
|
// Set the callback function for mock_read
|
||||||
|
mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
||||||
|
int resp_len = len / 2;
|
||||||
|
std::memset(buf, '?', resp_len);
|
||||||
|
return resp_len;
|
||||||
|
});
|
||||||
|
mock_poll_read_Stub(mock_poll_read_callback);
|
||||||
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
|
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||||
|
|
||||||
|
// check that the connect() function fails
|
||||||
|
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("ws connect fails (big response)") {
|
||||||
|
// Set the callback function for mock_read
|
||||||
|
mock_read_Stub([](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
||||||
|
return WS_BUFFER_SIZE;
|
||||||
|
});
|
||||||
|
mock_poll_read_Stub(mock_poll_read_callback);
|
||||||
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
|
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||||
|
|
||||||
|
// check that the connect() function fails
|
||||||
|
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) != 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("ws connect receives redirection response") {
|
||||||
|
// Set the callback function for mock_read
|
||||||
|
mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
||||||
|
char response[WS_BUFFER_SIZE];
|
||||||
|
int response_length = snprintf(response, WS_BUFFER_SIZE,
|
||||||
|
"HTTP/1.1 301 Moved Permanently\r\n"
|
||||||
|
"Location: ws://newhost:8080\r\n"
|
||||||
|
"\r\n");
|
||||||
|
std::memcpy(buf, response, response_length);
|
||||||
|
return response_length;
|
||||||
|
});
|
||||||
|
mock_poll_read_Stub(mock_poll_read_callback);
|
||||||
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
|
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||||
|
|
||||||
|
// check that the connect() function returns redir status
|
||||||
|
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 301);
|
||||||
|
// Assert the expected HTTP status code
|
||||||
|
REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("ws connect receives redirection response without location uri") {
|
||||||
|
// Set the callback function for mock_read
|
||||||
|
mock_read_Stub( [](esp_transport_handle_t h, char *buf, int len, int tout, int n) {
|
||||||
|
char response[WS_BUFFER_SIZE];
|
||||||
|
int response_length = snprintf(response, WS_BUFFER_SIZE,
|
||||||
|
"HTTP/1.1 301 Moved Permanently\r\n"
|
||||||
|
"\r\n");
|
||||||
|
std::memcpy(buf, response, response_length);
|
||||||
|
return response_length;
|
||||||
|
});
|
||||||
|
mock_poll_read_Stub(mock_poll_read_callback);
|
||||||
|
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||||
|
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||||
|
|
||||||
|
// check that the connect() function fails
|
||||||
|
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == -1);
|
||||||
|
// Assert the expected HTTP status code
|
||||||
|
REQUIRE((esp_transport_ws_get_upgrade_request_status(websocket_transport.get())) == 301);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
typedef enum ws_transport_opcodes {
|
typedef enum ws_transport_opcodes {
|
||||||
WS_TRANSPORT_OPCODES_CONT = 0x00,
|
WS_TRANSPORT_OPCODES_CONT = 0x00,
|
||||||
WS_TRANSPORT_OPCODES_TEXT = 0x01,
|
WS_TRANSPORT_OPCODES_TEXT = 0x01,
|
||||||
@ -152,7 +153,7 @@ bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t);
|
|||||||
/**
|
/**
|
||||||
* @brief Returns the HTTP status code of the websocket handshake
|
* @brief Returns the HTTP status code of the websocket handshake
|
||||||
*
|
*
|
||||||
* This API should be called after the connection atempt otherwise its result is meaningless
|
* This API should be called after the connection attempt otherwise its result is meaningless
|
||||||
*
|
*
|
||||||
* @param t websocket transport handle
|
* @param t websocket transport handle
|
||||||
*
|
*
|
||||||
@ -162,6 +163,17 @@ bool esp_transport_ws_get_fin_flag(esp_transport_handle_t t);
|
|||||||
*/
|
*/
|
||||||
int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t);
|
int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Returns websocket redir host for the last connection attempt
|
||||||
|
*
|
||||||
|
* @param t websocket transport handle
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
* - URI of the redirection host
|
||||||
|
* - NULL if no redirection was attempted
|
||||||
|
*/
|
||||||
|
char* esp_transport_ws_get_redir_uri(esp_transport_handle_t t);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Returns websocket op-code for last received data
|
* @brief Returns websocket op-code for last received data
|
||||||
*
|
*
|
||||||
|
@ -37,9 +37,17 @@ static const char *TAG = "transport_ws";
|
|||||||
#define WS_SIZE16 126
|
#define WS_SIZE16 126
|
||||||
#define WS_SIZE64 127
|
#define WS_SIZE64 127
|
||||||
#define MAX_WEBSOCKET_HEADER_SIZE 16
|
#define MAX_WEBSOCKET_HEADER_SIZE 16
|
||||||
#define WS_RESPONSE_OK 101
|
|
||||||
#define WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN 125
|
#define WS_TRANSPORT_MAX_CONTROL_FRAME_BUFFER_LEN 125
|
||||||
|
|
||||||
|
// HTTP status codes for redirection as described in RFC 9110.
|
||||||
|
#define WS_HTTP_CODE_MOVED_PERMANENTLY 301
|
||||||
|
#define WS_HTTP_CODE_FOUND 302
|
||||||
|
#define WS_HTTP_CODE_SEE_OTHER 303
|
||||||
|
#define WS_HTTP_CODE_TEMPORARY_REDIRECT 307
|
||||||
|
#define WS_HTTP_CODE_PERMANENT_REDIRECT 308
|
||||||
|
// Grouped HTTP status codes for redirection types.
|
||||||
|
#define WS_HTTP_PERMANENT_REDIRECT(code) ((code == WS_HTTP_CODE_MOVED_PERMANENTLY) || (code == WS_HTTP_CODE_PERMANENT_REDIRECT))
|
||||||
|
#define WS_HTTP_TEMPORARY_REDIRECT(code) ((code == WS_HTTP_CODE_FOUND) || (code == WS_HTTP_CODE_SEE_OTHER) || (code == WS_HTTP_CODE_TEMPORARY_REDIRECT))
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
uint8_t opcode;
|
uint8_t opcode;
|
||||||
@ -62,6 +70,7 @@ typedef struct {
|
|||||||
bool propagate_control_frames;
|
bool propagate_control_frames;
|
||||||
ws_transport_frame_state_t frame_state;
|
ws_transport_frame_state_t frame_state;
|
||||||
esp_transport_handle_t parent;
|
esp_transport_handle_t parent;
|
||||||
|
char *redir_host;
|
||||||
} transport_ws_t;
|
} transport_ws_t;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -306,6 +315,13 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
|
|||||||
if (ws->http_status_code == -1) {
|
if (ws->http_status_code == -1) {
|
||||||
ESP_LOGE(TAG, "HTTP upgrade failed");
|
ESP_LOGE(TAG, "HTTP upgrade failed");
|
||||||
return -1;
|
return -1;
|
||||||
|
} else if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) {
|
||||||
|
ws->redir_host = get_http_header(ws->buffer, "Location:");
|
||||||
|
if (ws->redir_host == NULL) {
|
||||||
|
ESP_LOGE(TAG, "Location header not found");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return ws->http_status_code;
|
||||||
}
|
}
|
||||||
|
|
||||||
char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:");
|
char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:");
|
||||||
@ -343,6 +359,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
|
|||||||
} else {
|
} else {
|
||||||
#ifdef CONFIG_WS_DYNAMIC_BUFFER
|
#ifdef CONFIG_WS_DYNAMIC_BUFFER
|
||||||
free(ws->buffer);
|
free(ws->buffer);
|
||||||
|
ws->redir_host = NULL;
|
||||||
ws->buffer = NULL;
|
ws->buffer = NULL;
|
||||||
#endif
|
#endif
|
||||||
ws->buffer_len = 0;
|
ws->buffer_len = 0;
|
||||||
@ -897,6 +914,22 @@ int esp_transport_ws_get_upgrade_request_status(esp_transport_handle_t t)
|
|||||||
return ws->http_status_code;
|
return ws->http_status_code;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
char* esp_transport_ws_get_redir_uri(esp_transport_handle_t t)
|
||||||
|
{
|
||||||
|
if (!t) {
|
||||||
|
ESP_LOGE(TAG, "Invalid Transport handle (null)");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||||
|
if (!ws) {
|
||||||
|
ESP_LOGE(TAG, "Invalid ws context data (null)");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ws->redir_host;
|
||||||
|
}
|
||||||
|
|
||||||
ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t)
|
ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t)
|
||||||
{
|
{
|
||||||
transport_ws_t *ws = esp_transport_get_context_data(t);
|
transport_ws_t *ws = esp_transport_get_context_data(t);
|
||||||
|
Reference in New Issue
Block a user