From b213f2c6d3d78ba3a95005e3206d4ce370b8a649 Mon Sep 17 00:00:00 2001 From: David Cermak Date: Fri, 17 Jul 2020 17:59:05 +0200 Subject: [PATCH 1/3] ws_client: Added support for close frame, closing connection gracefully --- .../esp_websocket_client.c | 92 ++++++++++++++++--- .../include/esp_websocket_client.h | 30 ++++++ .../websocket/main/websocket_example.c | 8 +- 3 files changed, 116 insertions(+), 14 deletions(-) diff --git a/components/esp_websocket_client/esp_websocket_client.c b/components/esp_websocket_client/esp_websocket_client.c index 5059e36dd2..23bdd3c6a5 100644 --- a/components/esp_websocket_client/esp_websocket_client.c +++ b/components/esp_websocket_client/esp_websocket_client.c @@ -52,6 +52,8 @@ static const char *TAG = "WEBSOCKET_CLIENT"; } const static int STOPPED_BIT = BIT0; +const static int CLOSING_BIT = BIT1; // Indicates that a close frame received from server + // and we are waiting for the "Reset by Peer" from the server ESP_EVENT_DEFINE_BASE(WEBSOCKET_EVENTS); @@ -477,6 +479,11 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client) do { rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms); if (rlen < 0) { + if (CLOSING_BIT & xEventGroupGetBits(client->status_bits)) { + client->run = false; + client->state = WEBSOCKET_STATE_UNKNOW; + return ESP_OK; + } ESP_LOGE(TAG, "Error read data"); return ESP_FAIL; } @@ -493,9 +500,10 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client) const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer; esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG | WS_TRANSPORT_OPCODES_FIN, data, client->payload_len, client->config->network_timeout_ms); - } - else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) { + } else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) { client->wait_for_pong_resp = false; + } else if (client->last_opcode == WS_TRANSPORT_OPCODES_CLOSE) { + xEventGroupSetBits(client->status_bits, CLOSING_BIT); } return ESP_OK; @@ -520,7 +528,7 @@ static void esp_websocket_client_task(void *pv) } client->state = WEBSOCKET_STATE_INIT; - xEventGroupClearBits(client->status_bits, STOPPED_BIT); + xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSING_BIT); int read_select = 0; while (client->run) { if (xSemaphoreTakeRecursive(client->lock, lock_timeout) != pdPASS) { @@ -598,6 +606,11 @@ static void esp_websocket_client_task(void *pv) if (WEBSOCKET_STATE_CONNECTED == client->state) { read_select = esp_transport_poll_read(client->transport, 1000); //Poll every 1000ms if (read_select < 0) { + if (CLOSING_BIT & xEventGroupGetBits(client->status_bits)) { + client->run = false; + client->state = WEBSOCKET_STATE_UNKNOW; + break; + } ESP_LOGE(TAG, "Network error: esp_transport_poll_read() returned %d, errno=%d", read_select, errno); esp_websocket_client_abort_connection(client); } @@ -626,7 +639,7 @@ esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client) ESP_LOGE(TAG, "Error create websocket task"); return ESP_FAIL; } - xEventGroupClearBits(client->status_bits, STOPPED_BIT); + xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSING_BIT); return ESP_OK; } @@ -645,30 +658,85 @@ esp_err_t esp_websocket_client_stop(esp_websocket_client_handle_t client) return ESP_OK; } -static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const char *data, int len, TickType_t timeout); +static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, uint8_t *data, int len, TickType_t timeout); + +int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout) +{ + uint8_t *close_status_data = NULL; + // RFC6455#section-5.5.1: The Close frame MAY contain a body (indicated by total_len >= 2) + if (total_len >= 2) { + close_status_data = calloc(1, total_len); + // RFC6455#section-5.5.1: The first two bytes of the body MUST be a 2-byte representing a status + uint16_t *code_network_order = (uint16_t *) close_status_data; + *code_network_order = htons(code); + memcpy(close_status_data + 2, additional_data, total_len - 2); + } + int ret = esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_CLOSE, close_status_data, total_len, timeout); + free(close_status_data); + return ret; +} + + +static esp_err_t esp_websocket_client_close_with_optional_body(esp_websocket_client_handle_t client, bool send_body, int code, const char *data, int len, TickType_t timeout) +{ + if (client == NULL) { + return ESP_ERR_INVALID_ARG; + } + if (!client->run) { + ESP_LOGW(TAG, "Client was not started"); + return ESP_FAIL; + } + + if (send_body) { + esp_websocket_client_send_close(client, code, data, len + 2, portMAX_DELAY); // len + 2 -> always sending the code + } else { + esp_websocket_client_send_close(client, 0, NULL, 0, portMAX_DELAY); // only opcode frame + } + + if (STOPPED_BIT & xEventGroupWaitBits(client->status_bits, STOPPED_BIT, false, true, timeout)) { + return ESP_OK; + } + + // If could not close gracefully within timeout, stop the client and disconnect + client->run = false; + xEventGroupWaitBits(client->status_bits, STOPPED_BIT, false, true, portMAX_DELAY); + client->state = WEBSOCKET_STATE_UNKNOW; + return ESP_OK; +} + +esp_err_t esp_websocket_client_close_with_code(esp_websocket_client_handle_t client, int code, const char *data, int len, TickType_t timeout) +{ + return esp_websocket_client_close_with_optional_body(client, true, code, data, len, timeout); +} + +esp_err_t esp_websocket_client_close(esp_websocket_client_handle_t client, TickType_t timeout) +{ + return esp_websocket_client_close_with_optional_body(client, false, 0, NULL, 0, timeout); +} int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) { - return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, data, len, timeout); + return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, (uint8_t *)data, len, timeout); } int esp_websocket_client_send(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) { - return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, data, len, timeout); + return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (uint8_t *)data, len, timeout); } int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) { - return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, data, len, timeout); + return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (uint8_t *)data, len, timeout); } -static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const char *data, int len, TickType_t timeout) +static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, uint8_t *data, int len, TickType_t timeout) { int need_write = len; int wlen = 0, widx = 0; int ret = ESP_FAIL; - if (client == NULL || data == NULL || len <= 0) { + if (client == NULL || len < 0 || + (opcode != WS_TRANSPORT_OPCODES_CLOSE && (data == NULL || len <= 0))) { ESP_LOGE(TAG, "Invalid arguments"); return ESP_FAIL; } @@ -688,7 +756,7 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c goto unlock_and_return; } uint32_t current_opcode = opcode; - while (widx < len) { + while (widx < len || current_opcode) { // allow for sending "current_opcode" only massage with len==0 if (need_write > client->buffer_size) { need_write = client->buffer_size; } else { @@ -698,7 +766,7 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c // send with ws specific way and specific opcode wlen = esp_transport_ws_send_raw(client->transport, current_opcode, (char *)client->tx_buffer, need_write, (timeout==portMAX_DELAY)? -1 : timeout * portTICK_PERIOD_MS); - if (wlen <= 0) { + if (wlen < 0 || (wlen == 0 && need_write != 0)) { ret = wlen; ESP_LOGE(TAG, "Network error: esp_transport_write() returned %d, errno=%d", ret, errno); esp_websocket_client_abort_connection(client); diff --git a/components/esp_websocket_client/include/esp_websocket_client.h b/components/esp_websocket_client/include/esp_websocket_client.h index 5a0e52e0aa..68946c61ce 100644 --- a/components/esp_websocket_client/include/esp_websocket_client.h +++ b/components/esp_websocket_client/include/esp_websocket_client.h @@ -187,6 +187,36 @@ int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const ch */ int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout); +/** + * @brief Close the WebSocket connection in a clean way + * + * Sequence of clean close initiated by client: + * * Client sends CLOSE frame + * * Client waits until server echos the CLOSE frame + * * Client waits until server closes the connection + * * Client is stopped the same way as by the `esp_websocket_client_stop()` + * + * @param[in] client The client + * @param[in] timeout Timeout in RTOS ticks for waiting + * + * @return esp_err_t + */ +esp_err_t esp_websocket_client_close(esp_websocket_client_handle_t client, TickType_t timeout); + +/** + * @brief Close the WebSocket connection in a clean way with custom code/data + * Closing sequence is the same as for esp_websocket_client_close() + * + * @param[in] client The client + * @param[in] code Close status code as defined in RFC6455 section-7.4 + * @param[in] data Additional data to closing message + * @param[in] len The length of the additional data + * @param[in] timeout Timeout in RTOS ticks for waiting + * + * @return esp_err_t + */ +esp_err_t esp_websocket_client_close_with_code(esp_websocket_client_handle_t client, int code, const char *data, int len, TickType_t timeout); + /** * @brief Check the WebSocket client connection state * diff --git a/examples/protocols/websocket/main/websocket_example.c b/examples/protocols/websocket/main/websocket_example.c index e424066d27..eb3db80ce7 100644 --- a/examples/protocols/websocket/main/websocket_example.c +++ b/examples/protocols/websocket/main/websocket_example.c @@ -69,7 +69,11 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i case WEBSOCKET_EVENT_DATA: ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA"); ESP_LOGI(TAG, "Received opcode=%d", data->op_code); - ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr); + if (data->op_code == 0x08 && data->data_len == 2) { + ESP_LOGW(TAG, "Received closed message with code=%d", 256*data->data_ptr[0] + data->data_ptr[1]); + } else { + ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr); + } ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset); xTimerReset(shutdown_signal_timer, portMAX_DELAY); @@ -121,7 +125,7 @@ static void websocket_app_start(void) } xSemaphoreTake(shutdown_sema, portMAX_DELAY); - esp_websocket_client_stop(client); + esp_websocket_client_close(client, portMAX_DELAY); ESP_LOGI(TAG, "Websocket Stopped"); esp_websocket_client_destroy(client); } From 5e9f8b52e7a87371370205a387b2d94e5ac6cbf9 Mon Sep 17 00:00:00 2001 From: David Cermak Date: Fri, 17 Jul 2020 17:59:05 +0200 Subject: [PATCH 2/3] tcp_transport: Added internal API for underlying socket, used for custom select on connection end for WS Internal tcp_transport functions could now use custom socket operations. This is used for WebSocket transport, when we typically wait for clean connection closure, i.e. selecting for read/error with expected errno or recv size=0 while socket readable (=connection terminated by FIN flag) --- .../esp_websocket_client.c | 99 ++++++++++++------- .../include/esp_websocket_client.h | 7 +- .../tcp_transport/include/esp_transport.h | 2 +- .../tcp_transport/include/esp_transport_ws.h | 15 +++ .../private_include/esp_transport_internal.h | 56 +++++++++++ .../esp_transport_ssl_internal.h | 6 +- components/tcp_transport/transport.c | 33 ++----- components/tcp_transport/transport_ssl.c | 13 +++ components/tcp_transport/transport_tcp.c | 13 +++ components/tcp_transport/transport_ws.c | 52 ++++++++++ 10 files changed, 232 insertions(+), 64 deletions(-) create mode 100644 components/tcp_transport/private_include/esp_transport_internal.h diff --git a/components/esp_websocket_client/esp_websocket_client.c b/components/esp_websocket_client/esp_websocket_client.c index 23bdd3c6a5..bd670d93e9 100644 --- a/components/esp_websocket_client/esp_websocket_client.c +++ b/components/esp_websocket_client/esp_websocket_client.c @@ -52,8 +52,8 @@ static const char *TAG = "WEBSOCKET_CLIENT"; } const static int STOPPED_BIT = BIT0; -const static int CLOSING_BIT = BIT1; // Indicates that a close frame received from server - // and we are waiting for the "Reset by Peer" from the server +const static int CLOSE_FRAME_SENT_BIT = BIT1; // Indicates that a close frame was sent by the client + // and we are waiting for the server to continue with clean close ESP_EVENT_DEFINE_BASE(WEBSOCKET_EVENTS); @@ -82,6 +82,7 @@ typedef enum { WEBSOCKET_STATE_INIT, WEBSOCKET_STATE_CONNECTED, WEBSOCKET_STATE_WAIT_TIMEOUT, + WEBSOCKET_STATE_CLOSING, } websocket_client_state_t; struct esp_websocket_client { @@ -479,11 +480,6 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client) do { rlen = esp_transport_read(client->transport, client->rx_buffer, client->buffer_size, client->config->network_timeout_ms); if (rlen < 0) { - if (CLOSING_BIT & xEventGroupGetBits(client->status_bits)) { - client->run = false; - client->state = WEBSOCKET_STATE_UNKNOW; - return ESP_OK; - } ESP_LOGE(TAG, "Error read data"); return ESP_FAIL; } @@ -503,12 +499,17 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client) } else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) { client->wait_for_pong_resp = false; } else if (client->last_opcode == WS_TRANSPORT_OPCODES_CLOSE) { - xEventGroupSetBits(client->status_bits, CLOSING_BIT); + ESP_LOGD(TAG, "Received close frame"); + client->state = WEBSOCKET_STATE_CLOSING; } return ESP_OK; } +static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const uint8_t *data, int len, TickType_t timeout); + +static int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout); + static void esp_websocket_client_task(void *pv) { const int lock_timeout = portMAX_DELAY; @@ -528,7 +529,7 @@ static void esp_websocket_client_task(void *pv) } client->state = WEBSOCKET_STATE_INIT; - xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSING_BIT); + xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSE_FRAME_SENT_BIT); int read_select = 0; while (client->run) { if (xSemaphoreTakeRecursive(client->lock, lock_timeout) != pdPASS) { @@ -558,22 +559,25 @@ static void esp_websocket_client_task(void *pv) break; case WEBSOCKET_STATE_CONNECTED: - if (_tick_get_ms() - client->ping_tick_ms > WEBSOCKET_PING_TIMEOUT_MS) { - client->ping_tick_ms = _tick_get_ms(); - ESP_LOGD(TAG, "Sending PING..."); - esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PING | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms); + if ((CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits)) == 0) { // only send and check for PING + // if closing hasn't been initiated + if (_tick_get_ms() - client->ping_tick_ms > WEBSOCKET_PING_TIMEOUT_MS) { + client->ping_tick_ms = _tick_get_ms(); + ESP_LOGD(TAG, "Sending PING..."); + esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PING | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms); - if (!client->wait_for_pong_resp && client->config->pingpong_timeout_sec) { - client->pingpong_tick_ms = _tick_get_ms(); - client->wait_for_pong_resp = true; + if (!client->wait_for_pong_resp && client->config->pingpong_timeout_sec) { + client->pingpong_tick_ms = _tick_get_ms(); + client->wait_for_pong_resp = true; + } } - } - if ( _tick_get_ms() - client->pingpong_tick_ms > client->config->pingpong_timeout_sec*1000 ) { - if (client->wait_for_pong_resp) { - ESP_LOGE(TAG, "Error, no PONG received for more than %d seconds after PING", client->config->pingpong_timeout_sec); - esp_websocket_client_abort_connection(client); - break; + if ( _tick_get_ms() - client->pingpong_tick_ms > client->config->pingpong_timeout_sec*1000 ) { + if (client->wait_for_pong_resp) { + ESP_LOGE(TAG, "Error, no PONG received for more than %d seconds after PING", client->config->pingpong_timeout_sec); + esp_websocket_client_abort_connection(client); + break; + } } } @@ -601,22 +605,43 @@ static void esp_websocket_client_task(void *pv) ESP_LOGD(TAG, "Reconnecting..."); } break; + case WEBSOCKET_STATE_CLOSING: + // if closing not initiated by the client echo the close message back + if ((CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits)) == 0) { + ESP_LOGD(TAG, "Closing initiated by the server, sending close frame"); + esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_CLOSE | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms); + xEventGroupSetBits(client->status_bits, CLOSE_FRAME_SENT_BIT); + } + break; + default: + ESP_LOGD(TAG, "Client run iteration in a default state: %d", client->state); + break; } xSemaphoreGiveRecursive(client->lock); if (WEBSOCKET_STATE_CONNECTED == client->state) { read_select = esp_transport_poll_read(client->transport, 1000); //Poll every 1000ms if (read_select < 0) { - if (CLOSING_BIT & xEventGroupGetBits(client->status_bits)) { - client->run = false; - client->state = WEBSOCKET_STATE_UNKNOW; - break; - } ESP_LOGE(TAG, "Network error: esp_transport_poll_read() returned %d, errno=%d", read_select, errno); esp_websocket_client_abort_connection(client); } } else if (WEBSOCKET_STATE_WAIT_TIMEOUT == client->state) { // waiting for reconnecting... vTaskDelay(client->wait_timeout_ms / 2 / portTICK_RATE_MS); + } else if (WEBSOCKET_STATE_CLOSING == client->state && + (CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits))) { + ESP_LOGD(TAG, " Waiting for TCP connection to be closed by the server"); + int ret = esp_transport_ws_poll_connection_closed(client->transport, 1000); + if (ret == 0) { + // still waiting + break; + } + if (ret < 0) { + ESP_LOGW(TAG, "Connection terminated while waiting for clean TCP close"); + } + client->run = false; + client->state = WEBSOCKET_STATE_UNKNOW; + esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_CLOSED, NULL, 0); + break; } } @@ -639,7 +664,7 @@ esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client) ESP_LOGE(TAG, "Error create websocket task"); return ESP_FAIL; } - xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSING_BIT); + xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSE_FRAME_SENT_BIT); return ESP_OK; } @@ -658,14 +683,13 @@ esp_err_t esp_websocket_client_stop(esp_websocket_client_handle_t client) return ESP_OK; } -static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, uint8_t *data, int len, TickType_t timeout); - -int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout) +static int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout) { uint8_t *close_status_data = NULL; // RFC6455#section-5.5.1: The Close frame MAY contain a body (indicated by total_len >= 2) if (total_len >= 2) { close_status_data = calloc(1, total_len); + ESP_WS_CLIENT_MEM_CHECK(TAG, close_status_data, return -1); // RFC6455#section-5.5.1: The first two bytes of the body MUST be a 2-byte representing a status uint16_t *code_network_order = (uint16_t *) close_status_data; *code_network_order = htons(code); @@ -693,6 +717,9 @@ static esp_err_t esp_websocket_client_close_with_optional_body(esp_websocket_cli esp_websocket_client_send_close(client, 0, NULL, 0, portMAX_DELAY); // only opcode frame } + // Set closing bit to prevent from sending PING frames while connected + xEventGroupSetBits(client->status_bits, CLOSE_FRAME_SENT_BIT); + if (STOPPED_BIT & xEventGroupWaitBits(client->status_bits, STOPPED_BIT, false, true, timeout)) { return ESP_OK; } @@ -716,20 +743,20 @@ esp_err_t esp_websocket_client_close(esp_websocket_client_handle_t client, TickT int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) { - return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, (uint8_t *)data, len, timeout); + return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, (const uint8_t *)data, len, timeout); } int esp_websocket_client_send(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) { - return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (uint8_t *)data, len, timeout); + return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (const uint8_t *)data, len, timeout); } int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) { - return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (uint8_t *)data, len, timeout); + return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (const uint8_t *)data, len, timeout); } -static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, uint8_t *data, int len, TickType_t timeout) +static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const uint8_t *data, int len, TickType_t timeout) { int need_write = len; int wlen = 0, widx = 0; @@ -756,7 +783,7 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c goto unlock_and_return; } uint32_t current_opcode = opcode; - while (widx < len || current_opcode) { // allow for sending "current_opcode" only massage with len==0 + while (widx < len || current_opcode) { // allow for sending "current_opcode" only message with len==0 if (need_write > client->buffer_size) { need_write = client->buffer_size; } else { diff --git a/components/esp_websocket_client/include/esp_websocket_client.h b/components/esp_websocket_client/include/esp_websocket_client.h index 68946c61ce..44488cdb0b 100644 --- a/components/esp_websocket_client/include/esp_websocket_client.h +++ b/components/esp_websocket_client/include/esp_websocket_client.h @@ -40,6 +40,7 @@ typedef enum { WEBSOCKET_EVENT_CONNECTED, /*!< Once the Websocket has been connected to the server, no data exchange has been performed */ WEBSOCKET_EVENT_DISCONNECTED, /*!< The connection has been disconnected */ WEBSOCKET_EVENT_DATA, /*!< When receiving data from the server, possibly multiple portions of the packet */ + WEBSOCKET_EVENT_CLOSED, /*!< The connection has been closed cleanly */ WEBSOCKET_EVENT_MAX } esp_websocket_event_id_t; @@ -125,7 +126,11 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client); /** - * @brief Close the WebSocket connection + * @brief Stops the WebSocket connection without websocket closing handshake + * + * This API stops ws client and closes TCP connection directly without sending + * close frames. It is a good practice to close the connection in a clean way + * using esp_websocket_client_close(). * * @param[in] client The client * diff --git a/components/tcp_transport/include/esp_transport.h b/components/tcp_transport/include/esp_transport.h index 4841725d0f..b13063691d 100644 --- a/components/tcp_transport/include/esp_transport.h +++ b/components/tcp_transport/include/esp_transport.h @@ -310,7 +310,7 @@ esp_err_t esp_transport_set_parent_transport_func(esp_transport_handle_t t, payl * @return * - valid pointer of esp_error_handle_t * - NULL if invalid transport handle - */ + */ esp_tls_error_handle_t esp_transport_get_error_handle(esp_transport_handle_t t); diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index 08ed2ef27e..febe1d0bb9 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -117,6 +117,21 @@ ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t */ int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t); +/** + * @brief Polls the active connection for termination + * + * This API is typically used by the client to wait for clean connection closure + * by websocket server + * + * @param t Websocket transport handle + * @param[in] timeout_ms The timeout milliseconds + * + * @return + * 0 - no activity on read and error socket descriptor within timeout + * 1 - Success: either connection terminated by FIN or the most common RST err codes + * -1 - Failure: Unexpected error code or socket is normally readable + */ +int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms); #ifdef __cplusplus } diff --git a/components/tcp_transport/private_include/esp_transport_internal.h b/components/tcp_transport/private_include/esp_transport_internal.h new file mode 100644 index 0000000000..c17fad8427 --- /dev/null +++ b/components/tcp_transport/private_include/esp_transport_internal.h @@ -0,0 +1,56 @@ +// Copyright 2020 Espressif Systems (Shanghai) PTE LTD +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _ESP_TRANSPORT_INTERNAL_H_ +#define _ESP_TRANSPORT_INTERNAL_H_ + +#include "esp_transport.h" +#include "sys/queue.h" + +typedef int (*get_socket_func)(esp_transport_handle_t t); + +/** + * Transport layer structure, which will provide functions, basic properties for transport types + */ +struct esp_transport_item_t { + int port; + char *scheme; /*!< Tag name */ + void *data; /*!< Additional transport data */ + connect_func _connect; /*!< Connect function of this transport */ + io_read_func _read; /*!< Read */ + io_func _write; /*!< Write */ + trans_func _close; /*!< Close */ + poll_func _poll_read; /*!< Poll and read */ + poll_func _poll_write; /*!< Poll and write */ + trans_func _destroy; /*!< Destroy and free transport */ + connect_async_func _connect_async; /*!< non-blocking connect function of this transport */ + payload_transfer_func _parent_transfer; /*!< Function returning underlying transport layer */ + get_socket_func _get_socket; + esp_tls_error_handle_t error_handle; /*!< Pointer to esp-tls error handle */ + + STAILQ_ENTRY(esp_transport_item_t) next; +}; + +/** + * @brief Returns underlying socket for the supplied transport handle + * + * @param t Transport handle + * + * @return Socket file descriptor in case of success + * -1 in case of error + */ +int esp_transport_get_socket(esp_transport_handle_t t); + + +#endif //_ESP_TRANSPORT_INTERNAL_H_ diff --git a/components/tcp_transport/private_include/esp_transport_ssl_internal.h b/components/tcp_transport/private_include/esp_transport_ssl_internal.h index 8b794e05b9..07b8c39435 100644 --- a/components/tcp_transport/private_include/esp_transport_ssl_internal.h +++ b/components/tcp_transport/private_include/esp_transport_ssl_internal.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef _ESP_TRANSPORT_INTERNAL_H_ -#define _ESP_TRANSPORT_INTERNAL_H_ +#ifndef _ESP_TRANSPORT_SSL_INTERNAL_H_ +#define _ESP_TRANSPORT_SSL_INTERNAL_H_ /** * @brief Sets error to common transport handle @@ -27,4 +27,4 @@ void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle); -#endif /* _ESP_TRANSPORT_INTERNAL_H_ */ +#endif /* _ESP_TRANSPORT_SSL_INTERNAL_H_ */ diff --git a/components/tcp_transport/transport.c b/components/tcp_transport/transport.c index d5bc57bb48..7ebc1fba25 100644 --- a/components/tcp_transport/transport.c +++ b/components/tcp_transport/transport.c @@ -21,32 +21,11 @@ #include "esp_log.h" #include "esp_transport.h" +#include "esp_transport_internal.h" #include "esp_transport_utils.h" static const char *TAG = "TRANSPORT"; -/** - * Transport layer structure, which will provide functions, basic properties for transport types - */ -struct esp_transport_item_t { - int port; - int socket; /*!< Socket to use in this transport */ - char *scheme; /*!< Tag name */ - void *context; /*!< Context data */ - void *data; /*!< Additional transport data */ - connect_func _connect; /*!< Connect function of this transport */ - io_read_func _read; /*!< Read */ - io_func _write; /*!< Write */ - trans_func _close; /*!< Close */ - poll_func _poll_read; /*!< Poll and read */ - poll_func _poll_write; /*!< Poll and write */ - trans_func _destroy; /*!< Destroy and free transport */ - connect_async_func _connect_async; /*!< non-blocking connect function of this transport */ - payload_transfer_func _parent_transfer; /*!< Function returning underlying transport layer */ - esp_tls_error_handle_t error_handle; /*!< Pointer to esp-tls error handle */ - - STAILQ_ENTRY(esp_transport_item_t) next; -}; /** @@ -305,4 +284,12 @@ void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_hand if (t) { memcpy(t->error_handle, error_handle, sizeof(esp_tls_last_error_t)); } -} \ No newline at end of file +} + +int esp_transport_get_socket(esp_transport_handle_t t) +{ + if (t && t->_get_socket) { + return t->_get_socket(t); + } + return -1; +} diff --git a/components/tcp_transport/transport_ssl.c b/components/tcp_transport/transport_ssl.c index a413c84617..a6863fd3a5 100644 --- a/components/tcp_transport/transport_ssl.c +++ b/components/tcp_transport/transport_ssl.c @@ -25,6 +25,7 @@ #include "esp_transport_ssl.h" #include "esp_transport_utils.h" #include "esp_transport_ssl_internal.h" +#include "esp_transport_internal.h" static const char *TAG = "TRANS_SSL"; @@ -288,6 +289,17 @@ void esp_transport_ssl_use_secure_element(esp_transport_handle_t t) } } +static int ssl_get_socket(esp_transport_handle_t t) +{ + if (t) { + transport_ssl_t *ssl = t->data; + if (ssl && ssl->tls) { + return ssl->tls->sockfd; + } + } + return -1; +} + esp_transport_handle_t esp_transport_ssl_init(void) { esp_transport_handle_t t = esp_transport_init(); @@ -296,6 +308,7 @@ esp_transport_handle_t esp_transport_ssl_init(void) esp_transport_set_context_data(t, ssl); esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy); esp_transport_set_async_connect_func(t, ssl_connect_async); + t->_get_socket = ssl_get_socket; return t; } diff --git a/components/tcp_transport/transport_tcp.c b/components/tcp_transport/transport_tcp.c index 69137ffca6..87a979c9c8 100644 --- a/components/tcp_transport/transport_tcp.c +++ b/components/tcp_transport/transport_tcp.c @@ -25,6 +25,7 @@ #include "esp_transport_utils.h" #include "esp_transport.h" +#include "esp_transport_internal.h" static const char *TAG = "TRANS_TCP"; @@ -234,6 +235,17 @@ static esp_err_t tcp_destroy(esp_transport_handle_t t) return 0; } +static int tcp_get_socket(esp_transport_handle_t t) +{ + if (t) { + transport_tcp_t *tcp = t->data; + if (tcp) { + return tcp->sock; + } + } + return -1; +} + esp_transport_handle_t esp_transport_tcp_init(void) { esp_transport_handle_t t = esp_transport_init(); @@ -242,6 +254,7 @@ esp_transport_handle_t esp_transport_tcp_init(void) tcp->sock = -1; esp_transport_set_func(t, tcp_connect, tcp_read, tcp_write, tcp_close, tcp_poll_read, tcp_poll_write, tcp_destroy); esp_transport_set_context_data(t, tcp); + t->_get_socket = tcp_get_socket; return t; } diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index dc18adf09d..d2ea9648ba 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -2,6 +2,7 @@ #include #include #include +#include #include "esp_log.h" #include "esp_transport.h" #include "esp_transport_tcp.h" @@ -9,6 +10,8 @@ #include "esp_transport_utils.h" #include "mbedtls/base64.h" #include "mbedtls/sha1.h" +#include "esp_transport_internal.h" +#include "errno.h" static const char *TAG = "TRANSPORT_WS"; @@ -449,6 +452,17 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path) strcpy(ws->path, path); } +static int ws_get_socket(esp_transport_handle_t t) +{ + if (t) { + transport_ws_t *ws = t->data; + if (ws && ws->parent && ws->parent->_get_socket) { + return ws->parent->_get_socket(ws->parent); + } + } + return -1; +} + esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle) { esp_transport_handle_t t = esp_transport_init(); @@ -473,6 +487,7 @@ esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handl esp_transport_set_parent_transport_func(t, ws_get_payload_transport_handle); esp_transport_set_context_data(t, ws); + t->_get_socket = ws_get_socket; return t; } @@ -548,4 +563,41 @@ int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t) return ws->frame_state.payload_len; } +int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms) +{ + struct timeval timeout; + int sock = esp_transport_get_socket(t); + fd_set readset; + fd_set errset; + FD_ZERO(&readset); + FD_ZERO(&errset); + FD_SET(sock, &readset); + FD_SET(sock, &errset); + + int ret = select(sock + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); + if (ret > 0) { + if (FD_ISSET(sock, &readset)) { + uint8_t buffer; + if (recv(sock, &buffer, 1, MSG_PEEK) <= 0) { + // socket is readable, but reads zero bytes -- connection cleanly closed by FIN flag + return 1; + } + ESP_LOGW(TAG, "esp_transport_ws_poll_connection_closed: unexpected data readable on socket=%d", sock); + } else if (FD_ISSET(sock, &errset)) { + int sock_errno = 0; + uint32_t optlen = sizeof(sock_errno); + getsockopt(sock, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); + ESP_LOGD(TAG, "esp_transport_ws_poll_connection_closed select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), sock); + if (sock_errno == ENOTCONN || sock_errno == ECONNRESET || sock_errno == ECONNABORTED) { + // the three err codes above might be caused by connection termination by RTS flag + // which we still assume as expected closing sequence of ws-transport connection + return 1; + } + ESP_LOGE(TAG, "esp_transport_ws_poll_connection_closed: unexpected errno=%d on socket=%d", sock_errno, sock); + } + return -1; // indicates error: socket unexpectedly reads an actual data, or unexpected errno code + } + return ret; + +} From 44c553fd1479bf990cf0d96dfce894c50900324e Mon Sep 17 00:00:00 2001 From: David Cermak Date: Tue, 21 Jul 2020 16:04:25 +0200 Subject: [PATCH 3/3] ws_client tests: Updated example test to use WebsSocket package Added a new test for closing connection with close frames --- examples/protocols/websocket/example_test.py | 186 ++++--------------- 1 file changed, 38 insertions(+), 148 deletions(-) diff --git a/examples/protocols/websocket/example_test.py b/examples/protocols/websocket/example_test.py index 6dfdccf5d8..0bb085d265 100644 --- a/examples/protocols/websocket/example_test.py +++ b/examples/protocols/websocket/example_test.py @@ -3,12 +3,10 @@ from __future__ import unicode_literals import re import os import socket -import select -import hashlib -import base64 -import queue import random import string +from SimpleWebSocketServer import SimpleWebSocketServer, WebSocket +from tiny_test_fw import Utility from threading import Thread, Event import ttfw_idf @@ -26,159 +24,45 @@ def get_my_ip(): return IP +class TestEcho(WebSocket): + + def handleMessage(self): + self.sendMessage(self.data) + print('Server sent: {}'.format(self.data)) + + def handleConnected(self): + print('Connection from: {}'.format(self.address)) + + def handleClose(self): + print('{} closed the connection'.format(self.address)) + + # Simple Websocket server for testing purposes -class Websocket: - HEADER_LEN = 6 +class Websocket(object): + + def send_data(self, data): + for nr, conn in self.server.connections.items(): + conn.sendMessage(data) + + def run(self): + self.server = SimpleWebSocketServer('', self.port, TestEcho) + while not self.exit_event.is_set(): + self.server.serveonce() def __init__(self, port): self.port = port - self.socket = socket.socket() - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.settimeout(10.0) - self.send_q = queue.Queue() - self.shutdown = Event() + self.exit_event = Event() + self.thread = Thread(target=self.run) + self.thread.start() def __enter__(self): - try: - self.socket.bind(('', self.port)) - except socket.error as e: - print("Bind failed:{}".format(e)) - raise - - self.socket.listen(1) - self.server_thread = Thread(target=self.run_server) - self.server_thread.start() - return self def __exit__(self, exc_type, exc_value, traceback): - self.shutdown.set() - self.server_thread.join() - self.socket.close() - self.conn.close() - - def run_server(self): - self.conn, address = self.socket.accept() # accept new connection - self.socket.settimeout(10.0) - - print("Connection from: {}".format(address)) - - self.establish_connection() - print("WS established") - # Handle connection until client closes it, will echo any data received and send data from send_q queue - self.handle_conn() - - def establish_connection(self): - while not self.shutdown.is_set(): - try: - # receive data stream. it won't accept data packet greater than 1024 bytes - data = self.conn.recv(1024).decode() - if not data: - # exit if data is not received - raise - - if "Upgrade: websocket" in data and "Connection: Upgrade" in data: - self.handshake(data) - return - - except socket.error as err: - print("Unable to establish a websocket connection: {}".format(err)) - raise - - def handshake(self, data): - # Magic string from RFC - MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - headers = data.split("\r\n") - - for header in headers: - if "Sec-WebSocket-Key" in header: - client_key = header.split()[1] - - if client_key: - resp_key = client_key + MAGIC_STRING - resp_key = base64.standard_b64encode(hashlib.sha1(resp_key.encode()).digest()) - - resp = "HTTP/1.1 101 Switching Protocols\r\n" + \ - "Upgrade: websocket\r\n" + \ - "Connection: Upgrade\r\n" + \ - "Sec-WebSocket-Accept: {}\r\n\r\n".format(resp_key.decode()) - - self.conn.send(resp.encode()) - - def handle_conn(self): - while not self.shutdown.is_set(): - r,w,e = select.select([self.conn], [], [], 1) - try: - if self.conn in r: - self.echo_data() - - if not self.send_q.empty(): - self._send_data_(self.send_q.get()) - - except socket.error as err: - print("Stopped echoing data: {}".format(err)) - raise - - def echo_data(self): - header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL)) - if not header: - # exit if socket closed by peer - return - - # Remove mask bit - payload_len = ~(1 << 7) & header[1] - - payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL)) - - if not payload: - # exit if socket closed by peer - return - frame = header + payload - - decoded_payload = self.decode_frame(frame) - print("Sending echo...") - self._send_data_(decoded_payload) - - def _send_data_(self, data): - frame = self.encode_frame(data) - self.conn.send(frame) - - def send_data(self, data): - self.send_q.put(data.encode()) - - def decode_frame(self, frame): - # Mask out MASK bit from payload length, this len is only valid for short messages (<126) - payload_len = ~(1 << 7) & frame[1] - - mask = frame[2:self.HEADER_LEN] - - encrypted_payload = frame[self.HEADER_LEN:self.HEADER_LEN + payload_len] - payload = bytearray() - - for i in range(payload_len): - payload.append(encrypted_payload[i] ^ mask[i % 4]) - - return payload - - def encode_frame(self, payload): - # Set FIN = 1 and OP_CODE = 1 (text) - header = (1 << 7) | (1 << 0) - - frame = bytearray([header]) - payload_len = len(payload) - - # If payload len is longer than 125 then the next 16 bits are used to encode length - if payload_len > 125: - frame.append(126) - frame.append(payload_len >> 8) - frame.append(0xFF & payload_len) - - else: - frame.append(payload_len) - - frame += payload - - return frame + self.exit_event.set() + self.thread.join(10) + if self.thread.is_alive(): + Utility.console_log('Thread cannot be joined', 'orange') def test_echo(dut): @@ -188,6 +72,11 @@ def test_echo(dut): print("All echos received") +def test_close(dut): + code = dut.expect(re.compile(r"WEBSOCKET: Received closed message with code=(\d*)"), timeout=60)[0] + print("Received close frame with code {}".format(code)) + + def test_recv_long_msg(dut, websocket, msg_len, repeats): send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len)) @@ -246,6 +135,7 @@ def test_examples_protocol_websocket(env, extra_data): test_echo(dut1) # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte test_recv_long_msg(dut1, ws, 2000, 3) + test_close(dut1) else: print("DUT connecting to {}".format(uri))