From d78fdd32084aba798047e38befcc54d2b2e5c8c6 Mon Sep 17 00:00:00 2001 From: Bruno Iljazovic Date: Tue, 9 Jan 2024 15:18:58 +0100 Subject: [PATCH] use associated executors for intermediate handlers Summary: * per-operation cancellation changed: total/partial signals only prevent further resending, terminal signal cancels the whole client Reviewers: ivica Reviewed By: ivica Subscribers: korina Differential Revision: https://repo.mireo.local/D27246 --- .../detail/cancellable_handler.hpp | 147 ++++-------------- include/async_mqtt5/impl/assemble_op.hpp | 8 +- include/async_mqtt5/impl/async_sender.hpp | 30 +++- include/async_mqtt5/impl/client_service.hpp | 2 + include/async_mqtt5/impl/connect_op.hpp | 9 +- include/async_mqtt5/impl/disconnect_op.hpp | 6 +- include/async_mqtt5/impl/endpoints.hpp | 14 +- include/async_mqtt5/impl/publish_send_op.hpp | 109 ++++++------- include/async_mqtt5/impl/re_auth_op.hpp | 2 +- include/async_mqtt5/impl/read_op.hpp | 20 +-- include/async_mqtt5/impl/reconnect_op.hpp | 9 +- include/async_mqtt5/impl/replies.hpp | 73 +++++---- include/async_mqtt5/impl/sentry_op.hpp | 2 +- include/async_mqtt5/impl/subscribe_op.hpp | 49 +++--- include/async_mqtt5/impl/unsubscribe_op.hpp | 49 +++--- include/async_mqtt5/impl/write_op.hpp | 18 ++- include/async_mqtt5/types.hpp | 2 +- .../unit/include/test_common/test_service.hpp | 3 +- test/unit/test/cancellation.cpp | 143 ++++++++--------- test/unit/test/publish_send_op.cpp | 29 ---- test/unit/test/session.cpp | 1 + 21 files changed, 311 insertions(+), 414 deletions(-) diff --git a/include/async_mqtt5/detail/cancellable_handler.hpp b/include/async_mqtt5/detail/cancellable_handler.hpp index c854ffd..5f82507 100644 --- a/include/async_mqtt5/detail/cancellable_handler.hpp +++ b/include/async_mqtt5/detail/cancellable_handler.hpp @@ -1,13 +1,10 @@ #ifndef ASYNC_MQTT5_CANCELLABLE_HANDLER_HPP #define ASYNC_MQTT5_CANCELLABLE_HANDLER_HPP -#include - #include +#include #include -#include -#include -#include +#include #include #include @@ -15,141 +12,59 @@ namespace async_mqtt5::detail { -template < - typename Handler, typename Executor, - typename CancelArgs = std::tuple<> -> +template class cancellable_handler { - struct op_state { - Handler _handler; - tracking_type _handler_ex; - cancellable_handler* _owner; - - op_state( - Handler&& handler, const Executor& ex, - cancellable_handler* owner - ) : - _handler(std::move(handler)), - _handler_ex(tracking_executor(_handler, ex)), - _owner(owner) - {} - - void cancel_op() { - _owner->cancel(); - } - }; - - struct cancel_proxy { - std::weak_ptr _state_weak_ptr; - Executor _executor; - - cancel_proxy(std::shared_ptr state, const Executor& ex) : - _state_weak_ptr(std::move(state)), _executor(ex) - {} - - void operator()(asio::cancellation_type_t type) { - if ( - (type & asio::cancellation_type_t::terminal) == - asio::cancellation_type_t::none - ) - return; - - auto op = [wptr = _state_weak_ptr]() { - if (auto state = wptr.lock()) - state->cancel_op(); - }; - - asio::dispatch(_executor, std::move(op)); - } - }; - - std::shared_ptr _state; Executor _executor; + Handler _handler; + tracking_type _handler_ex; + asio::cancellation_state _cancellation_state; public: - cancellable_handler(Handler&& handler, const Executor& ex) { - auto alloc = asio::get_associated_allocator(handler); - _state = std::allocate_shared( - alloc, std::move(handler), ex, this - ); - - auto slot = asio::get_associated_cancellation_slot(_state->_handler); - if (slot.is_connected()) - slot.template emplace(_state, ex); - - _executor = ex; - } - - cancellable_handler(cancellable_handler&& other) noexcept : - _state(std::exchange(other._state, nullptr)), - _executor(std::move(other._executor)) - { - if (!empty()) - _state->_owner = this; - } + cancellable_handler(Handler&& handler, const Executor& ex) : + _executor(ex), + _handler(std::move(handler)), + _handler_ex(tracking_executor(_handler, ex)), + _cancellation_state( + asio::get_associated_cancellation_slot(_handler), + asio::enable_total_cancellation {}, + asio::enable_terminal_cancellation {} + ) + {} + cancellable_handler(cancellable_handler&& other) noexcept = default; cancellable_handler(const cancellable_handler&) = delete; - ~cancellable_handler() { - cancel(); - } - - bool empty() const noexcept { - return _state == nullptr; + using executor_type = tracking_type; + executor_type get_executor() const noexcept { + return _handler_ex; } using allocator_type = asio::associated_allocator_t; allocator_type get_allocator() const noexcept { - return asio::get_associated_allocator(_state->_handler); + return asio::get_associated_allocator(_handler); } - void cancel() { - if (empty()) return; + using cancellation_slot_type = asio::associated_cancellation_slot_t; + cancellation_slot_type get_cancellation_slot() const noexcept { + return _cancellation_state.slot(); + } - auto h = std::move(_state->_handler); - asio::get_associated_cancellation_slot(h).clear(); - auto handler_ex = std::move(_state->_handler_ex); - _state.reset(); - - auto op = std::apply( - [&h](auto... args) { - return asio::prepend( - std::move(h), asio::error::operation_aborted, args... - ); - }, - CancelArgs {} - ); - - asio::dispatch(handler_ex, std::move(op)); + asio::cancellation_type_t cancelled() const { + return _cancellation_state.cancelled(); } template void complete(Args&&... args) { - if (empty()) return; - - auto h = std::move(_state->_handler); - asio::get_associated_cancellation_slot(h).clear(); - auto handler_ex = std::move(_state->_handler_ex); - _state.reset(); - - asio::dispatch( - handler_ex, - asio::prepend(std::move(h), std::forward(args)...) - ); + asio::get_associated_cancellation_slot(_handler).clear(); + std::move(_handler)(std::forward(args)...); } template void complete_post(Args&&... args) { - if (empty()) return; - - auto h = std::move(_state->_handler); - asio::get_associated_cancellation_slot(h).clear(); - auto handler_ex = std::move(_state->_handler_ex); - _state.reset(); - + asio::get_associated_cancellation_slot(_handler).clear(); asio::post( _executor, - asio::prepend(std::move(h), std::forward(args)...) + asio::prepend(std::move(_handler), std::forward(args)...) ); } diff --git a/include/async_mqtt5/impl/assemble_op.hpp b/include/async_mqtt5/impl/assemble_op.hpp index 5c96930..d5ca02a 100644 --- a/include/async_mqtt5/impl/assemble_op.hpp +++ b/include/async_mqtt5/impl/assemble_op.hpp @@ -97,9 +97,6 @@ public: }; if (cc(error_code {}, 0) == 0 && _data_span.size()) { - /* TODO clear read buffer on reconnect - * OR use dispatch instead of post here - */ return asio::post( asio::prepend( std::move(*this), on_read {}, error_code {}, @@ -216,10 +213,7 @@ private: error_code ec, uint8_t control_code, byte_citer first, byte_citer last ) { - asio::dispatch( - get_executor(), - asio::prepend(std::move(_handler), ec, control_code, first, last) - ); + std::move(_handler)(ec, control_code, first, last); } }; diff --git a/include/async_mqtt5/impl/async_sender.hpp b/include/async_mqtt5/impl/async_sender.hpp index 3ffe913..33e1d95 100644 --- a/include/async_mqtt5/impl/async_sender.hpp +++ b/include/async_mqtt5/impl/async_sender.hpp @@ -2,6 +2,8 @@ #define ASYNC_MQTT5_ASYNC_SENDER_HPP #include +#include +#include #include #include #include @@ -44,6 +46,14 @@ public: std::move(_handler)(ec); } + auto get_executor() { + return asio::get_associated_executor(_handler); + } + + auto get_allocator() { + return asio::get_associated_allocator(_handler); + } + bool throttled() const { return _flags & send_flag::throttled; } @@ -94,11 +104,6 @@ class async_sender { public: explicit async_sender(ClientService& svc) : _svc(svc) {} - using executor_type = typename client_service::executor_type; - executor_type get_executor() const noexcept { - return _svc.get_executor(); - } - using allocator_type = queue_allocator_type; allocator_type get_allocator() const noexcept { return allocator_type {}; @@ -223,7 +228,9 @@ private: _write_queue.begin(), _write_queue.end(), [](const auto& op) { return !op.throttled(); } ); - uint16_t dist = static_cast(std::distance(throttled_ptr, _write_queue.end())); + uint16_t dist = static_cast( + std::distance(throttled_ptr, _write_queue.end()) + ); uint16_t throttled_num = std::min(dist, _quota); _quota -= throttled_num; throttled_ptr += throttled_num; @@ -249,8 +256,17 @@ private: _svc._replies.clear_fast_replies(); + auto ex = write_queue.front().get_executor(); + auto alloc = write_queue.front().get_allocator(); _svc._stream.async_write( - buffers, asio::prepend(std::ref(*this), std::move(write_queue)) + buffers, + asio::bind_executor( + ex, + asio::bind_allocator( + alloc, + asio::prepend(std::ref(*this), std::move(write_queue)) + ) + ) ); } diff --git a/include/async_mqtt5/impl/client_service.hpp b/include/async_mqtt5/impl/client_service.hpp index df3f811..c35043b 100644 --- a/include/async_mqtt5/impl/client_service.hpp +++ b/include/async_mqtt5/impl/client_service.hpp @@ -3,6 +3,7 @@ #include +#include #include #include @@ -227,6 +228,7 @@ public: ) : _stream_context(std::move(tls_context)), _stream(ex, _stream_context), + _replies(ex), _async_sender(*this), _active_span(_read_buff.cend(), _read_buff.cend()), _rec_channel(ex, std::numeric_limits::max()) diff --git a/include/async_mqtt5/impl/connect_op.hpp b/include/async_mqtt5/impl/connect_op.hpp index 9d60c2d..2eacaaf 100644 --- a/include/async_mqtt5/impl/connect_op.hpp +++ b/include/async_mqtt5/impl/connect_op.hpp @@ -63,9 +63,9 @@ public: connect_op(connect_op&&) noexcept = default; connect_op(const connect_op&) = delete; - using executor_type = typename Stream::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _stream.get_executor(); + return asio::get_associated_executor(_handler); } using allocator_type = asio::associated_allocator_t; @@ -382,10 +382,7 @@ private: void complete(error_code ec) { get_cancellation_slot().clear(); - asio::dispatch( - get_executor(), - asio::prepend(std::move(_handler), ec) - ); + std::move(_handler)(ec); } static error_code to_asio_error(reason_code rc) { diff --git a/include/async_mqtt5/impl/disconnect_op.hpp b/include/async_mqtt5/impl/disconnect_op.hpp index f15d673..cfc4a32 100644 --- a/include/async_mqtt5/impl/disconnect_op.hpp +++ b/include/async_mqtt5/impl/disconnect_op.hpp @@ -44,15 +44,15 @@ public: ) : _svc_ptr(svc_ptr), _context(std::move(context)), - _handler(std::move(handler), get_executor()) + _handler(std::move(handler), _svc_ptr->get_executor()) {} disconnect_op(disconnect_op&&) noexcept = default; disconnect_op(const disconnect_op&) = delete; - using executor_type = typename client_service::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _svc_ptr->get_executor(); + return asio::get_associated_executor(_handler); } using allocator_type = asio::associated_allocator_t; diff --git a/include/async_mqtt5/impl/endpoints.hpp b/include/async_mqtt5/impl/endpoints.hpp index eda480d..abea53c 100644 --- a/include/async_mqtt5/impl/endpoints.hpp +++ b/include/async_mqtt5/impl/endpoints.hpp @@ -36,9 +36,9 @@ public: resolve_op(resolve_op&&) noexcept = default; resolve_op(const resolve_op&) = delete; - using executor_type = typename Owner::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _owner.get_executor(); + return asio::get_associated_executor(_handler); } using allocator_type = asio::associated_allocator_t; @@ -104,20 +104,14 @@ private: void complete(error_code ec, epoints eps, authority_path ap) { get_cancellation_slot().clear(); - asio::dispatch( - get_executor(), - asio::prepend( - std::move(_handler), ec, - std::move(eps), std::move(ap) - ) - ); + std::move(_handler)(ec, std::move(eps), std::move(ap)); } void complete_post(error_code ec, epoints eps, authority_path ap) { get_cancellation_slot().clear(); asio::post( - get_executor(), + _owner.get_executor(), asio::prepend( std::move(_handler), ec, std::move(eps), std::move(ap) diff --git a/include/async_mqtt5/impl/publish_send_op.hpp b/include/async_mqtt5/impl/publish_send_op.hpp index bb526bf..e176f32 100644 --- a/include/async_mqtt5/impl/publish_send_op.hpp +++ b/include/async_mqtt5/impl/publish_send_op.hpp @@ -43,17 +43,6 @@ using on_publish_props_type = std::conditional_t< > >; -template -using cancel_args = std::conditional_t< - qos_type == qos_e::at_most_once, - std::tuple<>, - std::conditional_t< - qos_type == qos_e::at_least_once, - std::tuple, - std::tuple - > ->; - template class publish_send_op { using client_service = ClientService; @@ -66,11 +55,11 @@ class publish_send_op { std::shared_ptr _svc_ptr; - cancellable_handler< + using handler_type = cancellable_handler< Handler, - typename client_service::executor_type, - cancel_args - > _handler; + typename client_service::executor_type + >; + handler_type _handler; serial_num_t _serial_num; @@ -80,18 +69,24 @@ public: Handler&& handler ) : _svc_ptr(svc_ptr), - _handler(std::move(handler), get_executor()) - {} + _handler(std::move(handler), _svc_ptr->get_executor()) + { + auto slot = asio::get_associated_cancellation_slot(_handler); + if (slot.is_connected()) + slot.assign([&svc = *_svc_ptr](asio::cancellation_type_t) { + svc.cancel(); + }); + } publish_send_op(publish_send_op&&) noexcept = default; publish_send_op(const publish_send_op&) = delete; - using executor_type = typename client_service::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _svc_ptr->get_executor(); + return asio::get_associated_executor(_handler); } - using allocator_type = asio::associated_allocator_t; + using allocator_type = asio::associated_allocator_t; allocator_type get_allocator() const noexcept { return asio::get_associated_allocator(_handler); } @@ -130,14 +125,7 @@ public: send_publish(std::move(publish)); } - void send_publish(control_packet publish) { - if (_handler.empty()) { // already cancelled - if constexpr (qos_type != qos_e::at_most_once) - _svc_ptr->free_pid(publish.packet_id()); - return; - } - auto wire_data = publish.wire_data(); _svc_ptr->async_send( wire_data, @@ -147,12 +135,20 @@ public: ); } + void resend_publish(control_packet publish) { + if (_handler.cancelled() != asio::cancellation_type_t::none) + return complete( + asio::error::operation_aborted, publish.packet_id() + ); + send_publish(std::move(publish)); + } + void operator()( on_publish, control_packet publish, error_code ec ) { if (ec == asio::error::try_again) - return send_publish(std::move(publish)); + return resend_publish(std::move(publish)); if constexpr (qos_type == qos_e::at_most_once) return complete(ec); @@ -160,31 +156,24 @@ public: else { auto packet_id = publish.packet_id(); - if constexpr (qos_type == qos_e::at_least_once) { - if (ec) - return complete( - ec, reason_codes::empty, packet_id, puback_props {} - ); + if (ec) + return complete(ec, packet_id); + + if constexpr (qos_type == qos_e::at_least_once) _svc_ptr->async_wait_reply( control_code_e::puback, packet_id, asio::prepend( std::move(*this), on_puback {}, std::move(publish) ) ); - } - else if constexpr (qos_type == qos_e::exactly_once) { - if (ec) - return complete( - ec, reason_codes::empty, packet_id, pubcomp_props {} - ); + else if constexpr (qos_type == qos_e::exactly_once) _svc_ptr->async_wait_reply( control_code_e::pubrec, packet_id, asio::prepend( std::move(*this), on_pubrec {}, std::move(publish) ) ); - } } } @@ -198,31 +187,29 @@ public: error_code ec, byte_citer first, byte_citer last ) { if (ec == asio::error::try_again) // "resend unanswered" - return send_publish(std::move(publish.set_dup())); + return resend_publish(std::move(publish.set_dup())); uint16_t packet_id = publish.packet_id(); if (ec) - return complete( - ec, reason_codes::empty, packet_id, puback_props {} - ); + return complete(ec, packet_id); auto puback = decoders::decode_puback( static_cast(std::distance(first, last)), first ); if (!puback.has_value()) { on_malformed_packet("Malformed PUBACK: cannot decode"); - return send_publish(std::move(publish.set_dup())); + return resend_publish(std::move(publish.set_dup())); } auto& [reason_code, props] = *puback; auto rc = to_reason_code(reason_code); if (!rc) { on_malformed_packet("Malformed PUBACK: invalid Reason Code"); - return send_publish(std::move(publish.set_dup())); + return resend_publish(std::move(publish.set_dup())); } - complete(ec, *rc, packet_id, std::move(props)); + complete(ec, packet_id, *rc, std::move(props)); } template < @@ -234,21 +221,19 @@ public: error_code ec, byte_citer first, byte_citer last ) { if (ec == asio::error::try_again) // "resend unanswered" - return send_publish(std::move(publish.set_dup())); + return resend_publish(std::move(publish.set_dup())); uint16_t packet_id = publish.packet_id(); if (ec) - return complete( - ec, reason_codes::empty, packet_id, pubcomp_props {} - ); + return complete(ec, packet_id); auto pubrec = decoders::decode_pubrec( static_cast(std::distance(first, last)), first ); if (!pubrec.has_value()) { on_malformed_packet("Malformed PUBREC: cannot decode"); - return send_publish(std::move(publish.set_dup())); + return resend_publish(std::move(publish.set_dup())); } auto& [reason_code, props] = *pubrec; @@ -256,11 +241,11 @@ public: auto rc = to_reason_code(reason_code); if (!rc) { on_malformed_packet("Malformed PUBREC: invalid Reason Code"); - return send_publish(std::move(publish.set_dup())); + return resend_publish(std::move(publish.set_dup())); } if (*rc) - return complete(ec, *rc, packet_id, pubcomp_props {}); + return complete(ec, packet_id, *rc); auto pubrel = control_packet::of( with_pid, get_allocator(), @@ -294,9 +279,7 @@ public: uint16_t packet_id = pubrel.packet_id(); if (ec) - return complete( - ec, reason_codes::empty, packet_id, pubcomp_props {} - ); + return complete(ec, packet_id); _svc_ptr->async_wait_reply( control_code_e::pubcomp, packet_id, @@ -319,9 +302,7 @@ public: uint16_t packet_id = pubrel.packet_id(); if (ec) - return complete( - ec, reason_codes::empty, packet_id, pubcomp_props {} - ); + return complete(ec, packet_id); auto pubcomp = decoders::decode_pubcomp( static_cast(std::distance(first, last)), first @@ -339,7 +320,7 @@ public: return send_pubrel(std::move(pubrel), true); } - return complete(ec, *rc, pubrel.packet_id(), pubcomp_props {}); + return complete(ec, pubrel.packet_id(), *rc); } private: @@ -434,7 +415,7 @@ private: qos_e q = qos_type, std::enable_if_t = true > - void complete(error_code ec) { + void complete(error_code ec, uint16_t = 0) { _handler.complete(ec); } @@ -455,8 +436,8 @@ private: > = true > void complete( - error_code ec, reason_code rc, - uint16_t packet_id, Props&& props + error_code ec, uint16_t packet_id, + reason_code rc = reason_codes::empty, Props&& props = Props {} ) { _svc_ptr->free_pid(packet_id, true); _handler.complete(ec, rc, std::forward(props)); diff --git a/include/async_mqtt5/impl/re_auth_op.hpp b/include/async_mqtt5/impl/re_auth_op.hpp index 5ea1a8e..c2c9246 100644 --- a/include/async_mqtt5/impl/re_auth_op.hpp +++ b/include/async_mqtt5/impl/re_auth_op.hpp @@ -119,7 +119,7 @@ public: private: void on_auth_fail(std::string message, disconnect_rc_e reason) { - auto props = disconnect_props{}; + auto props = disconnect_props {}; props[prop::reason_string] = std::move(message); async_disconnect(reason, props, false, _svc_ptr, asio::detached); diff --git a/include/async_mqtt5/impl/read_op.hpp b/include/async_mqtt5/impl/read_op.hpp index 6ee4594..741a175 100644 --- a/include/async_mqtt5/impl/read_op.hpp +++ b/include/async_mqtt5/impl/read_op.hpp @@ -2,6 +2,7 @@ #define ASYNC_MQTT5_READ_OP_HPP #include +#include #include #include @@ -32,9 +33,9 @@ public: read_op(read_op&&) noexcept = default; read_op(const read_op&) = delete; - using executor_type = typename Owner::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _owner.get_executor(); + return asio::get_associated_executor(_handler); } using allocator_type = asio::associated_allocator_t; @@ -62,9 +63,13 @@ public: ); } else - (*this)( - on_read {}, stream_ptr, - { 0, 1 }, asio::error::not_connected, 0, {} + asio::post( + _owner.get_executor(), + asio::prepend( + std::move(*this), on_read {}, stream_ptr, + std::array { 0, 1 }, + asio::error::not_connected, 0, error_code {} + ) ); } @@ -100,10 +105,7 @@ public: private: void complete(error_code ec, size_t bytes_read) { - asio::dispatch( - get_executor(), - asio::prepend(std::move(_handler), ec, bytes_read) - ); + std::move(_handler)(ec, bytes_read); } static bool should_reconnect(error_code ec) { diff --git a/include/async_mqtt5/impl/reconnect_op.hpp b/include/async_mqtt5/impl/reconnect_op.hpp index aa3a3b2..df85211 100644 --- a/include/async_mqtt5/impl/reconnect_op.hpp +++ b/include/async_mqtt5/impl/reconnect_op.hpp @@ -44,9 +44,9 @@ public: reconnect_op(reconnect_op&&) noexcept = default; reconnect_op(const reconnect_op&) = delete; - using executor_type = typename Owner::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _owner.get_executor(); + return asio::get_associated_executor(_handler); } using allocator_type = asio::associated_allocator_t; @@ -183,10 +183,7 @@ private: get_cancellation_slot().clear(); _owner._conn_mtx.unlock(); - asio::dispatch( - get_executor(), - asio::prepend(std::move(_handler), ec) - ); + std::move(_handler)(ec); } }; diff --git a/include/async_mqtt5/impl/replies.hpp b/include/async_mqtt5/impl/replies.hpp index 7e44cec..163ef6b 100644 --- a/include/async_mqtt5/impl/replies.hpp +++ b/include/async_mqtt5/impl/replies.hpp @@ -2,10 +2,12 @@ #define ASYNC_MQTT5_REPLIES_HPP #include +#include #include -#include +#include #include #include +#include #include #include @@ -15,33 +17,39 @@ namespace async_mqtt5::detail { namespace asio = boost::asio; class replies { +public: + using executor_type = asio::any_io_executor; +private: using Signature = void (error_code, byte_citer, byte_citer); static constexpr auto max_reply_time = std::chrono::seconds(20); - class handler_type : public asio::any_completion_handler { - using base = asio::any_completion_handler; + class reply_handler { + asio::any_completion_handler _handler; control_code_e _code; uint16_t _packet_id; std::chrono::time_point _ts; public: template - handler_type(control_code_e code, uint16_t pid, H&& handler) : - base(std::forward(handler)), _code(code), _packet_id(pid), + reply_handler(control_code_e code, uint16_t pid, H&& handler) : + _handler(std::forward(handler)), _code(code), _packet_id(pid), _ts(std::chrono::system_clock::now()) {} - handler_type(handler_type&& other) noexcept : - base(static_cast(other)), - _code(other._code), _packet_id(other._packet_id), _ts(other._ts) - {} + void complete( + error_code ec, + byte_citer first = byte_citer {}, byte_citer last = byte_citer {} + ) { + asio::dispatch(asio::prepend(std::move(_handler), ec, first, last)); + } - handler_type& operator=(handler_type&& other) noexcept { - base::operator=(static_cast(other)); - _code = other._code; - _packet_id = other._packet_id; - _ts = other._ts; - return *this; + void complete_post(const executor_type& ex, error_code ec) { + asio::post( + ex, + asio::prepend( + std::move(_handler), ec, byte_citer {}, byte_citer {} + ) + ); } uint16_t packet_id() const noexcept { @@ -57,7 +65,9 @@ class replies { } }; - using handlers = std::vector; + executor_type _ex; + + using handlers = std::vector; handlers _handlers; struct fast_reply { @@ -69,15 +79,16 @@ class replies { fast_replies _fast_replies; public: + template + replies(const Executor& ex) : _ex(ex) {} + template decltype(auto) async_wait_reply( control_code_e code, uint16_t packet_id, CompletionToken&& token ) { auto dup_handler_ptr = find_handler(code, packet_id); if (dup_handler_ptr != _handlers.end()) { - std::move(*dup_handler_ptr)( - asio::error::operation_aborted, byte_citer {}, byte_citer {} - ); + dup_handler_ptr->complete_post(_ex, asio::error::operation_aborted); _handlers.erase(dup_handler_ptr); } @@ -101,23 +112,25 @@ public: _fast_replies.erase(freply); auto initiation = []( - auto handler, std::unique_ptr packet + auto handler, std::unique_ptr packet, + const executor_type& ex ) { - auto ex = asio::get_associated_executor(handler); byte_citer first = packet->cbegin(); byte_citer last = packet->cend(); asio::post( ex, asio::consign( - asio::append(std::move(handler), error_code{}, first, last), + asio::prepend( + std::move(handler), error_code {}, first, last + ), std::move(packet) ) ); }; return asio::async_initiate( - initiation, token, std::move(fdata.packet) + initiation, token, std::move(fdata.packet), _ex ); } @@ -137,22 +150,19 @@ public: auto handler = std::move(*handler_ptr); _handlers.erase(handler_ptr); - std::move(handler)(ec, first, last); + handler.complete(ec, first, last); } void resend_unanswered() { auto ua = std::move(_handlers); for (auto& h : ua) - std::move(h)(asio::error::try_again, byte_citer {}, byte_citer {}); + h.complete(asio::error::try_again); } void cancel_unanswered() { auto ua = std::move(_handlers); for (auto& h : ua) - std::move(h)( - asio::error::operation_aborted, - byte_citer {}, byte_citer {} - ); + h.complete(asio::error::operation_aborted); } bool any_expired() { @@ -172,10 +182,7 @@ public: void clear_pending_pubrels() { for (auto it = _handlers.begin(); it != _handlers.end();) { if (it->code() == control_code_e::pubrel) { - std::move(*it)( - asio::error::operation_aborted, - byte_citer {}, byte_citer {} - ); + it->complete(asio::error::operation_aborted); it = _handlers.erase(it); } else diff --git a/include/async_mqtt5/impl/sentry_op.hpp b/include/async_mqtt5/impl/sentry_op.hpp index 94a307f..93e7493 100644 --- a/include/async_mqtt5/impl/sentry_op.hpp +++ b/include/async_mqtt5/impl/sentry_op.hpp @@ -67,7 +67,7 @@ public: return; if (_svc_ptr->_replies.any_expired()) { - auto props = disconnect_props{}; + auto props = disconnect_props {}; // TODO add what packet was expected? props[prop::reason_string] = "No reply received within 20 seconds"; auto svc_ptr = _svc_ptr; diff --git a/include/async_mqtt5/impl/subscribe_op.hpp b/include/async_mqtt5/impl/subscribe_op.hpp index 3846a33..5103bc2 100644 --- a/include/async_mqtt5/impl/subscribe_op.hpp +++ b/include/async_mqtt5/impl/subscribe_op.hpp @@ -32,11 +32,11 @@ class subscribe_op { std::shared_ptr _svc_ptr; - cancellable_handler< + using handler_type = cancellable_handler< Handler, - typename client_service::executor_type, - std::tuple, suback_props> - > _handler; + typename client_service::executor_type + >; + handler_type _handler; public: subscribe_op( @@ -44,18 +44,24 @@ public: Handler&& handler ) : _svc_ptr(svc_ptr), - _handler(std::move(handler), get_executor()) - {} + _handler(std::move(handler), _svc_ptr->get_executor()) + { + auto slot = asio::get_associated_cancellation_slot(_handler); + if (slot.is_connected()) + slot.assign([&svc = *_svc_ptr](asio::cancellation_type_t) { + svc.cancel(); + }); + } subscribe_op(subscribe_op&&) noexcept = default; subscribe_op(const subscribe_op&) noexcept = delete; - using executor_type = typename client_service::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _svc_ptr->get_executor(); + return asio::get_associated_executor(_handler); } - using allocator_type = asio::associated_allocator_t; + using allocator_type = asio::associated_allocator_t; allocator_type get_allocator() const noexcept { return asio::get_associated_allocator(_handler); } @@ -93,9 +99,6 @@ public: } void send_subscribe(control_packet subscribe) { - if (_handler.empty()) // already cancelled - return _svc_ptr->free_pid(subscribe.packet_id()); - auto wire_data = subscribe.wire_data(); _svc_ptr->async_send( wire_data, @@ -106,17 +109,25 @@ public: ); } + void resend_subscribe(control_packet subscribe) { + if (_handler.cancelled() != asio::cancellation_type_t::none) + return complete( + asio::error::operation_aborted, subscribe.packet_id() + ); + send_subscribe(std::move(subscribe)); + } + void operator()( on_subscribe, control_packet packet, error_code ec ) { if (ec == asio::error::try_again) - return send_subscribe(std::move(packet)); + return resend_subscribe(std::move(packet)); auto packet_id = packet.packet_id(); if (ec) - return complete(ec, packet_id, {}, {}); + return complete(ec, packet_id); _svc_ptr->async_wait_reply( control_code_e::suback, packet_id, @@ -129,19 +140,19 @@ public: error_code ec, byte_citer first, byte_citer last ) { if (ec == asio::error::try_again) // "resend unanswered" - return send_subscribe(std::move(packet)); + return resend_subscribe(std::move(packet)); uint16_t packet_id = packet.packet_id(); if (ec) - return complete(ec, packet_id, {}, {}); + return complete(ec, packet_id); auto suback = decoders::decode_suback( static_cast(std::distance(first, last)), first ); if (!suback.has_value()) { on_malformed_packet("Malformed SUBACK: cannot decode"); - return send_subscribe(std::move(packet)); + return resend_subscribe(std::move(packet)); } auto& [props, reason_codes] = *suback; @@ -247,14 +258,14 @@ private: if (packet_id != 0) _svc_ptr->free_pid(packet_id); _handler.complete_post( - ec, std::vector { num_topics, reason_codes::empty }, + ec, std::vector(num_topics, reason_codes::empty), suback_props {} ); } void complete( error_code ec, uint16_t packet_id, - std::vector reason_codes, suback_props props + std::vector reason_codes = {}, suback_props props = {} ) { if (!_svc_ptr->subscriptions_present()) { bool has_success_rc = std::any_of( diff --git a/include/async_mqtt5/impl/unsubscribe_op.hpp b/include/async_mqtt5/impl/unsubscribe_op.hpp index e9dd266..a38f5e9 100644 --- a/include/async_mqtt5/impl/unsubscribe_op.hpp +++ b/include/async_mqtt5/impl/unsubscribe_op.hpp @@ -27,11 +27,11 @@ class unsubscribe_op { std::shared_ptr _svc_ptr; - cancellable_handler< + using handler_type = cancellable_handler< Handler, - typename client_service::executor_type, - std::tuple, unsuback_props> - > _handler; + typename client_service::executor_type + >; + handler_type _handler; public: unsubscribe_op( @@ -39,18 +39,24 @@ public: Handler&& handler ) : _svc_ptr(svc_ptr), - _handler(std::move(handler), get_executor()) - {} + _handler(std::move(handler), _svc_ptr->get_executor()) + { + auto slot = asio::get_associated_cancellation_slot(_handler); + if (slot.is_connected()) + slot.assign([&svc = *_svc_ptr](asio::cancellation_type_t) { + svc.cancel(); + }); + } unsubscribe_op(unsubscribe_op&&) noexcept = default; unsubscribe_op(const unsubscribe_op&) noexcept = delete; - using executor_type = typename client_service::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _svc_ptr->get_executor(); + return asio::get_associated_executor(_handler); } - using allocator_type = asio::associated_allocator_t; + using allocator_type = asio::associated_allocator_t; allocator_type get_allocator() const noexcept { return asio::get_associated_allocator(_handler); } @@ -88,9 +94,6 @@ public: } void send_unsubscribe(control_packet unsubscribe) { - if (_handler.empty()) // already cancelled - return _svc_ptr->free_pid(unsubscribe.packet_id()); - auto wire_data = unsubscribe.wire_data(); _svc_ptr->async_send( wire_data, @@ -101,17 +104,25 @@ public: ); } + void resend_unsubscribe(control_packet subscribe) { + if (_handler.cancelled() != asio::cancellation_type_t::none) + return complete( + asio::error::operation_aborted, subscribe.packet_id() + ); + send_unsubscribe(std::move(subscribe)); + } + void operator()( on_unsubscribe, control_packet packet, error_code ec ) { if (ec == asio::error::try_again) - return send_unsubscribe(std::move(packet)); + return resend_unsubscribe(std::move(packet)); auto packet_id = packet.packet_id(); if (ec) - return complete(ec, packet_id, {}, {}); + return complete(ec, packet_id); _svc_ptr->async_wait_reply( control_code_e::unsuback, packet_id, @@ -124,19 +135,19 @@ public: error_code ec, byte_citer first, byte_citer last ) { if (ec == asio::error::try_again) // "resend unanswered" - return send_unsubscribe(std::move(packet)); + return resend_unsubscribe(std::move(packet)); uint16_t packet_id = packet.packet_id(); if (ec) - return complete(ec, packet_id, {}, {}); + return complete(ec, packet_id); auto unsuback = decoders::decode_unsuback( static_cast(std::distance(first, last)), first ); if (!unsuback.has_value()) { on_malformed_packet("Malformed UNSUBACK: cannot decode"); - return send_unsubscribe(std::move(packet)); + return resend_unsubscribe(std::move(packet)); } auto& [props, reason_codes] = *unsuback; @@ -189,14 +200,14 @@ private: if (packet_id != 0) _svc_ptr->free_pid(packet_id); _handler.complete_post( - ec, std::vector { num_topics, reason_codes::empty }, + ec, std::vector(num_topics, reason_codes::empty), unsuback_props {} ); } void complete( error_code ec, uint16_t packet_id, - std::vector reason_codes, unsuback_props props + std::vector reason_codes = {}, unsuback_props props = {} ) { _svc_ptr->free_pid(packet_id); _handler.complete(ec, std::move(reason_codes), std::move(props)); diff --git a/include/async_mqtt5/impl/write_op.hpp b/include/async_mqtt5/impl/write_op.hpp index 1a95af7..690fae9 100644 --- a/include/async_mqtt5/impl/write_op.hpp +++ b/include/async_mqtt5/impl/write_op.hpp @@ -2,6 +2,7 @@ #define ASYNC_MQTT5_WRITE_OP_HPP #include +#include #include #include @@ -28,9 +29,9 @@ public: write_op(write_op&&) noexcept = default; write_op(const write_op&) = delete; - using executor_type = typename Owner::executor_type; + using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { - return _owner.get_executor(); + return asio::get_associated_executor(_handler); } using allocator_type = asio::associated_allocator_t; @@ -48,7 +49,13 @@ public: asio::prepend(std::move(*this), on_write {}, stream_ptr) ); else - (*this)(on_write {}, stream_ptr, asio::error::not_connected, 0); + asio::post( + _owner.get_executor(), + asio::prepend( + std::move(*this), on_write {}, + stream_ptr, asio::error::not_connected, 0 + ) + ); } void operator()( @@ -79,10 +86,7 @@ public: private: void complete(error_code ec, size_t bytes_written) { - asio::dispatch( - get_executor(), - asio::prepend(std::move(_handler), ec, bytes_written) - ); + std::move(_handler)(ec, bytes_written); } static bool should_reconnect(error_code ec) { diff --git a/include/async_mqtt5/types.hpp b/include/async_mqtt5/types.hpp index b1c3b11..3eb273d 100644 --- a/include/async_mqtt5/types.hpp +++ b/include/async_mqtt5/types.hpp @@ -127,7 +127,7 @@ struct subscribe_options { new_subscription_only = 0b01, /** Do not send retained messages at the time of subscribe. */ - not_send = 0b100 + not_send = 0b10 }; diff --git a/test/unit/include/test_common/test_service.hpp b/test/unit/include/test_common/test_service.hpp index 5aa442c..2988e2f 100644 --- a/test/unit/include/test_common/test_service.hpp +++ b/test/unit/include/test_common/test_service.hpp @@ -39,8 +39,7 @@ public: CompletionToken&& token ) { auto initiation = [this](auto handler) { - auto ex = asio::get_associated_executor(handler, _ex); - asio::post(ex, + asio::post(_ex, asio::prepend(std::move(handler), error_code {}) ); }; diff --git a/test/unit/test/cancellation.cpp b/test/unit/test/cancellation.cpp index 0c86e2f..c3c7280 100644 --- a/test/unit/test/cancellation.cpp +++ b/test/unit/test/cancellation.cpp @@ -15,7 +15,8 @@ namespace async_mqtt5::test { enum cancel_type { ioc_stop = 1, - client_cancel + client_cancel, + signal_emit }; } // end namespace async_mqtt5::test @@ -34,27 +35,42 @@ void cancel_async_receive() { using client_type = mqtt_client; client_type c(ioc, ""); - c.brokers("mqtt.mireo.local", 1883) + c.brokers("127.0.0.1", 1883) + .credentials("test-cli", "", "") .run(); + auto handler = [&handlers_called]( + error_code ec, std::string, std::string, publish_props + ) { + BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); + handlers_called++; + }; + + std::vector signals(3); + for (auto i = 0; i < num_handlers; ++i) - c.async_receive([&handlers_called]( - error_code ec, std::string, std::string, publish_props - ) { - BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); - handlers_called++; - }); + c.async_receive(asio::bind_cancellation_slot( + signals[i].slot(), + std::move(handler) + )); asio::steady_timer timer(c.get_executor()); - timer.expires_after(std::chrono::seconds(1)); + timer.expires_after(std::chrono::milliseconds(10)); timer.async_wait([&](auto) { if constexpr (type == ioc_stop) ioc.stop(); - else + else if constexpr (type == client_cancel) c.cancel(); + else if constexpr (type == signal_emit) + std::for_each( + signals.begin(), signals.end(), + [](auto& signal) { + signal.emit(asio::cancellation_type_t::terminal); + } + ); }); - ioc.run(); + ioc.run_for(std::chrono::milliseconds(20)); BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); } @@ -71,87 +87,68 @@ void cancel_async_publish() { using client_type = mqtt_client; client_type c(ioc, ""); - c.brokers("mqtt.mireo.local", 1883) + c.brokers("127.0.0.1", 1883) + .credentials("test-cli", "", "") .run(); + std::vector signals(3); + c.async_publish( "topic", "payload", retain_e::yes, {}, - [&handlers_called](error_code ec) { - BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); - handlers_called++; - } + asio::bind_cancellation_slot( + signals[0].slot(), + [&handlers_called](error_code ec) { + BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); + handlers_called++; + } + ) ); c.async_publish( "topic", "payload", retain_e::yes, {}, - [&handlers_called](error_code ec, reason_code rc, puback_props) { - BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); - BOOST_CHECK_EQUAL(rc, reason_codes::empty); - handlers_called++; - } + asio::bind_cancellation_slot( + signals[1].slot(), + [&handlers_called](error_code ec, reason_code rc, puback_props) { + BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); + BOOST_CHECK_EQUAL(rc, reason_codes::empty); + handlers_called++; + } + ) ); c.async_publish( "topic", "payload", retain_e::yes, {}, - [&handlers_called](error_code ec, reason_code rc, pubcomp_props) { - BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); - BOOST_CHECK_EQUAL(rc, reason_codes::empty); - handlers_called++; - } + asio::bind_cancellation_slot( + signals[2].slot(), + [&handlers_called](error_code ec, reason_code rc, pubcomp_props) { + BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); + BOOST_CHECK_EQUAL(rc, reason_codes::empty); + handlers_called++; + } + ) ); asio::steady_timer timer(c.get_executor()); - timer.expires_after(std::chrono::seconds(1)); + timer.expires_after(std::chrono::milliseconds(10)); timer.async_wait([&](auto) { if constexpr (type == ioc_stop) ioc.stop(); - else + else if constexpr (type == client_cancel) c.cancel(); + else if constexpr (type == signal_emit) + std::for_each( + signals.begin(), signals.end(), + [](auto& signal) { + signal.emit(asio::cancellation_type_t::terminal); + } + ); }); ioc.run(); BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); } -template -void cancel_during_connecting() { - using namespace test; - - constexpr int expected_handlers_called = type == ioc_stop ? 0 : 1; - int handlers_called = 0; - - asio::io_context ioc; - - using stream_type = asio::ip::tcp::socket; - using client_type = mqtt_client; - client_type c(ioc, ""); - - c.brokers("127.0.0.1", 1883) - .run(); - - c.async_publish( - "topic", "payload", retain_e::yes, {}, - [&handlers_called](error_code ec) { - BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); - handlers_called++; - } - ); - - asio::steady_timer timer(c.get_executor()); - timer.expires_after(std::chrono::seconds(2)); - timer.async_wait([&](auto) { - if constexpr (type == ioc_stop) - ioc.stop(); - else - c.cancel(); - }); - - ioc.run(); - BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); -} - - BOOST_AUTO_TEST_SUITE(cancellation/*, *boost::unit_test::disabled()*/) @@ -159,11 +156,14 @@ BOOST_AUTO_TEST_CASE(ioc_stop_async_receive) { cancel_async_receive(); } - BOOST_AUTO_TEST_CASE(client_cancel_async_receive) { cancel_async_receive(); } +BOOST_AUTO_TEST_CASE(signal_emit_async_receive) { + cancel_async_receive(); +} + // passes on debian, hangs on windows in io_context destructor BOOST_AUTO_TEST_CASE(ioc_stop_async_publish, *boost::unit_test::disabled() ) { cancel_async_publish(); @@ -173,13 +173,8 @@ BOOST_AUTO_TEST_CASE(client_cancel_async_publish) { cancel_async_publish(); } -// passes on debian, hangs on windows -BOOST_AUTO_TEST_CASE(ioc_stop_cancel_during_connecting, *boost::unit_test::disabled() ) { - cancel_during_connecting(); -} - -BOOST_AUTO_TEST_CASE(client_cancel_during_connecting) { - cancel_during_connecting(); +BOOST_AUTO_TEST_CASE(signal_emit_async_publish) { + cancel_async_publish(); } #ifdef BOOST_ASIO_HAS_CO_AWAIT diff --git a/test/unit/test/publish_send_op.cpp b/test/unit/test/publish_send_op.cpp index 8920516..14ece50 100644 --- a/test/unit/test/publish_send_op.cpp +++ b/test/unit/test/publish_send_op.cpp @@ -260,35 +260,6 @@ BOOST_AUTO_TEST_CASE(test_topic_alias_maximum) { BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); } -BOOST_AUTO_TEST_CASE(test_publish_immediate_cancellation) { - constexpr int expected_handlers_called = 1; - int handlers_called = 0; - - asio::io_context ioc; - using client_service_type = test::test_service; - auto svc_ptr = std::make_shared(ioc.get_executor()); - asio::cancellation_signal cancel_signal; - - auto h = [&handlers_called](error_code ec, reason_code rc, puback_props) { - ++handlers_called; - BOOST_CHECK(ec == asio::error::operation_aborted); - BOOST_CHECK_EQUAL(rc, reason_codes::empty); - }; - - auto handler = asio::bind_cancellation_slot(cancel_signal.slot(), std::move(h)); - - detail::publish_send_op< - client_service_type, decltype(handler), qos_e::at_least_once - > { svc_ptr, std::move(handler) } - .perform( - "test", "payload", retain_e::no, {} - ); - - cancel_signal.emit(asio::cancellation_type::terminal); - ioc.run(); - BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); -} - BOOST_AUTO_TEST_CASE(test_publish_cancellation) { constexpr int expected_handlers_called = 1; int handlers_called = 0; diff --git a/test/unit/test/session.cpp b/test/unit/test/session.cpp index afffdb6..1789b89 100644 --- a/test/unit/test/session.cpp +++ b/test/unit/test/session.cpp @@ -27,6 +27,7 @@ BOOST_AUTO_TEST_CASE(clear_waiting_on_pubrel) { asio::io_context ioc; using client_service_type = test::test_service; auto svc_ptr = std::make_shared(ioc.get_executor()); + svc_ptr->open_stream(); decoders::publish_message pub_msg = std::make_tuple( "topic", int16_t(1), uint8_t(0b0100), publish_props {}, "payload"