diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..14012d8 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,47 @@ +set(headers + src/SocketIOclient.h + src/WebSockets.h + src/WebSockets4WebServer.h + src/WebSocketsClient.h + src/WebSocketsServer.h + src/WebSocketsVersion.h + src/libb64/cdecode_inc.h + src/libb64/cencode_inc.h + src/libsha1/libsha1.h +) + +set(sources + src/SocketIOclient.cpp + src/WebSockets.cpp + src/WebSocketsClient.cpp + src/WebSocketsServer.cpp + src/libb64/cdecode.c + src/libb64/cencode.c + src/libsha1/libsha1.c +) + +set(dependencies +) + +idf_component_register( + INCLUDE_DIRS + src + SRCS + ${headers} + ${sources} + REQUIRES + ${dependencies} +) + +target_compile_options(${COMPONENT_TARGET} + PUBLIC + -DESP32 + -DWEBSOCKETS_NETWORK_TYPE=NETWORK_ESP32 + PRIVATE + -fstack-reuse=all + -fstack-protector-all + -Wno-unused-function + -Wno-deprecated-declarations + -Wno-missing-field-initializers + -Wno-parentheses +) diff --git a/src/SocketIOclient.cpp b/src/SocketIOclient.cpp index a06efa2..72a6dd4 100644 --- a/src/SocketIOclient.cpp +++ b/src/SocketIOclient.cpp @@ -5,9 +5,12 @@ * Author: links */ +#include "SocketIOclient.h" + +#include + #include "WebSockets.h" #include "WebSocketsClient.h" -#include "SocketIOclient.h" SocketIOclient::SocketIOclient() { } @@ -21,7 +24,7 @@ void SocketIOclient::begin(const char * host, uint16_t port, const char * url, c initClient(); } -void SocketIOclient::begin(String host, uint16_t port, String url, String protocol) { +void SocketIOclient::begin(std::string host, uint16_t port, std::string url, std::string protocol) { WebSocketsClient::beginSocketIO(host, port, url, protocol); WebSocketsClient::enableHeartbeat(60 * 1000, 90 * 1000, 5); initClient(); @@ -33,7 +36,7 @@ void SocketIOclient::beginSSL(const char * host, uint16_t port, const char * url initClient(); } -void SocketIOclient::beginSSL(String host, uint16_t port, String url, String protocol) { +void SocketIOclient::beginSSL(std::string host, uint16_t port, std::string url, std::string protocol) { WebSocketsClient::beginSocketIOSSL(host, port, url, protocol); WebSocketsClient::enableHeartbeat(60 * 1000, 90 * 1000, 5); initClient(); @@ -131,7 +134,7 @@ bool SocketIOclient::send(socketIOmessageType_t type, const char * payload, size return send(type, (uint8_t *)payload, length); } -bool SocketIOclient::send(socketIOmessageType_t type, String & payload) { +bool SocketIOclient::send(socketIOmessageType_t type, std::string & payload) { return send(type, (uint8_t *)payload.c_str(), payload.length()); } @@ -159,7 +162,7 @@ bool SocketIOclient::sendEVENT(const char * payload, size_t length) { return sendEVENT((uint8_t *)payload, length); } -bool SocketIOclient::sendEVENT(String & payload) { +bool SocketIOclient::sendEVENT(std::string & payload) { return sendEVENT((uint8_t *)payload.c_str(), payload.length()); } diff --git a/src/SocketIOclient.h b/src/SocketIOclient.h index 6deb168..5014f52 100644 --- a/src/SocketIOclient.h +++ b/src/SocketIOclient.h @@ -8,6 +8,8 @@ #ifndef SOCKETIOCLIENT_H_ #define SOCKETIOCLIENT_H_ +#include + #include "WebSockets.h" #define EIO_HEARTBEAT_INTERVAL 20000 @@ -47,11 +49,11 @@ class SocketIOclient : protected WebSocketsClient { virtual ~SocketIOclient(void); void begin(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); - void begin(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); + void begin(std::string host, uint16_t port, std::string url = "/socket.io/?EIO=3", std::string protocol = "arduino"); #ifdef HAS_SSL void beginSSL(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); - void beginSSL(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); + void beginSSL(std::string host, uint16_t port, std::string url = "/socket.io/?EIO=3", std::string protocol = "arduino"); #ifndef SSL_AXTLS void beginSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * CA_cert = NULL, const char * protocol = "arduino"); void beginSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", BearSSL::X509List * CA_cert = NULL, const char * protocol = "arduino"); @@ -67,13 +69,13 @@ class SocketIOclient : protected WebSocketsClient { bool sendEVENT(const uint8_t * payload, size_t length = 0); bool sendEVENT(char * payload, size_t length = 0, bool headerToPayload = false); bool sendEVENT(const char * payload, size_t length = 0); - bool sendEVENT(String & payload); + bool sendEVENT(std::string & payload); bool send(socketIOmessageType_t type, uint8_t * payload, size_t length = 0, bool headerToPayload = false); bool send(socketIOmessageType_t type, const uint8_t * payload, size_t length = 0); bool send(socketIOmessageType_t type, char * payload, size_t length = 0, bool headerToPayload = false); bool send(socketIOmessageType_t type, const char * payload, size_t length = 0); - bool send(socketIOmessageType_t type, String & payload); + bool send(socketIOmessageType_t type, std::string & payload); void loop(void); diff --git a/src/WebSockets.cpp b/src/WebSockets.cpp index ef8224c..a237399 100644 --- a/src/WebSockets.cpp +++ b/src/WebSockets.cpp @@ -24,6 +24,8 @@ #include "WebSockets.h" +#include + #ifdef ESP8266 #include #endif @@ -535,15 +537,15 @@ void WebSockets::handleWebsocketPayloadCb(WSclient_t * client, bool ok, uint8_t /** * generate the key for Sec-WebSocket-Accept - * @param clientKey String - * @return String Accept Key + * @param clientKey std::string + * @return std::string Accept Key */ -String WebSockets::acceptKey(String & clientKey) { +std::string WebSockets::acceptKey(std::string & clientKey) { uint8_t sha1HashBin[20] = { 0 }; #ifdef ESP8266 sha1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", &sha1HashBin[0]); #elif defined(ESP32) - String data = clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + std::string data = clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; esp_sha(SHA1, (unsigned char *)data.c_str(), data.length(), &sha1HashBin[0]); #else clientKey += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -553,7 +555,7 @@ String WebSockets::acceptKey(String & clientKey) { SHA1Final(&sha1HashBin[0], &ctx); #endif - String key = base64_encode(sha1HashBin, 20); + std::string key = base64_encode(sha1HashBin, 20); key.trim(); return key; @@ -563,9 +565,9 @@ String WebSockets::acceptKey(String & clientKey) { * base64_encode * @param data uint8_t * * @param length size_t - * @return base64 encoded String + * @return base64 encoded std::string */ -String WebSockets::base64_encode(uint8_t * data, size_t length) { +std::string WebSockets::base64_encode(uint8_t * data, size_t length) { size_t size = ((length * 1.6f) + 1); char * buffer = (char *)malloc(size); if(buffer) { @@ -574,11 +576,11 @@ String WebSockets::base64_encode(uint8_t * data, size_t length) { int len = base64_encode_block((const char *)&data[0], length, &buffer[0], &_state); len = base64_encode_blockend((buffer + len), &_state); - String base64 = String(buffer); + std::string base64 = std::string(buffer); free(buffer); return base64; } - return String("-FAIL-"); + return std::string("-FAIL-"); } /** diff --git a/src/WebSockets.h b/src/WebSockets.h index b182bd5..800fa9e 100644 --- a/src/WebSockets.h +++ b/src/WebSockets.h @@ -25,12 +25,14 @@ #ifndef WEBSOCKETS_H_ #define WEBSOCKETS_H_ +#include + #ifdef STM32_DEVICE #include #define bit(b) (1UL << (b)) // Taken directly from Arduino.h #else -#include -#include +//#include +//#include #endif #ifdef ARDUINO_ARCH_AVR @@ -186,8 +188,8 @@ #elif(WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) -#include -#include +//#include +//#include #define SSL_AXTLS #define WEBSOCKETS_NETWORK_CLASS WiFiClient #define WEBSOCKETS_NETWORK_SSL_CLASS WiFiClientSecure @@ -284,28 +286,28 @@ typedef struct { WEBSOCKETS_NETWORK_SSL_CLASS * ssl; #endif - String cUrl; ///< http url + std::string cUrl; ///< http url uint16_t cCode = 0; ///< http code bool cIsClient = false; ///< will be used for masking bool cIsUpgrade = false; ///< Connection == Upgrade bool cIsWebsocket = false; ///< Upgrade == websocket - String cSessionId; ///< client Set-Cookie (session id) - String cKey; ///< client Sec-WebSocket-Key - String cAccept; ///< client Sec-WebSocket-Accept - String cProtocol; ///< client Sec-WebSocket-Protocol - String cExtensions; ///< client Sec-WebSocket-Extensions + std::string cSessionId; ///< client Set-Cookie (session id) + std::string cKey; ///< client Sec-WebSocket-Key + std::string cAccept; ///< client Sec-WebSocket-Accept + std::string cProtocol; ///< client Sec-WebSocket-Protocol + std::string cExtensions; ///< client Sec-WebSocket-Extensions uint16_t cVersion = 0; ///< client Sec-WebSocket-Version uint8_t cWsRXsize = 0; ///< State of the RX uint8_t cWsHeader[WEBSOCKETS_MAX_HEADER_SIZE]; ///< RX WS Message buffer WSMessageHeader_t cWsHeaderDecode; - String base64Authorization; ///< Base64 encoded Auth request - String plainAuthorization; ///< Base64 encoded Auth request + std::string base64Authorization; ///< Base64 encoded Auth request + std::string plainAuthorization; ///< Base64 encoded Auth request - String extraHeaders; + std::string extraHeaders; bool cHttpHeadersValid = false; ///< non-websocket http header validity indicator size_t cMandatoryHeadersCount; ///< non-websocket mandatory http headers present count @@ -318,7 +320,7 @@ typedef struct { uint8_t pongTimeoutCount = 0; // current pong timeout count #if(WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC) - String cHttpLine; ///< HTTP header lines + std::string cHttpLine; ///< HTTP header lines #endif } WSclient_t; @@ -350,8 +352,8 @@ class WebSockets { void handleWebsocketCb(WSclient_t * client); void handleWebsocketPayloadCb(WSclient_t * client, bool ok, uint8_t * payload); - String acceptKey(String & clientKey); - String base64_encode(uint8_t * data, size_t length); + std::string acceptKey(std::string & clientKey); + std::string base64_encode(uint8_t * data, size_t length); bool readCb(WSclient_t * client, uint8_t * out, size_t n, WSreadWaitCb cb); virtual size_t write(WSclient_t * client, uint8_t * out, size_t n); diff --git a/src/WebSockets4WebServer.h b/src/WebSockets4WebServer.h index ba2b020..6c962d9 100644 --- a/src/WebSockets4WebServer.h +++ b/src/WebSockets4WebServer.h @@ -25,6 +25,8 @@ #ifndef __WEBSOCKETS4WEBSERVER_H #define __WEBSOCKETS4WEBSERVER_H +#include + #include #include @@ -32,15 +34,15 @@ class WebSockets4WebServer : public WebSocketsServerCore { public: - WebSockets4WebServer(const String & origin = "", const String & protocol = "arduino") + WebSockets4WebServer(const std::string & origin = "", const std::string & protocol = "arduino") : WebSocketsServerCore(origin, protocol) { begin(); } - ESP8266WebServer::HookFunction hookForWebserver(const String & wsRootDir, WebSocketServerEvent event) { + ESP8266WebServer::HookFunction hookForWebserver(const std::string & wsRootDir, WebSocketServerEvent event) { onEvent(event); - return [&, wsRootDir](const String & method, const String & url, WiFiClient * tcpClient, ESP8266WebServer::ContentTypeFunction contentType) { + return [&, wsRootDir](const std::string & method, const std::string & url, WiFiClient * tcpClient, ESP8266WebServer::ContentTypeFunction contentType) { (void)contentType; if(!(method == "GET" && url.indexOf(wsRootDir) == 0)) { @@ -55,7 +57,7 @@ class WebSockets4WebServer : public WebSocketsServerCore { if(client) { // give "GET " - String headerLine; + std::string headerLine; headerLine.reserve(url.length() + 5); headerLine = "GET "; headerLine += url; diff --git a/src/WebSocketsClient.cpp b/src/WebSocketsClient.cpp index e3519c8..97334c0 100644 --- a/src/WebSocketsClient.cpp +++ b/src/WebSocketsClient.cpp @@ -22,9 +22,12 @@ * */ -#include "WebSockets.h" #include "WebSocketsClient.h" +#include + +#include "WebSockets.h" + WebSocketsClient::WebSocketsClient() { _cbEvent = NULL; _client.num = 0; @@ -90,7 +93,7 @@ void WebSocketsClient::begin(const char * host, uint16_t port, const char * url, DEBUG_WEBSOCKETS("[WS-Client] Websocket Version: " WEBSOCKETS_VERSION "\n"); } -void WebSocketsClient::begin(String host, uint16_t port, String url, String protocol) { +void WebSocketsClient::begin(std::string host, uint16_t port, std::string url, std::string protocol) { begin(host.c_str(), port, url.c_str(), protocol.c_str()); } @@ -107,7 +110,7 @@ void WebSocketsClient::beginSSL(const char * host, uint16_t port, const char * u _CA_cert = NULL; } -void WebSocketsClient::beginSSL(String host, uint16_t port, String url, String fingerprint, String protocol) { +void WebSocketsClient::beginSSL(std::string host, uint16_t port, std::string url, std::string fingerprint, std::string protocol) { beginSSL(host.c_str(), port, url.c_str(), fingerprint.c_str(), protocol.c_str()); } @@ -153,7 +156,7 @@ void WebSocketsClient::beginSocketIO(const char * host, uint16_t port, const cha _client.isSocketIO = true; } -void WebSocketsClient::beginSocketIO(String host, uint16_t port, String url, String protocol) { +void WebSocketsClient::beginSocketIO(std::string host, uint16_t port, std::string url, std::string protocol) { beginSocketIO(host.c_str(), port, url.c_str(), protocol.c_str()); } @@ -165,7 +168,7 @@ void WebSocketsClient::beginSocketIOSSL(const char * host, uint16_t port, const _fingerprint = SSL_FINGERPRINT_NULL; } -void WebSocketsClient::beginSocketIOSSL(String host, uint16_t port, String url, String protocol) { +void WebSocketsClient::beginSocketIOSSL(std::string host, uint16_t port, std::string url, std::string protocol) { beginSocketIOSSL(host.c_str(), port, url.c_str(), protocol.c_str()); } @@ -320,7 +323,7 @@ bool WebSocketsClient::sendTXT(const char * payload, size_t length) { return sendTXT((uint8_t *)payload, length); } -bool WebSocketsClient::sendTXT(String & payload) { +bool WebSocketsClient::sendTXT(std::string & payload) { return sendTXT((uint8_t *)payload.c_str(), payload.length()); } @@ -365,7 +368,7 @@ bool WebSocketsClient::sendPing(uint8_t * payload, size_t length) { return false; } -bool WebSocketsClient::sendPing(String & payload) { +bool WebSocketsClient::sendPing(std::string & payload) { return sendPing((uint8_t *)payload.c_str(), payload.length()); } @@ -386,7 +389,7 @@ void WebSocketsClient::disconnect(void) { */ void WebSocketsClient::setAuthorization(const char * user, const char * password) { if(user && password) { - String auth = user; + std::string auth = user; auth += ":"; auth += password; _client.base64Authorization = base64_encode((uint8_t *)auth.c_str(), auth.length()); @@ -565,13 +568,13 @@ void WebSocketsClient::handleClientData(void) { if(len > 0) { switch(_client.status) { case WSC_HEADER: { - String headerLine = _client.tcp->readStringUntil('\n'); + std::string headerLine = _client.tcp->readStringUntil('\n'); handleHeader(&_client, &headerLine); } break; case WSC_BODY: { char buf[256] = { 0 }; _client.tcp->readBytes(&buf[0], std::min((size_t)len, sizeof(buf))); - String bodyLine = buf; + std::string bodyLine = buf; handleHeader(&_client, &bodyLine); } break; case WSC_CONNECTED: @@ -607,9 +610,9 @@ void WebSocketsClient::sendHeader(WSclient_t * client) { unsigned long start = micros(); #endif - String handshake; + std::string handshake; bool ws_header = true; - String url = client->cUrl; + std::string url = client->cUrl; if(client->isSocketIO) { if(client->cSessionId.length() == 0) { @@ -682,13 +685,13 @@ void WebSocketsClient::sendHeader(WSclient_t * client) { * handle the WebSocket header reading * @param client WSclient_t * ptr to the client struct */ -void WebSocketsClient::handleHeader(WSclient_t * client, String * headerLine) { +void WebSocketsClient::handleHeader(WSclient_t * client, std::string * headerLine) { headerLine->trim(); // remove \r // this code handels the http body for Socket.IO V3 requests if(headerLine->length() > 0 && client->isSocketIO && client->status == WSC_BODY && client->cSessionId.length() == 0) { DEBUG_WEBSOCKETS("[WS-Client][handleHeader] socket.io json: %s\n", headerLine->c_str()); - String sid_begin = WEBSOCKETS_STRING("\"sid\":\""); + std::string sid_begin = WEBSOCKETS_STRING("\"sid\":\""); if(headerLine->indexOf(sid_begin) > -1) { int start = headerLine->indexOf(sid_begin) + sid_begin.length(); int end = headerLine->indexOf('"', start); @@ -708,8 +711,8 @@ void WebSocketsClient::handleHeader(WSclient_t * client, String * headerLine) { // "HTTP/1.1 101 Switching Protocols" client->cCode = headerLine->substring(9, headerLine->indexOf(' ', 9)).toInt(); } else if(headerLine->indexOf(':') >= 0) { - String headerName = headerLine->substring(0, headerLine->indexOf(':')); - String headerValue = headerLine->substring(headerLine->indexOf(':') + 1); + std::string headerName = headerLine->substring(0, headerLine->indexOf(':')); + std::string headerValue = headerLine->substring(headerLine->indexOf(':') + 1); // remove space in the beginning (RFC2616) if(headerValue[0] == ' ') { @@ -800,7 +803,7 @@ void WebSocketsClient::handleHeader(WSclient_t * client, String * headerLine) { ok = false; } else { // generate Sec-WebSocket-Accept key for check - String sKey = acceptKey(client->cKey); + std::string sKey = acceptKey(client->cKey); if(sKey != client->cAccept) { DEBUG_WEBSOCKETS("[WS-Client][handleHeader] Sec-WebSocket-Accept is wrong\n"); ok = false; diff --git a/src/WebSocketsClient.h b/src/WebSocketsClient.h index efa7631..d1b7827 100644 --- a/src/WebSocketsClient.h +++ b/src/WebSocketsClient.h @@ -25,6 +25,8 @@ #ifndef WEBSOCKETSCLIENT_H_ #define WEBSOCKETSCLIENT_H_ +#include + #include "WebSockets.h" class WebSocketsClient : protected WebSockets { @@ -39,13 +41,13 @@ class WebSocketsClient : protected WebSockets { virtual ~WebSocketsClient(void); void begin(const char * host, uint16_t port, const char * url = "/", const char * protocol = "arduino"); - void begin(String host, uint16_t port, String url = "/", String protocol = "arduino"); + void begin(std::string host, uint16_t port, std::string url = "/", std::string protocol = "arduino"); void begin(IPAddress host, uint16_t port, const char * url = "/", const char * protocol = "arduino"); #if defined(HAS_SSL) #ifdef SSL_AXTLS void beginSSL(const char * host, uint16_t port, const char * url = "/", const char * fingerprint = "", const char * protocol = "arduino"); - void beginSSL(String host, uint16_t port, String url = "/", String fingerprint = "", String protocol = "arduino"); + void beginSSL(std::string host, uint16_t port, std::string url = "/", std::string fingerprint = "", std::string protocol = "arduino"); #else void beginSSL(const char * host, uint16_t port, const char * url = "/", const uint8_t * fingerprint = NULL, const char * protocol = "arduino"); void beginSslWithCA(const char * host, uint16_t port, const char * url = "/", BearSSL::X509List * CA_cert = NULL, const char * protocol = "arduino"); @@ -56,11 +58,11 @@ class WebSocketsClient : protected WebSockets { #endif void beginSocketIO(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); - void beginSocketIO(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); + void beginSocketIO(std::string host, uint16_t port, std::string url = "/socket.io/?EIO=3", std::string protocol = "arduino"); #if defined(HAS_SSL) void beginSocketIOSSL(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); - void beginSocketIOSSL(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); + void beginSocketIOSSL(std::string host, uint16_t port, std::string url = "/socket.io/?EIO=3", std::string protocol = "arduino"); void beginSocketIOSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * CA_cert = NULL, const char * protocol = "arduino"); #if defined(SSL_BARESSL) @@ -81,14 +83,14 @@ class WebSocketsClient : protected WebSockets { bool sendTXT(const uint8_t * payload, size_t length = 0); bool sendTXT(char * payload, size_t length = 0, bool headerToPayload = false); bool sendTXT(const char * payload, size_t length = 0); - bool sendTXT(String & payload); + bool sendTXT(std::string & payload); bool sendTXT(char payload); bool sendBIN(uint8_t * payload, size_t length, bool headerToPayload = false); bool sendBIN(const uint8_t * payload, size_t length); bool sendPing(uint8_t * payload = NULL, size_t length = 0); - bool sendPing(String & payload); + bool sendPing(std::string & payload); void disconnect(void); @@ -105,12 +107,12 @@ class WebSocketsClient : protected WebSockets { bool isConnected(void); protected: - String _host; + std::string _host; uint16_t _port; #if defined(HAS_SSL) #ifdef SSL_AXTLS - String _fingerprint; + std::string _fingerprint; const char * _CA_cert; #define SSL_FINGERPRINT_IS_SET (_fingerprint.length()) #define SSL_FINGERPRINT_NULL "" @@ -142,7 +144,7 @@ class WebSocketsClient : protected WebSockets { #endif void sendHeader(WSclient_t * client); - void handleHeader(WSclient_t * client, String * headerLine); + void handleHeader(WSclient_t * client, std::string * headerLine); void connectedCb(); void connectFailedCb(); diff --git a/src/WebSocketsServer.cpp b/src/WebSocketsServer.cpp index 495cb55..a9eea51 100644 --- a/src/WebSocketsServer.cpp +++ b/src/WebSocketsServer.cpp @@ -22,10 +22,13 @@ * */ -#include "WebSockets.h" #include "WebSocketsServer.h" -WebSocketsServerCore::WebSocketsServerCore(const String & origin, const String & protocol) { +#include + +#include "WebSockets.h" + +WebSocketsServerCore::WebSocketsServerCore(const std::string & origin, const std::string & protocol) { _origin = origin; _protocol = protocol; _runnning = false; @@ -40,7 +43,7 @@ WebSocketsServerCore::WebSocketsServerCore(const String & origin, const String & _mandatoryHttpHeaderCount = 0; } -WebSocketsServer::WebSocketsServer(uint16_t port, const String & origin, const String & protocol) +WebSocketsServer::WebSocketsServer(uint16_t port, const std::string & origin, const std::string & protocol) : WebSocketsServerCore(origin, protocol) { _port = port; @@ -130,7 +133,7 @@ void WebSocketsServerCore::onValidateHttpHeader( delete[] _mandatoryHttpHeaders; _mandatoryHttpHeaderCount = mandatoryHttpHeaderCount; - _mandatoryHttpHeaders = new String[_mandatoryHttpHeaderCount]; + _mandatoryHttpHeaders = new std::string[_mandatoryHttpHeaderCount]; for(size_t i = 0; i < _mandatoryHttpHeaderCount; i++) { _mandatoryHttpHeaders[i] = mandatoryHttpHeaders[i]; @@ -171,7 +174,7 @@ bool WebSocketsServerCore::sendTXT(uint8_t num, const char * payload, size_t len return sendTXT(num, (uint8_t *)payload, length); } -bool WebSocketsServerCore::sendTXT(uint8_t num, String & payload) { +bool WebSocketsServerCore::sendTXT(uint8_t num, std::string & payload) { return sendTXT(num, (uint8_t *)payload.c_str(), payload.length()); } @@ -213,7 +216,7 @@ bool WebSocketsServerCore::broadcastTXT(const char * payload, size_t length) { return broadcastTXT((uint8_t *)payload, length); } -bool WebSocketsServerCore::broadcastTXT(String & payload) { +bool WebSocketsServerCore::broadcastTXT(std::string & payload) { return broadcastTXT((uint8_t *)payload.c_str(), payload.length()); } @@ -284,7 +287,7 @@ bool WebSocketsServerCore::sendPing(uint8_t num, uint8_t * payload, size_t lengt return false; } -bool WebSocketsServerCore::sendPing(uint8_t num, String & payload) { +bool WebSocketsServerCore::sendPing(uint8_t num, std::string & payload) { return sendPing(num, (uint8_t *)payload.c_str(), payload.length()); } @@ -309,7 +312,7 @@ bool WebSocketsServerCore::broadcastPing(uint8_t * payload, size_t length) { return ret; } -bool WebSocketsServerCore::broadcastPing(String & payload) { +bool WebSocketsServerCore::broadcastPing(std::string & payload) { return broadcastPing((uint8_t *)payload.c_str(), payload.length()); } @@ -347,7 +350,7 @@ void WebSocketsServerCore::disconnect(uint8_t num) { */ void WebSocketsServerCore::setAuthorization(const char * user, const char * password) { if(user && password) { - String auth = user; + std::string auth = user; auth += ":"; auth += password; _base64Authorization = base64_encode((uint8_t *)auth.c_str(), auth.length()); @@ -662,7 +665,7 @@ void WebSocketsServerCore::handleClientData(void) { //DEBUG_WEBSOCKETS("[WS-Server][%d][handleClientData] len: %d\n", client->num, len); switch(client->status) { case WSC_HEADER: { - String headerLine = client->tcp->readStringUntil('\n'); + std::string headerLine = client->tcp->readStringUntil('\n'); handleHeader(client, &headerLine); } break; case WSC_CONNECTED: @@ -685,9 +688,9 @@ void WebSocketsServerCore::handleClientData(void) { /* * returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection - * @param headerName String ///< the name of the header being checked + * @param headerName std::string ///< the name of the header being checked */ -bool WebSocketsServerCore::hasMandatoryHeader(String headerName) { +bool WebSocketsServerCore::hasMandatoryHeader(std::string headerName) { for(size_t i = 0; i < _mandatoryHttpHeaderCount; i++) { if(_mandatoryHttpHeaders[i].equalsIgnoreCase(headerName)) return true; @@ -698,9 +701,9 @@ bool WebSocketsServerCore::hasMandatoryHeader(String headerName) { /** * handles http header reading for WebSocket upgrade * @param client WSclient_t * ///< pointer to the client struct - * @param headerLine String ///< the header being read / processed + * @param headerLine std::string ///< the header being read / processed */ -void WebSocketsServerCore::handleHeader(WSclient_t * client, String * headerLine) { +void WebSocketsServerCore::handleHeader(WSclient_t * client, std::string * headerLine) { static const char * NEW_LINE = "\r\n"; headerLine->trim(); // remove \r @@ -718,8 +721,8 @@ void WebSocketsServerCore::handleHeader(WSclient_t * client, String * headerLine client->cMandatoryHeadersCount = 0; } else if(headerLine->indexOf(':') >= 0) { - String headerName = headerLine->substring(0, headerLine->indexOf(':')); - String headerValue = headerLine->substring(headerLine->indexOf(':') + 1); + std::string headerName = headerLine->substring(0, headerLine->indexOf(':')); + std::string headerValue = headerLine->substring(headerLine->indexOf(':') + 1); // remove space in the beginning (RFC2616) if(headerValue[0] == ' ') { @@ -795,7 +798,7 @@ void WebSocketsServerCore::handleHeader(WSclient_t * client, String * headerLine } if(_base64Authorization.length() > 0) { - String auth = WEBSOCKETS_STRING("Basic "); + std::string auth = WEBSOCKETS_STRING("Basic "); auth += _base64Authorization; if(auth != client->base64Authorization) { DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] HTTP Authorization failed!\n", client->num); @@ -808,13 +811,13 @@ void WebSocketsServerCore::handleHeader(WSclient_t * client, String * headerLine DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] Websocket connection incoming.\n", client->num); // generate Sec-WebSocket-Accept key - String sKey = acceptKey(client->cKey); + std::string sKey = acceptKey(client->cKey); DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - sKey: %s\n", client->num, sKey.c_str()); client->status = WSC_CONNECTED; - String handshake = WEBSOCKETS_STRING( + std::string handshake = WEBSOCKETS_STRING( "HTTP/1.1 101 Switching Protocols\r\n" "Server: arduino-WebSocketsServer\r\n" "Upgrade: websocket\r\n" diff --git a/src/WebSocketsServer.h b/src/WebSocketsServer.h index 5bdcd04..538b7a4 100644 --- a/src/WebSocketsServer.h +++ b/src/WebSocketsServer.h @@ -25,6 +25,8 @@ #ifndef WEBSOCKETSSERVER_H_ #define WEBSOCKETSSERVER_H_ +#include + #include "WebSockets.h" #ifndef WEBSOCKETS_SERVER_CLIENT_MAX @@ -33,7 +35,7 @@ class WebSocketsServerCore : protected WebSockets { public: - WebSocketsServerCore(const String & origin = "", const String & protocol = "arduino"); + WebSocketsServerCore(const std::string & origin = "", const std::string & protocol = "arduino"); virtual ~WebSocketsServerCore(void); void begin(void); @@ -41,10 +43,10 @@ class WebSocketsServerCore : protected WebSockets { #ifdef __AVR__ typedef void (*WebSocketServerEvent)(uint8_t num, WStype_t type, uint8_t * payload, size_t length); - typedef bool (*WebSocketServerHttpHeaderValFunc)(String headerName, String headerValue); + typedef bool (*WebSocketServerHttpHeaderValFunc)(std::string headerName, std::string headerValue); #else typedef std::function WebSocketServerEvent; - typedef std::function WebSocketServerHttpHeaderValFunc; + typedef std::function WebSocketServerHttpHeaderValFunc; #endif void onEvent(WebSocketServerEvent cbEvent); @@ -57,13 +59,13 @@ class WebSocketsServerCore : protected WebSockets { bool sendTXT(uint8_t num, const uint8_t * payload, size_t length = 0); bool sendTXT(uint8_t num, char * payload, size_t length = 0, bool headerToPayload = false); bool sendTXT(uint8_t num, const char * payload, size_t length = 0); - bool sendTXT(uint8_t num, String & payload); + bool sendTXT(uint8_t num, std::string & payload); bool broadcastTXT(uint8_t * payload, size_t length = 0, bool headerToPayload = false); bool broadcastTXT(const uint8_t * payload, size_t length = 0); bool broadcastTXT(char * payload, size_t length = 0, bool headerToPayload = false); bool broadcastTXT(const char * payload, size_t length = 0); - bool broadcastTXT(String & payload); + bool broadcastTXT(std::string & payload); bool sendBIN(uint8_t num, uint8_t * payload, size_t length, bool headerToPayload = false); bool sendBIN(uint8_t num, const uint8_t * payload, size_t length); @@ -72,10 +74,10 @@ class WebSocketsServerCore : protected WebSockets { bool broadcastBIN(const uint8_t * payload, size_t length); bool sendPing(uint8_t num, uint8_t * payload = NULL, size_t length = 0); - bool sendPing(uint8_t num, String & payload); + bool sendPing(uint8_t num, std::string & payload); bool broadcastPing(uint8_t * payload = NULL, size_t length = 0); - bool broadcastPing(String & payload); + bool broadcastPing(std::string & payload); void disconnect(void); void disconnect(uint8_t num); @@ -101,10 +103,10 @@ class WebSocketsServerCore : protected WebSockets { WSclient_t * newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient); protected: - String _origin; - String _protocol; - String _base64Authorization; ///< Base64 encoded Auth request - String * _mandatoryHttpHeaders; + std::string _origin; + std::string _protocol; + std::string _base64Authorization; ///< Base64 encoded Auth request + std::string * _mandatoryHttpHeaders; size_t _mandatoryHttpHeaderCount; WSclient_t _clients[WEBSOCKETS_SERVER_CLIENT_MAX]; @@ -127,7 +129,7 @@ class WebSocketsServerCore : protected WebSockets { void handleClientData(void); #endif - void handleHeader(WSclient_t * client, String * headerLine); + void handleHeader(WSclient_t * client, std::string * headerLine); void handleHBPing(WSclient_t * client); // send ping in specified intervals @@ -190,7 +192,7 @@ class WebSocketsServerCore : protected WebSockets { * This mechanism can be used to enable custom authentication schemes e.g. test the value * of a session cookie to determine if a user is logged on / authenticated */ - virtual bool execHttpHeaderValidation(String headerName, String headerValue) { + virtual bool execHttpHeaderValidation(std::string headerName, std::string headerValue) { if(_httpHeaderValidationFunc) { //return the value of the custom http header validation function return _httpHeaderValidationFunc(headerName, headerValue); @@ -211,14 +213,14 @@ class WebSocketsServerCore : protected WebSockets { private: /* * returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection - * @param headerName String ///< the name of the header being checked + * @param headerName std::string ///< the name of the header being checked */ - bool hasMandatoryHeader(String headerName); + bool hasMandatoryHeader(std::string headerName); }; class WebSocketsServer : public WebSocketsServerCore { public: - WebSocketsServer(uint16_t port, const String & origin = "", const String & protocol = "arduino"); + WebSocketsServer(uint16_t port, const std::string & origin = "", const std::string & protocol = "arduino"); virtual ~WebSocketsServer(void); void begin(void);