From e0a0bafbf1782920a7f30edcc71fe3bc74849e58 Mon Sep 17 00:00:00 2001 From: Bruno Iljazovic Date: Thu, 7 Dec 2023 12:18:28 +0100 Subject: [PATCH] thread-safe client Reviewers: ivica Reviewed By: ivica Subscribers: korina Differential Revision: https://repo.mireo.local/D26864 --- include/async_mqtt5/detail/control_packet.hpp | 4 ---- include/async_mqtt5/detail/internal_types.hpp | 2 ++ include/async_mqtt5/impl/client_service.hpp | 23 +++++++++++++++++++ include/async_mqtt5/impl/connect_op.hpp | 6 ++++- include/async_mqtt5/impl/disconnect_op.hpp | 4 ++-- include/async_mqtt5/impl/publish_send_op.hpp | 21 ++++++++++++++--- include/async_mqtt5/impl/subscribe_op.hpp | 9 ++++++++ include/async_mqtt5/impl/unsubscribe_op.hpp | 9 ++++++++ 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/include/async_mqtt5/detail/control_packet.hpp b/include/async_mqtt5/detail/control_packet.hpp index b0c0440..fb00ea2 100644 --- a/include/async_mqtt5/detail/control_packet.hpp +++ b/include/async_mqtt5/detail/control_packet.hpp @@ -1,7 +1,6 @@ #ifndef ASYNC_MQTT5_CONTROL_PACKET_HPP #define ASYNC_MQTT5_CONTROL_PACKET_HPP -#include #include #include @@ -115,7 +114,6 @@ class packet_id_allocator { {} }; - std::mutex _mtx; std::vector _free_ids; static constexpr uint16_t MAX_PACKET_ID = 65535; @@ -125,7 +123,6 @@ public: } uint16_t allocate() { - std::lock_guard _(_mtx); if (_free_ids.empty()) return 0; auto& last = _free_ids.back(); if (last.start == ++last.end) { @@ -137,7 +134,6 @@ public: } void free(uint16_t pid) { - std::lock_guard _(_mtx); auto it = std::upper_bound( _free_ids.begin(), _free_ids.end(), pid, [](const uint16_t x, const interval& i) { return x > i.start; } diff --git a/include/async_mqtt5/detail/internal_types.hpp b/include/async_mqtt5/detail/internal_types.hpp index da2c903..ff351c2 100644 --- a/include/async_mqtt5/detail/internal_types.hpp +++ b/include/async_mqtt5/detail/internal_types.hpp @@ -3,6 +3,7 @@ #include #include +#include #include @@ -70,6 +71,7 @@ struct mqtt_ctx { credentials creds; std::optional will_msg; connect_props co_props; + std::shared_mutex ca_mtx; connack_props ca_props; session_state state; any_authenticator authenticator; diff --git a/include/async_mqtt5/impl/client_service.hpp b/include/async_mqtt5/impl/client_service.hpp index 349becc..6babd20 100644 --- a/include/async_mqtt5/impl/client_service.hpp +++ b/include/async_mqtt5/impl/client_service.hpp @@ -61,9 +61,18 @@ public: template decltype(auto) connack_prop(Prop p) { + std::shared_lock reader_lock(_mqtt_context.ca_mtx); return _mqtt_context.ca_props[p]; } + template + decltype(auto) connack_props(Prop0 p0, Props ...props) { + std::shared_lock reader_lock(_mqtt_context.ca_mtx); + return std::make_tuple( + _mqtt_context.ca_props[p0], _mqtt_context.ca_props[props]... + ); + } + void credentials( std::string client_id, std::string username = "", std::string password = "" @@ -109,9 +118,18 @@ public: template decltype(auto) connack_prop(Prop p) { + std::shared_lock reader_lock(_mqtt_context.ca_mtx); return _mqtt_context.ca_props[p]; } + template + decltype(auto) connack_props(Prop0 p0, Props ...props) { + std::shared_lock reader_lock(_mqtt_context.ca_mtx); + return std::make_tuple( + _mqtt_context.ca_props[p0], _mqtt_context.ca_props[props]... + ); + } + void credentials( std::string client_id, std::string username = "", std::string password = "" @@ -243,6 +261,11 @@ public: return _stream_context.connack_prop(p); } + template + decltype(auto) connack_props(Prop0 p0, Props ...props) { + return _stream_context.connack_props(p0, props...); + } + void run() { _stream.open(); _rec_channel.reset(); diff --git a/include/async_mqtt5/impl/connect_op.hpp b/include/async_mqtt5/impl/connect_op.hpp index ae8c81e..b79c5cf 100644 --- a/include/async_mqtt5/impl/connect_op.hpp +++ b/include/async_mqtt5/impl/connect_op.hpp @@ -283,7 +283,11 @@ public: return complete(client::error::malformed_packet); const auto& [session_present, reason_code, ca_props] = *rv; - _ctx.ca_props = ca_props; + { + std::unique_lock writer_lock(_ctx.ca_mtx); + _ctx.ca_props = ca_props; + } + _ctx.state.session_present(session_present); // Unexpected result handling: diff --git a/include/async_mqtt5/impl/disconnect_op.hpp b/include/async_mqtt5/impl/disconnect_op.hpp index 9acba15..183bd74 100644 --- a/include/async_mqtt5/impl/disconnect_op.hpp +++ b/include/async_mqtt5/impl/disconnect_op.hpp @@ -66,10 +66,10 @@ public: static_cast(_context.reason_code), _context.props ); - send_disconnect(std::move(disconnect)); + asio::dispatch(asio::prepend(std::move(*this), std::move(disconnect))); } - void send_disconnect(control_packet disconnect) { + void operator()(control_packet disconnect) { const auto& wire_data = disconnect.wire_data(); _svc_ptr->async_send( diff --git a/include/async_mqtt5/impl/publish_send_op.hpp b/include/async_mqtt5/impl/publish_send_op.hpp index 775161b..f83f35b 100644 --- a/include/async_mqtt5/impl/publish_send_op.hpp +++ b/include/async_mqtt5/impl/publish_send_op.hpp @@ -103,6 +103,18 @@ public: if (ec) return complete_post(ec); + asio::dispatch( + asio::prepend( + std::move(*this), std::move(topic), + std::move(payload), retain, props + ) + ); + } + + void operator()( + std::string topic, std::string payload, + retain_e retain, const publish_props& props + ) { uint16_t packet_id = 0; if constexpr (qos_type != qos_e::at_most_once) { packet_id = _svc_ptr->allocate_pid(); @@ -334,15 +346,18 @@ private: if (!is_valid_utf8_topic(topic)) return client::error::invalid_topic; - auto max_qos = _svc_ptr->connack_prop(prop::maximum_qos); + const auto& [max_qos, retain_avail, topic_alias_max] = + _svc_ptr->connack_props( + prop::maximum_qos, prop::retain_available, + prop::topic_alias_maximum + ); + if (max_qos && uint8_t(qos_type) > *max_qos) return client::error::qos_not_supported; - auto retain_avail = _svc_ptr->connack_prop(prop::retain_available); if (retain_avail && *retain_avail == 0 && retain == retain_e::yes) return client::error::retain_not_available; - auto topic_alias_max = _svc_ptr->connack_prop(prop::topic_alias_maximum); auto topic_alias = props[prop::topic_alias]; if ( (!topic_alias_max || topic_alias_max && *topic_alias_max == 0) && diff --git a/include/async_mqtt5/impl/subscribe_op.hpp b/include/async_mqtt5/impl/subscribe_op.hpp index 70858c4..62f8ba8 100644 --- a/include/async_mqtt5/impl/subscribe_op.hpp +++ b/include/async_mqtt5/impl/subscribe_op.hpp @@ -66,6 +66,15 @@ public: if (ec) return complete_post(ec, topics.size()); + asio::dispatch( + asio::prepend(std::move(*this), topics, props) + ); + } + + void operator()( + const std::vector& topics, + const subscribe_props& props + ) { uint16_t packet_id = _svc_ptr->allocate_pid(); if (packet_id == 0) return complete_post(client::error::pid_overrun, topics.size()); diff --git a/include/async_mqtt5/impl/unsubscribe_op.hpp b/include/async_mqtt5/impl/unsubscribe_op.hpp index 5c6ccff..8b9aa91 100644 --- a/include/async_mqtt5/impl/unsubscribe_op.hpp +++ b/include/async_mqtt5/impl/unsubscribe_op.hpp @@ -63,6 +63,15 @@ public: if (ec) return complete_post(ec, topics.size()); + asio::dispatch( + asio::prepend(std::move(*this), topics, props) + ); + } + + void operator()( + const std::vector& topics, + const unsubscribe_props& props + ) { uint16_t packet_id = _svc_ptr->allocate_pid(); if (packet_id == 0) return complete_post(client::error::pid_overrun, topics.size());