diff --git a/README.md b/README.md index f083ab0..77e552a 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,8 @@ const esp_mqtt_client_config_t mqtt_cfg = { - `task_prio, task_stack` for MQTT task, default priority is 5, and task_stack = 6144 bytes (or default task stack can be set via `make menucofig`). - `buffer_size` for MQTT send/receive buffer, default is 1024 - `cert_pem` pointer to CERT file for server verify (with SSL), default is NULL, not required to verify the server +- `client_cert_pem` pointer to CERT file for SSL mutual authentication, default is NULL, not required if mutual authentication is not needed. If it is not NULL, also `client_key_pem` has to be provided. +- `client_key_pem` pointer to PEM private key file for SSL mutual authentication, default is NULL, not required if mutual authentication is not needed. If it is not NULL, also `client_cert_pem` has to be provided. - `transport`: override URI transport + `MQTT_TRANSPORT_OVER_TCP`: MQTT over TCP, using scheme: `mqtt` + `MQTT_TRANSPORT_OVER_SSL`: MQTT over SSL, using scheme: `mqtts` diff --git a/examples/mqtt_ssl_mutual_auth/CMakeLists.txt b/examples/mqtt_ssl_mutual_auth/CMakeLists.txt new file mode 100644 index 0000000..106b117 --- /dev/null +++ b/examples/mqtt_ssl_mutual_auth/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_minimum_required(VERSION 3.5) + +get_filename_component(DEV_ROOT "${CMAKE_CURRENT_SOURCE_DIR}" ABSOLUTE) + +set(PROJECT_ROOT "${DEV_ROOT}/") + +set(SUBMODULE_ROOT "${DEV_ROOT}/../../../") + +set(PROJECT_NAME "mqtt_ssl_mutual_auth") + +include($ENV{IDF_PATH}/tools/cmake/project.cmake) + +set(MAIN_SRCS ${PROJECT_ROOT}/main/app_main.c) + +set(EXTRA_COMPONENT_DIRS "${EXTRA_COMPONENT_DIRS} ${SUBMODULE_ROOT}") +set(BUILD_COMPONENTS "${BUILD_COMPONENTS} espmqtt") + +project(${PROJECT_NAME}) + diff --git a/examples/mqtt_ssl_mutual_auth/Makefile b/examples/mqtt_ssl_mutual_auth/Makefile new file mode 100644 index 0000000..c22f41d --- /dev/null +++ b/examples/mqtt_ssl_mutual_auth/Makefile @@ -0,0 +1,13 @@ +# +# This is a project Makefile. It is assumed the directory this Makefile resides in is a +# project subdirectory. +# +# +# This is a project Makefile. It is assumed the directory this Makefile resides in is a +# project subdirectory. +# +PROJECT_NAME := mqtt_ssl_mutual_auth +EXTRA_COMPONENT_DIRS += $(PROJECT_PATH)/../../../ + +include $(IDF_PATH)/make/project.mk + diff --git a/examples/mqtt_ssl_mutual_auth/README.md b/examples/mqtt_ssl_mutual_auth/README.md new file mode 100644 index 0000000..c415cdf --- /dev/null +++ b/examples/mqtt_ssl_mutual_auth/README.md @@ -0,0 +1,16 @@ +# ESPMQTT SSL Sample application + +Navigate to the main directory + +``` +cd main +``` + +Generate a client key and a CSR. When you are generating the CSR, do not use the default values. At a minimum, the CSR must include the Country, Organisation and Common Name fields. + +``` +openssl genrsa -out client.key +openssl req -out client.csr -key client.key -new +``` + +Paste the generated CSR in the [Mosquitto test certificate signer](https://test.mosquitto.org/ssl/index.php), click Submit and copy the downloaded `client.crt` in the `main` directory. diff --git a/examples/mqtt_ssl_mutual_auth/main/Kconfig.projbuild b/examples/mqtt_ssl_mutual_auth/main/Kconfig.projbuild new file mode 100644 index 0000000..1c9c2e6 --- /dev/null +++ b/examples/mqtt_ssl_mutual_auth/main/Kconfig.projbuild @@ -0,0 +1,15 @@ +menu "MQTT Application sample" + +config WIFI_SSID + string "WiFi SSID" + default "myssid" + help + SSID (network name) for the example to connect to. + +config WIFI_PASSWORD + string "WiFi Password" + default "mypassword" + help + WiFi password (WPA or WPA2) for the example to use. + +endmenu diff --git a/examples/mqtt_ssl_mutual_auth/main/app_main.c b/examples/mqtt_ssl_mutual_auth/main/app_main.c new file mode 100755 index 0000000..f290d15 --- /dev/null +++ b/examples/mqtt_ssl_mutual_auth/main/app_main.c @@ -0,0 +1,152 @@ +#include +#include +#include +#include +#include "esp_wifi.h" +#include "esp_system.h" +#include "nvs_flash.h" +#include "esp_event_loop.h" + +#include "freertos/FreeRTOS.h" +#include "freertos/task.h" +#include "freertos/semphr.h" +#include "freertos/queue.h" +#include "freertos/event_groups.h" + +#include "lwip/sockets.h" +#include "lwip/dns.h" +#include "lwip/netdb.h" + +#include "esp_log.h" +#include "mqtt_client.h" + +static const char *TAG = "MQTTS_SAMPLE"; + +static EventGroupHandle_t wifi_event_group; +const static int CONNECTED_BIT = BIT0; + + + +static esp_err_t wifi_event_handler(void *ctx, system_event_t *event) +{ + switch (event->event_id) { + case SYSTEM_EVENT_STA_START: + esp_wifi_connect(); + break; + case SYSTEM_EVENT_STA_GOT_IP: + xEventGroupSetBits(wifi_event_group, CONNECTED_BIT); + + break; + case SYSTEM_EVENT_STA_DISCONNECTED: + esp_wifi_connect(); + xEventGroupClearBits(wifi_event_group, CONNECTED_BIT); + break; + default: + break; + } + return ESP_OK; +} + +static void wifi_init(void) +{ + tcpip_adapter_init(); + wifi_event_group = xEventGroupCreate(); + ESP_ERROR_CHECK(esp_event_loop_init(wifi_event_handler, NULL)); + wifi_init_config_t cfg = WIFI_INIT_CONFIG_DEFAULT(); + ESP_ERROR_CHECK(esp_wifi_init(&cfg)); + ESP_ERROR_CHECK(esp_wifi_set_storage(WIFI_STORAGE_RAM)); + wifi_config_t wifi_config = { + .sta = { + .ssid = CONFIG_WIFI_SSID, + .password = CONFIG_WIFI_PASSWORD, + }, + }; + ESP_ERROR_CHECK(esp_wifi_set_mode(WIFI_MODE_STA)); + ESP_ERROR_CHECK(esp_wifi_set_config(ESP_IF_WIFI_STA, &wifi_config)); + ESP_LOGI(TAG, "start the WIFI SSID:[%s] password:[%s]", CONFIG_WIFI_SSID, "******"); + ESP_ERROR_CHECK(esp_wifi_start()); + ESP_LOGI(TAG, "Waiting for wifi"); + xEventGroupWaitBits(wifi_event_group, CONNECTED_BIT, false, true, portMAX_DELAY); +} + +extern const uint8_t client_cert_pem_start[] asm("_binary_client_crt_start"); +extern const uint8_t client_cert_pem_end[] asm("_binary_client_crt_end"); +extern const uint8_t client_key_pem_start[] asm("_binary_client_key_start"); +extern const uint8_t client_key_pem_end[] asm("_binary_client_key_end"); + +static esp_err_t mqtt_event_handler(esp_mqtt_event_handle_t event) +{ + esp_mqtt_client_handle_t client = event->client; + int msg_id; + // your_context_t *context = event->context; + switch (event->event_id) { + case MQTT_EVENT_CONNECTED: + ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED"); + msg_id = esp_mqtt_client_subscribe(client, "/topic/qos0", 0); + ESP_LOGI(TAG, "sent subscribe successful, msg_id=%d", msg_id); + + msg_id = esp_mqtt_client_subscribe(client, "/topic/qos1", 1); + ESP_LOGI(TAG, "sent subscribe successful, msg_id=%d", msg_id); + + msg_id = esp_mqtt_client_unsubscribe(client, "/topic/qos1"); + ESP_LOGI(TAG, "sent unsubscribe successful, msg_id=%d", msg_id); + break; + case MQTT_EVENT_DISCONNECTED: + ESP_LOGI(TAG, "MQTT_EVENT_DISCONNECTED"); + break; + + case MQTT_EVENT_SUBSCRIBED: + ESP_LOGI(TAG, "MQTT_EVENT_SUBSCRIBED, msg_id=%d", event->msg_id); + msg_id = esp_mqtt_client_publish(client, "/topic/qos0", "data", 0, 0, 0); + ESP_LOGI(TAG, "sent publish successful, msg_id=%d", msg_id); + break; + case MQTT_EVENT_UNSUBSCRIBED: + ESP_LOGI(TAG, "MQTT_EVENT_UNSUBSCRIBED, msg_id=%d", event->msg_id); + break; + case MQTT_EVENT_PUBLISHED: + ESP_LOGI(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id); + break; + case MQTT_EVENT_DATA: + ESP_LOGI(TAG, "MQTT_EVENT_DATA"); + printf("TOPIC=%.*s\r\n", event->topic_len, event->topic); + printf("DATA=%.*s\r\n", event->data_len, event->data); + break; + case MQTT_EVENT_ERROR: + ESP_LOGI(TAG, "MQTT_EVENT_ERROR"); + break; + } + return ESP_OK; +} + +static void mqtt_app_start(void) +{ + const esp_mqtt_client_config_t mqtt_cfg = { + .uri = "mqtts://test.mosquitto.org:8884", + .event_handle = mqtt_event_handler, + .client_cert_pem = (const char *)client_cert_pem_start, + .client_key_pem = (const char *)client_key_pem_start, + }; + + ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size()); + esp_mqtt_client_handle_t client = esp_mqtt_client_init(&mqtt_cfg); + esp_mqtt_client_start(client); +} + +void app_main() +{ + ESP_LOGI(TAG, "[APP] Startup.."); + ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size()); + ESP_LOGI(TAG, "[APP] IDF version: %s", esp_get_idf_version()); + + esp_log_level_set("*", ESP_LOG_INFO); + esp_log_level_set("MQTT_CLIENT", ESP_LOG_VERBOSE); + esp_log_level_set("TRANSPORT_TCP", ESP_LOG_VERBOSE); + esp_log_level_set("TRANSPORT_SSL", ESP_LOG_VERBOSE); + esp_log_level_set("TRANSPORT", ESP_LOG_VERBOSE); + esp_log_level_set("OUTBOX", ESP_LOG_VERBOSE); + + nvs_flash_init(); + wifi_init(); + mqtt_app_start(); + +} diff --git a/examples/mqtt_ssl_mutual_auth/main/component.mk b/examples/mqtt_ssl_mutual_auth/main/component.mk new file mode 100644 index 0000000..01adda5 --- /dev/null +++ b/examples/mqtt_ssl_mutual_auth/main/component.mk @@ -0,0 +1 @@ +COMPONENT_EMBED_TXTFILES := client.crt client.key diff --git a/include/mqtt_client.h b/include/mqtt_client.h index 01a2fe1..766705d 100755 --- a/include/mqtt_client.h +++ b/include/mqtt_client.h @@ -77,6 +77,8 @@ typedef struct { int task_stack; int buffer_size; const char *cert_pem; + const char *client_cert_pem; + const char *client_key_pem; esp_mqtt_transport_t transport; } esp_mqtt_client_config_t; diff --git a/lib/include/transport_ssl.h b/lib/include/transport_ssl.h index 2469aa5..fb09c3c 100644 --- a/lib/include/transport_ssl.h +++ b/lib/include/transport_ssl.h @@ -30,6 +30,8 @@ transport_handle_t transport_ssl_init(); * @param[in] len The length */ void transport_ssl_set_cert_data(transport_handle_t t, const char *data, int len); +void transport_ssl_set_client_cert_data(transport_handle_t t, const char *data, int len); +void transport_ssl_set_client_key_data(transport_handle_t t, const char *data, int len); #ifdef __cplusplus diff --git a/lib/transport_ssl.c b/lib/transport_ssl.c index a2169c3..7528e2e 100644 --- a/lib/transport_ssl.c +++ b/lib/transport_ssl.c @@ -35,10 +35,17 @@ typedef struct { mbedtls_ctr_drbg_context ctr_drbg; mbedtls_ssl_context ctx; mbedtls_x509_crt cacert; + mbedtls_x509_crt client_cert; + mbedtls_pk_context client_key; mbedtls_ssl_config conf; mbedtls_net_context client_fd; void *cert_pem_data; int cert_pem_len; + void *client_cert_pem_data; + int client_cert_pem_len; + void *client_key_pem_data; + int client_key_pem_len; + bool mutual_authentication; bool ssl_initialized; bool verify_server; } transport_ssl_t; @@ -91,6 +98,27 @@ static int ssl_connect(transport_handle_t t, const char *host, int port, int tim mbedtls_ssl_conf_authmode(&ssl->conf, MBEDTLS_SSL_VERIFY_NONE); } + if (ssl->client_cert_pem_data && ssl->client_key_pem_data) { + mbedtls_x509_crt_init(&ssl->client_cert); + mbedtls_pk_init(&ssl->client_key); + ssl->mutual_authentication = true; + if ((ret = mbedtls_x509_crt_parse(&ssl->client_cert, ssl->client_cert_pem_data, ssl->client_cert_pem_len + 1)) < 0) { + ESP_LOGE(TAG, "mbedtls_x509_crt_parse returned -0x%x\n\nDATA=%s,len=%d", -ret, (char*)ssl->client_cert_pem_data, ssl->client_cert_pem_len); + goto exit; + } + if ((ret = mbedtls_pk_parse_key(&ssl->client_key, ssl->client_key_pem_data, ssl->client_key_pem_len + 1, NULL, 0)) < 0) { + ESP_LOGE(TAG, "mbedtls_pk_parse_keyfile returned -0x%x\n\nDATA=%s,len=%d", -ret, (char*)ssl->client_key_pem_data, ssl->client_key_pem_len); + goto exit; + } + + if ((ret = mbedtls_ssl_conf_own_cert(&ssl->conf, &ssl->client_cert, &ssl->client_key)) < 0) { + ESP_LOGE(TAG, "mbedtls_ssl_conf_own_cert returned -0x%x\n", -ret); + goto exit; + } + } else if (ssl->client_cert_pem_data || ssl->client_key_pem_data) { + ESP_LOGE(TAG, "You have to provide both client_cert_pem and client_key_pem for mutual authentication"); + goto exit; + } mbedtls_ssl_conf_rng(&ssl->conf, mbedtls_ctr_drbg_random, &ssl->ctr_drbg); @@ -220,9 +248,14 @@ static int ssl_close(transport_handle_t t) if (ssl->verify_server) { mbedtls_x509_crt_free(&ssl->cacert); } + if (ssl->mutual_authentication) { + mbedtls_x509_crt_free(&ssl->client_cert); + mbedtls_pk_free(&ssl->client_key); + } mbedtls_ctr_drbg_free(&ssl->ctr_drbg); mbedtls_entropy_free(&ssl->entropy); mbedtls_ssl_free(&ssl->ctx); + ssl->mutual_authentication = false; ssl->ssl_initialized = false; ssl->verify_server = false; } @@ -246,6 +279,24 @@ void transport_ssl_set_cert_data(transport_handle_t t, const char *data, int len } } +void transport_ssl_set_client_cert_data(transport_handle_t t, const char *data, int len) +{ + transport_ssl_t *ssl = transport_get_context_data(t); + if (t && ssl) { + ssl->client_cert_pem_data = (void *)data; + ssl->client_cert_pem_len = len; + } +} + +void transport_ssl_set_client_key_data(transport_handle_t t, const char *data, int len) +{ + transport_ssl_t *ssl = transport_get_context_data(t); + if (t && ssl) { + ssl->client_key_pem_data = (void *)data; + ssl->client_key_pem_len = len; + } +} + transport_handle_t transport_ssl_init() { transport_handle_t t = transport_init(); diff --git a/mqtt_client.c b/mqtt_client.c index f9f9a44..8f3c050 100644 --- a/mqtt_client.c +++ b/mqtt_client.c @@ -299,6 +299,12 @@ esp_mqtt_client_handle_t esp_mqtt_client_init(const esp_mqtt_client_config_t *co if (config->cert_pem) { transport_ssl_set_cert_data(ssl, config->cert_pem, strlen(config->cert_pem)); } + if (config->client_cert_pem) { + transport_ssl_set_client_cert_data(ssl, config->client_cert_pem, strlen(config->client_cert_pem)); + } + if (config->client_key_pem) { + transport_ssl_set_client_key_data(ssl, config->client_key_pem, strlen(config->client_key_pem)); + } transport_list_add(client->transport_list, ssl, "mqtts"); if (config->transport == MQTT_TRANSPORT_OVER_SSL) { client->config->scheme = create_string("mqtts", 5);