From a5ef145163ff202fd0519d7142da8d5e2656b7d6 Mon Sep 17 00:00:00 2001 From: Tuan PM Date: Tue, 27 Feb 2018 23:02:35 +0700 Subject: [PATCH] check transport before use & close the connection if it connects failed --- lib/transport.c | 35 ++++++++++++++++++++++++++--------- lib/transport_ssl.c | 7 ++++++- mqtt_client.c | 1 + 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/lib/transport.c b/lib/transport.c index b2efd40..f443fca 100644 --- a/lib/transport.c +++ b/lib/transport.c @@ -110,7 +110,7 @@ int transport_destroy(transport_handle_t t) int transport_connect(transport_handle_t t, const char *host, int port, int timeout_ms) { int ret = -1; - if (t->_connect) { + if (t && t->_connect) { return t->_connect(t, host, port, timeout_ms); } return ret; @@ -118,7 +118,7 @@ int transport_connect(transport_handle_t t, const char *host, int port, int time int transport_read(transport_handle_t t, char *buffer, int len, int timeout_ms) { - if (t->_read) { + if (t && t->_read) { return t->_read(t, buffer, len, timeout_ms); } return -1; @@ -126,7 +126,7 @@ int transport_read(transport_handle_t t, char *buffer, int len, int timeout_ms) int transport_write(transport_handle_t t, char *buffer, int len, int timeout_ms) { - if (t->_write) { + if (t && t->_write) { return t->_write(t, buffer, len, timeout_ms); } return -1; @@ -134,7 +134,7 @@ int transport_write(transport_handle_t t, char *buffer, int len, int timeout_ms) int transport_poll_read(transport_handle_t t, int timeout_ms) { - if (t->_poll_read) { + if (t && t->_poll_read) { return t->_poll_read(t, timeout_ms); } return -1; @@ -142,7 +142,7 @@ int transport_poll_read(transport_handle_t t, int timeout_ms) int transport_poll_write(transport_handle_t t, int timeout_ms) { - if (t->_poll_write) { + if (t && t->_poll_write) { return t->_poll_write(t, timeout_ms); } return -1; @@ -150,7 +150,7 @@ int transport_poll_write(transport_handle_t t, int timeout_ms) int transport_close(transport_handle_t t) { - if (t->_close) { + if (t && t->_close) { return t->_close(t); } return 0; @@ -158,13 +158,19 @@ int transport_close(transport_handle_t t) void *transport_get_data(transport_handle_t t) { - return t->data; + if (t) { + return t->data; + } + return NULL; } esp_err_t transport_set_data(transport_handle_t t, void *data) { - t->data = data; - return ESP_OK; + if (t) { + t->data = data; + return ESP_OK; + } + return ESP_FAIL; } esp_err_t transport_set_func(transport_handle_t t, @@ -176,6 +182,9 @@ esp_err_t transport_set_func(transport_handle_t t, poll_func _poll_write, trans_func _destroy) { + if (t == NULL) { + return ESP_FAIL; + } t->_connect = _connect; t->_read = _read; t->_write = _write; @@ -185,12 +194,20 @@ esp_err_t transport_set_func(transport_handle_t t, t->_destroy = _destroy; return ESP_OK; } + int transport_get_default_port(transport_handle_t t) { + if (t == NULL) { + return -1; + } return t->port; } + esp_err_t transport_set_default_port(transport_handle_t t, int port) { + if (t == NULL) { + return ESP_FAIL; + } t->port = port; return ESP_OK; } diff --git a/lib/transport_ssl.c b/lib/transport_ssl.c index 6137488..420b020 100644 --- a/lib/transport_ssl.c +++ b/lib/transport_ssl.c @@ -32,6 +32,7 @@ typedef struct { void *cert_pem_data; int cert_pem_len; bool ssl_initialized; + bool verify_server; } transport_ssl_t; static int ssl_connect(transport_handle_t t, const char *host, int port, int timeout_ms) @@ -64,6 +65,7 @@ static int ssl_connect(transport_handle_t t, const char *host, int port, int tim if (ssl->cert_pem_data) { mbedtls_x509_crt_init(&ssl->cacert); + ssl->verify_server = true; if ((ret = mbedtls_x509_crt_parse(&ssl->cacert, ssl->cert_pem_data, ssl->cert_pem_len + 1)) < 0) { ESP_LOGE(TAG, "mbedtls_x509_crt_parse returned -0x%x\n\nDATA=%s,len=%d", -ret, (char*)ssl->cert_pem_data, ssl->cert_pem_len); goto exit; @@ -204,11 +206,14 @@ static int ssl_close(transport_handle_t t) mbedtls_ssl_session_reset(&ssl->ctx); mbedtls_net_free(&ssl->client_fd); mbedtls_ssl_config_free(&ssl->conf); - mbedtls_x509_crt_free(&ssl->cacert); + if (ssl->verify_server) { + mbedtls_x509_crt_free(&ssl->cacert); + } mbedtls_ctr_drbg_free(&ssl->ctr_drbg); mbedtls_entropy_free(&ssl->entropy); mbedtls_ssl_free(&ssl->ctx); ssl->ssl_initialized = false; + ssl->verify_server = false; } return ret; } diff --git a/mqtt_client.c b/mqtt_client.c index f54a19c..f3aaaed 100644 --- a/mqtt_client.c +++ b/mqtt_client.c @@ -249,6 +249,7 @@ static esp_err_t esp_mqtt_connect(esp_mqtt_client_handle_t client, int timeout_m static esp_err_t esp_mqtt_abort_connection(esp_mqtt_client_handle_t client) { + transport_close(client->transport); client->wait_timeout_ms = MQTT_RECONNECT_TIMEOUT_MS; client->reconnect_tick = platform_tick_get_ms(); client->state = MQTT_STATE_WAIT_TIMEOUT;