From d6c4884d535517cdb18a7a8ddda66ee102458ed2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Korina=20=C5=A0imi=C4=8Devi=C4=87?= Date: Fri, 26 Jan 2024 09:39:01 +0100 Subject: [PATCH] (Un)subscribe validates the number of topics and reason codes Test Plan: related to T12015 - (un)subscribe reason codes will always contain as many reason codes as there are topic filters - if, by some odd chance, the client receives the wrong number of rcs or some are invalid, it will treat it as malformed - both subscribe_op and unsubscribe_op should be 100% covered by tests now Reviewers: ivica Reviewed By: ivica Subscribers: miljen, iljazovic Differential Revision: https://repo.mireo.local/D27592 --- include/async_mqtt5/impl/subscribe_op.hpp | 42 +-- include/async_mqtt5/impl/unsubscribe_op.hpp | 37 ++- test/include/test_common/packet_util.hpp | 5 + test/integration/cancellation.cpp | 14 +- test/integration/re_authentication.cpp | 16 +- test/integration/receive_publish.cpp | 12 +- test/integration/send_publish.cpp | 13 +- test/integration/sub_unsub.cpp | 288 +++++++++++++++++++- test/unit/subscribe_op.cpp | 34 ++- test/unit/unsubscribe_op.cpp | 25 +- 10 files changed, 405 insertions(+), 81 deletions(-) diff --git a/include/async_mqtt5/impl/subscribe_op.hpp b/include/async_mqtt5/impl/subscribe_op.hpp index 4d1d369..45cc38b 100644 --- a/include/async_mqtt5/impl/subscribe_op.hpp +++ b/include/async_mqtt5/impl/subscribe_op.hpp @@ -38,6 +38,8 @@ class subscribe_op { >; handler_type _handler; + size_t _num_topics { 0 }; + public: subscribe_op( const std::shared_ptr& svc_ptr, @@ -70,15 +72,18 @@ public: const std::vector& topics, const subscribe_props& props ) { + _num_topics = topics.size(); + uint16_t packet_id = _svc_ptr->allocate_pid(); if (packet_id == 0) - return complete_post( - client::error::pid_overrun, packet_id, topics.size() - ); + return complete_post(client::error::pid_overrun, packet_id); + + if (_num_topics == 0) + return complete_post(client::error::invalid_topic, packet_id); auto ec = validate_subscribe(topics, props); if (ec) - return complete_post(ec, packet_id, topics.size()); + return complete_post(ec, packet_id); auto subscribe = control_packet::of( with_pid, get_allocator(), @@ -91,9 +96,7 @@ public: .value_or(default_max_send_size) ); if (subscribe.size() > max_packet_size) - return complete_post( - client::error::packet_too_large, packet_id, topics.size() - ); + return complete_post(client::error::packet_too_large, packet_id); send_subscribe(std::move(subscribe)); } @@ -155,11 +158,18 @@ public: return resend_subscribe(std::move(packet)); } - auto& [props, reason_codes] = *suback; + auto& [props, rcs] = *suback; + auto reason_codes = to_reason_codes(std::move(rcs)); + if (reason_codes.size() != _num_topics) { + on_malformed_packet( + "Malformed SUBACK: does not contain a " + "valid Reason Code for every Topic Filter" + ); + return resend_subscribe(std::move(packet)); + } complete( - ec, packet_id, - to_reason_codes(std::move(reason_codes)), std::move(props) + ec, packet_id, std::move(reason_codes), std::move(props) ); } @@ -232,9 +242,7 @@ private: client::error::malformed_packet; } - static std::vector to_reason_codes( - std::vector codes - ) { + static std::vector to_reason_codes(std::vector codes) { std::vector ret; for (uint8_t code : codes) { auto rc = to_reason_code(code); @@ -253,12 +261,11 @@ private: ); } - - void complete_post(error_code ec, uint16_t packet_id, size_t num_topics) { + void complete_post(error_code ec, uint16_t packet_id) { if (packet_id != 0) _svc_ptr->free_pid(packet_id); _handler.complete_post( - ec, std::vector(num_topics, reason_codes::empty), + ec, std::vector(_num_topics, reason_codes::empty), suback_props {} ); } @@ -267,6 +274,9 @@ private: error_code ec, uint16_t packet_id, std::vector reason_codes = {}, suback_props props = {} ) { + if (reason_codes.empty() && _num_topics) + reason_codes = std::vector(_num_topics, reason_codes::empty); + if (!_svc_ptr->subscriptions_present()) { bool has_success_rc = std::any_of( reason_codes.cbegin(), reason_codes.cend(), diff --git a/include/async_mqtt5/impl/unsubscribe_op.hpp b/include/async_mqtt5/impl/unsubscribe_op.hpp index 354f4bc..5edf7cd 100644 --- a/include/async_mqtt5/impl/unsubscribe_op.hpp +++ b/include/async_mqtt5/impl/unsubscribe_op.hpp @@ -33,6 +33,8 @@ class unsubscribe_op { >; handler_type _handler; + size_t _num_topics { 0 }; + public: unsubscribe_op( const std::shared_ptr& svc_ptr, @@ -65,15 +67,18 @@ public: const std::vector& topics, const unsubscribe_props& props ) { + _num_topics = topics.size(); + uint16_t packet_id = _svc_ptr->allocate_pid(); if (packet_id == 0) - return complete_post( - client::error::pid_overrun, packet_id, topics.size() - ); + return complete_post(client::error::pid_overrun, packet_id); + + if (_num_topics == 0) + return complete_post(client::error::invalid_topic, packet_id); auto ec = validate_unsubscribe(topics, props); if (ec) - return complete_post(ec, packet_id, topics.size()); + return complete_post(ec, packet_id); auto unsubscribe = control_packet::of( with_pid, get_allocator(), @@ -86,9 +91,7 @@ public: .value_or(default_max_send_size) ); if (unsubscribe.size() > max_packet_size) - return complete_post( - client::error::packet_too_large, packet_id, topics.size() - ); + return complete_post(client::error::packet_too_large, packet_id); send_unsubscribe(std::move(unsubscribe)); } @@ -150,11 +153,18 @@ public: return resend_unsubscribe(std::move(packet)); } - auto& [props, reason_codes] = *unsuback; + auto& [props, rcs] = *unsuback; + auto reason_codes = to_reason_codes(std::move(rcs)); + if (reason_codes.size() != _num_topics) { + on_malformed_packet( + "Malformed UNSUBACK: does not contain a " + "valid Reason Code for every Topic Filter" + ); + return resend_unsubscribe(std::move(packet)); + } complete( - ec, packet_id, - to_reason_codes(std::move(reason_codes)), std::move(props) + ec, packet_id, std::move(reason_codes), std::move(props) ); } @@ -196,11 +206,11 @@ private: ); } - void complete_post(error_code ec, uint16_t packet_id, size_t num_topics) { + void complete_post(error_code ec, uint16_t packet_id) { if (packet_id != 0) _svc_ptr->free_pid(packet_id); _handler.complete_post( - ec, std::vector(num_topics, reason_codes::empty), + ec, std::vector(_num_topics, reason_codes::empty), unsuback_props {} ); } @@ -209,6 +219,9 @@ private: error_code ec, uint16_t packet_id, std::vector reason_codes = {}, unsuback_props props = {} ) { + if (reason_codes.empty() && _num_topics) + reason_codes = std::vector(_num_topics, reason_codes::empty); + _svc_ptr->free_pid(packet_id); _handler.complete(ec, std::move(reason_codes), std::move(props)); } diff --git a/test/include/test_common/packet_util.hpp b/test/include/test_common/packet_util.hpp index e51c7d4..4edd6c1 100644 --- a/test/include/test_common/packet_util.hpp +++ b/test/include/test_common/packet_util.hpp @@ -361,6 +361,11 @@ std::vector to_readable_packets(const ConstBufferSequence& buffers) return content; } +inline disconnect_props dprops_with_reason_string(const std::string& reason_string) { + disconnect_props dprops; + dprops[prop::reason_string] = reason_string; + return dprops; +} } // end namespace async_mqtt5::test diff --git a/test/integration/cancellation.cpp b/test/integration/cancellation.cpp index dce5c11..8208dac 100644 --- a/test/integration/cancellation.cpp +++ b/test/integration/cancellation.cpp @@ -100,7 +100,7 @@ void setup_cancel_op_test_case( template < test::operation_type op_type, - std::enable_if_t = true + std::enable_if_t = true > void setup_cancel_op_test_case( client_type& c, asio::cancellation_signal& signal, int& handlers_called @@ -115,9 +115,8 @@ void setup_cancel_op_test_case( ) { handlers_called++; BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); - // TODO: be consistent with complete_post - //BOOST_ASSERT(rcs.size() == 1); - //BOOST_CHECK(rcs[0] == reason_codes::empty); + BOOST_ASSERT(rcs.size() == 1); + BOOST_CHECK(rcs[0] == reason_codes::empty); } ) ); @@ -125,7 +124,7 @@ void setup_cancel_op_test_case( template < test::operation_type op_type, - std::enable_if_t = true + std::enable_if_t = true > void setup_cancel_op_test_case( client_type& c, asio::cancellation_signal& signal, int& handlers_called @@ -140,9 +139,8 @@ void setup_cancel_op_test_case( ) { handlers_called++; BOOST_CHECK_EQUAL(ec, asio::error::operation_aborted); - // TODO: be consistent with complete_post - //BOOST_ASSERT(rcs.size() == 1); - //BOOST_CHECK(rcs[0] == reason_codes::empty); + BOOST_ASSERT(rcs.size() == 1); + BOOST_CHECK(rcs[0] == reason_codes::empty); } ) ); diff --git a/test/integration/re_authentication.cpp b/test/integration/re_authentication.cpp index ab9744a..79a4705 100644 --- a/test/integration/re_authentication.cpp +++ b/test/integration/re_authentication.cpp @@ -6,7 +6,9 @@ #include #include "test_common/message_exchange.hpp" +#include "test_common/packet_util.hpp" #include "test_common/test_authenticators.hpp" +#include "test_common/test_broker.hpp" #include "test_common/test_stream.hpp" using namespace async_mqtt5; @@ -85,12 +87,6 @@ void run_test( BOOST_CHECK(broker.received_all_expected()); } -disconnect_props dprops_with_reason_string(std::string_view reason_string) { - disconnect_props dprops; - dprops[prop::reason_string] = reason_string; - return dprops; -} - BOOST_FIXTURE_TEST_CASE(successful_re_auth, shared_test_data) { test::msg_exchange broker_side; broker_side @@ -123,7 +119,7 @@ BOOST_FIXTURE_TEST_CASE(successful_re_auth_multi_step, shared_test_data) { BOOST_FIXTURE_TEST_CASE(malformed_auth_rc, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::malformed_packet.value(), - dprops_with_reason_string("Malformed AUTH received: bad reason code") + test::dprops_with_reason_string("Malformed AUTH received: bad reason code") ); auto malformed_auth = encoders::encode_auth( reason_codes::administrative_action.value(), init_auth_props() @@ -152,7 +148,7 @@ BOOST_FIXTURE_TEST_CASE(mismatched_auth_method, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::protocol_error.value(), - dprops_with_reason_string("Malformed AUTH received: wrong authentication method") + test::dprops_with_reason_string("Malformed AUTH received: wrong authentication method") ); test::msg_exchange broker_side; @@ -171,7 +167,7 @@ BOOST_FIXTURE_TEST_CASE(mismatched_auth_method, shared_test_data) { BOOST_FIXTURE_TEST_CASE(async_auth_fail, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::unspecified_error.value(), - dprops_with_reason_string("Re-authentication: authentication fail") + test::dprops_with_reason_string("Re-authentication: authentication fail") ); test::msg_exchange broker_side; @@ -196,7 +192,7 @@ BOOST_FIXTURE_TEST_CASE(unexpected_auth, shared_test_data) { ); auto disconnect = encoders::encode_disconnect( reason_codes::protocol_error.value(), - dprops_with_reason_string("Unexpected AUTH received") + test::dprops_with_reason_string("Unexpected AUTH received") ); test::msg_exchange broker_side; diff --git a/test/integration/receive_publish.cpp b/test/integration/receive_publish.cpp index 528d301..0304596 100644 --- a/test/integration/receive_publish.cpp +++ b/test/integration/receive_publish.cpp @@ -6,6 +6,8 @@ #include #include "test_common/message_exchange.hpp" +#include "test_common/packet_util.hpp" +#include "test_common/test_broker.hpp" #include "test_common/test_service.hpp" #include "test_common/test_stream.hpp" @@ -122,12 +124,6 @@ BOOST_FIXTURE_TEST_CASE(receive_publish_qos2, shared_test_data) { run_test(std::move(broker_side)); } -disconnect_props dprops_with_reason_string(const std::string& reason_string) { - disconnect_props dprops; - dprops[prop::reason_string] = reason_string; - return dprops; -} - BOOST_FIXTURE_TEST_CASE(receive_malformed_publish, shared_test_data) { // packets auto malformed_publish = encoders::encode_publish( @@ -137,7 +133,7 @@ BOOST_FIXTURE_TEST_CASE(receive_malformed_publish, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::malformed_packet.value(), - dprops_with_reason_string("Malformed PUBLISH received: QoS bits set to 0b11") + test::dprops_with_reason_string("Malformed PUBLISH received: QoS bits set to 0b11") ); test::msg_exchange broker_side; @@ -162,7 +158,7 @@ BOOST_FIXTURE_TEST_CASE(receive_malformed_pubrel, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::malformed_packet.value(), - dprops_with_reason_string("Malformed PUBREL received: invalid Reason Code") + test::dprops_with_reason_string("Malformed PUBREL received: invalid Reason Code") ); test::msg_exchange broker_side; diff --git a/test/integration/send_publish.cpp b/test/integration/send_publish.cpp index 7147fe0..0bc2e9c 100644 --- a/test/integration/send_publish.cpp +++ b/test/integration/send_publish.cpp @@ -5,6 +5,7 @@ #include #include "test_common/message_exchange.hpp" +#include "test_common/packet_util.hpp" #include "test_common/test_service.hpp" #include "test_common/test_stream.hpp" @@ -200,12 +201,6 @@ BOOST_FIXTURE_TEST_CASE(fail_to_send_pubrel, shared_test_data) { ); } -disconnect_props dprops_with_reason_string(const std::string& reason_string) { - disconnect_props dprops; - dprops[prop::reason_string] = reason_string; - return dprops; -} - BOOST_FIXTURE_TEST_CASE(receive_malformed_puback, shared_test_data) { // packets auto publish_qos1_dup = encoders::encode_publish( @@ -215,7 +210,7 @@ BOOST_FIXTURE_TEST_CASE(receive_malformed_puback, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::malformed_packet.value(), - dprops_with_reason_string("Malformed PUBACK: invalid Reason Code") + test::dprops_with_reason_string("Malformed PUBACK: invalid Reason Code") ); test::msg_exchange broker_side; @@ -254,7 +249,7 @@ BOOST_FIXTURE_TEST_CASE(receive_malformed_pubrec, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::malformed_packet.value(), - dprops_with_reason_string("Malformed PUBREC: invalid Reason Code") + test::dprops_with_reason_string("Malformed PUBREC: invalid Reason Code") ); test::msg_exchange broker_side; @@ -292,7 +287,7 @@ BOOST_FIXTURE_TEST_CASE(receive_malformed_pubcomp, shared_test_data) { auto disconnect = encoders::encode_disconnect( reason_codes::malformed_packet.value(), - dprops_with_reason_string("Malformed PUBCOMP: invalid Reason Code") + test::dprops_with_reason_string("Malformed PUBCOMP: invalid Reason Code") ); test::msg_exchange broker_side; diff --git a/test/integration/sub_unsub.cpp b/test/integration/sub_unsub.cpp index 5bc716a..b1c7353 100644 --- a/test/integration/sub_unsub.cpp +++ b/test/integration/sub_unsub.cpp @@ -1,10 +1,14 @@ #include +#include +#include +#include #include #include #include "test_common/message_exchange.hpp" +#include "test_common/packet_util.hpp" #include "test_common/test_service.hpp" #include "test_common/test_stream.hpp" @@ -36,17 +40,17 @@ struct shared_test_data { subscribe_topic { "topic", subscribe_options {} } }; std::vector unsub_topics = { "topic" }; - std::vector rcs = { uint8_t(0x00) }; + std::vector reason_codes = { uint8_t(0x00) }; const std::string subscribe = encoders::encode_subscribe( 1, sub_topics, subscribe_props {} ); - const std::string suback = encoders::encode_suback(1, rcs, suback_props {}); + const std::string suback = encoders::encode_suback(1, reason_codes, suback_props {}); const std::string unsubscribe = encoders::encode_unsubscribe( 1, unsub_topics, unsubscribe_props {} ); - const std::string unsuback = encoders::encode_unsuback(1, rcs, unsuback_props {}); + const std::string unsuback = encoders::encode_unsuback(1, reason_codes, unsuback_props {}); }; using test::after; @@ -140,6 +144,106 @@ BOOST_FIXTURE_TEST_CASE(fail_to_receive_suback, shared_test_data) { run_test(std::move(broker_side)); } +BOOST_FIXTURE_TEST_CASE(receive_malformed_suback, shared_test_data) { + // packets + const char malformed_bytes[] = { + -112, 7, 0, 1, 4, 31, 0, 2, 32 + }; + std::string malformed_suback { malformed_bytes, sizeof(malformed_bytes) / sizeof(char) }; + + auto disconnect = encoders::encode_disconnect( + reason_codes::malformed_packet.value(), + test::dprops_with_reason_string("Malformed SUBACK: cannot decode") + ); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe) + .complete_with(success, after(1ms)) + .reply_with(malformed_suback, after(2ms)) + .expect(disconnect) + .complete_with(success, after(1ms)) + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe) + .complete_with(success, after(1ms)) + .reply_with(suback, after(2ms)); + + run_test(std::move(broker_side)); +} + +BOOST_FIXTURE_TEST_CASE(receive_invalid_rc_in_suback, shared_test_data) { + // packets + auto malformed_suback = encoders::encode_suback( + 1, { uint8_t(0x04) }, suback_props {} + ); + + auto disconnect = encoders::encode_disconnect( + reason_codes::malformed_packet.value(), + test::dprops_with_reason_string( + "Malformed SUBACK: does not contain a " + "valid Reason Code for every Topic Filter" + ) + ); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe) + .complete_with(success, after(1ms)) + .reply_with(malformed_suback, after(2ms)) + .expect(disconnect) + .complete_with(success, after(1ms)) + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe) + .complete_with(success, after(1ms)) + .reply_with(suback, after(2ms)); + + run_test(std::move(broker_side)); +} + +BOOST_FIXTURE_TEST_CASE(mismatched_num_of_suback_rcs, shared_test_data) { + // packets + auto malformed_suback = encoders::encode_suback( + 1, { uint8_t(0x00), uint8_t(0x00) }, suback_props {} + ); + + auto disconnect = encoders::encode_disconnect( + reason_codes::malformed_packet.value(), + test::dprops_with_reason_string( + "Malformed SUBACK: does not contain a " + "valid Reason Code for every Topic Filter" + ) + ); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe) + .complete_with(success, after(1ms)) + .reply_with(malformed_suback, after(2ms)) + .expect(disconnect) + .complete_with(success, after(1ms)) + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe) + .complete_with(success, after(1ms)) + .reply_with(suback, after(2ms)); + + run_test(std::move(broker_side)); +} + // unsubscribe BOOST_FIXTURE_TEST_CASE(fail_to_send_unsubscribe, shared_test_data) { @@ -179,4 +283,182 @@ BOOST_FIXTURE_TEST_CASE(fail_to_receive_unsuback, shared_test_data) { run_test(std::move(broker_side)); } +BOOST_FIXTURE_TEST_CASE(receive_malformed_unsuback, shared_test_data) { + // packets + const char malformed_bytes[] = { + -80, 7, 0, 1, 4, 31, 0, 2, 32 + }; + std::string malformed_unsuback { malformed_bytes, sizeof(malformed_bytes) / sizeof(char) }; + + auto disconnect = encoders::encode_disconnect( + reason_codes::malformed_packet.value(), + test::dprops_with_reason_string("Malformed UNSUBACK: cannot decode") + ); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(unsubscribe) + .complete_with(success, after(1ms)) + .reply_with(malformed_unsuback, after(2ms)) + .expect(disconnect) + .complete_with(success, after(1ms)) + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(unsubscribe) + .complete_with(success, after(1ms)) + .reply_with(unsuback, after(2ms)); + + run_test(std::move(broker_side)); +} + +BOOST_FIXTURE_TEST_CASE(receive_invalid_rc_in_unsuback, shared_test_data) { + // packets + auto malformed_unsuback = encoders::encode_unsuback( + 1, { uint8_t(0x04) }, unsuback_props {} + ); + + auto disconnect = encoders::encode_disconnect( + reason_codes::malformed_packet.value(), + test::dprops_with_reason_string( + "Malformed UNSUBACK: does not contain a " + "valid Reason Code for every Topic Filter" + ) + ); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(unsubscribe) + .complete_with(success, after(1ms)) + .reply_with(malformed_unsuback, after(2ms)) + .expect(disconnect) + .complete_with(success, after(1ms)) + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(unsubscribe) + .complete_with(success, after(1ms)) + .reply_with(unsuback, after(2ms)); + + run_test(std::move(broker_side)); +} + +BOOST_FIXTURE_TEST_CASE(mismatched_num_of_unsuback_rcs, shared_test_data) { + // packets + auto malformed_unsuback = encoders::encode_unsuback( + 1, { uint8_t(0x00), uint8_t(0x00)}, unsuback_props {} + ); + + auto disconnect = encoders::encode_disconnect( + reason_codes::malformed_packet.value(), + test::dprops_with_reason_string( + "Malformed UNSUBACK: does not contain a " + "valid Reason Code for every Topic Filter" + ) + ); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(unsubscribe) + .complete_with(success, after(1ms)) + .reply_with(malformed_unsuback, after(2ms)) + .expect(disconnect) + .complete_with(success, after(1ms)) + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(unsubscribe) + .complete_with(success, after(1ms)) + .reply_with(unsuback, after(2ms)); + + run_test(std::move(broker_side)); +} + +template +void run_cancellation_test(test::msg_exchange broker_side) { + constexpr int expected_handlers_called = 1; + int handlers_called = 0; + + asio::io_context ioc; + auto executor = ioc.get_executor(); + auto& broker = asio::make_service( + ioc, executor, std::move(broker_side) + ); + + using client_type = mqtt_client; + client_type c(executor, ""); + c.brokers("127.0.0.1,127.0.0.1") // to avoid reconnect backoff + .async_run(asio::detached); + + asio::cancellation_signal cancel_signal; + auto data = shared_test_data(); + if constexpr (op_type == test::operation_type::subscribe) + c.async_subscribe( + data.sub_topics, subscribe_props {}, + asio::bind_cancellation_slot( + cancel_signal.slot(), + [&handlers_called, &c](error_code ec, std::vector rcs, suback_props) { + ++handlers_called; + + BOOST_CHECK(ec == asio::error::operation_aborted); + BOOST_ASSERT(rcs.size() == 1); + BOOST_CHECK_EQUAL(rcs[0], reason_codes::empty); + + c.cancel(); + } + ) + ); + else + c.async_unsubscribe( + data.unsub_topics, unsubscribe_props {}, + asio::bind_cancellation_slot( + cancel_signal.slot(), + [&handlers_called, &c](error_code ec, std::vector rcs, unsuback_props) { + ++handlers_called; + + BOOST_CHECK(ec == asio::error::operation_aborted); + BOOST_ASSERT(rcs.size() == 1); + BOOST_CHECK_EQUAL(rcs[0], reason_codes::empty); + + c.cancel(); + } + ) + ); + + cancel_signal.emit(asio::cancellation_type::total); + + ioc.run_for(2s); + BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); + BOOST_CHECK(broker.received_all_expected()); +} + +BOOST_FIXTURE_TEST_CASE(cancel_resending_subscribe, shared_test_data) { + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)); + + run_cancellation_test(std::move(broker_side)); +} + +BOOST_FIXTURE_TEST_CASE(cancel_resending_unsubscribe, shared_test_data) { + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)); + + run_cancellation_test(std::move(broker_side)); +} + BOOST_AUTO_TEST_SUITE_END(); diff --git a/test/unit/subscribe_op.cpp b/test/unit/subscribe_op.cpp index 9978143..4e058b1 100644 --- a/test/unit/subscribe_op.cpp +++ b/test/unit/subscribe_op.cpp @@ -37,7 +37,7 @@ BOOST_AUTO_TEST_CASE(pid_overrun) { } void run_test( - error_code expected_ec, const std::string& topic_filter, + error_code expected_ec, const std::vector& topics, const subscribe_props& sprops = {}, const connack_props& cprops = {} ) { constexpr int expected_handlers_called = 1; @@ -47,26 +47,38 @@ void run_test( using client_service_type = test::test_service; auto svc_ptr = std::make_shared(ioc.get_executor(), cprops); - auto handler = [&handlers_called, expected_ec] + auto handler = [&handlers_called, expected_ec, num_tp = topics.size()] (error_code ec, std::vector rcs, suback_props) { ++handlers_called; BOOST_CHECK(ec == expected_ec); - BOOST_ASSERT(rcs.size() == 1); - BOOST_CHECK_EQUAL(rcs[0], reason_codes::empty); + BOOST_ASSERT(rcs.size() == num_tp); + + for (size_t i = 0; i < rcs.size(); ++i) + BOOST_CHECK_EQUAL(rcs[i], reason_codes::empty); }; detail::subscribe_op< client_service_type, decltype(handler) > { svc_ptr, std::move(handler) } - .perform( - { { topic_filter, { qos_e::exactly_once } } }, sprops - ); + .perform(topics, sprops); ioc.run_for(std::chrono::milliseconds(500)); BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); } +void run_test( + error_code expected_ec, const std::string& topic, + const subscribe_props& sprops = {}, const connack_props& cprops = {} +) { + auto sub_topic = subscribe_topic(topic, subscribe_options()); + return run_test( + expected_ec, + std::vector { std::move(sub_topic) }, + sprops, cprops + ); +} + BOOST_AUTO_TEST_CASE(invalid_topic_filter_1) { run_test(client::error::invalid_topic, ""); } @@ -154,9 +166,7 @@ BOOST_AUTO_TEST_CASE(large_subscription_id) { subscribe_props sprops; sprops[prop::subscription_identifier] = std::numeric_limits::max(); - run_test( - client::error::malformed_packet, "topic", sprops, cprops - ); + run_test(client::error::malformed_packet, "topic", sprops, cprops); } BOOST_AUTO_TEST_CASE(packet_too_large) { @@ -168,4 +178,8 @@ BOOST_AUTO_TEST_CASE(packet_too_large) { ); } +BOOST_AUTO_TEST_CASE(zero_topic_filters) { + run_test(client::error::invalid_topic, std::vector {}); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/test/unit/unsubscribe_op.cpp b/test/unit/unsubscribe_op.cpp index 626c562..9135e5b 100644 --- a/test/unit/unsubscribe_op.cpp +++ b/test/unit/unsubscribe_op.cpp @@ -35,7 +35,7 @@ BOOST_AUTO_TEST_CASE(pid_overrun) { } void run_test( - error_code expected_ec, const std::string& topic_filter, + error_code expected_ec, const std::vector& topics, const unsubscribe_props& uprops = {}, const connack_props& cprops = {} ) { constexpr int expected_handlers_called = 1; @@ -45,24 +45,35 @@ void run_test( using client_service_type = test::test_service; auto svc_ptr = std::make_shared(ioc.get_executor(), cprops); - auto handler = [&handlers_called, expected_ec] + auto handler = [&handlers_called, expected_ec, num_tp = topics.size()] (error_code ec, std::vector rcs, unsuback_props) { ++handlers_called; BOOST_CHECK(ec == expected_ec); - BOOST_ASSERT(rcs.size() == 1); - BOOST_CHECK_EQUAL(rcs[0], reason_codes::empty); + BOOST_ASSERT(rcs.size() == num_tp); + + for (size_t i = 0; i < rcs.size(); ++i) + BOOST_CHECK_EQUAL(rcs[i], reason_codes::empty); }; detail::unsubscribe_op< client_service_type, decltype(handler) > { svc_ptr, std::move(handler) } - .perform({ topic_filter }, uprops); + .perform(topics, uprops); ioc.run_for(std::chrono::milliseconds(500)); BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); } +void run_test( + error_code expected_ec, const std::string& topic, + const unsubscribe_props& uprops = {}, const connack_props& cprops = {} +) { + return run_test( + expected_ec, std::vector { topic }, uprops, cprops + ); +} + BOOST_AUTO_TEST_CASE(invalid_topic_filter_1) { run_test(client::error::invalid_topic, ""); } @@ -111,4 +122,8 @@ BOOST_AUTO_TEST_CASE(packet_too_large) { ); } +BOOST_AUTO_TEST_CASE(zero_topic_filters) { + run_test(client::error::invalid_topic, std::vector {}); +} + BOOST_AUTO_TEST_SUITE_END()