mirror of
https://github.com/espressif/esp-protocols.git
synced 2025-07-14 19:16:32 +02:00
tcp_transport/ws_client: websockets now correctly handle messages longer than buffer
transport_ws can now be read multiple times in a row to read frames larger than the buffer. Added reporting of total payload length and offset to the user in websocket_client. Added local example test for long messages. Closes IDF-1083 * Original commit: espressif/esp-idf@ffeda3003c
This commit is contained in:
committed by
gabsuren
parent
a6be8e2e3d
commit
aec6a75d40
@ -93,6 +93,8 @@ struct esp_websocket_client {
|
||||
char *tx_buffer;
|
||||
int buffer_size;
|
||||
ws_transport_opcodes_t last_opcode;
|
||||
int payload_len;
|
||||
int payload_offset;
|
||||
};
|
||||
|
||||
static uint64_t _tick_get_ms(void)
|
||||
@ -110,10 +112,11 @@ static esp_err_t esp_websocket_client_dispatch_event(esp_websocket_client_handle
|
||||
|
||||
event_data.client = client;
|
||||
event_data.user_context = client->config->user_context;
|
||||
|
||||
event_data.data_ptr = data;
|
||||
event_data.data_len = data_len;
|
||||
event_data.op_code = client->last_opcode;
|
||||
event_data.payload_len = client->payload_len;
|
||||
event_data.payload_offset = client->payload_offset;
|
||||
|
||||
if ((err = esp_event_post_to(client->event_handle,
|
||||
WEBSOCKET_EVENTS, event,
|
||||
@ -446,10 +449,38 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con
|
||||
return ESP_OK;
|
||||
}
|
||||
|
||||
static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client)
|
||||
{
|
||||
int rlen;
|
||||
client->payload_offset = 0;
|
||||
do {
|
||||
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms);
|
||||
if (rlen < 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
esp_websocket_client_abort_connection(client);
|
||||
return ESP_FAIL;
|
||||
}
|
||||
client->payload_len = esp_transport_ws_get_read_payload_len(client->transport);
|
||||
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
|
||||
|
||||
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
|
||||
|
||||
client->payload_offset += rlen;
|
||||
} while (client->payload_offset < client->payload_len);
|
||||
|
||||
// if a PING message received -> send out the PONG, this will not work for PING messages with payload longer than buffer len
|
||||
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
|
||||
const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer;
|
||||
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, client->payload_len,
|
||||
client->config->network_timeout_ms);
|
||||
}
|
||||
|
||||
return ESP_OK;
|
||||
}
|
||||
|
||||
static void esp_websocket_client_task(void *pv)
|
||||
{
|
||||
const int lock_timeout = portMAX_DELAY;
|
||||
int rlen;
|
||||
esp_websocket_client_handle_t client = (esp_websocket_client_handle_t) pv;
|
||||
client->run = true;
|
||||
|
||||
@ -506,22 +537,11 @@ static void esp_websocket_client_task(void *pv)
|
||||
}
|
||||
client->ping_tick_ms = _tick_get_ms();
|
||||
|
||||
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms);
|
||||
if (rlen < 0) {
|
||||
ESP_LOGE(TAG, "Error read data");
|
||||
if (esp_websocket_client_recv(client) == ESP_FAIL) {
|
||||
ESP_LOGE(TAG, "Error receive data");
|
||||
esp_websocket_client_abort_connection(client);
|
||||
break;
|
||||
}
|
||||
if (rlen >= 0) {
|
||||
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
|
||||
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
|
||||
// if a PING message received -> send out the PONG
|
||||
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
|
||||
const char *data = (rlen == 0) ? NULL : client->rx_buffer;
|
||||
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, rlen,
|
||||
client->config->network_timeout_ms);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case WEBSOCKET_STATE_WAIT_TIMEOUT:
|
||||
|
||||
@ -663,7 +683,8 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client)
|
||||
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
|
||||
esp_websocket_event_id_t event,
|
||||
esp_event_handler_t event_handler,
|
||||
void* event_handler_arg) {
|
||||
void *event_handler_arg)
|
||||
{
|
||||
if (client == NULL) {
|
||||
return ESP_ERR_INVALID_ARG;
|
||||
}
|
||||
|
@ -27,7 +27,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct esp_websocket_client* esp_websocket_client_handle_t;
|
||||
typedef struct esp_websocket_client *esp_websocket_client_handle_t;
|
||||
|
||||
ESP_EVENT_DECLARE_BASE(WEBSOCKET_EVENTS); // declaration of the task events family
|
||||
|
||||
@ -52,6 +52,8 @@ typedef struct {
|
||||
uint8_t op_code; /*!< Received opcode */
|
||||
esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */
|
||||
void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */
|
||||
int payload_len; /*!< Total payload length, payloads exceeding buffer will be posted through multiple events */
|
||||
int payload_offset; /*!< Actual offset for the data associated with this event */
|
||||
} esp_websocket_event_data_t;
|
||||
|
||||
/**
|
||||
@ -205,7 +207,7 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client);
|
||||
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
|
||||
esp_websocket_event_id_t event,
|
||||
esp_event_handler_t event_handler,
|
||||
void* event_handler_arg);
|
||||
void *event_handler_arg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -3,10 +3,13 @@ from __future__ import unicode_literals
|
||||
import re
|
||||
import os
|
||||
import socket
|
||||
import select
|
||||
import hashlib
|
||||
import base64
|
||||
from threading import Thread
|
||||
|
||||
import queue
|
||||
import random
|
||||
import string
|
||||
from threading import Thread, Event
|
||||
import ttfw_idf
|
||||
|
||||
|
||||
@ -30,7 +33,10 @@ class Websocket:
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
self.socket = socket.socket()
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.settimeout(10.0)
|
||||
self.send_q = queue.Queue()
|
||||
self.shutdown = Event()
|
||||
|
||||
def __enter__(self):
|
||||
try:
|
||||
@ -43,23 +49,27 @@ class Websocket:
|
||||
self.server_thread = Thread(target=self.run_server)
|
||||
self.server_thread.start()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.shutdown.set()
|
||||
self.server_thread.join()
|
||||
self.socket.close()
|
||||
self.conn.close()
|
||||
|
||||
def run_server(self):
|
||||
self.conn, address = self.socket.accept() # accept new connection
|
||||
self.conn.settimeout(10.0)
|
||||
self.socket.settimeout(10.0)
|
||||
|
||||
print("Connection from: {}".format(address))
|
||||
|
||||
self.establish_connection()
|
||||
|
||||
# Echo data until client closes connection
|
||||
self.echo_data()
|
||||
print("WS established")
|
||||
# Handle connection until client closes it, will echo any data received and send data from send_q queue
|
||||
self.handle_conn()
|
||||
|
||||
def establish_connection(self):
|
||||
while True:
|
||||
while not self.shutdown.is_set():
|
||||
try:
|
||||
# receive data stream. it won't accept data packet greater than 1024 bytes
|
||||
data = self.conn.recv(1024).decode()
|
||||
@ -70,6 +80,7 @@ class Websocket:
|
||||
if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
|
||||
self.handshake(data)
|
||||
return
|
||||
|
||||
except socket.error as err:
|
||||
print("Unable to establish a websocket connection: {}, {}".format(err))
|
||||
raise
|
||||
@ -94,26 +105,46 @@ class Websocket:
|
||||
|
||||
self.conn.send(resp.encode())
|
||||
|
||||
def echo_data(self):
|
||||
while(True):
|
||||
def handle_conn(self):
|
||||
while not self.shutdown.is_set():
|
||||
r,w,e = select.select([self.conn], [], [], 1)
|
||||
try:
|
||||
if self.conn in r:
|
||||
self.echo_data()
|
||||
|
||||
if not self.send_q.empty():
|
||||
self._send_data_(self.send_q.get())
|
||||
|
||||
except socket.error as err:
|
||||
print("Stopped echoing data: {}".format(err))
|
||||
raise
|
||||
|
||||
def echo_data(self):
|
||||
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
|
||||
if not header:
|
||||
# exit if data is not received
|
||||
# exit if socket closed by peer
|
||||
return
|
||||
|
||||
# Remove mask bit
|
||||
payload_len = ~(1 << 7) & header[1]
|
||||
|
||||
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
|
||||
|
||||
if not payload:
|
||||
# exit if socket closed by peer
|
||||
return
|
||||
frame = header + payload
|
||||
|
||||
decoded_payload = self.decode_frame(frame)
|
||||
print("Sending echo...")
|
||||
self._send_data_(decoded_payload)
|
||||
|
||||
echo_frame = self.encode_frame(decoded_payload)
|
||||
self.conn.send(echo_frame)
|
||||
except socket.error as err:
|
||||
print("Stopped echoing data: {}".format(err))
|
||||
def _send_data_(self, data):
|
||||
frame = self.encode_frame(data)
|
||||
self.conn.send(frame)
|
||||
|
||||
def send_data(self, data):
|
||||
self.send_q.put(data.encode())
|
||||
|
||||
def decode_frame(self, frame):
|
||||
# Mask out MASK bit from payload length, this len is only valid for short messages (<126)
|
||||
@ -133,8 +164,18 @@ class Websocket:
|
||||
# Set FIN = 1 and OP_CODE = 1 (text)
|
||||
header = (1 << 7) | (1 << 0)
|
||||
|
||||
frame = bytearray(header)
|
||||
frame.append(len(payload))
|
||||
frame = bytearray([header])
|
||||
payload_len = len(payload)
|
||||
|
||||
# If payload len is longer than 125 then the next 16 bits are used to encode length
|
||||
if payload_len > 125:
|
||||
frame.append(126)
|
||||
frame.append(payload_len >> 8)
|
||||
frame.append(0xFF & payload_len)
|
||||
|
||||
else:
|
||||
frame.append(payload_len)
|
||||
|
||||
frame += payload
|
||||
|
||||
return frame
|
||||
@ -143,8 +184,27 @@ class Websocket:
|
||||
def test_echo(dut):
|
||||
dut.expect("WEBSOCKET_EVENT_CONNECTED")
|
||||
for i in range(0, 10):
|
||||
dut.expect(re.compile(r"Received=hello (\d)"))
|
||||
dut.expect("Websocket Stopped")
|
||||
dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
|
||||
print("All echos received")
|
||||
|
||||
|
||||
def test_recv_long_msg(dut, websocket, msg_len, repeats):
|
||||
send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
|
||||
|
||||
for _ in range(repeats):
|
||||
websocket.send_data(send_msg)
|
||||
|
||||
recv_msg = ''
|
||||
while len(recv_msg) < msg_len:
|
||||
# Filter out color encoding
|
||||
match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
|
||||
recv_msg += match
|
||||
|
||||
if recv_msg == send_msg:
|
||||
print("Sent message and received message are equal")
|
||||
else:
|
||||
raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
|
||||
\nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
|
||||
|
||||
|
||||
@ttfw_idf.idf_example_test(env_tag="Example_WIFI")
|
||||
@ -178,12 +238,14 @@ def test_examples_protocol_websocket(env, extra_data):
|
||||
|
||||
if uri_from_stdin:
|
||||
server_port = 4455
|
||||
with Websocket(server_port):
|
||||
with Websocket(server_port) as ws:
|
||||
uri = "ws://{}:{}".format(get_my_ip(), server_port)
|
||||
print("DUT connecting to {}".format(uri))
|
||||
dut1.expect("Please enter uri of websocket endpoint", timeout=30)
|
||||
dut1.write(uri)
|
||||
test_echo(dut1)
|
||||
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
|
||||
test_recv_long_msg(dut1, ws, 2000, 3)
|
||||
|
||||
else:
|
||||
print("DUT connecting to {}".format(uri))
|
||||
|
@ -17,15 +17,26 @@
|
||||
|
||||
#include "freertos/FreeRTOS.h"
|
||||
#include "freertos/task.h"
|
||||
#include "freertos/semphr.h"
|
||||
#include "freertos/event_groups.h"
|
||||
|
||||
|
||||
#include "esp_log.h"
|
||||
#include "esp_websocket_client.h"
|
||||
#include "esp_event.h"
|
||||
|
||||
#define NO_DATA_TIMEOUT_SEC 10
|
||||
|
||||
static const char *TAG = "WEBSOCKET";
|
||||
|
||||
static TimerHandle_t shutdown_signal_timer;
|
||||
static SemaphoreHandle_t shutdown_sema;
|
||||
|
||||
static void shutdown_signaler(TimerHandle_t xTimer)
|
||||
{
|
||||
ESP_LOGI(TAG, "No data received for %d seconds, signaling shutdown", NO_DATA_TIMEOUT_SEC);
|
||||
xSemaphoreGive(shutdown_sema);
|
||||
}
|
||||
|
||||
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
|
||||
static void get_string(char *line, size_t size)
|
||||
{
|
||||
@ -58,7 +69,10 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i
|
||||
case WEBSOCKET_EVENT_DATA:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
|
||||
ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
|
||||
ESP_LOGW(TAG, "Received=%.*s\r\n", data->data_len, (char*)data->data_ptr);
|
||||
ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr);
|
||||
ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset);
|
||||
|
||||
xTimerReset(shutdown_signal_timer, portMAX_DELAY);
|
||||
break;
|
||||
case WEBSOCKET_EVENT_ERROR:
|
||||
ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR");
|
||||
@ -70,7 +84,11 @@ static void websocket_app_start(void)
|
||||
{
|
||||
esp_websocket_client_config_t websocket_cfg = {};
|
||||
|
||||
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
|
||||
shutdown_signal_timer = xTimerCreate("Websocket shutdown timer", NO_DATA_TIMEOUT_SEC * 1000 / portTICK_PERIOD_MS,
|
||||
pdFALSE, NULL, shutdown_signaler);
|
||||
shutdown_sema = xSemaphoreCreateBinary();
|
||||
|
||||
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
|
||||
char line[128];
|
||||
|
||||
ESP_LOGI(TAG, "Please enter uri of websocket endpoint");
|
||||
@ -79,10 +97,10 @@ static void websocket_app_start(void)
|
||||
websocket_cfg.uri = line;
|
||||
ESP_LOGI(TAG, "Endpoint uri: %s\n", line);
|
||||
|
||||
#else
|
||||
#else
|
||||
websocket_cfg.uri = CONFIG_WEBSOCKET_URI;
|
||||
|
||||
#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */
|
||||
#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */
|
||||
|
||||
ESP_LOGI(TAG, "Connecting to %s...", websocket_cfg.uri);
|
||||
|
||||
@ -90,6 +108,7 @@ static void websocket_app_start(void)
|
||||
esp_websocket_register_events(client, WEBSOCKET_EVENT_ANY, websocket_event_handler, (void *)client);
|
||||
|
||||
esp_websocket_client_start(client);
|
||||
xTimerStart(shutdown_signal_timer, portMAX_DELAY);
|
||||
char data[32];
|
||||
int i = 0;
|
||||
while (i < 10) {
|
||||
@ -100,8 +119,8 @@ static void websocket_app_start(void)
|
||||
}
|
||||
vTaskDelay(1000 / portTICK_RATE_MS);
|
||||
}
|
||||
// Give server some time to respond before closing
|
||||
vTaskDelay(3000 / portTICK_RATE_MS);
|
||||
|
||||
xSemaphoreTake(shutdown_sema, portMAX_DELAY);
|
||||
esp_websocket_client_stop(client);
|
||||
ESP_LOGI(TAG, "Websocket Stopped");
|
||||
esp_websocket_client_destroy(client);
|
||||
|
Reference in New Issue
Block a user