diff --git a/components/tcp_transport/include/esp_transport_ws.h b/components/tcp_transport/include/esp_transport_ws.h index c8997378eb..f4be6b3db4 100644 --- a/components/tcp_transport/include/esp_transport_ws.h +++ b/components/tcp_transport/include/esp_transport_ws.h @@ -29,6 +29,8 @@ typedef enum ws_transport_opcodes { * from the API esp_transport_ws_get_read_opcode() */ } ws_transport_opcodes_t; +typedef void (*ws_header_hook)(void * userp, const char * line, int line_len); + /** * WS transport configuration structure */ @@ -37,6 +39,8 @@ typedef struct { const char *sub_protocol; /*!< WS subprotocol */ const char *user_agent; /*!< WS user agent */ const char *headers; /*!< WS additional headers */ + ws_header_hook header_hook; /*!< WS received header */ + void *header_userp; /*!< WS received header user-pointer */ const char *auth; /*!< HTTP authorization header */ char *response_headers; /*!< The buffer to copy the http response header */ size_t response_headers_len; /*!< The length of the http response header */ @@ -99,6 +103,31 @@ esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char * */ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers); +/** + * @brief Set websocket header callback + * + * @param t websocket transport handle + * @param hook call function on header received. NULL to disable. + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_header_hook(esp_transport_handle_t t, ws_header_hook hook); + + +/** + * @brief Set websocket header callback user-pointer + * + * @param t websocket transport handle + * @param userp caller-controlled argument to ws_header_hook + * + * @return + * - ESP_OK on success + * - One of the error codes + */ +esp_err_t esp_transport_ws_set_header_userp(esp_transport_handle_t t, void * userp); + /** * @brief Set websocket authorization headers * diff --git a/components/tcp_transport/transport_ws.c b/components/tcp_transport/transport_ws.c index 17ff4e19ee..c683befb61 100644 --- a/components/tcp_transport/transport_ws.c +++ b/components/tcp_transport/transport_ws.c @@ -63,6 +63,8 @@ typedef struct { char *sub_protocol; char *user_agent; char *headers; + ws_header_hook header_hook; + void * header_userp; char *auth; char *buffer; /*!< Initial HTTP connection buffer, which may include data beyond the handshake headers, such as the next WebSocket packet*/ size_t buffer_len; /*!< The buffer length */ @@ -144,31 +146,6 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len return to_read; } -static char *trimwhitespace(char *str) -{ - char *end; - - // Trim leading space - while (isspace((unsigned char)*str)) { - str++; - } - - if (*str == 0) { - return str; - } - - // Trim trailing space - end = str + strlen(str) - 1; - while (end > str && isspace((unsigned char)*end)) { - end--; - } - - // Write new null terminator - *(end + 1) = '\0'; - - return str; -} - static int get_http_status_code(const char *buffer) { const char http[] = "HTTP/"; @@ -189,21 +166,6 @@ static int get_http_status_code(const char *buffer) return -1; } -static char *get_http_header(char *buffer, const char *key) -{ - char *found = strcasestr(buffer, key); - if (found) { - found += strlen(key); - char *found_end = strstr(found, "\r\n"); - if (found_end) { - *found_end = '\0'; // terminal string - - return trimwhitespace(found); - } - } - return NULL; -} - static int ws_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) { transport_ws_t *ws = esp_transport_get_context_data(t); @@ -330,17 +292,67 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int if (ws->http_status_code == -1) { ESP_LOGE(TAG, "HTTP upgrade failed"); return -1; - } else if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) { - char * redir_host = get_http_header(ws->buffer, "Location:"); - if (redir_host == NULL) { + } + + const char *location = NULL; + int location_len = 0; + + const char *server_key = NULL; + int server_key_len = 0; + const char * header_cursor = strnstr(ws->buffer, "\r\n", header_len); + if (!header_cursor){ + ESP_LOGE(TAG, "HTTP Header locate failed"); + return -1; + } + header_cursor += strlen("\r\n"); + + while(header_cursor < delim_ptr){ + const char * end_of_line = strnstr(header_cursor, "\r\n", header_len - (header_cursor - ws->buffer)); + if(!end_of_line){ + ESP_LOGE(TAG, "HTTP Header walk failed"); + return -1; + } + else if(end_of_line == header_cursor){ + ESP_LOGD(TAG, "HTTP Header walk found end"); + break; + } + int line_len = end_of_line - header_cursor; + ESP_LOGD(TAG, "HTTP Header walk line:%.*s", line_len, header_cursor); + + // Check for Sec-WebSocket-Accept header + const char * header_sec_websocket_accept = "Sec-WebSocket-Accept: "; + size_t header_sec_websocket_accept_len = strlen(header_sec_websocket_accept); + if (line_len >= header_sec_websocket_accept_len && !strncasecmp(header_cursor, header_sec_websocket_accept, header_sec_websocket_accept_len)) { + ESP_LOGD(TAG, "found server-key"); + server_key = header_cursor + header_sec_websocket_accept_len; + server_key_len = line_len - header_sec_websocket_accept_len; + } + else if (ws->header_hook) { + ws->header_hook(ws->header_userp, header_cursor, line_len); + } + + // Check for Location: header + const char * header_location = "Location: "; + size_t header_location_len = strlen(header_location); + if (line_len >= header_location_len && !strncasecmp(header_cursor, header_location, header_location_len)) { + location = header_cursor + header_location_len; + location_len = line_len - header_location_len; + } + + // Adjust cursor to the start of the next line + header_cursor += line_len; + header_cursor += strlen("\r\n"); + } + + if (WS_HTTP_TEMPORARY_REDIRECT(ws->http_status_code) || WS_HTTP_PERMANENT_REDIRECT(ws->http_status_code)) { + if (location == NULL || location_len <= 0) { ESP_LOGE(TAG, "Location header not found"); return -1; } - ws->redir_host = strdup(redir_host); + ws->redir_host = strndup(location, location_len); return ws->http_status_code; } - char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:"); if (server_key == NULL) { ESP_LOGE(TAG, "Sec-WebSocket-Accept not found"); return -1; @@ -361,7 +373,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int esp_crypto_base64_encode(expected_server_key, sizeof(expected_server_key), &outlen, expected_server_sha1, sizeof(expected_server_sha1)); expected_server_key[ (outlen < sizeof(expected_server_key)) ? outlen : (sizeof(expected_server_key) - 1) ] = 0; ESP_LOGD(TAG, "server key=%s, send_key=%s, expected_server_key=%s", (char *)server_key, (char *)client_key, expected_server_key); - if (strcmp((char *)expected_server_key, (char *)server_key) != 0) { + if (strncmp((char *)expected_server_key, (char *)server_key, server_key_len) != 0) { ESP_LOGE(TAG, "Invalid websocket key"); return -1; } @@ -866,6 +878,26 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea return ESP_OK; } +esp_err_t esp_transport_ws_set_header_hook(esp_transport_handle_t t, ws_header_hook hook) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + ws->header_hook = hook; + return ESP_OK; +} + +esp_err_t esp_transport_ws_set_header_userp(esp_transport_handle_t t, void * userp) +{ + if (t == NULL) { + return ESP_ERR_INVALID_ARG; + } + transport_ws_t *ws = esp_transport_get_context_data(t); + ws->header_userp = userp; + return ESP_OK; +} + esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth) { if (t == NULL) { @@ -931,6 +963,14 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp err = esp_transport_ws_set_headers(t, config->headers); ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) } + if (config->header_hook) { + err = esp_transport_ws_set_header_hook(t, config->header_hook); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } + if (config->header_userp) { + err = esp_transport_ws_set_header_userp(t, config->header_userp); + ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;) + } if (config->auth) { err = esp_transport_ws_set_auth(t, config->auth); ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)