forked from espressif/esp-idf
Merge branch 'fix/ws_transport_head_fragments_v5.4' into 'release/v5.4'
fix(ws_transport): Fix reading WS header in fragments (v5.4) See merge request espressif/esp-idf!35088
This commit is contained in:
@@ -102,17 +102,52 @@ int mock_write_callback(esp_transport_handle_t transport, const char *request_se
|
||||
return len;
|
||||
}
|
||||
|
||||
// Callback function for mock_read
|
||||
// Callbacks for mocked poll_reed and read functions
|
||||
int mock_poll_read_callback(esp_transport_handle_t t, int timeout_ms, int num_call)
|
||||
{
|
||||
if (num_call) {
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
int mock_valid_read_callback(esp_transport_handle_t transport, char *buffer, int len, int timeout_ms, int num_call)
|
||||
{
|
||||
if (num_call) {
|
||||
return 0;
|
||||
}
|
||||
std::string websocket_response = make_response();
|
||||
std::memcpy(buffer, websocket_response.data(), websocket_response.size());
|
||||
return websocket_response.size();
|
||||
}
|
||||
|
||||
// Callback function for mock_read
|
||||
int mock_valid_read_fragmented_callback(esp_transport_handle_t t, char *buffer, int len, int timeout_ms, int num_call)
|
||||
{
|
||||
static int offset = 0;
|
||||
std::string websocket_response = make_response();
|
||||
if (buffer == nullptr) {
|
||||
return offset == websocket_response.size() ? 0 : 1;
|
||||
}
|
||||
int read_size = 1;
|
||||
if (offset == websocket_response.size()) {
|
||||
return 0;
|
||||
}
|
||||
std::memcpy(buffer, websocket_response.data() + offset, read_size);
|
||||
offset += read_size;
|
||||
return read_size;
|
||||
}
|
||||
|
||||
void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read_callback) {
|
||||
int mock_valid_poll_read_fragmented_callback(esp_transport_handle_t t, int timeout_ms, int num_call)
|
||||
{
|
||||
return mock_valid_read_fragmented_callback(t, nullptr, 0, 0, 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void test_ws_connect(bool expect_valid_connection,
|
||||
CMOCK_mock_read_CALLBACK read_callback,
|
||||
CMOCK_mock_poll_read_CALLBACK poll_read_callback=mock_poll_read_callback) {
|
||||
constexpr static auto timeout = 50;
|
||||
constexpr static auto port = 8080;
|
||||
constexpr static auto host = "localhost";
|
||||
@@ -128,7 +163,7 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read
|
||||
|
||||
SECTION("Successful connection and read data") {
|
||||
fmt::print("Attempting to connect to WebSocket\n");
|
||||
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
|
||||
esp_crypto_sha1_ExpectAnyArgsAndReturn(0);
|
||||
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||
|
||||
// Set the callback function for mock_write
|
||||
@@ -136,7 +171,7 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read
|
||||
mock_connect_ExpectAndReturn(parent_handle.get(), host, port, timeout, ESP_OK);
|
||||
// Set the callback function for mock_read
|
||||
mock_read_Stub(read_callback);
|
||||
mock_poll_read_ExpectAnyArgsAndReturn(1);
|
||||
mock_poll_read_Stub(poll_read_callback);
|
||||
esp_crypto_base64_encode_ExpectAnyArgsAndReturn(0);
|
||||
mock_destroy_ExpectAnyArgsAndReturn(ESP_OK);
|
||||
|
||||
@@ -150,7 +185,11 @@ void test_ws_connect(bool expect_valid_connection, CMOCK_mock_read_CALLBACK read
|
||||
REQUIRE(esp_transport_connect(websocket_transport.get(), host, port, timeout) == 0);
|
||||
|
||||
char buffer[WS_BUFFER_SIZE];
|
||||
int read_len = esp_transport_read(websocket_transport.get(), buffer, WS_BUFFER_SIZE, timeout);
|
||||
int read_len = 0;
|
||||
int partial_read;
|
||||
while ((partial_read = esp_transport_read(websocket_transport.get(), &buffer[read_len], WS_BUFFER_SIZE - read_len, timeout)) > 0 ) {
|
||||
read_len+= partial_read;
|
||||
}
|
||||
fmt::print("Read result: {}\n", read_len);
|
||||
REQUIRE(read_len > 0); // Ensure data is read
|
||||
|
||||
@@ -166,6 +205,12 @@ TEST_CASE("WebSocket Transport Connection", "[websocket_transport]")
|
||||
test_ws_connect(true, mock_valid_read_callback);
|
||||
}
|
||||
|
||||
// Happy flow with fragmented reads byte by byte
|
||||
TEST_CASE("ws connect and reads by fragments", "[websocket_transport]")
|
||||
{
|
||||
test_ws_connect(true, mock_valid_read_fragmented_callback, mock_valid_poll_read_fragmented_callback);
|
||||
}
|
||||
|
||||
// Some corner cases where we expect the ws connection to fail
|
||||
|
||||
TEST_CASE("ws connect fails (0 len response)", "[websocket_transport]")
|
||||
|
@@ -467,6 +467,34 @@ static int ws_read_payload(esp_transport_handle_t t, char *buffer, int len, int
|
||||
return rlen;
|
||||
}
|
||||
|
||||
static int esp_transport_read_exact_size(transport_ws_t *ws, char *buffer, int requested_len, int timeout_ms)
|
||||
{
|
||||
int total_read = 0;
|
||||
int len = requested_len;
|
||||
|
||||
while (len > 0) {
|
||||
int bytes_read = esp_transport_read_internal(ws, buffer, len, timeout_ms);
|
||||
|
||||
if (bytes_read < 0) {
|
||||
return bytes_read; // Return error from the underlying read
|
||||
}
|
||||
|
||||
if (bytes_read == 0) {
|
||||
// If we read 0 bytes, we return an error, since reading exact number of bytes resulted in a timeout operation
|
||||
ESP_LOGW(TAG, "Requested to read %d, actually read %d bytes", requested_len, total_read);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Update buffer and remaining length
|
||||
buffer += bytes_read;
|
||||
len -= bytes_read;
|
||||
total_read += bytes_read;
|
||||
|
||||
ESP_LOGV(TAG, "Read fragment of %d bytes", bytes_read);
|
||||
}
|
||||
return total_read;
|
||||
}
|
||||
|
||||
|
||||
/* Read and parse the WS header, determine length of payload */
|
||||
static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
|
||||
@@ -486,7 +514,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
|
||||
// Receive and process header first (based on header size)
|
||||
int header = 2;
|
||||
int mask_len = 4;
|
||||
if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) {
|
||||
if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
@@ -500,7 +528,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
|
||||
ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d", ws->frame_state.opcode, mask, payload_len);
|
||||
if (payload_len == 126) {
|
||||
// headerLen += 2;
|
||||
if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) {
|
||||
if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
@@ -508,7 +536,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
|
||||
} else if (payload_len == 127) {
|
||||
// headerLen += 8;
|
||||
header = 8;
|
||||
if ((rlen = esp_transport_read_internal(ws, data_ptr, header, timeout_ms)) <= 0) {
|
||||
if ((rlen = esp_transport_read_exact_size(ws, data_ptr, header, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
@@ -523,7 +551,7 @@ static int ws_read_header(esp_transport_handle_t t, char *buffer, int len, int t
|
||||
|
||||
if (mask) {
|
||||
// Read and store mask
|
||||
if (payload_len != 0 && (rlen = esp_transport_read_internal(ws, buffer, mask_len, timeout_ms)) <= 0) {
|
||||
if (payload_len != 0 && (rlen = esp_transport_read_exact_size(ws, buffer, mask_len, timeout_ms)) <= 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
return rlen;
|
||||
}
|
||||
|
Reference in New Issue
Block a user