More nullptr derefence crash fixes

This commit is contained in:
2021-01-02 17:55:30 +01:00
parent 6ac642b019
commit c9e4d424d5
2 changed files with 153 additions and 122 deletions

View File

@ -311,7 +311,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time)
_controlQueue.pop(); _controlQueue.pop();
_status = WS_DISCONNECTED; _status = WS_DISCONNECTED;
l.unlock(); l.unlock();
_client->close(true); if (_client) _client->close(true);
return; return;
} }
_controlQueue.pop(); _controlQueue.pop();
@ -329,13 +329,16 @@ void AsyncWebSocketClient::_onPoll()
{ {
if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_onPoll this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_onPoll this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle()));
if (!_client)
return;
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_onPoll"); AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_onPoll");
if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) if (_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty()))
{ {
l.unlock(); l.unlock();
_runQueue(); _runQueue();
} }
else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty()))
{ {
l.unlock(); l.unlock();
ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN);
@ -346,6 +349,9 @@ void AsyncWebSocketClient::_runQueue()
{ {
if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_runQueue this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_runQueue this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle()));
if (!_client)
return;
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_runQueue()"); AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_runQueue()");
while (!_messageQueue.empty() && _messageQueue.front().finished()) while (!_messageQueue.empty() && _messageQueue.front().finished())
@ -391,12 +397,15 @@ void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si
{ {
if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_queueControl this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_queueControl this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle()));
if (!_client)
return;
{ {
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueControl"); AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueControl");
_controlQueue.emplace(opcode, data, len, mask); _controlQueue.emplace(opcode, data, len, mask);
} }
if (_client->canSend()) if (_client && _client->canSend())
_runQueue(); _runQueue();
} }
@ -407,12 +416,18 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
if(_status != WS_CONNECTED) if(_status != WS_CONNECTED)
return; return;
if (!_client)
return;
{ {
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueMessage"); AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueMessage");
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES)
{ {
l.unlock(); l.unlock();
ets_printf("ERROR: Too many messages queued\n"); ets_printf("AsyncWebSocketClient::_queueMessage: Too many messages queued, closing connection\n");
_status = WS_DISCONNECTED;
if (_client) _client->close(true);
return;
} }
else else
{ {
@ -420,7 +435,7 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
} }
} }
if(_client->canSend()) if (_client && _client->canSend())
_runQueue(); _runQueue();
} }
@ -460,18 +475,22 @@ void AsyncWebSocketClient::ping(const uint8_t *data, size_t len)
_queueControl(WS_PING, data, len); _queueControl(WS_PING, data, len);
} }
void AsyncWebSocketClient::_onError(int8_t){} void AsyncWebSocketClient::_onError(int8_t err)
{
Serial.printf("AsyncWebSocketClient::_onError() %i %i\r\n", _clientId, err);
}
void AsyncWebSocketClient::_onTimeout(uint32_t time) void AsyncWebSocketClient::_onTimeout(uint32_t time)
{ {
Serial.printf("AsyncWebSocketClient::_onTimeout() %i\r\n", _clientId);
(void)time; (void)time;
_client->close(true); _client->close(true);
} }
void AsyncWebSocketClient::_onDisconnect() void AsyncWebSocketClient::_onDisconnect()
{ {
Serial.printf("AsyncWebSocketClient::_onDisconnect() %i\r\n", _clientId);
_client = NULL; _client = NULL;
_server->_handleDisconnect(this);
} }
void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) void AsyncWebSocketClient::_onData(void *pbuf, size_t plen)
@ -725,17 +744,17 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len)
IPAddress AsyncWebSocketClient::remoteIP() const IPAddress AsyncWebSocketClient::remoteIP() const
{ {
if (!_client) { if (!_client)
return IPAddress(0U); return IPAddress(0U);
}
return _client->remoteIP(); return _client->remoteIP();
} }
uint16_t AsyncWebSocketClient::remotePort() const uint16_t AsyncWebSocketClient::remotePort() const
{ {
if(!_client) { if(!_client)
return 0; return 0;
}
return _client->remotePort(); return _client->remotePort();
} }
@ -768,36 +787,30 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
return &_clients.back(); return &_clients.back();
} }
//void AsyncWebSocket::_addClient(AsyncWebSocketClient * client){ bool AsyncWebSocket::availableForWriteAll()
// _clients.add(client); {
//}
void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient * client){
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id=client->id()](const auto &c){ return c.id() == id; });
if (iter != std::end(_clients))
_clients.erase(iter);
}
bool AsyncWebSocket::availableForWriteAll(){
return std::none_of(std::begin(_clients), std::end(_clients), return std::none_of(std::begin(_clients), std::end(_clients),
[](const auto &c){ return c.queueIsFull(); }); [](const auto &c){ return c.queueIsFull(); });
} }
bool AsyncWebSocket::availableForWrite(uint32_t id){ bool AsyncWebSocket::availableForWrite(uint32_t id)
{
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id](const auto &c){ return c.id() == id; }); [id](const auto &c){ return c.id() == id; });
if (iter == std::end(_clients)) if (iter == std::end(_clients))
return true; // don't know why me-no-dev decided like this? return true;
return !iter->queueIsFull(); return !iter->queueIsFull();
} }
size_t AsyncWebSocket::count() const { size_t AsyncWebSocket::count() const
{
return std::count_if(std::begin(_clients), std::end(_clients), return std::count_if(std::begin(_clients), std::end(_clients),
[](const auto &c){ return c.status() == WS_CONNECTED; }); [](const auto &c){ return c.status() == WS_CONNECTED; });
} }
AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){ AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id)
{
const auto iter = std::find_if(std::begin(_clients), std::end(_clients), const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id](const auto &c){ return c.id() == id && c.status() == WS_CONNECTED; }); [id](const auto &c){ return c.id() == id && c.status() == WS_CONNECTED; });
if (iter == std::end(_clients)) if (iter == std::end(_clients))
@ -816,16 +829,21 @@ void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message)
void AsyncWebSocket::closeAll(uint16_t code, const char * message) void AsyncWebSocket::closeAll(uint16_t code, const char * message)
{ {
for (auto &c : _clients) for (auto &c : _clients)
{
if (c.status() == WS_CONNECTED) if (c.status() == WS_CONNECTED)
c.close(code, message); c.close(code, message);
}
} }
void AsyncWebSocket::cleanupClients(uint16_t maxClients) void AsyncWebSocket::cleanupClients(uint16_t maxClients)
{ {
if (count() > maxClients) { if (count() > maxClients)
_clients.front().close(); _clients.front().close();
for (auto iter = std::begin(_clients); iter != std::end(_clients);)
{
if (iter->shouldBeDeleted())
iter = _clients.erase(iter);
else
iter++;
} }
} }
@ -837,10 +855,9 @@ void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len)
void AsyncWebSocket::pingAll(const uint8_t *data, size_t len) void AsyncWebSocket::pingAll(const uint8_t *data, size_t len)
{ {
for (auto &c : _clients) { for (auto &c : _clients)
if (c.status() == WS_CONNECTED) if (c.status() == WS_CONNECTED)
c.ping(data, len); c.ping(data, len);
}
} }
void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len)
@ -909,16 +926,17 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data)
PGM_P p = reinterpret_cast<PGM_P>(data); PGM_P p = reinterpret_cast<PGM_P>(data);
size_t n = 0; size_t n = 0;
while (1) while (true)
{ {
if (pgm_read_byte(p+n) == 0) break; if (pgm_read_byte(p+n) == 0)
break;
n += 1; n += 1;
} }
char *message = (char*)malloc(n+1); char *message = (char*)malloc(n+1);
if(message) if (message)
{ {
for(size_t b=0; b<n; b++) for (size_t b=0; b<n; b++)
message[b] = pgm_read_byte(p++); message[b] = pgm_read_byte(p++);
message[n] = 0; message[n] = 0;
textAll(message, n); textAll(message, n);
@ -928,7 +946,7 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data)
void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len)
{ {
if (AsyncWebSocketClient * c = client(id)) if (AsyncWebSocketClient *c = client(id))
c->binary(makeBuffer(message, len)); c->binary(makeBuffer(message, len));
} }
void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len) void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len)
@ -947,7 +965,8 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t
{ {
PGM_P p = reinterpret_cast<PGM_P>(data); PGM_P p = reinterpret_cast<PGM_P>(data);
char *message = (char*) malloc(len); char *message = (char*) malloc(len);
if (message) { if (message)
{
for (size_t b=0; b<len; b++) for (size_t b=0; b<len; b++)
message[b] = pgm_read_byte(p++); message[b] = pgm_read_byte(p++);
binary(id, message, len); binary(id, message, len);
@ -983,8 +1002,9 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
{ {
PGM_P p = reinterpret_cast<PGM_P>(data); PGM_P p = reinterpret_cast<PGM_P>(data);
char * message = (char*) malloc(len); char * message = (char*) malloc(len);
if(message) { if(message)
for(size_t b=0; b<len; b++) {
for (size_t b=0; b<len; b++)
message[b] = pgm_read_byte(p++); message[b] = pgm_read_byte(p++);
binaryAll(message, len); binaryAll(message, len);
free(message); free(message);
@ -992,8 +1012,9 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
} }
size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){
AsyncWebSocketClient * c = client(id); AsyncWebSocketClient *c = client(id);
if(c){ if (c)
{
va_list arg; va_list arg;
va_start(arg, format); va_start(arg, format);
size_t len = c->printf(format, arg); size_t len = c->printf(format, arg);
@ -1006,10 +1027,10 @@ size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){
size_t AsyncWebSocket::printfAll(const char *format, ...) size_t AsyncWebSocket::printfAll(const char *format, ...)
{ {
va_list arg; va_list arg;
char* temp = new char[MAX_PRINTF_LEN]; char *temp = new char[MAX_PRINTF_LEN];
if (!temp) { if (!temp)
return 0; return 0;
}
va_start(arg, format); va_start(arg, format);
size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg);
va_end(arg); va_end(arg);
@ -1042,10 +1063,10 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){
size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...)
{ {
va_list arg; va_list arg;
char* temp = new char[MAX_PRINTF_LEN]; char *temp = new char[MAX_PRINTF_LEN];
if (!temp) { if (!temp)
return 0; return 0;
}
va_start(arg, formatP); va_start(arg, formatP);
size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg);
va_end(arg); va_end(arg);
@ -1089,15 +1110,18 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request)
void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
{ {
if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)){ if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY))
{
request->send(400); request->send(400);
return; return;
} }
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str())){ if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
{
return request->requestAuthentication(); return request->requestAuthentication();
} }
AsyncWebHeader* version = request->getHeader(WS_STR_VERSION); AsyncWebHeader* version = request->getHeader(WS_STR_VERSION);
if (version->value().toInt() != 13){ if (version->value().toInt() != 13)
{
AsyncWebServerResponse *response = request->beginResponse(400); AsyncWebServerResponse *response = request->beginResponse(400);
response->addHeader(WS_STR_VERSION,"13"); response->addHeader(WS_STR_VERSION,"13");
request->send(response); request->send(response);
@ -1105,7 +1129,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
} }
AsyncWebHeader* key = request->getHeader(WS_STR_KEY); AsyncWebHeader* key = request->getHeader(WS_STR_KEY);
AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this); AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this);
if (request->hasHeader(WS_STR_PROTOCOL)){ if (request->hasHeader(WS_STR_PROTOCOL))
{
AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL); AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL);
//ToDo: check protocol //ToDo: check protocol
response->addHeader(WS_STR_PROTOCOL, protocol->value()); response->addHeader(WS_STR_PROTOCOL, protocol->value());
@ -1118,18 +1143,21 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
* Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480 * Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480
*/ */
AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server){ AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server)
{
_server = server; _server = server;
_code = 101; _code = 101;
_sendContentLength = false; _sendContentLength = false;
uint8_t * hash = (uint8_t*)malloc(20); uint8_t * hash = (uint8_t*)malloc(20);
if(hash == NULL){ if(hash == NULL)
{
_state = RESPONSE_FAILED; _state = RESPONSE_FAILED;
return; return;
} }
char * buffer = (char *) malloc(33); char * buffer = (char *) malloc(33);
if(buffer == NULL){ if(buffer == NULL)
{
free(hash); free(hash);
_state = RESPONSE_FAILED; _state = RESPONSE_FAILED;
return; return;
@ -1154,8 +1182,10 @@ AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket
free(hash); free(hash);
} }
void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){ void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request)
if(_state == RESPONSE_FAILED){ {
if(_state == RESPONSE_FAILED)
{
request->client()->close(true); request->client()->close(true);
return; return;
} }
@ -1164,11 +1194,12 @@ void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){
_state = RESPONSE_WAIT_ACK; _state = RESPONSE_WAIT_ACK;
} }
size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time){ size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time)
{
(void)time; (void)time;
if(len){
//new AsyncWebSocketClient(request, _server); if(len)
_server->_newClient(request); _server->_newClient(request);
}
return 0; return 0;
} }

View File

@ -24,7 +24,7 @@
#include <Arduino.h> #include <Arduino.h>
#ifdef ESP32 #ifdef ESP32
#include <AsyncTCP.h> #include <AsyncTCP.h>
#define WS_MAX_QUEUED_MESSAGES 32 #define WS_MAX_QUEUED_MESSAGES 16
#else #else
#include <ESPAsyncTCP.h> #include <ESPAsyncTCP.h>
#define WS_MAX_QUEUED_MESSAGES 8 #define WS_MAX_QUEUED_MESSAGES 8
@ -145,6 +145,8 @@ class AsyncWebSocketClient {
IPAddress remoteIP() const; IPAddress remoteIP() const;
uint16_t remotePort() const; uint16_t remotePort() const;
bool shouldBeDeleted() const { return !_client; }
//control frames //control frames
void close(uint16_t code=0, const char * message=NULL); void close(uint16_t code=0, const char * message=NULL);
void ping(const uint8_t *data=NULL, size_t len=0); void ping(const uint8_t *data=NULL, size_t len=0);
@ -264,8 +266,6 @@ class AsyncWebSocket: public AsyncWebHandler {
//system callbacks (do not call) //system callbacks (do not call)
uint32_t _getNextId(){ return _cNextId++; } uint32_t _getNextId(){ return _cNextId++; }
AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request); AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request);
//void _addClient(AsyncWebSocketClient * client);
void _handleDisconnect(AsyncWebSocketClient * client);
void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len); void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len);
virtual bool canHandle(AsyncWebServerRequest *request) override final; virtual bool canHandle(AsyncWebServerRequest *request) override final;
virtual void handleRequest(AsyncWebServerRequest *request) override final; virtual void handleRequest(AsyncWebServerRequest *request) override final;