diff --git a/examples/WebSocketServer/WebSocketServerHttpHeaderValidation.ino b/examples/WebSocketServer/WebSocketServerHttpHeaderValidation.ino new file mode 100644 index 0000000..5323b84 --- /dev/null +++ b/examples/WebSocketServer/WebSocketServerHttpHeaderValidation.ino @@ -0,0 +1,85 @@ +/* + * WebSocketServer.ino + * + * Created on: 22.05.2015 + * + */ + +#include + +#include +#include +#include +#include + +ESP8266WiFiMulti WiFiMulti; + +WebSocketsServer webSocket = WebSocketsServer(81); + +#define USE_SERIAL Serial1 + +const unsigned long int validSessionId = 12345; //some arbitrary value to act as a valid sessionId + +/* + * Returns a bool value as an indicator to describe whether a user is allowed to initiate a websocket upgrade + * based on the value of a cookie. This function expects the rawCookieHeaderValue to look like this "sessionId=|" + */ +bool isCookieValid(String rawCookieHeaderValue) { + + if (rawCookieHeaderValue.indexOf("sessionId") != -1) { + String sessionIdStr = rawCookieHeaderValue.substring(rawCookieHeaderValue.indexOf("sessionId=") + 10, rawCookieHeaderValue.indexOf("|")); + unsigned long int sessionId = strtoul(sessionIdStr.c_str(), NULL, 10); + return sessionId == validSessionId; + } + return false; +} + +/* + * The WebSocketServerHttpHeaderValFunc delegate passed to webSocket.onValidateHttpHeader + */ +bool validateHttpHeader(String headerName, String headerValue) { + + //assume a true response for any headers not handled by this validator + bool valid = true; + + if(headerName.equalsIgnoreCase("Cookie")) { + //if the header passed is the Cookie header, validate it according to the rules in 'isCookieValid' function + valid = isCookieValid(headerValue); + } + + return valid; +} + +void setup() { + // USE_SERIAL.begin(921600); + USE_SERIAL.begin(115200); + + //Serial.setDebugOutput(true); + USE_SERIAL.setDebugOutput(true); + + USE_SERIAL.println(); + USE_SERIAL.println(); + USE_SERIAL.println(); + + for(uint8_t t = 4; t > 0; t--) { + USE_SERIAL.printf("[SETUP] BOOT WAIT %d...\n", t); + USE_SERIAL.flush(); + delay(1000); + } + + WiFiMulti.addAP("SSID", "passpasspass"); + + while(WiFiMulti.run() != WL_CONNECTED) { + delay(100); + } + + //connecting clients must supply a valid session cookie at websocket upgrade handshake negotiation time + const char * headerkeys[] = { "Cookie" }; + webSocket.onValidateHttpHeader(validateHttpHeader, headerkeys); + webSocket.begin(); +} + +void loop() { + webSocket.loop(); +} + diff --git a/src/WebSockets.cpp b/src/WebSockets.cpp index 3b18344..3c89b9d 100644 --- a/src/WebSockets.cpp +++ b/src/WebSockets.cpp @@ -426,7 +426,7 @@ void WebSockets::handleWebsocketPayloadCb(WSclient_t * client, bool ok, uint8_t DEBUG_WEBSOCKETS("[WS][%d][handleWebsocket] text: %s\n", client->num, payload); // no break here! case WSop_binary: - messageRecived(client, header->opCode, payload, header->payloadLen); + messageReceived(client, header->opCode, payload, header->payloadLen); break; case WSop_ping: // send pong back diff --git a/src/WebSockets.h b/src/WebSockets.h index fcd85a6..b78199f 100644 --- a/src/WebSockets.h +++ b/src/WebSockets.h @@ -183,6 +183,9 @@ typedef struct { String base64Authorization; ///< Base64 encoded Auth request String plainAuthorization; ///< Base64 encoded Auth request + bool cHttpHeadersValid; ///< non-websocket http header validity indicator + size_t cMandatoryHeadersCount; ///< non-websocket mandatory http headers present count + #if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC) String cHttpLine; ///< HTTP header lines #endif @@ -202,7 +205,7 @@ class WebSockets { virtual void clientDisconnect(WSclient_t * client); virtual bool clientIsConnected(WSclient_t * client); - virtual void messageRecived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length); + virtual void messageReceived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length); void clientDisconnect(WSclient_t * client, uint16_t code, char * reason = NULL, size_t reasonLen = 0); bool sendFrame(WSclient_t * client, WSopcode_t opcode, uint8_t * payload = NULL, size_t length = 0, bool mask = false, bool fin = true, bool headerToPayload = false); diff --git a/src/WebSocketsServer.cpp b/src/WebSocketsServer.cpp index ccde0cd..eaaf731 100644 --- a/src/WebSocketsServer.cpp +++ b/src/WebSocketsServer.cpp @@ -40,6 +40,9 @@ WebSocketsServer::WebSocketsServer(uint16_t port, String origin, String protocol _cbEvent = NULL; + _httpHeaderValidationFunc = NULL; + _mandatoryHttpHeaders = NULL; + _mandatoryHttpHeaderCount = 0; } @@ -53,10 +56,14 @@ WebSocketsServer::~WebSocketsServer() { // TODO how to close server? #endif + if (_mandatoryHttpHeaders) + delete[] _mandatoryHttpHeaders; + + _mandatoryHttpHeaderCount = 0; } /** - * calles to init the Websockets server + * called to initialize the Websocket server */ void WebSocketsServer::begin(void) { WSclient_t * client; @@ -83,6 +90,7 @@ void WebSocketsServer::begin(void) { client->base64Authorization = ""; client->cWsRXsize = 0; + #if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC) client->cHttpLine = ""; #endif @@ -118,7 +126,30 @@ void WebSocketsServer::onEvent(WebSocketServerEvent cbEvent) { _cbEvent = cbEvent; } -/** +/* + * Sets the custom http header validator function + * If this functionality is being used, call this function prior to calling WebSocketsServer::begin + * @param httpHeaderValidationFunc WebSocketServerHttpHeaderValFunc ///< pointer to the custom http header validation function + * @param mandatoryHttpHeaders const char* ///< the array of named http headers considered to be mandatory / must be present in order for websocket upgrade to succeed + */ +void WebSocketsServer::onValidateHttpHeader( + WebSocketServerHttpHeaderValFunc validationFunc, + const char* mandatoryHttpHeaders[]) +{ + _httpHeaderValidationFunc = validationFunc; + + if (_mandatoryHttpHeaders) + delete[] _mandatoryHttpHeaders; + + _mandatoryHttpHeaderCount = (sizeof(mandatoryHttpHeaders) / sizeof(char*)); + _mandatoryHttpHeaders = new String[_mandatoryHttpHeaderCount]; + + for (size_t i = 0; i < _mandatoryHttpHeaderCount; i++) { + _mandatoryHttpHeaders[i] = mandatoryHttpHeaders[i]; + } +} + +/* * send text data to client * @param num uint8_t client id * @param payload uint8_t * @@ -279,9 +310,8 @@ void WebSocketsServer::disconnect(uint8_t num) { } - -/** - * set the Authorizatio for the http request +/* + * set the Authorization for the http request * @param user const char * * @param password const char * */ @@ -388,7 +418,7 @@ bool WebSocketsServer::newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient) { * @param payload uint8_t * * @param lenght size_t */ -void WebSocketsServer::messageRecived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) { +void WebSocketsServer::messageReceived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t lenght) { WStype_t type = WStype_ERROR; switch(opcode) { @@ -446,6 +476,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) { client->cIsWebsocket = false; client->cWsRXsize = 0; + #if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266_ASYNC) client->cHttpLine = ""; #endif @@ -461,7 +492,7 @@ void WebSocketsServer::clientDisconnect(WSclient_t * client) { /** * get client state * @param client WSclient_t * ptr to the client struct - * @return true = conneted + * @return true = connected */ bool WebSocketsServer::clientIsConnected(WSclient_t * client) { @@ -492,7 +523,7 @@ bool WebSocketsServer::clientIsConnected(WSclient_t * client) { } #if (WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC) /** - * Handle incomming Connection Request + * Handle incoming Connection Request */ void WebSocketsServer::handleNewClients(void) { @@ -569,10 +600,22 @@ void WebSocketsServer::handleClientData(void) { } #endif +/* + * returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection + * @param headerName String ///< the name of the header being checked + */ +bool WebSocketsServer::hasMandatoryHeader(String headerName) { + for (size_t i = 0; i < _mandatoryHttpHeaderCount; i++) { + if (_mandatoryHttpHeaders[i].equalsIgnoreCase(headerName)) + return true; + } + return false; +} /** - * handle the WebSocket header reading - * @param client WSclient_t * ptr to the client struct + * handles http header reading for WebSocket upgrade + * @param client WSclient_t * ///< pointer to the client struct + * @param headerLine String ///< the header being read / processed */ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { @@ -581,10 +624,16 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { if(headerLine->length() > 0) { DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] RX: %s\n", client->num, headerLine->c_str()); - // websocket request starts allways with GET see rfc6455 + // websocket requests always start with GET see rfc6455 if(headerLine->startsWith("GET ")) { + // cut URL out client->cUrl = headerLine->substring(4, headerLine->indexOf(' ', 4)); + + //reset non-websocket http header validation state for this client + client->cHttpHeadersValid = true; + client->cMandatoryHeadersCount = 0; + } else if(headerLine->indexOf(':')) { String headerName = headerLine->substring(0, headerLine->indexOf(':')); String headerValue = headerLine->substring(headerLine->indexOf(':') + 2); @@ -609,7 +658,13 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { client->cExtensions = headerValue; } else if(headerName.equalsIgnoreCase("Authorization")) { client->base64Authorization = headerValue; + } else { + client->cHttpHeadersValid &= execHttpHeaderValidation(headerName, headerValue); + if (_mandatoryHttpHeaderCount > 0 && hasMandatoryHeader(headerName)) { + client->cMandatoryHeadersCount++; + } } + } else { DEBUG_WEBSOCKETS("[WS-Client][handleHeader] Header error (%s)\n", headerLine->c_str()); } @@ -619,8 +674,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { client->tcp->readStringUntil('\n', &(client->cHttpLine), std::bind(&WebSocketsServer::handleHeader, this, client, &(client->cHttpLine))); #endif } else { - DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] Header read fin.\n", client->num); + DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] Header read fin.\n", client->num); DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cURL: %s\n", client->num, client->cUrl.c_str()); DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cIsUpgrade: %d\n", client->num, client->cIsUpgrade); DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cIsWebsocket: %d\n", client->num, client->cIsWebsocket); @@ -629,6 +684,8 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cExtensions: %s\n", client->num, client->cExtensions.c_str()); DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cVersion: %d\n", client->num, client->cVersion); DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - base64Authorization: %s\n", client->num, client->base64Authorization.c_str()); + DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cHttpHeadersValid: %d\n", client->num, client->cHttpHeadersValid); + DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - cMandatoryHeadersCount: %d\n", client->num, client->cMandatoryHeadersCount); bool ok = (client->cIsUpgrade && client->cIsWebsocket); @@ -642,6 +699,12 @@ void WebSocketsServer::handleHeader(WSclient_t * client, String * headerLine) { if(client->cVersion != 13) { ok = false; } + if(!client->cHttpHeadersValid) { + ok = false; + } + if (client->cMandatoryHeadersCount != _mandatoryHttpHeaderCount) { + ok = false; + } } if(_base64Authorization.length() > 0) { diff --git a/src/WebSocketsServer.h b/src/WebSocketsServer.h index 8b75982..d03f840 100644 --- a/src/WebSocketsServer.h +++ b/src/WebSocketsServer.h @@ -38,8 +38,10 @@ public: #ifdef __AVR__ typedef void (*WebSocketServerEvent)(uint8_t num, WStype_t type, uint8_t * payload, size_t length); + typedef bool (*WebSocketServerHttpHeaderValFunc)(String headerName, String headerValue); #else typedef std::function WebSocketServerEvent; + typedef std::function WebSocketServerHttpHeaderValFunc; #endif WebSocketsServer(uint16_t port, String origin = "", String protocol = "arduino"); @@ -55,6 +57,7 @@ public: #endif void onEvent(WebSocketServerEvent cbEvent); + void onValidateHttpHeader(WebSocketServerHttpHeaderValFunc validationFunc, const char* mandatoryHttpHeaders[]); bool sendTXT(uint8_t num, uint8_t * payload, size_t length = 0, bool headerToPayload = false); @@ -90,16 +93,19 @@ protected: String _origin; String _protocol; String _base64Authorization; ///< Base64 encoded Auth request + String * _mandatoryHttpHeaders; + size_t _mandatoryHttpHeaderCount; WEBSOCKETS_NETWORK_SERVER_CLASS * _server; WSclient_t _clients[WEBSOCKETS_SERVER_CLIENT_MAX]; WebSocketServerEvent _cbEvent; + WebSocketServerHttpHeaderValFunc _httpHeaderValidationFunc; bool newClient(WEBSOCKETS_NETWORK_CLASS * TCPclient); - void messageRecived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length); + void messageReceived(WSclient_t * client, WSopcode_t opcode, uint8_t * payload, size_t length); void clientDisconnect(WSclient_t * client); bool clientIsConnected(WSclient_t * client); @@ -111,7 +117,6 @@ protected: void handleHeader(WSclient_t * client, String * headerLine); - /** * called if a non Websocket connection is coming in. * Note: can be override @@ -162,6 +167,30 @@ protected: } } + /* + * Called at client socket connect handshake negotiation time for each http header that is not + * a websocket specific http header (not Connection, Upgrade, Sec-WebSocket-*) + * If the custom httpHeaderValidationFunc returns false for any headerName / headerValue passed, the + * socket negotiation is considered invalid and the upgrade to websockets request is denied / rejected + * This mechanism can be used to enable custom authentication schemes e.g. test the value + * of a session cookie to determine if a user is logged on / authenticated + */ + virtual bool execHttpHeaderValidation(String headerName, String headerValue) { + if(_httpHeaderValidationFunc) { + //return the value of the custom http header validation function + return _httpHeaderValidationFunc(headerName, headerValue); + } + //no custom http header validation so just assume all is good + return true; + } + +private: + /* + * returns an indicator whether the given named header exists in the configured _mandatoryHttpHeaders collection + * @param headerName String ///< the name of the header being checked + */ + bool hasMandatoryHeader(String headerName); + };