mirror of
				https://github.com/espressif/esp-protocols.git
				synced 2025-10-30 22:21:39 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			334 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			334 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /*
 | |
|  * SPDX-FileCopyrightText: 2021-2024 Espressif Systems (Shanghai) CO LTD
 | |
|  *
 | |
|  * SPDX-License-Identifier: Apache-2.0
 | |
|  */
 | |
| 
 | |
| #include <string>
 | |
| #include <algorithm>
 | |
| #include <stdexcept>
 | |
| #include <variant>
 | |
| 
 | |
| #include "mqtt_client.h"
 | |
| #include "esp_log.h"
 | |
| 
 | |
| #include "esp_mqtt.hpp"
 | |
| 
 | |
| namespace {
 | |
| 
 | |
| // Helper for static assert.
 | |
| template<class T>
 | |
| constexpr bool always_false = false;
 | |
| 
 | |
| template<class... Ts> struct overloaded : Ts... {
 | |
|     using Ts::operator()...;
 | |
| };
 | |
| template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;
 | |
| 
 | |
| using namespace idf::mqtt;
 | |
| 
 | |
| /*
 | |
|  *  This function is responsible for fill in the configurations for the broker related data
 | |
|  *  of mqtt_client_config_t
 | |
|  */
 | |
| void config_broker(esp_mqtt_client_config_t &mqtt_client_cfg, BrokerConfiguration const &broker)
 | |
| {
 | |
|     std::visit(overloaded{
 | |
|         [&mqtt_client_cfg](Host const & host)
 | |
|         {
 | |
|             mqtt_client_cfg.broker.address.hostname = host.address.c_str();
 | |
|             mqtt_client_cfg.broker.address.path = host.path.c_str();
 | |
|             mqtt_client_cfg.broker.address.transport = host.transport;
 | |
|         },
 | |
|         [&mqtt_client_cfg](URI const & uri)
 | |
|         {
 | |
|             mqtt_client_cfg.broker.address.uri = uri.address.c_str();
 | |
|         },
 | |
|         []([[maybe_unused ]]auto & unknown)
 | |
|         {
 | |
|             static_assert(always_false<decltype(unknown)>, "Missing type handler for variant handler");
 | |
|         }
 | |
|     },
 | |
|     broker.address.address);
 | |
| 
 | |
|     std::visit(overloaded{
 | |
|         []([[maybe_unused]]Insecure const & insecure) {},
 | |
|         [&mqtt_client_cfg](GlobalCAStore const & use_global_store)
 | |
|         {
 | |
|             mqtt_client_cfg.broker.verification.use_global_ca_store = true;
 | |
|         },
 | |
|         [&mqtt_client_cfg](CryptographicInformation const & certificates)
 | |
|         {
 | |
|             std::visit(overloaded{
 | |
|                 [&mqtt_client_cfg](PEM const & pem)
 | |
|                 {
 | |
|                     mqtt_client_cfg.broker.verification.certificate = pem.data;
 | |
|                 }, [&mqtt_client_cfg](DER const & der)
 | |
|                 {
 | |
|                     mqtt_client_cfg.broker.verification.certificate = der.data;
 | |
|                     mqtt_client_cfg.broker.verification.certificate_len = der.len;
 | |
|                 }}, certificates);
 | |
|         },
 | |
|         []([[maybe_unused]] PSK const & psk) {},
 | |
|         []([[maybe_unused]] auto & unknown)
 | |
|         {
 | |
|             static_assert(always_false<decltype(unknown)>, "Missing type handler for variant handler");
 | |
|         }
 | |
|     },
 | |
|     broker.security);
 | |
|     mqtt_client_cfg.broker.address.port = broker.address.port;
 | |
| }
 | |
| 
 | |
| /*
 | |
|  *  This function is responsible for fill in the configurations for the client credentials related data
 | |
|  *  of mqtt_client_config_t
 | |
|  */
 | |
| void config_client_credentials(esp_mqtt_client_config_t &mqtt_client_cfg, ClientCredentials const &credentials)
 | |
| {
 | |
|     mqtt_client_cfg.credentials.client_id = credentials.client_id.has_value() ?  credentials.client_id.value().c_str() : nullptr ;
 | |
|     mqtt_client_cfg.credentials.username = credentials.username.has_value() ?  credentials.username.value().c_str() : nullptr ;
 | |
|     std::visit(overloaded{
 | |
|         [&mqtt_client_cfg](Password const & password)
 | |
|         {
 | |
|             mqtt_client_cfg.credentials.authentication.password = password.data.c_str();
 | |
|         },
 | |
|         [&mqtt_client_cfg](ClientCertificate const & certificate)
 | |
|         {
 | |
|             std::visit(overloaded{
 | |
|                 [&mqtt_client_cfg](PEM const & pem)
 | |
|                 {
 | |
|                     mqtt_client_cfg.credentials.authentication.certificate = pem.data;
 | |
|                 }, [&mqtt_client_cfg](DER const & der)
 | |
|                 {
 | |
|                     mqtt_client_cfg.credentials.authentication.certificate = der.data;
 | |
|                     mqtt_client_cfg.credentials.authentication.certificate_len = der.len;
 | |
|                 }}, certificate.certificate);
 | |
|             std::visit(overloaded{
 | |
|                 [&mqtt_client_cfg](PEM const & pem)
 | |
|                 {
 | |
|                     mqtt_client_cfg.credentials.authentication.key = pem.data;
 | |
|                 }, [&mqtt_client_cfg](DER const & der)
 | |
|                 {
 | |
|                     mqtt_client_cfg.credentials.authentication.key = der.data;
 | |
|                     mqtt_client_cfg.credentials.authentication.key_len = der.len;
 | |
|                 }}, certificate.key);
 | |
|             if (certificate.key_password.has_value()) {
 | |
|                 mqtt_client_cfg.credentials.authentication.key_password = certificate.key_password.value().data.c_str();
 | |
|                 mqtt_client_cfg.credentials.authentication.key_password_len = static_cast<int>(certificate.key_password.value().data.size());
 | |
|             }
 | |
|         },
 | |
|         [&mqtt_client_cfg](SecureElement const & enable_secure_element)
 | |
|         {
 | |
|             mqtt_client_cfg.credentials.authentication.use_secure_element = true;
 | |
|         },
 | |
|         []([[maybe_unused ]]auto & unknown)
 | |
|         {
 | |
|             static_assert(always_false<decltype(unknown)>, "Missing type handler for variant handler");
 | |
|         }
 | |
|     }, credentials.authentication);
 | |
| }
 | |
| 
 | |
| esp_mqtt_client_config_t make_config(BrokerConfiguration const &broker, ClientCredentials const  &credentials, Configuration const &config)
 | |
| {
 | |
|     esp_mqtt_client_config_t mqtt_client_cfg{};
 | |
|     config_broker(mqtt_client_cfg, broker);
 | |
|     config_client_credentials(mqtt_client_cfg, credentials);
 | |
|     mqtt_client_cfg.session.keepalive = config.session.keepalive;
 | |
|     mqtt_client_cfg.session.last_will.msg = config.session.last_will.lwt_msg;
 | |
|     mqtt_client_cfg.session.last_will.topic = config.session.last_will.lwt_topic;
 | |
|     mqtt_client_cfg.session.last_will.msg_len = config.session.last_will.lwt_msg_len;
 | |
|     mqtt_client_cfg.session.last_will.qos = config.session.last_will.lwt_qos;
 | |
|     mqtt_client_cfg.session.last_will.retain = config.session.last_will.lwt_retain;
 | |
|     mqtt_client_cfg.session.protocol_ver = config.session.protocol_ver;
 | |
|     mqtt_client_cfg.session.disable_keepalive = config.session.disable_keepalive;
 | |
|     mqtt_client_cfg.network.reconnect_timeout_ms = config.connection.reconnect_timeout_ms;
 | |
|     mqtt_client_cfg.network.timeout_ms = config.connection.network_timeout_ms;
 | |
|     mqtt_client_cfg.network.disable_auto_reconnect = config.connection.disable_auto_reconnect;
 | |
|     mqtt_client_cfg.network.refresh_connection_after_ms = config.connection.refresh_connection_after_ms;
 | |
|     mqtt_client_cfg.task.priority = config.task.task_prio;
 | |
|     mqtt_client_cfg.task.stack_size = config.task.task_stack;
 | |
|     mqtt_client_cfg.buffer.size = config.buffer_size;
 | |
|     mqtt_client_cfg.buffer.out_size = config.out_buffer_size;
 | |
|     return mqtt_client_cfg;
 | |
| }
 | |
| }
 | |
| 
 | |
| namespace idf::mqtt {
 | |
| 
 | |
| Client::Client(BrokerConfiguration const &broker, ClientCredentials const  &credentials, Configuration const &config): Client(make_config(broker, credentials, config))  {}
 | |
| 
 | |
| Client::Client(esp_mqtt_client_config_t const &config) :  handler(esp_mqtt_client_init(&config))
 | |
| {
 | |
|     if (handler == nullptr) {
 | |
|         throw MQTTException(ESP_FAIL);
 | |
|     };
 | |
|     CHECK_THROW_SPECIFIC(esp_mqtt_client_register_event(handler.get(), MQTT_EVENT_ANY, mqtt_event_handler, this), mqtt::MQTTException);
 | |
|     CHECK_THROW_SPECIFIC(esp_mqtt_client_start(handler.get()), mqtt::MQTTException);
 | |
| }
 | |
| 
 | |
| void Client::mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data) noexcept
 | |
| {
 | |
|     ESP_LOGD(TAG, "Event dispatched from event loop base=%s, event_id=%" PRIu32, base, event_id);
 | |
|     auto *event = static_cast<esp_mqtt_event_t *>(event_data);
 | |
|     auto &client = *static_cast<Client *>(handler_args);
 | |
|     switch (event->event_id) {
 | |
|     case MQTT_EVENT_CONNECTED:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED");
 | |
|         client.on_connected(event);
 | |
|         break;
 | |
|     case MQTT_EVENT_DISCONNECTED:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_DISCONNECTED");
 | |
|         client.on_disconnected(event);
 | |
|         break;
 | |
| 
 | |
|     case MQTT_EVENT_SUBSCRIBED:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_SUBSCRIBED, msg_id=%d", event->msg_id);
 | |
|         client.on_subscribed(event);
 | |
|         break;
 | |
|     case MQTT_EVENT_UNSUBSCRIBED:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_UNSUBSCRIBED, msg_id=%d", event->msg_id);
 | |
|         client.on_unsubscribed(event);
 | |
|         break;
 | |
|     case MQTT_EVENT_PUBLISHED:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id);
 | |
|         client.on_published(event);
 | |
|         break;
 | |
|     case MQTT_EVENT_DATA:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_DATA");
 | |
|         client.on_data(event);
 | |
|         break;
 | |
|     case MQTT_EVENT_ERROR:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_ERROR");
 | |
|         client.on_error(event);
 | |
|         break;
 | |
|     case MQTT_EVENT_BEFORE_CONNECT:
 | |
|         ESP_LOGI(TAG, "MQTT_EVENT_BEFORE_CONNECT");
 | |
|         client.on_before_connect(event);
 | |
|         break;
 | |
|     default:
 | |
|         ESP_LOGI(TAG, "Other event id:%d", event->event_id);
 | |
|         break;
 | |
|     }
 | |
| }
 | |
| 
 | |
| void Client::on_error(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
|     auto log_error_if_nonzero = [](const char *message, int error_code) {
 | |
|         if (error_code != 0) {
 | |
|             ESP_LOGE(TAG, "Last error %s: 0x%x", message, error_code);
 | |
|         }
 | |
|     };
 | |
|     if (event->error_handle->error_type == MQTT_ERROR_TYPE_TCP_TRANSPORT) {
 | |
|         log_error_if_nonzero("reported from esp-tls", event->error_handle->esp_tls_last_esp_err);
 | |
|         log_error_if_nonzero("reported from tls stack", event->error_handle->esp_tls_stack_err);
 | |
|         log_error_if_nonzero("captured as transport's socket errno",  event->error_handle->esp_transport_sock_errno);
 | |
|         ESP_LOGI(TAG, "Last errno string (%s)", strerror(event->error_handle->esp_transport_sock_errno));
 | |
|     }
 | |
| }
 | |
| void Client::on_disconnected(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
| }
 | |
| void Client::on_subscribed(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
| }
 | |
| void Client::on_unsubscribed(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
| }
 | |
| void Client::on_published(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
| }
 | |
| void Client::on_before_connect(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
| }
 | |
| void Client::on_connected(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
| }
 | |
| void Client::on_data(esp_mqtt_event_handle_t const event)
 | |
| {
 | |
| }
 | |
| 
 | |
| std::optional<MessageID> Client::subscribe(std::string const &topic, QoS qos)
 | |
| {
 | |
|     auto res = esp_mqtt_client_subscribe(handler.get(), topic.c_str(),
 | |
|                                          static_cast<int>(qos));
 | |
|     if (res < 0) {
 | |
|         return std::nullopt;
 | |
|     }
 | |
|     return MessageID{res};
 | |
| }
 | |
| 
 | |
| bool is_valid(std::string::const_iterator first, std::string::const_iterator last)
 | |
| {
 | |
|     if (first == last) {
 | |
|         return false;
 | |
|     }
 | |
|     auto number = std::find(first, last, '#');
 | |
|     if (number != last) {
 | |
|         if (std::next(number) != last) {
 | |
|             return false;
 | |
|         }
 | |
|         if (*std::prev(number) != '/' && number != first) {
 | |
|             return false;
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     auto plus = std::find(first, last, '+');
 | |
|     if (plus != last) {
 | |
|         if (*(std::prev(plus)) != '/' && plus != first) {
 | |
|             return false;
 | |
|         }
 | |
|         if (std::next(plus) != last && *(std::next(plus)) != '/') {
 | |
|             return false;
 | |
|         }
 | |
|     }
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| Filter::Filter(std::string user_filter) : filter(std::move(user_filter))
 | |
| {
 | |
|     if (!is_valid(filter.begin(), filter.end())) {
 | |
|         throw std::domain_error("Forbidden Filter string");
 | |
|     }
 | |
| }
 | |
| 
 | |
| [[nodiscard]] bool Filter::match(std::string::const_iterator topic_begin, std::string::const_iterator topic_end) const noexcept
 | |
| {
 | |
|     auto filter_begin = filter.begin();
 | |
|     auto filter_end = filter.end();
 | |
|     for (auto mismatch = std::mismatch(filter_begin, filter_end, topic_begin);
 | |
|             mismatch.first != filter.end() and mismatch.second != topic_end;
 | |
|             mismatch = std::mismatch(filter_begin, filter_end, topic_begin)) {
 | |
|         if (*mismatch.first != '#' and * mismatch.first != '+') {
 | |
|             return false;
 | |
|         }
 | |
|         if (*mismatch.first == '#') {
 | |
|             return true;
 | |
|         }
 | |
|         if (*mismatch.first == '+') {
 | |
|             filter_begin = advance(mismatch.first, filter_end);
 | |
|             topic_begin = advance(mismatch.second, topic_end);
 | |
|             if (filter_begin == filter_end and topic_begin != topic_end) {
 | |
|                 return false;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     return true;
 | |
| }
 | |
| const std::string &Filter::get()
 | |
| {
 | |
|     return filter;
 | |
| }
 | |
| 
 | |
| [[nodiscard]] bool Filter::match(char *const begin, int size) const noexcept
 | |
| {
 | |
|     auto it = static_cast<std::string::const_iterator>(begin);
 | |
|     return match(it, it + size);
 | |
| }
 | |
| std::string::const_iterator Filter::advance(std::string::const_iterator begin, std::string::const_iterator end) const
 | |
| {
 | |
|     constexpr auto separator = '/';
 | |
|     return std::find(begin, end, separator);
 | |
| }
 | |
| 
 | |
| }
 |