diff --git a/src/WebSockets.cpp b/src/WebSockets.cpp index 9f4358a..9497ef6 100644 --- a/src/WebSockets.cpp +++ b/src/WebSockets.cpp @@ -24,6 +24,12 @@ #include "WebSockets.h" +extern "C" { +#include "libb64/cencode.h" +} + +#include + /** * * @param client WSclient_t * ptr to the client struct @@ -195,8 +201,7 @@ void WebSockets::handleWebsocket(WSclient_t * client) { switch(opCode) { case WSop_text: - DEBUG_WEBSOCKETS("[WS-Server][%d][handleWebsocket] text: %s\n", client->num, payload) - ; + DEBUG_WEBSOCKETS("[WS-Server][%d][handleWebsocket] text: %s\n", client->num, payload); // no break here! case WSop_binary: messageRecived(client, opCode, payload, payloadLen); @@ -206,8 +211,7 @@ void WebSockets::handleWebsocket(WSclient_t * client) { sendFrame(client, WSop_pong, payload, payloadLen); break; case WSop_pong: - DEBUG_WEBSOCKETS("[WS-Server][%d][handleWebsocket] get pong from Client (%s)\n", client->num, payload) - ; + DEBUG_WEBSOCKETS("[WS-Server][%d][handleWebsocket] get pong from Client (%s)\n", client->num, payload); break; case WSop_close: { @@ -240,6 +244,43 @@ void WebSockets::handleWebsocket(WSclient_t * client) { } +/** + * generate the key for Sec-WebSocket-Accept + * @param clientKey String + * @return String Accept Key + */ +String WebSockets::acceptKey(String clientKey) { + uint8_t sha1HashBin[20] = { 0 }; + sha1(clientKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", &sha1HashBin[0]); + + String key = base64_encode(sha1HashBin, 20); + key.trim(); + + return key; +} + +/** + * base64_encode + * @param data uint8_t * + * @param length size_t + * @return base64 encoded String + */ +String WebSockets::base64_encode(uint8_t * data, size_t length) { + + char * buffer = (char *) malloc((length*1.4)+1); + if(buffer) { + base64_encodestate _state; + base64_init_encodestate(&_state); + int len = base64_encode_block((const char *) &data[0], length, &buffer[0], &_state); + len = base64_encode_blockend((buffer + len), &_state); + + String base64 = String(buffer); + free(buffer); + return base64; + } + return "-FAIL-"; +} + /** * read x byte from tcp or get timeout * @param client WSclient_t * diff --git a/src/WebSockets.h b/src/WebSockets.h index ee6aa5a..e28478f 100644 --- a/src/WebSockets.h +++ b/src/WebSockets.h @@ -104,6 +104,9 @@ class WebSockets { void handleWebsocket(WSclient_t * client); bool readWait(WSclient_t * client, uint8_t *out, size_t n); + + String acceptKey(String clientKey); + String base64_encode(uint8_t * data, size_t length); }; #endif /* WEBSOCKETS_H_ */ diff --git a/src/WebSocketsServer.cpp b/src/WebSocketsServer.cpp index 0582350..ddf72df 100644 --- a/src/WebSocketsServer.cpp +++ b/src/WebSocketsServer.cpp @@ -25,12 +25,6 @@ #include "WebSockets.h" #include "WebSocketsServer.h" -extern "C" { -#include "libb64/cencode.h" -} - -#include - WebSocketsServer::WebSocketsServer(uint16_t port) { _port = port; _server = new WiFiServer(port); @@ -95,12 +89,27 @@ void WebSocketsServer::sendTXT(uint8_t num, uint8_t * payload, size_t length) { if(num >= WEBSOCKETS_SERVER_CLIENT_MAX) { return; } + if(length == 0) { + length = strlen((const char *) payload); + } WSclient_t * client = &_clients[num]; if(clientIsConnected(client)) { sendFrame(client, WSop_text, payload, length); } } +void WebSocketsServer::sendTXT(uint8_t num, const uint8_t * payload, size_t length) { + sendTXT(num, (uint8_t *) payload, length); +} + +void WebSocketsServer::sendTXT(uint8_t num, char * payload, size_t length) { + sendTXT(num, (uint8_t *) payload, length); +} + +void WebSocketsServer::sendTXT(uint8_t num, const char * payload, size_t length) { + sendTXT(num, (uint8_t *) payload, length); +} + void WebSocketsServer::sendTXT(uint8_t num, String payload) { sendTXT(num, (uint8_t *) payload.c_str(), payload.length()); } @@ -112,6 +121,9 @@ void WebSocketsServer::sendTXT(uint8_t num, String payload) { */ void WebSocketsServer::broadcastTXT(uint8_t * payload, size_t length) { WSclient_t * client; + if(length == 0) { + length = strlen((const char *) payload); + } for(uint8_t i = 0; i < WEBSOCKETS_SERVER_CLIENT_MAX; i++) { client = &_clients[i]; if(clientIsConnected(client)) { @@ -120,6 +132,18 @@ void WebSocketsServer::broadcastTXT(uint8_t * payload, size_t length) { } } +void WebSocketsServer::broadcastTXT(const uint8_t * payload, size_t length) { + broadcastTXT((uint8_t *) payload, length); +} + +void WebSocketsServer::broadcastTXT(char * payload, size_t length) { + broadcastTXT((uint8_t *) payload, length); +} + +void WebSocketsServer::broadcastTXT(const char * payload, size_t length) { + broadcastTXT((uint8_t *) payload, length); +} + void WebSocketsServer::broadcastTXT(String payload) { broadcastTXT((uint8_t *) payload.c_str(), payload.length()); } @@ -140,6 +164,10 @@ void WebSocketsServer::sendBIN(uint8_t num, uint8_t * payload, size_t length) { } } +void WebSocketsServer::sendBIN(uint8_t num, const uint8_t * payload, size_t length) { + sendBIN(num, (uint8_t *) payload, length); +} + /** * send binary data to client all * @param payload uint8_t * @@ -155,6 +183,10 @@ void WebSocketsServer::broadcastBIN(uint8_t * payload, size_t length) { } } +void WebSocketsServer::broadcastBIN(const uint8_t * payload, size_t length) { + broadcastBIN((uint8_t *) payload, length); +} + //################################################################################# //################################################################################# //################################################################################# @@ -289,7 +321,7 @@ void WebSocketsServer::handleClientData(void) { WebSockets::handleWebsocket(client); break; default: - clientDisconnect(client); + WebSockets::clientDisconnect(client, 1002); break; } } @@ -363,19 +395,7 @@ void WebSocketsServer::handleHeader(WSclient_t * client) { DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] Websocket connection incomming.\n", client->num); // generate Sec-WebSocket-Accept key - uint8_t sha1HashBin[20] = { 0 }; - sha1(client->cKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", &sha1HashBin[0]); - - char sha1Base64[64] = { 0 }; - int len = 0; - - base64_encodestate _state; - base64_init_encodestate(&_state); - len = base64_encode_block((const char *) &sha1HashBin[0], 20, &sha1Base64[0], &_state); - base64_encode_blockend((sha1Base64 + len), &_state); - - client->sKey = sha1Base64; - client->sKey.trim(); + client->sKey = acceptKey(client->cKey); DEBUG_WEBSOCKETS("[WS-Server][%d][handleHeader] - sKey: %s\n", client->num, client->sKey.c_str()); @@ -398,8 +418,11 @@ void WebSocketsServer::handleHeader(WSclient_t * client) { // header end client->tcp.write("\r\n"); + // send ping + WebSockets::sendFrame(client, WSop_ping); + if(_cbEvent) { - _cbEvent(client->num, WStype_CONNECTED, NULL, 0); + _cbEvent(client->num, WStype_CONNECTED, (uint8_t *)client->cUrl.c_str(), client->cUrl.length()); } } else { diff --git a/src/WebSocketsServer.h b/src/WebSocketsServer.h index b4dd226..5e1951b 100644 --- a/src/WebSocketsServer.h +++ b/src/WebSocketsServer.h @@ -65,14 +65,23 @@ public: void onEvent(WebSocketServerEvent cbEvent); - void sendTXT(uint8_t num, uint8_t * payload, size_t length); - void broadcastTXT(uint8_t * payload, size_t length); - + void sendTXT(uint8_t num, uint8_t * payload, size_t length = 0); + void sendTXT(uint8_t num, const uint8_t * payload, size_t length = 0); + void sendTXT(uint8_t num, char * payload, size_t length = 0); + void sendTXT(uint8_t num, const char * payload, size_t length = 0); void sendTXT(uint8_t num, String payload); + + void broadcastTXT(uint8_t * payload, size_t length = 0); + void broadcastTXT(const uint8_t * payload, size_t length = 0); + void broadcastTXT(char * payload, size_t length = 0); + void broadcastTXT(const char * payload, size_t length = 0); void broadcastTXT(String payload); void sendBIN(uint8_t num, uint8_t * payload, size_t length); + void sendBIN(uint8_t num, const uint8_t * payload, size_t length); + void broadcastBIN(uint8_t * payload, size_t length); + void broadcastBIN(const uint8_t * payload, size_t length); private: uint16_t _port; diff --git a/src/libb64/cencode.c b/src/libb64/cencode.c index a8c8fee..a15b5dc 100644 --- a/src/libb64/cencode.c +++ b/src/libb64/cencode.c @@ -102,7 +102,7 @@ int base64_encode_blockend(char* code_out, base64_encodestate* state_in) case step_A: break; } - *codechar++ = '\n'; + *codechar++ = 0x00; return codechar - code_out; }