Merge branch 'feature/transpost_methods_renamed_no_collision' into 'idf'

transport: renamed transport headers to avoid compilation collisions

See merge request idf/esp-mqtt!5
This commit is contained in:
David Čermák
2018-10-03 14:12:58 +08:00
3 changed files with 38 additions and 347 deletions

View File

@@ -1,46 +0,0 @@
/*
* This file is subject to the terms and conditions defined in
* file 'LICENSE', which is part of this source code package.
* Tuan PM <tuanpm at live dot com>
*/
#ifndef _TRANSPORT_WS_H_
#define _TRANSPORT_WS_H_
#include "transport.h"
#ifdef __cplusplus
extern "C" {
#endif
#define WS_FIN 0x80
#define WS_OPCODE_TEXT 0x01
#define WS_OPCODE_BINARY 0x02
#define WS_OPCODE_CLOSE 0x08
#define WS_OPCODE_PING 0x09
#define WS_OPCODE_PONG 0x0a
// Second byte
#define WS_MASK 0x80
#define WS_SIZE16 126
#define WS_SIZE64 127
#define MAX_WEBSOCKET_HEADER_SIZE 10
#define WS_RESPONSE_OK 101
/**
* @brief Create TCP transport
*
* @return
* - transport
* - NULL
*/
transport_handle_t transport_ws_init(transport_handle_t parent_handle);
void transport_ws_set_path(transport_handle_t t, const char *path);
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -1,263 +0,0 @@
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include "platform.h"
#include "transport.h"
#include "transport_tcp.h"
#include "transport_ws.h"
#include "mbedtls/base64.h"
#include "mbedtls/sha1.h"
static const char *TAG = "TRANSPORT_WS";
#define DEFAULT_WS_BUFFER (1024)
typedef struct {
char *path;
char *buffer;
transport_handle_t parent;
} transport_ws_t;
transport_handle_t ws_transport_get_payload_transport_handle(transport_handle_t t)
{
transport_ws_t *ws = transport_get_context_data(t);
return ws->parent;
}
static char *trimwhitespace(const char *str)
{
char *end;
// Trim leading space
while (isspace((unsigned char)*str)) str++;
if (*str == 0) {
return (char *)str;
}
// Trim trailing space
end = (char *)(str + strlen(str) - 1);
while (end > str && isspace((unsigned char)*end)) end--;
// Write new null terminator
*(end + 1) = 0;
return (char *)str;
}
static char *get_http_header(const char *buffer, const char *key)
{
char *found = strstr(buffer, key);
if (found) {
found += strlen(key);
char *found_end = strstr(found, "\r\n");
if (found_end) {
found_end[0] = 0;//terminal string
return trimwhitespace(found);
}
}
return NULL;
}
static int ws_connect(transport_handle_t t, const char *host, int port, int timeout_ms)
{
transport_ws_t *ws = transport_get_context_data(t);
if (transport_connect(ws->parent, host, port, timeout_ms) < 0) {
ESP_LOGE(TAG, "Error connect to ther server");
}
unsigned char random_key[16] = { 0 }, client_key[32] = {0};
int i;
for (i = 0; i < sizeof(random_key); i++) {
random_key[i] = rand() & 0xFF;
}
size_t outlen = 0;
mbedtls_base64_encode(client_key, 32, &outlen, random_key, 16);
int len = snprintf(ws->buffer, DEFAULT_WS_BUFFER,
"GET %s HTTP/1.1\r\n"
"Connection: Upgrade\r\n"
"Host: %s:%d\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Version: 13\r\n"
"Sec-WebSocket-Protocol: mqtt\r\n"
"Sec-WebSocket-Key: %s\r\n"
"User-Agent: ESP32 MQTT Client\r\n\r\n",
ws->path,
host, port,
client_key);
ESP_LOGD(TAG, "Write upgrate request\r\n%s", ws->buffer);
if (transport_write(ws->parent, ws->buffer, len, timeout_ms) <= 0) {
ESP_LOGE(TAG, "Error write Upgrade header %s", ws->buffer);
return -1;
}
if ((len = transport_read(ws->parent, ws->buffer, DEFAULT_WS_BUFFER, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read response for Upgrade header %s", ws->buffer);
return -1;
}
char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:");
if (server_key == NULL) {
ESP_LOGE(TAG, "Sec-WebSocket-Accept not found");
return -1;
}
unsigned char client_key_b64[64], valid_client_key[20], accept_key[32] = {0};
int key_len = sprintf((char*)client_key_b64, "%s258EAFA5-E914-47DA-95CA-C5AB0DC85B11", (char*)client_key);
mbedtls_sha1_ret(client_key_b64, (size_t)key_len, valid_client_key);
mbedtls_base64_encode(accept_key, 32, &outlen, valid_client_key, 20);
accept_key[outlen] = 0;
ESP_LOGD(TAG, "server key=%s, send_key=%s, accept_key=%s", (char *)server_key, (char*)client_key, accept_key);
if (strcmp((char*)accept_key, (char*)server_key) != 0) {
ESP_LOGE(TAG, "Invalid websocket key");
return -1;
}
return 0;
}
static int ws_write(transport_handle_t t, const char *buff, int len, int timeout_ms)
{
transport_ws_t *ws = transport_get_context_data(t);
char ws_header[MAX_WEBSOCKET_HEADER_SIZE];
char *mask;
int header_len = 0, i;
char *buffer = (char *)buff;
int poll_write;
if ((poll_write = transport_poll_write(ws->parent, timeout_ms)) <= 0) {
return poll_write;
}
ws_header[header_len++] = WS_OPCODE_BINARY | WS_FIN;
// NOTE: no support for > 16-bit sized messages
if (len > 125) {
ws_header[header_len++] = WS_SIZE16 | WS_MASK;
ws_header[header_len++] = (uint8_t)(len >> 8);
ws_header[header_len++] = (uint8_t)(len & 0xFF);
} else {
ws_header[header_len++] = (uint8_t)(len | WS_MASK);
}
mask = &ws_header[header_len];
ws_header[header_len++] = rand() & 0xFF;
ws_header[header_len++] = rand() & 0xFF;
ws_header[header_len++] = rand() & 0xFF;
ws_header[header_len++] = rand() & 0xFF;
for (i = 0; i < len; ++i) {
buffer[i] = (buffer[i] ^ mask[i % 4]);
}
if (transport_write(ws->parent, ws_header, header_len, timeout_ms) != header_len) {
ESP_LOGE(TAG, "Error write header");
return -1;
}
return transport_write(ws->parent, buffer, len, timeout_ms);
}
static int ws_read(transport_handle_t t, char *buffer, int len, int timeout_ms)
{
transport_ws_t *ws = transport_get_context_data(t);
int payload_len;
int payload_len_buff = len;
char *data_ptr = buffer, opcode, mask, *mask_key = NULL;
int rlen;
int poll_read;
if ((poll_read = transport_poll_read(ws->parent, timeout_ms)) <= 0) {
return poll_read;
}
if ((rlen = transport_read(ws->parent, buffer, len, timeout_ms)) <= 0) {
ESP_LOGE(TAG, "Error read data");
return rlen;
}
opcode = (*data_ptr & 0x0F);
data_ptr ++;
mask = ((*data_ptr >> 7) & 0x01);
payload_len = (*data_ptr & 0x7F);
data_ptr++;
ESP_LOGD(TAG, "Opcode: %d, mask: %d, len: %d\r\n", opcode, mask, payload_len);
if (payload_len == 126) {
// headerLen += 2;
payload_len = data_ptr[0] << 8 | data_ptr[1];
payload_len_buff = len - 4;
data_ptr += 2;
} else if (payload_len == 127) {
// headerLen += 8;
if (data_ptr[0] != 0 || data_ptr[1] != 0 || data_ptr[2] != 0 || data_ptr[3] != 0) {
// really too big!
payload_len = 0xFFFFFFFF;
} else {
payload_len = data_ptr[4] << 24 | data_ptr[5] << 16 | data_ptr[6] << 8 | data_ptr[7];
}
data_ptr += 8;
payload_len_buff = len - 10;
}
if (payload_len > payload_len_buff) {
ESP_LOGD(TAG, "Actual data received (%d) are longer than mqtt buffer (%d)", payload_len, payload_len_buff);
payload_len = payload_len_buff;
}
if (mask) {
mask_key = data_ptr;
data_ptr += 4;
for (int i = 0; i < payload_len; i++) {
buffer[i] = (data_ptr[i] ^ mask_key[i % 4]);
}
} else {
memmove(buffer, data_ptr, payload_len);
}
return payload_len;
}
static int ws_poll_read(transport_handle_t t, int timeout_ms)
{
transport_ws_t *ws = transport_get_context_data(t);
return transport_poll_read(ws->parent, timeout_ms);
}
static int ws_poll_write(transport_handle_t t, int timeout_ms)
{
transport_ws_t *ws = transport_get_context_data(t);
return transport_poll_write(ws->parent, timeout_ms);;
}
static int ws_close(transport_handle_t t)
{
transport_ws_t *ws = transport_get_context_data(t);
return transport_close(ws->parent);
}
static esp_err_t ws_destroy(transport_handle_t t)
{
transport_ws_t *ws = transport_get_context_data(t);
free(ws->buffer);
free(ws->path);
free(ws);
return 0;
}
void transport_ws_set_path(transport_handle_t t, const char *path)
{
transport_ws_t *ws = transport_get_context_data(t);
ws->path = realloc(ws->path, strlen(path) + 1);
strcpy(ws->path, path);
}
transport_handle_t transport_ws_init(transport_handle_t parent_handle)
{
transport_handle_t t = transport_init();
transport_ws_t *ws = calloc(1, sizeof(transport_ws_t));
ESP_MEM_CHECK(TAG, ws, return NULL);
ws->parent = parent_handle;
ws->path = strdup("/");
ESP_MEM_CHECK(TAG, ws->path, return NULL);
ws->buffer = malloc(DEFAULT_WS_BUFFER);
ESP_MEM_CHECK(TAG, ws->buffer, {
free(ws->path);
free(ws);
return NULL;
});
transport_set_func(t, ws_connect, ws_read, ws_write, ws_close, ws_poll_read, ws_poll_write, ws_destroy, ws_transport_get_payload_transport_handle);
transport_set_context_data(t, ws);
return t;
}

View File

@@ -3,10 +3,10 @@
#include "mqtt_client.h"
#include "mqtt_msg.h"
#include "transport.h"
#include "transport_tcp.h"
#include "transport_ssl.h"
#include "transport_ws.h"
#include "esp_transport.h"
#include "esp_transport_tcp.h"
#include "esp_transport_ssl.h"
#include "esp_transport_ws.h"
#include "platform.h"
#include "mqtt_outbox.h"
@@ -55,8 +55,8 @@ typedef enum {
} mqtt_client_state_t;
struct esp_mqtt_client {
transport_list_handle_t transport_list;
transport_handle_t transport;
esp_transport_list_handle_t transport_list;
esp_transport_handle_t transport;
mqtt_config_storage_t *config;
mqtt_state_t mqtt_state;
mqtt_connect_info_t connect_info;
@@ -204,7 +204,7 @@ static esp_err_t esp_mqtt_connect(esp_mqtt_client_handle_t client, int timeout_m
client->mqtt_state.pending_msg_type,
client->mqtt_state.pending_msg_id);
write_len = transport_write(client->transport,
write_len = esp_transport_write(client->transport,
(char *)client->mqtt_state.outbound_message->data,
client->mqtt_state.outbound_message->length,
client->config->network_timeout_ms);
@@ -212,7 +212,7 @@ static esp_err_t esp_mqtt_connect(esp_mqtt_client_handle_t client, int timeout_m
ESP_LOGE(TAG, "Writing failed, errno= %d", errno);
return ESP_FAIL;
}
read_len = transport_read(client->transport,
read_len = esp_transport_read(client->transport,
(char *)client->mqtt_state.in_buffer,
client->mqtt_state.outbound_message->length,
client->config->network_timeout_ms);
@@ -251,7 +251,7 @@ static esp_err_t esp_mqtt_connect(esp_mqtt_client_handle_t client, int timeout_m
static esp_err_t esp_mqtt_abort_connection(esp_mqtt_client_handle_t client)
{
transport_close(client->transport);
esp_transport_close(client->transport);
client->wait_timeout_ms = MQTT_RECONNECT_TIMEOUT_MS;
client->reconnect_tick = platform_tick_get_ms();
client->state = MQTT_STATE_WAIT_TIMEOUT;
@@ -269,23 +269,23 @@ esp_mqtt_client_handle_t esp_mqtt_client_init(const esp_mqtt_client_config_t *co
esp_mqtt_set_config(client, config);
client->transport_list = transport_list_init();
client->transport_list = esp_transport_list_init();
ESP_MEM_CHECK(TAG, client->transport_list, goto _mqtt_init_failed);
transport_handle_t tcp = transport_tcp_init();
esp_transport_handle_t tcp = esp_transport_tcp_init();
ESP_MEM_CHECK(TAG, tcp, goto _mqtt_init_failed);
transport_set_default_port(tcp, MQTT_TCP_DEFAULT_PORT);
transport_list_add(client->transport_list, tcp, "mqtt");
esp_transport_set_default_port(tcp, MQTT_TCP_DEFAULT_PORT);
esp_transport_list_add(client->transport_list, tcp, "mqtt");
if (config->transport == MQTT_TRANSPORT_OVER_TCP) {
client->config->scheme = create_string("mqtt", 4);
ESP_MEM_CHECK(TAG, client->config->scheme, goto _mqtt_init_failed);
}
#if MQTT_ENABLE_WS
transport_handle_t ws = transport_ws_init(tcp);
esp_transport_handle_t ws = esp_transport_ws_init(tcp);
ESP_MEM_CHECK(TAG, ws, goto _mqtt_init_failed);
transport_set_default_port(ws, MQTT_WS_DEFAULT_PORT);
transport_list_add(client->transport_list, ws, "ws");
esp_transport_set_default_port(ws, MQTT_WS_DEFAULT_PORT);
esp_transport_list_add(client->transport_list, ws, "ws");
if (config->transport == MQTT_TRANSPORT_OVER_WS) {
client->config->scheme = create_string("ws", 2);
ESP_MEM_CHECK(TAG, client->config->scheme, goto _mqtt_init_failed);
@@ -293,13 +293,13 @@ esp_mqtt_client_handle_t esp_mqtt_client_init(const esp_mqtt_client_config_t *co
#endif
#if MQTT_ENABLE_SSL
transport_handle_t ssl = transport_ssl_init();
esp_transport_handle_t ssl = esp_transport_ssl_init();
ESP_MEM_CHECK(TAG, ssl, goto _mqtt_init_failed);
transport_set_default_port(ssl, MQTT_SSL_DEFAULT_PORT);
esp_transport_set_default_port(ssl, MQTT_SSL_DEFAULT_PORT);
if (config->cert_pem) {
transport_ssl_set_cert_data(ssl, config->cert_pem, strlen(config->cert_pem));
esp_transport_ssl_set_cert_data(ssl, config->cert_pem, strlen(config->cert_pem));
}
transport_list_add(client->transport_list, ssl, "mqtts");
esp_transport_list_add(client->transport_list, ssl, "mqtts");
if (config->transport == MQTT_TRANSPORT_OVER_SSL) {
client->config->scheme = create_string("mqtts", 5);
ESP_MEM_CHECK(TAG, client->config->scheme, goto _mqtt_init_failed);
@@ -307,10 +307,10 @@ esp_mqtt_client_handle_t esp_mqtt_client_init(const esp_mqtt_client_config_t *co
#endif
#if MQTT_ENABLE_WSS
transport_handle_t wss = transport_ws_init(ssl);
esp_transport_handle_t wss = esp_transport_ws_init(ssl);
ESP_MEM_CHECK(TAG, wss, goto _mqtt_init_failed);
transport_set_default_port(wss, MQTT_WSS_DEFAULT_PORT);
transport_list_add(client->transport_list, wss, "wss");
esp_transport_set_default_port(wss, MQTT_WSS_DEFAULT_PORT);
esp_transport_list_add(client->transport_list, wss, "wss");
if (config->transport == MQTT_TRANSPORT_OVER_WSS) {
client->config->scheme = create_string("wss", 3);
ESP_MEM_CHECK(TAG, client->config->scheme, goto _mqtt_init_failed);
@@ -357,7 +357,7 @@ esp_err_t esp_mqtt_client_destroy(esp_mqtt_client_handle_t client)
{
esp_mqtt_client_stop(client);
esp_mqtt_destroy_config(client);
transport_list_destroy(client->transport_list);
esp_transport_list_destroy(client->transport_list);
outbox_destroy(client->outbox);
vEventGroupDelete(client->status_bits);
free(client->mqtt_state.in_buffer);
@@ -400,13 +400,13 @@ esp_err_t esp_mqtt_client_set_uri(esp_mqtt_client_handle_t client, const char *u
client->config->path = create_string(uri + puri.field_data[UF_PATH].off, puri.field_data[UF_PATH].len);
}
if (client->config->path) {
transport_handle_t trans = transport_list_get_transport(client->transport_list, "ws");
esp_transport_handle_t trans = esp_transport_list_get_transport(client->transport_list, "ws");
if (trans) {
transport_ws_set_path(trans, client->config->path);
esp_transport_ws_set_path(trans, client->config->path);
}
trans = transport_list_get_transport(client->transport_list, "wss");
trans = esp_transport_list_get_transport(client->transport_list, "wss");
if (trans) {
transport_ws_set_path(trans, client->config->path);
esp_transport_ws_set_path(trans, client->config->path);
}
}
@@ -432,7 +432,7 @@ esp_err_t esp_mqtt_client_set_uri(esp_mqtt_client_handle_t client, const char *u
static esp_err_t mqtt_write_data(esp_mqtt_client_handle_t client)
{
int write_len = transport_write(client->transport,
int write_len = esp_transport_write(client->transport,
(char *)client->mqtt_state.outbound_message->data,
client->mqtt_state.outbound_message->length,
client->config->network_timeout_ms);
@@ -464,7 +464,7 @@ static esp_err_t esp_mqtt_dispatch_event(esp_mqtt_client_handle_t client)
typedef struct {
char *path;
char *buffer;
transport_handle_t parent;
esp_transport_handle_t parent;
} transport_ws_t;
static void deliver_publish(esp_mqtt_client_handle_t client, uint8_t *message, int length)
@@ -473,7 +473,7 @@ static void deliver_publish(esp_mqtt_client_handle_t client, uint8_t *message, i
uint32_t mqtt_topic_length, mqtt_data_length;
uint32_t mqtt_len, mqtt_offset = 0, total_mqtt_len = 0;
int len_read;
transport_handle_t transport = client->transport;
esp_transport_handle_t transport = client->transport;
do
{
@@ -485,7 +485,7 @@ static void deliver_publish(esp_mqtt_client_handle_t client, uint8_t *message, i
total_mqtt_len = client->mqtt_state.message_length - client->mqtt_state.message_length_read + mqtt_data_length;
mqtt_len = mqtt_data_length;
/* any further reading only the underlying payload */
transport = transport_get_payload_transport_handle(transport);
transport = esp_transport_get_payload_transport_handle(transport);
} else {
mqtt_len = len_read;
mqtt_data = (const char*)client->mqtt_state.in_buffer;
@@ -508,7 +508,7 @@ static void deliver_publish(esp_mqtt_client_handle_t client, uint8_t *message, i
break;
}
len_read = transport_read(transport,
len_read = esp_transport_read(transport,
(char *)client->mqtt_state.in_buffer,
client->mqtt_state.message_length - client->mqtt_state.message_length_read > client->mqtt_state.in_buffer_length ?
client->mqtt_state.in_buffer_length : client->mqtt_state.message_length - client->mqtt_state.message_length_read,
@@ -565,7 +565,7 @@ static esp_err_t mqtt_process_receive(esp_mqtt_client_handle_t client)
uint8_t msg_qos;
uint16_t msg_id;
read_len = transport_read(client->transport, (char *)client->mqtt_state.in_buffer, client->mqtt_state.in_buffer_length, 1000);
read_len = esp_transport_read(client->transport, (char *)client->mqtt_state.in_buffer, client->mqtt_state.in_buffer_length, 1000);
if (read_len < 0) {
ESP_LOGE(TAG, "Read error or end of stream");
@@ -662,7 +662,7 @@ static void esp_mqtt_task(void *pv)
client->run = true;
//get transport by scheme
client->transport = transport_list_get_transport(client->transport_list, client->config->scheme);
client->transport = esp_transport_list_get_transport(client->transport_list, client->config->scheme);
if (client->transport == NULL) {
ESP_LOGE(TAG, "There are no transports valid, stop mqtt client, config scheme = %s", client->config->scheme);
@@ -670,7 +670,7 @@ static void esp_mqtt_task(void *pv)
}
//default port
if (client->config->port == 0) {
client->config->port = transport_get_default_port(client->transport);
client->config->port = esp_transport_get_default_port(client->transport);
}
client->state = MQTT_STATE_INIT;
@@ -684,7 +684,7 @@ static void esp_mqtt_task(void *pv)
client->run = false;
}
if (transport_connect(client->transport,
if (esp_transport_connect(client->transport,
client->config->host,
client->config->port,
client->config->network_timeout_ms) < 0) {
@@ -748,7 +748,7 @@ static void esp_mqtt_task(void *pv)
break;
}
}
transport_close(client->transport);
esp_transport_close(client->transport);
xEventGroupSetBits(client->status_bits, STOPPED_BIT);
vTaskDelete(NULL);