diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index de81f3e..ca76bfb 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -317,7 +317,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){ _controlQueue.pop_front(); _status = WS_DISCONNECTED; l.unlock(); - _client->close(true); + if (_client) _client->close(true); return; } _controlQueue.pop_front(); @@ -335,13 +335,16 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){ void AsyncWebSocketClient::_onPoll() { + if (!_client) + return; + AsyncWebLockGuard l(_lock); - if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) + if (_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) { l.unlock(); _runQueue(); } - else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) + else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) { l.unlock(); ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); @@ -350,6 +353,9 @@ void AsyncWebSocketClient::_onPoll() void AsyncWebSocketClient::_runQueue() { + if (!_client) + return; + AsyncWebLockGuard l(_lock); _clearQueue(); @@ -395,12 +401,15 @@ bool AsyncWebSocketClient::canSend() const void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) { + if (!_client) + return; + { AsyncWebLockGuard l(_lock); _controlQueue.emplace_back(opcode, data, len, mask); } - if (_client->canSend()) + if (_client && _client->canSend()) _runQueue(); } @@ -409,12 +418,18 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr> b if(_status != WS_CONNECTED) return; + if (!_client) + return; + { AsyncWebLockGuard l(_lock); if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) { l.unlock(); - ets_printf("ERROR: Too many messages queued\n"); + ets_printf("AsyncWebSocketClient::_queueMessage: Too many messages queued, closing connection\n"); + _status = WS_DISCONNECTED; + if (_client) _client->close(true); + return; } else { @@ -422,7 +437,7 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr> b } } - if(_client->canSend()) + if (_client && _client->canSend()) _runQueue(); } @@ -462,8 +477,9 @@ void AsyncWebSocketClient::ping(const uint8_t *data, size_t len) _queueControl(WS_PING, data, len); } -void AsyncWebSocketClient::_onError(int8_t){ - //Serial.println("onErr"); +void AsyncWebSocketClient::_onError(int8_t) +{ + //Serial.println("onErr"); } void AsyncWebSocketClient::_onTimeout(uint32_t time) @@ -477,7 +493,6 @@ void AsyncWebSocketClient::_onDisconnect() { // Serial.println("onDis"); _client = NULL; - _server->_handleDisconnect(this); } void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) @@ -732,17 +747,17 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) IPAddress AsyncWebSocketClient::remoteIP() const { - if (!_client) { + if (!_client) return IPAddress(0U); - } + return _client->remoteIP(); } uint16_t AsyncWebSocketClient::remotePort() const { - if(!_client) { + if(!_client) return 0; - } + return _client->remotePort(); } @@ -774,39 +789,35 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) return &_clients.back(); } -void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient * client){ - const auto client_id = client->id(); - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), - [client_id](const AsyncWebSocketClient &c){ return c.id() == client_id; }); - if (iter != std::end(_clients)) - _clients.erase(iter); -} - -bool AsyncWebSocket::availableForWriteAll(){ - return std::none_of(std::begin(_clients), std::end(_clients), +bool AsyncWebSocket::availableForWriteAll() +{ + return std::none_of(std::begin(_clients), std::end(_clients), [](const AsyncWebSocketClient &c){ return c.queueIsFull(); }); } -bool AsyncWebSocket::availableForWrite(uint32_t id){ - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), - [id](const AsyncWebSocketClient &c){ return c.id() == id; }); - if (iter == std::end(_clients)) - return true; // don't know why me-no-dev decided like this? - return !iter->queueIsFull(); +bool AsyncWebSocket::availableForWrite(uint32_t id) +{ + const auto iter = std::find_if(std::begin(_clients), std::end(_clients), + [id](const AsyncWebSocketClient &c){ return c.id() == id; }); + if (iter == std::end(_clients)) + return true; + return !iter->queueIsFull(); } -size_t AsyncWebSocket::count() const { - return std::count_if(std::begin(_clients), std::end(_clients), - [](const AsyncWebSocketClient &c){ return c.status() == WS_CONNECTED; }); +size_t AsyncWebSocket::count() const +{ + return std::count_if(std::begin(_clients), std::end(_clients), + [](const AsyncWebSocketClient &c){ return c.status() == WS_CONNECTED; }); } -AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){ - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), - [id](const AsyncWebSocketClient &c){ return c.id() == id && c.status() == WS_CONNECTED; }); - if (iter == std::end(_clients)) - return nullptr; +AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id) +{ + const auto iter = std::find_if(std::begin(_clients), std::end(_clients), + [id](const AsyncWebSocketClient &c){ return c.id() == id && c.status() == WS_CONNECTED; }); + if (iter == std::end(_clients)) + return nullptr; - return &(*iter); + return &(*iter); } @@ -819,16 +830,21 @@ void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message) void AsyncWebSocket::closeAll(uint16_t code, const char * message) { for (auto &c : _clients) - { if (c.status() == WS_CONNECTED) c.close(code, message); - } } void AsyncWebSocket::cleanupClients(uint16_t maxClients) { - if (count() > maxClients) { + if (count() > maxClients) _clients.front().close(); + + for (auto iter = std::begin(_clients); iter != std::end(_clients);) + { + if (iter->shouldBeDeleted()) + iter = _clients.erase(iter); + else + iter++; } } @@ -840,10 +856,9 @@ void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) void AsyncWebSocket::pingAll(const uint8_t *data, size_t len) { - for (auto &c : _clients) { + for (auto &c : _clients) if (c.status() == WS_CONNECTED) c.ping(data, len); - } } void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) @@ -868,14 +883,15 @@ void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data) PGM_P p = reinterpret_cast(data); size_t n = 0; - while (1) + while (true) { - if (pgm_read_byte(p+n) == 0) break; - n += 1; + if (pgm_read_byte(p+n) == 0) + break; + n += 1; } char * message = (char*) malloc(n+1); - if(message) + if (message) { memcpy_P(message, p, n); message[n] = 0; @@ -929,7 +945,7 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data) void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) { - if (AsyncWebSocketClient * c = client(id)) + if (AsyncWebSocketClient *c = client(id)) c->binary(makeBuffer(message, len)); } void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len) @@ -948,7 +964,8 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t { PGM_P p = reinterpret_cast(data); char *message = (char*) malloc(len); - if (message) { + if (message) + { memcpy_P(message, p, len); binary(id, message, len); free(message); @@ -983,7 +1000,8 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len) { PGM_P p = reinterpret_cast(data); char * message = (char*) malloc(len); - if(message) { + if(message) + { memcpy_P(message, p, len); binaryAll(message, len); free(message); @@ -991,24 +1009,25 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len) } size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ - AsyncWebSocketClient * c = client(id); - if(c){ - va_list arg; - va_start(arg, format); - size_t len = c->printf(format, arg); - va_end(arg); - return len; - } - return 0; + AsyncWebSocketClient * c = client(id); + if (c) + { + va_list arg; + va_start(arg, format); + size_t len = c->printf(format, arg); + va_end(arg); + return len; + } + return 0; } size_t AsyncWebSocket::printfAll(const char *format, ...) { va_list arg; - char* temp = new char[MAX_PRINTF_LEN]; - if (!temp) { + char *temp = new char[MAX_PRINTF_LEN]; + if (!temp) return 0; - } + va_start(arg, format); size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); va_end(arg); @@ -1041,10 +1060,10 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){ size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) { va_list arg; - char* temp = new char[MAX_PRINTF_LEN]; - if (!temp) { + char *temp = new char[MAX_PRINTF_LEN]; + if (!temp) return 0; - } + va_start(arg, formatP); size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); va_end(arg); @@ -1099,11 +1118,13 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) { - if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)){ + if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)) + { request->send(400); return; } - if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())){ + if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())) + { return request->requestAuthentication(); } if (_handshakeHandler != nullptr){ @@ -1113,7 +1134,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) } } AsyncWebHeader* version = request->getHeader(WS_STR_VERSION); - if (version->value().toInt() != 13){ + if (version->value().toInt() != 13) + { AsyncWebServerResponse *response = request->beginResponse(400); response->addHeader(WS_STR_VERSION, F("13")); request->send(response); @@ -1121,7 +1143,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) } AsyncWebHeader* key = request->getHeader(WS_STR_KEY); AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this); - if (request->hasHeader(WS_STR_PROTOCOL)){ + if (request->hasHeader(WS_STR_PROTOCOL)) + { AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL); //ToDo: check protocol response->addHeader(WS_STR_PROTOCOL, protocol->value()); @@ -1134,56 +1157,63 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) * Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480 */ -AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server){ - _server = server; - _code = 101; - _sendContentLength = false; +AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server) +{ + _server = server; + _code = 101; + _sendContentLength = false; - uint8_t * hash = (uint8_t*)malloc(20); - if(hash == NULL){ - _state = RESPONSE_FAILED; - return; - } - char * buffer = (char *) malloc(33); - if(buffer == NULL){ - free(hash); - _state = RESPONSE_FAILED; - return; - } + uint8_t * hash = (uint8_t*)malloc(20); + if(hash == NULL) + { + _state = RESPONSE_FAILED; + return; + } + char * buffer = (char *) malloc(33); + if(buffer == NULL) + { + free(hash); + _state = RESPONSE_FAILED; + return; + } #ifdef ESP8266 - sha1(key + WS_STR_UUID, hash); + sha1(key + WS_STR_UUID, hash); #else - (String&)key += WS_STR_UUID; - SHA1_CTX ctx; - SHA1Init(&ctx); - SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length()); - SHA1Final(hash, &ctx); + (String&)key += WS_STR_UUID; + SHA1_CTX ctx; + SHA1Init(&ctx); + SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length()); + SHA1Final(hash, &ctx); #endif - base64_encodestate _state; - base64_init_encodestate(&_state); - int len = base64_encode_block((const char *) hash, 20, buffer, &_state); - len = base64_encode_blockend((buffer + len), &_state); - addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE); - addHeader(WS_STR_UPGRADE, F("websocket")); - addHeader(WS_STR_ACCEPT,buffer); - free(buffer); - free(hash); + base64_encodestate _state; + base64_init_encodestate(&_state); + int len = base64_encode_block((const char *) hash, 20, buffer, &_state); + len = base64_encode_blockend((buffer + len), &_state); + addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE); + addHeader(WS_STR_UPGRADE, F("websocket")); + addHeader(WS_STR_ACCEPT,buffer); + free(buffer); + free(hash); } -void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){ - if(_state == RESPONSE_FAILED){ - request->client()->close(true); - return; - } - String out = _assembleHead(request->version()); - request->client()->write(out.c_str(), _headLength); - _state = RESPONSE_WAIT_ACK; +void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request) +{ + if(_state == RESPONSE_FAILED) + { + request->client()->close(true); + return; + } + String out = _assembleHead(request->version()); + request->client()->write(out.c_str(), _headLength); + _state = RESPONSE_WAIT_ACK; } -size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time){ - (void)time; - if(len){ - _server->_newClient(request); - } - return 0; +size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time) +{ + (void)time; + + if(len) + _server->_newClient(request); + + return 0; } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 3b4e0f7..9a0a3b4 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -24,7 +24,7 @@ #include #ifdef ESP32 #include -#define WS_MAX_QUEUED_MESSAGES 32 +#define WS_MAX_QUEUED_MESSAGES 16 #else #include #define WS_MAX_QUEUED_MESSAGES 8 @@ -146,6 +146,8 @@ class AsyncWebSocketClient { IPAddress remoteIP() const; uint16_t remotePort() const; + bool shouldBeDeleted() const { return !_client; } + //control frames void close(uint16_t code=0, const char * message=NULL); void ping(const uint8_t *data=NULL, size_t len=0); @@ -273,7 +275,6 @@ class AsyncWebSocket: public AsyncWebHandler { //system callbacks (do not call) uint32_t _getNextId(){ return _cNextId++; } AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request); - void _handleDisconnect(AsyncWebSocketClient * client); void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len); virtual bool canHandle(AsyncWebServerRequest *request) override final; virtual void handleRequest(AsyncWebServerRequest *request) override final;