diff --git a/include/mqtt_config.h b/include/mqtt_config.h index c0b4ab7..972f2a1 100644 --- a/include/mqtt_config.h +++ b/include/mqtt_config.h @@ -10,6 +10,7 @@ #define MQTT_PROTOCOL_311 CONFIG_MQTT_PROTOCOL_311 #define MQTT_RECONNECT_TIMEOUT_MS (10*1000) +#define MQTT_POLL_READ_TIMEOUT_MS (1000) #if CONFIG_MQTT_BUFFER_SIZE #define MQTT_BUFFER_SIZE_BYTE CONFIG_MQTT_BUFFER_SIZE @@ -61,14 +62,18 @@ #define MQTT_CORE_SELECTION_ENABLED CONFIG_MQTT_TASK_CORE_SELECTION_ENABLED +#ifdef CONFIG_MQTT_DISABLE_API_LOCKS +#define MQTT_DISABLE_API_LOCKS CONFIG_MQTT_DISABLE_API_LOCKS +#endif + #ifdef CONFIG_MQTT_USE_CORE_0 - #define MQTT_TASK_CORE 0 + #define MQTT_TASK_CORE 0 #else - #ifdef CONFIG_MQTT_USE_CORE_1 - #define MQTT_TASK_CORE 1 - #else - #define MQTT_TASK_CORE 0 - #endif + #ifdef CONFIG_MQTT_USE_CORE_1 + #define MQTT_TASK_CORE 1 + #else + #define MQTT_TASK_CORE 0 + #endif #endif diff --git a/mqtt_client.c b/mqtt_client.c index 73d3624..8fe31cc 100644 --- a/mqtt_client.c +++ b/mqtt_client.c @@ -13,6 +13,18 @@ /* using uri parser */ #include "http_parser.h" +#ifdef MQTT_DISABLE_API_LOCKS +# define MQTT_API_LOCK(c) +# define MQTT_API_UNLOCK(c) +# define MQTT_API_LOCK_FROM_OTHER_TASK(c) +# define MQTT_API_UNLOCK_FROM_OTHER_TASK(c) +#else +# define MQTT_API_LOCK(c) xSemaphoreTake(c->api_lock, portMAX_DELAY) +# define MQTT_API_UNLOCK(c) xSemaphoreGive(c->api_lock) +# define MQTT_API_LOCK_FROM_OTHER_TASK(c) { if (c->task_handle != xTaskGetCurrentTaskHandle()) { xSemaphoreTake(c->api_lock, portMAX_DELAY); } } +# define MQTT_API_UNLOCK_FROM_OTHER_TASK(c) { if (c->task_handle != xTaskGetCurrentTaskHandle()) { xSemaphoreGive(c->api_lock); } } +#endif /* MQTT_USE_API_LOCKS */ + static const char *TAG = "MQTT_CLIENT"; typedef struct mqtt_state @@ -72,6 +84,8 @@ struct esp_mqtt_client { bool wait_for_ping_resp; outbox_handle_t outbox; EventGroupHandle_t status_bits; + SemaphoreHandle_t api_lock; + TaskHandle_t task_handle; }; const static int STOPPED_BIT = BIT0; @@ -87,6 +101,7 @@ static char *create_string(const char *ptr, int len); esp_err_t esp_mqtt_set_config(esp_mqtt_client_handle_t client, const esp_mqtt_client_config_t *config) { + MQTT_API_LOCK(client); //Copy user configurations to client context esp_err_t err = ESP_OK; mqtt_config_storage_t *cfg; @@ -94,7 +109,10 @@ esp_err_t esp_mqtt_set_config(esp_mqtt_client_handle_t client, const esp_mqtt_cl cfg = client->config; } else { cfg = calloc(1, sizeof(mqtt_config_storage_t)); - ESP_MEM_CHECK(TAG, cfg, return ESP_ERR_NO_MEM); + ESP_MEM_CHECK(TAG, cfg, { + MQTT_API_UNLOCK(client); + return ESP_ERR_NO_MEM; + }); client->config = cfg; } if (config->task_prio) { @@ -200,9 +218,11 @@ esp_err_t esp_mqtt_set_config(esp_mqtt_client_handle_t client, const esp_mqtt_cl if (config->disable_auto_reconnect == cfg->auto_reconnect) { cfg->auto_reconnect = !config->disable_auto_reconnect; } + MQTT_API_UNLOCK(client); return ESP_OK; _mqtt_set_config_failed: esp_mqtt_destroy_config(client); + MQTT_API_UNLOCK(client); return err; } @@ -302,8 +322,13 @@ esp_mqtt_client_handle_t esp_mqtt_client_init(const esp_mqtt_client_config_t *co { esp_mqtt_client_handle_t client = calloc(1, sizeof(struct esp_mqtt_client)); ESP_MEM_CHECK(TAG, client, return NULL); - + client->api_lock = xSemaphoreCreateMutex(); + if (!client->api_lock) { + free(client); + return NULL; + } esp_mqtt_set_config(client, config); + MQTT_API_LOCK(client); client->transport_list = esp_transport_list_init(); ESP_MEM_CHECK(TAG, client->transport_list, goto _mqtt_init_failed); @@ -384,9 +409,11 @@ esp_mqtt_client_handle_t esp_mqtt_client_init(const esp_mqtt_client_config_t *co ESP_MEM_CHECK(TAG, client->outbox, goto _mqtt_init_failed); client->status_bits = xEventGroupCreate(); ESP_MEM_CHECK(TAG, client->status_bits, goto _mqtt_init_failed); + MQTT_API_UNLOCK(client); return client; _mqtt_init_failed: esp_mqtt_client_destroy(client); + MQTT_API_UNLOCK(client); return NULL; } @@ -399,6 +426,7 @@ esp_err_t esp_mqtt_client_destroy(esp_mqtt_client_handle_t client) vEventGroupDelete(client->status_bits); free(client->mqtt_state.in_buffer); free(client->mqtt_state.out_buffer); + vSemaphoreDelete(client->api_lock); free(client); return ESP_OK; } @@ -638,7 +666,7 @@ static esp_err_t mqtt_process_receive(esp_mqtt_client_handle_t client) uint16_t msg_id; uint32_t transport_message_offset = 0 ; - read_len = esp_transport_read(client->transport, (char *)client->mqtt_state.in_buffer, client->mqtt_state.in_buffer_length, 1000); + read_len = esp_transport_read(client->transport, (char *)client->mqtt_state.in_buffer, client->mqtt_state.in_buffer_length, 0); if (read_len < 0) { ESP_LOGE(TAG, "Read error or end of stream"); @@ -776,7 +804,7 @@ static void esp_mqtt_task(void *pv) client->state = MQTT_STATE_INIT; xEventGroupClearBits(client->status_bits, STOPPED_BIT); while (client->run) { - + MQTT_API_LOCK(client); switch ((int)client->state) { case MQTT_STATE_INIT: xEventGroupClearBits(client->status_bits, RECONNECT_BIT); @@ -873,10 +901,20 @@ static void esp_mqtt_task(void *pv) ESP_LOGD(TAG, "Reconnecting..."); break; } + MQTT_API_UNLOCK(client); xEventGroupWaitBits(client->status_bits, RECONNECT_BIT, false, true, client->wait_timeout_ms / 2 / portTICK_RATE_MS); - break; + // continue the while loop insted of break, as the mutex is unlocked + continue; } + MQTT_API_UNLOCK(client); + if (MQTT_STATE_CONNECTED == client->state) { + if (esp_transport_poll_read(client->transport, MQTT_POLL_READ_TIMEOUT_MS) < 0) { + ESP_LOGE(TAG, "Poll read error: %d, aborting connection", errno); + esp_mqtt_abort_connection(client); + } + } + } esp_transport_close(client->transport); xEventGroupSetBits(client->status_bits, STOPPED_BIT); @@ -892,13 +930,13 @@ esp_err_t esp_mqtt_client_start(esp_mqtt_client_handle_t client) } #if MQTT_CORE_SELECTION_ENABLED ESP_LOGD(TAG, "Core selection enabled on %u", MQTT_TASK_CORE); - if (xTaskCreatePinnedToCore(esp_mqtt_task, "mqtt_task", client->config->task_stack, client, client->config->task_prio, NULL, MQTT_TASK_CORE) != pdTRUE) { + if (xTaskCreatePinnedToCore(esp_mqtt_task, "mqtt_task", client->config->task_stack, client, client->config->task_prio, &client->task_handle, MQTT_TASK_CORE) != pdTRUE) { ESP_LOGE(TAG, "Error create mqtt task"); return ESP_FAIL; } #else ESP_LOGD(TAG, "Core selection disabled"); - if (xTaskCreate(esp_mqtt_task, "mqtt_task", client->config->task_stack, client, client->config->task_prio, NULL) != pdTRUE) { + if (xTaskCreate(esp_mqtt_task, "mqtt_task", client->config->task_stack, client, client->config->task_prio, &client->task_handle) != pdTRUE) { ESP_LOGE(TAG, "Error create mqtt task"); return ESP_FAIL; } @@ -949,6 +987,7 @@ int esp_mqtt_client_subscribe(esp_mqtt_client_handle_t client, const char *topic ESP_LOGE(TAG, "Client has not connected"); return -1; } + MQTT_API_LOCK_FROM_OTHER_TASK(client); client->mqtt_state.outbound_message = mqtt_msg_subscribe(&client->mqtt_state.mqtt_connection, topic, qos, &client->mqtt_state.pending_msg_id); @@ -959,10 +998,12 @@ int esp_mqtt_client_subscribe(esp_mqtt_client_handle_t client, const char *topic if (mqtt_write_data(client) != ESP_OK) { ESP_LOGE(TAG, "Error to subscribe topic=%s, qos=%d", topic, qos); + MQTT_API_UNLOCK_FROM_OTHER_TASK(client); return -1; } ESP_LOGD(TAG, "Sent subscribe topic=%s, id: %d, type=%d successful", topic, client->mqtt_state.pending_msg_id, client->mqtt_state.pending_msg_type); + MQTT_API_UNLOCK_FROM_OTHER_TASK(client); return client->mqtt_state.pending_msg_id; } @@ -972,7 +1013,7 @@ int esp_mqtt_client_unsubscribe(esp_mqtt_client_handle_t client, const char *top ESP_LOGE(TAG, "Client has not connected"); return -1; } - mqtt_enqueue(client); + MQTT_API_LOCK_FROM_OTHER_TASK(client); client->mqtt_state.outbound_message = mqtt_msg_unsubscribe(&client->mqtt_state.mqtt_connection, topic, &client->mqtt_state.pending_msg_id); @@ -980,13 +1021,16 @@ int esp_mqtt_client_unsubscribe(esp_mqtt_client_handle_t client, const char *top client->mqtt_state.pending_msg_type = mqtt_get_type(client->mqtt_state.outbound_message->data); client->mqtt_state.pending_msg_count ++; + mqtt_enqueue(client); if (mqtt_write_data(client) != ESP_OK) { ESP_LOGE(TAG, "Error to unsubscribe topic=%s", topic); + MQTT_API_UNLOCK_FROM_OTHER_TASK(client); return -1; } ESP_LOGD(TAG, "Sent Unsubscribe topic=%s, id: %d, successful", topic, client->mqtt_state.pending_msg_id); + MQTT_API_UNLOCK_FROM_OTHER_TASK(client); return client->mqtt_state.pending_msg_id; } @@ -998,6 +1042,7 @@ int esp_mqtt_client_publish(esp_mqtt_client_handle_t client, const char *topic, len = strlen(data); } + MQTT_API_LOCK_FROM_OTHER_TASK(client); mqtt_message_t *publish_msg = mqtt_msg_publish(&client->mqtt_state.mqtt_connection, topic, data, len, qos, retain, @@ -1072,12 +1117,14 @@ int esp_mqtt_client_publish(esp_mqtt_client_handle_t client, const char *topic, if (qos > 0) { outbox_set_pending(client->outbox, pending_msg_id, TRANSMITTED); } + MQTT_API_UNLOCK_FROM_OTHER_TASK(client); return pending_msg_id; cannot_publish: if (qos == 0) { - ESP_LOGW(TAG, "Publishing qos0 data while client not connecting"); + ESP_LOGW(TAG, "Publish: Loosing qos0 data when client not connected"); } + MQTT_API_UNLOCK_FROM_OTHER_TASK(client); return 0; }