diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index c9c2863..233ca3a 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -311,7 +311,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time) _controlQueue.pop(); _status = WS_DISCONNECTED; l.unlock(); - _client->close(true); + if (_client) _client->close(true); return; } _controlQueue.pop(); @@ -329,13 +329,16 @@ void AsyncWebSocketClient::_onPoll() { if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_onPoll this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + if (!_client) + return; + AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_onPoll"); - if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) + if (_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) { l.unlock(); _runQueue(); } - else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) + else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) { l.unlock(); ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); @@ -346,6 +349,9 @@ void AsyncWebSocketClient::_runQueue() { if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_runQueue this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + if (!_client) + return; + AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_runQueue()"); while (!_messageQueue.empty() && _messageQueue.front().finished()) @@ -391,12 +397,15 @@ void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si { if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_queueControl this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + if (!_client) + return; + { AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueControl"); _controlQueue.emplace(opcode, data, len, mask); } - if (_client->canSend()) + if (_client && _client->canSend()) _runQueue(); } @@ -407,12 +416,18 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr> b if(_status != WS_CONNECTED) return; + if (!_client) + return; + { AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueMessage"); if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) { l.unlock(); - ets_printf("ERROR: Too many messages queued\n"); + ets_printf("AsyncWebSocketClient::_queueMessage: Too many messages queued, closing connection\n"); + _status = WS_DISCONNECTED; + if (_client) _client->close(true); + return; } else { @@ -420,7 +435,7 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr> b } } - if(_client->canSend()) + if (_client && _client->canSend()) _runQueue(); } @@ -460,18 +475,22 @@ void AsyncWebSocketClient::ping(const uint8_t *data, size_t len) _queueControl(WS_PING, data, len); } -void AsyncWebSocketClient::_onError(int8_t){} +void AsyncWebSocketClient::_onError(int8_t err) +{ + Serial.printf("AsyncWebSocketClient::_onError() %i %i\r\n", _clientId, err); +} void AsyncWebSocketClient::_onTimeout(uint32_t time) { + Serial.printf("AsyncWebSocketClient::_onTimeout() %i\r\n", _clientId); (void)time; _client->close(true); } void AsyncWebSocketClient::_onDisconnect() { + Serial.printf("AsyncWebSocketClient::_onDisconnect() %i\r\n", _clientId); _client = NULL; - _server->_handleDisconnect(this); } void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) @@ -725,17 +744,17 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) IPAddress AsyncWebSocketClient::remoteIP() const { - if (!_client) { + if (!_client) return IPAddress(0U); - } + return _client->remoteIP(); } uint16_t AsyncWebSocketClient::remotePort() const { - if(!_client) { + if(!_client) return 0; - } + return _client->remotePort(); } @@ -768,42 +787,36 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request) return &_clients.back(); } -//void AsyncWebSocket::_addClient(AsyncWebSocketClient * client){ -// _clients.add(client); -//} - -void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient * client){ - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), - [id=client->id()](const auto &c){ return c.id() == id; }); - if (iter != std::end(_clients)) - _clients.erase(iter); +bool AsyncWebSocket::availableForWriteAll() +{ + return std::none_of(std::begin(_clients), std::end(_clients), + [](const auto &c){ return c.queueIsFull(); }); } -bool AsyncWebSocket::availableForWriteAll(){ - return std::none_of(std::begin(_clients), std::end(_clients), - [](const auto &c){ return c.queueIsFull(); }); +bool AsyncWebSocket::availableForWrite(uint32_t id) +{ + const auto iter = std::find_if(std::begin(_clients), std::end(_clients), + [id](const auto &c){ return c.id() == id; }); + if (iter == std::end(_clients)) + return true; + + return !iter->queueIsFull(); } -bool AsyncWebSocket::availableForWrite(uint32_t id){ - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), - [id](const auto &c){ return c.id() == id; }); - if (iter == std::end(_clients)) - return true; // don't know why me-no-dev decided like this? - return !iter->queueIsFull(); +size_t AsyncWebSocket::count() const +{ + return std::count_if(std::begin(_clients), std::end(_clients), + [](const auto &c){ return c.status() == WS_CONNECTED; }); } -size_t AsyncWebSocket::count() const { - return std::count_if(std::begin(_clients), std::end(_clients), - [](const auto &c){ return c.status() == WS_CONNECTED; }); -} +AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id) +{ + const auto iter = std::find_if(std::begin(_clients), std::end(_clients), + [id](const auto &c){ return c.id() == id && c.status() == WS_CONNECTED; }); + if (iter == std::end(_clients)) + return nullptr; -AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){ - const auto iter = std::find_if(std::begin(_clients), std::end(_clients), - [id](const auto &c){ return c.id() == id && c.status() == WS_CONNECTED; }); - if (iter == std::end(_clients)) - return nullptr; - - return &(*iter); + return &(*iter); } @@ -816,16 +829,21 @@ void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message) void AsyncWebSocket::closeAll(uint16_t code, const char * message) { for (auto &c : _clients) - { if (c.status() == WS_CONNECTED) c.close(code, message); - } } void AsyncWebSocket::cleanupClients(uint16_t maxClients) { - if (count() > maxClients) { + if (count() > maxClients) _clients.front().close(); + + for (auto iter = std::begin(_clients); iter != std::end(_clients);) + { + if (iter->shouldBeDeleted()) + iter = _clients.erase(iter); + else + iter++; } } @@ -837,10 +855,9 @@ void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) void AsyncWebSocket::pingAll(const uint8_t *data, size_t len) { - for (auto &c : _clients) { + for (auto &c : _clients) if (c.status() == WS_CONNECTED) c.ping(data, len); - } } void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) @@ -909,16 +926,17 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data) PGM_P p = reinterpret_cast(data); size_t n = 0; - while (1) + while (true) { - if (pgm_read_byte(p+n) == 0) break; - n += 1; + if (pgm_read_byte(p+n) == 0) + break; + n += 1; } char *message = (char*)malloc(n+1); - if(message) + if (message) { - for(size_t b=0; bbinary(makeBuffer(message, len)); } void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len) @@ -947,7 +965,8 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t { PGM_P p = reinterpret_cast(data); char *message = (char*) malloc(len); - if (message) { + if (message) + { for (size_t b=0; b(data); char * message = (char*) malloc(len); - if(message) { - for(size_t b=0; bprintf(format, arg); - va_end(arg); - return len; - } - return 0; + AsyncWebSocketClient *c = client(id); + if (c) + { + va_list arg; + va_start(arg, format); + size_t len = c->printf(format, arg); + va_end(arg); + return len; + } + return 0; } size_t AsyncWebSocket::printfAll(const char *format, ...) { va_list arg; - char* temp = new char[MAX_PRINTF_LEN]; - if (!temp) { + char *temp = new char[MAX_PRINTF_LEN]; + if (!temp) return 0; - } + va_start(arg, format); size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); va_end(arg); @@ -1042,10 +1063,10 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){ size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) { va_list arg; - char* temp = new char[MAX_PRINTF_LEN]; - if (!temp) { + char *temp = new char[MAX_PRINTF_LEN]; + if (!temp) return 0; - } + va_start(arg, formatP); size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); va_end(arg); @@ -1089,15 +1110,18 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request) void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) { - if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)){ + if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)) + { request->send(400); return; } - if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str())){ + if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str())) + { return request->requestAuthentication(); } AsyncWebHeader* version = request->getHeader(WS_STR_VERSION); - if (version->value().toInt() != 13){ + if (version->value().toInt() != 13) + { AsyncWebServerResponse *response = request->beginResponse(400); response->addHeader(WS_STR_VERSION,"13"); request->send(response); @@ -1105,7 +1129,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) } AsyncWebHeader* key = request->getHeader(WS_STR_KEY); AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this); - if (request->hasHeader(WS_STR_PROTOCOL)){ + if (request->hasHeader(WS_STR_PROTOCOL)) + { AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL); //ToDo: check protocol response->addHeader(WS_STR_PROTOCOL, protocol->value()); @@ -1118,57 +1143,63 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) * Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480 */ -AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server){ - _server = server; - _code = 101; - _sendContentLength = false; +AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server) +{ + _server = server; + _code = 101; + _sendContentLength = false; - uint8_t * hash = (uint8_t*)malloc(20); - if(hash == NULL){ - _state = RESPONSE_FAILED; - return; - } - char * buffer = (char *) malloc(33); - if(buffer == NULL){ - free(hash); - _state = RESPONSE_FAILED; - return; - } + uint8_t * hash = (uint8_t*)malloc(20); + if(hash == NULL) + { + _state = RESPONSE_FAILED; + return; + } + char * buffer = (char *) malloc(33); + if(buffer == NULL) + { + free(hash); + _state = RESPONSE_FAILED; + return; + } #ifdef ESP8266 - sha1(key + WS_STR_UUID, hash); + sha1(key + WS_STR_UUID, hash); #else - (String&)key += WS_STR_UUID; - SHA1_CTX ctx; - SHA1Init(&ctx); - SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length()); - SHA1Final(hash, &ctx); + (String&)key += WS_STR_UUID; + SHA1_CTX ctx; + SHA1Init(&ctx); + SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length()); + SHA1Final(hash, &ctx); #endif - base64_encodestate _state; - base64_init_encodestate(&_state); - int len = base64_encode_block((const char *) hash, 20, buffer, &_state); - len = base64_encode_blockend((buffer + len), &_state); - addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE); - addHeader(WS_STR_UPGRADE, "websocket"); - addHeader(WS_STR_ACCEPT,buffer); - free(buffer); - free(hash); + base64_encodestate _state; + base64_init_encodestate(&_state); + int len = base64_encode_block((const char *) hash, 20, buffer, &_state); + len = base64_encode_blockend((buffer + len), &_state); + addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE); + addHeader(WS_STR_UPGRADE, "websocket"); + addHeader(WS_STR_ACCEPT,buffer); + free(buffer); + free(hash); } -void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){ - if(_state == RESPONSE_FAILED){ - request->client()->close(true); - return; - } - String out = _assembleHead(request->version()); - request->client()->write(out.c_str(), _headLength); - _state = RESPONSE_WAIT_ACK; +void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request) +{ + if(_state == RESPONSE_FAILED) + { + request->client()->close(true); + return; + } + String out = _assembleHead(request->version()); + request->client()->write(out.c_str(), _headLength); + _state = RESPONSE_WAIT_ACK; } -size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time){ - (void)time; - if(len){ - //new AsyncWebSocketClient(request, _server); - _server->_newClient(request); - } - return 0; +size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time) +{ + (void)time; + + if(len) + _server->_newClient(request); + + return 0; } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 61661ef..16a34a1 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -24,7 +24,7 @@ #include #ifdef ESP32 #include -#define WS_MAX_QUEUED_MESSAGES 32 +#define WS_MAX_QUEUED_MESSAGES 16 #else #include #define WS_MAX_QUEUED_MESSAGES 8 @@ -145,6 +145,8 @@ class AsyncWebSocketClient { IPAddress remoteIP() const; uint16_t remotePort() const; + bool shouldBeDeleted() const { return !_client; } + //control frames void close(uint16_t code=0, const char * message=NULL); void ping(const uint8_t *data=NULL, size_t len=0); @@ -264,8 +266,6 @@ class AsyncWebSocket: public AsyncWebHandler { //system callbacks (do not call) uint32_t _getNextId(){ return _cNextId++; } AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request); - //void _addClient(AsyncWebSocketClient * client); - void _handleDisconnect(AsyncWebSocketClient * client); void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len); virtual bool canHandle(AsyncWebServerRequest *request) override final; virtual void handleRequest(AsyncWebServerRequest *request) override final;