Add websocket header tests

This commit is contained in:
2022-07-05 11:26:23 +02:00
parent 1fe9c38914
commit f12c21b41b
12 changed files with 205 additions and 49 deletions

View File

@@ -2,12 +2,14 @@ set(headers
src/asio_webserver/clientconnection.h src/asio_webserver/clientconnection.h
src/asio_webserver/responsehandler.h src/asio_webserver/responsehandler.h
src/asio_webserver/webserver.h src/asio_webserver/webserver.h
src/asio_webserver/websocketclientconnection.h
) )
set(sources set(sources
src/asio_webserver/clientconnection.cpp src/asio_webserver/clientconnection.cpp
src/asio_webserver/responsehandler.cpp src/asio_webserver/responsehandler.cpp
src/asio_webserver/webserver.cpp src/asio_webserver/webserver.cpp
src/asio_webserver/websocketclientconnection.cpp
) )
set(dependencies set(dependencies

View File

@@ -1,9 +1,11 @@
HEADERS += \ HEADERS += \
$$PWD/src/asio_webserver/clientconnection.h \ $$PWD/src/asio_webserver/clientconnection.h \
$$PWD/src/asio_webserver/responsehandler.h \ $$PWD/src/asio_webserver/responsehandler.h \
$$PWD/src/asio_webserver/webserver.h $$PWD/src/asio_webserver/webserver.h \
$$PWD/src/asio_webserver/websocketclientconnection.h
SOURCES += \ SOURCES += \
$$PWD/src/asio_webserver/clientconnection.cpp \ $$PWD/src/asio_webserver/clientconnection.cpp \
$$PWD/src/asio_webserver/responsehandler.cpp \ $$PWD/src/asio_webserver/responsehandler.cpp \
$$PWD/src/asio_webserver/webserver.cpp $$PWD/src/asio_webserver/webserver.cpp \
$$PWD/src/asio_webserver/websocketclientconnection.cpp

View File

@@ -14,6 +14,7 @@
// local includes // local includes
#include "webserver.h" #include "webserver.h"
#include "responsehandler.h" #include "responsehandler.h"
#include "websocketclientconnection.h"
namespace { namespace {
constexpr const char * const TAG = "ASIO_WEBSERVER"; constexpr const char * const TAG = "ASIO_WEBSERVER";
@@ -27,7 +28,7 @@ ClientConnection::ClientConnection(Webserver &webserver, asio::ip::tcp::socket s
ESP_LOGI(TAG, "new client (%s:%hi)", ESP_LOGI(TAG, "new client (%s:%hi)",
m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port()); m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port());
m_webserver.m_clients++; m_webserver.m_httpClients++;
} }
ClientConnection::~ClientConnection() ClientConnection::~ClientConnection()
@@ -35,7 +36,7 @@ ClientConnection::~ClientConnection()
ESP_LOGI(TAG, "client destroyed (%s:%hi)", ESP_LOGI(TAG, "client destroyed (%s:%hi)",
m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port()); m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port());
m_webserver.m_clients--; m_webserver.m_httpClients--;
} }
void ClientConnection::start() void ClientConnection::start()
@@ -69,7 +70,7 @@ void ClientConnection::upgradeWebsocket()
// ESP_LOGD(TAG, "state changed to RequestLine"); // ESP_LOGD(TAG, "state changed to RequestLine");
m_state = State::WebSocket; m_state = State::WebSocket;
doReadWebSocket(); std::make_shared<WebsocketClientConnection>(m_webserver, std::move(m_socket), std::move(m_responseHandler))->start();
} }
void ClientConnection::doRead() void ClientConnection::doRead()
@@ -79,13 +80,6 @@ void ClientConnection::doRead()
{ readyRead(ec, length); }); { readyRead(ec, length); });
} }
void ClientConnection::doReadWebSocket()
{
m_socket.async_read_some(asio::buffer(m_receiveBuffer, max_length),
[this, self=shared_from_this()](std::error_code ec, std::size_t length)
{ readyReadWebSocket(ec, length); });
}
void ClientConnection::readyRead(std::error_code ec, std::size_t length) void ClientConnection::readyRead(std::error_code ec, std::size_t length)
{ {
if (ec) if (ec)
@@ -160,20 +154,6 @@ requestFinished:
doRead(); doRead();
} }
void ClientConnection::readyReadWebSocket(std::error_code ec, std::size_t length)
{
if (ec)
{
ESP_LOGI(TAG, "error: %i %s (%s:%hi)", ec.value(), ec.message().c_str(),
m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port());
return;
}
ESP_LOGV(TAG, "received: %zd \"%.*s\"", length, length, m_receiveBuffer);
doReadWebSocket();
}
bool ClientConnection::readyReadLine(std::string_view line) bool ClientConnection::readyReadLine(std::string_view line)
{ {
switch (m_state) switch (m_state)
@@ -203,6 +183,8 @@ bool ClientConnection::readyReadLine(std::string_view line)
bool ClientConnection::parseRequestLine(std::string_view line) bool ClientConnection::parseRequestLine(std::string_view line)
{ {
// ESP_LOGV(TAG, "%.*s", line.size(), line.data());
if (const auto index = line.find(' '); index == std::string::npos) if (const auto index = line.find(' '); index == std::string::npos)
{ {
ESP_LOGW(TAG, "invalid request line (1): \"%.*s\" (%s:%hi)", line.size(), line.data(), ESP_LOGW(TAG, "invalid request line (1): \"%.*s\" (%s:%hi)", line.size(), line.data(),
@@ -250,6 +232,8 @@ bool ClientConnection::parseRequestLine(std::string_view line)
bool ClientConnection::parseRequestHeader(std::string_view line) bool ClientConnection::parseRequestHeader(std::string_view line)
{ {
// ESP_LOGV(TAG, "%.*s", line.size(), line.data());
if (!line.empty()) if (!line.empty())
{ {
constexpr std::string_view sep{": "}; constexpr std::string_view sep{": "};

View File

@@ -31,9 +31,7 @@ public:
private: private:
void doRead(); void doRead();
void doReadWebSocket();
void readyRead(std::error_code ec, std::size_t length); void readyRead(std::error_code ec, std::size_t length);
void readyReadWebSocket(std::error_code ec, std::size_t length);
bool parseRequestLine(std::string_view line); bool parseRequestLine(std::string_view line);
bool readyReadLine(std::string_view line); bool readyReadLine(std::string_view line);
bool parseRequestHeader(std::string_view line); bool parseRequestHeader(std::string_view line);

View File

@@ -24,7 +24,9 @@ public:
protected: protected:
friend class ClientConnection; friend class ClientConnection;
std::atomic<int> m_clients; friend class WebsocketClientConnection;
std::atomic<int> m_httpClients;
std::atomic<int> m_websocketClients;
private: private:
void doAccept(); void doAccept();

View File

@@ -0,0 +1,127 @@
#include "websocketclientconnection.h"
// system includes
// esp-idf includes
#include <esp_log.h>
// 3rdparty lib includes
#include <strutils.h>
// local includes
#include "webserver.h"
#include "responsehandler.h"
namespace {
constexpr const char * const TAG = "ASIO_WEBSERVER";
} // namespace
WebsocketClientConnection::WebsocketClientConnection(Webserver &webserver, asio::ip::tcp::socket socket, std::unique_ptr<ResponseHandler> &&responseHandler) :
m_webserver{webserver},
m_socket{std::move(socket)},
m_remote_endpoint{m_socket.remote_endpoint()},
m_responseHandler{std::move(responseHandler)}
{
ESP_LOGI(TAG, "new client (%s:%hi)",
m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port());
m_webserver.m_websocketClients++;
}
WebsocketClientConnection::~WebsocketClientConnection()
{
ESP_LOGI(TAG, "client destroyed (%s:%hi)",
m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port());
m_webserver.m_websocketClients--;
}
void WebsocketClientConnection::start()
{
doReadWebSocket();
}
void WebsocketClientConnection::doReadWebSocket()
{
m_socket.async_read_some(asio::buffer(m_receiveBuffer, max_length),
[this, self=shared_from_this()](std::error_code ec, std::size_t length)
{ readyReadWebSocket(ec, length); });
}
void WebsocketClientConnection::readyReadWebSocket(std::error_code ec, std::size_t length)
{
if (ec)
{
ESP_LOGI(TAG, "error: %i %s (%s:%hi)", ec.value(), ec.message().c_str(),
m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port());
return;
}
// ESP_LOGV(TAG, "received: %zd \"%.*s\"", length, length, m_receiveBuffer);
m_parsingBuffer.append({m_receiveBuffer, length});
again:
ESP_LOGI(TAG, "m_parsingBuffer: %s", cpputils::toHexString(m_parsingBuffer).c_str());
if (m_parsingBuffer.empty())
{
doReadWebSocket();
return;
}
struct WebsocketHeader {
uint8_t opcode:4;
uint8_t reserved:3;
uint8_t fin:1;
uint8_t payloadLength:7;
uint8_t mask:1;
};
if (m_parsingBuffer.size() < sizeof(WebsocketHeader))
{
ESP_LOGW(TAG, "buffer smaller than a websocket header");
doReadWebSocket();
return;
}
ESP_LOGI(TAG, "%s%s%s%s %s%s%s%s %s%s%s%s %s%s%s%s",
m_parsingBuffer.data()[0]&128?"1":".", m_parsingBuffer.data()[0]&64?"1":".", m_parsingBuffer.data()[0]&32?"1":".", m_parsingBuffer.data()[0]&16?"1":".",
m_parsingBuffer.data()[0]&8?"1":".", m_parsingBuffer.data()[0]&4?"1":".", m_parsingBuffer.data()[0]&2?"1":".", m_parsingBuffer.data()[0]&1?"1":".",
m_parsingBuffer.data()[1]&128?"1":".", m_parsingBuffer.data()[1]&64?"1":".", m_parsingBuffer.data()[1]&32?"1":".", m_parsingBuffer.data()[1]&16?"1":".",
m_parsingBuffer.data()[1]&8?"1":".", m_parsingBuffer.data()[1]&4?"1":".", m_parsingBuffer.data()[1]&2?"1":".", m_parsingBuffer.data()[1]&1?"1":".");
const WebsocketHeader *hdr = (const WebsocketHeader *)m_parsingBuffer.data();
ESP_LOGI(TAG, "fin=%i reserved=%i opcode=%i mask=%i payloadLength=%i", hdr->fin, hdr->reserved, hdr->opcode, hdr->mask, hdr->payloadLength);
if (hdr->mask)
{
uint32_t mask;
if (m_parsingBuffer.size() < sizeof(WebsocketHeader) + sizeof(mask) + hdr->payloadLength)
{
ESP_LOGW(TAG, "buffer smaller than payload %zd vs %zd", m_parsingBuffer.size(), sizeof(WebsocketHeader) + hdr->payloadLength);
doReadWebSocket();
return;
}
mask = *(const uint32_t *)(m_parsingBuffer.data() + sizeof(WebsocketHeader));
ESP_LOGI(TAG, "mask=%s", cpputils::toHexString({(const char *)&mask, sizeof(mask)}).c_str());
m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), sizeof(WebsocketHeader) + sizeof(mask) + hdr->payloadLength));
}
else
{
if (m_parsingBuffer.size() < sizeof(WebsocketHeader) + hdr->payloadLength)
{
ESP_LOGW(TAG, "buffer smaller than payload %zd vs %zd", m_parsingBuffer.size(), sizeof(WebsocketHeader) + hdr->payloadLength);
doReadWebSocket();
return;
}
m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), sizeof(WebsocketHeader) + hdr->payloadLength));
}
goto again;
}

View File

@@ -0,0 +1,41 @@
#pragma once
// system includes
// esp-idf includes
#include <asio.hpp>
class Webserver;
class ResponseHandler;
class WebsocketClientConnection : public std::enable_shared_from_this<WebsocketClientConnection>
{
public:
WebsocketClientConnection(Webserver &webserver, asio::ip::tcp::socket socket, std::unique_ptr<ResponseHandler> &&responseHandler);
~WebsocketClientConnection();
Webserver &webserver() { return m_webserver; }
const Webserver &webserver() const { return m_webserver; }
asio::ip::tcp::socket &socket() { return m_socket; }
const asio::ip::tcp::socket &socket() const { return m_socket; }
const asio::ip::tcp::endpoint &remote_endpoint() const { return m_remote_endpoint; }
void start();
private:
void doReadWebSocket();
void readyReadWebSocket(std::error_code ec, std::size_t length);
Webserver &m_webserver;
asio::ip::tcp::socket m_socket;
const asio::ip::tcp::endpoint m_remote_endpoint;
std::unique_ptr<ResponseHandler> m_responseHandler;
static constexpr const std::size_t max_length = 4;
char m_receiveBuffer[max_length];
std::string m_parsingBuffer;
};

View File

@@ -16,14 +16,14 @@ constexpr const char * const TAG = "ASIO_WEBSERVER";
ChunkedResponseHandler::ChunkedResponseHandler(ClientConnection &clientConnection) : ChunkedResponseHandler::ChunkedResponseHandler(ClientConnection &clientConnection) :
m_clientConnection{clientConnection} m_clientConnection{clientConnection}
{ {
ESP_LOGV(TAG, "constructed for (%s:%hi)", // ESP_LOGV(TAG, "constructed for (%s:%hi)",
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
ChunkedResponseHandler::~ChunkedResponseHandler() ChunkedResponseHandler::~ChunkedResponseHandler()
{ {
ESP_LOGV(TAG, "destructed for (%s:%hi)", // ESP_LOGV(TAG, "destructed for (%s:%hi)",
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
void ChunkedResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value) void ChunkedResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value)

View File

@@ -16,14 +16,14 @@ constexpr const char * const TAG = "ASIO_WEBSERVER";
DebugResponseHandler::DebugResponseHandler(ClientConnection &clientConnection, std::string_view method, std::string_view path, std::string_view protocol) : DebugResponseHandler::DebugResponseHandler(ClientConnection &clientConnection, std::string_view method, std::string_view path, std::string_view protocol) :
m_clientConnection{clientConnection}, m_method{method}, m_path{path}, m_protocol{protocol} m_clientConnection{clientConnection}, m_method{method}, m_path{path}, m_protocol{protocol}
{ {
ESP_LOGI(TAG, "constructed for %.*s %.*s (%s:%hi)", m_method.size(), m_method.data(), path.size(), path.data(), // ESP_LOGV(TAG, "constructed for %.*s %.*s (%s:%hi)", m_method.size(), m_method.data(), path.size(), path.data(),
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
DebugResponseHandler::~DebugResponseHandler() DebugResponseHandler::~DebugResponseHandler()
{ {
ESP_LOGI(TAG, "destructed for %.*s %.*s (%s:%hi)", m_method.size(), m_method.data(), m_path.size(), m_path.data(), // ESP_LOGV(TAG, "destructed for %.*s %.*s (%s:%hi)", m_method.size(), m_method.data(), m_path.size(), m_path.data(),
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
void DebugResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value) void DebugResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value)

View File

@@ -17,14 +17,14 @@ ErrorResponseHandler::ErrorResponseHandler(ClientConnection &clientConnection, s
m_clientConnection{clientConnection}, m_clientConnection{clientConnection},
m_path{path} m_path{path}
{ {
ESP_LOGI(TAG, "constructed for %.*s (%s:%hi)", path.size(), path.data(), // ESP_LOGV(TAG, "constructed for %.*s (%s:%hi)", path.size(), path.data(),
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
ErrorResponseHandler::~ErrorResponseHandler() ErrorResponseHandler::~ErrorResponseHandler()
{ {
ESP_LOGI(TAG, "destructed for %.*s (%s:%hi)", m_path.size(), m_path.data(), // ESP_LOGV(TAG, "destructed for %.*s (%s:%hi)", m_path.size(), m_path.data(),
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
void ErrorResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value) void ErrorResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value)

View File

@@ -16,14 +16,14 @@ constexpr const char * const TAG = "ASIO_WEBSERVER";
RootResponseHandler::RootResponseHandler(ClientConnection &clientConnection) : RootResponseHandler::RootResponseHandler(ClientConnection &clientConnection) :
m_clientConnection{clientConnection} m_clientConnection{clientConnection}
{ {
ESP_LOGI(TAG, "constructed for (%s:%hi)", // ESP_LOGV(TAG, "constructed for (%s:%hi)",
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
RootResponseHandler::~RootResponseHandler() RootResponseHandler::~RootResponseHandler()
{ {
ESP_LOGI(TAG, "destructed for (%s:%hi)", // ESP_LOGV(TAG, "destructed for (%s:%hi)",
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
void RootResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value) void RootResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value)

View File

@@ -132,14 +132,14 @@ constexpr std::string_view html{R"END(
WebsocketResponseHandler::WebsocketResponseHandler(ClientConnection &clientConnection) : WebsocketResponseHandler::WebsocketResponseHandler(ClientConnection &clientConnection) :
m_clientConnection{clientConnection} m_clientConnection{clientConnection}
{ {
ESP_LOGI(TAG, "constructed for (%s:%hi)", // ESP_LOGV(TAG, "constructed for (%s:%hi)",
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
WebsocketResponseHandler::~WebsocketResponseHandler() WebsocketResponseHandler::~WebsocketResponseHandler()
{ {
ESP_LOGI(TAG, "destructed for (%s:%hi)", // ESP_LOGV(TAG, "destructed for (%s:%hi)",
m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); // m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
} }
void WebsocketResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value) void WebsocketResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value)