diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d3f9fc..ff74f58 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,6 +3,7 @@ set(headers src/asio_web/responsehandler.h src/asio_web/webserver.h src/asio_web/websocketclientconnection.h + src/asio_web/websocketstream.h ) set(sources @@ -10,6 +11,7 @@ set(sources src/asio_web/responsehandler.cpp src/asio_web/webserver.cpp src/asio_web/websocketclientconnection.cpp + src/asio_web/websocketstream.cpp ) set(dependencies diff --git a/asio_web_src.pri b/asio_web_src.pri index b0bfe9e..0c64de3 100644 --- a/asio_web_src.pri +++ b/asio_web_src.pri @@ -2,10 +2,12 @@ HEADERS += \ $$PWD/src/asio_web/clientconnection.h \ $$PWD/src/asio_web/responsehandler.h \ $$PWD/src/asio_web/webserver.h \ - $$PWD/src/asio_web/websocketclientconnection.h + $$PWD/src/asio_web/websocketclientconnection.h \ + $$PWD/src/asio_web/websocketstream.h SOURCES += \ $$PWD/src/asio_web/clientconnection.cpp \ $$PWD/src/asio_web/responsehandler.cpp \ $$PWD/src/asio_web/webserver.cpp \ - $$PWD/src/asio_web/websocketclientconnection.cpp + $$PWD/src/asio_web/websocketclientconnection.cpp \ + $$PWD/src/asio_web/websocketstream.cpp diff --git a/src/asio_web/clientconnection.cpp b/src/asio_web/clientconnection.cpp index 6264259..b86f8eb 100644 --- a/src/asio_web/clientconnection.cpp +++ b/src/asio_web/clientconnection.cpp @@ -70,7 +70,7 @@ void ClientConnection::upgradeWebsocket() // ESP_LOGD(TAG, "state changed to RequestLine"); m_state = State::WebSocket; - std::make_shared(m_webserver, std::move(m_socket), std::move(m_responseHandler))->start(); + std::make_shared(m_webserver, std::move(m_socket), std::move(m_parsingBuffer), std::move(m_responseHandler))->start(); } void ClientConnection::doRead() diff --git a/src/asio_web/clientconnection.h b/src/asio_web/clientconnection.h index 56a92dc..c76e474 100644 --- a/src/asio_web/clientconnection.h +++ b/src/asio_web/clientconnection.h @@ -32,8 +32,8 @@ public: private: void doRead(); void readyRead(std::error_code ec, std::size_t length); - bool parseRequestLine(std::string_view line); bool readyReadLine(std::string_view line); + bool parseRequestLine(std::string_view line); bool parseRequestHeader(std::string_view line); Webserver &m_webserver; diff --git a/src/asio_web/websocketclientconnection.cpp b/src/asio_web/websocketclientconnection.cpp index 3dd7e29..d173176 100644 --- a/src/asio_web/websocketclientconnection.cpp +++ b/src/asio_web/websocketclientconnection.cpp @@ -7,19 +7,23 @@ // 3rdparty lib includes #include +#include // local includes #include "webserver.h" #include "responsehandler.h" +#include "websocketstream.h" namespace { constexpr const char * const TAG = "ASIO_WEBSERVER"; } // namespace -WebsocketClientConnection::WebsocketClientConnection(Webserver &webserver, asio::ip::tcp::socket socket, std::unique_ptr &&responseHandler) : +WebsocketClientConnection::WebsocketClientConnection(Webserver &webserver, asio::ip::tcp::socket socket, + std::string &&parsingBuffer, std::unique_ptr &&responseHandler) : m_webserver{webserver}, m_socket{std::move(socket)}, m_remote_endpoint{m_socket.remote_endpoint()}, + m_parsingBuffer{std::move(parsingBuffer)}, m_responseHandler{std::move(responseHandler)} { ESP_LOGI(TAG, "new client (%s:%hi)", @@ -62,7 +66,7 @@ void WebsocketClientConnection::readyReadWebSocket(std::error_code ec, std::size m_parsingBuffer.append({m_receiveBuffer, length}); again: - ESP_LOGI(TAG, "m_parsingBuffer: %s", cpputils::toHexString(m_parsingBuffer).c_str()); +// ESP_LOGV(TAG, "m_parsingBuffer: %s", cpputils::toHexString(m_parsingBuffer).c_str()); if (m_parsingBuffer.empty()) { @@ -70,13 +74,7 @@ again: return; } - struct WebsocketHeader { - uint8_t opcode:4; - uint8_t reserved:3; - uint8_t fin:1; - uint8_t payloadLength:7; - uint8_t mask:1; - }; + static_assert(sizeof(WebsocketHeader) == 2); if (m_parsingBuffer.size() < sizeof(WebsocketHeader)) { @@ -85,43 +83,131 @@ again: return; } - ESP_LOGI(TAG, "%s%s%s%s %s%s%s%s %s%s%s%s %s%s%s%s", - m_parsingBuffer.data()[0]&128?"1":".", m_parsingBuffer.data()[0]&64?"1":".", m_parsingBuffer.data()[0]&32?"1":".", m_parsingBuffer.data()[0]&16?"1":".", - m_parsingBuffer.data()[0]&8?"1":".", m_parsingBuffer.data()[0]&4?"1":".", m_parsingBuffer.data()[0]&2?"1":".", m_parsingBuffer.data()[0]&1?"1":".", - m_parsingBuffer.data()[1]&128?"1":".", m_parsingBuffer.data()[1]&64?"1":".", m_parsingBuffer.data()[1]&32?"1":".", m_parsingBuffer.data()[1]&16?"1":".", - m_parsingBuffer.data()[1]&8?"1":".", m_parsingBuffer.data()[1]&4?"1":".", m_parsingBuffer.data()[1]&2?"1":".", m_parsingBuffer.data()[1]&1?"1":"."); +// ESP_LOGV(TAG, "%s%s%s%s %s%s%s%s %s%s%s%s %s%s%s%s", +// m_parsingBuffer.data()[0]&128?"1":".", m_parsingBuffer.data()[0]&64?"1":".", m_parsingBuffer.data()[0]&32?"1":".", m_parsingBuffer.data()[0]&16?"1":".", +// m_parsingBuffer.data()[0]&8?"1":".", m_parsingBuffer.data()[0]&4?"1":".", m_parsingBuffer.data()[0]&2?"1":".", m_parsingBuffer.data()[0]&1?"1":".", +// m_parsingBuffer.data()[1]&128?"1":".", m_parsingBuffer.data()[1]&64?"1":".", m_parsingBuffer.data()[1]&32?"1":".", m_parsingBuffer.data()[1]&16?"1":".", +// m_parsingBuffer.data()[1]&8?"1":".", m_parsingBuffer.data()[1]&4?"1":".", m_parsingBuffer.data()[1]&2?"1":".", m_parsingBuffer.data()[1]&1?"1":"."); - const WebsocketHeader *hdr = (const WebsocketHeader *)m_parsingBuffer.data(); + auto iter = std::begin(m_parsingBuffer); - ESP_LOGI(TAG, "fin=%i reserved=%i opcode=%i mask=%i payloadLength=%i", hdr->fin, hdr->reserved, hdr->opcode, hdr->mask, hdr->payloadLength); + const WebsocketHeader &hdr = *(const WebsocketHeader *)(&*iter); + std::advance(iter, sizeof(WebsocketHeader)); - if (hdr->mask) +// ESP_LOGV(TAG, "fin=%i reserved=%i opcode=%i mask=%i payloadLength=%i", hdr.fin, hdr.reserved, hdr.opcode, hdr.mask, hdr.payloadLength); + + uint64_t payloadLength = hdr.payloadLength; + + if (hdr.payloadLength == 126) { - uint32_t mask; - - if (m_parsingBuffer.size() < sizeof(WebsocketHeader) + sizeof(mask) + hdr->payloadLength) + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint16_t)) { - ESP_LOGW(TAG, "buffer smaller than payload %zd vs %zd", m_parsingBuffer.size(), sizeof(WebsocketHeader) + hdr->payloadLength); + ESP_LOGW(TAG, "buffer smaller than uint32_t payloadLength"); doReadWebSocket(); return; } - mask = *(const uint32_t *)(m_parsingBuffer.data() + sizeof(WebsocketHeader)); - ESP_LOGI(TAG, "mask=%s", cpputils::toHexString({(const char *)&mask, sizeof(mask)}).c_str()); + payloadLength = __builtin_bswap16(*(const uint16_t *)(&*iter)); + std::advance(iter, sizeof(uint16_t)); - m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), sizeof(WebsocketHeader) + sizeof(mask) + hdr->payloadLength)); +// ESP_LOGV(TAG, "16bit payloadLength: %u", payloadLength); } - else + else if (hdr.payloadLength == 127) { - if (m_parsingBuffer.size() < sizeof(WebsocketHeader) + hdr->payloadLength) + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint64_t)) { - ESP_LOGW(TAG, "buffer smaller than payload %zd vs %zd", m_parsingBuffer.size(), sizeof(WebsocketHeader) + hdr->payloadLength); + ESP_LOGW(TAG, "buffer smaller than uint64_t payloadLength"); doReadWebSocket(); return; } - m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), sizeof(WebsocketHeader) + hdr->payloadLength)); + payloadLength = *(const uint64_t *)(&*iter); + std::advance(iter, sizeof(uint64_t)); + + ESP_LOGI(TAG, "64bit payloadLength: %u", payloadLength); } + if (hdr.mask) + { + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint32_t)) + { + ESP_LOGW(TAG, "buffer smaller than uint32_t mask"); + doReadWebSocket(); + return; + } + + union { + uint32_t mask; + uint8_t maskArr[4]; + }; + mask = *(const uint32_t *)(&*iter); + std::advance(iter, sizeof(uint32_t)); + + if (std::distance(iter, std::end(m_parsingBuffer)) < payloadLength) + { + ESP_LOGW(TAG, "masked buffer smaller payloadLength"); + doReadWebSocket(); + return; + } + + auto iter2 = std::begin(maskArr); + for (auto iter3 = iter; + iter3 != std::end(m_parsingBuffer) && iter3 != std::next(iter, payloadLength); + iter3++) + { + *iter3 ^= *(iter2++); + if (iter2 == std::end(maskArr)) + iter2 = std::begin(maskArr); + } + } + else if (std::distance(iter, std::end(m_parsingBuffer)) < payloadLength) + { + ESP_LOGW(TAG, "buffer smaller payloadLength"); + doReadWebSocket(); + return; + } + + ESP_LOGI(TAG, "remaining: %zd %lu", std::distance(iter, std::end(m_parsingBuffer)), payloadLength); + + ESP_LOGI(TAG, "payload: %.*s", payloadLength, &*iter); + + std::advance(iter, payloadLength); + m_parsingBuffer.erase(std::begin(m_parsingBuffer), iter); + + sendMessage(true, 0, 1, false, fmt::format("received {}", payloadLength)); + goto again; } + +void WebsocketClientConnection::sendMessage(bool fin, uint8_t reserved, uint8_t opcode, bool mask, std::string_view payload) +{ + m_sendBuffer.clear(); + m_sendBuffer.resize(2); + { + auto iter = std::begin(m_sendBuffer); + WebsocketHeader &hdr = *(WebsocketHeader *)(&*iter); + hdr.fin = fin; + hdr.reserved = reserved; + hdr.opcode = opcode; + hdr.mask = mask; + hdr.payloadLength = payload.size(); + } + m_sendBuffer.append(payload); + + asio::async_write(m_socket, + asio::buffer(m_sendBuffer.data(), m_sendBuffer.size()), + [this](std::error_code ec, std::size_t length) + { onMessageSent(ec, length); }); +} + +void WebsocketClientConnection::onMessageSent(std::error_code ec, std::size_t length) +{ + if (ec) + { + ESP_LOGW(TAG, "error: %i (%s:%hi)", ec.value(), + m_remote_endpoint.address().to_string().c_str(), m_remote_endpoint.port()); + return; + } + + ESP_LOGI(TAG, "length=%zd", length); +} diff --git a/src/asio_web/websocketclientconnection.h b/src/asio_web/websocketclientconnection.h index 439a8de..3477eee 100644 --- a/src/asio_web/websocketclientconnection.h +++ b/src/asio_web/websocketclientconnection.h @@ -11,7 +11,7 @@ class ResponseHandler; class WebsocketClientConnection : public std::enable_shared_from_this { public: - WebsocketClientConnection(Webserver &webserver, asio::ip::tcp::socket socket, std::unique_ptr &&responseHandler); + WebsocketClientConnection(Webserver &webserver, asio::ip::tcp::socket socket, std::string &&parsingBuffer, std::unique_ptr &&responseHandler); ~WebsocketClientConnection(); Webserver &webserver() { return m_webserver; } @@ -28,14 +28,19 @@ private: void doReadWebSocket(); void readyReadWebSocket(std::error_code ec, std::size_t length); + void sendMessage(bool fin, uint8_t reserved, uint8_t opcode, bool mask, std::string_view payload); + void onMessageSent(std::error_code ec, std::size_t length); + Webserver &m_webserver; asio::ip::tcp::socket m_socket; const asio::ip::tcp::endpoint m_remote_endpoint; - std::unique_ptr m_responseHandler; - - static constexpr const std::size_t max_length = 4; + static constexpr const std::size_t max_length = 1024; char m_receiveBuffer[max_length]; std::string m_parsingBuffer; + + std::unique_ptr m_responseHandler; + + std::string m_sendBuffer; }; diff --git a/src/asio_web/websocketstream.cpp b/src/asio_web/websocketstream.cpp new file mode 100644 index 0000000..65e5dc6 --- /dev/null +++ b/src/asio_web/websocketstream.cpp @@ -0,0 +1 @@ +#include "websocketstream.h" diff --git a/src/asio_web/websocketstream.h b/src/asio_web/websocketstream.h new file mode 100644 index 0000000..0a88941 --- /dev/null +++ b/src/asio_web/websocketstream.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +#pragma pack(push,1) +struct WebsocketHeader { + uint8_t opcode:4; + uint8_t reserved:3; + bool fin:1; + uint8_t payloadLength:7; + bool mask:1; +}; +#pragma pack(pop) diff --git a/test/asio_web_tests.pro b/test/asio_web_tests.pro index 96363e9..20788e5 100644 --- a/test/asio_web_tests.pro +++ b/test/asio_web_tests.pro @@ -72,7 +72,10 @@ equals(CLONE_FMT, 1) { SUBDIRS += \ asio_web.pro \ - webserver_example + webserver_example \ + websocket_client_example sub-webserver_example.depends += sub-asio_web-pro webserver_example.depends += sub-asio_web-pro +sub-websocket_client_example.depends += sub-asio_web-pro +websocket_client_example.depends += sub-asio_web-pro diff --git a/test/webserver_example/main.cpp b/test/webserver_example/main.cpp index 1341cfd..e8f3c0f 100644 --- a/test/webserver_example/main.cpp +++ b/test/webserver_example/main.cpp @@ -5,7 +5,6 @@ #include // 3rdparty lib includes -#include #include // local includes diff --git a/test/webserver_example/websocketresponsehandler.cpp b/test/webserver_example/websocketresponsehandler.cpp index a71f329..632af38 100644 --- a/test/webserver_example/websocketresponsehandler.cpp +++ b/test/webserver_example/websocketresponsehandler.cpp @@ -267,7 +267,7 @@ void WebsocketResponseHandler::writtenHtml(std::error_code ec, std::size_t lengt return; } - ESP_LOGI(TAG, "expected=%zd actual=%zd for (%s:%hi)", m_response.size(), length, + ESP_LOGI(TAG, "expected=%zd actual=%zd for (%s:%hi)", html.size(), length, m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port()); m_clientConnection.responseFinished(ec); diff --git a/test/websocket_client_example/main.cpp b/test/websocket_client_example/main.cpp new file mode 100644 index 0000000..978e907 --- /dev/null +++ b/test/websocket_client_example/main.cpp @@ -0,0 +1,39 @@ +#include + +// esp-idf includes +#include +#include +#include +#include + +// 3rdparty lib includes +#include + +// local includes +#include "websocketclient.h" + +int main(int argc, char *argv[]) +{ + CPP_UNUSED(argc) + CPP_UNUSED(argv) + + qSetMessagePattern(QStringLiteral("%{time dd.MM.yyyy HH:mm:ss.zzz} " + "[" + "%{if-debug}D%{endif}" + "%{if-info}I%{endif}" + "%{if-warning}W%{endif}" + "%{if-critical}C%{endif}" + "%{if-fatal}F%{endif}" + "] " + "%{function}(): " + "%{message}")); + + asio::io_context io_context; + + WebsocketClient c{io_context}; + c.start(); + + ESP_LOGI(TAG, "running mainloop"); + + io_context.run(); +} diff --git a/test/websocket_client_example/websocket_client_example.pro b/test/websocket_client_example/websocket_client_example.pro new file mode 100644 index 0000000..d2900af --- /dev/null +++ b/test/websocket_client_example/websocket_client_example.pro @@ -0,0 +1,28 @@ +TEMPLATE = app + +QT += core + +CONFIG += c++latest + +HEADERS += \ + websocketclient.h + +SOURCES += \ + main.cpp \ + websocketclient.cpp + +unix: TARGET=websocket_client_example.bin +DESTDIR=$${OUT_PWD}/.. +INCLUDEPATH += $$PWD/.. + +include(../paths.pri) + +include(../dependencies.pri) + +unix: { + LIBS += -Wl,-rpath=\\\$$ORIGIN +} +LIBS += -L$${OUT_PWD}/.. +LIBS += -lasio_web + +LIBS += -lssl -lcrypto diff --git a/test/websocket_client_example/websocketclient.cpp b/test/websocket_client_example/websocketclient.cpp new file mode 100644 index 0000000..99f0142 --- /dev/null +++ b/test/websocket_client_example/websocketclient.cpp @@ -0,0 +1,502 @@ +#include "websocketclient.h" + +// esp-idf includes +#include + +// 3rdparty lib includes +#include +#include + +// local includes +#include "asio_web/websocketstream.h" + +namespace { + +const std::string_view request = "GET /charger/99999999 HTTP/1.1\r\n" + "Host: localhost\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + +} // namespace + +WebsocketClient::WebsocketClient(asio::io_context &io_context) : + m_resolver{io_context}, + //m_socket{io_context, m_sslCtx}, + m_socket{io_context} +{ + //m_socket.set_verify_mode(asio::ssl::verify_none); +} + +void WebsocketClient::start() +{ + resolve(); +} + +void WebsocketClient::resolve() +{ + ESP_LOGI(TAG, "called"); + +// m_resolver.async_resolve("backend.com", "8086", +// [this](const std::error_code &error, asio::ip::tcp::resolver::iterator iterator){ +// onResolved(error, iterator); +// }); + m_resolver.async_resolve("localhost", "1234", + [this](const std::error_code &error, asio::ip::tcp::resolver::iterator iterator){ + onResolved(error, iterator); + }); +} + +void WebsocketClient::onResolved(const std::error_code &error, asio::ip::tcp::resolver::iterator iterator) +{ + if (error) + { + ESP_LOGW(TAG, "Resolving failed: %s", error.message().c_str()); + return; + } + + connect(iterator); +} + +void WebsocketClient::connect(const asio::ip::tcp::resolver::iterator &endpoints) +{ + ESP_LOGI(TAG, "called"); + + asio::async_connect(m_socket.lowest_layer(), endpoints, + [this](const std::error_code & error, const asio::ip::tcp::resolver::iterator &) { + onConnected(error); + }); +} + +void WebsocketClient::onConnected(const std::error_code &error) +{ + if (error) + { + ESP_LOGW(TAG, "Connect failed: %s", error.message().c_str()); + return; + } + +// handshake(); + send_request(); +} + +void WebsocketClient::handshake() +{ + ESP_LOGI(TAG, "called"); + +// m_socket.async_handshake(asio::ssl::stream_base::client, +// [this](const std::error_code &error) { +// onHandshaked(error); +// }); +} + +void WebsocketClient::onHandshaked(const std::error_code &error) +{ + if (error) + { + ESP_LOGW(TAG, "Handshake failed: %s", error.message().c_str()); + return; + } + + send_request(); +} + +void WebsocketClient::send_request() +{ + ESP_LOGI(TAG, "called %.*s", request.size(), request.data()); + + m_state = State::Request; + + asio::async_write(m_socket, + asio::buffer(request.data(), request.size()), + [this](const std::error_code &error, std::size_t length) { + onSentRequest(error, length); + }); +} + +void WebsocketClient::onSentRequest(const std::error_code &error, std::size_t length) +{ + if (error) + { + ESP_LOGW(TAG, "Write failed: %s", error.message().c_str()); + return; + } + + ESP_LOGI(TAG, "called %zd (%zd)", length, request.size()); + + m_state = State::ResponseLine; + + receive_response(); +} + +void WebsocketClient::receive_response() +{ + ESP_LOGI(TAG, "called"); + + m_socket.async_read_some(asio::buffer(m_receiveBuffer, std::size(m_receiveBuffer)), + [this](const std::error_code &error, std::size_t length) { + onReceivedResponse(error, length); + }); +} + +void WebsocketClient::onReceivedResponse(const std::error_code &error, std::size_t length) +{ + if (error) + { + ESP_LOGI(TAG, "Read failed: %s", error.message().c_str()); + return; + } + + ESP_LOGI(TAG, "received %.*s", length, m_receiveBuffer); + m_parsingBuffer.append(m_receiveBuffer, length); + + bool shouldDoRead{true}; + + while (true) + { + constexpr std::string_view newLine{"\r\n"}; + const auto index = m_parsingBuffer.find(newLine.data(), 0, newLine.size()); + if (index == std::string::npos) + break; + + std::string line{m_parsingBuffer.data(), index}; + +// ESP_LOGD(TAG, "line: %zd \"%.*s\"", line.size(), line.size(), line.data()); + + m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), line.size() + newLine.size())); + + if (!readyReadLine(line)) + shouldDoRead = false; + if (m_state == State::WebSocket) + break; + } + + if (shouldDoRead) + { + if (m_state == State::WebSocket) + doReadWebSocket(); + else + receive_response(); + } +} + +bool WebsocketClient::readyReadLine(std::string_view line) +{ + switch (m_state) + { + case State::Request: +// ESP_LOGV(TAG, "case State::Request:"); + ESP_LOGW(TAG, "unexpected state=Request"); + return true; + case State::ResponseLine: +// ESP_LOGV(TAG, "case State::StatusLine:"); + return parseResponseLine(line); + case State::ResponseHeaders: +// ESP_LOGV(TAG, "case State::ResponseHeaders:"); + return parseResponseHeader(line); + case State::ResponseBody: +// ESP_LOGV(TAG, "case State::RequestBody:"); + ESP_LOGW(TAG, "unexpected state=ResponseBody"); + return true; + default: + ESP_LOGW(TAG, "unknown state %i", std::to_underlying(m_state)); + return true; + } +} + +bool WebsocketClient::parseResponseLine(std::string_view line) +{ +// ESP_LOGV(TAG, "%.*s", line.size(), line.data()); + + if (const auto index = line.find(' '); index == std::string::npos) + { + ESP_LOGW(TAG, "invalid response line (1): \"%.*s\"", line.size(), line.data()); + //m_socket.close(); + return false; + } + else + { + const std::string_view protocol { line.data(), index }; + ESP_LOGV(TAG, "response protocol: %zd \"%.*s\"", protocol.size(), protocol.size(), protocol.data()); + + if (const auto index2 = line.find(' ', index + 1); index2 == std::string::npos) + { + ESP_LOGW(TAG, "invalid request line (2): \"%.*s\"", line.size(), line.data()); + //m_socket.close(); + return false; + } + else + { + const std::string_view status { line.data() + index + 1, line.data() + index2 }; + ESP_LOGV(TAG, "response status: %zd \"%.*s\"", status.size(), status.size(), status.data()); + + const std::string_view message { line.cbegin() + index2 + 1, line.cend() }; + ESP_LOGV(TAG, "response message: %zd \"%.*s\"", message.size(), message.size(), message.data()); + +// ESP_LOGV(TAG, "state changed to ResponseHeaders"); + m_state = State::ResponseHeaders; + + return true; + } + } +} + +bool WebsocketClient::parseResponseHeader(std::string_view line) +{ +// ESP_LOGV(TAG, "%.*s", line.size(), line.data()); + + if (!line.empty()) + { + constexpr std::string_view sep{": "}; + if (const auto index = line.find(sep.data(), 0, sep.size()); index == std::string_view::npos) + { + ESP_LOGW(TAG, "invalid request header: %zd \"%.*s\"", line.size(), line.size(), line.data()); + //m_socket.close(); + return false; + } + else + { + std::string_view key{line.data(), index}; + std::string_view value{std::begin(line) + index + sep.size(), std::end(line)}; + + ESP_LOGD(TAG, "header key=\"%.*s\" value=\"%.*s\"", key.size(), key.data(), value.size(), value.data()); + + if (cpputils::stringEqualsIgnoreCase(key, "Content-Length")) + { + if (const auto parsed = cpputils::fromString(value); !parsed) + { + ESP_LOGW(TAG, "invalid Content-Length %.*s %.*s", value.size(), value.data(), + parsed.error().size(), parsed.error().data()); + //m_socket.close(); + return false; + } + else + m_responseBodySize = *parsed; + } + + return true; + } + } + else + { + if (m_responseBodySize) + { +// ESP_LOGV(TAG, "state changed to ResponseBody"); + m_state = State::ResponseBody; + + if (!m_parsingBuffer.empty()) + { + if (m_parsingBuffer.size() <= m_responseBodySize) + { +// m_responseHandler->requestBodyReceived(m_parsingBuffer); + m_responseBodySize -= m_parsingBuffer.size(); + m_parsingBuffer.clear(); + + if (!m_responseBodySize) + goto requestFinished; + + return true; + } + else + { +// m_responseHandler->requestBodyReceived({m_parsingBuffer.data(), m_responseBodySize}); + m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), m_responseBodySize)); + m_responseBodySize = 0; + goto requestFinished; + } + } + else + return true; + } + else + { + requestFinished: + ESP_LOGI(TAG, "finished"); + +// ESP_LOGV(TAG, "state changed to WebSocket"); + m_state = State::WebSocket; + +// m_responseHandler->sendResponse(); + + return true; + } + } +} + +void WebsocketClient::doReadWebSocket() +{ + ESP_LOGI(TAG, "called"); + + m_socket.async_read_some(asio::buffer(m_receiveBuffer, std::size(m_receiveBuffer)), + [this](const std::error_code &error, std::size_t length) { + onReceiveWebsocket(error, length); + }); +} + +void WebsocketClient::onReceiveWebsocket(const std::error_code &error, std::size_t length) +{ + if (error) + { + ESP_LOGI(TAG, "error: %i %s", error.value(), error.message().c_str()); + return; + } + +// ESP_LOGV(TAG, "received: %zd \"%.*s\"", length, length, m_receiveBuffer); + + m_parsingBuffer.append({m_receiveBuffer, length}); + +again: + // ESP_LOGV(TAG, "m_parsingBuffer: %s", cpputils::toHexString(m_parsingBuffer).c_str()); + + if (m_parsingBuffer.empty()) + { + doReadWebSocket(); + return; + } + + static_assert(sizeof(WebsocketHeader) == 2); + + if (m_parsingBuffer.size() < sizeof(WebsocketHeader)) + { + ESP_LOGW(TAG, "buffer smaller than a websocket header"); + doReadWebSocket(); + return; + } + + // ESP_LOGV(TAG, "%s%s%s%s %s%s%s%s %s%s%s%s %s%s%s%s", + // m_parsingBuffer.data()[0]&128?"1":".", m_parsingBuffer.data()[0]&64?"1":".", m_parsingBuffer.data()[0]&32?"1":".", m_parsingBuffer.data()[0]&16?"1":".", + // m_parsingBuffer.data()[0]&8?"1":".", m_parsingBuffer.data()[0]&4?"1":".", m_parsingBuffer.data()[0]&2?"1":".", m_parsingBuffer.data()[0]&1?"1":".", + // m_parsingBuffer.data()[1]&128?"1":".", m_parsingBuffer.data()[1]&64?"1":".", m_parsingBuffer.data()[1]&32?"1":".", m_parsingBuffer.data()[1]&16?"1":".", + // m_parsingBuffer.data()[1]&8?"1":".", m_parsingBuffer.data()[1]&4?"1":".", m_parsingBuffer.data()[1]&2?"1":".", m_parsingBuffer.data()[1]&1?"1":"."); + + auto iter = std::begin(m_parsingBuffer); + + const WebsocketHeader &hdr = *(const WebsocketHeader *)(&*iter); + std::advance(iter, sizeof(WebsocketHeader)); + + // ESP_LOGV(TAG, "fin=%i reserved=%i opcode=%i mask=%i payloadLength=%i", hdr.fin, hdr.reserved, hdr.opcode, hdr.mask, hdr.payloadLength); + + uint64_t payloadLength = hdr.payloadLength; + + if (hdr.payloadLength == 126) + { + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint16_t)) + { + ESP_LOGW(TAG, "buffer smaller than uint32_t payloadLength"); + doReadWebSocket(); + return; + } + + payloadLength = __builtin_bswap16(*(const uint16_t *)(&*iter)); + std::advance(iter, sizeof(uint16_t)); + + // ESP_LOGV(TAG, "16bit payloadLength: %u", payloadLength); + } + else if (hdr.payloadLength == 127) + { + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint64_t)) + { + ESP_LOGW(TAG, "buffer smaller than uint64_t payloadLength"); + doReadWebSocket(); + return; + } + + payloadLength = *(const uint64_t *)(&*iter); + std::advance(iter, sizeof(uint64_t)); + + ESP_LOGI(TAG, "64bit payloadLength: %u", payloadLength); + } + + if (hdr.mask) + { + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint32_t)) + { + ESP_LOGW(TAG, "buffer smaller than uint32_t mask"); + doReadWebSocket(); + return; + } + + union { + uint32_t mask; + uint8_t maskArr[4]; + }; + mask = *(const uint32_t *)(&*iter); + std::advance(iter, sizeof(uint32_t)); + + if (std::distance(iter, std::end(m_parsingBuffer)) < payloadLength) + { + ESP_LOGW(TAG, "masked buffer smaller payloadLength"); + doReadWebSocket(); + return; + } + + auto iter2 = std::begin(maskArr); + for (auto iter3 = iter; + iter3 != std::end(m_parsingBuffer) && iter3 != std::next(iter, payloadLength); + iter3++) + { + *iter3 ^= *(iter2++); + if (iter2 == std::end(maskArr)) + iter2 = std::begin(maskArr); + } + } + else if (std::distance(iter, std::end(m_parsingBuffer)) < payloadLength) + { + ESP_LOGW(TAG, "buffer smaller payloadLength"); + doReadWebSocket(); + return; + } + + ESP_LOGI(TAG, "remaining: %zd %lu", std::distance(iter, std::end(m_parsingBuffer)), payloadLength); + + ESP_LOGI(TAG, "payload: %.*s", payloadLength, &*iter); + + std::advance(iter, payloadLength); + m_parsingBuffer.erase(std::begin(m_parsingBuffer), iter); + + sendMessage(true, 0, 1, true, "{\"type\":\"hello\"}"); + + goto again; +} + +void WebsocketClient::sendMessage(bool fin, uint8_t reserved, uint8_t opcode, bool mask, std::string_view payload) +{ + ESP_LOGI(TAG, "%.*s", payload.size(), payload.data()); + + auto sendBuffer = std::make_shared(); + sendBuffer->resize(2); + { + auto iter = std::begin(*sendBuffer); + WebsocketHeader &hdr = *(WebsocketHeader *)(&*iter); + hdr.fin = fin; + hdr.reserved = reserved; + hdr.opcode = opcode; + hdr.mask = mask; + hdr.payloadLength = payload.size(); + } + if (mask) + { + sendBuffer->append(4, '\0'); + ESP_LOGI(TAG, "sendBuffer size %zd", sendBuffer->size()); + assert(sendBuffer->size() == 6); + } + sendBuffer->append(payload); + + asio::async_write(m_socket, + asio::buffer(sendBuffer->data(), sendBuffer->size()), + [this, sendBuffer](std::error_code ec, std::size_t length) + { onMessageSent(ec, length, sendBuffer->size()); }); +} + +void WebsocketClient::onMessageSent(std::error_code ec, std::size_t length, std::size_t expectedLength) +{ + if (ec) + { + ESP_LOGW(TAG, "error: %i", ec.value()); + return; + } + + ESP_LOGI(TAG, "length=%zd expected=%zd", length, expectedLength); +} diff --git a/test/websocket_client_example/websocketclient.h b/test/websocket_client_example/websocketclient.h new file mode 100644 index 0000000..bf7ce45 --- /dev/null +++ b/test/websocket_client_example/websocketclient.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +class WebsocketClient +{ +public: + WebsocketClient(asio::io_context &io_context); + + void start(); + +private: + void resolve(); + void onResolved(const std::error_code &error, asio::ip::tcp::resolver::iterator iterator); + void connect(const asio::ip::tcp::resolver::iterator &endpoints); + void onConnected(const std::error_code &error); + void handshake(); + void onHandshaked(const std::error_code & error); + void send_request(); + void onSentRequest(const std::error_code &error, std::size_t length); + void receive_response(); + void onReceivedResponse(const std::error_code &error, std::size_t length); + bool readyReadLine(std::string_view line); + bool parseResponseLine(std::string_view line); + bool parseResponseHeader(std::string_view line); + void doReadWebSocket(); + void onReceiveWebsocket(const std::error_code &error, std::size_t length); + + void sendMessage(bool fin, uint8_t reserved, uint8_t opcode, bool mask, std::string_view payload); + void onMessageSent(std::error_code ec, std::size_t length, std::size_t expectedLength); + + asio::ip::tcp::resolver m_resolver; + //asio::ssl::context m_sslCtx{asio::ssl::context::tls_client}; + //asio::ssl::stream m_socket; + asio::ip::tcp::socket m_socket; + char m_receiveBuffer[1024]; + + enum class State { Request, ResponseLine, ResponseHeaders, ResponseBody, WebSocket }; + State m_state { State::Request }; + + std::string m_parsingBuffer; + + std::size_t m_responseBodySize{}; +};