diff --git a/CMakeLists.txt b/CMakeLists.txt index ff74f58..b0f3b15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,17 +1,21 @@ set(headers src/asio_web/clientconnection.h src/asio_web/responsehandler.h + src/asio_web/sslwebsocketclient.h src/asio_web/webserver.h src/asio_web/websocketclientconnection.h src/asio_web/websocketstream.h + src/asio_web/websocketclient.h ) set(sources src/asio_web/clientconnection.cpp src/asio_web/responsehandler.cpp + src/asio_web/sslwebsocketclient.cpp src/asio_web/webserver.cpp src/asio_web/websocketclientconnection.cpp src/asio_web/websocketstream.cpp + src/asio_web/websocketclient.cpp ) set(dependencies diff --git a/asio_web_src.pri b/asio_web_src.pri index 0c64de3..72e2e85 100644 --- a/asio_web_src.pri +++ b/asio_web_src.pri @@ -1,13 +1,17 @@ HEADERS += \ $$PWD/src/asio_web/clientconnection.h \ $$PWD/src/asio_web/responsehandler.h \ + $$PWD/src/asio_web/sslwebsocketclient.h \ $$PWD/src/asio_web/webserver.h \ $$PWD/src/asio_web/websocketclientconnection.h \ - $$PWD/src/asio_web/websocketstream.h + $$PWD/src/asio_web/websocketstream.h \ + $$PWD/src/asio_web/websocketclient.h SOURCES += \ $$PWD/src/asio_web/clientconnection.cpp \ $$PWD/src/asio_web/responsehandler.cpp \ + $$PWD/src/asio_web/sslwebsocketclient.cpp \ $$PWD/src/asio_web/webserver.cpp \ $$PWD/src/asio_web/websocketclientconnection.cpp \ - $$PWD/src/asio_web/websocketstream.cpp + $$PWD/src/asio_web/websocketstream.cpp \ + $$PWD/src/asio_web/websocketclient.cpp diff --git a/src/asio_web/clientconnection.cpp b/src/asio_web/clientconnection.cpp index b86f8eb..1ea46e5 100644 --- a/src/asio_web/clientconnection.cpp +++ b/src/asio_web/clientconnection.cpp @@ -17,7 +17,7 @@ #include "websocketclientconnection.h" namespace { -constexpr const char * const TAG = "ASIO_WEBSERVER"; +constexpr const char * const TAG = "ASIO_WEB"; } // namespace ClientConnection::ClientConnection(Webserver &webserver, asio::ip::tcp::socket socket) : diff --git a/src/asio_web/sslwebsocketclient.cpp b/src/asio_web/sslwebsocketclient.cpp new file mode 100644 index 0000000..6ed4744 --- /dev/null +++ b/src/asio_web/sslwebsocketclient.cpp @@ -0,0 +1,564 @@ +#include "sslwebsocketclient.h" + +// esp-idf includes +#include + +// 3rdparty lib includes +#include +#include +#include + +// local includes +#include "websocketstream.h" + +namespace { +constexpr const char * const TAG = "ASIO_WEB"; +} // namespace + +SslWebsocketClient::SslWebsocketClient(asio::io_context &io_context, std::string &&host, std::string &&path) : + m_host(std::move(host)), + m_path{std::move(path)}, + m_resolver{io_context}, + m_socket{io_context, m_sslCtx} + //m_socket{io_context} +{ + m_socket.set_verify_mode(asio::ssl::verify_none); +} + +SslWebsocketClient::SslWebsocketClient(asio::io_context &io_context, const std::string &host, const std::string &path) : + m_host{host}, + m_path{path}, + m_resolver{io_context}, + m_socket{io_context, m_sslCtx} + //m_socket{io_context} +{ + m_socket.set_verify_mode(asio::ssl::verify_none); +} + +void SslWebsocketClient::start() +{ + ESP_LOGI(TAG, "called"); + resolve(); +} + +void SslWebsocketClient::resolve() +{ + ESP_LOGI(TAG, "called"); + + m_resolver.async_resolve(m_host, "8086", + [this](const std::error_code &error, asio::ip::tcp::resolver::iterator iterator){ + onResolved(error, iterator); + }); +// m_resolver.async_resolve("ruezn.local", "1234", +// [this](const std::error_code &error, asio::ip::tcp::resolver::iterator iterator){ +// onResolved(error, iterator); +// }); +} + +void SslWebsocketClient::onResolved(const std::error_code &error, asio::ip::tcp::resolver::iterator iterator) +{ + if (error) + { + ESP_LOGW(TAG, "Resolving failed: %s", error.message().c_str()); + return; + } + + ESP_LOGI(TAG, "called"); + + connect(iterator); +} + +void SslWebsocketClient::connect(const asio::ip::tcp::resolver::iterator &endpoints) +{ + ESP_LOGI(TAG, "called"); + + asio::async_connect(m_socket.lowest_layer(), endpoints, + [this](const std::error_code & error, const asio::ip::tcp::resolver::iterator &) { + onConnected(error); + }); +} + +void SslWebsocketClient::onConnected(const std::error_code &error) +{ + if (error) + { + ESP_LOGW(TAG, "Connect failed: %s", error.message().c_str()); + return; + } + + ESP_LOGI(TAG, "called"); + + handshake(); +// send_request(); +} + +void SslWebsocketClient::handshake() +{ + ESP_LOGI(TAG, "called"); + + m_socket.async_handshake(asio::ssl::stream_base::client, + [this](const std::error_code &error) { + onHandshaked(error); + }); +} + +void SslWebsocketClient::onHandshaked(const std::error_code &error) +{ + if (error) + { + ESP_LOGW(TAG, "Handshake failed: %s", error.message().c_str()); + return; + } + + ESP_LOGI(TAG, "called"); + + send_request(); +} + +void SslWebsocketClient::send_request() +{ + m_sending = fmt::format("GET {} HTTP/1.1\r\n" + "Host: {}\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n", m_path, m_host); + ESP_LOGI(TAG, "called %.*s", m_sending->size(), m_sending->data()); + + m_state = State::Request; + + asio::async_write(m_socket, + asio::buffer(m_sending->data(), m_sending->size()), + [this](const std::error_code &error, std::size_t length) { + onSentRequest(error, length); + }); +} + +void SslWebsocketClient::onSentRequest(const std::error_code &error, std::size_t length) +{ + if (error) + { + ESP_LOGW(TAG, "Write failed: %s", error.message().c_str()); + m_sending = std::nullopt; + return; + } + + ESP_LOGI(TAG, "called %zd (%zd)", length, m_sending->size()); + + m_sending = std::nullopt; + m_state = State::ResponseLine; + + receive_response(); +} + +void SslWebsocketClient::receive_response() +{ + ESP_LOGI(TAG, "called"); + + m_socket.async_read_some(asio::buffer(m_receiveBuffer, std::size(m_receiveBuffer)), + [this](const std::error_code &error, std::size_t length) { + onReceivedResponse(error, length); + }); +} + +void SslWebsocketClient::onReceivedResponse(const std::error_code &error, std::size_t length) +{ + if (error) + { + ESP_LOGI(TAG, "Read failed: %s", error.message().c_str()); + return; + } + + ESP_LOGI(TAG, "received %.*s", length, m_receiveBuffer); + m_parsingBuffer.append(m_receiveBuffer, length); + + bool shouldDoRead{true}; + + while (true) + { + constexpr std::string_view newLine{"\r\n"}; + const auto index = m_parsingBuffer.find(newLine.data(), 0, newLine.size()); + if (index == std::string::npos) + break; + + std::string line{m_parsingBuffer.data(), index}; + + // ESP_LOGD(TAG, "line: %zd \"%.*s\"", line.size(), line.size(), line.data()); + + m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), line.size() + newLine.size())); + + if (!readyReadLine(line)) + shouldDoRead = false; + if (m_state == State::WebSocket) + break; + } + + if (shouldDoRead) + { + if (m_state == State::WebSocket) + doReadWebSocket(); + else + receive_response(); + } +} + +bool SslWebsocketClient::readyReadLine(std::string_view line) +{ + switch (m_state) + { + case State::Request: +// ESP_LOGV(TAG, "case State::Request:"); + ESP_LOGW(TAG, "unexpected state=Request"); + return true; + case State::ResponseLine: +// ESP_LOGV(TAG, "case State::StatusLine:"); + return parseResponseLine(line); + case State::ResponseHeaders: +// ESP_LOGV(TAG, "case State::ResponseHeaders:"); + return parseResponseHeader(line); + case State::ResponseBody: +// ESP_LOGV(TAG, "case State::RequestBody:"); + ESP_LOGW(TAG, "unexpected state=ResponseBody"); + return true; + default: + ESP_LOGW(TAG, "unknown state %i", std::to_underlying(m_state)); + return true; + } +} + +bool SslWebsocketClient::parseResponseLine(std::string_view line) +{ +// ESP_LOGV(TAG, "%.*s", line.size(), line.data()); + + if (const auto index = line.find(' '); index == std::string::npos) + { + ESP_LOGW(TAG, "invalid response line (1): \"%.*s\"", line.size(), line.data()); + //m_socket.close(); + return false; + } + else + { + const std::string_view protocol { line.data(), index }; + ESP_LOGV(TAG, "response protocol: %zd \"%.*s\"", protocol.size(), protocol.size(), protocol.data()); + + if (const auto index2 = line.find(' ', index + 1); index2 == std::string::npos) + { + ESP_LOGW(TAG, "invalid request line (2): \"%.*s\"", line.size(), line.data()); + //m_socket.close(); + return false; + } + else + { + const std::string_view status { line.data() + index + 1, line.data() + index2 }; + ESP_LOGV(TAG, "response status: %zd \"%.*s\"", status.size(), status.size(), status.data()); + + const std::string_view message { line.cbegin() + index2 + 1, line.cend() }; + ESP_LOGV(TAG, "response message: %zd \"%.*s\"", message.size(), message.size(), message.data()); + +// ESP_LOGV(TAG, "state changed to ResponseHeaders"); + m_state = State::ResponseHeaders; + + return true; + } + } +} + +bool SslWebsocketClient::parseResponseHeader(std::string_view line) +{ +// ESP_LOGV(TAG, "%.*s", line.size(), line.data()); + + if (!line.empty()) + { + constexpr std::string_view sep{": "}; + if (const auto index = line.find(sep.data(), 0, sep.size()); index == std::string_view::npos) + { + ESP_LOGW(TAG, "invalid request header: %zd \"%.*s\"", line.size(), line.size(), line.data()); + //m_socket.close(); + return false; + } + else + { + std::string_view key{line.data(), index}; + std::string_view value{std::begin(line) + index + sep.size(), std::end(line)}; + + ESP_LOGD(TAG, "header key=\"%.*s\" value=\"%.*s\"", key.size(), key.data(), value.size(), value.data()); + + if (cpputils::stringEqualsIgnoreCase(key, "Content-Length")) + { + if (const auto parsed = cpputils::fromString(value); !parsed) + { + ESP_LOGW(TAG, "invalid Content-Length %.*s %.*s", value.size(), value.data(), + parsed.error().size(), parsed.error().data()); + //m_socket.close(); + return false; + } + else + m_responseBodySize = *parsed; + } + + return true; + } + } + else + { + if (m_responseBodySize) + { +// ESP_LOGV(TAG, "state changed to ResponseBody"); + m_state = State::ResponseBody; + + if (!m_parsingBuffer.empty()) + { + if (m_parsingBuffer.size() <= m_responseBodySize) + { +// m_responseHandler->requestBodyReceived(m_parsingBuffer); + m_responseBodySize -= m_parsingBuffer.size(); + m_parsingBuffer.clear(); + + if (!m_responseBodySize) + goto requestFinished; + + return true; + } + else + { +// m_responseHandler->requestBodyReceived({m_parsingBuffer.data(), m_responseBodySize}); + m_parsingBuffer.erase(std::begin(m_parsingBuffer), std::next(std::begin(m_parsingBuffer), m_responseBodySize)); + m_responseBodySize = 0; + goto requestFinished; + } + } + else + return true; + } + else + { + requestFinished: + ESP_LOGI(TAG, "finished"); + + handleConnected(); + +// ESP_LOGV(TAG, "state changed to WebSocket"); + m_state = State::WebSocket; + +// m_responseHandler->sendResponse(); + + return true; + } + } +} + +void SslWebsocketClient::doReadWebSocket() +{ + ESP_LOGI(TAG, "called"); + + m_socket.async_read_some(asio::buffer(m_receiveBuffer, std::size(m_receiveBuffer)), + [this](const std::error_code &error, std::size_t length) { + onReceiveWebsocket(error, length); + }); +} + +void SslWebsocketClient::onReceiveWebsocket(const std::error_code &error, std::size_t length) +{ + if (error) + { + ESP_LOGI(TAG, "error: %i %s", error.value(), error.message().c_str()); + return; + } + +// ESP_LOGV(TAG, "received: %zd \"%.*s\"", length, length, m_receiveBuffer); + + m_parsingBuffer.append({m_receiveBuffer, length}); + +again: +// ESP_LOGV(TAG, "m_parsingBuffer: %s", cpputils::toHexString(m_parsingBuffer).c_str()); + + if (m_parsingBuffer.empty()) + { + doReadWebSocket(); + return; + } + + static_assert(sizeof(WebsocketHeader) == 2); + + if (m_parsingBuffer.size() < sizeof(WebsocketHeader)) + { + ESP_LOGW(TAG, "buffer smaller than a websocket header"); + doReadWebSocket(); + return; + } + +// ESP_LOGV(TAG, "%s%s%s%s %s%s%s%s %s%s%s%s %s%s%s%s", +// m_parsingBuffer.data()[0]&128?"1":".", m_parsingBuffer.data()[0]&64?"1":".", m_parsingBuffer.data()[0]&32?"1":".", m_parsingBuffer.data()[0]&16?"1":".", +// m_parsingBuffer.data()[0]&8?"1":".", m_parsingBuffer.data()[0]&4?"1":".", m_parsingBuffer.data()[0]&2?"1":".", m_parsingBuffer.data()[0]&1?"1":".", +// m_parsingBuffer.data()[1]&128?"1":".", m_parsingBuffer.data()[1]&64?"1":".", m_parsingBuffer.data()[1]&32?"1":".", m_parsingBuffer.data()[1]&16?"1":".", +// m_parsingBuffer.data()[1]&8?"1":".", m_parsingBuffer.data()[1]&4?"1":".", m_parsingBuffer.data()[1]&2?"1":".", m_parsingBuffer.data()[1]&1?"1":"."); + + auto iter = std::begin(m_parsingBuffer); + + const WebsocketHeader &hdr = *(const WebsocketHeader *)(&*iter); + std::advance(iter, sizeof(WebsocketHeader)); + + ESP_LOGI(TAG, "fin=%i reserved=%i opcode=%i mask=%i payloadLength=%i", hdr.fin, hdr.reserved, hdr.opcode, hdr.mask, hdr.payloadLength); + + uint64_t payloadLength = hdr.payloadLength; + + if (hdr.payloadLength == 126) + { + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint16_t)) + { + ESP_LOGW(TAG, "buffer smaller than uint32_t payloadLength"); + doReadWebSocket(); + return; + } + + payloadLength = __builtin_bswap16(*(const uint16_t *)(&*iter)); + std::advance(iter, sizeof(uint16_t)); + + ESP_LOGI(TAG, "16bit payloadLength: %u", payloadLength); + } + else if (hdr.payloadLength == 127) + { + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint64_t)) + { + ESP_LOGW(TAG, "buffer smaller than uint64_t payloadLength"); + doReadWebSocket(); + return; + } + + payloadLength = *(const uint64_t *)(&*iter); + std::advance(iter, sizeof(uint64_t)); + + ESP_LOGI(TAG, "64bit payloadLength: %u", payloadLength); + } + + if (hdr.mask) + { + if (std::distance(iter, std::end(m_parsingBuffer)) < sizeof(uint32_t)) + { + ESP_LOGW(TAG, "buffer smaller than uint32_t mask"); + doReadWebSocket(); + return; + } + + union { + uint32_t mask; + uint8_t maskArr[4]; + }; + mask = *(const uint32_t *)(&*iter); + std::advance(iter, sizeof(uint32_t)); + + if (std::distance(iter, std::end(m_parsingBuffer)) < payloadLength) + { + ESP_LOGW(TAG, "masked buffer smaller payloadLength"); + doReadWebSocket(); + return; + } + + auto iter2 = std::begin(maskArr); + for (auto iter3 = iter; + iter3 != std::end(m_parsingBuffer) && iter3 != std::next(iter, payloadLength); + iter3++) + { + *iter3 ^= *(iter2++); + if (iter2 == std::end(maskArr)) + iter2 = std::begin(maskArr); + } + } + else if (std::distance(iter, std::end(m_parsingBuffer)) < payloadLength) + { + ESP_LOGW(TAG, "buffer smaller payloadLength"); + doReadWebSocket(); + return; + } + + ESP_LOGI(TAG, "remaining: std::distance=%zd payloadLength=%llu", std::distance(iter, std::end(m_parsingBuffer)), payloadLength); + + std::string_view payload{&*iter, (unsigned int)(payloadLength)}; + + ESP_LOGI(TAG, "payload: %.*s", payload.size(), payload.data()); + + handleMessage(hdr.fin, hdr.reserved, hdr.opcode, hdr.mask, payload); + + std::advance(iter, payloadLength); + m_parsingBuffer.erase(std::begin(m_parsingBuffer), iter); + + goto again; +} + +void SslWebsocketClient::sendMessage(bool fin, uint8_t reserved, uint8_t opcode, bool mask, std::string_view payload) +{ + //ESP_LOGI(TAG, "%.*s", payload.size(), payload.data()); + + std::string sendBuffer; + sendBuffer.resize(2); + { + auto iter = std::begin(sendBuffer); + WebsocketHeader &hdr = *(WebsocketHeader *)(&*iter); + hdr.fin = fin; + hdr.reserved = reserved; + hdr.opcode = opcode; + hdr.mask = mask; + hdr.payloadLength = payload.size() < 126 ? payload.size() : 126; + } + if (payload.size() > 125) + { + union { + char buf[2]; + uint16_t length; + }; + length = __builtin_bswap16((uint16_t)payload.size()); + sendBuffer.append(std::string_view{buf, 2}); + } + if (mask) + { + sendBuffer.append(4, '\0'); + } + sendBuffer.append(payload); + + if (!m_sending) + { + m_sending = std::move(sendBuffer); + + asio::async_write(m_socket, + asio::buffer(m_sending->data(), m_sending->size()), + [this](std::error_code ec, std::size_t length) + { onMessageSent(ec, length); }); + } + else + { + m_sendingQueue.push(std::move(sendBuffer)); + + ESP_LOGI(TAG, "enqueueing %zd", m_sendingQueue.size()); + } +} + +void SslWebsocketClient::onMessageSent(std::error_code ec, std::size_t length) +{ + if (ec) + { + ESP_LOGW(TAG, "error: %i %s", ec.value(), ec.message().c_str()); + m_sending = std::nullopt; + m_sendingQueue = {}; + return; + } + +// ESP_LOGI(TAG, "length=%zd expected=%zd", length, m_sending->size()); + + if (m_sendingQueue.empty()) + { + m_sending = std::nullopt; + } + else + { + m_sending = m_sendingQueue.front(); + m_sendingQueue.pop(); + +// ESP_LOGI(TAG, "asio send %zd %.*s", m_sending->size(), (int)m_sending->size(), m_sending->data()); + + asio::async_write(m_socket, + asio::buffer(m_sending->data(), m_sending->size()), + [this](std::error_code ec, std::size_t length) + { onMessageSent(ec, length); }); + } +} diff --git a/src/asio_web/sslwebsocketclient.h b/src/asio_web/sslwebsocketclient.h new file mode 100644 index 0000000..8eaa9de --- /dev/null +++ b/src/asio_web/sslwebsocketclient.h @@ -0,0 +1,63 @@ +#pragma once + +// system include +#include +#include + +// esp-idf includes +#include +#include + +class SslWebsocketClient +{ +public: + SslWebsocketClient(asio::io_context &io_context, std::string &&host, std::string &&path); + SslWebsocketClient(asio::io_context &io_context, const std::string &host, const std::string &path); + + void start(); + + virtual void handleConnected() = 0; + virtual void handleMessage(bool fin, uint8_t reserved, uint8_t opcode, bool mask, std::string_view payload) = 0; + +private: + void resolve(); + void onResolved(const std::error_code &error, asio::ip::tcp::resolver::iterator iterator); + void connect(const asio::ip::tcp::resolver::iterator &endpoints); + void onConnected(const std::error_code &error); + void handshake(); + void onHandshaked(const std::error_code & error); + void send_request(); + void onSentRequest(const std::error_code &error, std::size_t length); + void receive_response(); + void onReceivedResponse(const std::error_code &error, std::size_t length); + bool readyReadLine(std::string_view line); + bool parseResponseLine(std::string_view line); + bool parseResponseHeader(std::string_view line); + void doReadWebSocket(); + void onReceiveWebsocket(const std::error_code &error, std::size_t length); + +public: + void sendMessage(bool fin, uint8_t reserved, uint8_t opcode, bool mask, std::string_view payload); + +private: + void onMessageSent(std::error_code ec, std::size_t length); + + std::string m_host; + std::string m_path; + + asio::ip::tcp::resolver m_resolver; + asio::ssl::context m_sslCtx{asio::ssl::context::tls_client}; + asio::ssl::stream m_socket; + //asio::ip::tcp::socket m_socket; + char m_receiveBuffer[1024]; + + enum class State { Request, ResponseLine, ResponseHeaders, ResponseBody, WebSocket }; + State m_state { State::Request }; + + std::string m_parsingBuffer; + + std::size_t m_responseBodySize{}; + + std::optional m_sending; + std::queue m_sendingQueue; +}; diff --git a/src/asio_web/webserver.cpp b/src/asio_web/webserver.cpp index b48c314..70c7809 100644 --- a/src/asio_web/webserver.cpp +++ b/src/asio_web/webserver.cpp @@ -7,7 +7,7 @@ #include "clientconnection.h" namespace { -constexpr const char * const TAG = "ASIO_WEBSERVER"; +constexpr const char * const TAG = "ASIO_WEB"; } // namespace Webserver::Webserver(asio::io_context &io_context, unsigned short port) : diff --git a/src/asio_web/websocketclient.cpp b/src/asio_web/websocketclient.cpp new file mode 100644 index 0000000..bbada65 --- /dev/null +++ b/src/asio_web/websocketclient.cpp @@ -0,0 +1,5 @@ +#include "websocketclient.h" + +namespace { +constexpr const char * const TAG = "ASIO_WEB"; +} // namespace diff --git a/src/asio_web/websocketclient.h b/src/asio_web/websocketclient.h new file mode 100644 index 0000000..3f59c93 --- /dev/null +++ b/src/asio_web/websocketclient.h @@ -0,0 +1,2 @@ +#pragma once + diff --git a/src/asio_web/websocketclientconnection.cpp b/src/asio_web/websocketclientconnection.cpp index d173176..cf8b32c 100644 --- a/src/asio_web/websocketclientconnection.cpp +++ b/src/asio_web/websocketclientconnection.cpp @@ -15,7 +15,7 @@ #include "websocketstream.h" namespace { -constexpr const char * const TAG = "ASIO_WEBSERVER"; +constexpr const char * const TAG = "ASIO_WEB"; } // namespace WebsocketClientConnection::WebsocketClientConnection(Webserver &webserver, asio::ip::tcp::socket socket, diff --git a/src/asio_web/websocketstream.cpp b/src/asio_web/websocketstream.cpp index 65e5dc6..820816f 100644 --- a/src/asio_web/websocketstream.cpp +++ b/src/asio_web/websocketstream.cpp @@ -1 +1,5 @@ #include "websocketstream.h" + +namespace { +constexpr const char * const TAG = "ASIO_WEB"; +} // namespace