mirror of
https://github.com/boostorg/mqtt5.git
synced 2025-10-04 04:40:55 +02:00
[mqtt-client] add support for enhanced authentication
Summary: - Relates to T12899 - TODO: support re-authentication Reviewers: ivica Reviewed By: ivica Subscribers: korina Maniphest Tasks: T12899 Differential Revision: https://repo.mireo.local/D26327
This commit is contained in:
@@ -95,6 +95,9 @@ public:
|
||||
};
|
||||
|
||||
if (cc(error_code {}, 0) == 0 && _data_span.size()) {
|
||||
/* TODO clear read buffer on reconnect
|
||||
* OR use dispatch instead of post here
|
||||
*/
|
||||
return asio::post(
|
||||
asio::prepend(
|
||||
std::move(*this), on_read {}, error_code {},
|
||||
@@ -128,16 +131,16 @@ public:
|
||||
}
|
||||
|
||||
if (ec)
|
||||
return complete(ec, 0, 0, {}, {});
|
||||
return complete(ec, 0, {}, {});
|
||||
|
||||
_data_span.expand_suffix(bytes_read);
|
||||
assert(_data_span.size());
|
||||
|
||||
auto control_code = uint8_t(*_data_span.first());
|
||||
auto control_byte = uint8_t(*_data_span.first());
|
||||
|
||||
if ((control_code & 0b11110000) == 0)
|
||||
if ((control_byte & 0b11110000) == 0)
|
||||
// close the connection, cancel
|
||||
return complete(client::error::malformed_packet, 0, 0, {}, {});
|
||||
return complete(client::error::malformed_packet, 0, {}, {});
|
||||
|
||||
auto first = _data_span.first() + 1;
|
||||
auto varlen = decoders::type_parse(
|
||||
@@ -147,12 +150,12 @@ public:
|
||||
if (!varlen) {
|
||||
if (_data_span.size() < 5)
|
||||
return perform(wait_for, asio::transfer_at_least(1));
|
||||
return complete(client::error::malformed_packet, 0, 0, {}, {});
|
||||
return complete(client::error::malformed_packet, 0, {}, {});
|
||||
}
|
||||
|
||||
// TODO: respect max packet size which could be dinamically set by the broker
|
||||
if (*varlen > max_packet_size - std::distance(_data_span.first(), first))
|
||||
return complete(client::error::malformed_packet, 0, 0, {}, {});
|
||||
return complete(client::error::malformed_packet, 0, {}, {});
|
||||
|
||||
if (std::distance(first, _data_span.last()) < *varlen)
|
||||
return perform(wait_for, asio::transfer_at_least(1));
|
||||
@@ -161,7 +164,7 @@ public:
|
||||
std::distance(_data_span.first(), first) + *varlen
|
||||
);
|
||||
|
||||
dispatch(wait_for, control_code, first, first + *varlen);
|
||||
dispatch(wait_for, control_byte, first, first + *varlen);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -179,51 +182,39 @@ private:
|
||||
return res == 0b00000000;
|
||||
}
|
||||
|
||||
static bool contains_packet_id(control_code_e code) {
|
||||
using enum control_code_e;
|
||||
|
||||
return code == puback || code == pubrec
|
||||
|| code == pubrel || code == pubcomp
|
||||
|| code == subscribe || code == suback
|
||||
|| code == unsubscribe || code == unsuback;
|
||||
}
|
||||
|
||||
void dispatch(
|
||||
duration wait_for,
|
||||
uint8_t control_code, byte_citer first, byte_citer last
|
||||
uint8_t control_byte, byte_citer first, byte_citer last
|
||||
) {
|
||||
using namespace decoders;
|
||||
using enum control_code_e;
|
||||
|
||||
if (!valid_header(control_code))
|
||||
return complete(client::error::malformed_packet, 0, 0, {}, {});
|
||||
if (!valid_header(control_byte))
|
||||
return complete(client::error::malformed_packet, 0, {}, {});
|
||||
|
||||
auto code = control_code_e(control_code & 0b11110000);
|
||||
auto code = control_code_e(control_byte & 0b11110000);
|
||||
|
||||
if (code == pingresp)
|
||||
return perform(wait_for, asio::transfer_at_least(0));
|
||||
|
||||
uint16_t packet_id = 0;
|
||||
if (contains_packet_id(code))
|
||||
packet_id = decoders::decode_packet_id(first).value();
|
||||
|
||||
bool is_reply = code != publish && code != auth && code != disconnect;
|
||||
if (is_reply) {
|
||||
auto packet_id = decoders::decode_packet_id(first).value();
|
||||
_svc._replies.dispatch(error_code {}, code, packet_id, first, last);
|
||||
return perform(wait_for, asio::transfer_at_least(0));
|
||||
}
|
||||
|
||||
complete(error_code {}, packet_id, control_code, first, last);
|
||||
complete(error_code {}, control_byte, first, last);
|
||||
}
|
||||
|
||||
void complete(
|
||||
error_code ec, uint16_t packet_id, uint8_t control_code,
|
||||
error_code ec, uint8_t control_code,
|
||||
byte_citer first, byte_citer last
|
||||
) {
|
||||
asio::dispatch(
|
||||
get_executor(),
|
||||
asio::prepend(
|
||||
std::move(_handler), ec, packet_id, control_code,
|
||||
std::move(_handler), ec, control_code,
|
||||
first, last
|
||||
)
|
||||
);
|
||||
|
@@ -30,7 +30,7 @@ public:
|
||||
_handler(std::move(handler)) {}
|
||||
|
||||
static serial_num_t next_serial_num(serial_num_t last) {
|
||||
return ++last;
|
||||
return last + 1;
|
||||
}
|
||||
|
||||
asio::const_buffer buffer() const { return _buffer; }
|
||||
|
@@ -57,6 +57,13 @@ public:
|
||||
std::move(username), std::move(password)
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Authenticator>
|
||||
void authenticator(Authenticator&& authenticator) {
|
||||
_mqtt_context.authenticator = any_authenticator(
|
||||
std::forward<Authenticator>(authenticator)
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename StreamType>
|
||||
@@ -88,6 +95,13 @@ public:
|
||||
std::move(username), std::move(password)
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Authenticator>
|
||||
void authenticator(Authenticator&& authenticator) {
|
||||
_mqtt_context.authenticator = any_authenticator(
|
||||
std::forward<Authenticator>(authenticator)
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -177,6 +191,13 @@ public:
|
||||
_stream.brokers(std::move(hosts), default_port);
|
||||
}
|
||||
|
||||
template <typename Authenticator>
|
||||
void authenticator(Authenticator&& authenticator) {
|
||||
_stream_context.authenticator(
|
||||
std::forward<Authenticator>(authenticator)
|
||||
);
|
||||
}
|
||||
|
||||
template <typename Prop>
|
||||
auto connack_prop(Prop p) {
|
||||
return _stream_context.connack_prop(p);
|
||||
@@ -243,10 +264,10 @@ public:
|
||||
}.perform(wait_for, asio::transfer_at_least(0));
|
||||
};
|
||||
|
||||
using signature = void (
|
||||
error_code, uint16_t, uint8_t, byte_citer, byte_citer
|
||||
using Signature = void (
|
||||
error_code, uint8_t, byte_citer, byte_citer
|
||||
);
|
||||
return asio::async_initiate<CompletionToken, signature> (
|
||||
return asio::async_initiate<CompletionToken, Signature> (
|
||||
std::move(initiation), token, wait_for
|
||||
);
|
||||
}
|
||||
|
@@ -28,12 +28,18 @@ template <
|
||||
typename Stream, typename Handler
|
||||
>
|
||||
class connect_op {
|
||||
static constexpr size_t min_packet_sz = 5;
|
||||
|
||||
struct on_connect {};
|
||||
struct on_tls_handshake {};
|
||||
struct on_ws_handshake {};
|
||||
struct on_send_connect {};
|
||||
struct on_fixed_header {};
|
||||
struct on_read_connack {};
|
||||
struct on_read_packet {};
|
||||
struct on_init_auth_data {};
|
||||
struct on_auth_data {};
|
||||
struct on_send_auth {};
|
||||
struct on_complete_auth {};
|
||||
|
||||
Stream& _stream;
|
||||
mqtt_context& _ctx;
|
||||
@@ -152,13 +158,30 @@ public:
|
||||
);
|
||||
}
|
||||
else
|
||||
send_connect();
|
||||
(*this)(on_ws_handshake {}, error_code {});
|
||||
}
|
||||
|
||||
void operator()(on_ws_handshake, error_code ec) {
|
||||
if (ec)
|
||||
return complete(ec);
|
||||
|
||||
auto auth_method = _ctx.authenticator.method();
|
||||
if (!auth_method.empty()) {
|
||||
_ctx.co_props[prop::authentication_method] = auth_method;
|
||||
return _ctx.authenticator.async_auth(
|
||||
auth_step_e::client_initial, "",
|
||||
asio::prepend(std::move(*this), on_init_auth_data {})
|
||||
);
|
||||
}
|
||||
|
||||
send_connect();
|
||||
}
|
||||
|
||||
void operator()(on_init_auth_data, error_code ec, std::string data) {
|
||||
if (ec)
|
||||
return complete(asio::error::try_again);
|
||||
|
||||
_ctx.co_props[prop::authentication_data] = std::move(data);
|
||||
send_connect();
|
||||
}
|
||||
|
||||
@@ -186,10 +209,9 @@ public:
|
||||
if (ec)
|
||||
return complete(ec);
|
||||
|
||||
constexpr size_t min_connack_sz = 5;
|
||||
_buffer_ptr = std::make_unique<std::string>(min_connack_sz, 0);
|
||||
_buffer_ptr = std::make_unique<std::string>(min_packet_sz, 0);
|
||||
|
||||
auto buff = asio::buffer(_buffer_ptr->data(), min_connack_sz);
|
||||
auto buff = asio::buffer(_buffer_ptr->data(), min_packet_sz);
|
||||
asio::async_read(
|
||||
_stream, buff,
|
||||
asio::prepend(std::move(*this), on_fixed_header {})
|
||||
@@ -202,8 +224,9 @@ public:
|
||||
if (ec)
|
||||
return complete(ec);
|
||||
|
||||
auto control_byte = (*_buffer_ptr)[0];
|
||||
if (control_byte != 0b00100000)
|
||||
auto code = control_code_e((*_buffer_ptr)[0] & 0b11110000);
|
||||
|
||||
if (code != control_code_e::auth && code != control_code_e::connack)
|
||||
return complete(asio::error::try_again);
|
||||
|
||||
auto varlen_ptr = _buffer_ptr->cbegin() + 1;
|
||||
@@ -217,7 +240,8 @@ public:
|
||||
auto remain_len = *varlen -
|
||||
std::distance(varlen_ptr, _buffer_ptr->cbegin() + num_read);
|
||||
|
||||
_buffer_ptr->resize(_buffer_ptr->size() + remain_len);
|
||||
if (num_read + remain_len > _buffer_ptr->size())
|
||||
_buffer_ptr->resize(num_read + remain_len);
|
||||
|
||||
auto buff = asio::buffer(_buffer_ptr->data() + num_read, remain_len);
|
||||
auto first = _buffer_ptr->cbegin() + varlen_sz + 1;
|
||||
@@ -227,21 +251,33 @@ public:
|
||||
_stream, buff,
|
||||
asio::prepend(
|
||||
asio::append(
|
||||
std::move(*this), uint8_t(control_byte), first, last
|
||||
), on_read_connack {}
|
||||
std::move(*this), code, first, last
|
||||
), on_read_packet {}
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
void operator()(
|
||||
on_read_connack, error_code ec, size_t, uint8_t control_code,
|
||||
on_read_packet, error_code ec, size_t, control_code_e code,
|
||||
byte_citer first, byte_citer last
|
||||
) {
|
||||
if (ec)
|
||||
return complete(ec);
|
||||
|
||||
if (code == control_code_e::connack)
|
||||
return on_connack(first, last);
|
||||
|
||||
if (!_ctx.co_props[prop::authentication_method].has_value())
|
||||
return complete(client::error::malformed_packet);
|
||||
|
||||
on_auth(first, last);
|
||||
}
|
||||
|
||||
void on_connack(byte_citer first, byte_citer last) {
|
||||
auto packet_length = std::distance(first, last);
|
||||
auto rv = decoders::decode_connack(packet_length, first);
|
||||
if (!rv.has_value())
|
||||
return complete(client::error::malformed_packet);
|
||||
const auto& [session_present, reason_code, ca_props] = *rv;
|
||||
|
||||
_ctx.ca_props = ca_props;
|
||||
@@ -257,7 +293,84 @@ public:
|
||||
if (!rc.has_value()) // reason code not allowed in CONNACK
|
||||
return complete(client::error::malformed_packet);
|
||||
|
||||
complete(to_asio_error(*rc));
|
||||
auto ec = to_asio_error(*rc);
|
||||
if (ec)
|
||||
return complete(ec);
|
||||
|
||||
if (_ctx.co_props[prop::authentication_method].has_value())
|
||||
return _ctx.authenticator.async_auth(
|
||||
auth_step_e::server_final,
|
||||
ca_props[prop::authentication_data].value_or(""),
|
||||
asio::prepend(std::move(*this), on_complete_auth {})
|
||||
);
|
||||
|
||||
complete(error_code {});
|
||||
}
|
||||
|
||||
void on_auth(byte_citer first, byte_citer last) {
|
||||
auto packet_length = std::distance(first, last);
|
||||
auto rv = decoders::decode_auth(packet_length, first);
|
||||
if (!rv.has_value())
|
||||
return complete(client::error::malformed_packet);
|
||||
const auto& [reason_code, auth_props] = *rv;
|
||||
|
||||
auto rc = to_reason_code<reason_codes::category::auth>(reason_code);
|
||||
if (
|
||||
!rc.has_value() ||
|
||||
auth_props[prop::authentication_method]
|
||||
!= _ctx.co_props[prop::authentication_method]
|
||||
)
|
||||
return complete(client::error::malformed_packet);
|
||||
|
||||
_ctx.authenticator.async_auth(
|
||||
auth_step_e::server_challenge,
|
||||
auth_props[prop::authentication_data].value_or(""),
|
||||
asio::prepend(std::move(*this), on_auth_data {})
|
||||
);
|
||||
}
|
||||
|
||||
void operator()(on_auth_data, error_code ec, std::string data) {
|
||||
if (ec)
|
||||
return complete(asio::error::try_again);
|
||||
|
||||
auth_props props;
|
||||
props[prop::authentication_method] =
|
||||
_ctx.co_props[prop::authentication_method];
|
||||
props[prop::authentication_data] = std::move(data);
|
||||
|
||||
auto packet = control_packet<allocator_type>::of(
|
||||
no_pid, get_allocator(),
|
||||
encoders::encode_auth,
|
||||
reason_codes::continue_authentication.value(), props
|
||||
);
|
||||
|
||||
const auto& wire_data = packet.wire_data();
|
||||
|
||||
async_mqtt5::detail::async_write(
|
||||
_stream, asio::buffer(wire_data),
|
||||
asio::consign(
|
||||
asio::prepend(std::move(*this), on_send_auth{}),
|
||||
std::move(packet)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
void operator()(on_send_auth, error_code ec, size_t) {
|
||||
if (ec)
|
||||
return complete(ec);
|
||||
|
||||
auto buff = asio::buffer(_buffer_ptr->data(), min_packet_sz);
|
||||
asio::async_read(
|
||||
_stream, buff,
|
||||
asio::prepend(std::move(*this), on_fixed_header {})
|
||||
);
|
||||
}
|
||||
|
||||
void operator()(on_complete_auth, error_code ec, std::string) {
|
||||
if (ec)
|
||||
return complete(asio::error::try_again);
|
||||
|
||||
complete(error_code {});
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -55,7 +55,7 @@ public:
|
||||
|
||||
void operator()(
|
||||
on_message, error_code ec,
|
||||
uint16_t packet_id, uint8_t control_code,
|
||||
uint8_t control_code,
|
||||
byte_citer first, byte_citer last
|
||||
) {
|
||||
if (ec == client::error::malformed_packet)
|
||||
@@ -69,19 +69,17 @@ public:
|
||||
)
|
||||
return;
|
||||
|
||||
dispatch(ec, packet_id, control_code, first, last);
|
||||
dispatch(control_code, first, last);
|
||||
}
|
||||
|
||||
void operator()(on_disconnect, error_code ec) {
|
||||
if (!ec || ec == asio::error::try_again)
|
||||
if (!ec)
|
||||
perform();
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
// TODO: ec & packet_id are not used here
|
||||
void dispatch(
|
||||
error_code ec, uint16_t packet_id, uint8_t control_byte,
|
||||
uint8_t control_byte,
|
||||
byte_citer first, byte_citer last
|
||||
) {
|
||||
using enum control_code_e;
|
||||
|
@@ -83,7 +83,7 @@ public:
|
||||
void operator()(on_disconnect, error_code ec) {
|
||||
get_cancellation_slot().clear();
|
||||
|
||||
if (!ec || ec == asio::error::try_again)
|
||||
if (!ec)
|
||||
perform();
|
||||
}
|
||||
};
|
||||
|
Reference in New Issue
Block a user