More nullptr derefence crash fixes

This commit is contained in:
2021-01-02 17:55:30 +01:00
parent 6ac642b019
commit c9e4d424d5
2 changed files with 153 additions and 122 deletions

View File

@ -311,7 +311,7 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time)
_controlQueue.pop();
_status = WS_DISCONNECTED;
l.unlock();
_client->close(true);
if (_client) _client->close(true);
return;
}
_controlQueue.pop();
@ -329,13 +329,16 @@ void AsyncWebSocketClient::_onPoll()
{
if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_onPoll this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle()));
if (!_client)
return;
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_onPoll");
if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty()))
if (_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty()))
{
l.unlock();
_runQueue();
}
else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty()))
else if (_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty()))
{
l.unlock();
ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN);
@ -346,6 +349,9 @@ void AsyncWebSocketClient::_runQueue()
{
if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_runQueue this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle()));
if (!_client)
return;
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_runQueue()");
while (!_messageQueue.empty() && _messageQueue.front().finished())
@ -391,12 +397,15 @@ void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, si
{
if (asyncWebSocketDebug) Serial.printf("AsyncWebSocketClient::_queueControl this=0x%llx task=0x%llx %s\r\n", uint64_t(this), uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle()));
if (!_client)
return;
{
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueControl");
_controlQueue.emplace(opcode, data, len, mask);
}
if (_client->canSend())
if (_client && _client->canSend())
_runQueue();
}
@ -407,12 +416,18 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
if(_status != WS_CONNECTED)
return;
if (!_client)
return;
{
AsyncWebLockGuard l(_lock, "AsyncWebSocketClient::_queueMessage");
if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES)
{
l.unlock();
ets_printf("ERROR: Too many messages queued\n");
ets_printf("AsyncWebSocketClient::_queueMessage: Too many messages queued, closing connection\n");
_status = WS_DISCONNECTED;
if (_client) _client->close(true);
return;
}
else
{
@ -420,7 +435,7 @@ void AsyncWebSocketClient::_queueMessage(std::shared_ptr<std::vector<uint8_t>> b
}
}
if(_client->canSend())
if (_client && _client->canSend())
_runQueue();
}
@ -460,18 +475,22 @@ void AsyncWebSocketClient::ping(const uint8_t *data, size_t len)
_queueControl(WS_PING, data, len);
}
void AsyncWebSocketClient::_onError(int8_t){}
void AsyncWebSocketClient::_onError(int8_t err)
{
Serial.printf("AsyncWebSocketClient::_onError() %i %i\r\n", _clientId, err);
}
void AsyncWebSocketClient::_onTimeout(uint32_t time)
{
Serial.printf("AsyncWebSocketClient::_onTimeout() %i\r\n", _clientId);
(void)time;
_client->close(true);
}
void AsyncWebSocketClient::_onDisconnect()
{
Serial.printf("AsyncWebSocketClient::_onDisconnect() %i\r\n", _clientId);
_client = NULL;
_server->_handleDisconnect(this);
}
void AsyncWebSocketClient::_onData(void *pbuf, size_t plen)
@ -725,17 +744,17 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len)
IPAddress AsyncWebSocketClient::remoteIP() const
{
if (!_client) {
if (!_client)
return IPAddress(0U);
}
return _client->remoteIP();
}
uint16_t AsyncWebSocketClient::remotePort() const
{
if(!_client) {
if(!_client)
return 0;
}
return _client->remotePort();
}
@ -768,42 +787,36 @@ AsyncWebSocketClient *AsyncWebSocket::_newClient(AsyncWebServerRequest *request)
return &_clients.back();
}
//void AsyncWebSocket::_addClient(AsyncWebSocketClient * client){
// _clients.add(client);
//}
void AsyncWebSocket::_handleDisconnect(AsyncWebSocketClient * client){
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id=client->id()](const auto &c){ return c.id() == id; });
if (iter != std::end(_clients))
_clients.erase(iter);
bool AsyncWebSocket::availableForWriteAll()
{
return std::none_of(std::begin(_clients), std::end(_clients),
[](const auto &c){ return c.queueIsFull(); });
}
bool AsyncWebSocket::availableForWriteAll(){
return std::none_of(std::begin(_clients), std::end(_clients),
[](const auto &c){ return c.queueIsFull(); });
bool AsyncWebSocket::availableForWrite(uint32_t id)
{
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id](const auto &c){ return c.id() == id; });
if (iter == std::end(_clients))
return true;
return !iter->queueIsFull();
}
bool AsyncWebSocket::availableForWrite(uint32_t id){
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id](const auto &c){ return c.id() == id; });
if (iter == std::end(_clients))
return true; // don't know why me-no-dev decided like this?
return !iter->queueIsFull();
size_t AsyncWebSocket::count() const
{
return std::count_if(std::begin(_clients), std::end(_clients),
[](const auto &c){ return c.status() == WS_CONNECTED; });
}
size_t AsyncWebSocket::count() const {
return std::count_if(std::begin(_clients), std::end(_clients),
[](const auto &c){ return c.status() == WS_CONNECTED; });
}
AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id)
{
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id](const auto &c){ return c.id() == id && c.status() == WS_CONNECTED; });
if (iter == std::end(_clients))
return nullptr;
AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){
const auto iter = std::find_if(std::begin(_clients), std::end(_clients),
[id](const auto &c){ return c.id() == id && c.status() == WS_CONNECTED; });
if (iter == std::end(_clients))
return nullptr;
return &(*iter);
return &(*iter);
}
@ -816,16 +829,21 @@ void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message)
void AsyncWebSocket::closeAll(uint16_t code, const char * message)
{
for (auto &c : _clients)
{
if (c.status() == WS_CONNECTED)
c.close(code, message);
}
}
void AsyncWebSocket::cleanupClients(uint16_t maxClients)
{
if (count() > maxClients) {
if (count() > maxClients)
_clients.front().close();
for (auto iter = std::begin(_clients); iter != std::end(_clients);)
{
if (iter->shouldBeDeleted())
iter = _clients.erase(iter);
else
iter++;
}
}
@ -837,10 +855,9 @@ void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len)
void AsyncWebSocket::pingAll(const uint8_t *data, size_t len)
{
for (auto &c : _clients) {
for (auto &c : _clients)
if (c.status() == WS_CONNECTED)
c.ping(data, len);
}
}
void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len)
@ -909,16 +926,17 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data)
PGM_P p = reinterpret_cast<PGM_P>(data);
size_t n = 0;
while (1)
while (true)
{
if (pgm_read_byte(p+n) == 0) break;
n += 1;
if (pgm_read_byte(p+n) == 0)
break;
n += 1;
}
char *message = (char*)malloc(n+1);
if(message)
if (message)
{
for(size_t b=0; b<n; b++)
for (size_t b=0; b<n; b++)
message[b] = pgm_read_byte(p++);
message[n] = 0;
textAll(message, n);
@ -928,7 +946,7 @@ void AsyncWebSocket::textAll(const __FlashStringHelper *data)
void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len)
{
if (AsyncWebSocketClient * c = client(id))
if (AsyncWebSocketClient *c = client(id))
c->binary(makeBuffer(message, len));
}
void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len)
@ -947,7 +965,8 @@ void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t
{
PGM_P p = reinterpret_cast<PGM_P>(data);
char *message = (char*) malloc(len);
if (message) {
if (message)
{
for (size_t b=0; b<len; b++)
message[b] = pgm_read_byte(p++);
binary(id, message, len);
@ -983,8 +1002,9 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
{
PGM_P p = reinterpret_cast<PGM_P>(data);
char * message = (char*) malloc(len);
if(message) {
for(size_t b=0; b<len; b++)
if(message)
{
for (size_t b=0; b<len; b++)
message[b] = pgm_read_byte(p++);
binaryAll(message, len);
free(message);
@ -992,24 +1012,25 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len)
}
size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){
AsyncWebSocketClient * c = client(id);
if(c){
va_list arg;
va_start(arg, format);
size_t len = c->printf(format, arg);
va_end(arg);
return len;
}
return 0;
AsyncWebSocketClient *c = client(id);
if (c)
{
va_list arg;
va_start(arg, format);
size_t len = c->printf(format, arg);
va_end(arg);
return len;
}
return 0;
}
size_t AsyncWebSocket::printfAll(const char *format, ...)
{
va_list arg;
char* temp = new char[MAX_PRINTF_LEN];
if (!temp) {
char *temp = new char[MAX_PRINTF_LEN];
if (!temp)
return 0;
}
va_start(arg, format);
size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg);
va_end(arg);
@ -1042,10 +1063,10 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){
size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...)
{
va_list arg;
char* temp = new char[MAX_PRINTF_LEN];
if (!temp) {
char *temp = new char[MAX_PRINTF_LEN];
if (!temp)
return 0;
}
va_start(arg, formatP);
size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg);
va_end(arg);
@ -1089,15 +1110,18 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request)
void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
{
if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY)){
if (!request->hasHeader(WS_STR_VERSION) || !request->hasHeader(WS_STR_KEY))
{
request->send(400);
return;
}
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str())){
if ((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
{
return request->requestAuthentication();
}
AsyncWebHeader* version = request->getHeader(WS_STR_VERSION);
if (version->value().toInt() != 13){
if (version->value().toInt() != 13)
{
AsyncWebServerResponse *response = request->beginResponse(400);
response->addHeader(WS_STR_VERSION,"13");
request->send(response);
@ -1105,7 +1129,8 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
}
AsyncWebHeader* key = request->getHeader(WS_STR_KEY);
AsyncWebServerResponse *response = new AsyncWebSocketResponse(key->value(), this);
if (request->hasHeader(WS_STR_PROTOCOL)){
if (request->hasHeader(WS_STR_PROTOCOL))
{
AsyncWebHeader* protocol = request->getHeader(WS_STR_PROTOCOL);
//ToDo: check protocol
response->addHeader(WS_STR_PROTOCOL, protocol->value());
@ -1118,57 +1143,63 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request)
* Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480
*/
AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server){
_server = server;
_code = 101;
_sendContentLength = false;
AsyncWebSocketResponse::AsyncWebSocketResponse(const String& key, AsyncWebSocket *server)
{
_server = server;
_code = 101;
_sendContentLength = false;
uint8_t * hash = (uint8_t*)malloc(20);
if(hash == NULL){
_state = RESPONSE_FAILED;
return;
}
char * buffer = (char *) malloc(33);
if(buffer == NULL){
free(hash);
_state = RESPONSE_FAILED;
return;
}
uint8_t * hash = (uint8_t*)malloc(20);
if(hash == NULL)
{
_state = RESPONSE_FAILED;
return;
}
char * buffer = (char *) malloc(33);
if(buffer == NULL)
{
free(hash);
_state = RESPONSE_FAILED;
return;
}
#ifdef ESP8266
sha1(key + WS_STR_UUID, hash);
sha1(key + WS_STR_UUID, hash);
#else
(String&)key += WS_STR_UUID;
SHA1_CTX ctx;
SHA1Init(&ctx);
SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length());
SHA1Final(hash, &ctx);
(String&)key += WS_STR_UUID;
SHA1_CTX ctx;
SHA1Init(&ctx);
SHA1Update(&ctx, (const unsigned char*)key.c_str(), key.length());
SHA1Final(hash, &ctx);
#endif
base64_encodestate _state;
base64_init_encodestate(&_state);
int len = base64_encode_block((const char *) hash, 20, buffer, &_state);
len = base64_encode_blockend((buffer + len), &_state);
addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE);
addHeader(WS_STR_UPGRADE, "websocket");
addHeader(WS_STR_ACCEPT,buffer);
free(buffer);
free(hash);
base64_encodestate _state;
base64_init_encodestate(&_state);
int len = base64_encode_block((const char *) hash, 20, buffer, &_state);
len = base64_encode_blockend((buffer + len), &_state);
addHeader(WS_STR_CONNECTION, WS_STR_UPGRADE);
addHeader(WS_STR_UPGRADE, "websocket");
addHeader(WS_STR_ACCEPT,buffer);
free(buffer);
free(hash);
}
void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request){
if(_state == RESPONSE_FAILED){
request->client()->close(true);
return;
}
String out = _assembleHead(request->version());
request->client()->write(out.c_str(), _headLength);
_state = RESPONSE_WAIT_ACK;
void AsyncWebSocketResponse::_respond(AsyncWebServerRequest *request)
{
if(_state == RESPONSE_FAILED)
{
request->client()->close(true);
return;
}
String out = _assembleHead(request->version());
request->client()->write(out.c_str(), _headLength);
_state = RESPONSE_WAIT_ACK;
}
size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time){
(void)time;
if(len){
//new AsyncWebSocketClient(request, _server);
_server->_newClient(request);
}
return 0;
size_t AsyncWebSocketResponse::_ack(AsyncWebServerRequest *request, size_t len, uint32_t time)
{
(void)time;
if(len)
_server->_newClient(request);
return 0;
}

View File

@ -24,7 +24,7 @@
#include <Arduino.h>
#ifdef ESP32
#include <AsyncTCP.h>
#define WS_MAX_QUEUED_MESSAGES 32
#define WS_MAX_QUEUED_MESSAGES 16
#else
#include <ESPAsyncTCP.h>
#define WS_MAX_QUEUED_MESSAGES 8
@ -145,6 +145,8 @@ class AsyncWebSocketClient {
IPAddress remoteIP() const;
uint16_t remotePort() const;
bool shouldBeDeleted() const { return !_client; }
//control frames
void close(uint16_t code=0, const char * message=NULL);
void ping(const uint8_t *data=NULL, size_t len=0);
@ -264,8 +266,6 @@ class AsyncWebSocket: public AsyncWebHandler {
//system callbacks (do not call)
uint32_t _getNextId(){ return _cNextId++; }
AsyncWebSocketClient *_newClient(AsyncWebServerRequest *request);
//void _addClient(AsyncWebSocketClient * client);
void _handleDisconnect(AsyncWebSocketClient * client);
void _handleEvent(AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len);
virtual bool canHandle(AsyncWebServerRequest *request) override final;
virtual void handleRequest(AsyncWebServerRequest *request) override final;