Guard websocket client against invalidation of TCP socket

Based on commit c9e4d424d5d007d9901b83d2c7215be7e77c9837 of dumbfixes
branch of 0xFEEDC0DE64 fork of ESPAsyncWebServer.

When the socket gets disconnected, its corresponding AsyncClient is set
to NULL. Several places in the code now have null pointer checks to
prevent crashes.
Message queue overflow is now an error that triggers socket
disconnection. Limit of queued messages is now 16 for ESP32.
Disconnected clients are no longer immediately removed from the client
queue. Instead the cleanupClients() function prunes all of the
disconnected clients.
Some code formatting fixes.
This commit is contained in:
Alex Villacís Lasso
2021-01-02 20:52:41 -05:00
parent 7a0d05849a
commit b949a6fcd0
2 changed files with 146 additions and 115 deletions

View File

@@ -317,7 +317,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){
_controlQueue.pop_front(); _controlQueue.pop_front();
_status = WS_DISCONNECTED; _status = WS_DISCONNECTED;
l.unlock(); l.unlock();
_client->close(true); if (_client) _client->close(true);
return; return;
} }
_controlQueue.pop_front(); _controlQueue.pop_front();
@@ -335,13 +335,16 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){
void AsyncWebSocketClient::_onPoll() void AsyncWebSocketClient::_onPoll()
{ {
if (!_client)
return;
AsyncWebLockGuard l(_lock); AsyncWebLockGuard l(_lock);
if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) if (_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty()))
{ {
l.unlock(); l.unlock();
_runQueue(); _runQueue();
} }
else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty()))
{ {
l.unlock(); l.unlock();
ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN);
@@ -350,6 +353,9 @@ void AsyncWebSocketClient::_onPoll()
void AsyncWebSocketClient::_runQueue() void AsyncWebSocketClient::_runQueue()
{ {
if (!_client)
return;
AsyncWebLockGuard l(_lock); AsyncWebLockGuard l(_lock);
_clearQueue(); _clearQueue();
@@ -395,12 +401,15 @@ bool AsyncWebSocketClient::canSend() const
void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask)
{ {
if (!_client)
return;
{ {
AsyncWebLockGuard l(_lock); AsyncWebLockGuard l(_lock);
_controlQueue.emplace_back(opcode, data, len, mask); _controlQueue.emplace_back(opcode, data, len, mask);
} }
if (_client->canSend()) if (_client && _client->canSend())
_runQueue(); _runQueue();
} }
@@ -409,12 +418,18 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
if(_status != WS_CONNECTED) if(_status != WS_CONNECTED)
return; return;
if (!_client)
return;
{ {
AsyncWebLockGuard l(_lock); AsyncWebLockGuard l(_lock);
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES)
{ {
l.unlock(); l.unlock();
ets_printf("ERROR: Too many messages queued\n"); ets_printf("AsyncWebSocketClient::_queueMessage: Too many messages queued, closing connection\n");
_status = WS_DISCONNECTED;
if (_client) _client->close(true);
return;
} }
else else
{ {
@@ -422,7 +437,7 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
} }
} }
if(_client->canSend()) if (_client && _client->canSend())
_runQueue(); _runQueue();
} }
@@ -462,8 +477,9 @@ void AsyncWebSocketClient::ping(const uint8_t *data, size_t len)
_queueControl(WS_PING, data, len); _queueControl(WS_PING, data, len);
} }
void AsyncWebSocketClient::_onError(int8_t){ void AsyncWebSocketClient::_onError(int8_t)
//Serial.println("onErr"); {
//Serial.println("onErr");
} }
void AsyncWebSocketClient::_onTimeout(uint32_t time) void AsyncWebSocketClient::_onTimeout(uint32_t time)
@@ -477,7 +493,6 @@ void AsyncWebSocketClient::_onDisconnect()
{ {
// Serial.println("onDis"); // Serial.println("onDis");
_client = NULL; _client = NULL;
_server->_handleDisconnect(this);
} }
void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) void AsyncWebSocketClient::_onData(void *pbuf, size_t plen)
@@ -732,17 +747,17 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len)
IPAddress AsyncWebSocketClient::remoteIP() const IPAddress AsyncWebSocketClient::remoteIP() const
{ {
if (!_client) { if (!_client)
return IPAddress(0U); return IPAddress(0U);
}
return _client->remoteIP(); return _client->remoteIP();
} }
uint16_t AsyncWebSocketClient::remotePort() const uint16_t AsyncWebSocketClient::remotePort() const
{ {
if(!_client) { if(!_client)
return 0; return 0;
}
return _client->remotePort(); return _client->remotePort();
} }
@@ -774,39 +789,35 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
return &_clients.back(); return &_clients.back();
} }
void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient * client){ bool AsyncWebSocket::availableForWriteAll()
const auto client_id = client->id(); {
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), return std::none_of(std::begin(_clients), std::end(_clients),
[client_id](const AsyncWebSocketClient &c){ return c.id() == client_id; });
if (iter != std::end(_clients))
_clients.erase(iter);
}
bool AsyncWebSocket::availableForWriteAll(){
return std::none_of(std::begin(_clients), std::end(_clients),
[](const AsyncWebSocketClient &c){ return c.queueIsFull(); }); [](const AsyncWebSocketClient &c){ return c.queueIsFull(); });
} }
bool AsyncWebSocket::availableForWrite(uint32_t id){ bool AsyncWebSocket::availableForWrite(uint32_t id)
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), {
[id](const AsyncWebSocketClient &c){ return c.id() == id; }); const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
if (iter == std::end(_clients)) [id](const AsyncWebSocketClient &c){ return c.id() == id; });
return true; // don't know why me-no-dev decided like this? if (iter == std::end(_clients))
return !iter->queueIsFull(); return true;
return !iter->queueIsFull();
} }
size_t AsyncWebSocket::count() const { size_t AsyncWebSocket::count() const
return std::count_if(std::begin(_clients), std::end(_clients), {
[](const AsyncWebSocketClient &c){ return c.status() == WS_CONNECTED; }); return std::count_if(std::begin(_clients), std::end(_clients),
[](const AsyncWebSocketClient &c){ return c.status() == WS_CONNECTED; });
} }
AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){ AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id)
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), {
[id](const AsyncWebSocketClient &c){ return c.id() == id && c.status() == WS_CONNECTED; }); const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
if (iter == std::end(_clients)) [id](const AsyncWebSocketClient &c){ return c.id() == id && c.status() == WS_CONNECTED; });
return nullptr; if (iter == std::end(_clients))
return nullptr;
return &(*iter); return &(*iter);
} }
@@ -819,16 +830,21 @@ void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message)
void AsyncWebSocket::closeAll(uint16_t code, const char * message) void AsyncWebSocket::closeAll(uint16_t code, const char * message)
{ {
for (auto &c : _clients) for (auto &c : _clients)
{
if (c.status() == WS_CONNECTED) if (c.status() == WS_CONNECTED)
c.close(code, message); c.close(code, message);
}
} }
void AsyncWebSocket::cleanupClients(uint16_t maxClients) void AsyncWebSocket::cleanupClients(uint16_t maxClients)
{ {
if (count() > maxClients) { if (count() > maxClients)
_clients.front().close(); _clients.front().close();
for (auto iter = std::begin(_clients); iter != std::end(_clients);)
{
if (iter->shouldBeDeleted())
iter = _clients.erase(iter);
else
iter++;
} }
} }
@@ -840,10 +856,9 @@ void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len)
void AsyncWebSocket::pingAll(const uint8_t *data, size_t len) void AsyncWebSocket::pingAll(const uint8_t *data, size_t len)
{ {
for (auto &c : _clients) { for (auto &c : _clients)
if (c.status() == WS_CONNECTED) if (c.status() == WS_CONNECTED)
c.ping(data, len); c.ping(data, len);
}
} }
void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len)
@@ -868,14 +883,15 @@ void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data)
PGM_P p = reinterpret_cast<PGM_P>(data); PGM_P p = reinterpret_cast<PGM_P>(data);
size_t n = 0; size_t n = 0;
while (1) while (true)
{ {
if (pgm_read_byte(p+n) == 0) break; if (pgm_read_byte(p+n) == 0)
n += 1; break;
n += 1;
} }
char * message = (char*) malloc(n+1); char * message = (char*) malloc(n+1);
if(message) if (message)
{ {
memcpy_P(message, p, n); memcpy_P(message, p, n);
message[n] = 0; message[n] = 0;
@@ -929,7 +945,7 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data)
void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len)
{ {
if (AsyncWebSocketClient * c = client(id)) if (AsyncWebSocketClient *c = client(id))
c->binary(makeBuffer(message, len)); c->binary(makeBuffer(message, len));
} }
void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len) void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len)
@@ -948,7 +964,8 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t
{ {
PGM_P p = reinterpret_cast<PGM_P>(data); PGM_P p = reinterpret_cast<PGM_P>(data);
char *message = (char*) malloc(len); char *message = (char*) malloc(len);
if (message) { if (message)
{
memcpy_P(message, p, len); memcpy_P(message, p, len);
binary(id, message, len); binary(id, message, len);
free(message); free(message);
@@ -983,7 +1000,8 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
{ {
PGM_P p = reinterpret_cast<PGM_P>(data); PGM_P p = reinterpret_cast<PGM_P>(data);
char * message = (char*) malloc(len); char * message = (char*) malloc(len);
if(message) { if(message)
{
memcpy_P(message, p, len); memcpy_P(message, p, len);
binaryAll(message, len); binaryAll(message, len);
free(message); free(message);
@@ -991,24 +1009,25 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
} }
size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){
AsyncWebSocketClient * c = client(id); AsyncWebSocketClient * c = client(id);
if(c){ if (c)
va_list arg; {
va_start(arg, format); va_list arg;
size_t len = c->printf(format, arg); va_start(arg, format);
va_end(arg); size_t len = c->printf(format, arg);
return len; va_end(arg);
} return len;
return 0; }
return 0;
} }
size_t AsyncWebSocket::printfAll(const char *format, ...) size_t AsyncWebSocket::printfAll(const char *format, ...)
{ {
va_list arg; va_list arg;
char* temp = new char[MAX_PRINTF_LEN]; char *temp = new char[MAX_PRINTF_LEN];
if (!temp) { if (!temp)
return 0; return 0;
}
va_start(arg, format); va_start(arg, format);
size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg);
va_end(arg); va_end(arg);
@@ -1041,10 +1060,10 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){
size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...)
{ {
va_list arg; va_list arg;
char* temp = new char[MAX_PRINTF_LEN]; char *temp = new char[MAX_PRINTF_LEN];
if (!temp) { if (!temp)
return 0; return 0;
}
va_start(arg, formatP); va_start(arg, formatP);
size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg);
va_end(arg); va_end(arg);
@@ -1099,11 +1118,13 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){
void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
{ {
if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)){ if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY))
{
request->send(400); request->send(400);
return; return;
} }
if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())){ if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str()))
{
return request->requestAuthentication(); return request->requestAuthentication();
} }
if (_handshakeHandler != nullptr){ if (_handshakeHandler != nullptr){
@@ -1113,7 +1134,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
} }
} }
AsyncWebHeader* version = request->getHeader(WS_STR_VERSION); AsyncWebHeader* version = request->getHeader(WS_STR_VERSION);
if (version->value().toInt() != 13){ if (version->value().toInt() != 13)
{
AsyncWebServerResponse *response = request->beginResponse(400); AsyncWebServerResponse *response = request->beginResponse(400);
response->addHeader(WS_STR_VERSION, F("13")); response->addHeader(WS_STR_VERSION, F("13"));
request->send(response); request->send(response);
@@ -1121,7 +1143,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
} }
AsyncWebHeader* key = request->getHeader(WS_STR_KEY); AsyncWebHeader* key = request->getHeader(WS_STR_KEY);
AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this); AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this);
if (request->hasHeader(WS_STR_PROTOCOL)){ if (request->hasHeader(WS_STR_PROTOCOL))
{
AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL); AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL);
//ToDo: check protocol //ToDo: check protocol
response->addHeader(WS_STR_PROTOCOL, protocol->value()); response->addHeader(WS_STR_PROTOCOL, protocol->value());
@@ -1134,56 +1157,63 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
* Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480 * Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480
*/ */
AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server){ AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server)
_server = server; {
_code = 101; _server = server;
_sendContentLength = false; _code = 101;
_sendContentLength = false;
uint8_t * hash = (uint8_t*)malloc(20); uint8_t * hash = (uint8_t*)malloc(20);
if(hash == NULL){ if(hash == NULL)
_state = RESPONSE_FAILED; {
return; _state = RESPONSE_FAILED;
} return;
char * buffer = (char *) malloc(33); }
if(buffer == NULL){ char * buffer = (char *) malloc(33);
free(hash); if(buffer == NULL)
_state = RESPONSE_FAILED; {
return; free(hash);
} _state = RESPONSE_FAILED;
return;
}
#ifdef ESP8266 #ifdef ESP8266
sha1(key + WS_STR_UUID, hash); sha1(key + WS_STR_UUID, hash);
#else #else
(String&)key += WS_STR_UUID; (String&)key += WS_STR_UUID;
SHA1_CTX ctx; SHA1_CTX ctx;
SHA1Init(&ctx); SHA1Init(&ctx);
SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length()); SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length());
SHA1Final(hash, &ctx); SHA1Final(hash, &ctx);
#endif #endif
base64_encodestate _state; base64_encodestate _state;
base64_init_encodestate(&_state); base64_init_encodestate(&_state);
int len = base64_encode_block((const char *) hash, 20, buffer, &_state); int len = base64_encode_block((const char *) hash, 20, buffer, &_state);
len = base64_encode_blockend((buffer + len), &_state); len = base64_encode_blockend((buffer + len), &_state);
addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE); addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE);
addHeader(WS_STR_UPGRADE, F("websocket")); addHeader(WS_STR_UPGRADE, F("websocket"));
addHeader(WS_STR_ACCEPT,buffer); addHeader(WS_STR_ACCEPT,buffer);
free(buffer); free(buffer);
free(hash); free(hash);
} }
void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){ void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request)
if(_state == RESPONSE_FAILED){ {
request->client()->close(true); if(_state == RESPONSE_FAILED)
return; {
} request->client()->close(true);
String out = _assembleHead(request->version()); return;
request->client()->write(out.c_str(), _headLength); }
_state = RESPONSE_WAIT_ACK; String out = _assembleHead(request->version());
request->client()->write(out.c_str(), _headLength);
_state = RESPONSE_WAIT_ACK;
} }
size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time){ size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time)
(void)time; {
if(len){ (void)time;
_server->_newClient(request);
} if(len)
return 0; _server->_newClient(request);
return 0;
} }

View File

@@ -24,7 +24,7 @@
#include <Arduino.h> #include <Arduino.h>
#ifdef ESP32 #ifdef ESP32
#include <AsyncTCP.h> #include <AsyncTCP.h>
#define WS_MAX_QUEUED_MESSAGES 32 #define WS_MAX_QUEUED_MESSAGES 16
#else #else
#include <ESPAsyncTCP.h> #include <ESPAsyncTCP.h>
#define WS_MAX_QUEUED_MESSAGES 8 #define WS_MAX_QUEUED_MESSAGES 8
@@ -146,6 +146,8 @@ class AsyncWebSocketClient {
IPAddress remoteIP() const; IPAddress remoteIP() const;
uint16_t remotePort() const; uint16_t remotePort() const;
bool shouldBeDeleted() const { return !_client; }
//control frames //control frames
void close(uint16_t code=0, const char * message=NULL); void close(uint16_t code=0, const char * message=NULL);
void ping(const uint8_t *data=NULL, size_t len=0); void ping(const uint8_t *data=NULL, size_t len=0);
@@ -273,7 +275,6 @@ class AsyncWebSocket: public AsyncWebHandler {
//system callbacks (do not call) //system callbacks (do not call)
uint32_t _getNextId(){ return _cNextId++; } uint32_t _getNextId(){ return _cNextId++; }
AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request); AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request);
void _handleDisconnect(AsyncWebSocketClient * client);
void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len); void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len);
virtual bool canHandle(AsyncWebServerRequest *request) override final; virtual bool canHandle(AsyncWebServerRequest *request) override final;
virtual void handleRequest(AsyncWebServerRequest *request) override final; virtual void handleRequest(AsyncWebServerRequest *request) override final;