diff --git a/src/WebSockets.h b/src/WebSockets.h index a3419bc..7bb4303 100644 --- a/src/WebSockets.h +++ b/src/WebSockets.h @@ -47,6 +47,8 @@ #define WEBSOCKETS_MAX_DATA_SIZE (15*1024) #define WEBSOCKETS_USE_BIG_MEM #define GET_FREE_HEAP ESP.getFreeHeap() +// moves all Header strings to Flash (~300 Byte) +//#define WEBSOCKETS_SAVE_RAM #else #ifdef STM32_DEVICE #define WEBSOCKETS_MAX_DATA_SIZE (15*1024) @@ -55,6 +57,8 @@ #else //atmega328p has only 2KB ram! #define WEBSOCKETS_MAX_DATA_SIZE (1024) +// moves all Header strings to Flash +#define WEBSOCKETS_SAVE_RAM #endif #endif @@ -134,6 +138,12 @@ #error "no network type selected!" #endif +// moves all Header strings to Flash (~300 Byte) +#ifdef WEBSOCKETS_SAVE_RAM +#define WEBSOCKETS_STRING(var) F(var) +#else +#define WEBSOCKETS_STRING(var) var +#endif typedef enum { WSC_NOT_CONNECTED, diff --git a/src/WebSocketsClient.cpp b/src/WebSocketsClient.cpp index 76e354c..c4a8206 100644 --- a/src/WebSocketsClient.cpp +++ b/src/WebSocketsClient.cpp @@ -404,12 +404,15 @@ void WebSocketsClient::handleClientData(void) { } #endif + /** * send the WebSocket header to Server * @param client WSclient_t * ptr to the client struct */ void WebSocketsClient::sendHeader(WSclient_t * client) { + static const char * NEW_LINE = "\r\n"; + DEBUG_WEBSOCKETS("[WS-Client][sendHeader] sending header...\n"); uint8_t randomKey[16] = { 0 }; @@ -424,45 +427,59 @@ void WebSocketsClient::sendHeader(WSclient_t * client) { unsigned long start = micros(); #endif - String transport; String handshake; - if(!client->isSocketIO || (client->isSocketIO && client->cSessionId.length() > 0)) { - if(client->isSocketIO) { - transport = "&transport=websocket&sid=" + client->cSessionId; - } - handshake = "GET " + client->cUrl + transport + " HTTP/1.1\r\n" - "Host: " + _host + ":" + _port + "\r\n" - "Connection: Upgrade\r\n" - "Upgrade: websocket\r\n" - "Sec-WebSocket-Version: 13\r\n" - "Sec-WebSocket-Key: " + client->cKey + "\r\n"; + bool ws_header = true; + String url = client->cUrl; - if(client->cProtocol.length() > 0) { - handshake += "Sec-WebSocket-Protocol: " + client->cProtocol + "\r\n"; - } - - if(client->cExtensions.length() > 0) { - handshake += "Sec-WebSocket-Extensions: " + client->cExtensions + "\r\n"; - } - - } else { - handshake = "GET " + client->cUrl + "&transport=polling HTTP/1.1\r\n" - "Host: " + _host + ":" + _port + "\r\n" - "Connection: keep-alive\r\n"; + if(client->isSocketIO) { + if(client->cSessionId.length() == 0) { + url += WEBSOCKETS_STRING("&transport=polling"); + ws_header = false; + } else { + url += WEBSOCKETS_STRING("&transport=websocket&sid="); + url += client->cSessionId; + } } - handshake += "Origin: file://\r\n" - "User-Agent: arduino-WebSocket-Client\r\n"; + handshake = WEBSOCKETS_STRING("GET "); + handshake += url + WEBSOCKETS_STRING(" HTTP/1.1\r\n" + "Host: "); + handshake += _host + ":" + _port + NEW_LINE; - if(client->base64Authorization.length() > 0) { - handshake += "Authorization: Basic " + client->base64Authorization + "\r\n"; - } + if(ws_header) { + handshake += WEBSOCKETS_STRING("Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Key: "); + handshake += client->cKey + NEW_LINE; - if(client->plainAuthorization.length() > 0) { - handshake += "Authorization: " + client->plainAuthorization + "\r\n"; - } + if(client->cProtocol.length() > 0) { + handshake += WEBSOCKETS_STRING("Sec-WebSocket-Protocol: "); + handshake +=client->cProtocol + NEW_LINE; + } - handshake += "\r\n"; + if(client->cExtensions.length() > 0) { + handshake += WEBSOCKETS_STRING("Sec-WebSocket-Extensions: "); + handshake +=client->cExtensions + NEW_LINE; + } + } else { + handshake += WEBSOCKETS_STRING("Connection: keep-alive\r\n"); + } + + handshake += WEBSOCKETS_STRING("Origin: file://\r\n" + "User-Agent: arduino-WebSocket-Client\r\n"); + + if(client->base64Authorization.length() > 0) { + handshake += WEBSOCKETS_STRING("Authorization: Basic "); + handshake += client->base64Authorization + NEW_LINE; + } + + if(client->plainAuthorization.length() > 0) { + handshake += WEBSOCKETS_STRING("Authorization: "); + handshake += client->plainAuthorization + NEW_LINE; + } + + handshake += NEW_LINE; DEBUG_WEBSOCKETS("[WS-Client][sendHeader] handshake %s", (uint8_t*)handshake.c_str()); client->tcp->write((uint8_t*)handshake.c_str(), handshake.length()); @@ -486,32 +503,32 @@ void WebSocketsClient::handleHeader(WSclient_t * client, String * headerLine) { if(headerLine->length() > 0) { DEBUG_WEBSOCKETS("[WS-Client][handleHeader] RX: %s\n", headerLine->c_str()); - if(headerLine->startsWith("HTTP/1.")) { + if(headerLine->startsWith(WEBSOCKETS_STRING("HTTP/1."))) { // "HTTP/1.1 101 Switching Protocols" client->cCode = headerLine->substring(9, headerLine->indexOf(' ', 9)).toInt(); } else if(headerLine->indexOf(':')) { String headerName = headerLine->substring(0, headerLine->indexOf(':')); String headerValue = headerLine->substring(headerLine->indexOf(':') + 2); - if(headerName.equalsIgnoreCase("Connection")) { - if(headerValue.equalsIgnoreCase("upgrade")) { + if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Connection"))) { + if(headerValue.equalsIgnoreCase(WEBSOCKETS_STRING("upgrade"))) { client->cIsUpgrade = true; } - } else if(headerName.equalsIgnoreCase("Upgrade")) { - if(headerValue.equalsIgnoreCase("websocket")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Upgrade"))) { + if(headerValue.equalsIgnoreCase(WEBSOCKETS_STRING("websocket"))) { client->cIsWebsocket = true; } - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Accept")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Accept"))) { client->cAccept = headerValue; client->cAccept.trim(); // see rfc6455 - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Protocol")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Protocol"))) { client->cProtocol = headerValue; - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Extensions")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Extensions"))) { client->cExtensions = headerValue; - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Version")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Version"))) { client->cVersion = headerValue.toInt(); - } else if(headerName.equalsIgnoreCase("Set-Cookie")) { - if (headerValue.indexOf("HttpOnly") > -1) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Set-Cookie"))) { + if (headerValue.indexOf(WEBSOCKETS_STRING("HttpOnly")) > -1) { client->cSessionId = headerValue.substring(headerValue.indexOf('=') + 1, headerValue.indexOf(";")); } else { client->cSessionId = headerValue.substring(headerValue.indexOf('=') + 1); diff --git a/src/WebSocketsServer.cpp b/src/WebSocketsServer.cpp index a7d3dc3..adc40ae 100644 --- a/src/WebSocketsServer.cpp +++ b/src/WebSocketsServer.cpp @@ -667,6 +667,7 @@ bool WebSocketsServer::hasMandatoryHeader(String headerName) { return false; } + /** * handles http header reading for WebSocket upgrade * @param client WSclient_t * ///< pointer to the client struct @@ -674,6 +675,8 @@ bool WebSocketsServer::hasMandatoryHeader(String headerName) { */ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { + static const char * NEW_LINE = "\r\n"; + headerLine->trim(); // remove \r if(headerLine->length() > 0) { @@ -693,25 +696,25 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { String headerName = headerLine->substring(0, headerLine->indexOf(':')); String headerValue = headerLine->substring(headerLine->indexOf(':') + 2); - if(headerName.equalsIgnoreCase("Connection")) { + if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Connection"))) { headerValue.toLowerCase(); - if(headerValue.indexOf("upgrade") >= 0) { + if(headerValue.indexOf(WEBSOCKETS_STRING("upgrade")) >= 0) { client->cIsUpgrade = true; } - } else if(headerName.equalsIgnoreCase("Upgrade")) { - if(headerValue.equalsIgnoreCase("websocket")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Upgrade"))) { + if(headerValue.equalsIgnoreCase(WEBSOCKETS_STRING("websocket"))) { client->cIsWebsocket = true; } - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Version")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Version"))) { client->cVersion = headerValue.toInt(); - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Key")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Key"))) { client->cKey = headerValue; client->cKey.trim(); // see rfc6455 - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Protocol")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Protocol"))) { client->cProtocol = headerValue; - } else if(headerName.equalsIgnoreCase("Sec-WebSocket-Extensions")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Sec-WebSocket-Extensions"))) { client->cExtensions = headerValue; - } else if(headerName.equalsIgnoreCase("Authorization")) { + } else if(headerName.equalsIgnoreCase(WEBSOCKETS_STRING("Authorization"))) { client->base64Authorization = headerValue; } else { client->cHttpHeadersValid &= execHttpHeaderValidation(headerName, headerValue); @@ -764,7 +767,7 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { if(_base64Authorization.length() > 0) { if(client->base64Authorization.length() > 0) { - String auth = "Basic "; + String auth = WEBSOCKETS_STRING("Basic "); auth += _base64Authorization; if(auth != client->base64Authorization) { DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] HTTP Authorization failed!\n", client->num); @@ -787,32 +790,30 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { client->status = WSC_CONNECTED; - client->tcp->write("HTTP/1.1 101 Switching Protocols\r\n" + String handshake = WEBSOCKETS_STRING("HTTP/1.1 101 Switching Protocols\r\n" "Server: arduino-WebSocketsServer\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Version: 13\r\n" "Sec-WebSocket-Accept: "); - client->tcp->write((uint8_t*)sKey.c_str(), sKey.length()); + handshake += sKey + NEW_LINE; if(_origin.length() > 0) { - String origin = "\r\nAccess-Control-Allow-Origin: "; - origin += _origin; - origin += "\r\n"; - client->tcp->write((uint8_t*)origin.c_str(), origin.length()); + handshake += WEBSOCKETS_STRING("Access-Control-Allow-Origin: "); + handshake +=_origin + NEW_LINE; } if(client->cProtocol.length() > 0) { - String protocol = "\r\nSec-WebSocket-Protocol: "; - protocol += _protocol; - protocol += "\r\n"; - client->tcp->write((uint8_t*)protocol.c_str(), protocol.length()); - } else { - client->tcp->write("\r\n"); + handshake += WEBSOCKETS_STRING("Sec-WebSocket-Protocol: "); + handshake +=_protocol + NEW_LINE; } // header end - client->tcp->write("\r\n"); + handshake += NEW_LINE; + + DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] handshake %s", client->num, (uint8_t*)handshake.c_str()); + + client->tcp->write((uint8_t*)handshake.c_str(), handshake.length()); headerDone(client); diff --git a/src/WebSocketsServer.h b/src/WebSocketsServer.h index 6185e92..3550c6a 100644 --- a/src/WebSocketsServer.h +++ b/src/WebSocketsServer.h @@ -149,8 +149,7 @@ protected: * @param client WSclient_t * ptr to the client struct */ virtual void handleAuthorizationFailed(WSclient_t *client) { - - client->tcp->write("HTTP/1.1 401 Unauthorized\r\n" + client->tcp->write("HTTP/1.1 401 Unauthorized\r\n" "Server: arduino-WebSocket-Server\r\n" "Content-Type: text/plain\r\n" "Content-Length: 45\r\n"