#include "websocketresponsehandler.h" // system includes #include // esp-idf includes #include #include // 3rdparty lib includes #include #include #include #include namespace { constexpr const char * const TAG = "ASIO_WEBSERVER"; constexpr std::string_view html{R"END( Websocket test

Websocket test

Connection Not connected
Send msg


        
    

)END"};
} // namespace

WebsocketResponseHandler::WebsocketResponseHandler(ClientConnection &clientConnection) :
    m_clientConnection{clientConnection}
{
//    ESP_LOGV(TAG, "constructed for (%s:%hi)",
//             m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
}

WebsocketResponseHandler::~WebsocketResponseHandler()
{
//    ESP_LOGV(TAG, "destructed for (%s:%hi)",
//             m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
}

void WebsocketResponseHandler::requestHeaderReceived(std::string_view key, std::string_view value)
{
//    ESP_LOGV(TAG, "key=\"%.*s\" value=\"%.*s\"", key.size(), key.data(), value.size(), value.data());

    if (cpputils::stringEqualsIgnoreCase(key, "Connection"))
    {
        m_connectionUpgrade = cpputils::stringEqualsIgnoreCase(value, "Upgrade") ||
            value.contains("Upgrade");
    }
    else if (cpputils::stringEqualsIgnoreCase(key, "Upgrade"))
    {
        m_upgradeWebsocket = cpputils::stringEqualsIgnoreCase(value, "websocket");
    }
    else if (cpputils::stringEqualsIgnoreCase(key, "Sec-WebSocket-Version"))
    {
        m_secWebsocketVersion = value;
    }
    else if (cpputils::stringEqualsIgnoreCase(key, "Sec-WebSocket-Key"))
    {
        m_secWebsocketKey = value;
    }
    else if (cpputils::stringEqualsIgnoreCase(key, "Sec-WebSocket-Extensions"))
    {
        m_secWebsocketExtensions = value;
    }
}

void WebsocketResponseHandler::requestBodyReceived(std::string_view body)
{
}

void WebsocketResponseHandler::sendResponse()
{
    ESP_LOGI(TAG, "sending response for (%s:%hi)",
             m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());

    if (!m_connectionUpgrade || !m_upgradeWebsocket)
    {
        m_response = fmt::format("HTTP/1.1 200 Ok\r\n"
                                 "Connection: {}\r\n"
                                 "Content-Type: text/html\r\n"
                                 "Content-Length: {}\r\n"
                                 "\r\n",
                                 m_clientConnection.webserver().connectionKeepAlive() ? "keep-alive" : "close",
                                 html.size());

        asio::async_write(m_clientConnection.socket(),
                          asio::buffer(m_response.data(), m_response.size()),
                          [this, self=m_clientConnection.shared_from_this()](std::error_code ec, std::size_t length)
                          { writtenHtmlHeader(ec, length); });

        return;
    }

    const auto showError = [&](std::string_view msg){
        m_response = fmt::format("HTTP/1.1 400 Bad Request\r\n"
                                 "Connection: {}\r\n"
                                 "Content-Type: text/html\r\n"
                                 "Content-Length: {}\r\n"
                                 "\r\n"
                                 "{}",
                                 m_clientConnection.webserver().connectionKeepAlive() ? "keep-alive" : "close",
                                 msg.size(), msg);

        asio::async_write(m_clientConnection.socket(),
                          asio::buffer(m_response.data(), m_response.size()),
                          [this, self=m_clientConnection.shared_from_this()](std::error_code ec, std::size_t length)
                          { writtenHtml(ec, length); });
    };

    if (m_secWebsocketKey.empty())
    {
        showError("Header Sec-WebSocket-Key empty or missing!");
        return;
    }

    constexpr std::string_view magic_uuid{"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"};
    m_secWebsocketKey.append(magic_uuid);

    unsigned char sha1[SHA_DIGEST_LENGTH]; // == 20
    SHA1((const unsigned char *)m_secWebsocketKey.data(), m_secWebsocketKey.size(), sha1);

    const auto base64Sha1 = cpputils::toBase64String({sha1, SHA_DIGEST_LENGTH});

    m_response = fmt::format("HTTP/1.1 101 Switching Protocols\r\n"
                             "Upgrade: websocket\r\n"
                             "Connection: Upgrade\r\n"
                             "Sec-WebSocket-Accept: {}\r\n"
                             "\r\n", base64Sha1);

    asio::async_write(m_clientConnection.socket(),
                      asio::buffer(m_response.data(), m_response.size()),
                      [this, self=m_clientConnection.shared_from_this()](std::error_code ec, std::size_t length)
                      { writtenWebsocket(ec, length); });
}

void WebsocketResponseHandler::writtenHtmlHeader(std::error_code ec, std::size_t length)
{
    if (ec)
    {
        ESP_LOGW(TAG, "error: %i (%s:%hi)", ec.value(),
                 m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
        m_clientConnection.responseFinished(ec);
        return;
    }

    ESP_LOGI(TAG, "expected=%zd actual=%zd for (%s:%hi)", m_response.size(), length,
             m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());

    asio::async_write(m_clientConnection.socket(),
                      asio::buffer(html.data(), html.size()),
                      [this, self=m_clientConnection.shared_from_this()](std::error_code ec, std::size_t length)
                      { writtenHtml(ec, length); });
}

void WebsocketResponseHandler::writtenHtml(std::error_code ec, std::size_t length)
{
    if (ec)
    {
        ESP_LOGW(TAG, "error: %i (%s:%hi)", ec.value(),
                 m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
        m_clientConnection.responseFinished(ec);
        return;
    }

    ESP_LOGI(TAG, "expected=%zd actual=%zd for (%s:%hi)", html.size(), length,
             m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());

    m_clientConnection.responseFinished(ec);
}

void WebsocketResponseHandler::writtenWebsocket(std::error_code ec, std::size_t length)
{
    if (ec)
    {
        ESP_LOGW(TAG, "error: %i (%s:%hi)", ec.value(),
                 m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());
        m_clientConnection.responseFinished(ec);
        return;
    }

    ESP_LOGI(TAG, "expected=%zd actual=%zd for (%s:%hi)", m_response.size(), length,
             m_clientConnection.remote_endpoint().address().to_string().c_str(), m_clientConnection.remote_endpoint().port());

    m_clientConnection.upgradeWebsocket();
}