diff --git a/include/mqtt.h b/include/mqtt.h old mode 100644 new mode 100755 index 0006ced..b1d6214 --- a/include/mqtt.h +++ b/include/mqtt.h @@ -7,6 +7,19 @@ #include "ringbuf.h" +#if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL +#include "openssl/ssl.h" + + #define ClientRead(buf,num) SSL_read(client->ssl, buf, num) + #define ClientWrite(buf,num) SSL_write(client->ssl, buf, num) + +#else + + #define ClientRead(buf,num) read(client->socket, buf, num) + #define ClientWrite(buf,num) write(client->socket, buf, num) +#endif + + typedef void (* mqtt_callback)(void *, void *); typedef struct { @@ -62,6 +75,12 @@ typedef struct mqtt_state_t typedef struct { int socket; + +#if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL + SSL_CTX *ctx; + SSL *ssl; +#endif + mqtt_settings *settings; mqtt_state_t mqtt_state; mqtt_connect_info_t connect_info; @@ -71,6 +90,7 @@ typedef struct { } mqtt_client; mqtt_client *mqtt_start(mqtt_settings *mqtt_info); +void mqtt_stop(); void mqtt_task(void *pvParameters); void mqtt_subscribe(mqtt_client *client, char *topic, uint8_t qos); void mqtt_publish(mqtt_client* client, char *topic, char *data, int len, int qos, int retain); diff --git a/mqtt.c b/mqtt.c old mode 100644 new mode 100755 index f1128c1..cefb36b --- a/mqtt.c +++ b/mqtt.c @@ -39,39 +39,138 @@ static void mqtt_queue(mqtt_client *client) client->mqtt_state.outbound_message->length); xQueueSend(client->xSendingQueue, &client->mqtt_state.outbound_message->length, 0); } -static int client_connect(const char *stream_host, int stream_port) + +static bool client_connect(mqtt_client *client) { - int sock; + int ret; struct sockaddr_in remote_ip; - while (1) { + + while (1) { + bzero(&remote_ip, sizeof(struct sockaddr_in)); remote_ip.sin_family = AF_INET; - //if stream_host is not ip address, resolve it - if (inet_aton(stream_host, &(remote_ip.sin_addr)) == 0) { - mqtt_info("Resolve dns for domain: %s", stream_host); - if (!resolve_dns(stream_host, &remote_ip)) { + remote_ip.sin_port = htons(client->settings->port); + + + //if host is not ip address, resolve it + if (inet_aton( client->settings->host, &(remote_ip.sin_addr)) == 0) { + mqtt_info("Resolve dns for domain: %s", client->settings->host); + + if (!resolve_dns(client->settings->host, &remote_ip)) { vTaskDelay(1000 / portTICK_RATE_MS); continue; } } - sock = socket(PF_INET, SOCK_STREAM, 0); - if (sock == -1) { - continue; + + +#if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL + client->ctx = NULL; + client->ssl = NULL; + + client->ctx = SSL_CTX_new(TLSv1_2_client_method()); + if (!client->ctx) { + mqtt_error("Failed to create SSL CTX"); + goto failed1; } - remote_ip.sin_port = htons(stream_port); +#endif + + client->socket = socket(PF_INET, SOCK_STREAM, 0); + if (client->socket == -1) { + mqtt_error("Failed to create socket"); + goto failed2; + } + + + mqtt_info("Connecting to server %s:%d,%d", inet_ntoa((remote_ip.sin_addr)), - stream_port, + client->settings->port, remote_ip.sin_port); - if (connect(sock, (struct sockaddr *)(&remote_ip), sizeof(struct sockaddr)) != 00) { - close(sock); - mqtt_error("Conn err."); - vTaskDelay(1000 / portTICK_RATE_MS); - continue; + + if (connect(client->socket, (struct sockaddr *)(&remote_ip), sizeof(struct sockaddr)) != 00) { + mqtt_error("Connect failed"); + goto failed3; } - return sock; - } + +#if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL + mqtt_info("Creating SSL object..."); + client->ssl = SSL_new(client->ctx); + if (!client->ssl) { + mqtt_error("Unable to creat new SSL"); + goto failed3; + } + + if (!SSL_set_fd(client->ssl, client->socket)) { + mqtt_error("SSL set_fd failed"); + goto failed3; + } + + mqtt_info("Start SSL connect.."); + ret = SSL_connect(client->ssl); + if (!ret) { + mqtt_error("SSL Connect FAILED"); + goto failed4; + } +#endif + mqtt_info("Connected!"); + + return true; + + //failed5: + // SSL_shutdown(client->ssl); + +#if defined(CONFIG_MQTT_SECURITY_ON) + failed4: + SSL_free(client->ssl); + client->ssl = NULL; +#endif + + failed3: + close(client->socket); + client->socket = -1; + + failed2: +#if defined(CONFIG_MQTT_SECURITY_ON) + SSL_CTX_free(client->ctx); + + failed1: + client->ctx = NULL; +#endif + vTaskDelay(1000 / portTICK_RATE_MS); + + } +} + + +// Close client socket +// including SSL objects if CNFIG_MQTT_SECURITY_ON is enabled +void closeclient(mqtt_client *client) +{ + +#if defined(CONFIG_MQTT_SECURITY_ON) + if (client->ssl != NULL) + { + SSL_shutdown(client->ssl); + + SSL_free(client->ssl); + client->ssl = NULL; + } +#endif + if (client->socket != -1) + { + close(client->socket); + client->socket = -1; + } + +#if defined(CONFIG_MQTT_SECURITY_ON) + if (client->ctx != NULL) + { + SSL_CTX_free(client->ctx); + client->ctx = NULL; + } +#endif + } /* * mqtt_connect @@ -99,11 +198,14 @@ static bool mqtt_connect(mqtt_client *client) mqtt_info("Sending MQTT CONNECT message, type: %d, id: %04X", client->mqtt_state.pending_msg_type, client->mqtt_state.pending_msg_id); - write_len = write(client->socket, + + write_len = ClientWrite( client->mqtt_state.outbound_message->data, client->mqtt_state.outbound_message->length); + mqtt_info("Reading MQTT CONNECT response message"); - read_len = read(client->socket, client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); + + read_len = ClientRead(client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); tv.tv_sec = 0; /* No timeout */ setsockopt(client->socket, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(struct timeval)); @@ -152,7 +254,8 @@ void mqtt_sending_task(void *pvParameters) rb_read(&client->send_rb, client->mqtt_state.out_buffer, send_len); client->mqtt_state.pending_msg_type = mqtt_get_type(client->mqtt_state.out_buffer); client->mqtt_state.pending_msg_id = mqtt_get_id(client->mqtt_state.out_buffer, send_len); - write(client->socket, client->mqtt_state.out_buffer, send_len); + ClientWrite(client->mqtt_state.out_buffer, send_len); + //TODO: Check sending type, to callback publish message msg_len -= send_len; } @@ -168,7 +271,7 @@ void mqtt_sending_task(void *pvParameters) client->mqtt_state.pending_msg_id = mqtt_get_id(client->mqtt_state.outbound_message->data, client->mqtt_state.outbound_message->length); mqtt_info("Sending pingreq"); - write(client->socket, + ClientWrite( client->mqtt_state.outbound_message->data, client->mqtt_state.outbound_message->length); } @@ -207,7 +310,8 @@ void deliver_publish(mqtt_client *client, uint8_t *message, int length) mqtt_offset += mqtt_len; if (client->mqtt_state.message_length_read >= client->mqtt_state.message_length) break; - len_read = read(client->socket, client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); + + len_read = ClientRead(client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); client->mqtt_state.message_length_read += len_read; } while (1); @@ -220,7 +324,9 @@ void mqtt_start_receive_schedule(mqtt_client *client) uint16_t msg_id; while (1) { - read_len = read(client->socket, client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); + + read_len = ClientRead(client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); + mqtt_info("Read len %d", read_len); if (read_len == 0) break; @@ -309,9 +415,11 @@ void mqtt_task(void *pvParameters) mqtt_client *client = (mqtt_client *)pvParameters; while (1) { - client->socket = client_connect(client->settings->host, client->settings->port); + client_connect(client); + mqtt_info("Connected to server %s:%d", client->settings->host, client->settings->port); if (!mqtt_connect(client)) { + closeclient(client); continue; //return; } @@ -323,8 +431,8 @@ void mqtt_task(void *pvParameters) mqtt_info("mqtt_start_receive_schedule"); mqtt_start_receive_schedule(client); - - close(client->socket); + + closeclient(client); vTaskDelete(xMqttSendingTask); vTaskDelay(1000 / portTICK_RATE_MS); @@ -336,6 +444,8 @@ void mqtt_task(void *pvParameters) mqtt_client *mqtt_start(mqtt_settings *settings) { + int stackSize = 2048; + uint8_t *rb_buf; if (xMqttTask != NULL) return NULL; @@ -368,7 +478,13 @@ mqtt_client *mqtt_start(mqtt_settings *settings) client->mqtt_state.out_buffer_length = CONFIG_MQTT_BUFFER_SIZE_BYTE; client->mqtt_state.connect_info = &client->connect_info; + client->socket = -1; +#if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL + client->ctx = NULL; + client->ssl = NULL; + stackSize = 10240; // Need more stack to handle SSL handshake +#endif /* Create a queue capable of containing 64 unsigned long values. */ client->xSendingQueue = xQueueCreate(64, sizeof( uint32_t )); @@ -385,7 +501,7 @@ mqtt_client *mqtt_start(mqtt_settings *settings) client->mqtt_state.out_buffer, client->mqtt_state.out_buffer_length); - xTaskCreate(&mqtt_task, "mqtt_task", 2048, client, CONFIG_MQTT_PRIORITY, &xMqttTask); + xTaskCreate(&mqtt_task, "mqtt_task", stackSize, client, CONFIG_MQTT_PRIORITY, &xMqttTask); return client; } @@ -411,6 +527,7 @@ void mqtt_publish(mqtt_client* client, char *topic, char *data, int len, int qos client->send_rb.fill_cnt, client->send_rb.size); } + void mqtt_stop() {