diff --git a/CHANGELOG.md b/CHANGELOG.md index 600c5c60..b0b66b6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ * Refactor static_string * Refactor base64 +* Use static_string for WebSocket handshakes -------------------------------------------------------------------------------- diff --git a/include/beast/websocket/detail/hybi13.hpp b/include/beast/websocket/detail/hybi13.hpp index ec09918b..d2dc3fb3 100644 --- a/include/beast/websocket/detail/hybi13.hpp +++ b/include/beast/websocket/detail/hybi13.hpp @@ -8,8 +8,10 @@ #ifndef BEAST_WEBSOCKET_DETAIL_HYBI13_HPP #define BEAST_WEBSOCKET_DETAIL_HYBI13_HPP +#include #include #include +#include #include #include #include @@ -20,11 +22,17 @@ namespace beast { namespace websocket { namespace detail { +using sec_ws_key_type = static_string< + beast::detail::base64::encoded_size(16)>; + +using sec_ws_accept_type = static_string< + beast::detail::base64::encoded_size(20)>; + template -std::string -make_sec_ws_key(Gen& g) +void +make_sec_ws_key(sec_ws_key_type& key, Gen& g) { - std::array a; + char a[16]; for(int i = 0; i < 16; i += 4) { auto const v = g(); @@ -33,24 +41,27 @@ make_sec_ws_key(Gen& g) a[i+2] = (v >> 16) & 0xff; a[i+3] = (v >> 24) & 0xff; } - return beast::detail::base64_encode( - a.data(), a.size()); + key.resize(key.max_size()); + key.resize(beast::detail::base64::encode( + key.data(), &a[0], 16)); } template -std::string -make_sec_ws_accept(boost::string_ref const& key) +void +make_sec_ws_accept(sec_ws_accept_type& accept, + boost::string_ref key) { - std::string s(key.data(), key.size()); - s += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + BOOST_ASSERT(key.size() <= sec_ws_key_type::max_size_n); + static_string m(key); + m.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); beast::detail::sha1_context ctx; beast::detail::init(ctx); - beast::detail::update(ctx, s.data(), s.size()); - std::array digest; - beast::detail::finish(ctx, digest.data()); - return beast::detail::base64_encode( - digest.data(), digest.size()); + beast::detail::update(ctx, m.data(), m.size()); + char digest[beast::detail::sha1_context::digest_size]; + beast::detail::finish(ctx, &digest[0]); + accept.resize(accept.max_size()); + accept.resize(beast::detail::base64::encode( + accept.data(), &digest[0], sizeof(digest))); } } // detail diff --git a/include/beast/websocket/impl/handshake.ipp b/include/beast/websocket/impl/handshake.ipp index 534ca2e1..c4c5e142 100644 --- a/include/beast/websocket/impl/handshake.ipp +++ b/include/beast/websocket/impl/handshake.ipp @@ -34,7 +34,7 @@ class stream::handshake_op bool cont; stream& ws; response_type* res_p; - std::string key; + detail::sec_ws_key_type key; request_type req; response_type res; int state = 0; diff --git a/include/beast/websocket/impl/stream.ipp b/include/beast/websocket/impl/stream.ipp index 717e9aef..618a9a4a 100644 --- a/include/beast/websocket/impl/stream.ipp +++ b/include/beast/websocket/impl/stream.ipp @@ -134,7 +134,7 @@ do_handshake(response_type* res_p, { response_type res; reset(); - std::string key; + detail::sec_ws_key_type key; { auto const req = build_request( key, host, resource, decorator); @@ -155,7 +155,7 @@ template template request_type stream:: -build_request(std::string& key, +build_request(detail::sec_ws_key_type& key, boost::string_ref const& host, boost::string_ref const& resource, Decorator const& decorator) @@ -167,7 +167,7 @@ build_request(std::string& key, req.fields.insert("Host", host); req.fields.insert("Upgrade", "websocket"); req.fields.insert("Connection", "upgrade"); - key = detail::make_sec_ws_key(maskgen_); + detail::make_sec_ws_key(key, maskgen_); req.fields.insert("Sec-WebSocket-Key", key); req.fields.insert("Sec-WebSocket-Version", "13"); if(pmd_opts_.client_enable) @@ -186,6 +186,7 @@ build_request(std::string& key, req.fields, config); } decorator(req); + // VFALCO Use static_string here if(! req.fields.exists("User-Agent")) req.fields.insert("User-Agent", std::string("Beast/") + BEAST_VERSION_STRING); @@ -203,6 +204,7 @@ build_response(request_type const& req, [&decorator](response_type& res) { decorator(res); + // VFALCO Use static_string here if(! res.fields.exists("Server")) res.fields.insert("Server", std::string("Beast/") + @@ -235,6 +237,9 @@ build_response(request_type const& req, return err("Missing Sec-WebSocket-Key"); if(! http::token_list{req.fields["Upgrade"]}.exists("websocket")) return err("Missing websocket Upgrade token"); + auto const key = req.fields["Sec-WebSocket-Key"]; + if(key.size() > detail::sec_ws_key_type::max_size_n) + return err("Invalid Sec-WebSocket-Key"); { auto const version = req.fields["Sec-WebSocket-Version"]; @@ -255,6 +260,7 @@ build_response(request_type const& req, return res; } } + response_type res; { detail::pmd_offer offer; @@ -269,10 +275,9 @@ build_response(request_type const& req, res.fields.insert("Upgrade", "websocket"); res.fields.insert("Connection", "upgrade"); { - auto const key = - req.fields["Sec-WebSocket-Key"]; - res.fields.insert("Sec-WebSocket-Accept", - detail::make_sec_ws_accept(key)); + detail::sec_ws_accept_type accept; + detail::make_sec_ws_accept(accept, key); + res.fields.insert("Sec-WebSocket-Accept", accept); } decorate(res); return res; @@ -282,7 +287,7 @@ template void stream:: do_response(http::response_header const& res, - boost::string_ref const& key, error_code& ec) + detail::sec_ws_key_type const& key, error_code& ec) { // VFALCO Review these error codes auto fail = [&]{ ec = error::response_failed; }; @@ -296,8 +301,10 @@ do_response(http::response_header const& res, return fail(); if(! res.fields.exists("Sec-WebSocket-Accept")) return fail(); - if(res.fields["Sec-WebSocket-Accept"] != - detail::make_sec_ws_accept(key)) + detail::sec_ws_accept_type accept; + detail::make_sec_ws_accept(accept, key); + if(accept.compare( + res.fields["Sec-WebSocket-Accept"]) != 0) return fail(); detail::pmd_offer offer; pmd_read(offer, res.fields); diff --git a/include/beast/websocket/stream.hpp b/include/beast/websocket/stream.hpp index ae9b5ddb..c415563f 100644 --- a/include/beast/websocket/stream.hpp +++ b/include/beast/websocket/stream.hpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -2981,7 +2982,7 @@ private: template request_type - build_request(std::string& key, + build_request(detail::sec_ws_key_type& key, boost::string_ref const& host, boost::string_ref const& resource, Decorator const& decorator); @@ -2993,7 +2994,7 @@ private: void do_response(http::response_header const& resp, - boost::string_ref const& key, error_code& ec); + detail::sec_ws_key_type const& key, error_code& ec); }; } // websocket