diff --git a/component.mk b/component.mk index b045d65..229049d 100644 --- a/component.mk +++ b/component.mk @@ -9,6 +9,5 @@ COMPONENT_ADD_INCLUDEDIRS := include #COMPONENT_PRIV_INCLUDEDIRS := -COMPONENT_SRCDIRS := . #EXTRA_CFLAGS := -DICACHE_RODATA_ATTR CFLAGS += -Wno-error=implicit-function-declaration -Wno-error=format= -DHAVE_CONFIG_H diff --git a/include/mqtt.h b/include/mqtt.h index 7e6c276..a54e147 100644 --- a/include/mqtt.h +++ b/include/mqtt.h @@ -7,29 +7,46 @@ #include "ringbuf.h" -#if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL -#include "openssl/ssl.h" +typedef struct mqtt_client mqtt_client; +typedef struct mqtt_event_data_t mqtt_event_data_t; - #define ClientRead(buf,num) SSL_read(client->ssl, buf, num) - #define ClientWrite(buf,num) SSL_write(client->ssl, buf, num) +/** + * \return True on connect success, false on error + */ +typedef bool (* mqtt_connect_callback)(mqtt_client *client); +/** + */ +typedef void (* mqtt_disconnect_callback)(mqtt_client *client); +/** + * \param[out] buffer Pointer to buffer to fill + * \param[in] len Number of bytes to read + * \param[in] timeout_ms Time to wait for completion, or 0 for no timeout + * \return Number of bytes read, less than 0 on error + */ +typedef int (* mqtt_read_callback)(mqtt_client *client, void *buffer, int len, int timeout_ms); +/** + * \param[in] buffer Pointer to buffer to write + * \param[in] len Number of bytes to write + * \param[in] timeout_ms Time to wait for completion, or 0 for no timeout + * \return Number of bytes written, less than 0 on error + */ +typedef int (* mqtt_write_callback)(mqtt_client *client, const void *buffer, int len, int timeout_ms); +typedef void (* mqtt_event_callback)(mqtt_client *client, mqtt_event_data_t *event_data); -#else +typedef struct mqtt_settings { + mqtt_connect_callback connect_cb; + mqtt_disconnect_callback disconnect_cb; - #define ClientRead(buf,num) read(client->socket, buf, num) - #define ClientWrite(buf,num) write(client->socket, buf, num) -#endif + mqtt_read_callback read_cb; + mqtt_write_callback write_cb; + mqtt_event_callback connected_cb; + mqtt_event_callback disconnected_cb; // unused + mqtt_event_callback reconnect_cb; // unused -typedef void (* mqtt_callback)(void *, void *); - -typedef struct { - mqtt_callback connected_cb; - mqtt_callback disconnected_cb; - mqtt_callback reconnect_cb; - - mqtt_callback subscribe_cb; - mqtt_callback publish_cb; - mqtt_callback data_cb; + mqtt_event_callback subscribe_cb; + mqtt_event_callback publish_cb; + mqtt_event_callback data_cb; char host[CONFIG_MQTT_MAX_HOST_LEN]; uint32_t port; @@ -73,7 +90,7 @@ typedef struct mqtt_state_t int pending_publish_qos; } mqtt_state_t; -typedef struct { +typedef struct mqtt_client { int socket; #if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL diff --git a/mqtt.c b/mqtt.c index e415a82..40b5347 100644 --- a/mqtt.c +++ b/mqtt.c @@ -13,6 +13,9 @@ #include "lwip/sockets.h" #include "lwip/dns.h" #include "lwip/netdb.h" +#if defined(CONFIG_MQTT_SECURITY_ON) +#include "openssl/ssl.h" +#endif #include "ringbuf.h" #include "mqtt.h" @@ -176,6 +179,65 @@ void closeclient(mqtt_client *client) #endif } + +int mqtt_read(mqtt_client *client, void *buffer, int len, int timeout_ms) +{ + int result; + struct timeval tv; + if (timeout_ms > 0) { + tv.tv_sec = 0; + tv.tv_usec = timeout_ms * 1000; + while (tv.tv_usec > 1000 * 1000) { + tv.tv_usec -= 1000 * 1000; + tv.tv_sec++; + } + setsockopt(client->socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + } + +#if defined(CONFIG_MQTT_SECURITY_ON) + result = SSL_read(client->ssl, buffer, len); +#else + result = read(client->socket, buffer, len); +#endif + + if (timeout_ms > 0) { + tv.tv_sec = 0; + tv.tv_usec = 0; + setsockopt(client->socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + } + + return result; +} + +int mqtt_write(mqtt_client *client, const void *buffer, int len, int timeout_ms) +{ + int result; + struct timeval tv; + if (timeout_ms > 0) { + tv.tv_sec = 0; + tv.tv_usec = timeout_ms * 1000; + while (tv.tv_usec > 1000 * 1000) { + tv.tv_usec -= 1000 * 1000; + tv.tv_sec++; + } + setsockopt(client->socket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); + } + +#if defined(CONFIG_MQTT_SECURITY_ON) + result = SSL_write(client->ssl, buffer, len) +#else + result = write(client->socket, buffer, len); +#endif + + if (timeout_ms > 0) { + tv.tv_sec = 0; + tv.tv_usec = 0; + setsockopt(client->socket, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); + } + + return result; +} + /* * mqtt_connect * input - client @@ -184,12 +246,7 @@ void closeclient(mqtt_client *client) static bool mqtt_connect(mqtt_client *client) { int write_len, read_len, connect_rsp_code; - struct timeval tv; - tv.tv_sec = 10; /* 30 Secs Timeout */ - tv.tv_usec = 0; // Not init'ing this can cause strange errors - - setsockopt(client->socket, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(struct timeval)); mqtt_msg_init(&client->mqtt_state.mqtt_connection, client->mqtt_state.out_buffer, @@ -203,16 +260,13 @@ static bool mqtt_connect(mqtt_client *client) client->mqtt_state.pending_msg_type, client->mqtt_state.pending_msg_id); - write_len = ClientWrite( + write_len = client->settings->write_cb(client, client->mqtt_state.outbound_message->data, - client->mqtt_state.outbound_message->length); + client->mqtt_state.outbound_message->length, 0); mqtt_info("Reading MQTT CONNECT response message"); - read_len = ClientRead(client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); - - tv.tv_sec = 0; /* No timeout */ - setsockopt(client->socket, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(struct timeval)); + read_len = client->settings->read_cb(client, client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE, 10 * 1000); if (read_len < 0) { mqtt_error("Error network response"); @@ -258,7 +312,7 @@ void mqtt_sending_task(void *pvParameters) rb_read(&client->send_rb, client->mqtt_state.out_buffer, send_len); client->mqtt_state.pending_msg_type = mqtt_get_type(client->mqtt_state.out_buffer); client->mqtt_state.pending_msg_id = mqtt_get_id(client->mqtt_state.out_buffer, send_len); - ClientWrite(client->mqtt_state.out_buffer, send_len); + client->settings->write_cb(client, client->mqtt_state.out_buffer, send_len, 0); //TODO: Check sending type, to callback publish message msg_len -= send_len; @@ -275,9 +329,9 @@ void mqtt_sending_task(void *pvParameters) client->mqtt_state.pending_msg_id = mqtt_get_id(client->mqtt_state.outbound_message->data, client->mqtt_state.outbound_message->length); mqtt_info("Sending pingreq"); - ClientWrite( + client->settings->write_cb(client, client->mqtt_state.outbound_message->data, - client->mqtt_state.outbound_message->length); + client->mqtt_state.outbound_message->length, 0); } } } @@ -315,7 +369,7 @@ void deliver_publish(mqtt_client *client, uint8_t *message, int length) if (client->mqtt_state.message_length_read >= client->mqtt_state.message_length) break; - len_read = ClientRead(client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); + len_read = client->settings->read_cb(client, client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE, 0); client->mqtt_state.message_length_read += len_read; } while (1); @@ -329,7 +383,7 @@ void mqtt_start_receive_schedule(mqtt_client *client) while (1) { - read_len = ClientRead(client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE); + read_len = client->settings->read_cb(client, client->mqtt_state.in_buffer, CONFIG_MQTT_BUFFER_SIZE_BYTE, 0); mqtt_info("Read len %d", read_len); if (read_len == 0) @@ -419,11 +473,11 @@ void mqtt_task(void *pvParameters) mqtt_client *client = (mqtt_client *)pvParameters; while (1) { - client_connect(client); + client->settings->connect_cb(client); mqtt_info("Connected to server %s:%d", client->settings->host, client->settings->port); if (!mqtt_connect(client)) { - closeclient(client); + client->settings->disconnect_cb(client); continue; //return; } @@ -436,7 +490,7 @@ void mqtt_task(void *pvParameters) mqtt_info("mqtt_start_receive_schedule"); mqtt_start_receive_schedule(client); - closeclient(client); + client->settings->disconnect_cb(client); vTaskDelete(xMqttSendingTask); vTaskDelay(1000 / portTICK_RATE_MS); @@ -484,6 +538,15 @@ mqtt_client *mqtt_start(mqtt_settings *settings) client->socket = -1; + if (!client->settings->connect_cb) + client->settings->connect_cb = client_connect; + if (!client->settings->disconnect_cb) + client->settings->disconnect_cb = closeclient; + if (!client->settings->read_cb) + client->settings->read_cb = mqtt_read; + if (!client->settings->write_cb) + client->settings->write_cb = mqtt_write; + #if defined(CONFIG_MQTT_SECURITY_ON) // ENABLE MQTT OVER SSL client->ctx = NULL; client->ssl = NULL;