From a22e7056d5a29062dee7ceb19abf73ed5ff92b8a Mon Sep 17 00:00:00 2001 From: Vinnie Falco Date: Fri, 24 Feb 2017 16:02:59 -0500 Subject: [PATCH] Fix race when write suspends --- CHANGELOG.md | 1 + include/beast/websocket/impl/write.ipp | 96 +++++++++++++++++--------- 2 files changed, 65 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cee21d97..39e7d487 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ WebSocket * Fix race in pings during reads * Fix race in close frames during reads +* Fix race when write suspends -------------------------------------------------------------------------------- diff --git a/include/beast/websocket/impl/write.ipp b/include/beast/websocket/impl/write.ipp index a278bbd5..5f6a46e2 100644 --- a/include/beast/websocket/impl/write.ipp +++ b/include/beast/websocket/impl/write.ipp @@ -42,7 +42,7 @@ class stream::write_frame_op detail::prepared_key key; std::uint64_t remain; int state = 0; - int entry; + int entry_state; data(Handler& handler_, stream& ws_, bool fin_, Buffers const& bs) @@ -179,40 +179,44 @@ operator()(error_code ec, d.fh.mask = d.ws.role_ == detail::role_type::client; + // entry_state determines which algorithm + // we will use to send. If we suspend, we + // will transition to entry_state + 1 on + // the resume. if(d.ws.wr_.compress) { - d.entry = do_deflate; + d.entry_state = do_deflate; } else if(! d.fh.mask) { if(! d.ws.wr_.autofrag) { - d.entry = do_nomask_nofrag; + d.entry_state = do_nomask_nofrag; } else { BOOST_ASSERT(d.ws.wr_.buf_size != 0); d.remain = buffer_size(d.cb); if(d.remain > d.ws.wr_.buf_size) - d.entry = do_nomask_frag; + d.entry_state = do_nomask_frag; else - d.entry = do_nomask_nofrag; + d.entry_state = do_nomask_nofrag; } } else { if(! d.ws.wr_.autofrag) { - d.entry = do_mask_nofrag; + d.entry_state = do_mask_nofrag; } else { BOOST_ASSERT(d.ws.wr_.buf_size != 0); d.remain = buffer_size(d.cb); if(d.remain > d.ws.wr_.buf_size) - d.entry = do_mask_frag; + d.entry_state = do_mask_frag; else - d.entry = do_mask_nofrag; + d.entry_state = do_mask_nofrag; } } d.state = do_maybe_suspend; @@ -221,7 +225,13 @@ operator()(error_code ec, //---------------------------------------------------------------------- case do_nomask_nofrag: + BOOST_ASSERT(! d.ws.wr_block_); + d.ws.wr_block_ = &d; + // [[fallthrough]] + + case do_nomask_nofrag + 1: { + BOOST_ASSERT(d.ws.wr_block_ == &d); d.fh.fin = d.fin; d.fh.len = buffer_size(d.cb); detail::write( @@ -229,8 +239,6 @@ operator()(error_code ec, d.ws.wr_.cont = ! d.fin; // Send frame d.state = do_upcall; - BOOST_ASSERT(! d.ws.wr_block_); - d.ws.wr_block_ = &d; boost::asio::async_write(d.ws.stream_, buffer_cat(d.fh_buf.data(), d.cb), std::move(*this)); @@ -240,7 +248,13 @@ operator()(error_code ec, //---------------------------------------------------------------------- case do_nomask_frag: + BOOST_ASSERT(! d.ws.wr_block_); + d.ws.wr_block_ = &d; + // [[fallthrough]] + + case do_nomask_frag + 1: { + BOOST_ASSERT(d.ws.wr_block_ == &d); auto const n = clamp( d.remain, d.ws.wr_.buf_size); d.remain -= n; @@ -251,9 +265,7 @@ operator()(error_code ec, d.ws.wr_.cont = ! d.fin; // Send frame d.state = d.remain == 0 ? - do_upcall : do_nomask_frag + 1; - BOOST_ASSERT(! d.ws.wr_block_); - d.ws.wr_block_ = &d; + do_upcall : do_nomask_frag + 2; boost::asio::async_write(d.ws.stream_, buffer_cat(d.fh_buf.data(), prepare_buffers(n, d.cb)), @@ -261,7 +273,7 @@ operator()(error_code ec, return; } - case do_nomask_frag + 1: + case do_nomask_frag + 2: d.cb.consume( bytes_transferred - d.fh_buf.size()); d.fh_buf.reset(); @@ -275,13 +287,19 @@ operator()(error_code ec, std::move(*this)); return; } - d.state = d.entry; + d.state = d.entry_state; break; //---------------------------------------------------------------------- case do_mask_nofrag: + BOOST_ASSERT(! d.ws.wr_block_); + d.ws.wr_block_ = &d; + // [[fallthrough]] + + case do_mask_nofrag + 1: { + BOOST_ASSERT(d.ws.wr_block_ == &d); d.remain = buffer_size(d.cb); d.fh.fin = d.fin; d.fh.len = d.remain; @@ -299,16 +317,14 @@ operator()(error_code ec, d.ws.wr_.cont = ! d.fin; // Send frame header and partial payload d.state = d.remain == 0 ? - do_upcall : do_mask_nofrag + 1; - BOOST_ASSERT(! d.ws.wr_block_); - d.ws.wr_block_ = &d; + do_upcall : do_mask_nofrag + 2; boost::asio::async_write(d.ws.stream_, buffer_cat(d.fh_buf.data(), b), std::move(*this)); return; } - case do_mask_nofrag + 1: + case do_mask_nofrag + 2: { d.cb.consume(d.ws.wr_.buf_size); auto const n = @@ -329,7 +345,13 @@ operator()(error_code ec, //---------------------------------------------------------------------- case do_mask_frag: + BOOST_ASSERT(! d.ws.wr_block_); + d.ws.wr_block_ = &d; + // [[fallthrough]] + + case do_mask_frag + 1: { + BOOST_ASSERT(d.ws.wr_block_ == &d); auto const n = clamp( d.remain, d.ws.wr_.buf_size); d.remain -= n; @@ -346,16 +368,14 @@ operator()(error_code ec, d.ws.wr_.cont = ! d.fin; // Send frame d.state = d.remain == 0 ? - do_upcall : do_mask_frag + 1; - BOOST_ASSERT(! d.ws.wr_block_); - d.ws.wr_block_ = &d; + do_upcall : do_mask_frag + 2; boost::asio::async_write(d.ws.stream_, buffer_cat(d.fh_buf.data(), b), std::move(*this)); return; } - case do_mask_frag + 1: + case do_mask_frag + 2: d.cb.consume( bytes_transferred - d.fh_buf.size()); d.fh_buf.reset(); @@ -369,13 +389,19 @@ operator()(error_code ec, std::move(*this)); return; } - d.state = d.entry; + d.state = d.entry_state; break; //---------------------------------------------------------------------- case do_deflate: + BOOST_ASSERT(! d.ws.wr_block_); + d.ws.wr_block_ = &d; + // [[fallthrough]] + + case do_deflate + 1: { + BOOST_ASSERT(d.ws.wr_block_ == &d); auto b = buffer(d.ws.wr_.buf.get(), d.ws.wr_.buf_size); auto const more = detail::deflate( @@ -414,16 +440,14 @@ operator()(error_code ec, d.ws.wr_.cont = ! d.fin; // Send frame d.state = more ? - do_deflate + 1 : do_deflate + 2; - BOOST_ASSERT(! d.ws.wr_block_); - d.ws.wr_block_ = &d; + do_deflate + 2 : do_deflate + 3; boost::asio::async_write(d.ws.stream_, buffer_cat(fh_buf.data(), b), std::move(*this)); return; } - case do_deflate + 1: + case do_deflate + 2: d.fh.op = opcode::cont; d.fh.rsv1 = false; BOOST_ASSERT(d.ws.wr_block_ == &d); @@ -435,10 +459,10 @@ operator()(error_code ec, std::move(*this)); return; } - d.state = d.entry; + d.state = d.entry_state; break; - case do_deflate + 2: + case do_deflate + 3: if(d.fh.fin && ( (d.ws.role_ == detail::role_type::client && d.ws.pmd_config_.client_no_context_takeover) || @@ -468,24 +492,32 @@ operator()(error_code ec, boost::asio::error::operation_aborted)); return; } - d.state = d.entry; + d.state = d.entry_state; break; } case do_maybe_suspend + 1: + BOOST_ASSERT(! d.ws.wr_block_); + d.ws.wr_block_ = &d; d.state = do_maybe_suspend + 2; + // The current context is safe but might not be + // the same as the one for this operation (since + // we are being called from a write operation). + // Call post to make sure we are invoked the same + // way as the final handler for this operation. d.ws.get_io_service().post(bind_handler( std::move(*this), ec)); return; case do_maybe_suspend + 2: + BOOST_ASSERT(d.ws.wr_block_ == &d); if(d.ws.failed_ || d.ws.wr_close_) { // call handler ec = boost::asio::error::operation_aborted; goto upcall; } - d.state = d.entry; + d.state = d.entry_state + 1; break; //----------------------------------------------------------------------