diff --git a/CHANGELOG.md b/CHANGELOG.md index f84aab49..1bf9407d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ * Documentation tidying * is_invocable works with move-only types +* websocket::stream tidying -------------------------------------------------------------------------------- diff --git a/include/beast/websocket/impl/stream.ipp b/include/beast/websocket/impl/stream.ipp index 62bfe595..1b064912 100644 --- a/include/beast/websocket/impl/stream.ipp +++ b/include/beast/websocket/impl/stream.ipp @@ -70,252 +70,6 @@ set_option(permessage_deflate const& o) //------------------------------------------------------------------------------ -template -void -stream:: -reset() -{ - failed_ = false; - rd_.cont = false; - wr_close_ = false; - wr_.cont = false; - wr_block_ = nullptr; // should be nullptr on close anyway - ping_data_ = nullptr; // should be nullptr on close anyway - - stream_.buffer().consume( - stream_.buffer().size()); -} - -template -template -void -stream:: -do_accept( - Decorator const& decorator, error_code& ec) -{ - http::request_parser p; - http::read_header(next_layer(), - stream_.buffer(), p, ec); - if(ec) - return; - do_accept(p.get(), decorator, ec); -} - -template -template -void -stream:: -do_accept(http::header> const& req, - Decorator const& decorator, error_code& ec) -{ - auto const res = build_response(req, decorator); - http::write(stream_, res, ec); - if(ec) - return; - if(res.result() != http::status::switching_protocols) - { - ec = error::handshake_failed; - // VFALCO TODO Respect keep alive setting, perform - // teardown if Connection: close. - return; - } - pmd_read(pmd_config_, req); - open(role_type::server); -} - -template -template -void -stream:: -do_handshake(response_type* res_p, - string_view host, - string_view target, - RequestDecorator const& decorator, - error_code& ec) -{ - response_type res; - reset(); - detail::sec_ws_key_type key; - { - auto const req = build_request( - key, host, target, decorator); - pmd_read(pmd_config_, req); - http::write(stream_, req, ec); - } - if(ec) - return; - http::read(next_layer(), stream_.buffer(), res, ec); - if(ec) - return; - do_response(res, key, ec); - if(res_p) - *res_p = std::move(res); -} - -template -template -request_type -stream:: -build_request(detail::sec_ws_key_type& key, - string_view host, - string_view target, - Decorator const& decorator) -{ - request_type req; - req.target(target); - req.version = 11; - req.method(http::verb::get); - req.set(http::field::host, host); - req.set(http::field::upgrade, "websocket"); - req.set(http::field::connection, "upgrade"); - detail::make_sec_ws_key(key, maskgen_); - req.set(http::field::sec_websocket_key, key); - req.set(http::field::sec_websocket_version, "13"); - if(pmd_opts_.client_enable) - { - detail::pmd_offer config; - config.accept = true; - config.server_max_window_bits = - pmd_opts_.server_max_window_bits; - config.client_max_window_bits = - pmd_opts_.client_max_window_bits; - config.server_no_context_takeover = - pmd_opts_.server_no_context_takeover; - config.client_no_context_takeover = - pmd_opts_.client_no_context_takeover; - detail::pmd_write(req, config); - } - decorator(req); - if(! req.count(http::field::user_agent)) - req.set(http::field::user_agent, - BEAST_VERSION_STRING); - return req; -} - -template -template -response_type -stream:: -build_response(http::header> const& req, - Decorator const& decorator) -{ - auto const decorate = - [&decorator](response_type& res) - { - decorator(res); - if(! res.count(http::field::server)) - { - BOOST_STATIC_ASSERT(sizeof(BEAST_VERSION_STRING) < 20); - static_string<20> s(BEAST_VERSION_STRING); - res.set(http::field::server, s); - } - }; - auto err = - [&](std::string const& text) - { - response_type res; - res.version = req.version; - res.result(http::status::bad_request); - res.body = text; - res.prepare_payload(); - decorate(res); - return res; - }; - if(req.version < 11) - return err("HTTP version 1.1 required"); - if(req.method() != http::verb::get) - return err("Wrong method"); - if(! is_upgrade(req)) - return err("Expected Upgrade request"); - if(! req.count(http::field::host)) - return err("Missing Host"); - if(! req.count(http::field::sec_websocket_key)) - return err("Missing Sec-WebSocket-Key"); - if(! http::token_list{req[http::field::upgrade]}.exists("websocket")) - return err("Missing websocket Upgrade token"); - auto const key = req[http::field::sec_websocket_key]; - if(key.size() > detail::sec_ws_key_type::max_size_n) - return err("Invalid Sec-WebSocket-Key"); - { - auto const version = - req[http::field::sec_websocket_version]; - if(version.empty()) - return err("Missing Sec-WebSocket-Version"); - if(version != "13") - { - response_type res; - res.result(http::status::upgrade_required); - res.version = req.version; - res.set(http::field::sec_websocket_version, "13"); - res.prepare_payload(); - decorate(res); - return res; - } - } - - response_type res; - { - detail::pmd_offer offer; - detail::pmd_offer unused; - pmd_read(offer, req); - pmd_negotiate(res, unused, offer, pmd_opts_); - } - res.result(http::status::switching_protocols); - res.version = req.version; - res.set(http::field::upgrade, "websocket"); - res.set(http::field::connection, "upgrade"); - { - detail::sec_ws_accept_type acc; - detail::make_sec_ws_accept(acc, key); - res.set(http::field::sec_websocket_accept, acc); - } - decorate(res); - return res; -} - -template -void -stream:: -do_response(http::header const& res, - detail::sec_ws_key_type const& key, error_code& ec) -{ - bool const success = [&]() - { - if(res.version < 11) - return false; - if(res.result() != http::status::switching_protocols) - return false; - if(! http::token_list{res[http::field::connection]}.exists("upgrade")) - return false; - if(! http::token_list{res[http::field::upgrade]}.exists("websocket")) - return false; - if(res.count(http::field::sec_websocket_accept) != 1) - return false; - detail::sec_ws_accept_type acc; - detail::make_sec_ws_accept(acc, key); - if(acc.compare( - res[http::field::sec_websocket_accept]) != 0) - return false; - return true; - }(); - if(! success) - { - ec = error::handshake_failed; - return; - } - ec.assign(0, ec.category()); - detail::pmd_offer offer; - pmd_read(offer, res); - // VFALCO see if offer satisfies pmd_config_, - // return an error if not. - pmd_config_ = offer; // overwrite for now - open(role_type::client); -} - -//------------------------------------------------------------------------------ - template void stream:: @@ -371,6 +125,69 @@ close() pmd_.reset(); } +template +void +stream:: +reset() +{ + failed_ = false; + rd_.cont = false; + wr_close_ = false; + wr_.cont = false; + wr_block_ = nullptr; // should be nullptr on close anyway + ping_data_ = nullptr; // should be nullptr on close anyway + + stream_.buffer().consume( + stream_.buffer().size()); +} + +// Called before each read frame +template +void +stream:: +rd_begin() +{ + // Maintain the read buffer + if(pmd_) + { + if(! rd_.buf || rd_.buf_size != rd_buf_size_) + { + rd_.buf_size = rd_buf_size_; + rd_.buf = boost::make_unique_noinit< + std::uint8_t[]>(rd_.buf_size); + } + } +} + +// Called before each write frame +template +void +stream:: +wr_begin() +{ + wr_.autofrag = wr_autofrag_; + wr_.compress = static_cast(pmd_); + + // Maintain the write buffer + if( wr_.compress || + role_ == role_type::client) + { + if(! wr_.buf || wr_.buf_size != wr_buf_size_) + { + wr_.buf_size = wr_buf_size_; + wr_.buf = boost::make_unique_noinit< + std::uint8_t[]>(wr_.buf_size); + } + } + else + { + wr_.buf_size = wr_buf_size_; + wr_.buf.reset(); + } +} + +//------------------------------------------------------------------------------ + // Read fixed frame header from buffer // Requires at least 2 bytes // @@ -559,49 +376,6 @@ read_fh2(detail::frame_header& fh, code = close_code::none; } -template -void -stream:: -rd_begin() -{ - // Maintain the read buffer - if(pmd_) - { - if(! rd_.buf || rd_.buf_size != rd_buf_size_) - { - rd_.buf_size = rd_buf_size_; - rd_.buf = boost::make_unique_noinit< - std::uint8_t[]>(rd_.buf_size); - } - } -} - -template -void -stream:: -wr_begin() -{ - wr_.autofrag = wr_autofrag_; - wr_.compress = static_cast(pmd_); - - // Maintain the write buffer - if( wr_.compress || - role_ == role_type::client) - { - if(! wr_.buf || wr_.buf_size != wr_buf_size_) - { - wr_.buf_size = wr_buf_size_; - wr_.buf = boost::make_unique_noinit< - std::uint8_t[]>(wr_.buf_size); - } - } - else - { - wr_.buf_size = wr_buf_size_; - wr_.buf.reset(); - } -} - template template void @@ -682,6 +456,238 @@ write_ping(DynamicBuffer& db, db.commit(data.size()); } +//------------------------------------------------------------------------------ + +template +template +request_type +stream:: +build_request(detail::sec_ws_key_type& key, + string_view host, + string_view target, + Decorator const& decorator) +{ + request_type req; + req.target(target); + req.version = 11; + req.method(http::verb::get); + req.set(http::field::host, host); + req.set(http::field::upgrade, "websocket"); + req.set(http::field::connection, "upgrade"); + detail::make_sec_ws_key(key, maskgen_); + req.set(http::field::sec_websocket_key, key); + req.set(http::field::sec_websocket_version, "13"); + if(pmd_opts_.client_enable) + { + detail::pmd_offer config; + config.accept = true; + config.server_max_window_bits = + pmd_opts_.server_max_window_bits; + config.client_max_window_bits = + pmd_opts_.client_max_window_bits; + config.server_no_context_takeover = + pmd_opts_.server_no_context_takeover; + config.client_no_context_takeover = + pmd_opts_.client_no_context_takeover; + detail::pmd_write(req, config); + } + decorator(req); + if(! req.count(http::field::user_agent)) + req.set(http::field::user_agent, + BEAST_VERSION_STRING); + return req; +} + +template +template +response_type +stream:: +build_response(http::header> const& req, + Decorator const& decorator) +{ + auto const decorate = + [&decorator](response_type& res) + { + decorator(res); + if(! res.count(http::field::server)) + { + BOOST_STATIC_ASSERT(sizeof(BEAST_VERSION_STRING) < 20); + static_string<20> s(BEAST_VERSION_STRING); + res.set(http::field::server, s); + } + }; + auto err = + [&](std::string const& text) + { + response_type res; + res.version = req.version; + res.result(http::status::bad_request); + res.body = text; + res.prepare_payload(); + decorate(res); + return res; + }; + if(req.version < 11) + return err("HTTP version 1.1 required"); + if(req.method() != http::verb::get) + return err("Wrong method"); + if(! is_upgrade(req)) + return err("Expected Upgrade request"); + if(! req.count(http::field::host)) + return err("Missing Host"); + if(! req.count(http::field::sec_websocket_key)) + return err("Missing Sec-WebSocket-Key"); + if(! http::token_list{req[http::field::upgrade]}.exists("websocket")) + return err("Missing websocket Upgrade token"); + auto const key = req[http::field::sec_websocket_key]; + if(key.size() > detail::sec_ws_key_type::max_size_n) + return err("Invalid Sec-WebSocket-Key"); + { + auto const version = + req[http::field::sec_websocket_version]; + if(version.empty()) + return err("Missing Sec-WebSocket-Version"); + if(version != "13") + { + response_type res; + res.result(http::status::upgrade_required); + res.version = req.version; + res.set(http::field::sec_websocket_version, "13"); + res.prepare_payload(); + decorate(res); + return res; + } + } + + response_type res; + { + detail::pmd_offer offer; + detail::pmd_offer unused; + pmd_read(offer, req); + pmd_negotiate(res, unused, offer, pmd_opts_); + } + res.result(http::status::switching_protocols); + res.version = req.version; + res.set(http::field::upgrade, "websocket"); + res.set(http::field::connection, "upgrade"); + { + detail::sec_ws_accept_type acc; + detail::make_sec_ws_accept(acc, key); + res.set(http::field::sec_websocket_accept, acc); + } + decorate(res); + return res; +} + +template +template +void +stream:: +do_accept( + Decorator const& decorator, error_code& ec) +{ + http::request_parser p; + http::read_header(next_layer(), + stream_.buffer(), p, ec); + if(ec) + return; + do_accept(p.get(), decorator, ec); +} + +template +template +void +stream:: +do_accept(http::header> const& req, + Decorator const& decorator, error_code& ec) +{ + auto const res = build_response(req, decorator); + http::write(stream_, res, ec); + if(ec) + return; + if(res.result() != http::status::switching_protocols) + { + ec = error::handshake_failed; + // VFALCO TODO Respect keep alive setting, perform + // teardown if Connection: close. + return; + } + pmd_read(pmd_config_, req); + open(role_type::server); +} + +template +template +void +stream:: +do_handshake(response_type* res_p, + string_view host, + string_view target, + RequestDecorator const& decorator, + error_code& ec) +{ + response_type res; + reset(); + detail::sec_ws_key_type key; + { + auto const req = build_request( + key, host, target, decorator); + pmd_read(pmd_config_, req); + http::write(stream_, req, ec); + } + if(ec) + return; + http::read(next_layer(), stream_.buffer(), res, ec); + if(ec) + return; + do_response(res, key, ec); + if(res_p) + *res_p = std::move(res); +} + +template +void +stream:: +do_response(http::header const& res, + detail::sec_ws_key_type const& key, error_code& ec) +{ + bool const success = [&]() + { + if(res.version < 11) + return false; + if(res.result() != http::status::switching_protocols) + return false; + if(! http::token_list{res[http::field::connection]}.exists("upgrade")) + return false; + if(! http::token_list{res[http::field::upgrade]}.exists("websocket")) + return false; + if(res.count(http::field::sec_websocket_accept) != 1) + return false; + detail::sec_ws_accept_type acc; + detail::make_sec_ws_accept(acc, key); + if(acc.compare( + res[http::field::sec_websocket_accept]) != 0) + return false; + return true; + }(); + if(! success) + { + ec = error::handshake_failed; + return; + } + ec.assign(0, ec.category()); + detail::pmd_offer offer; + pmd_read(offer, res); + // VFALCO see if offer satisfies pmd_config_, + // return an error if not. + pmd_config_ = offer; // overwrite for now + open(role_type::client); +} + +//------------------------------------------------------------------------------ + } // websocket } // beast diff --git a/include/beast/websocket/stream.hpp b/include/beast/websocket/stream.hpp index 79e1dd98..7d59ca49 100644 --- a/include/beast/websocket/stream.hpp +++ b/include/beast/websocket/stream.hpp @@ -114,8 +114,12 @@ class stream { friend class detail::frame_test; friend class stream_test; + friend class frame_test; - buffered_read_stream stream_; + struct op {}; + + using control_cb_type = + std::function; /// Identifies the role of a WebSockets stream. enum class role_type @@ -127,35 +131,6 @@ class stream server }; - friend class frame_test; - - using control_cb_type = - std::function; - - struct op {}; - - detail::maskgen maskgen_; // source of mask keys - std::size_t rd_msg_max_ = - 16 * 1024 * 1024; // max message size - bool wr_autofrag_ = true; // auto fragment - std::size_t wr_buf_size_ = 4096; // write buffer size - std::size_t rd_buf_size_ = 4096; // read buffer size - detail::opcode wr_opcode_ = - detail::opcode::text; // outgoing message type - control_cb_type ctrl_cb_; // control callback - role_type role_; // server or client - bool failed_; // the connection failed - - bool wr_close_; // sent close frame - op* wr_block_; // op currenly writing - - ping_data* ping_data_; // where to put the payload - detail::pausation rd_op_; // paused read op - detail::pausation wr_op_; // paused write op - detail::pausation ping_op_; // paused ping op - detail::pausation close_op_; // paused close op - close_reason cr_; // set from received close frame - // State information for the message being received // struct rd_t @@ -182,8 +157,6 @@ class stream std::unique_ptr buf; }; - rd_t rd_; - // State information for the message being sent // struct wr_t @@ -216,8 +189,6 @@ class stream std::unique_ptr buf; }; - wr_t wr_; - // State information for the permessage-deflate extension struct pmd_t { @@ -228,6 +199,32 @@ class stream zlib::inflate_stream zi; }; + buffered_read_stream< + NextLayer, multi_buffer> stream_; // the wrapped stream + detail::maskgen maskgen_; // source of mask keys + std::size_t rd_msg_max_ = + 16 * 1024 * 1024; // max message size + bool wr_autofrag_ = true; // auto fragment + std::size_t wr_buf_size_ = 4096; // write buffer size + std::size_t rd_buf_size_ = 4096; // read buffer size + detail::opcode wr_opcode_ = + detail::opcode::text; // outgoing message type + control_cb_type ctrl_cb_; // control callback + role_type role_; // server or client + bool failed_; // the connection failed + + bool wr_close_; // sent close frame + op* wr_block_; // op currenly writing + + ping_data* ping_data_; // where to put the payload + detail::pausation rd_op_; // paused read op + detail::pausation wr_op_; // paused write op + detail::pausation ping_op_; // paused ping op + detail::pausation close_op_; // paused close op + close_reason cr_; // set from received close frame + rd_t rd_; // read state + wr_t wr_; // write state + // If not engaged, then permessage-deflate is not // enabled for the currently active session. std::unique_ptr pmd_; @@ -238,40 +235,6 @@ class stream // Offer for clients, negotiated result for servers detail::pmd_offer pmd_config_; - void - open(role_type role); - - void - close(); - - template - std::size_t - read_fh1(detail::frame_header& fh, - DynamicBuffer& db, close_code& code); - - template - void - read_fh2(detail::frame_header& fh, - DynamicBuffer& db, close_code& code); - - // Called before receiving the first frame of each message - void - rd_begin(); - - // Called before sending the first frame of each message - // - void - wr_begin(); - - template - void - write_close(DynamicBuffer& db, close_reason const& rc); - - template - void - write_ping(DynamicBuffer& db, - detail::opcode op, ping_data const& data); - public: /// The type of the next layer. using next_layer_type = @@ -281,9 +244,18 @@ public: using lowest_layer_type = typename get_lowest_layer::type; + /** Destructor + + Destroys the stream and all associated resources. + + @note A stream object must not be destroyed while there + are pending asynchronous operations associated with it. + */ + ~stream() = default; + /** Constructor - If @c NextLayer is move constructible, this function + If `NextLayer` is move constructible, this function will move-construct a new stream from the existing stream. @note The behavior of move assignment on or from streams @@ -316,13 +288,6 @@ public: explicit stream(Args&&... args); - /** Destructor - - @note A stream object must not be destroyed while there - are pending asynchronous operations associated with it. - */ - ~stream() = default; - /** Return the `io_service` associated with the stream This function may be used to obtain the `io_service` object @@ -3309,30 +3274,61 @@ public: ConstBufferSequence const& buffers, WriteHandler&& handler); private: - template class accept_op; - template class close_op; - template class handshake_op; - template class ping_op; - template class response_op; - template class write_op; - template class write_frame_op; - template class read_op; - template class read_frame_op; + template class accept_op; + template class close_op; + template class handshake_op; + template class ping_op; + template class read_op; + template class read_frame_op; + template class response_op; + template class write_frame_op; + template class write_op; - static - void - default_decorate_req(request_type&) - { - } + static void default_decorate_req(request_type&) {} + static void default_decorate_res(response_type&) {} - static - void - default_decorate_res(response_type&) - { - } + void open(role_type role); + void close(); + void reset(); + void rd_begin(); + void wr_begin(); + template + std::size_t + read_fh1(detail::frame_header& fh, + DynamicBuffer& db, close_code& code); + + template void - reset(); + read_fh2(detail::frame_header& fh, + DynamicBuffer& db, close_code& code); + + template + void + write_close(DynamicBuffer& db, close_reason const& rc); + + template + void + write_ping(DynamicBuffer& db, + detail::opcode op, ping_data const& data); + + template + request_type + build_request(detail::sec_ws_key_type& key, + string_view host, + string_view target, + Decorator const& decorator); + + template + response_type + build_response(http::header> const& req, + Decorator const& decorator); template void @@ -3353,19 +3349,6 @@ private: RequestDecorator const& decorator, error_code& ec); - template - request_type - build_request(detail::sec_ws_key_type& key, - string_view host, - string_view target, - Decorator const& decorator); - - template - response_type - build_response(http::header> const& req, - Decorator const& decorator); - void do_response(http::header const& resp, detail::sec_ws_key_type const& key, error_code& ec);