diff --git a/components/esp_websocket_client/esp_websocket_client.c b/components/esp_websocket_client/esp_websocket_client.c index cc77267d0a..dc73334190 100644 --- a/components/esp_websocket_client/esp_websocket_client.c +++ b/components/esp_websocket_client/esp_websocket_client.c @@ -63,6 +63,7 @@ typedef struct { bool auto_reconnect; void *user_context; int network_timeout_ms; + char *subprotocol; } websocket_config_storage_t; typedef enum { @@ -172,6 +173,11 @@ static esp_err_t esp_websocket_client_set_config(esp_websocket_client_handle_t c cfg->path = strdup(config->path); ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->path, return ESP_ERR_NO_MEM); } + if (config->subprotocol) { + free(cfg->subprotocol); + cfg->subprotocol = strdup(config->subprotocol); + ESP_WS_CLIENT_MEM_CHECK(TAG, cfg->subprotocol, return ESP_ERR_NO_MEM); + } cfg->network_timeout_ms = WEBSOCKET_NETWORK_TIMEOUT_MS; cfg->user_context = config->user_context; @@ -199,12 +205,23 @@ static esp_err_t esp_websocket_client_destroy_config(esp_websocket_client_handle free(cfg->scheme); free(cfg->username); free(cfg->password); + free(cfg->subprotocol); memset(cfg, 0, sizeof(websocket_config_storage_t)); free(client->config); client->config = NULL; return ESP_OK; } +static void set_websocket_transport_optional_settings(esp_websocket_client_handle_t client, esp_transport_handle_t trans) +{ + if (trans && client->config->path) { + esp_transport_ws_set_path(trans, client->config->path); + } + if (trans && client->config->subprotocol) { + esp_transport_ws_set_subprotocol(trans, client->config->subprotocol); + } +} + esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_client_config_t *config) { esp_websocket_client_handle_t client = calloc(1, sizeof(struct esp_websocket_client)); @@ -224,6 +241,9 @@ esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_clie client->lock = xSemaphoreCreateMutex(); ESP_WS_CLIENT_MEM_CHECK(TAG, client->lock, goto _websocket_init_fail); + client->config = calloc(1, sizeof(websocket_config_storage_t)); + ESP_WS_CLIENT_MEM_CHECK(TAG, client->config, goto _websocket_init_fail); + client->transport_list = esp_transport_list_init(); ESP_WS_CLIENT_MEM_CHECK(TAG, client->transport_list, goto _websocket_init_fail); @@ -259,14 +279,11 @@ esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_clie esp_transport_set_default_port(wss, WEBSOCKET_SSL_DEFAULT_PORT); esp_transport_list_add(client->transport_list, wss, "wss"); - if (config->transport == WEBSOCKET_TRANSPORT_OVER_TCP) { + if (config->transport == WEBSOCKET_TRANSPORT_OVER_SSL) { asprintf(&client->config->scheme, "wss"); ESP_WS_CLIENT_MEM_CHECK(TAG, client->config->scheme, goto _websocket_init_fail); } - client->config = calloc(1, sizeof(websocket_config_storage_t)); - ESP_WS_CLIENT_MEM_CHECK(TAG, client->config, goto _websocket_init_fail); - if (config->uri) { if (esp_websocket_client_set_uri(client, config->uri) != ESP_OK) { ESP_LOGE(TAG, "Invalid uri"); @@ -284,6 +301,9 @@ esp_websocket_client_handle_t esp_websocket_client_init(const esp_websocket_clie ESP_WS_CLIENT_MEM_CHECK(TAG, client->config->scheme, goto _websocket_init_fail); } + set_websocket_transport_optional_settings(client, esp_transport_list_get_transport(client->transport_list, "ws")); + set_websocket_transport_optional_settings(client, esp_transport_list_get_transport(client->transport_list, "wss")); + client->keepalive_tick_ms = _tick_get_ms(); client->reconnect_tick_ms = _tick_get_ms(); client->ping_tick_ms = _tick_get_ms(); @@ -366,15 +386,6 @@ esp_err_t esp_websocket_client_set_uri(esp_websocket_client_handle_t client, con free(client->config->path); asprintf(&client->config->path, "%.*s", puri.field_data[UF_PATH].len, uri + puri.field_data[UF_PATH].off); ESP_WS_CLIENT_MEM_CHECK(TAG, client->config->path, return ESP_ERR_NO_MEM); - - esp_transport_handle_t trans = esp_transport_list_get_transport(client->transport_list, "ws"); - if (trans) { - esp_transport_ws_set_path(trans, client->config->path); - } - trans = esp_transport_list_get_transport(client->transport_list, "wss"); - if (trans) { - esp_transport_ws_set_path(trans, client->config->path); - } } if (puri.field_data[UF_PORT].off) { client->config->port = strtol((const char*)(uri + puri.field_data[UF_PORT].off), NULL, 10); diff --git a/components/esp_websocket_client/include/esp_websocket_client.h b/components/esp_websocket_client/include/esp_websocket_client.h index a8bcc5e2f4..4f1e0a3fc0 100644 --- a/components/esp_websocket_client/include/esp_websocket_client.h +++ b/components/esp_websocket_client/include/esp_websocket_client.h @@ -91,6 +91,7 @@ typedef struct { int buffer_size; /*!< Websocket buffer size */ const char *cert_pem; /*!< SSL Certification, PEM format as string, if the client requires to verify server */ esp_websocket_transport_t transport; /*!< Websocket transport type, see `esp_websocket_transport_t */ + char *subprotocol; /*!< Websocket subprotocol */ } esp_websocket_client_config_t; /** diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index f47fd049cf..bb4f0bc013 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -43,7 +43,6 @@ void esp_transport_ws_set_path(esp_transport_handle_t t, const char *path); */ esp_err_t esp_transport_ws_set_subprotocol(esp_transport_handle_t t, const char *sub_protocol); - #ifdef __cplusplus } #endif diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 886f2a31e5..c8f3da31c6 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -188,18 +188,17 @@ static int _ws_write(esp_transport_handle_t t, int opcode, int mask_flag, const ws_header[header_len++] = (uint8_t)((len >> 8) & 0xFF); ws_header[header_len++] = (uint8_t)((len >> 0) & 0xFF); } - if (len) { - if (mask_flag) { - mask = &ws_header[header_len]; - getrandom(ws_header + header_len, 4, 0); - header_len += 4; - for (i = 0; i < len; ++i) { - buffer[i] = (buffer[i] ^ mask[i % 4]); - } + if (mask_flag) { + mask = &ws_header[header_len]; + getrandom(ws_header + header_len, 4, 0); + header_len += 4; + + for (i = 0; i < len; ++i) { + buffer[i] = (buffer[i] ^ mask[i % 4]); } - } + if (esp_transport_write(ws->parent, ws_header, header_len, timeout_ms) != header_len) { ESP_LOGE(TAG, "Error write header"); return -1; @@ -224,7 +223,7 @@ static int ws_write(esp_transport_handle_t t, const char *b, int len, int timeou { if (len == 0) { ESP_LOGD(TAG, "Write PING message"); - return _ws_write(t, WS_OPCODE_PING | WS_FIN, 0, NULL, 0, timeout_ms); + return _ws_write(t, WS_OPCODE_PING | WS_FIN, WS_MASK, NULL, 0, timeout_ms); } return _ws_write(t, WS_OPCODE_BINARY | WS_FIN, WS_MASK, b, len, timeout_ms); } @@ -282,7 +281,7 @@ static int ws_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ } // Then receive and process payload - if ((rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) { + if (payload_len != 0 && (rlen = esp_transport_read(ws->parent, buffer, payload_len, timeout_ms)) <= 0) { ESP_LOGE(TAG, "Error read data"); return rlen; }