#ifndef ASYNC_MQTT5_PUBLISH_SEND_OP_HPP #define ASYNC_MQTT5_PUBLISH_SEND_OP_HPP #include #include #include #include #include #include #include #include #include #include #include #include namespace async_mqtt5::detail { namespace asio = boost::asio; template using on_publish_signature = std::conditional_t< qos_type == qos_e::at_most_once, void (error_code), std::conditional_t< qos_type == qos_e::at_least_once, void (error_code, reason_code, puback_props), void (error_code, reason_code, pubcomp_props) > >; template using on_publish_props_type = std::conditional_t< qos_type == qos_e::at_most_once, void, std::conditional_t< qos_type == qos_e::at_least_once, puback_props, pubcomp_props > >; template class publish_send_op { using client_service = ClientService; struct on_publish {}; struct on_puback {}; struct on_pubrec {}; struct on_pubrel {}; struct on_pubcomp {}; std::shared_ptr _svc_ptr; using handler_type = cancellable_handler< Handler, typename client_service::executor_type >; handler_type _handler; serial_num_t _serial_num; public: publish_send_op( const std::shared_ptr& svc_ptr, Handler&& handler ) : _svc_ptr(svc_ptr), _handler(std::move(handler), _svc_ptr->get_executor()) { auto slot = asio::get_associated_cancellation_slot(_handler); if (slot.is_connected()) slot.assign([&svc = *_svc_ptr](asio::cancellation_type_t) { svc.cancel(); }); } publish_send_op(publish_send_op&&) noexcept = default; publish_send_op(const publish_send_op&) = delete; using executor_type = asio::associated_executor_t; executor_type get_executor() const noexcept { return asio::get_associated_executor(_handler); } using allocator_type = asio::associated_allocator_t; allocator_type get_allocator() const noexcept { return asio::get_associated_allocator(_handler); } void perform( std::string topic, std::string payload, retain_e retain, const publish_props& props ) { uint16_t packet_id = 0; if constexpr (qos_type != qos_e::at_most_once) { packet_id = _svc_ptr->allocate_pid(); if (packet_id == 0) return complete_post(client::error::pid_overrun, packet_id); } auto ec = validate_publish(topic, payload, retain, props); if (ec) return complete_post(ec, packet_id); _serial_num = _svc_ptr->next_serial_num(); auto publish = control_packet::of( with_pid, get_allocator(), encoders::encode_publish, packet_id, std::move(topic), std::move(payload), qos_type, retain, dup_e::no, props ); auto max_packet_size = static_cast( _svc_ptr->connack_property(prop::maximum_packet_size) .value_or(default_max_send_size) ); if (publish.size() > max_packet_size) return complete_post(client::error::packet_too_large, packet_id); send_publish(std::move(publish)); } void send_publish(control_packet publish) { auto wire_data = publish.wire_data(); _svc_ptr->async_send( wire_data, _serial_num, send_flag::throttled * (qos_type != qos_e::at_most_once), asio::prepend(std::move(*this), on_publish {}, std::move(publish)) ); } void resend_publish(control_packet publish) { if (_handler.cancelled() != asio::cancellation_type_t::none) return complete( asio::error::operation_aborted, publish.packet_id() ); send_publish(std::move(publish)); } void operator()( on_publish, control_packet publish, error_code ec ) { if (ec == asio::error::try_again) return resend_publish(std::move(publish)); if constexpr (qos_type == qos_e::at_most_once) return complete(ec); else { auto packet_id = publish.packet_id(); if (ec) return complete(ec, packet_id); if constexpr (qos_type == qos_e::at_least_once) _svc_ptr->async_wait_reply( control_code_e::puback, packet_id, asio::prepend( std::move(*this), on_puback {}, std::move(publish) ) ); else if constexpr (qos_type == qos_e::exactly_once) _svc_ptr->async_wait_reply( control_code_e::pubrec, packet_id, asio::prepend( std::move(*this), on_pubrec {}, std::move(publish) ) ); } } template < qos_e q = qos_type, std::enable_if_t = true > void operator()( on_puback, control_packet publish, error_code ec, byte_citer first, byte_citer last ) { if (ec == asio::error::try_again) // "resend unanswered" return resend_publish(std::move(publish.set_dup())); uint16_t packet_id = publish.packet_id(); if (ec) return complete(ec, packet_id); auto puback = decoders::decode_puback( static_cast(std::distance(first, last)), first ); if (!puback.has_value()) { on_malformed_packet("Malformed PUBACK: cannot decode"); return resend_publish(std::move(publish.set_dup())); } auto& [reason_code, props] = *puback; auto rc = to_reason_code(reason_code); if (!rc) { on_malformed_packet("Malformed PUBACK: invalid Reason Code"); return resend_publish(std::move(publish.set_dup())); } complete(ec, packet_id, *rc, std::move(props)); } template < qos_e q = qos_type, std::enable_if_t = true > void operator()( on_pubrec, control_packet publish, error_code ec, byte_citer first, byte_citer last ) { if (ec == asio::error::try_again) // "resend unanswered" return resend_publish(std::move(publish.set_dup())); uint16_t packet_id = publish.packet_id(); if (ec) return complete(ec, packet_id); auto pubrec = decoders::decode_pubrec( static_cast(std::distance(first, last)), first ); if (!pubrec.has_value()) { on_malformed_packet("Malformed PUBREC: cannot decode"); return resend_publish(std::move(publish.set_dup())); } auto& [reason_code, props] = *pubrec; auto rc = to_reason_code(reason_code); if (!rc) { on_malformed_packet("Malformed PUBREC: invalid Reason Code"); return resend_publish(std::move(publish.set_dup())); } if (*rc) return complete(ec, packet_id, *rc); auto pubrel = control_packet::of( with_pid, get_allocator(), encoders::encode_pubrel, packet_id, 0, pubrel_props {} ); send_pubrel(std::move(pubrel), false); } void send_pubrel(control_packet pubrel, bool throttled) { auto wire_data = pubrel.wire_data(); _svc_ptr->async_send( wire_data, _serial_num, (send_flag::throttled * throttled) | send_flag::prioritized, asio::prepend(std::move(*this), on_pubrel {}, std::move(pubrel)) ); } template < qos_e q = qos_type, std::enable_if_t = true > void operator()( on_pubrel, control_packet pubrel, error_code ec ) { if (ec == asio::error::try_again) return send_pubrel(std::move(pubrel), true); uint16_t packet_id = pubrel.packet_id(); if (ec) return complete(ec, packet_id); _svc_ptr->async_wait_reply( control_code_e::pubcomp, packet_id, asio::prepend(std::move(*this), on_pubcomp {}, std::move(pubrel)) ); } template < qos_e q = qos_type, std::enable_if_t = true > void operator()( on_pubcomp, control_packet pubrel, error_code ec, byte_citer first, byte_citer last ) { if (ec == asio::error::try_again) // "resend unanswered" return send_pubrel(std::move(pubrel), true); uint16_t packet_id = pubrel.packet_id(); if (ec) return complete(ec, packet_id); auto pubcomp = decoders::decode_pubcomp( static_cast(std::distance(first, last)), first ); if (!pubcomp.has_value()) { on_malformed_packet("Malformed PUBCOMP: cannot decode"); return send_pubrel(std::move(pubrel), true); } auto& [reason_code, props] = *pubcomp; auto rc = to_reason_code(reason_code); if (!rc) { on_malformed_packet("Malformed PUBCOMP: invalid Reason Code"); return send_pubrel(std::move(pubrel), true); } return complete(ec, pubrel.packet_id(), *rc); } private: error_code validate_publish( const std::string& topic, const std::string& payload, retain_e retain, const publish_props& props ) const { constexpr uint8_t default_retain_available = 1; constexpr uint8_t default_maximum_qos = 2; constexpr uint8_t default_payload_format_ind = 0; if (validate_topic_name(topic) != validation_result::valid) return client::error::invalid_topic; auto max_qos = _svc_ptr->connack_property(prop::maximum_qos) .value_or(default_maximum_qos); auto retain_available = _svc_ptr->connack_property(prop::retain_available) .value_or(default_retain_available); if (uint8_t(qos_type) > max_qos) return client::error::qos_not_supported; if (retain_available == 0 && retain == retain_e::yes) return client::error::retain_not_available; auto payload_format_ind = props[prop::payload_format_indicator] .value_or(default_payload_format_ind); if ( payload_format_ind == 1 && validate_mqtt_utf8(payload) != validation_result::valid ) return client::error::malformed_packet; return validate_props(props); } error_code validate_props(const publish_props& props) const { constexpr uint16_t default_topic_alias_max = 0; auto topic_alias = props[prop::topic_alias]; if (topic_alias) { auto topic_alias_max = _svc_ptr->connack_property(prop::topic_alias_maximum) .value_or(default_topic_alias_max); if (topic_alias_max == 0 || *topic_alias > topic_alias_max) return client::error::topic_alias_maximum_reached; if (*topic_alias == 0 ) return client::error::malformed_packet; } auto response_topic = props[prop::response_topic]; if ( response_topic && validate_topic_name(*response_topic) != validation_result::valid ) return client::error::malformed_packet; auto user_properties = props[prop::user_property]; for (const auto& user_prop: user_properties) if (validate_mqtt_utf8(user_prop) != validation_result::valid) return client::error::malformed_packet; auto subscription_identifier = props[prop::subscription_identifier]; if ( subscription_identifier && (*subscription_identifier < min_subscription_identifier || *subscription_identifier > max_subscription_identifier) ) return client::error::malformed_packet; auto content_type = props[prop::content_type]; if ( content_type && validate_mqtt_utf8(*content_type) != validation_result::valid ) return client::error::malformed_packet; return error_code {}; } void on_malformed_packet(const std::string& reason) { auto props = disconnect_props {}; props[prop::reason_string] = reason; async_disconnect( disconnect_rc_e::malformed_packet, props, false, _svc_ptr, asio::detached ); } template < qos_e q = qos_type, std::enable_if_t = true > void complete(error_code ec, uint16_t = 0) { _handler.complete(ec); } template < qos_e q = qos_type, std::enable_if_t = true > void complete_post(error_code ec, uint16_t) { _handler.complete_post(ec); } template < typename Props = on_publish_props_type, std::enable_if_t< std::is_same_v || std::is_same_v, bool > = true > void complete( error_code ec, uint16_t packet_id, reason_code rc = reason_codes::empty, Props&& props = Props {} ) { _svc_ptr->free_pid(packet_id, true); _handler.complete(ec, rc, std::forward(props)); } template < typename Props = on_publish_props_type, std::enable_if_t< std::is_same_v || std::is_same_v, bool > = true > void complete_post(error_code ec, uint16_t packet_id) { if (packet_id != 0) _svc_ptr->free_pid(packet_id, false); _handler.complete_post(ec, reason_codes::empty, Props {}); } }; } // end namespace async_mqtt5::detail #endif // !ASYNC_MQTT5_PUBLISH_SEND_OP_HPP