From 64a003e2ed6a01ffb6005b00dd3814b65ce71665 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Korina=20=C5=A0imi=C4=8Devi=C4=87?= Date: Fri, 1 Dec 2023 15:46:53 +0100 Subject: [PATCH] Do not store ec=session_expired if the client never subscribed Summary: related to T13152 Reviewers: ivica Reviewed By: ivica Subscribers: miljen, iljazovic Differential Revision: https://repo.mireo.local/D26803 --- example/openssl-tls.cpp | 2 +- example/src/run_examples.cpp | 6 +-- example/tcp.cpp | 2 +- example/websocket-tcp.cpp | 2 +- example/websocket-tls.cpp | 2 +- include/async_mqtt5/detail/internal_types.hpp | 13 +++++- include/async_mqtt5/impl/client_service.hpp | 35 ++++++++++++-- include/async_mqtt5/impl/subscribe_op.hpp | 11 +++++ test/unit/test/session.cpp | 46 +++---------------- 9 files changed, 69 insertions(+), 50 deletions(-) diff --git a/example/openssl-tls.cpp b/example/openssl-tls.cpp index 57036c8..880cc78 100644 --- a/example/openssl-tls.cpp +++ b/example/openssl-tls.cpp @@ -280,6 +280,6 @@ void run_openssl_tls_examples() { publish_qos0_openssl_tls(); publish_qos1_openssl_tls(); publish_qos2_openssl_tls(); - subscribe_and_receive_openssl_tls(2); + subscribe_and_receive_openssl_tls(1); test_coro(); } diff --git a/example/src/run_examples.cpp b/example/src/run_examples.cpp index c600354..f981058 100644 --- a/example/src/run_examples.cpp +++ b/example/src/run_examples.cpp @@ -7,9 +7,9 @@ void run_websocket_tls_examples(); int main(int argc, char* argv[]) { run_tcp_examples(); - // run_openssl_tls_examples(); - // run_websocket_tcp_examples(); - // run_websocket_tls_examples(); + run_openssl_tls_examples(); + run_websocket_tcp_examples(); + run_websocket_tls_examples(); return 0; } diff --git a/example/tcp.cpp b/example/tcp.cpp index b355bfd..9fbc64a 100644 --- a/example/tcp.cpp +++ b/example/tcp.cpp @@ -142,5 +142,5 @@ void run_tcp_examples() { publish_qos0_tcp(); publish_qos1_tcp(); publish_qos2_tcp(); - subscribe_and_receive_tcp(2); + subscribe_and_receive_tcp(1); } diff --git a/example/websocket-tcp.cpp b/example/websocket-tcp.cpp index e294d4a..8007e23 100644 --- a/example/websocket-tcp.cpp +++ b/example/websocket-tcp.cpp @@ -168,5 +168,5 @@ void run_websocket_tcp_examples() { publish_qos0_websocket_tcp(); publish_qos1_websocket_tcp(); publish_qos2_websocket_tcp(); - subscribe_and_receive_websocket_tcp(2); + subscribe_and_receive_websocket_tcp(1); } diff --git a/example/websocket-tls.cpp b/example/websocket-tls.cpp index 0e3bcdc..1dda51f 100644 --- a/example/websocket-tls.cpp +++ b/example/websocket-tls.cpp @@ -245,5 +245,5 @@ void run_websocket_tls_examples() { publish_qos0_websocket_tls(); publish_qos1_websocket_tls(); publish_qos2_websocket_tls(); - subscribe_and_receive_websocket_tls(2); + subscribe_and_receive_websocket_tls(1); } diff --git a/include/async_mqtt5/detail/internal_types.hpp b/include/async_mqtt5/detail/internal_types.hpp index f6127b6..799c677 100644 --- a/include/async_mqtt5/detail/internal_types.hpp +++ b/include/async_mqtt5/detail/internal_types.hpp @@ -39,12 +39,23 @@ class session_state { uint8_t _flags = 0b00; static constexpr uint8_t session_present_flag = 0b01; + static constexpr uint8_t subscriptions_present_flag = 0b10; public: void session_present(bool present) { return update_flag(present, session_present_flag); } - bool session_present() const { return _flags & session_present_flag; }; + bool session_present() const { + return _flags & session_present_flag; + } + + void subscriptions_present(bool present) { + return update_flag(present, subscriptions_present_flag); + } + + bool subscriptions_present() const { + return _flags & subscriptions_present_flag; + } private: void update_flag(bool set, uint8_t flag) { diff --git a/include/async_mqtt5/impl/client_service.hpp b/include/async_mqtt5/impl/client_service.hpp index 818cf38..20a70a7 100644 --- a/include/async_mqtt5/impl/client_service.hpp +++ b/include/async_mqtt5/impl/client_service.hpp @@ -3,8 +3,8 @@ #include -#include #include +#include #include #include @@ -40,6 +40,14 @@ public: return _tls_context; } + auto& session_state() { + return _mqtt_context.session_state; + } + + const auto& session_state() const { + return _mqtt_context.session_state; + } + void will(will will) { _mqtt_context.will = std::move(will); } @@ -78,6 +86,14 @@ public: return _mqtt_context; } + auto& session_state() { + return _mqtt_context.session_state; + } + + const auto& session_state() const { + return _mqtt_context.session_state; + } + void will(will will) { _mqtt_context.will = std::move(will); } @@ -289,12 +305,25 @@ public: ); } + bool subscriptions_present() const { + return _stream_context.session_state().subscriptions_present(); + } + + void subscriptions_present(bool present) { + _stream_context.session_state().subscriptions_present(present); + } + void update_session_state() { - auto& session_state = _stream_context.mqtt_context().session_state; + auto& session_state = _stream_context.session_state(); + if (!session_state.session_present()) { - channel_store_error(client::error::session_expired); _replies.clear_pending_pubrels(); session_state.session_present(true); + + if (session_state.subscriptions_present()) { + channel_store_error(client::error::session_expired); + session_state.subscriptions_present(false); + } } } diff --git a/include/async_mqtt5/impl/subscribe_op.hpp b/include/async_mqtt5/impl/subscribe_op.hpp index 3080bb3..653587a 100644 --- a/include/async_mqtt5/impl/subscribe_op.hpp +++ b/include/async_mqtt5/impl/subscribe_op.hpp @@ -1,6 +1,8 @@ #ifndef ASYNC_MQTT5_SUBSCRIBE_OP_HPP #define ASYNC_MQTT5_SUBSCRIBE_OP_HPP +#include + #include #include @@ -173,6 +175,15 @@ private: error_code ec, uint16_t packet_id, std::vector reason_codes, suback_props props ) { + if (!_svc_ptr->subscriptions_present()) { + bool has_success_rc = std::any_of( + reason_codes.cbegin(), reason_codes.cend(), + [](const reason_code& rc) { return !rc; } + ); + if (has_success_rc) + _svc_ptr->subscriptions_present(true); + } + _svc_ptr->free_pid(packet_id); _handler.complete(ec, std::move(reason_codes), std::move(props)); } diff --git a/test/unit/test/session.cpp b/test/unit/test/session.cpp index 7c8fa5e..89e5cdc 100644 --- a/test/unit/test/session.cpp +++ b/test/unit/test/session.cpp @@ -15,43 +15,21 @@ BOOST_AUTO_TEST_SUITE(session/*, *boost::unit_test::disabled()*/) BOOST_AUTO_TEST_CASE(session_state_session_present) { detail::session_state session_state; + BOOST_CHECK_EQUAL(session_state.session_present(), false); session_state.session_present(true); BOOST_CHECK_EQUAL(session_state.session_present(), true); session_state.session_present(false); BOOST_CHECK_EQUAL(session_state.session_present(), false); -} -BOOST_AUTO_TEST_CASE(session_expired_in_channel) { - asio::io_context ioc; - - using stream_type = asio::ip::tcp::socket; - using client_type = mqtt_client; - client_type c(ioc, ""); - - c.credentials("tester", "", "") - .brokers("mqtt.mireo.local", 1883) - .run(); - - co_spawn(ioc, - [&]() -> asio::awaitable { - auto [ec, topic, payload, props] = co_await c.async_receive(use_nothrow_awaitable); - BOOST_CHECK(ec == client::error::session_expired); - BOOST_CHECK_EQUAL(topic, std::string {}); - BOOST_CHECK_EQUAL(payload, std::string {}); - c.cancel(); - co_return; - }, - asio::detached - ); - - ioc.run(); + BOOST_CHECK_EQUAL(session_state.subscriptions_present(), false); + session_state.subscriptions_present(true); + BOOST_CHECK_EQUAL(session_state.subscriptions_present(), true); + session_state.subscriptions_present(false); + BOOST_CHECK_EQUAL(session_state.subscriptions_present(), false); } BOOST_AUTO_TEST_CASE(clear_waiting_on_pubrel) { - constexpr int expected_handlers_called = 1; - int handlers_called = 0; - asio::io_context ioc; using client_service_type = test::test_service; auto svc_ptr = std::make_shared(ioc.get_executor()); @@ -65,24 +43,14 @@ BOOST_AUTO_TEST_CASE(clear_waiting_on_pubrel) { // let publish_rec_op reach wait_on_pubrel stage asio::steady_timer timer(ioc.get_executor()); timer.expires_after(std::chrono::milliseconds(50)); - timer.async_wait([&svc_ptr, &handlers_called](error_code) { + timer.async_wait([&svc_ptr](error_code) { BOOST_CHECK_EQUAL(svc_ptr.use_count(), 2); svc_ptr->update_session_state(); // session_present = false // publish_rec_op should complete BOOST_CHECK_EQUAL(svc_ptr.use_count(), 1); - - svc_ptr->async_channel_receive( - [&svc_ptr, &handlers_called](error_code ec, std::string topic, std::string payload, publish_props props) { - handlers_called++; - BOOST_CHECK(ec == client::error::session_expired); - BOOST_CHECK_EQUAL(topic, std::string {}); - BOOST_CHECK_EQUAL(payload, std::string {}); - svc_ptr->cancel(); - }); }); ioc.run(); - BOOST_CHECK_EQUAL(handlers_called, expected_handlers_called); }