Fix race when write suspends

This commit is contained in:
Vinnie Falco
2017-02-24 16:02:59 -05:00
parent 9554bd105d
commit a22e7056d5
2 changed files with 65 additions and 32 deletions

View File

@@ -4,6 +4,7 @@ WebSocket
* Fix race in pings during reads * Fix race in pings during reads
* Fix race in close frames during reads * Fix race in close frames during reads
* Fix race when write suspends
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------

View File

@@ -42,7 +42,7 @@ class stream<NextLayer>::write_frame_op
detail::prepared_key key; detail::prepared_key key;
std::uint64_t remain; std::uint64_t remain;
int state = 0; int state = 0;
int entry; int entry_state;
data(Handler& handler_, stream<NextLayer>& ws_, data(Handler& handler_, stream<NextLayer>& ws_,
bool fin_, Buffers const& bs) bool fin_, Buffers const& bs)
@@ -179,40 +179,44 @@ operator()(error_code ec,
d.fh.mask = d.fh.mask =
d.ws.role_ == detail::role_type::client; d.ws.role_ == detail::role_type::client;
// entry_state determines which algorithm
// we will use to send. If we suspend, we
// will transition to entry_state + 1 on
// the resume.
if(d.ws.wr_.compress) if(d.ws.wr_.compress)
{ {
d.entry = do_deflate; d.entry_state = do_deflate;
} }
else if(! d.fh.mask) else if(! d.fh.mask)
{ {
if(! d.ws.wr_.autofrag) if(! d.ws.wr_.autofrag)
{ {
d.entry = do_nomask_nofrag; d.entry_state = do_nomask_nofrag;
} }
else else
{ {
BOOST_ASSERT(d.ws.wr_.buf_size != 0); BOOST_ASSERT(d.ws.wr_.buf_size != 0);
d.remain = buffer_size(d.cb); d.remain = buffer_size(d.cb);
if(d.remain > d.ws.wr_.buf_size) if(d.remain > d.ws.wr_.buf_size)
d.entry = do_nomask_frag; d.entry_state = do_nomask_frag;
else else
d.entry = do_nomask_nofrag; d.entry_state = do_nomask_nofrag;
} }
} }
else else
{ {
if(! d.ws.wr_.autofrag) if(! d.ws.wr_.autofrag)
{ {
d.entry = do_mask_nofrag; d.entry_state = do_mask_nofrag;
} }
else else
{ {
BOOST_ASSERT(d.ws.wr_.buf_size != 0); BOOST_ASSERT(d.ws.wr_.buf_size != 0);
d.remain = buffer_size(d.cb); d.remain = buffer_size(d.cb);
if(d.remain > d.ws.wr_.buf_size) if(d.remain > d.ws.wr_.buf_size)
d.entry = do_mask_frag; d.entry_state = do_mask_frag;
else else
d.entry = do_mask_nofrag; d.entry_state = do_mask_nofrag;
} }
} }
d.state = do_maybe_suspend; d.state = do_maybe_suspend;
@@ -221,7 +225,13 @@ operator()(error_code ec,
//---------------------------------------------------------------------- //----------------------------------------------------------------------
case do_nomask_nofrag: case do_nomask_nofrag:
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
// [[fallthrough]]
case do_nomask_nofrag + 1:
{ {
BOOST_ASSERT(d.ws.wr_block_ == &d);
d.fh.fin = d.fin; d.fh.fin = d.fin;
d.fh.len = buffer_size(d.cb); d.fh.len = buffer_size(d.cb);
detail::write<static_streambuf>( detail::write<static_streambuf>(
@@ -229,8 +239,6 @@ operator()(error_code ec,
d.ws.wr_.cont = ! d.fin; d.ws.wr_.cont = ! d.fin;
// Send frame // Send frame
d.state = do_upcall; d.state = do_upcall;
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
boost::asio::async_write(d.ws.stream_, boost::asio::async_write(d.ws.stream_,
buffer_cat(d.fh_buf.data(), d.cb), buffer_cat(d.fh_buf.data(), d.cb),
std::move(*this)); std::move(*this));
@@ -240,7 +248,13 @@ operator()(error_code ec,
//---------------------------------------------------------------------- //----------------------------------------------------------------------
case do_nomask_frag: case do_nomask_frag:
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
// [[fallthrough]]
case do_nomask_frag + 1:
{ {
BOOST_ASSERT(d.ws.wr_block_ == &d);
auto const n = clamp( auto const n = clamp(
d.remain, d.ws.wr_.buf_size); d.remain, d.ws.wr_.buf_size);
d.remain -= n; d.remain -= n;
@@ -251,9 +265,7 @@ operator()(error_code ec,
d.ws.wr_.cont = ! d.fin; d.ws.wr_.cont = ! d.fin;
// Send frame // Send frame
d.state = d.remain == 0 ? d.state = d.remain == 0 ?
do_upcall : do_nomask_frag + 1; do_upcall : do_nomask_frag + 2;
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
boost::asio::async_write(d.ws.stream_, boost::asio::async_write(d.ws.stream_,
buffer_cat(d.fh_buf.data(), buffer_cat(d.fh_buf.data(),
prepare_buffers(n, d.cb)), prepare_buffers(n, d.cb)),
@@ -261,7 +273,7 @@ operator()(error_code ec,
return; return;
} }
case do_nomask_frag + 1: case do_nomask_frag + 2:
d.cb.consume( d.cb.consume(
bytes_transferred - d.fh_buf.size()); bytes_transferred - d.fh_buf.size());
d.fh_buf.reset(); d.fh_buf.reset();
@@ -275,13 +287,19 @@ operator()(error_code ec,
std::move(*this)); std::move(*this));
return; return;
} }
d.state = d.entry; d.state = d.entry_state;
break; break;
//---------------------------------------------------------------------- //----------------------------------------------------------------------
case do_mask_nofrag: case do_mask_nofrag:
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
// [[fallthrough]]
case do_mask_nofrag + 1:
{ {
BOOST_ASSERT(d.ws.wr_block_ == &d);
d.remain = buffer_size(d.cb); d.remain = buffer_size(d.cb);
d.fh.fin = d.fin; d.fh.fin = d.fin;
d.fh.len = d.remain; d.fh.len = d.remain;
@@ -299,16 +317,14 @@ operator()(error_code ec,
d.ws.wr_.cont = ! d.fin; d.ws.wr_.cont = ! d.fin;
// Send frame header and partial payload // Send frame header and partial payload
d.state = d.remain == 0 ? d.state = d.remain == 0 ?
do_upcall : do_mask_nofrag + 1; do_upcall : do_mask_nofrag + 2;
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
boost::asio::async_write(d.ws.stream_, boost::asio::async_write(d.ws.stream_,
buffer_cat(d.fh_buf.data(), b), buffer_cat(d.fh_buf.data(), b),
std::move(*this)); std::move(*this));
return; return;
} }
case do_mask_nofrag + 1: case do_mask_nofrag + 2:
{ {
d.cb.consume(d.ws.wr_.buf_size); d.cb.consume(d.ws.wr_.buf_size);
auto const n = auto const n =
@@ -329,7 +345,13 @@ operator()(error_code ec,
//---------------------------------------------------------------------- //----------------------------------------------------------------------
case do_mask_frag: case do_mask_frag:
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
// [[fallthrough]]
case do_mask_frag + 1:
{ {
BOOST_ASSERT(d.ws.wr_block_ == &d);
auto const n = clamp( auto const n = clamp(
d.remain, d.ws.wr_.buf_size); d.remain, d.ws.wr_.buf_size);
d.remain -= n; d.remain -= n;
@@ -346,16 +368,14 @@ operator()(error_code ec,
d.ws.wr_.cont = ! d.fin; d.ws.wr_.cont = ! d.fin;
// Send frame // Send frame
d.state = d.remain == 0 ? d.state = d.remain == 0 ?
do_upcall : do_mask_frag + 1; do_upcall : do_mask_frag + 2;
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
boost::asio::async_write(d.ws.stream_, boost::asio::async_write(d.ws.stream_,
buffer_cat(d.fh_buf.data(), b), buffer_cat(d.fh_buf.data(), b),
std::move(*this)); std::move(*this));
return; return;
} }
case do_mask_frag + 1: case do_mask_frag + 2:
d.cb.consume( d.cb.consume(
bytes_transferred - d.fh_buf.size()); bytes_transferred - d.fh_buf.size());
d.fh_buf.reset(); d.fh_buf.reset();
@@ -369,13 +389,19 @@ operator()(error_code ec,
std::move(*this)); std::move(*this));
return; return;
} }
d.state = d.entry; d.state = d.entry_state;
break; break;
//---------------------------------------------------------------------- //----------------------------------------------------------------------
case do_deflate: case do_deflate:
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
// [[fallthrough]]
case do_deflate + 1:
{ {
BOOST_ASSERT(d.ws.wr_block_ == &d);
auto b = buffer(d.ws.wr_.buf.get(), auto b = buffer(d.ws.wr_.buf.get(),
d.ws.wr_.buf_size); d.ws.wr_.buf_size);
auto const more = detail::deflate( auto const more = detail::deflate(
@@ -414,16 +440,14 @@ operator()(error_code ec,
d.ws.wr_.cont = ! d.fin; d.ws.wr_.cont = ! d.fin;
// Send frame // Send frame
d.state = more ? d.state = more ?
do_deflate + 1 : do_deflate + 2; do_deflate + 2 : do_deflate + 3;
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
boost::asio::async_write(d.ws.stream_, boost::asio::async_write(d.ws.stream_,
buffer_cat(fh_buf.data(), b), buffer_cat(fh_buf.data(), b),
std::move(*this)); std::move(*this));
return; return;
} }
case do_deflate + 1: case do_deflate + 2:
d.fh.op = opcode::cont; d.fh.op = opcode::cont;
d.fh.rsv1 = false; d.fh.rsv1 = false;
BOOST_ASSERT(d.ws.wr_block_ == &d); BOOST_ASSERT(d.ws.wr_block_ == &d);
@@ -435,10 +459,10 @@ operator()(error_code ec,
std::move(*this)); std::move(*this));
return; return;
} }
d.state = d.entry; d.state = d.entry_state;
break; break;
case do_deflate + 2: case do_deflate + 3:
if(d.fh.fin && ( if(d.fh.fin && (
(d.ws.role_ == detail::role_type::client && (d.ws.role_ == detail::role_type::client &&
d.ws.pmd_config_.client_no_context_takeover) || d.ws.pmd_config_.client_no_context_takeover) ||
@@ -468,24 +492,32 @@ operator()(error_code ec,
boost::asio::error::operation_aborted)); boost::asio::error::operation_aborted));
return; return;
} }
d.state = d.entry; d.state = d.entry_state;
break; break;
} }
case do_maybe_suspend + 1: case do_maybe_suspend + 1:
BOOST_ASSERT(! d.ws.wr_block_);
d.ws.wr_block_ = &d;
d.state = do_maybe_suspend + 2; d.state = do_maybe_suspend + 2;
// The current context is safe but might not be
// the same as the one for this operation (since
// we are being called from a write operation).
// Call post to make sure we are invoked the same
// way as the final handler for this operation.
d.ws.get_io_service().post(bind_handler( d.ws.get_io_service().post(bind_handler(
std::move(*this), ec)); std::move(*this), ec));
return; return;
case do_maybe_suspend + 2: case do_maybe_suspend + 2:
BOOST_ASSERT(d.ws.wr_block_ == &d);
if(d.ws.failed_ || d.ws.wr_close_) if(d.ws.failed_ || d.ws.wr_close_)
{ {
// call handler // call handler
ec = boost::asio::error::operation_aborted; ec = boost::asio::error::operation_aborted;
goto upcall; goto upcall;
} }
d.state = d.entry; d.state = d.entry_state + 1;
break; break;
//---------------------------------------------------------------------- //----------------------------------------------------------------------