mirror of
https://github.com/me-no-dev/ESPAsyncWebServer.git
synced 2025-09-27 06:40:56 +02:00
Guard websocket client against invalidation of TCP socket
Based on commit c9e4d424d5d007d9901b83d2c7215be7e77c9837 of dumbfixes branch of 0xFEEDC0DE64 fork of ESPAsyncWebServer. When the socket gets disconnected, its corresponding AsyncClient is set to NULL. Several places in the code now have null pointer checks to prevent crashes. Message queue overflow is now an error that triggers socket disconnection. Limit of queued messages is now 16 for ESP32. Disconnected clients are no longer immediately removed from the client queue. Instead the cleanupClients() function prunes all of the disconnected clients. Some code formatting fixes.
This commit is contained in:
@@ -317,7 +317,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){
|
||||
_controlQueue.pop_front();
|
||||
_status = WS_DISCONNECTED;
|
||||
l.unlock();
|
||||
_client->close(true);
|
||||
if (_client) _client->close(true);
|
||||
return;
|
||||
}
|
||||
_controlQueue.pop_front();
|
||||
@@ -335,13 +335,16 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){
|
||||
|
||||
void AsyncWebSocketClient::_onPoll()
|
||||
{
|
||||
if (!_client)
|
||||
return;
|
||||
|
||||
AsyncWebLockGuard l(_lock);
|
||||
if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty()))
|
||||
if (_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty()))
|
||||
{
|
||||
l.unlock();
|
||||
_runQueue();
|
||||
}
|
||||
else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty()))
|
||||
else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty()))
|
||||
{
|
||||
l.unlock();
|
||||
ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN);
|
||||
@@ -350,6 +353,9 @@ void AsyncWebSocketClient::_onPoll()
|
||||
|
||||
void AsyncWebSocketClient::_runQueue()
|
||||
{
|
||||
if (!_client)
|
||||
return;
|
||||
|
||||
AsyncWebLockGuard l(_lock);
|
||||
|
||||
_clearQueue();
|
||||
@@ -395,12 +401,15 @@ bool AsyncWebSocketClient::canSend() const
|
||||
|
||||
void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask)
|
||||
{
|
||||
if (!_client)
|
||||
return;
|
||||
|
||||
{
|
||||
AsyncWebLockGuard l(_lock);
|
||||
_controlQueue.emplace_back(opcode, data, len, mask);
|
||||
}
|
||||
|
||||
if (_client->canSend())
|
||||
if (_client && _client->canSend())
|
||||
_runQueue();
|
||||
}
|
||||
|
||||
@@ -409,12 +418,18 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
|
||||
if(_status != WS_CONNECTED)
|
||||
return;
|
||||
|
||||
if (!_client)
|
||||
return;
|
||||
|
||||
{
|
||||
AsyncWebLockGuard l(_lock);
|
||||
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES)
|
||||
{
|
||||
l.unlock();
|
||||
ets_printf("ERROR: Too many messages queued\n");
|
||||
ets_printf("AsyncWebSocketClient::_queueMessage: Too many messages queued, closing connection\n");
|
||||
_status = WS_DISCONNECTED;
|
||||
if (_client) _client->close(true);
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -422,7 +437,7 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
|
||||
}
|
||||
}
|
||||
|
||||
if(_client->canSend())
|
||||
if (_client && _client->canSend())
|
||||
_runQueue();
|
||||
}
|
||||
|
||||
@@ -462,8 +477,9 @@ void AsyncWebSocketClient::ping(const uint8_t *data, size_t len)
|
||||
_queueControl(WS_PING, data, len);
|
||||
}
|
||||
|
||||
void AsyncWebSocketClient::_onError(int8_t){
|
||||
//Serial.println("onErr");
|
||||
void AsyncWebSocketClient::_onError(int8_t)
|
||||
{
|
||||
//Serial.println("onErr");
|
||||
}
|
||||
|
||||
void AsyncWebSocketClient::_onTimeout(uint32_t time)
|
||||
@@ -477,7 +493,6 @@ void AsyncWebSocketClient::_onDisconnect()
|
||||
{
|
||||
// Serial.println("onDis");
|
||||
_client = NULL;
|
||||
_server->_handleDisconnect(this);
|
||||
}
|
||||
|
||||
void AsyncWebSocketClient::_onData(void *pbuf, size_t plen)
|
||||
@@ -732,17 +747,17 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len)
|
||||
|
||||
IPAddress AsyncWebSocketClient::remoteIP() const
|
||||
{
|
||||
if (!_client) {
|
||||
if (!_client)
|
||||
return IPAddress(0U);
|
||||
}
|
||||
|
||||
return _client->remoteIP();
|
||||
}
|
||||
|
||||
uint16_t AsyncWebSocketClient::remotePort() const
|
||||
{
|
||||
if(!_client) {
|
||||
if(!_client)
|
||||
return 0;
|
||||
}
|
||||
|
||||
return _client->remotePort();
|
||||
}
|
||||
|
||||
@@ -774,39 +789,35 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
|
||||
return &_clients.back();
|
||||
}
|
||||
|
||||
void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient * client){
|
||||
const auto client_id = client->id();
|
||||
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
|
||||
[client_id](const AsyncWebSocketClient &c){ return c.id() == client_id; });
|
||||
if (iter != std::end(_clients))
|
||||
_clients.erase(iter);
|
||||
}
|
||||
|
||||
bool AsyncWebSocket::availableForWriteAll(){
|
||||
return std::none_of(std::begin(_clients), std::end(_clients),
|
||||
bool AsyncWebSocket::availableForWriteAll()
|
||||
{
|
||||
return std::none_of(std::begin(_clients), std::end(_clients),
|
||||
[](const AsyncWebSocketClient &c){ return c.queueIsFull(); });
|
||||
}
|
||||
|
||||
bool AsyncWebSocket::availableForWrite(uint32_t id){
|
||||
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
|
||||
[id](const AsyncWebSocketClient &c){ return c.id() == id; });
|
||||
if (iter == std::end(_clients))
|
||||
return true; // don't know why me-no-dev decided like this?
|
||||
return !iter->queueIsFull();
|
||||
bool AsyncWebSocket::availableForWrite(uint32_t id)
|
||||
{
|
||||
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
|
||||
[id](const AsyncWebSocketClient &c){ return c.id() == id; });
|
||||
if (iter == std::end(_clients))
|
||||
return true;
|
||||
return !iter->queueIsFull();
|
||||
}
|
||||
|
||||
size_t AsyncWebSocket::count() const {
|
||||
return std::count_if(std::begin(_clients), std::end(_clients),
|
||||
[](const AsyncWebSocketClient &c){ return c.status() == WS_CONNECTED; });
|
||||
size_t AsyncWebSocket::count() const
|
||||
{
|
||||
return std::count_if(std::begin(_clients), std::end(_clients),
|
||||
[](const AsyncWebSocketClient &c){ return c.status() == WS_CONNECTED; });
|
||||
}
|
||||
|
||||
AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){
|
||||
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
|
||||
[id](const AsyncWebSocketClient &c){ return c.id() == id && c.status() == WS_CONNECTED; });
|
||||
if (iter == std::end(_clients))
|
||||
return nullptr;
|
||||
AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id)
|
||||
{
|
||||
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
|
||||
[id](const AsyncWebSocketClient &c){ return c.id() == id && c.status() == WS_CONNECTED; });
|
||||
if (iter == std::end(_clients))
|
||||
return nullptr;
|
||||
|
||||
return &(*iter);
|
||||
return &(*iter);
|
||||
}
|
||||
|
||||
|
||||
@@ -819,16 +830,21 @@ void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message)
|
||||
void AsyncWebSocket::closeAll(uint16_t code, const char * message)
|
||||
{
|
||||
for (auto &c : _clients)
|
||||
{
|
||||
if (c.status() == WS_CONNECTED)
|
||||
c.close(code, message);
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncWebSocket::cleanupClients(uint16_t maxClients)
|
||||
{
|
||||
if (count() > maxClients) {
|
||||
if (count() > maxClients)
|
||||
_clients.front().close();
|
||||
|
||||
for (auto iter = std::begin(_clients); iter != std::end(_clients);)
|
||||
{
|
||||
if (iter->shouldBeDeleted())
|
||||
iter = _clients.erase(iter);
|
||||
else
|
||||
iter++;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -840,10 +856,9 @@ void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len)
|
||||
|
||||
void AsyncWebSocket::pingAll(const uint8_t *data, size_t len)
|
||||
{
|
||||
for (auto &c : _clients) {
|
||||
for (auto &c : _clients)
|
||||
if (c.status() == WS_CONNECTED)
|
||||
c.ping(data, len);
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len)
|
||||
@@ -868,14 +883,15 @@ void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data)
|
||||
PGM_P p = reinterpret_cast<PGM_P>(data);
|
||||
|
||||
size_t n = 0;
|
||||
while (1)
|
||||
while (true)
|
||||
{
|
||||
if (pgm_read_byte(p+n) == 0) break;
|
||||
n += 1;
|
||||
if (pgm_read_byte(p+n) == 0)
|
||||
break;
|
||||
n += 1;
|
||||
}
|
||||
|
||||
char * message = (char*) malloc(n+1);
|
||||
if(message)
|
||||
if (message)
|
||||
{
|
||||
memcpy_P(message, p, n);
|
||||
message[n] = 0;
|
||||
@@ -929,7 +945,7 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data)
|
||||
|
||||
void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len)
|
||||
{
|
||||
if (AsyncWebSocketClient * c = client(id))
|
||||
if (AsyncWebSocketClient *c = client(id))
|
||||
c->binary(makeBuffer(message, len));
|
||||
}
|
||||
void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len)
|
||||
@@ -948,7 +964,8 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t
|
||||
{
|
||||
PGM_P p = reinterpret_cast<PGM_P>(data);
|
||||
char *message = (char*) malloc(len);
|
||||
if (message) {
|
||||
if (message)
|
||||
{
|
||||
memcpy_P(message, p, len);
|
||||
binary(id, message, len);
|
||||
free(message);
|
||||
@@ -983,7 +1000,8 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
|
||||
{
|
||||
PGM_P p = reinterpret_cast<PGM_P>(data);
|
||||
char * message = (char*) malloc(len);
|
||||
if(message) {
|
||||
if(message)
|
||||
{
|
||||
memcpy_P(message, p, len);
|
||||
binaryAll(message, len);
|
||||
free(message);
|
||||
@@ -991,24 +1009,25 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
|
||||
}
|
||||
|
||||
size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){
|
||||
AsyncWebSocketClient * c = client(id);
|
||||
if(c){
|
||||
va_list arg;
|
||||
va_start(arg, format);
|
||||
size_t len = c->printf(format, arg);
|
||||
va_end(arg);
|
||||
return len;
|
||||
}
|
||||
return 0;
|
||||
AsyncWebSocketClient * c = client(id);
|
||||
if (c)
|
||||
{
|
||||
va_list arg;
|
||||
va_start(arg, format);
|
||||
size_t len = c->printf(format, arg);
|
||||
va_end(arg);
|
||||
return len;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t AsyncWebSocket::printfAll(const char *format, ...)
|
||||
{
|
||||
va_list arg;
|
||||
char* temp = new char[MAX_PRINTF_LEN];
|
||||
if (!temp) {
|
||||
char *temp = new char[MAX_PRINTF_LEN];
|
||||
if (!temp)
|
||||
return 0;
|
||||
}
|
||||
|
||||
va_start(arg, format);
|
||||
size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg);
|
||||
va_end(arg);
|
||||
@@ -1041,10 +1060,10 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){
|
||||
size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...)
|
||||
{
|
||||
va_list arg;
|
||||
char* temp = new char[MAX_PRINTF_LEN];
|
||||
if (!temp) {
|
||||
char *temp = new char[MAX_PRINTF_LEN];
|
||||
if (!temp)
|
||||
return 0;
|
||||
}
|
||||
|
||||
va_start(arg, formatP);
|
||||
size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg);
|
||||
va_end(arg);
|
||||
@@ -1099,11 +1118,13 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){
|
||||
|
||||
void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
|
||||
{
|
||||
if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)){
|
||||
if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY))
|
||||
{
|
||||
request->send(400);
|
||||
return;
|
||||
}
|
||||
if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str())){
|
||||
if ((_username.length() && _password.length()) && !request->authenticate(_username.c_str(), _password.c_str()))
|
||||
{
|
||||
return request->requestAuthentication();
|
||||
}
|
||||
if (_handshakeHandler != nullptr){
|
||||
@@ -1113,7 +1134,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
|
||||
}
|
||||
}
|
||||
AsyncWebHeader* version = request->getHeader(WS_STR_VERSION);
|
||||
if (version->value().toInt() != 13){
|
||||
if (version->value().toInt() != 13)
|
||||
{
|
||||
AsyncWebServerResponse *response = request->beginResponse(400);
|
||||
response->addHeader(WS_STR_VERSION, F("13"));
|
||||
request->send(response);
|
||||
@@ -1121,7 +1143,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
|
||||
}
|
||||
AsyncWebHeader* key = request->getHeader(WS_STR_KEY);
|
||||
AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this);
|
||||
if (request->hasHeader(WS_STR_PROTOCOL)){
|
||||
if (request->hasHeader(WS_STR_PROTOCOL))
|
||||
{
|
||||
AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL);
|
||||
//ToDo: check protocol
|
||||
response->addHeader(WS_STR_PROTOCOL, protocol->value());
|
||||
@@ -1134,56 +1157,63 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
|
||||
* Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480
|
||||
*/
|
||||
|
||||
AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server){
|
||||
_server = server;
|
||||
_code = 101;
|
||||
_sendContentLength = false;
|
||||
AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server)
|
||||
{
|
||||
_server = server;
|
||||
_code = 101;
|
||||
_sendContentLength = false;
|
||||
|
||||
uint8_t * hash = (uint8_t*)malloc(20);
|
||||
if(hash == NULL){
|
||||
_state = RESPONSE_FAILED;
|
||||
return;
|
||||
}
|
||||
char * buffer = (char *) malloc(33);
|
||||
if(buffer == NULL){
|
||||
free(hash);
|
||||
_state = RESPONSE_FAILED;
|
||||
return;
|
||||
}
|
||||
uint8_t * hash = (uint8_t*)malloc(20);
|
||||
if(hash == NULL)
|
||||
{
|
||||
_state = RESPONSE_FAILED;
|
||||
return;
|
||||
}
|
||||
char * buffer = (char *) malloc(33);
|
||||
if(buffer == NULL)
|
||||
{
|
||||
free(hash);
|
||||
_state = RESPONSE_FAILED;
|
||||
return;
|
||||
}
|
||||
#ifdef ESP8266
|
||||
sha1(key + WS_STR_UUID, hash);
|
||||
sha1(key + WS_STR_UUID, hash);
|
||||
#else
|
||||
(String&)key += WS_STR_UUID;
|
||||
SHA1_CTX ctx;
|
||||
SHA1Init(&ctx);
|
||||
SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length());
|
||||
SHA1Final(hash, &ctx);
|
||||
(String&)key += WS_STR_UUID;
|
||||
SHA1_CTX ctx;
|
||||
SHA1Init(&ctx);
|
||||
SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length());
|
||||
SHA1Final(hash, &ctx);
|
||||
#endif
|
||||
base64_encodestate _state;
|
||||
base64_init_encodestate(&_state);
|
||||
int len = base64_encode_block((const char *) hash, 20, buffer, &_state);
|
||||
len = base64_encode_blockend((buffer + len), &_state);
|
||||
addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE);
|
||||
addHeader(WS_STR_UPGRADE, F("websocket"));
|
||||
addHeader(WS_STR_ACCEPT,buffer);
|
||||
free(buffer);
|
||||
free(hash);
|
||||
base64_encodestate _state;
|
||||
base64_init_encodestate(&_state);
|
||||
int len = base64_encode_block((const char *) hash, 20, buffer, &_state);
|
||||
len = base64_encode_blockend((buffer + len), &_state);
|
||||
addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE);
|
||||
addHeader(WS_STR_UPGRADE, F("websocket"));
|
||||
addHeader(WS_STR_ACCEPT,buffer);
|
||||
free(buffer);
|
||||
free(hash);
|
||||
}
|
||||
|
||||
void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){
|
||||
if(_state == RESPONSE_FAILED){
|
||||
request->client()->close(true);
|
||||
return;
|
||||
}
|
||||
String out = _assembleHead(request->version());
|
||||
request->client()->write(out.c_str(), _headLength);
|
||||
_state = RESPONSE_WAIT_ACK;
|
||||
void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request)
|
||||
{
|
||||
if(_state == RESPONSE_FAILED)
|
||||
{
|
||||
request->client()->close(true);
|
||||
return;
|
||||
}
|
||||
String out = _assembleHead(request->version());
|
||||
request->client()->write(out.c_str(), _headLength);
|
||||
_state = RESPONSE_WAIT_ACK;
|
||||
}
|
||||
|
||||
size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time){
|
||||
(void)time;
|
||||
if(len){
|
||||
_server->_newClient(request);
|
||||
}
|
||||
return 0;
|
||||
size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time)
|
||||
{
|
||||
(void)time;
|
||||
|
||||
if(len)
|
||||
_server->_newClient(request);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@@ -24,7 +24,7 @@
|
||||
#include <Arduino.h>
|
||||
#ifdef ESP32
|
||||
#include <AsyncTCP.h>
|
||||
#define WS_MAX_QUEUED_MESSAGES 32
|
||||
#define WS_MAX_QUEUED_MESSAGES 16
|
||||
#else
|
||||
#include <ESPAsyncTCP.h>
|
||||
#define WS_MAX_QUEUED_MESSAGES 8
|
||||
@@ -146,6 +146,8 @@ class AsyncWebSocketClient {
|
||||
IPAddress remoteIP() const;
|
||||
uint16_t remotePort() const;
|
||||
|
||||
bool shouldBeDeleted() const { return !_client; }
|
||||
|
||||
//control frames
|
||||
void close(uint16_t code=0, const char * message=NULL);
|
||||
void ping(const uint8_t *data=NULL, size_t len=0);
|
||||
@@ -273,7 +275,6 @@ class AsyncWebSocket: public AsyncWebHandler {
|
||||
//system callbacks (do not call)
|
||||
uint32_t _getNextId(){ return _cNextId++; }
|
||||
AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request);
|
||||
void _handleDisconnect(AsyncWebSocketClient * client);
|
||||
void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len);
|
||||
virtual bool canHandle(AsyncWebServerRequest *request) override final;
|
||||
virtual void handleRequest(AsyncWebServerRequest *request) override final;
|
||||
|
Reference in New Issue
Block a user