websocket::stream tidying

This commit is contained in:
Vinnie Falco
2017-07-14 12:11:44 -07:00
parent 27d070c724
commit 4d15fc455a
3 changed files with 388 additions and 398 deletions

View File

@@ -1,5 +1,6 @@
* Documentation tidying
* is_invocable works with move-only types
* websocket::stream tidying
--------------------------------------------------------------------------------

View File

@@ -70,252 +70,6 @@ set_option(permessage_deflate const& o)
//------------------------------------------------------------------------------
template<class NextLayer>
void
stream<NextLayer>::
reset()
{
failed_ = false;
rd_.cont = false;
wr_close_ = false;
wr_.cont = false;
wr_block_ = nullptr; // should be nullptr on close anyway
ping_data_ = nullptr; // should be nullptr on close anyway
stream_.buffer().consume(
stream_.buffer().size());
}
template<class NextLayer>
template<class Decorator>
void
stream<NextLayer>::
do_accept(
Decorator const& decorator, error_code& ec)
{
http::request_parser<http::empty_body> p;
http::read_header(next_layer(),
stream_.buffer(), p, ec);
if(ec)
return;
do_accept(p.get(), decorator, ec);
}
template<class NextLayer>
template<class Allocator, class Decorator>
void
stream<NextLayer>::
do_accept(http::header<true,
http::basic_fields<Allocator>> const& req,
Decorator const& decorator, error_code& ec)
{
auto const res = build_response(req, decorator);
http::write(stream_, res, ec);
if(ec)
return;
if(res.result() != http::status::switching_protocols)
{
ec = error::handshake_failed;
// VFALCO TODO Respect keep alive setting, perform
// teardown if Connection: close.
return;
}
pmd_read(pmd_config_, req);
open(role_type::server);
}
template<class NextLayer>
template<class RequestDecorator>
void
stream<NextLayer>::
do_handshake(response_type* res_p,
string_view host,
string_view target,
RequestDecorator const& decorator,
error_code& ec)
{
response_type res;
reset();
detail::sec_ws_key_type key;
{
auto const req = build_request(
key, host, target, decorator);
pmd_read(pmd_config_, req);
http::write(stream_, req, ec);
}
if(ec)
return;
http::read(next_layer(), stream_.buffer(), res, ec);
if(ec)
return;
do_response(res, key, ec);
if(res_p)
*res_p = std::move(res);
}
template<class NextLayer>
template<class Decorator>
request_type
stream<NextLayer>::
build_request(detail::sec_ws_key_type& key,
string_view host,
string_view target,
Decorator const& decorator)
{
request_type req;
req.target(target);
req.version = 11;
req.method(http::verb::get);
req.set(http::field::host, host);
req.set(http::field::upgrade, "websocket");
req.set(http::field::connection, "upgrade");
detail::make_sec_ws_key(key, maskgen_);
req.set(http::field::sec_websocket_key, key);
req.set(http::field::sec_websocket_version, "13");
if(pmd_opts_.client_enable)
{
detail::pmd_offer config;
config.accept = true;
config.server_max_window_bits =
pmd_opts_.server_max_window_bits;
config.client_max_window_bits =
pmd_opts_.client_max_window_bits;
config.server_no_context_takeover =
pmd_opts_.server_no_context_takeover;
config.client_no_context_takeover =
pmd_opts_.client_no_context_takeover;
detail::pmd_write(req, config);
}
decorator(req);
if(! req.count(http::field::user_agent))
req.set(http::field::user_agent,
BEAST_VERSION_STRING);
return req;
}
template<class NextLayer>
template<class Allocator, class Decorator>
response_type
stream<NextLayer>::
build_response(http::header<true,
http::basic_fields<Allocator>> const& req,
Decorator const& decorator)
{
auto const decorate =
[&decorator](response_type& res)
{
decorator(res);
if(! res.count(http::field::server))
{
BOOST_STATIC_ASSERT(sizeof(BEAST_VERSION_STRING) < 20);
static_string<20> s(BEAST_VERSION_STRING);
res.set(http::field::server, s);
}
};
auto err =
[&](std::string const& text)
{
response_type res;
res.version = req.version;
res.result(http::status::bad_request);
res.body = text;
res.prepare_payload();
decorate(res);
return res;
};
if(req.version < 11)
return err("HTTP version 1.1 required");
if(req.method() != http::verb::get)
return err("Wrong method");
if(! is_upgrade(req))
return err("Expected Upgrade request");
if(! req.count(http::field::host))
return err("Missing Host");
if(! req.count(http::field::sec_websocket_key))
return err("Missing Sec-WebSocket-Key");
if(! http::token_list{req[http::field::upgrade]}.exists("websocket"))
return err("Missing websocket Upgrade token");
auto const key = req[http::field::sec_websocket_key];
if(key.size() > detail::sec_ws_key_type::max_size_n)
return err("Invalid Sec-WebSocket-Key");
{
auto const version =
req[http::field::sec_websocket_version];
if(version.empty())
return err("Missing Sec-WebSocket-Version");
if(version != "13")
{
response_type res;
res.result(http::status::upgrade_required);
res.version = req.version;
res.set(http::field::sec_websocket_version, "13");
res.prepare_payload();
decorate(res);
return res;
}
}
response_type res;
{
detail::pmd_offer offer;
detail::pmd_offer unused;
pmd_read(offer, req);
pmd_negotiate(res, unused, offer, pmd_opts_);
}
res.result(http::status::switching_protocols);
res.version = req.version;
res.set(http::field::upgrade, "websocket");
res.set(http::field::connection, "upgrade");
{
detail::sec_ws_accept_type acc;
detail::make_sec_ws_accept(acc, key);
res.set(http::field::sec_websocket_accept, acc);
}
decorate(res);
return res;
}
template<class NextLayer>
void
stream<NextLayer>::
do_response(http::header<false> const& res,
detail::sec_ws_key_type const& key, error_code& ec)
{
bool const success = [&]()
{
if(res.version < 11)
return false;
if(res.result() != http::status::switching_protocols)
return false;
if(! http::token_list{res[http::field::connection]}.exists("upgrade"))
return false;
if(! http::token_list{res[http::field::upgrade]}.exists("websocket"))
return false;
if(res.count(http::field::sec_websocket_accept) != 1)
return false;
detail::sec_ws_accept_type acc;
detail::make_sec_ws_accept(acc, key);
if(acc.compare(
res[http::field::sec_websocket_accept]) != 0)
return false;
return true;
}();
if(! success)
{
ec = error::handshake_failed;
return;
}
ec.assign(0, ec.category());
detail::pmd_offer offer;
pmd_read(offer, res);
// VFALCO see if offer satisfies pmd_config_,
// return an error if not.
pmd_config_ = offer; // overwrite for now
open(role_type::client);
}
//------------------------------------------------------------------------------
template<class NextLayer>
void
stream<NextLayer>::
@@ -371,6 +125,69 @@ close()
pmd_.reset();
}
template<class NextLayer>
void
stream<NextLayer>::
reset()
{
failed_ = false;
rd_.cont = false;
wr_close_ = false;
wr_.cont = false;
wr_block_ = nullptr; // should be nullptr on close anyway
ping_data_ = nullptr; // should be nullptr on close anyway
stream_.buffer().consume(
stream_.buffer().size());
}
// Called before each read frame
template<class NextLayer>
void
stream<NextLayer>::
rd_begin()
{
// Maintain the read buffer
if(pmd_)
{
if(! rd_.buf || rd_.buf_size != rd_buf_size_)
{
rd_.buf_size = rd_buf_size_;
rd_.buf = boost::make_unique_noinit<
std::uint8_t[]>(rd_.buf_size);
}
}
}
// Called before each write frame
template<class NextLayer>
void
stream<NextLayer>::
wr_begin()
{
wr_.autofrag = wr_autofrag_;
wr_.compress = static_cast<bool>(pmd_);
// Maintain the write buffer
if( wr_.compress ||
role_ == role_type::client)
{
if(! wr_.buf || wr_.buf_size != wr_buf_size_)
{
wr_.buf_size = wr_buf_size_;
wr_.buf = boost::make_unique_noinit<
std::uint8_t[]>(wr_.buf_size);
}
}
else
{
wr_.buf_size = wr_buf_size_;
wr_.buf.reset();
}
}
//------------------------------------------------------------------------------
// Read fixed frame header from buffer
// Requires at least 2 bytes
//
@@ -559,49 +376,6 @@ read_fh2(detail::frame_header& fh,
code = close_code::none;
}
template<class NextLayer>
void
stream<NextLayer>::
rd_begin()
{
// Maintain the read buffer
if(pmd_)
{
if(! rd_.buf || rd_.buf_size != rd_buf_size_)
{
rd_.buf_size = rd_buf_size_;
rd_.buf = boost::make_unique_noinit<
std::uint8_t[]>(rd_.buf_size);
}
}
}
template<class NextLayer>
void
stream<NextLayer>::
wr_begin()
{
wr_.autofrag = wr_autofrag_;
wr_.compress = static_cast<bool>(pmd_);
// Maintain the write buffer
if( wr_.compress ||
role_ == role_type::client)
{
if(! wr_.buf || wr_.buf_size != wr_buf_size_)
{
wr_.buf_size = wr_buf_size_;
wr_.buf = boost::make_unique_noinit<
std::uint8_t[]>(wr_.buf_size);
}
}
else
{
wr_.buf_size = wr_buf_size_;
wr_.buf.reset();
}
}
template<class NextLayer>
template<class DynamicBuffer>
void
@@ -682,6 +456,238 @@ write_ping(DynamicBuffer& db,
db.commit(data.size());
}
//------------------------------------------------------------------------------
template<class NextLayer>
template<class Decorator>
request_type
stream<NextLayer>::
build_request(detail::sec_ws_key_type& key,
string_view host,
string_view target,
Decorator const& decorator)
{
request_type req;
req.target(target);
req.version = 11;
req.method(http::verb::get);
req.set(http::field::host, host);
req.set(http::field::upgrade, "websocket");
req.set(http::field::connection, "upgrade");
detail::make_sec_ws_key(key, maskgen_);
req.set(http::field::sec_websocket_key, key);
req.set(http::field::sec_websocket_version, "13");
if(pmd_opts_.client_enable)
{
detail::pmd_offer config;
config.accept = true;
config.server_max_window_bits =
pmd_opts_.server_max_window_bits;
config.client_max_window_bits =
pmd_opts_.client_max_window_bits;
config.server_no_context_takeover =
pmd_opts_.server_no_context_takeover;
config.client_no_context_takeover =
pmd_opts_.client_no_context_takeover;
detail::pmd_write(req, config);
}
decorator(req);
if(! req.count(http::field::user_agent))
req.set(http::field::user_agent,
BEAST_VERSION_STRING);
return req;
}
template<class NextLayer>
template<class Allocator, class Decorator>
response_type
stream<NextLayer>::
build_response(http::header<true,
http::basic_fields<Allocator>> const& req,
Decorator const& decorator)
{
auto const decorate =
[&decorator](response_type& res)
{
decorator(res);
if(! res.count(http::field::server))
{
BOOST_STATIC_ASSERT(sizeof(BEAST_VERSION_STRING) < 20);
static_string<20> s(BEAST_VERSION_STRING);
res.set(http::field::server, s);
}
};
auto err =
[&](std::string const& text)
{
response_type res;
res.version = req.version;
res.result(http::status::bad_request);
res.body = text;
res.prepare_payload();
decorate(res);
return res;
};
if(req.version < 11)
return err("HTTP version 1.1 required");
if(req.method() != http::verb::get)
return err("Wrong method");
if(! is_upgrade(req))
return err("Expected Upgrade request");
if(! req.count(http::field::host))
return err("Missing Host");
if(! req.count(http::field::sec_websocket_key))
return err("Missing Sec-WebSocket-Key");
if(! http::token_list{req[http::field::upgrade]}.exists("websocket"))
return err("Missing websocket Upgrade token");
auto const key = req[http::field::sec_websocket_key];
if(key.size() > detail::sec_ws_key_type::max_size_n)
return err("Invalid Sec-WebSocket-Key");
{
auto const version =
req[http::field::sec_websocket_version];
if(version.empty())
return err("Missing Sec-WebSocket-Version");
if(version != "13")
{
response_type res;
res.result(http::status::upgrade_required);
res.version = req.version;
res.set(http::field::sec_websocket_version, "13");
res.prepare_payload();
decorate(res);
return res;
}
}
response_type res;
{
detail::pmd_offer offer;
detail::pmd_offer unused;
pmd_read(offer, req);
pmd_negotiate(res, unused, offer, pmd_opts_);
}
res.result(http::status::switching_protocols);
res.version = req.version;
res.set(http::field::upgrade, "websocket");
res.set(http::field::connection, "upgrade");
{
detail::sec_ws_accept_type acc;
detail::make_sec_ws_accept(acc, key);
res.set(http::field::sec_websocket_accept, acc);
}
decorate(res);
return res;
}
template<class NextLayer>
template<class Decorator>
void
stream<NextLayer>::
do_accept(
Decorator const& decorator, error_code& ec)
{
http::request_parser<http::empty_body> p;
http::read_header(next_layer(),
stream_.buffer(), p, ec);
if(ec)
return;
do_accept(p.get(), decorator, ec);
}
template<class NextLayer>
template<class Allocator, class Decorator>
void
stream<NextLayer>::
do_accept(http::header<true,
http::basic_fields<Allocator>> const& req,
Decorator const& decorator, error_code& ec)
{
auto const res = build_response(req, decorator);
http::write(stream_, res, ec);
if(ec)
return;
if(res.result() != http::status::switching_protocols)
{
ec = error::handshake_failed;
// VFALCO TODO Respect keep alive setting, perform
// teardown if Connection: close.
return;
}
pmd_read(pmd_config_, req);
open(role_type::server);
}
template<class NextLayer>
template<class RequestDecorator>
void
stream<NextLayer>::
do_handshake(response_type* res_p,
string_view host,
string_view target,
RequestDecorator const& decorator,
error_code& ec)
{
response_type res;
reset();
detail::sec_ws_key_type key;
{
auto const req = build_request(
key, host, target, decorator);
pmd_read(pmd_config_, req);
http::write(stream_, req, ec);
}
if(ec)
return;
http::read(next_layer(), stream_.buffer(), res, ec);
if(ec)
return;
do_response(res, key, ec);
if(res_p)
*res_p = std::move(res);
}
template<class NextLayer>
void
stream<NextLayer>::
do_response(http::header<false> const& res,
detail::sec_ws_key_type const& key, error_code& ec)
{
bool const success = [&]()
{
if(res.version < 11)
return false;
if(res.result() != http::status::switching_protocols)
return false;
if(! http::token_list{res[http::field::connection]}.exists("upgrade"))
return false;
if(! http::token_list{res[http::field::upgrade]}.exists("websocket"))
return false;
if(res.count(http::field::sec_websocket_accept) != 1)
return false;
detail::sec_ws_accept_type acc;
detail::make_sec_ws_accept(acc, key);
if(acc.compare(
res[http::field::sec_websocket_accept]) != 0)
return false;
return true;
}();
if(! success)
{
ec = error::handshake_failed;
return;
}
ec.assign(0, ec.category());
detail::pmd_offer offer;
pmd_read(offer, res);
// VFALCO see if offer satisfies pmd_config_,
// return an error if not.
pmd_config_ = offer; // overwrite for now
open(role_type::client);
}
//------------------------------------------------------------------------------
} // websocket
} // beast

View File

@@ -114,8 +114,12 @@ class stream
{
friend class detail::frame_test;
friend class stream_test;
friend class frame_test;
buffered_read_stream<NextLayer, multi_buffer> stream_;
struct op {};
using control_cb_type =
std::function<void(frame_type, string_view)>;
/// Identifies the role of a WebSockets stream.
enum class role_type
@@ -127,35 +131,6 @@ class stream
server
};
friend class frame_test;
using control_cb_type =
std::function<void(frame_type, string_view)>;
struct op {};
detail::maskgen maskgen_; // source of mask keys
std::size_t rd_msg_max_ =
16 * 1024 * 1024; // max message size
bool wr_autofrag_ = true; // auto fragment
std::size_t wr_buf_size_ = 4096; // write buffer size
std::size_t rd_buf_size_ = 4096; // read buffer size
detail::opcode wr_opcode_ =
detail::opcode::text; // outgoing message type
control_cb_type ctrl_cb_; // control callback
role_type role_; // server or client
bool failed_; // the connection failed
bool wr_close_; // sent close frame
op* wr_block_; // op currenly writing
ping_data* ping_data_; // where to put the payload
detail::pausation rd_op_; // paused read op
detail::pausation wr_op_; // paused write op
detail::pausation ping_op_; // paused ping op
detail::pausation close_op_; // paused close op
close_reason cr_; // set from received close frame
// State information for the message being received
//
struct rd_t
@@ -182,8 +157,6 @@ class stream
std::unique_ptr<std::uint8_t[]> buf;
};
rd_t rd_;
// State information for the message being sent
//
struct wr_t
@@ -216,8 +189,6 @@ class stream
std::unique_ptr<std::uint8_t[]> buf;
};
wr_t wr_;
// State information for the permessage-deflate extension
struct pmd_t
{
@@ -228,6 +199,32 @@ class stream
zlib::inflate_stream zi;
};
buffered_read_stream<
NextLayer, multi_buffer> stream_; // the wrapped stream
detail::maskgen maskgen_; // source of mask keys
std::size_t rd_msg_max_ =
16 * 1024 * 1024; // max message size
bool wr_autofrag_ = true; // auto fragment
std::size_t wr_buf_size_ = 4096; // write buffer size
std::size_t rd_buf_size_ = 4096; // read buffer size
detail::opcode wr_opcode_ =
detail::opcode::text; // outgoing message type
control_cb_type ctrl_cb_; // control callback
role_type role_; // server or client
bool failed_; // the connection failed
bool wr_close_; // sent close frame
op* wr_block_; // op currenly writing
ping_data* ping_data_; // where to put the payload
detail::pausation rd_op_; // paused read op
detail::pausation wr_op_; // paused write op
detail::pausation ping_op_; // paused ping op
detail::pausation close_op_; // paused close op
close_reason cr_; // set from received close frame
rd_t rd_; // read state
wr_t wr_; // write state
// If not engaged, then permessage-deflate is not
// enabled for the currently active session.
std::unique_ptr<pmd_t> pmd_;
@@ -238,40 +235,6 @@ class stream
// Offer for clients, negotiated result for servers
detail::pmd_offer pmd_config_;
void
open(role_type role);
void
close();
template<class DynamicBuffer>
std::size_t
read_fh1(detail::frame_header& fh,
DynamicBuffer& db, close_code& code);
template<class DynamicBuffer>
void
read_fh2(detail::frame_header& fh,
DynamicBuffer& db, close_code& code);
// Called before receiving the first frame of each message
void
rd_begin();
// Called before sending the first frame of each message
//
void
wr_begin();
template<class DynamicBuffer>
void
write_close(DynamicBuffer& db, close_reason const& rc);
template<class DynamicBuffer>
void
write_ping(DynamicBuffer& db,
detail::opcode op, ping_data const& data);
public:
/// The type of the next layer.
using next_layer_type =
@@ -281,9 +244,18 @@ public:
using lowest_layer_type =
typename get_lowest_layer<next_layer_type>::type;
/** Destructor
Destroys the stream and all associated resources.
@note A stream object must not be destroyed while there
are pending asynchronous operations associated with it.
*/
~stream() = default;
/** Constructor
If @c NextLayer is move constructible, this function
If `NextLayer` is move constructible, this function
will move-construct a new stream from the existing stream.
@note The behavior of move assignment on or from streams
@@ -316,13 +288,6 @@ public:
explicit
stream(Args&&... args);
/** Destructor
@note A stream object must not be destroyed while there
are pending asynchronous operations associated with it.
*/
~stream() = default;
/** Return the `io_service` associated with the stream
This function may be used to obtain the `io_service` object
@@ -3309,30 +3274,61 @@ public:
ConstBufferSequence const& buffers, WriteHandler&& handler);
private:
template<class Decorator, class Handler> class accept_op;
template<class Handler> class close_op;
template<class Handler> class handshake_op;
template<class Handler> class ping_op;
template<class Handler> class response_op;
template<class Buffers, class Handler> class write_op;
template<class Buffers, class Handler> class write_frame_op;
template<class DynamicBuffer, class Handler> class read_op;
template<class DynamicBuffer, class Handler> class read_frame_op;
template<class Decorator,
class Handler> class accept_op;
template<class Handler> class close_op;
template<class Handler> class handshake_op;
template<class Handler> class ping_op;
template<class DynamicBuffer,
class Handler> class read_op;
template<class DynamicBuffer,
class Handler> class read_frame_op;
template<class Handler> class response_op;
template<class Buffers,
class Handler> class write_frame_op;
template<class Buffers,
class Handler> class write_op;
static
void
default_decorate_req(request_type&)
{
}
static void default_decorate_req(request_type&) {}
static void default_decorate_res(response_type&) {}
static
void
default_decorate_res(response_type&)
{
}
void open(role_type role);
void close();
void reset();
void rd_begin();
void wr_begin();
template<class DynamicBuffer>
std::size_t
read_fh1(detail::frame_header& fh,
DynamicBuffer& db, close_code& code);
template<class DynamicBuffer>
void
reset();
read_fh2(detail::frame_header& fh,
DynamicBuffer& db, close_code& code);
template<class DynamicBuffer>
void
write_close(DynamicBuffer& db, close_reason const& rc);
template<class DynamicBuffer>
void
write_ping(DynamicBuffer& db,
detail::opcode op, ping_data const& data);
template<class Decorator>
request_type
build_request(detail::sec_ws_key_type& key,
string_view host,
string_view target,
Decorator const& decorator);
template<class Allocator, class Decorator>
response_type
build_response(http::header<true,
http::basic_fields<Allocator>> const& req,
Decorator const& decorator);
template<class Decorator>
void
@@ -3353,19 +3349,6 @@ private:
RequestDecorator const& decorator,
error_code& ec);
template<class Decorator>
request_type
build_request(detail::sec_ws_key_type& key,
string_view host,
string_view target,
Decorator const& decorator);
template<class Allocator, class Decorator>
response_type
build_response(http::header<true,
http::basic_fields<Allocator>> const& req,
Decorator const& decorator);
void
do_response(http::header<false> const& resp,
detail::sec_ws_key_type const& key, error_code& ec);