diff --git a/include/async_mqtt5/impl/connect_op.hpp b/include/async_mqtt5/impl/connect_op.hpp index 083e299..50013a9 100644 --- a/include/async_mqtt5/impl/connect_op.hpp +++ b/include/async_mqtt5/impl/connect_op.hpp @@ -2,6 +2,7 @@ #define ASYNC_MQTT5_CONNECT_OP_HPP #include +#include #include #include #include @@ -48,6 +49,7 @@ class connect_op { handler_type _handler; std::unique_ptr _buffer_ptr; + asio::cancellation_state _cancellation_state; using endpoint = asio::ip::tcp::endpoint; using epoints = asio::ip::tcp::resolver::results_type; @@ -58,7 +60,12 @@ public: Stream& stream, Handler&& handler, mqtt_ctx& ctx ) : _stream(stream), _ctx(ctx), - _handler(std::forward(handler)) + _handler(std::forward(handler)), + _cancellation_state( + asio::get_associated_cancellation_slot(_handler), + asio::enable_total_cancellation{}, + asio::enable_total_cancellation{} + ) {} connect_op(connect_op&&) noexcept = default; @@ -77,7 +84,7 @@ public: using cancellation_slot_type = asio::associated_cancellation_slot_t; cancellation_slot_type get_cancellation_slot() const noexcept { - return asio::get_associated_cancellation_slot(_handler); + return _cancellation_state.slot(); } void perform( @@ -95,6 +102,9 @@ public: void operator()( on_connect, error_code ec, endpoint ep, authority_path ap ) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(ec); @@ -130,6 +140,9 @@ public: on_tls_handshake, error_code ec, endpoint ep, authority_path ap ) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(ec); @@ -168,6 +181,9 @@ public: } void operator()(on_ws_handshake, error_code ec) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(ec); @@ -184,6 +200,9 @@ public: } void operator()(on_init_auth_data, error_code ec, std::string data) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(asio::error::try_again); @@ -212,6 +231,9 @@ public: } void operator()(on_send_connect, error_code ec, size_t) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(ec); @@ -227,6 +249,9 @@ public: void operator()( on_fixed_header, error_code ec, size_t num_read ) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(ec); @@ -267,6 +292,9 @@ public: on_read_packet, error_code ec, size_t, control_code_e code, byte_citer first, byte_citer last ) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(ec); @@ -336,6 +364,9 @@ public: } void operator()(on_auth_data, error_code ec, std::string data) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(asio::error::try_again); @@ -362,6 +393,9 @@ public: } void operator()(on_send_auth, error_code ec, size_t) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(ec); @@ -373,6 +407,9 @@ public: } void operator()(on_complete_auth, error_code ec, std::string) { + if (is_cancelled()) + return complete(asio::error::operation_aborted); + if (ec) return complete(asio::error::try_again); @@ -380,7 +417,12 @@ public: } private: + bool is_cancelled() const { + return _cancellation_state.cancelled() != asio::cancellation_type::none; + } + void complete(error_code ec) { + _cancellation_state.slot().clear(); std::move(_handler)(ec); } diff --git a/test/include/test_common/delayed_op.hpp b/test/include/test_common/delayed_op.hpp index 1f15a84..bfb4bd4 100644 --- a/test/include/test_common/delayed_op.hpp +++ b/test/include/test_common/delayed_op.hpp @@ -70,6 +70,12 @@ public: template void operator()(on_timer, CompletionHandler&& h, error_code ec) { + // The timer places a handler into the cancellation slot + // and does not clear it. Therefore, we need to clear it explicitly + // to properly remove the corresponding cancellation signal + // in the test_broker. + get_cancellation_slot().clear(); + auto bh = std::apply( [h = std::move(h)](auto&&... args) mutable { return asio::append(std::move(h), std::move(args)...);