Merge branch 'feature/ws_client_close_frame' into 'master'

ws_client: Added support for close frame, closing connection gracefully

Closes IDF-1915

See merge request espressif/esp-idf!9677
This commit is contained in:
David Čermák
2020-08-21 14:36:18 +08:00
12 changed files with 363 additions and 203 deletions

View File

@@ -52,6 +52,8 @@ static const char *TAG = "WEBSOCKET_CLIENT";
} }
const static int STOPPED_BIT = BIT0; const static int STOPPED_BIT = BIT0;
const static int CLOSE_FRAME_SENT_BIT = BIT1; // Indicates that a close frame was sent by the client
// and we are waiting for the server to continue with clean close
ESP_EVENT_DEFINE_BASE(WEBSOCKET_EVENTS); ESP_EVENT_DEFINE_BASE(WEBSOCKET_EVENTS);
@@ -80,6 +82,7 @@ typedef enum {
WEBSOCKET_STATE_INIT, WEBSOCKET_STATE_INIT,
WEBSOCKET_STATE_CONNECTED, WEBSOCKET_STATE_CONNECTED,
WEBSOCKET_STATE_WAIT_TIMEOUT, WEBSOCKET_STATE_WAIT_TIMEOUT,
WEBSOCKET_STATE_CLOSING,
} websocket_client_state_t; } websocket_client_state_t;
struct esp_websocket_client { struct esp_websocket_client {
@@ -493,14 +496,20 @@ static esp_err_t esp_websocket_client_recv(esp_websocket_client_handle_t client)
const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer; const char *data = (client->payload_len == 0) ? NULL : client->rx_buffer;
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG | WS_TRANSPORT_OPCODES_FIN, data, client->payload_len, esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PONG | WS_TRANSPORT_OPCODES_FIN, data, client->payload_len,
client->config->network_timeout_ms); client->config->network_timeout_ms);
} } else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) {
else if (client->last_opcode == WS_TRANSPORT_OPCODES_PONG) {
client->wait_for_pong_resp = false; client->wait_for_pong_resp = false;
} else if (client->last_opcode == WS_TRANSPORT_OPCODES_CLOSE) {
ESP_LOGD(TAG, "Received close frame");
client->state = WEBSOCKET_STATE_CLOSING;
} }
return ESP_OK; return ESP_OK;
} }
static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const uint8_t *data, int len, TickType_t timeout);
static int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout);
static void esp_websocket_client_task(void *pv) static void esp_websocket_client_task(void *pv)
{ {
const int lock_timeout = portMAX_DELAY; const int lock_timeout = portMAX_DELAY;
@@ -520,7 +529,7 @@ static void esp_websocket_client_task(void *pv)
} }
client->state = WEBSOCKET_STATE_INIT; client->state = WEBSOCKET_STATE_INIT;
xEventGroupClearBits(client->status_bits, STOPPED_BIT); xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSE_FRAME_SENT_BIT);
int read_select = 0; int read_select = 0;
while (client->run) { while (client->run) {
if (xSemaphoreTakeRecursive(client->lock, lock_timeout) != pdPASS) { if (xSemaphoreTakeRecursive(client->lock, lock_timeout) != pdPASS) {
@@ -550,22 +559,25 @@ static void esp_websocket_client_task(void *pv)
break; break;
case WEBSOCKET_STATE_CONNECTED: case WEBSOCKET_STATE_CONNECTED:
if (_tick_get_ms() - client->ping_tick_ms > WEBSOCKET_PING_TIMEOUT_MS) { if ((CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits)) == 0) { // only send and check for PING
client->ping_tick_ms = _tick_get_ms(); // if closing hasn't been initiated
ESP_LOGD(TAG, "Sending PING..."); if (_tick_get_ms() - client->ping_tick_ms > WEBSOCKET_PING_TIMEOUT_MS) {
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PING | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms); client->ping_tick_ms = _tick_get_ms();
ESP_LOGD(TAG, "Sending PING...");
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_PING | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms);
if (!client->wait_for_pong_resp && client->config->pingpong_timeout_sec) { if (!client->wait_for_pong_resp && client->config->pingpong_timeout_sec) {
client->pingpong_tick_ms = _tick_get_ms(); client->pingpong_tick_ms = _tick_get_ms();
client->wait_for_pong_resp = true; client->wait_for_pong_resp = true;
}
} }
}
if ( _tick_get_ms() - client->pingpong_tick_ms > client->config->pingpong_timeout_sec*1000 ) { if ( _tick_get_ms() - client->pingpong_tick_ms > client->config->pingpong_timeout_sec*1000 ) {
if (client->wait_for_pong_resp) { if (client->wait_for_pong_resp) {
ESP_LOGE(TAG, "Error, no PONG received for more than %d seconds after PING", client->config->pingpong_timeout_sec); ESP_LOGE(TAG, "Error, no PONG received for more than %d seconds after PING", client->config->pingpong_timeout_sec);
esp_websocket_client_abort_connection(client); esp_websocket_client_abort_connection(client);
break; break;
}
} }
} }
@@ -593,6 +605,17 @@ static void esp_websocket_client_task(void *pv)
ESP_LOGD(TAG, "Reconnecting..."); ESP_LOGD(TAG, "Reconnecting...");
} }
break; break;
case WEBSOCKET_STATE_CLOSING:
// if closing not initiated by the client echo the close message back
if ((CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits)) == 0) {
ESP_LOGD(TAG, "Closing initiated by the server, sending close frame");
esp_transport_ws_send_raw(client->transport, WS_TRANSPORT_OPCODES_CLOSE | WS_TRANSPORT_OPCODES_FIN, NULL, 0, client->config->network_timeout_ms);
xEventGroupSetBits(client->status_bits, CLOSE_FRAME_SENT_BIT);
}
break;
default:
ESP_LOGD(TAG, "Client run iteration in a default state: %d", client->state);
break;
} }
xSemaphoreGiveRecursive(client->lock); xSemaphoreGiveRecursive(client->lock);
if (WEBSOCKET_STATE_CONNECTED == client->state) { if (WEBSOCKET_STATE_CONNECTED == client->state) {
@@ -604,6 +627,21 @@ static void esp_websocket_client_task(void *pv)
} else if (WEBSOCKET_STATE_WAIT_TIMEOUT == client->state) { } else if (WEBSOCKET_STATE_WAIT_TIMEOUT == client->state) {
// waiting for reconnecting... // waiting for reconnecting...
vTaskDelay(client->wait_timeout_ms / 2 / portTICK_RATE_MS); vTaskDelay(client->wait_timeout_ms / 2 / portTICK_RATE_MS);
} else if (WEBSOCKET_STATE_CLOSING == client->state &&
(CLOSE_FRAME_SENT_BIT & xEventGroupGetBits(client->status_bits))) {
ESP_LOGD(TAG, " Waiting for TCP connection to be closed by the server");
int ret = esp_transport_ws_poll_connection_closed(client->transport, 1000);
if (ret == 0) {
// still waiting
break;
}
if (ret < 0) {
ESP_LOGW(TAG, "Connection terminated while waiting for clean TCP close");
}
client->run = false;
client->state = WEBSOCKET_STATE_UNKNOW;
esp_websocket_client_dispatch_event(client, WEBSOCKET_EVENT_CLOSED, NULL, 0);
break;
} }
} }
@@ -626,7 +664,7 @@ esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client)
ESP_LOGE(TAG, "Error create websocket task"); ESP_LOGE(TAG, "Error create websocket task");
return ESP_FAIL; return ESP_FAIL;
} }
xEventGroupClearBits(client->status_bits, STOPPED_BIT); xEventGroupClearBits(client->status_bits, STOPPED_BIT | CLOSE_FRAME_SENT_BIT);
return ESP_OK; return ESP_OK;
} }
@@ -645,30 +683,87 @@ esp_err_t esp_websocket_client_stop(esp_websocket_client_handle_t client)
return ESP_OK; return ESP_OK;
} }
static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const char *data, int len, TickType_t timeout); static int esp_websocket_client_send_close(esp_websocket_client_handle_t client, int code, const char *additional_data, int total_len, TickType_t timeout)
{
uint8_t *close_status_data = NULL;
// RFC6455#section-5.5.1: The Close frame MAY contain a body (indicated by total_len >= 2)
if (total_len >= 2) {
close_status_data = calloc(1, total_len);
ESP_WS_CLIENT_MEM_CHECK(TAG, close_status_data, return -1);
// RFC6455#section-5.5.1: The first two bytes of the body MUST be a 2-byte representing a status
uint16_t *code_network_order = (uint16_t *) close_status_data;
*code_network_order = htons(code);
memcpy(close_status_data + 2, additional_data, total_len - 2);
}
int ret = esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_CLOSE, close_status_data, total_len, timeout);
free(close_status_data);
return ret;
}
static esp_err_t esp_websocket_client_close_with_optional_body(esp_websocket_client_handle_t client, bool send_body, int code, const char *data, int len, TickType_t timeout)
{
if (client == NULL) {
return ESP_ERR_INVALID_ARG;
}
if (!client->run) {
ESP_LOGW(TAG, "Client was not started");
return ESP_FAIL;
}
if (send_body) {
esp_websocket_client_send_close(client, code, data, len + 2, portMAX_DELAY); // len + 2 -> always sending the code
} else {
esp_websocket_client_send_close(client, 0, NULL, 0, portMAX_DELAY); // only opcode frame
}
// Set closing bit to prevent from sending PING frames while connected
xEventGroupSetBits(client->status_bits, CLOSE_FRAME_SENT_BIT);
if (STOPPED_BIT & xEventGroupWaitBits(client->status_bits, STOPPED_BIT, false, true, timeout)) {
return ESP_OK;
}
// If could not close gracefully within timeout, stop the client and disconnect
client->run = false;
xEventGroupWaitBits(client->status_bits, STOPPED_BIT, false, true, portMAX_DELAY);
client->state = WEBSOCKET_STATE_UNKNOW;
return ESP_OK;
}
esp_err_t esp_websocket_client_close_with_code(esp_websocket_client_handle_t client, int code, const char *data, int len, TickType_t timeout)
{
return esp_websocket_client_close_with_optional_body(client, true, code, data, len, timeout);
}
esp_err_t esp_websocket_client_close(esp_websocket_client_handle_t client, TickType_t timeout)
{
return esp_websocket_client_close_with_optional_body(client, false, 0, NULL, 0, timeout);
}
int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout)
{ {
return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, data, len, timeout); return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_TEXT, (const uint8_t *)data, len, timeout);
} }
int esp_websocket_client_send(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) int esp_websocket_client_send(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout)
{ {
return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, data, len, timeout); return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (const uint8_t *)data, len, timeout);
} }
int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout) int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout)
{ {
return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, data, len, timeout); return esp_websocket_client_send_with_opcode(client, WS_TRANSPORT_OPCODES_BINARY, (const uint8_t *)data, len, timeout);
} }
static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const char *data, int len, TickType_t timeout) static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t client, ws_transport_opcodes_t opcode, const uint8_t *data, int len, TickType_t timeout)
{ {
int need_write = len; int need_write = len;
int wlen = 0, widx = 0; int wlen = 0, widx = 0;
int ret = ESP_FAIL; int ret = ESP_FAIL;
if (client == NULL || data == NULL || len <= 0) { if (client == NULL || len < 0 ||
(opcode != WS_TRANSPORT_OPCODES_CLOSE && (data == NULL || len <= 0))) {
ESP_LOGE(TAG, "Invalid arguments"); ESP_LOGE(TAG, "Invalid arguments");
return ESP_FAIL; return ESP_FAIL;
} }
@@ -688,7 +783,7 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c
goto unlock_and_return; goto unlock_and_return;
} }
uint32_t current_opcode = opcode; uint32_t current_opcode = opcode;
while (widx < len) { while (widx < len || current_opcode) { // allow for sending "current_opcode" only message with len==0
if (need_write > client->buffer_size) { if (need_write > client->buffer_size) {
need_write = client->buffer_size; need_write = client->buffer_size;
} else { } else {
@@ -698,7 +793,7 @@ static int esp_websocket_client_send_with_opcode(esp_websocket_client_handle_t c
// send with ws specific way and specific opcode // send with ws specific way and specific opcode
wlen = esp_transport_ws_send_raw(client->transport, current_opcode, (char *)client->tx_buffer, need_write, wlen = esp_transport_ws_send_raw(client->transport, current_opcode, (char *)client->tx_buffer, need_write,
(timeout==portMAX_DELAY)? -1 : timeout * portTICK_PERIOD_MS); (timeout==portMAX_DELAY)? -1 : timeout * portTICK_PERIOD_MS);
if (wlen <= 0) { if (wlen < 0 || (wlen == 0 && need_write != 0)) {
ret = wlen; ret = wlen;
ESP_LOGE(TAG, "Network error: esp_transport_write() returned %d, errno=%d", ret, errno); ESP_LOGE(TAG, "Network error: esp_transport_write() returned %d, errno=%d", ret, errno);
esp_websocket_client_abort_connection(client); esp_websocket_client_abort_connection(client);

View File

@@ -40,6 +40,7 @@ typedef enum {
WEBSOCKET_EVENT_CONNECTED, /*!< Once the Websocket has been connected to the server, no data exchange has been performed */ WEBSOCKET_EVENT_CONNECTED, /*!< Once the Websocket has been connected to the server, no data exchange has been performed */
WEBSOCKET_EVENT_DISCONNECTED, /*!< The connection has been disconnected */ WEBSOCKET_EVENT_DISCONNECTED, /*!< The connection has been disconnected */
WEBSOCKET_EVENT_DATA, /*!< When receiving data from the server, possibly multiple portions of the packet */ WEBSOCKET_EVENT_DATA, /*!< When receiving data from the server, possibly multiple portions of the packet */
WEBSOCKET_EVENT_CLOSED, /*!< The connection has been closed cleanly */
WEBSOCKET_EVENT_MAX WEBSOCKET_EVENT_MAX
} esp_websocket_event_id_t; } esp_websocket_event_id_t;
@@ -125,7 +126,11 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con
esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client); esp_err_t esp_websocket_client_start(esp_websocket_client_handle_t client);
/** /**
* @brief Close the WebSocket connection * @brief Stops the WebSocket connection without websocket closing handshake
*
* This API stops ws client and closes TCP connection directly without sending
* close frames. It is a good practice to close the connection in a clean way
* using esp_websocket_client_close().
* *
* @param[in] client The client * @param[in] client The client
* *
@@ -187,6 +192,36 @@ int esp_websocket_client_send_bin(esp_websocket_client_handle_t client, const ch
*/ */
int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout); int esp_websocket_client_send_text(esp_websocket_client_handle_t client, const char *data, int len, TickType_t timeout);
/**
* @brief Close the WebSocket connection in a clean way
*
* Sequence of clean close initiated by client:
* * Client sends CLOSE frame
* * Client waits until server echos the CLOSE frame
* * Client waits until server closes the connection
* * Client is stopped the same way as by the `esp_websocket_client_stop()`
*
* @param[in] client The client
* @param[in] timeout Timeout in RTOS ticks for waiting
*
* @return esp_err_t
*/
esp_err_t esp_websocket_client_close(esp_websocket_client_handle_t client, TickType_t timeout);
/**
* @brief Close the WebSocket connection in a clean way with custom code/data
* Closing sequence is the same as for esp_websocket_client_close()
*
* @param[in] client The client
* @param[in] code Close status code as defined in RFC6455 section-7.4
* @param[in] data Additional data to closing message
* @param[in] len The length of the additional data
* @param[in] timeout Timeout in RTOS ticks for waiting
*
* @return esp_err_t
*/
esp_err_t esp_websocket_client_close_with_code(esp_websocket_client_handle_t client, int code, const char *data, int len, TickType_t timeout);
/** /**
* @brief Check the WebSocket client connection state * @brief Check the WebSocket client connection state
* *

View File

@@ -310,7 +310,7 @@ esp_err_t esp_transport_set_parent_transport_func(esp_transport_handle_t t, payl
* @return * @return
* - valid pointer of esp_error_handle_t * - valid pointer of esp_error_handle_t
* - NULL if invalid transport handle * - NULL if invalid transport handle
*/ */
esp_tls_error_handle_t esp_transport_get_error_handle(esp_transport_handle_t t); esp_tls_error_handle_t esp_transport_get_error_handle(esp_transport_handle_t t);

View File

@@ -117,6 +117,21 @@ ws_transport_opcodes_t esp_transport_ws_get_read_opcode(esp_transport_handle_t t
*/ */
int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t); int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t);
/**
* @brief Polls the active connection for termination
*
* This API is typically used by the client to wait for clean connection closure
* by websocket server
*
* @param t Websocket transport handle
* @param[in] timeout_ms The timeout milliseconds
*
* @return
* 0 - no activity on read and error socket descriptor within timeout
* 1 - Success: either connection terminated by FIN or the most common RST err codes
* -1 - Failure: Unexpected error code or socket is normally readable
*/
int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@@ -0,0 +1,56 @@
// Copyright 2020 Espressif Systems (Shanghai) PTE LTD
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef _ESP_TRANSPORT_INTERNAL_H_
#define _ESP_TRANSPORT_INTERNAL_H_
#include "esp_transport.h"
#include "sys/queue.h"
typedef int (*get_socket_func)(esp_transport_handle_t t);
/**
* Transport layer structure, which will provide functions, basic properties for transport types
*/
struct esp_transport_item_t {
int port;
char *scheme; /*!< Tag name */
void *data; /*!< Additional transport data */
connect_func _connect; /*!< Connect function of this transport */
io_read_func _read; /*!< Read */
io_func _write; /*!< Write */
trans_func _close; /*!< Close */
poll_func _poll_read; /*!< Poll and read */
poll_func _poll_write; /*!< Poll and write */
trans_func _destroy; /*!< Destroy and free transport */
connect_async_func _connect_async; /*!< non-blocking connect function of this transport */
payload_transfer_func _parent_transfer; /*!< Function returning underlying transport layer */
get_socket_func _get_socket;
esp_tls_error_handle_t error_handle; /*!< Pointer to esp-tls error handle */
STAILQ_ENTRY(esp_transport_item_t) next;
};
/**
* @brief Returns underlying socket for the supplied transport handle
*
* @param t Transport handle
*
* @return Socket file descriptor in case of success
* -1 in case of error
*/
int esp_transport_get_socket(esp_transport_handle_t t);
#endif //_ESP_TRANSPORT_INTERNAL_H_

View File

@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef _ESP_TRANSPORT_INTERNAL_H_ #ifndef _ESP_TRANSPORT_SSL_INTERNAL_H_
#define _ESP_TRANSPORT_INTERNAL_H_ #define _ESP_TRANSPORT_SSL_INTERNAL_H_
/** /**
* @brief Sets error to common transport handle * @brief Sets error to common transport handle
@@ -27,4 +27,4 @@
void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle); void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_handle_t error_handle);
#endif /* _ESP_TRANSPORT_INTERNAL_H_ */ #endif /* _ESP_TRANSPORT_SSL_INTERNAL_H_ */

View File

@@ -21,32 +21,11 @@
#include "esp_log.h" #include "esp_log.h"
#include "esp_transport.h" #include "esp_transport.h"
#include "esp_transport_internal.h"
#include "esp_transport_utils.h" #include "esp_transport_utils.h"
static const char *TAG = "TRANSPORT"; static const char *TAG = "TRANSPORT";
/**
* Transport layer structure, which will provide functions, basic properties for transport types
*/
struct esp_transport_item_t {
int port;
int socket; /*!< Socket to use in this transport */
char *scheme; /*!< Tag name */
void *context; /*!< Context data */
void *data; /*!< Additional transport data */
connect_func _connect; /*!< Connect function of this transport */
io_read_func _read; /*!< Read */
io_func _write; /*!< Write */
trans_func _close; /*!< Close */
poll_func _poll_read; /*!< Poll and read */
poll_func _poll_write; /*!< Poll and write */
trans_func _destroy; /*!< Destroy and free transport */
connect_async_func _connect_async; /*!< non-blocking connect function of this transport */
payload_transfer_func _parent_transfer; /*!< Function returning underlying transport layer */
esp_tls_error_handle_t error_handle; /*!< Pointer to esp-tls error handle */
STAILQ_ENTRY(esp_transport_item_t) next;
};
/** /**
@@ -305,4 +284,12 @@ void esp_transport_set_errors(esp_transport_handle_t t, const esp_tls_error_hand
if (t) { if (t) {
memcpy(t->error_handle, error_handle, sizeof(esp_tls_last_error_t)); memcpy(t->error_handle, error_handle, sizeof(esp_tls_last_error_t));
} }
} }
int esp_transport_get_socket(esp_transport_handle_t t)
{
if (t && t->_get_socket) {
return t->_get_socket(t);
}
return -1;
}

View File

@@ -25,6 +25,7 @@
#include "esp_transport_ssl.h" #include "esp_transport_ssl.h"
#include "esp_transport_utils.h" #include "esp_transport_utils.h"
#include "esp_transport_ssl_internal.h" #include "esp_transport_ssl_internal.h"
#include "esp_transport_internal.h"
static const char *TAG = "TRANS_SSL"; static const char *TAG = "TRANS_SSL";
@@ -288,6 +289,17 @@ void esp_transport_ssl_use_secure_element(esp_transport_handle_t t)
} }
} }
static int ssl_get_socket(esp_transport_handle_t t)
{
if (t) {
transport_ssl_t *ssl = t->data;
if (ssl && ssl->tls) {
return ssl->tls->sockfd;
}
}
return -1;
}
esp_transport_handle_t esp_transport_ssl_init(void) esp_transport_handle_t esp_transport_ssl_init(void)
{ {
esp_transport_handle_t t = esp_transport_init(); esp_transport_handle_t t = esp_transport_init();
@@ -296,6 +308,7 @@ esp_transport_handle_t esp_transport_ssl_init(void)
esp_transport_set_context_data(t, ssl); esp_transport_set_context_data(t, ssl);
esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy); esp_transport_set_func(t, ssl_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy);
esp_transport_set_async_connect_func(t, ssl_connect_async); esp_transport_set_async_connect_func(t, ssl_connect_async);
t->_get_socket = ssl_get_socket;
return t; return t;
} }

View File

@@ -25,6 +25,7 @@
#include "esp_transport_utils.h" #include "esp_transport_utils.h"
#include "esp_transport.h" #include "esp_transport.h"
#include "esp_transport_internal.h"
static const char *TAG = "TRANS_TCP"; static const char *TAG = "TRANS_TCP";
@@ -234,6 +235,17 @@ static esp_err_t tcp_destroy(esp_transport_handle_t t)
return 0; return 0;
} }
static int tcp_get_socket(esp_transport_handle_t t)
{
if (t) {
transport_tcp_t *tcp = t->data;
if (tcp) {
return tcp->sock;
}
}
return -1;
}
esp_transport_handle_t esp_transport_tcp_init(void) esp_transport_handle_t esp_transport_tcp_init(void)
{ {
esp_transport_handle_t t = esp_transport_init(); esp_transport_handle_t t = esp_transport_init();
@@ -242,6 +254,7 @@ esp_transport_handle_t esp_transport_tcp_init(void)
tcp->sock = -1; tcp->sock = -1;
esp_transport_set_func(t, tcp_connect, tcp_read, tcp_write, tcp_close, tcp_poll_read, tcp_poll_write, tcp_destroy); esp_transport_set_func(t, tcp_connect, tcp_read, tcp_write, tcp_close, tcp_poll_read, tcp_poll_write, tcp_destroy);
esp_transport_set_context_data(t, tcp); esp_transport_set_context_data(t, tcp);
t->_get_socket = tcp_get_socket;
return t; return t;
} }

View File

@@ -2,6 +2,7 @@
#include <string.h> #include <string.h>
#include <ctype.h> #include <ctype.h>
#include <sys/random.h> #include <sys/random.h>
#include <sys/socket.h>
#include "esp_log.h" #include "esp_log.h"
#include "esp_transport.h" #include "esp_transport.h"
#include "esp_transport_tcp.h" #include "esp_transport_tcp.h"
@@ -9,6 +10,8 @@
#include "esp_transport_utils.h" #include "esp_transport_utils.h"
#include "mbedtls/base64.h" #include "mbedtls/base64.h"
#include "mbedtls/sha1.h" #include "mbedtls/sha1.h"
#include "esp_transport_internal.h"
#include "errno.h"
static const char *TAG = "TRANSPORT_WS"; static const char *TAG = "TRANSPORT_WS";
@@ -449,6 +452,17 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path)
strcpy(ws->path, path); strcpy(ws->path, path);
} }
static int ws_get_socket(esp_transport_handle_t t)
{
if (t) {
transport_ws_t *ws = t->data;
if (ws && ws->parent && ws->parent->_get_socket) {
return ws->parent->_get_socket(ws->parent);
}
}
return -1;
}
esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle) esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handle)
{ {
esp_transport_handle_t t = esp_transport_init(); esp_transport_handle_t t = esp_transport_init();
@@ -473,6 +487,7 @@ esp_transport_handle_t esp_transport_ws_init(esp_transport_handle_t parent_handl
esp_transport_set_parent_transport_func(t, ws_get_payload_transport_handle); esp_transport_set_parent_transport_func(t, ws_get_payload_transport_handle);
esp_transport_set_context_data(t, ws); esp_transport_set_context_data(t, ws);
t->_get_socket = ws_get_socket;
return t; return t;
} }
@@ -548,4 +563,41 @@ int esp_transport_ws_get_read_payload_len(esp_transport_handle_t t)
return ws->frame_state.payload_len; return ws->frame_state.payload_len;
} }
int esp_transport_ws_poll_connection_closed(esp_transport_handle_t t, int timeout_ms)
{
struct timeval timeout;
int sock = esp_transport_get_socket(t);
fd_set readset;
fd_set errset;
FD_ZERO(&readset);
FD_ZERO(&errset);
FD_SET(sock, &readset);
FD_SET(sock, &errset);
int ret = select(sock + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout));
if (ret > 0) {
if (FD_ISSET(sock, &readset)) {
uint8_t buffer;
if (recv(sock, &buffer, 1, MSG_PEEK) <= 0) {
// socket is readable, but reads zero bytes -- connection cleanly closed by FIN flag
return 1;
}
ESP_LOGW(TAG, "esp_transport_ws_poll_connection_closed: unexpected data readable on socket=%d", sock);
} else if (FD_ISSET(sock, &errset)) {
int sock_errno = 0;
uint32_t optlen = sizeof(sock_errno);
getsockopt(sock, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen);
ESP_LOGD(TAG, "esp_transport_ws_poll_connection_closed select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), sock);
if (sock_errno == ENOTCONN || sock_errno == ECONNRESET || sock_errno == ECONNABORTED) {
// the three err codes above might be caused by connection termination by RTS flag
// which we still assume as expected closing sequence of ws-transport connection
return 1;
}
ESP_LOGE(TAG, "esp_transport_ws_poll_connection_closed: unexpected errno=%d on socket=%d", sock_errno, sock);
}
return -1; // indicates error: socket unexpectedly reads an actual data, or unexpected errno code
}
return ret;
}

View File

@@ -3,12 +3,10 @@ from __future__ import unicode_literals
import re import re
import os import os
import socket import socket
import select
import hashlib
import base64
import queue
import random import random
import string import string
from SimpleWebSocketServer import SimpleWebSocketServer, WebSocket
from tiny_test_fw import Utility
from threading import Thread, Event from threading import Thread, Event
import ttfw_idf import ttfw_idf
@@ -26,159 +24,45 @@ def get_my_ip():
return IP return IP
class TestEcho(WebSocket):
def handleMessage(self):
self.sendMessage(self.data)
print('Server sent: {}'.format(self.data))
def handleConnected(self):
print('Connection from: {}'.format(self.address))
def handleClose(self):
print('{} closed the connection'.format(self.address))
# Simple Websocket server for testing purposes # Simple Websocket server for testing purposes
class Websocket: class Websocket(object):
HEADER_LEN = 6
def send_data(self, data):
for nr, conn in self.server.connections.items():
conn.sendMessage(data)
def run(self):
self.server = SimpleWebSocketServer('', self.port, TestEcho)
while not self.exit_event.is_set():
self.server.serveonce()
def __init__(self, port): def __init__(self, port):
self.port = port self.port = port
self.socket = socket.socket() self.exit_event = Event()
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.thread = Thread(target=self.run)
self.socket.settimeout(10.0) self.thread.start()
self.send_q = queue.Queue()
self.shutdown = Event()
def __enter__(self): def __enter__(self):
try:
self.socket.bind(('', self.port))
except socket.error as e:
print("Bind failed:{}".format(e))
raise
self.socket.listen(1)
self.server_thread = Thread(target=self.run_server)
self.server_thread.start()
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.shutdown.set() self.exit_event.set()
self.server_thread.join() self.thread.join(10)
self.socket.close() if self.thread.is_alive():
self.conn.close() Utility.console_log('Thread cannot be joined', 'orange')
def run_server(self):
self.conn, address = self.socket.accept() # accept new connection
self.socket.settimeout(10.0)
print("Connection from: {}".format(address))
self.establish_connection()
print("WS established")
# Handle connection until client closes it, will echo any data received and send data from send_q queue
self.handle_conn()
def establish_connection(self):
while not self.shutdown.is_set():
try:
# receive data stream. it won't accept data packet greater than 1024 bytes
data = self.conn.recv(1024).decode()
if not data:
# exit if data is not received
raise
if "Upgrade: websocket" in data and "Connection: Upgrade" in data:
self.handshake(data)
return
except socket.error as err:
print("Unable to establish a websocket connection: {}".format(err))
raise
def handshake(self, data):
# Magic string from RFC
MAGIC_STRING = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
headers = data.split("\r\n")
for header in headers:
if "Sec-WebSocket-Key" in header:
client_key = header.split()[1]
if client_key:
resp_key = client_key + MAGIC_STRING
resp_key = base64.standard_b64encode(hashlib.sha1(resp_key.encode()).digest())
resp = "HTTP/1.1 101 Switching Protocols\r\n" + \
"Upgrade: websocket\r\n" + \
"Connection: Upgrade\r\n" + \
"Sec-WebSocket-Accept: {}\r\n\r\n".format(resp_key.decode())
self.conn.send(resp.encode())
def handle_conn(self):
while not self.shutdown.is_set():
r,w,e = select.select([self.conn], [], [], 1)
try:
if self.conn in r:
self.echo_data()
if not self.send_q.empty():
self._send_data_(self.send_q.get())
except socket.error as err:
print("Stopped echoing data: {}".format(err))
raise
def echo_data(self):
header = bytearray(self.conn.recv(self.HEADER_LEN, socket.MSG_WAITALL))
if not header:
# exit if socket closed by peer
return
# Remove mask bit
payload_len = ~(1 << 7) & header[1]
payload = bytearray(self.conn.recv(payload_len, socket.MSG_WAITALL))
if not payload:
# exit if socket closed by peer
return
frame = header + payload
decoded_payload = self.decode_frame(frame)
print("Sending echo...")
self._send_data_(decoded_payload)
def _send_data_(self, data):
frame = self.encode_frame(data)
self.conn.send(frame)
def send_data(self, data):
self.send_q.put(data.encode())
def decode_frame(self, frame):
# Mask out MASK bit from payload length, this len is only valid for short messages (<126)
payload_len = ~(1 << 7) & frame[1]
mask = frame[2:self.HEADER_LEN]
encrypted_payload = frame[self.HEADER_LEN:self.HEADER_LEN + payload_len]
payload = bytearray()
for i in range(payload_len):
payload.append(encrypted_payload[i] ^ mask[i % 4])
return payload
def encode_frame(self, payload):
# Set FIN = 1 and OP_CODE = 1 (text)
header = (1 << 7) | (1 << 0)
frame = bytearray([header])
payload_len = len(payload)
# If payload len is longer than 125 then the next 16 bits are used to encode length
if payload_len > 125:
frame.append(126)
frame.append(payload_len >> 8)
frame.append(0xFF & payload_len)
else:
frame.append(payload_len)
frame += payload
return frame
def test_echo(dut): def test_echo(dut):
@@ -188,6 +72,11 @@ def test_echo(dut):
print("All echos received") print("All echos received")
def test_close(dut):
code = dut.expect(re.compile(r"WEBSOCKET: Received closed message with code=(\d*)"), timeout=60)[0]
print("Received close frame with code {}".format(code))
def test_recv_long_msg(dut, websocket, msg_len, repeats): def test_recv_long_msg(dut, websocket, msg_len, repeats):
send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len)) send_msg = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(msg_len))
@@ -246,6 +135,7 @@ def test_examples_protocol_websocket(env, extra_data):
test_echo(dut1) test_echo(dut1)
# Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte # Message length should exceed DUT's buffer size to test fragmentation, default is 1024 byte
test_recv_long_msg(dut1, ws, 2000, 3) test_recv_long_msg(dut1, ws, 2000, 3)
test_close(dut1)
else: else:
print("DUT connecting to {}".format(uri)) print("DUT connecting to {}".format(uri))

View File

@@ -69,7 +69,11 @@ static void websocket_event_handler(void *handler_args, esp_event_base_t base, i
case WEBSOCKET_EVENT_DATA: case WEBSOCKET_EVENT_DATA:
ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA"); ESP_LOGI(TAG, "WEBSOCKET_EVENT_DATA");
ESP_LOGI(TAG, "Received opcode=%d", data->op_code); ESP_LOGI(TAG, "Received opcode=%d", data->op_code);
ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr); if (data->op_code == 0x08 && data->data_len == 2) {
ESP_LOGW(TAG, "Received closed message with code=%d", 256*data->data_ptr[0] + data->data_ptr[1]);
} else {
ESP_LOGW(TAG, "Received=%.*s", data->data_len, (char *)data->data_ptr);
}
ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset); ESP_LOGW(TAG, "Total payload length=%d, data_len=%d, current payload offset=%d\r\n", data->payload_len, data->data_len, data->payload_offset);
xTimerReset(shutdown_signal_timer, portMAX_DELAY); xTimerReset(shutdown_signal_timer, portMAX_DELAY);
@@ -121,7 +125,7 @@ static void websocket_app_start(void)
} }
xSemaphoreTake(shutdown_sema, portMAX_DELAY); xSemaphoreTake(shutdown_sema, portMAX_DELAY);
esp_websocket_client_stop(client); esp_websocket_client_close(client, portMAX_DELAY);
ESP_LOGI(TAG, "Websocket Stopped"); ESP_LOGI(TAG, "Websocket Stopped");
esp_websocket_client_destroy(client); esp_websocket_client_destroy(client);
} }