tcp_transport/ws_client: websockets now correctly handle messages longer than buffer

transport_ws can now be read multiple times in a row to read frames larger than the buffer.

Added reporting of total payload length and offset to the user in websocket_client.

Added local example test for long messages.

Closes IDF-1083


* Original commit: espressif/esp-idf@ffeda3003c
This commit is contained in:
Marius Vikhammer
2019-11-14 18:09:38 +08:00
committed by gabsuren
parent a6be8e2e3d
commit aec6a75d40
4 changed files with 177 additions and 73 deletions

View File

@ -93,6 +93,8 @@ struct esp_websocket_client {
char *tx_buffer; char *tx_buffer;
int buffer_size; int buffer_size;
ws_transport_opcodes_t last_opcode; ws_transport_opcodes_t last_opcode;
int payload_len;
int payload_offset;
}; };
static uint64_t _tick_get_ms(void) static uint64_t _tick_get_ms(void)
@ -101,19 +103,20 @@ static uint64_t _tick_get_ms(void)
} }
static esp_err_t esp_websocket_client_dispatch_event(esp_websocket_client_handle_t client, static esp_err_t esp_websocket_client_dispatch_event(esp_websocket_client_handle_t client,
esp_websocket_event_id_t event, esp_websocket_event_id_t event,
const char *data, const char *data,
int data_len) int data_len)
{ {
esp_err_t err; esp_err_t err;
esp_websocket_event_data_t event_data; esp_websocket_event_data_t event_data;
event_data.client = client; event_data.client = client;
event_data.user_context = client->config->user_context; event_data.user_context = client->config->user_context;
event_data.data_ptr = data; event_data.data_ptr = data;
event_data.data_len = data_len; event_data.data_len = data_len;
event_data.op_code = client->last_opcode; event_data.op_code = client->last_opcode;
event_data.payload_len = client->payload_len;
event_data.payload_offset = client->payload_offset;
if ((err = esp_event_post_to(client->event_handle, if ((err = esp_event_post_to(client->event_handle,
WEBSOCKET_EVENTS, event, WEBSOCKET_EVENTS, event,
@ -446,10 +449,38 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con
return ESP_OK; return ESP_OK;
} }
static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client)
{
int rlen;
client->payload_offset = 0;
do {
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms);
if (rlen < 0) {
ESP_LOGE(TAG, "Error read data");
esp_websocket_client_abort_connection(client);
return ESP_FAIL;
}
client->payload_len = esp_transport_ws_get_read_payload_len(client->transport);
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
client->payload_offset += rlen;
} while (client->payload_offset < client->payload_len);
// if a PING message received -> send out the PONG, this will not work for PING messages with payload longer than buffer len
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer;
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, client->payload_len,
client->config->network_timeout_ms);
}
return ESP_OK;
}
static void esp_websocket_client_task(void *pv) static void esp_websocket_client_task(void *pv)
{ {
const int lock_timeout = portMAX_DELAY; const int lock_timeout = portMAX_DELAY;
int rlen;
esp_websocket_client_handle_t client = (esp_websocket_client_handle_t) pv; esp_websocket_client_handle_t client = (esp_websocket_client_handle_t) pv;
client->run = true; client->run = true;
@ -506,22 +537,11 @@ static void esp_websocket_client_task(void *pv)
} }
client->ping_tick_ms = _tick_get_ms(); client->ping_tick_ms = _tick_get_ms();
rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms); if (esp_websocket_client_recv(client) == ESP_FAIL) {
if (rlen < 0) { ESP_LOGE(TAG, "Error receive data");
ESP_LOGE(TAG, "Error read data");
esp_websocket_client_abort_connection(client); esp_websocket_client_abort_connection(client);
break; break;
} }
if (rlen >= 0) {
client->last_opcode = esp_transport_ws_get_read_opcode(client->transport);
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_DATA, client->rx_buffer, rlen);
// if a PING message received -> send out the PONG
if (client->last_opcode == WS_TRANSPORT_OPCODES_PING) {
const char *data = (rlen == 0) ? NULL : client->rx_buffer;
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG, data, rlen,
client->config->network_timeout_ms);
}
}
break; break;
case WEBSOCKET_STATE_WAIT_TIMEOUT: case WEBSOCKET_STATE_WAIT_TIMEOUT:
@ -663,7 +683,8 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client)
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client, esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
esp_websocket_event_id_t event, esp_websocket_event_id_t event,
esp_event_handler_t event_handler, esp_event_handler_t event_handler,
void* event_handler_arg) { void *event_handler_arg)
{
if (client == NULL) { if (client == NULL) {
return ESP_ERR_INVALID_ARG; return ESP_ERR_INVALID_ARG;
} }

View File

@ -27,7 +27,7 @@
extern "C" { extern "C" {
#endif #endif
typedef struct esp_websocket_client* esp_websocket_client_handle_t; typedef struct esp_websocket_client *esp_websocket_client_handle_t;
ESP_EVENT_DECLARE_BASE(WEBSOCKET_EVENTS); // declaration of the task events family ESP_EVENT_DECLARE_BASE(WEBSOCKET_EVENTS); // declaration of the task events family
@ -47,11 +47,13 @@ typedef enum {
* @brief Websocket event data * @brief Websocket event data
*/ */
typedef struct { typedef struct {
const char *data_ptr; /*!< Data pointer */ const char *data_ptr; /*!< Data pointer */
int data_len; /*!< Data length */ int data_len; /*!< Data length */
uint8_t op_code; /*!< Received opcode */ uint8_t op_code; /*!< Received opcode */
esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */ esp_websocket_client_handle_t client; /*!< esp_websocket_client_handle_t context */
void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */ void *user_context; /*!< user_data context, from esp_websocket_client_config_t user_data */
int payload_len; /*!< Total payload length, payloads exceeding buffer will be posted through multiple events */
int payload_offset; /*!< Actual offset for the data associated with this event */
} esp_websocket_event_data_t; } esp_websocket_event_data_t;
/** /**
@ -205,7 +207,7 @@ bool esp_websocket_client_is_connected(esp_websocket_client_handle_t client);
esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client, esp_err_t esp_websocket_register_events(esp_websocket_client_handle_t client,
esp_websocket_event_id_t event, esp_websocket_event_id_t event,
esp_event_handler_t event_handler, esp_event_handler_t event_handler,
void* event_handler_arg); void *event_handler_arg);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -3,10 +3,13 @@ from __future__ import unicode_literals
import re import re
import os import os
import socket import socket
import select
import hashlib import hashlib
import base64 import base64
from threading import Thread import queue
import random
import string
from threading import Thread, Event
import ttfw_idf import ttfw_idf
@ -30,7 +33,10 @@ class Websocket:
def __init__(self, port): def __init__(self, port):
self.port = port self.port = port
self.socket = socket.socket() self.socket = socket.socket()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.settimeout(10.0) self.socket.settimeout(10.0)
self.send_q = queue.Queue()
self.shutdown = Event()
def __enter__(self): def __enter__(self):
try: try:
@ -43,23 +49,27 @@ class Websocket:
self.server_thread = Thread(target=self.run_server) self.server_thread = Thread(target=self.run_server)
self.server_thread.start() self.server_thread.start()
return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.shutdown.set()
self.server_thread.join() self.server_thread.join()
self.socket.close() self.socket.close()
self.conn.close() self.conn.close()
def run_server(self): def run_server(self):
self.conn, address = self.socket.accept() # accept new connection self.conn, address = self.socket.accept() # accept new connection
self.conn.settimeout(10.0) self.socket.settimeout(10.0)
print("Connection from: {}".format(address)) print("Connection from: {}".format(address))
self.establish_connection() self.establish_connection()
print("WS established")
# Echo data until client closes connection # Handle connection until client closes it, will echo any data received and send data from send_q queue
self.echo_data() self.handle_conn()
def establish_connection(self): def establish_connection(self):
while True: while not self.shutdown.is_set():
try: try:
# receive data stream. it won't accept data packet greater than 1024 bytes # receive data stream. it won't accept data packet greater than 1024 bytes
data = self.conn.recv(1024).decode() data = self.conn.recv(1024).decode()
@ -70,6 +80,7 @@ class Websocket:
if "Upgrade: websocket" in data and "Connection: Upgrade" in data: if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
self.handshake(data) self.handshake(data)
return return
except socket.error as err: except socket.error as err:
print("Unable to establish a websocket connection: {}, {}".format(err)) print("Unable to establish a websocket connection: {}, {}".format(err))
raise raise
@ -94,26 +105,46 @@ class Websocket:
self.conn.send(resp.encode()) self.conn.send(resp.encode())
def echo_data(self): def handle_conn(self):
while(True): while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1)
try: try:
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL)) if self.conn in r:
if not header: self.echo_data()
# exit if data is not received
return
# Remove mask bit if not self.send_q.empty():
payload_len = ~(1 << 7) & header[1] self._send_data_(self.send_q.get())
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
frame = header + payload
decoded_payload = self.decode_frame(frame)
echo_frame = self.encode_frame(decoded_payload)
self.conn.send(echo_frame)
except socket.error as err: except socket.error as err:
print("Stopped echoing data: {}".format(err)) print("Stopped echoing data: {}".format(err))
raise
def echo_data(self):
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
if not header:
# exit if socket closed by peer
return
# Remove mask bit
payload_len = ~(1 << 7) & header[1]
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
if not payload:
# exit if socket closed by peer
return
frame = header + payload
decoded_payload = self.decode_frame(frame)
print("Sending echo...")
self._send_data_(decoded_payload)
def _send_data_(self, data):
frame = self.encode_frame(data)
self.conn.send(frame)
def send_data(self, data):
self.send_q.put(data.encode())
def decode_frame(self, frame): def decode_frame(self, frame):
# Mask out MASK bit from payload length, this len is only valid for short messages (<126) # Mask out MASK bit from payload length, this len is only valid for short messages (<126)
@ -133,8 +164,18 @@ class Websocket:
# Set FIN = 1 and OP_CODE = 1 (text) # Set FIN = 1 and OP_CODE = 1 (text)
header = (1 << 7) | (1 << 0) header = (1 << 7) | (1 << 0)
frame = bytearray(header) frame = bytearray([header])
frame.append(len(payload)) payload_len = len(payload)
# If payload len is longer than 125 then the next 16 bits are used to encode length
if payload_len > 125:
frame.append(126)
frame.append(payload_len >> 8)
frame.append(0xFF & payload_len)
else:
frame.append(payload_len)
frame += payload frame += payload
return frame return frame
@ -143,8 +184,27 @@ class Websocket:
def test_echo(dut): def test_echo(dut):
dut.expect("WEBSOCKET_EVENT_CONNECTED") dut.expect("WEBSOCKET_EVENT_CONNECTED")
for i in range(0, 10): for i in range(0, 10):
dut.expect(re.compile(r"Received=hello (\d)")) dut.expect(re.compile(r"Received=hello (\d)"), timeout=30)
dut.expect("Websocket Stopped") print("All echos received")
def test_recv_long_msg(dut, websocket, msg_len, repeats):
send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
for _ in range(repeats):
websocket.send_data(send_msg)
recv_msg = ''
while len(recv_msg) < msg_len:
# Filter out color encoding
match = dut.expect(re.compile(r"Received=([a-zA-Z0-9]*).*\n"), timeout=30)[0]
recv_msg += match
if recv_msg == send_msg:
print("Sent message and received message are equal")
else:
raise ValueError("DUT received string do not match sent string, \nexpected: {}\nwith length {}\
\nreceived: {}\nwith length {}".format(send_msg, len(send_msg), recv_msg, len(recv_msg)))
@ttfw_idf.idf_example_test(env_tag="Example_WIFI") @ttfw_idf.idf_example_test(env_tag="Example_WIFI")
@ -178,12 +238,14 @@ def test_examples_protocol_websocket(env, extra_data):
if uri_from_stdin: if uri_from_stdin:
server_port = 4455 server_port = 4455
with Websocket(server_port): with Websocket(server_port) as ws:
uri = "ws://{}:{}".format(get_my_ip(), server_port) uri = "ws://{}:{}".format(get_my_ip(), server_port)
print("DUT connecting to {}".format(uri)) print("DUT connecting to {}".format(uri))
dut1.expect("Please enter uri of websocket endpoint", timeout=30) dut1.expect("Please enter uri of websocket endpoint", timeout=30)
dut1.write(uri) dut1.write(uri)
test_echo(dut1) test_echo(dut1)
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
test_recv_long_msg(dut1, ws, 2000, 3)
else: else:
print("DUT connecting to {}".format(uri)) print("DUT connecting to {}".format(uri))

View File

@ -17,15 +17,26 @@
#include "freertos/FreeRTOS.h" #include "freertos/FreeRTOS.h"
#include "freertos/task.h" #include "freertos/task.h"
#include "freertos/semphr.h"
#include "freertos/event_groups.h" #include "freertos/event_groups.h"
#include "esp_log.h" #include "esp_log.h"
#include "esp_websocket_client.h" #include "esp_websocket_client.h"
#include "esp_event.h" #include "esp_event.h"
#define NO_DATA_TIMEOUT_SEC 10
static const char *TAG = "WEBSOCKET"; static const char *TAG = "WEBSOCKET";
static TimerHandle_t shutdown_signal_timer;
static SemaphoreHandle_t shutdown_sema;
static void shutdown_signaler(TimerHandle_t xTimer)
{
ESP_LOGI(TAG, "No data received for %d seconds, signaling shutdown", NO_DATA_TIMEOUT_SEC);
xSemaphoreGive(shutdown_sema);
}
#if CONFIG_WEBSOCKET_URI_FROM_STDIN #if CONFIG_WEBSOCKET_URI_FROM_STDIN
static void get_string(char *line, size_t size) static void get_string(char *line, size_t size)
{ {
@ -49,20 +60,23 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i
{ {
esp_websocket_event_data_t *data = (esp_websocket_event_data_t *)event_data; esp_websocket_event_data_t *data = (esp_websocket_event_data_t *)event_data;
switch (event_id) { switch (event_id) {
case WEBSOCKET_EVENT_CONNECTED: case WEBSOCKET_EVENT_CONNECTED:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED"); ESP_LOGI(TAG, "WEBSOCKET_EVENT_CONNECTED");
break; break;
case WEBSOCKET_EVENT_DISCONNECTED: case WEBSOCKET_EVENT_DISCONNECTED:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED"); ESP_LOGI(TAG, "WEBSOCKET_EVENT_DISCONNECTED");
break; break;
case WEBSOCKET_EVENT_DATA: case WEBSOCKET_EVENT_DATA:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA"); ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
ESP_LOGI(TAG, "Received opcode=%d", data->op_code); ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
ESP_LOGW(TAG, "Received=%.*s\r\n", data->data_len, (char*)data->data_ptr); ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr);
break; ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset);
case WEBSOCKET_EVENT_ERROR:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR"); xTimerReset(shutdown_signal_timer, portMAX_DELAY);
break; break;
case WEBSOCKET_EVENT_ERROR:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_ERROR");
break;
} }
} }
@ -70,7 +84,11 @@ static void websocket_app_start(void)
{ {
esp_websocket_client_config_t websocket_cfg = {}; esp_websocket_client_config_t websocket_cfg = {};
#if CONFIG_WEBSOCKET_URI_FROM_STDIN shutdown_signal_timer = xTimerCreate("Websocket shutdown timer", NO_DATA_TIMEOUT_SEC * 1000 / portTICK_PERIOD_MS,
pdFALSE, NULL, shutdown_signaler);
shutdown_sema = xSemaphoreCreateBinary();
#if CONFIG_WEBSOCKET_URI_FROM_STDIN
char line[128]; char line[128];
ESP_LOGI(TAG, "Please enter uri of websocket endpoint"); ESP_LOGI(TAG, "Please enter uri of websocket endpoint");
@ -79,10 +97,10 @@ static void websocket_app_start(void)
websocket_cfg.uri = line; websocket_cfg.uri = line;
ESP_LOGI(TAG, "Endpoint uri: %s\n", line); ESP_LOGI(TAG, "Endpoint uri: %s\n", line);
#else #else
websocket_cfg.uri = CONFIG_WEBSOCKET_URI; websocket_cfg.uri = CONFIG_WEBSOCKET_URI;
#endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */ #endif /* CONFIG_WEBSOCKET_URI_FROM_STDIN */
ESP_LOGI(TAG, "Connecting to %s...", websocket_cfg.uri); ESP_LOGI(TAG, "Connecting to %s...", websocket_cfg.uri);
@ -90,6 +108,7 @@ static void websocket_app_start(void)
esp_websocket_register_events(client, WEBSOCKET_EVENT_ANY, websocket_event_handler, (void *)client); esp_websocket_register_events(client, WEBSOCKET_EVENT_ANY, websocket_event_handler, (void *)client);
esp_websocket_client_start(client); esp_websocket_client_start(client);
xTimerStart(shutdown_signal_timer, portMAX_DELAY);
char data[32]; char data[32];
int i = 0; int i = 0;
while (i < 10) { while (i < 10) {
@ -100,8 +119,8 @@ static void websocket_app_start(void)
} }
vTaskDelay(1000 / portTICK_RATE_MS); vTaskDelay(1000 / portTICK_RATE_MS);
} }
// Give server some time to respond before closing
vTaskDelay(3000 / portTICK_RATE_MS); xSemaphoreTake(shutdown_sema, portMAX_DELAY);
esp_websocket_client_stop(client); esp_websocket_client_stop(client);
ESP_LOGI(TAG, "Websocket Stopped"); ESP_LOGI(TAG, "Websocket Stopped");
esp_websocket_client_destroy(client); esp_websocket_client_destroy(client);