diff --git a/include/async_mqtt5/impl/async_sender.hpp b/include/async_mqtt5/impl/async_sender.hpp index 4fcf760..4ebc26a 100644 --- a/include/async_mqtt5/impl/async_sender.hpp +++ b/include/async_mqtt5/impl/async_sender.hpp @@ -40,7 +40,7 @@ public: } asio::const_buffer buffer() const { - return _buffer; + return _buffer; } void complete(error_code ec) { @@ -54,6 +54,10 @@ public: ); } + bool empty() const { + return !_handler; + } + auto get_executor() { return asio::get_associated_executor(_handler); } @@ -235,29 +239,24 @@ private: write_queue = std::move(_write_queue); } else { - auto throttled_ptr = std::stable_partition( - _write_queue.begin(), _write_queue.end(), - [](const auto& op) { return !op.throttled(); } - ); - uint16_t dist = static_cast( - std::distance(throttled_ptr, _write_queue.end()) - ); - uint16_t throttled_num = std::min(dist, _quota); - _quota -= throttled_num; - throttled_ptr += throttled_num; + for (write_req& req : _write_queue) + if (!req.throttled()) + write_queue.push_back(std::move(req)); + else if (_quota > 0) { + --_quota; + write_queue.push_back(std::move(req)); + } - if (throttled_ptr == _write_queue.begin()) { + if (write_queue.empty()) { _write_in_progress = false; return; } - write_queue.insert( - write_queue.end(), - std::make_move_iterator(_write_queue.begin()), - std::make_move_iterator(throttled_ptr) + auto it = std::remove_if( + _write_queue.begin(), _write_queue.end(), + [](const write_req& req) { return req.empty(); } ); - - _write_queue.erase(_write_queue.begin(), throttled_ptr); + _write_queue.erase(it, _write_queue.end()); } std::vector buffers; diff --git a/test/integration/async_sender.cpp b/test/integration/async_sender.cpp index 7fdbce9..d25447d 100644 --- a/test/integration/async_sender.cpp +++ b/test/integration/async_sender.cpp @@ -40,7 +40,7 @@ struct shared_test_data { using test::after; using namespace std::chrono; -BOOST_FIXTURE_TEST_CASE(ordering_after_reconnect, shared_test_data) { +BOOST_FIXTURE_TEST_CASE(publish_ordering_after_reconnect, shared_test_data) { constexpr int expected_handlers_called = 2; int handlers_called = 0; @@ -112,6 +112,85 @@ BOOST_FIXTURE_TEST_CASE(ordering_after_reconnect, shared_test_data) { BOOST_TEST(broker.received_all_expected()); } +BOOST_FIXTURE_TEST_CASE(sub_unsub_ordering_after_reconnect, shared_test_data) { + constexpr int expected_handlers_called = 2; + int handlers_called = 0; + + // data + std::vector sub_topics = { + subscribe_topic { "topic", subscribe_options {} } + }; + std::vector sub_reason_codes = { + reason_codes::granted_qos_2.value() + }; + std::vector unsub_topics = { "topic" }; + std::vector unsub_reason_codes = { reason_codes::success.value() }; + + // packets + auto subscribe = encoders::encode_subscribe( + 1, sub_topics, subscribe_props {} + ); + auto suback = encoders::encode_suback(1, sub_reason_codes, suback_props {}); + auto unsubscribe = encoders::encode_unsubscribe( + 2, unsub_topics, unsubscribe_props {} + ); + auto unsuback = encoders::encode_unsuback(2, unsub_reason_codes, unsuback_props {}); + auto disconnect = encoders::encode_disconnect(0x00, {}); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe, unsubscribe) + .complete_with(success, after(1ms)) + .send(disconnect, after(5ms)) + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(subscribe, unsubscribe) + .complete_with(success, after(1ms)) + .reply_with(suback, unsuback, after(2ms)); + + asio::io_context ioc; + auto executor = ioc.get_executor(); + auto& broker = asio::make_service( + ioc, executor, std::move(broker_side) + ); + + using client_type = mqtt_client; + client_type c(executor); + c.brokers("127.0.0.1,127.0.0.1") // to avoid reconnect backoff + .async_run(asio::detached); + + c.async_subscribe( + sub_topics, subscribe_props {}, + [&](error_code ec, std::vector rcs, suback_props) { + ++handlers_called; + + BOOST_TEST(!ec); + BOOST_TEST_REQUIRE(rcs.size() == 1u); + BOOST_TEST(rcs[0] == reason_codes::granted_qos_2); + } + ); + c.async_unsubscribe( + unsub_topics, unsubscribe_props {}, + [&](error_code ec, std::vector rcs, unsuback_props) { + ++handlers_called; + + BOOST_TEST(!ec); + BOOST_TEST_REQUIRE(rcs.size() == 1u); + BOOST_TEST(rcs[0] == reason_codes::success); + + c.cancel(); + } + ); + + ioc.run_for(1s); + BOOST_TEST(handlers_called == expected_handlers_called); + BOOST_TEST(broker.received_all_expected()); +} + BOOST_FIXTURE_TEST_CASE(throttling, shared_test_data) { constexpr int expected_handlers_called = 3; int handlers_called = 0; @@ -184,6 +263,70 @@ BOOST_FIXTURE_TEST_CASE(throttling, shared_test_data) { BOOST_TEST(broker.received_all_expected()); } +BOOST_FIXTURE_TEST_CASE(throttling_ordering, shared_test_data) { + constexpr int expected_handlers_called = 2; + int handlers_called = 0; + + // packets + connack_props props; + props[prop::receive_maximum] = 2; + const std::string connack = encoders::encode_connack( + false, reason_codes::success.value(), props + ); + auto publish_qos0 = encoders::encode_publish( + 0, topic, payload, qos_e::at_most_once, retain_e::no, dup_e::no, {} + ); + + test::msg_exchange broker_side; + broker_side + .expect(connect) + .complete_with(success, after(1ms)) + .reply_with(connack, after(2ms)) + .expect(publish_qos1, publish_qos0) + .complete_with(success, after(1ms)) + .reply_with(puback, after(2ms)); + + asio::io_context ioc; + auto executor = ioc.get_executor(); + auto& broker = asio::make_service( + ioc, executor, std::move(broker_side) + ); + + using client_type = mqtt_client; + client_type c(executor); + c.brokers("127.0.0.1,127.0.0.1") // to avoid reconnect backoff + .async_run(asio::detached); + + c.async_publish( + topic, payload, retain_e::no, publish_props {}, + [&](error_code ec, reason_code rc, puback_props) { + ++handlers_called; + + BOOST_TEST(!ec); + BOOST_TEST(rc == reason_codes::success); + + if (handlers_called == expected_handlers_called) + c.cancel(); + } + ); + + c.async_publish( + topic, payload, retain_e::no, publish_props{}, + [&](error_code ec) { + ++handlers_called; + + BOOST_TEST(!ec); + + if (handlers_called == expected_handlers_called) + c.cancel(); + } + ); + + ioc.run_for(1s); + BOOST_TEST(handlers_called == expected_handlers_called); + BOOST_TEST(broker.received_all_expected()); +} + BOOST_FIXTURE_TEST_CASE(prioritize_disconnect, shared_test_data) { constexpr int expected_handlers_called = 3; int handlers_called = 0;