WebSocket close will automatically drain (API Change):

fix #642

* Calls to stream::close and stream::async_close will
  automatically perform the required read operations

Actions Required:

* Remove calling code which drains the connection after
  calling stream::close or stream::async_close
This commit is contained in:
Vinnie Falco
2017-08-01 20:15:07 -07:00
parent dc6a08d10a
commit 64327739f0
14 changed files with 1233 additions and 898 deletions

View File

@ -16,6 +16,16 @@ WebSocket:
* eof on accept returns error::closed
* Fix stream::read_size_hint calculation
API Changes:
* Calls to stream::close and stream::async_close will
automatically perform the required read operations
Actions Required:
* Remove calling code which drains the connection after
calling stream::close or stream::async_close
--------------------------------------------------------------------------------
Version 99:

View File

@ -70,6 +70,7 @@ variant ubasan
:
<cxxflags>"-msse4.2 -funsigned-char -fno-omit-frame-pointer -fsanitize=address,undefined -fsanitize-blacklist=libs/beast/build/blacklist.supp"
<linkflags>"-fsanitize=address,undefined"
<define>BOOST_USE_ASAN=1
;
#cxx11_hdr_type_traits

View File

@ -90,7 +90,7 @@ int main()
if(ec)
return fail("read", ec);
// Send a "close" frame to the other end, this is a websocket thing
// Close the WebSocket connection
ws.close(websocket::close_code::normal, ec);
if(ec)
return fail("close", ec);
@ -98,24 +98,5 @@ int main()
// The buffers() function helps print a ConstBufferSequence
std::cout << boost::beast::buffers(b.data()) << std::endl;
// WebSocket says that to close a connection you have
// to keep reading messages until you receive a close frame.
// Beast delivers the close frame as an error from read.
//
boost::beast::drain_buffer drain; // Throws everything away efficiently
for(;;)
{
// Keep reading messages...
ws.read(drain, ec);
// ...until we get the special error code
if(ec == websocket::error::closed)
break;
// Some other error occurred, report it and exit.
if(ec)
return fail("close", ec);
}
return EXIT_SUCCESS;
}

View File

@ -69,7 +69,7 @@ int main()
if(ec)
return fail("read", ec);
// Send a "close" frame to the other end, this is a websocket thing
// Close the WebSocket connection
ws.close(websocket::close_code::normal, ec);
if(ec)
return fail("close", ec);
@ -77,25 +77,6 @@ int main()
// The buffers() function helps print a ConstBufferSequence
std::cout << boost::beast::buffers(b.data()) << std::endl;
// WebSocket says that to close a connection you have
// to keep reading messages until you receive a close frame.
// Beast delivers the close frame as an error from read.
//
boost::beast::drain_buffer drain; // Throws everything away efficiently
for(;;)
{
// Keep reading messages...
ws.read(drain, ec);
// ...until we get the special error code
if(ec == websocket::error::closed)
break;
// Some other error occurred, report it and exit.
if(ec)
return fail("close", ec);
}
// If we get here the connection was cleanly closed
return EXIT_SUCCESS;
}

View File

@ -64,8 +64,7 @@ class server
boost::asio::basic_waitable_timer<
clock_type> timer_; // Needed for timeouts
boost::asio::io_service::strand strand_;// Needed when threads > 1
boost::beast::multi_buffer buffer_; // Stores the current message
boost::beast::drain_buffer drain_; // Helps discard data on close
boost::beast::multi_buffer buffer_; // Stores the current message
std::size_t id_; // A small unique id
public:

View File

@ -209,6 +209,7 @@ close(close_reason const& cr, error_code& ec)
{
static_assert(is_sync_stream<next_layer_type>::value,
"SyncStream requirements not met");
using beast::detail::clamp;
// If rd_close_ is set then we already sent a close
BOOST_ASSERT(! rd_close_);
if(wr_close_)
@ -219,12 +220,82 @@ close(close_reason const& cr, error_code& ec)
return;
}
wr_close_ = true;
detail::frame_streambuf fb;
write_close<flat_static_buffer_base>(fb, cr);
boost::asio::write(stream_, fb.data(), ec);
{
detail::frame_streambuf fb;
write_close<flat_static_buffer_base>(fb, cr);
boost::asio::write(stream_, fb.data(), ec);
}
failed_ = !!ec;
if(failed_)
return;
// Drain the connection
close_code code{};
if(rd_.remain > 0)
goto read_payload;
for(;;)
{
// Read frame header
while(! parse_fh(rd_.fh, rd_.buf, code))
{
if(code != close_code::none)
return do_fail(close_code::none,
error::failed, ec);
auto const bytes_transferred =
stream_.read_some(
rd_.buf.prepare(read_size(rd_.buf,
rd_.buf.max_size())), ec);
failed_ = !!ec;
if(failed_)
return;
rd_.buf.commit(bytes_transferred);
}
if(detail::is_control(rd_.fh.op))
{
// Process control frame
if(rd_.fh.op == detail::opcode::close)
{
BOOST_ASSERT(! rd_close_);
rd_close_ = true;
auto const mb = buffer_prefix(
clamp(rd_.fh.len),
rd_.buf.mutable_data());
if(rd_.fh.len > 0 && rd_.fh.mask)
detail::mask_inplace(mb, rd_.key);
detail::read_close(cr_, mb, code);
if(code != close_code::none)
// Protocol error
return do_fail(close_code::none,
error::failed, ec);
rd_.buf.consume(clamp(rd_.fh.len));
break;
}
rd_.buf.consume(clamp(rd_.fh.len));
}
else
{
read_payload:
while(rd_.buf.size() < rd_.remain)
{
rd_.remain -= rd_.buf.size();
rd_.buf.consume(rd_.buf.size());
auto const bytes_transferred =
stream_.read_some(
rd_.buf.prepare(read_size(rd_.buf,
rd_.buf.max_size())), ec);
failed_ = !!ec;
if(failed_)
return;
rd_.buf.commit(bytes_transferred);
}
BOOST_ASSERT(rd_.buf.size() >= rd_.remain);
rd_.buf.consume(clamp(rd_.remain));
rd_.remain = 0;
}
}
// _Close the WebSocket Connection_
do_fail(close_code::none, error::closed, ec);
if(ec == error::closed)
ec.assign(0, ec.category());
}
template<class NextLayer>

View File

@ -39,7 +39,7 @@ class stream<NextLayer>::fail_op
stream<NextLayer>& ws_;
int step_ = 0;
bool dispatched_ = false;
close_code code_;
std::uint16_t code_;
error_code ev_;
token tok_;
@ -47,12 +47,11 @@ public:
fail_op(fail_op&&) = default;
fail_op(fail_op const&) = default;
// send close code, then teardown
template<class DeducedHandler>
fail_op(
DeducedHandler&& h,
stream<NextLayer>& ws,
close_code code,
std::uint16_t code,
error_code ev)
: h_(std::forward<DeducedHandler>(h))
, ws_(ws)
@ -220,7 +219,7 @@ template<class NextLayer>
void
stream<NextLayer>::
do_fail(
close_code code, // if set, send a close frame first
std::uint16_t code, // if set, send a close frame first
error_code ev, // error code to use upon success
error_code& ec) // set to the error, else set to ev
{
@ -256,7 +255,7 @@ template<class Handler>
void
stream<NextLayer>::
do_async_fail(
close_code code, // if set, send a close frame first
std::uint16_t code, // if set, send a close frame first
error_code ev, // error code to use upon success
Handler&& handler)
{

View File

@ -196,34 +196,27 @@ operator()(
{
BOOST_ASSERT(! ws_.rd_close_);
ws_.rd_close_ = true;
detail::read_close(ws_.cr_, cb, code);
close_reason cr;
detail::read_close(cr, cb, code);
if(code != close_code::none)
{
// _Fail the WebSocket Connection_
return ws_.do_async_fail(
code, error::failed, std::move(h_));
}
ws_.cr_ = cr;
ws_.rd_.buf.consume(len);
if(ws_.ctrl_cb_)
ws_.ctrl_cb_(frame_type::close,
ws_.cr_.reason);
if(ws_.wr_close_)
// _Close the WebSocket Connection_
return ws_.do_async_fail(close_code::none,
if(! ws_.wr_close_)
// _Start the WebSocket Closing Handshake_
return ws_.do_async_fail(
cr.code == close_code::none ?
close_code::normal : cr.code,
error::closed, std::move(h_));
auto cr = ws_.cr_;
if(cr.code == close_code::none)
cr.code = close_code::normal;
cr.reason = "";
ws_.rd_.fb.consume(ws_.rd_.fb.size());
ws_.template write_close<
flat_static_buffer_base>(
ws_.rd_.fb, cr);
// _Start the WebSocket Closing Handshake_
return ws_.do_async_fail(
cr.code == close_code::none ?
close_code::normal :
static_cast<close_code>(cr.code),
// _Close the WebSocket Connection_
return ws_.do_async_fail(close_code::none,
error::closed, std::move(h_));
}
}
@ -434,11 +427,11 @@ operator()(
goto go_maybe_fill;
case do_maybe_fill:
dispatched_ = true;
if(ec)
break;
if(ws_.rd_.done)
break;
dispatched_ = true;
go_maybe_fill:
if(ws_.pmd_ && ws_.pmd_->rd_set)
@ -760,8 +753,9 @@ operator()(
do_read:
using buffers_type = typename
DynamicBuffer::mutable_buffers_type;
auto const rsh = ws_.read_size_hint(b_);
auto const size = clamp(
ws_.read_size_hint(b_), limit_);
rsh, limit_);
boost::optional<buffers_type> mb;
try
{
@ -996,6 +990,7 @@ loop:
{
if(code != close_code::none)
{
// _Fail the WebSocket Connection_
do_fail(code, error::failed, ec);
return bytes_written;
}
@ -1057,25 +1052,28 @@ loop:
{
BOOST_ASSERT(! rd_close_);
rd_close_ = true;
detail::read_close(cr_, cb, code);
close_reason cr;
detail::read_close(cr, cb, code);
if(code != close_code::none)
{
// _Fail the WebSocket Connection_
do_fail(code, error::failed, ec);
return bytes_written;
}
cr_ = cr;
rd_.buf.consume(len);
if(ctrl_cb_)
ctrl_cb_(frame_type::close, cr_.reason);
if(! wr_close_)
{
// _Start the WebSocket Closing Handshake_
do_fail(
cr_.code == close_code::none ?
close_code::normal :
static_cast<close_code>(cr_.code),
error::closed,
ec);
cr.code == close_code::none ?
close_code::normal : cr.code,
error::closed, ec);
return bytes_written;
}
// _Close the WebSocket Connection_
do_fail(close_code::none, error::closed, ec);
return bytes_written;
}
@ -1129,6 +1127,7 @@ loop:
(rd_.remain == 0 && rd_.fh.fin &&
! rd_.utf8.finish()))
{
// _Fail the WebSocket Connection_
do_fail(
close_code::bad_payload,
error::failed,
@ -1163,6 +1162,7 @@ loop:
(rd_.remain == 0 && rd_.fh.fin &&
! rd_.utf8.finish()))
{
// _Fail the WebSocket Connection_
do_fail(
close_code::bad_payload,
error::failed,
@ -1292,6 +1292,7 @@ loop:
rd_.remain == 0 && rd_.fh.fin &&
! rd_.utf8.finish()))
{
// _Fail the WebSocket Connection_
do_fail(
close_code::bad_payload,
error::failed,

View File

@ -57,39 +57,68 @@ read_size_hint(
std::size_t initial_size) const
{
using beast::detail::clamp;
// no permessage-deflate
std::size_t result;
BOOST_ASSERT(initial_size > 0);
if(! pmd_ || (! rd_.done && ! pmd_->rd_set))
{
// fresh message
if(rd_.done)
return initial_size;
// current message is uncompressed
if(rd_.fh.fin)
return clamp(rd_.remain);
if(rd_.done)
{
// first message frame
result = initial_size;
goto done;
}
else if(rd_.fh.fin)
{
// last message frame
BOOST_ASSERT(rd_.remain > 0);
result = clamp(rd_.remain);
goto done;
}
}
return (std::max)(
result = (std::max)(
initial_size, clamp(rd_.remain));
done:
BOOST_ASSERT(result != 0);
return result;
}
template<class NextLayer>
template<class DynamicBuffer, class>
std::size_t
stream<NextLayer>::
read_size_hint(
DynamicBuffer& buffer) const
read_size_hint(DynamicBuffer& buffer) const
{
static_assert(is_dynamic_buffer<DynamicBuffer>::value,
"DynamicBuffer requirements not met");
#if 1
auto const initial_size = (std::min)(
+tcp_frame_size,
buffer.max_size() - buffer.size());
if(initial_size == 0)
return 1; // buffer is full
return read_size_hint(initial_size);
#else
using beast::detail::clamp;
// no permessage-deflate
std::size_t result;
if(! pmd_ || (! rd_.done && ! pmd_->rd_set))
{
// fresh message
// current message is uncompressed
if(rd_.done)
return (std::min)(
{
// first message frame
auto const n = (std::min)(
buffer.max_size(),
(std::max)(+tcp_frame_size,
(std::max)(
+tcp_frame_size,
buffer.capacity() - buffer.size()));
if(n > 0)
return n;
return 1;
}
if(rd_.fh.fin)
{
@ -104,6 +133,10 @@ read_size_hint(
+tcp_frame_size,
clamp(rd_.remain)),
buffer.capacity() - buffer.size()));
done:
BOOST_ASSERT(result != 0);
return result;
#endif
}
template<class NextLayer>

View File

@ -446,10 +446,10 @@ public:
frame and whether or not the permessage-deflate extension is
enabled.
@param initial_size A size representing the caller's desired
buffer size for when there is no information which may be used
to calculate a more specific value. For example, when reading
the first frame header of a message.
@param initial_size A non-zero size representing the caller's
desired buffer size for when there is no information which may
be used to calculate a more specific value. For example, when
reading the first frame header of a message.
*/
std::size_t
read_size_hint(
@ -3816,14 +3816,14 @@ private:
void
do_fail(
close_code code,
std::uint16_t code,
error_code ev,
error_code& ec);
template<class Handler>
void
do_async_fail(
close_code code,
std::uint16_t code,
error_code ev,
Handler&& handler);
};

View File

@ -12,7 +12,9 @@
#include <boost/beast/core/drain_buffer.hpp>
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/core/flat_static_buffer.hpp>
#include <boost/beast/core/multi_buffer.hpp>
#include <boost/beast/core/static_buffer.hpp>
#include <boost/beast/unit_test/suite.hpp>
#include <boost/asio/streambuf.hpp>
@ -37,7 +39,9 @@ public:
{
check<drain_buffer>();
check<flat_buffer>();
check<flat_static_buffer<1024>>();
check<multi_buffer>();
check<static_buffer<1024>>();
check<boost::asio::streambuf>();
}
};

File diff suppressed because it is too large Load Diff

View File

@ -216,27 +216,7 @@ private:
{
if(ec)
return fail("on_close", ec);
do_drain();
}
void
do_drain()
{
ws_.async_read(buffer_,
alloc_.wrap(std::bind(
&connection::on_drain,
shared_from_this(),
ph::_1)));
}
void
on_drain(error_code ec)
{
if(ec)
return fail("on_drain", ec);
do_drain();
}
};
class timer

View File

@ -50,14 +50,23 @@ class stream_impl
template<class Handler, class Buffers>
class read_op_impl;
enum class status
{
ok,
eof,
reset
};
struct state
{
friend class stream;
std::mutex m;
buffer_type b;
std::condition_variable cv;
std::unique_ptr<read_op> op;
boost::asio::io_service& ios;
bool eof = false;
status code = status::ok;
fail_counter* fc = nullptr;
std::size_t nread = 0;
std::size_t nwrite = 0;
@ -75,7 +84,24 @@ class stream_impl
{
}
friend class stream;
~state()
{
BOOST_ASSERT(! op);
}
void
on_write()
{
if(op)
{
std::unique_ptr<read_op> op_ = std::move(op);
op_->operator()();
}
else
{
cv.notify_all();
}
}
};
state s0_;
@ -86,52 +112,57 @@ public:
boost::asio::io_service& ios,
fail_counter* fc)
: s0_(ios, fc)
, s1_(ios, fc)
, s1_(ios, nullptr)
{
}
~stream_impl()
{
BOOST_ASSERT(! s0_.op);
BOOST_ASSERT(! s1_.op);
}
};
template<class Handler, class Buffers>
class stream_impl::read_op_impl : public stream_impl::read_op
{
state& s_;
Buffers b_;
Handler h_;
public:
read_op_impl(state& s,
Buffers const& b, Handler&& h)
: s_(s)
, b_(b)
, h_(std::move(h))
class lambda
{
}
state& s_;
Buffers b_;
Handler h_;
read_op_impl(state& s,
Buffers const& b, Handler const& h)
: s_(s)
, b_(b)
, h_(h)
{
}
public:
lambda(lambda&&) = default;
lambda(lambda const&) = default;
void
operator()() override;
};
template<class Handler, class Buffers>
void
stream_impl::
read_op_impl<Handler, Buffers>::
operator()()
{
using boost::asio::buffer_copy;
using boost::asio::buffer_size;
s_.ios.post(
[&]()
lambda(state& s, Buffers const& b, Handler&& h)
: s_(s)
, b_(b)
, h_(std::move(h))
{
BOOST_ASSERT(s_.op);
}
lambda(state& s, Buffers const& b, Handler const& h)
: s_(s)
, b_(b)
, h_(h)
{
}
void
post()
{
s_.ios.post(std::move(*this));
}
void
operator()()
{
using boost::asio::buffer_copy;
using boost::asio::buffer_size;
std::unique_lock<std::mutex> lock{s_.m};
BOOST_ASSERT(! s_.op);
if(s_.b.size() > 0)
{
auto const bytes_transferred = buffer_copy(
@ -139,7 +170,6 @@ operator()()
s_.b.consume(bytes_transferred);
auto& s = s_;
Handler h{std::move(h_)};
s.op.reset(nullptr);
lock.unlock();
++s.nread;
s.ios.post(bind_handler(std::move(h),
@ -147,17 +177,40 @@ operator()()
}
else
{
BOOST_ASSERT(s_.eof);
BOOST_ASSERT(s_.code != status::ok);
auto& s = s_;
Handler h{std::move(h_)};
s.op.reset(nullptr);
lock.unlock();
++s.nread;
s.ios.post(bind_handler(std::move(h),
boost::asio::error::eof, 0));
error_code ec;
if(s.code == status::eof)
ec = boost::asio::error::eof;
else if(s.code == status::reset)
ec = boost::asio::error::connection_reset;
s.ios.post(bind_handler(std::move(h), ec, 0));
}
});
}
}
};
lambda fn_;
public:
read_op_impl(state& s, Buffers const& b, Handler&& h)
: fn_(s, b, std::move(h))
{
}
read_op_impl(state& s, Buffers const& b, Handler const& h)
: fn_(s, b, h)
{
}
void
operator()() override
{
fn_.post();
}
};
} // detail
@ -175,6 +228,8 @@ operator()()
*/
class stream
{
using status = detail::stream_impl::status;
std::shared_ptr<detail::stream_impl> impl_;
detail::stream_impl::state& in_;
detail::stream_impl::state& out_;
@ -191,10 +246,30 @@ class stream
public:
using buffer_type = flat_buffer;
~stream() = default;
stream(stream&&) = default;
stream& operator=(stream const&) = delete;
/// Destructor
~stream()
{
if(! impl_)
return;
BOOST_ASSERT(! in_.op);
std::unique_lock<std::mutex> lock{out_.m};
if(out_.code == status::ok)
{
out_.code = status::reset;
out_.on_write();
}
lock.unlock();
}
stream(stream&& other)
: impl_(std::move(other.impl_))
, in_(other.in_)
, out_(other.out_)
{
}
/// Constructor
explicit
stream(
@ -298,15 +373,17 @@ public:
buffer_size(*in_.b.data().begin())};
}
/// Clear the buffer holding the input data
/*
/// Appends a string to the pending input data
void
clear()
str(string_view s)
{
in_.b.consume((std::numeric_limits<
std::size_t>::max)());
using boost::asio::buffer;
using boost::asio::buffer_copy;
std::unique_lock<std::mutex> lock{in_.m};
in_.b.commit(buffer_copy(
in_.b.prepare(s.size()),
buffer(s.data(), s.size())));
}
*/
/// Return the number of reads
std::size_t
@ -409,7 +486,9 @@ read_some(MutableBufferSequence const& buffers,
in_.cv.wait(lock,
[&]()
{
return in_.b.size() > 0 || in_.eof;
return
in_.b.size() > 0 ||
in_.code != status::ok;
});
std::size_t bytes_transferred;
if(in_.b.size() > 0)
@ -421,9 +500,12 @@ read_some(MutableBufferSequence const& buffers,
}
else
{
BOOST_ASSERT(in_.eof);
BOOST_ASSERT(in_.code != status::ok);
bytes_transferred = 0;
ec = boost::asio::error::eof;
if(in_.code == status::eof)
ec = boost::asio::error::eof;
else if(in_.code == status::reset)
ec = boost::asio::error::connection_reset;
}
++in_.nread;
return bytes_transferred;
@ -433,7 +515,8 @@ template<class MutableBufferSequence, class ReadHandler>
async_return_type<
ReadHandler, void(error_code, std::size_t)>
stream::
async_read_some(MutableBufferSequence const& buffers,
async_read_some(
MutableBufferSequence const& buffers,
ReadHandler&& handler)
{
static_assert(is_mutable_buffer_sequence<
@ -454,14 +537,7 @@ async_read_some(MutableBufferSequence const& buffers,
}
{
std::unique_lock<std::mutex> lock{in_.m};
if(in_.eof)
{
lock.unlock();
++in_.nread;
in_.ios.post(bind_handler(init.completion_handler,
boost::asio::error::eof, 0));
}
else if(buffer_size(buffers) == 0 ||
if(buffer_size(buffers) == 0 ||
buffer_size(in_.b.data()) > 0)
{
auto const bytes_transferred = buffer_copy(
@ -472,6 +548,18 @@ async_read_some(MutableBufferSequence const& buffers,
in_.ios.post(bind_handler(init.completion_handler,
error_code{}, bytes_transferred));
}
else if(in_.code != status::ok)
{
lock.unlock();
++in_.nread;
error_code ec;
if(in_.code == status::eof)
ec = boost::asio::error::eof;
else if(in_.code == status::reset)
ec = boost::asio::error::connection_reset;
in_.ios.post(bind_handler(
init.completion_handler, ec, 0));
}
else
{
in_.op.reset(new
@ -492,7 +580,7 @@ write_some(ConstBufferSequence const& buffers)
static_assert(is_const_buffer_sequence<
ConstBufferSequence>::value,
"ConstBufferSequence requirements not met");
BOOST_ASSERT(! out_.eof);
BOOST_ASSERT(out_.code == status::ok);
error_code ec;
auto const bytes_transferred =
write_some(buffers, ec);
@ -512,7 +600,7 @@ write_some(
"ConstBufferSequence requirements not met");
using boost::asio::buffer_copy;
using boost::asio::buffer_size;
BOOST_ASSERT(! out_.eof);
BOOST_ASSERT(out_.code == status::ok);
if(in_.fc && in_.fc->fail(ec))
return 0;
auto const n = (std::min)(
@ -521,10 +609,7 @@ write_some(
auto const bytes_transferred =
buffer_copy(out_.b.prepare(n), buffers);
out_.b.commit(bytes_transferred);
if(out_.op)
out_.op.get()->operator()();
else
out_.cv.notify_all();
out_.on_write();
lock.unlock();
++out_.nwrite;
ec.assign(0, ec.category());
@ -543,7 +628,7 @@ async_write_some(ConstBufferSequence const& buffers,
"ConstBufferSequence requirements not met");
using boost::asio::buffer_copy;
using boost::asio::buffer_size;
BOOST_ASSERT(! out_.eof);
BOOST_ASSERT(out_.code == status::ok);
async_completion<WriteHandler,
void(error_code, std::size_t)> init{handler};
if(in_.fc)
@ -559,10 +644,7 @@ async_write_some(ConstBufferSequence const& buffers,
auto const bytes_transferred =
buffer_copy(out_.b.prepare(n), buffers);
out_.b.commit(bytes_transferred);
if(out_.op)
out_.op.get()->operator()();
else
out_.cv.notify_all();
out_.on_write();
lock.unlock();
++out_.nwrite;
in_.ios.post(bind_handler(init.completion_handler,
@ -607,14 +689,12 @@ void
stream::
close()
{
BOOST_ASSERT(! in_.op);
std::lock_guard<std::mutex> lock{out_.m};
if(! out_.eof)
if(out_.code == status::ok)
{
out_.eof = true;
if(out_.op)
out_.op.get()->operator()();
else
out_.cv.notify_all();
out_.code = status::eof;
out_.on_write();
}
}