Add intermediate is_cancelled checks in connect_op

Summary: related to T13706

Reviewers: ivica

Reviewed By: ivica

Subscribers: miljen, iljazovic

Differential Revision: https://repo.mireo.local/D27918
This commit is contained in:
Korina Šimičević
2024-02-14 14:02:10 +01:00
parent e5de307723
commit 2b686dd6cb
2 changed files with 50 additions and 2 deletions

View File

@@ -2,6 +2,7 @@
#define ASYNC_MQTT5_CONNECT_OP_HPP #define ASYNC_MQTT5_CONNECT_OP_HPP
#include <boost/asio/append.hpp> #include <boost/asio/append.hpp>
#include <boost/asio/cancellation_state.hpp>
#include <boost/asio/consign.hpp> #include <boost/asio/consign.hpp>
#include <boost/asio/dispatch.hpp> #include <boost/asio/dispatch.hpp>
#include <boost/asio/prepend.hpp> #include <boost/asio/prepend.hpp>
@@ -48,6 +49,7 @@ class connect_op {
handler_type _handler; handler_type _handler;
std::unique_ptr<std::string> _buffer_ptr; std::unique_ptr<std::string> _buffer_ptr;
asio::cancellation_state _cancellation_state;
using endpoint = asio::ip::tcp::endpoint; using endpoint = asio::ip::tcp::endpoint;
using epoints = asio::ip::tcp::resolver::results_type; using epoints = asio::ip::tcp::resolver::results_type;
@@ -58,7 +60,12 @@ public:
Stream& stream, Handler&& handler, mqtt_ctx& ctx Stream& stream, Handler&& handler, mqtt_ctx& ctx
) : ) :
_stream(stream), _ctx(ctx), _stream(stream), _ctx(ctx),
_handler(std::forward<Handler>(handler)) _handler(std::forward<Handler>(handler)),
_cancellation_state(
asio::get_associated_cancellation_slot(_handler),
asio::enable_total_cancellation{},
asio::enable_total_cancellation{}
)
{} {}
connect_op(connect_op&&) noexcept = default; connect_op(connect_op&&) noexcept = default;
@@ -77,7 +84,7 @@ public:
using cancellation_slot_type = using cancellation_slot_type =
asio::associated_cancellation_slot_t<handler_type>; asio::associated_cancellation_slot_t<handler_type>;
cancellation_slot_type get_cancellation_slot() const noexcept { cancellation_slot_type get_cancellation_slot() const noexcept {
return asio::get_associated_cancellation_slot(_handler); return _cancellation_state.slot();
} }
void perform( void perform(
@@ -95,6 +102,9 @@ public:
void operator()( void operator()(
on_connect, error_code ec, endpoint ep, authority_path ap on_connect, error_code ec, endpoint ep, authority_path ap
) { ) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(ec); return complete(ec);
@@ -130,6 +140,9 @@ public:
on_tls_handshake, error_code ec, on_tls_handshake, error_code ec,
endpoint ep, authority_path ap endpoint ep, authority_path ap
) { ) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(ec); return complete(ec);
@@ -168,6 +181,9 @@ public:
} }
void operator()(on_ws_handshake, error_code ec) { void operator()(on_ws_handshake, error_code ec) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(ec); return complete(ec);
@@ -184,6 +200,9 @@ public:
} }
void operator()(on_init_auth_data, error_code ec, std::string data) { void operator()(on_init_auth_data, error_code ec, std::string data) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(asio::error::try_again); return complete(asio::error::try_again);
@@ -212,6 +231,9 @@ public:
} }
void operator()(on_send_connect, error_code ec, size_t) { void operator()(on_send_connect, error_code ec, size_t) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(ec); return complete(ec);
@@ -227,6 +249,9 @@ public:
void operator()( void operator()(
on_fixed_header, error_code ec, size_t num_read on_fixed_header, error_code ec, size_t num_read
) { ) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(ec); return complete(ec);
@@ -267,6 +292,9 @@ public:
on_read_packet, error_code ec, size_t, control_code_e code, on_read_packet, error_code ec, size_t, control_code_e code,
byte_citer first, byte_citer last byte_citer first, byte_citer last
) { ) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(ec); return complete(ec);
@@ -336,6 +364,9 @@ public:
} }
void operator()(on_auth_data, error_code ec, std::string data) { void operator()(on_auth_data, error_code ec, std::string data) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(asio::error::try_again); return complete(asio::error::try_again);
@@ -362,6 +393,9 @@ public:
} }
void operator()(on_send_auth, error_code ec, size_t) { void operator()(on_send_auth, error_code ec, size_t) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(ec); return complete(ec);
@@ -373,6 +407,9 @@ public:
} }
void operator()(on_complete_auth, error_code ec, std::string) { void operator()(on_complete_auth, error_code ec, std::string) {
if (is_cancelled())
return complete(asio::error::operation_aborted);
if (ec) if (ec)
return complete(asio::error::try_again); return complete(asio::error::try_again);
@@ -380,7 +417,12 @@ public:
} }
private: private:
bool is_cancelled() const {
return _cancellation_state.cancelled() != asio::cancellation_type::none;
}
void complete(error_code ec) { void complete(error_code ec) {
_cancellation_state.slot().clear();
std::move(_handler)(ec); std::move(_handler)(ec);
} }

View File

@@ -70,6 +70,12 @@ public:
template <typename CompletionHandler> template <typename CompletionHandler>
void operator()(on_timer, CompletionHandler&& h, error_code ec) { void operator()(on_timer, CompletionHandler&& h, error_code ec) {
// The timer places a handler into the cancellation slot
// and does not clear it. Therefore, we need to clear it explicitly
// to properly remove the corresponding cancellation signal
// in the test_broker.
get_cancellation_slot().clear();
auto bh = std::apply( auto bh = std::apply(
[h = std::move(h)](auto&&... args) mutable { [h = std::move(h)](auto&&... args) mutable {
return asio::append(std::move(h), std::move(args)...); return asio::append(std::move(h), std::move(args)...);