diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e60874b..adac21b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ WebSocket: * Refactor read_op + fail_op * Websocket close will automatically drain * Autobahn|Testsuite fixes +* Tidy up utf8_checker and tests -------------------------------------------------------------------------------- diff --git a/include/boost/beast/websocket/detail/utf8_checker.hpp b/include/boost/beast/websocket/detail/utf8_checker.hpp index 79bd5d84..3c8bb799 100644 --- a/include/boost/beast/websocket/detail/utf8_checker.hpp +++ b/include/boost/beast/websocket/detail/utf8_checker.hpp @@ -30,9 +30,9 @@ namespace detail { template class utf8_checker_t { - std::size_t need_ = 0; - std::uint8_t* p_ = have_; - std::uint8_t have_[4]; + std::size_t need_ = 0; // chars we need to finish the code point + std::uint8_t* p_ = cp_; // current position in temp buffer + std::uint8_t cp_[4]; // a temp buffer for the code point public: /** Prepare to process text as valid utf8 @@ -67,7 +67,7 @@ utf8_checker_t<_>:: reset() { need_ = 0; - p_ = have_; + p_ = cp_; } template @@ -105,21 +105,21 @@ write(std::uint8_t const* in, std::size_t size) auto const valid = [](std::uint8_t const*& p) { - if (p[0] < 128) + if(p[0] < 128) { ++p; return true; } - if ((p[0] & 0x60) == 0x40) + if((p[0] & 0x60) == 0x40) { - if ((p[1] & 0xc0) != 0x80) + if((p[1] & 0xc0) != 0x80) return false; p += 2; return true; } - if ((p[0] & 0xf0) == 0xe0) + if((p[0] & 0xf0) == 0xe0) { - if ((p[1] & 0xc0) != 0x80 || + if((p[1] & 0xc0) != 0x80 || (p[2] & 0xc0) != 0x80 || (p[0] == 224 && p[1] < 160) || (p[0] == 237 && p[1] > 159)) @@ -127,9 +127,9 @@ write(std::uint8_t const* in, std::size_t size) p += 3; return true; } - if ((p[0] & 0xf8) == 0xf0) + if((p[0] & 0xf8) == 0xf0) { - if (p[0] > 244 || + if(p[0] > 244 || (p[1] & 0xc0) != 0x80 || (p[2] & 0xc0) != 0x80 || (p[3] & 0xc0) != 0x80 || @@ -144,26 +144,26 @@ write(std::uint8_t const* in, std::size_t size) auto const valid_have = [&]() { - if ((have_[0] & 0x60) == 0x40) - return have_[0] <= 223; - if ((have_[0] & 0xf0) == 0xe0) + if((cp_[0] & 0x60) == 0x40) + return cp_[0] <= 223; + if((cp_[0] & 0xf0) == 0xe0) { - if (p_ - have_ > 1 && - ((have_[1] & 0xc0) != 0x80 || - (have_[0] == 224 && have_[1] < 160) || - (have_[0] == 237 && have_[1] > 159))) + if(p_ - cp_ > 1 && + ((cp_[1] & 0xc0) != 0x80 || + (cp_[0] == 224 && cp_[1] < 160) || + (cp_[0] == 237 && cp_[1] > 159))) return false; return true; } - if ((have_[0] & 0xf8) == 0xf0) + if((cp_[0] & 0xf8) == 0xf0) { - auto const n = p_ - have_; - if (n > 2 && (have_[2] & 0xc0) != 0x80) + auto const n = p_ - cp_; + if(n > 2 && (cp_[2] & 0xc0) != 0x80) return false; - if (n > 1 && - ((have_[1] & 0xc0) != 0x80 || - (have_[0] == 240 && have_[1] < 144) || - (have_[0] == 244 && have_[1] > 143))) + if(n > 1 && + ((cp_[1] & 0xc0) != 0x80 || + (cp_[0] == 240 && cp_[1] < 144) || + (cp_[0] == 244 && cp_[1] > 143))) return false; } return true; @@ -171,51 +171,69 @@ write(std::uint8_t const* in, std::size_t size) auto const needed = [](std::uint8_t const v) { - if (v < 128) + if(v < 128) return 1; - if (v < 194) + if(v < 194) return 0; - if (v < 224) + if(v < 224) return 2; - if (v < 240) + if(v < 240) return 3; - if (v < 245) + if(v < 245) return 4; return 0; }; auto const end = in + size; - if (need_ > 0) + + // Finish up any incomplete code point + if(need_ > 0) { + // Calculate what we have auto n = (std::min)(size, need_); size -= n; need_ -= n; + + // Add characters to the code point while(n--) *p_++ = *in++; + BOOST_ASSERT(p_ <= cp_ + 5); + + // Still incomplete? if(need_ > 0) { + // Incomplete code point BOOST_ASSERT(in == end); + + // Do partial validation on the incomplete + // code point, this is called "Fail fast" + // in Autobahn|Testsuite parlance. return valid_have(); } - std::uint8_t const* p = &have_[0]; - if (! valid(p)) + + // Complete code point, validate it + std::uint8_t const* p = &cp_[0]; + if(! valid(p)) return false; - p_ = have_; + p_ = cp_; } if(size <= sizeof(std::size_t)) goto slow; - // align in to sizeof(std::size_t) boundary + // Align `in` to sizeof(std::size_t) boundary { auto const in0 = in; auto last = reinterpret_cast( ((reinterpret_cast(in) + sizeof(std::size_t) - 1) / sizeof(std::size_t)) * sizeof(std::size_t)); + + // Check one character at a time for low-ASCII while(in < last) { if(*in & 0x80) { + // Not low-ASCII so switch to slow loop size = size - (in - in0); goto slow; } @@ -224,7 +242,7 @@ write(std::uint8_t const* in, std::size_t size) size = size - (in - in0); } - // fast loop + // Fast loop: Process 4 or 8 low-ASCII characters at a time { auto const in0 = in; auto last = in + size - 7; @@ -246,6 +264,7 @@ write(std::uint8_t const* in, std::size_t size) } in += sizeof(std::size_t); } + // There's at least one more full code point left last += 4; while(in < last) if(! valid(in)) @@ -253,8 +272,8 @@ write(std::uint8_t const* in, std::size_t size) goto tail; } - // slow loop: one code point at a time slow: + // Slow loop: Full validation on one code point at a time { auto last = in + size - 3; while(in < last) @@ -263,24 +282,45 @@ slow: } tail: + // Handle the remaining bytes. The last + // characters could split a code point so + // we save the partial code point for later. + // + // On entry to the loop, `in` points to the + // beginning of a code point. + // for(;;) { + // Number of chars left auto n = end - in; if(! n) break; + + // Chars we need to finish this code point auto const need = needed(*in); - if (need == 0) + if(need == 0) return false; if(need <= n) { + // Check a whole code point if(! valid(in)) return false; } else { + // Calculate how many chars we need + // to finish this partial code point need_ = need - n; + + // Save the partial code point while(n--) *p_++ = *in++; + BOOST_ASSERT(in == end); + BOOST_ASSERT(p_ <= cp_ + 5); + + // Do partial validation on the incomplete + // code point, this is called "Fail fast" + // in Autobahn|Testsuite parlance. return valid_have(); } } diff --git a/test/beast/websocket/utf8_checker.cpp b/test/beast/websocket/utf8_checker.cpp index 7d669d3a..952e3c42 100644 --- a/test/beast/websocket/utf8_checker.cpp +++ b/test/beast/websocket/utf8_checker.cpp @@ -42,14 +42,12 @@ public: BEAST_EXPECT(utf8.finish()); // Invalid range 128-193 - for(auto it = std::next(buf.begin(), 128); - it != std::next(buf.begin(), 194); ++it) - BEAST_EXPECT(! utf8.write(&(*it), 1)); + for(unsigned char c = 128; c < 194; ++c) + BEAST_EXPECT(! utf8.write(&c, 1)); // Invalid range 245-255 - for(auto it = std::next(buf.begin(), 245); - it != buf.end(); ++it) - BEAST_EXPECT(! utf8.write(&(*it), 1)); + for(unsigned char c = 245; c; ++c) + BEAST_EXPECT(! utf8.write(&c, 1)); // Invalid sequence std::fill(buf.begin(), buf.end(), '\xff'); @@ -79,6 +77,7 @@ public: // Second byte invalid range 0-127 buf[1] = static_cast(j); BEAST_EXPECT(! utf8.write(buf, 2)); + utf8.reset(); } for(auto j = 192; j <= 255; ++j) @@ -86,6 +85,7 @@ public: // Second byte invalid range 192-255 buf[1] = static_cast(j); BEAST_EXPECT(! utf8.write(buf, 2)); + utf8.reset(); } // Segmented sequence second byte invalid @@ -134,6 +134,7 @@ public: // Second byte invalid range 0-127 or 0-159 buf[1] = static_cast(l); BEAST_EXPECT(! utf8.write(buf, 3)); + utf8.reset(); if (l > 127) { // Segmented sequence second byte invalid @@ -149,7 +150,8 @@ public: { // Second byte invalid range 160-255 or 192-255 buf[1] = static_cast(l); - BEAST_EXPECT(!utf8.write(buf, 3)); + BEAST_EXPECT(! utf8.write(buf, 3)); + utf8.reset(); if (l > 159) { // Segmented sequence second byte invalid @@ -166,6 +168,7 @@ public: // Third byte invalid range 0-127 buf[2] = static_cast(k); BEAST_EXPECT(! utf8.write(buf, 3)); + utf8.reset(); } for(auto k = 192; k <= 255; ++k) @@ -173,6 +176,7 @@ public: // Third byte invalid range 192-255 buf[2] = static_cast(k); BEAST_EXPECT(! utf8.write(buf, 3)); + utf8.reset(); } // Segmented sequence third byte invalid @@ -186,6 +190,7 @@ public: // Second byte invalid range 0-127 or 0-159 buf[1] = static_cast(j); BEAST_EXPECT(! utf8.write(buf, 3)); + utf8.reset(); } for(auto j = e + 1; j <= 255; ++j) @@ -193,6 +198,7 @@ public: // Second byte invalid range 160-255 or 192-255 buf[1] = static_cast(j); BEAST_EXPECT(! utf8.write(buf, 3)); + utf8.reset(); } // Segmented sequence second byte invalid @@ -251,6 +257,7 @@ public: { buf[1] = static_cast(r); BEAST_EXPECT(! utf8.write(buf, 4)); + utf8.reset(); if (r > 127) { // Segmented sequence second byte invalid @@ -267,6 +274,7 @@ public: { buf[1] = static_cast(r); BEAST_EXPECT(! utf8.write(buf, 4)); + utf8.reset(); // Segmented sequence second byte invalid BEAST_EXPECT(! utf8.write(buf, 2)); utf8.reset(); @@ -280,6 +288,7 @@ public: { buf[3] = static_cast(r); BEAST_EXPECT(! utf8.write(buf, 4)); + utf8.reset(); } // Segmented sequence fourth byte invalid @@ -293,6 +302,7 @@ public: { buf[2] = static_cast(r); BEAST_EXPECT(! utf8.write(buf, 4)); + utf8.reset(); } // Segmented sequence third byte invalid @@ -306,6 +316,7 @@ public: { buf[1] = static_cast(r); BEAST_EXPECT(! utf8.write(buf, 4)); + utf8.reset(); } // Second byte invalid range 144-255 or 192-255 @@ -313,6 +324,7 @@ public: { buf[1] = static_cast(r); BEAST_EXPECT(! utf8.write(buf, 4)); + utf8.reset(); } // Segmented sequence second byte invalid @@ -326,6 +338,7 @@ public: { buf[0] = static_cast(r); BEAST_EXPECT(! utf8.write(buf, 4)); + utf8.reset(); } }