From 4963ce9da983e433d938e8b978327f0eac2ddd8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Villac=C3=ADs=20Lasso?= Date: Sat, 2 Jan 2021 20:07:28 -0500 Subject: [PATCH] Replace entire AsyncWebSocketBuffer infrastructure with std::shared_ptr Based on commit 9172736ac28e7c2f3edd0c56434bc27f578a89b1 of dumbfixes branch of 0xFEEDC0DE64 fork of ESPAsyncWebServer. The entire purpose of having an AsyncWebSocketMessageBuffer is to maintain a shared copy of a buffer sent to multiple clients until all of them finish sending it. This can be replaced with a single message object that contains a shared_ptr to the data being sent. For the single client case, this is equivalent to non-shared buffer. For multiple clients, the shared_ptr will maintain the buffer live until all messages referencing it are destroyed. This simplifies the websocket architecture. --- src/AsyncWebSocket.cpp | 883 ++++++++++++++++------------------------- src/AsyncWebSocket.h | 201 ++-------- 2 files changed, 377 insertions(+), 707 deletions(-) diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 5a62b48..58bc457 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -21,6 +21,8 @@ #include "Arduino.h" #include "AsyncWebSocket.h" +#include + #include #ifndef ESP8266 @@ -131,308 +133,129 @@ size_t webSocketSendFrame(AsyncClient *client, bool final, uint8_t opcode, bool } -/* - * AsyncWebSocketMessageBuffer - */ - - - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() - :_data(nullptr) - ,_len(0) -{ - -} - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t * data, size_t size) - :_data(nullptr) - ,_len(size) -{ - if (!data) - return; - - _data = std::unique_ptr(new uint8_t[_len + 1]); //std::make_unique(_len + 1); - memcpy(_data.get(), data, _len); - _data[_len] = 0; -} - - -AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(size_t size) - :_data(nullptr) - ,_len(size) -{ - _data = std::unique_ptr(new uint8_t[_len + 1]); //std::make_unique(_len + 1); - _data[_len] = 0; -} - -AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer() -{ -} - -bool AsyncWebSocketMessageBuffer::reserve(size_t size) -{ - _len = size; - - _data = std::unique_ptr(new uint8_t[_len + 1]); //std::make_unique(_len + 1); - _data[_len] = 0; - - return true; -} - - /* * Control Frame */ class AsyncWebSocketControl { - private: +private: uint8_t _opcode; uint8_t *_data; size_t _len; bool _mask; bool _finished; - public: - AsyncWebSocketControl(uint8_t opcode, uint8_t *data=NULL, size_t len=0, bool mask=false) + +public: + AsyncWebSocketControl(uint8_t opcode, const uint8_t *data=NULL, size_t len=0, bool mask=false) :_opcode(opcode) ,_len(len) ,_mask(len && mask) ,_finished(false) - { - if(data == NULL) - _len = 0; - if(_len){ - if(_len > 125) - _len = 125; - _data = (uint8_t*)malloc(_len); - if(_data == NULL) - _len = 0; - else memcpy(_data, data, len); - } else _data = NULL; + { + if (data == NULL) + _len = 0; + if (_len) + { + if (_len > 125) + _len = 125; + + _data = (uint8_t*)malloc(_len); + + if(_data == NULL) + _len = 0; + else + memcpy(_data, data, len); + } + else + _data = NULL; } - virtual ~AsyncWebSocketControl(){ - if(_data != NULL) - free(_data); + + virtual ~AsyncWebSocketControl() + { + if (_data != NULL) + free(_data); } + virtual bool finished() const { return _finished; } uint8_t opcode(){ return _opcode; } uint8_t len(){ return _len + 2; } size_t send(AsyncClient *client){ - _finished = true; - return webSocketSendFrame(client, true, _opcode & 0x0F, _mask, _data, _len); + _finished = true; + return webSocketSendFrame(client, true, _opcode & 0x0F, _mask, _data, _len); } }; + /* - * Basic Buffered Message + * AsyncWebSocketMessage Message */ -AsyncWebSocketBasicMessage::AsyncWebSocketBasicMessage(const char * data, size_t len, uint8_t opcode, bool mask) - :_len(len) - ,_sent(0) - ,_ack(0) - ,_acked(0) +AsyncWebSocketMessage::AsyncWebSocketMessage(std::shared_ptr> buffer, uint8_t opcode, bool mask) : + _WSbuffer{buffer}, + _opcode(opcode & 0x07), + _mask{mask}, + _status{_WSbuffer?WS_MSG_SENDING:WS_MSG_ERROR} { - _opcode = opcode & 0x07; - _mask = mask; - _data = (uint8_t*)malloc(_len+1); - // Serial.println("MSG alloc"); - if(_data == NULL){ - _len = 0; - _status = WS_MSG_ERROR; - } else { - _status = WS_MSG_SENDING; - memcpy(_data, data, _len); - _data[_len] = 0; - } -} -AsyncWebSocketBasicMessage::AsyncWebSocketBasicMessage(uint8_t opcode, bool mask) - :_len(0) - ,_sent(0) - ,_ack(0) - ,_acked(0) - ,_data(NULL) +} + +void AsyncWebSocketMessage::ack(size_t len, uint32_t time) { - _opcode = opcode & 0x07; - _mask = mask; - + (void)time; + _acked += len; + if (_sent >= _WSbuffer->size() && _acked >= _ack) + { + _status = WS_MSG_SENT; + } + //ets_printf("A: %u\n", len); } - -AsyncWebSocketBasicMessage::~AsyncWebSocketBasicMessage() { - if(_data != NULL) { - // Serial.println("MSG free"); - free(_data); - } -} - - void AsyncWebSocketBasicMessage::ack(size_t len, uint32_t time) { - (void)time; - _acked += len; - // Serial.printf("ACK %u = %u | %u = %u\n", _sent, _len, _acked, _ack); - if(_sent == _len && _acked == _ack){ - // Serial.println("ACK end"); - _status = WS_MSG_SENT; - } -} - size_t AsyncWebSocketBasicMessage::send(AsyncClient *client) { - if(_status != WS_MSG_SENDING){ - // Serial.println("MS 1"); - return 0; - } - if(_acked < _ack){ - // Serial.println("MS 2"); - return 0; - } - if(_sent == _len){ - // Serial.println("MS 3"); - _status = WS_MSG_SENT; - return 0; - } - if(_sent > _len){ - // Serial.println("MS 4"); - _status = WS_MSG_ERROR; - return 0; - } - size_t toSend = _len - _sent; - size_t window = webSocketSendFrameWindow(client); - // Serial.printf("Send %u %u %u\n", _len, _sent, toSend); - - if(window < toSend) { - toSend = window; - } - - _sent += toSend; - _ack += toSend + ((toSend < 126)?2:4) + (_mask * 4); - - bool final = (_sent == _len); - uint8_t* dPtr = (uint8_t*)(_data + (_sent - toSend)); - uint8_t opCode = (toSend && _sent == toSend)?_opcode:(uint8_t)WS_CONTINUATION; - - size_t sent = webSocketSendFrame(client, final, opCode, _mask, dPtr, toSend); - _status = WS_MSG_SENDING; - if(toSend && sent != toSend){ - size_t delta = (toSend - sent); - // Serial.printf("\ns:%u a:%u d:%u\n", _sent, _ack, delta); - _sent -= delta; - _ack -= delta + ((delta < 126)?2:4) + (_mask * 4); - // Serial.printf("s:%u a:%u\n", _sent, _ack); - if (!sent) { +size_t AsyncWebSocketMessage::send(AsyncClient *client) +{ + if (_status != WS_MSG_SENDING) + return 0; + if (_acked < _ack){ + return 0; + } + if (_sent == _WSbuffer->size()) + { + if(_acked == _ack) + _status = WS_MSG_SENT; + return 0; + } + if (_sent > _WSbuffer->size()) + { _status = WS_MSG_ERROR; - } - } - return sent; -} + //ets_printf("E: %u > %u\n", _sent, _WSbuffer->length()); + return 0; + } -// bool AsyncWebSocketBasicMessage::reserve(size_t size) { -// if (size) { -// _data = (uint8_t*)malloc(size +1); -// if (_data) { -// memset(_data, 0, size); -// _len = size; -// _status = WS_MSG_SENDING; -// return true; -// } -// } -// return false; -// } + size_t toSend = _WSbuffer->size() - _sent; + size_t window = webSocketSendFrameWindow(client); + if (window < toSend) { + toSend = window; + } -/* - * AsyncWebSocketMultiMessage Message - */ + _sent += toSend; + _ack += toSend + ((toSend < 126)?2:4) + (_mask * 4); + //ets_printf("W: %u %u\n", _sent - toSend, toSend); -AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(std::shared_ptr buffer, uint8_t opcode, bool mask) - :_len(0) - ,_sent(0) - ,_ack(0) - ,_acked(0) - ,_WSbuffer(nullptr) -{ - _opcode = opcode & 0x07; - _mask = mask; + bool final = (_sent == _WSbuffer->size()); + uint8_t* dPtr = (uint8_t*)(_WSbuffer->data() + (_sent - toSend)); + uint8_t opCode = (toSend && _sent == toSend)?_opcode:(uint8_t)WS_CONTINUATION; - if (buffer) { - _WSbuffer = buffer; - _data = buffer->get(); - _len = buffer->length(); + size_t sent = webSocketSendFrame(client, final, opCode, _mask, dPtr, toSend); _status = WS_MSG_SENDING; - //ets_printf("M: %u\n", _len); - } else { - // Serial.println("BUFF ERROR"); - _status = WS_MSG_ERROR; - } - -} - - -AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() { -} - - void AsyncWebSocketMultiMessage::ack(size_t len, uint32_t time) { - (void)time; - _acked += len; - // Serial.printf("ACK %u = %u | %u = %u\n", _sent, _len, _acked, _ack); - if(_sent >= _len && _acked >= _ack){ - // Serial.println("ACK end"); - _status = WS_MSG_SENT; - } - //ets_printf("A: %u\n", len); -} - size_t AsyncWebSocketMultiMessage::send(AsyncClient *client) { - if(_status != WS_MSG_SENDING) { - // Serial.println("MS 1"); - return 0; - } - if(_acked < _ack){ - // Serial.println("MS 2"); - return 0; - } - if(_sent == _len){ - // Serial.println("MS 3"); - _status = WS_MSG_SENT; - return 0; - } - if(_sent > _len){ - // Serial.println("MS 4"); - _status = WS_MSG_ERROR; - //ets_printf("E: %u > %u\n", _sent, _len); - return 0; - } - size_t toSend = _len - _sent; - size_t window = webSocketSendFrameWindow(client); - // Serial.printf("Send %u %u %u\n", _len, _sent, toSend); - - if(window < toSend) { - toSend = window; - } - // Serial.printf("s:%u a:%u t:%u\n", _sent, _ack, toSend); - _sent += toSend; - _ack += toSend + ((toSend < 126)?2:4) + (_mask * 4); - - //ets_printf("W: %u %u\n", _sent - toSend, toSend); - - bool final = (_sent == _len); - uint8_t* dPtr = (uint8_t*)(_data + (_sent - toSend)); - uint8_t opCode = (toSend && _sent == toSend)?_opcode:(uint8_t)WS_CONTINUATION; - - size_t sent = webSocketSendFrame(client, final, opCode, _mask, dPtr, toSend); - _status = WS_MSG_SENDING; - if(toSend && sent != toSend){ - //ets_printf("E: %u != %u\n", toSend, sent); - size_t delta = (toSend - sent); - // Serial.printf("\ns:%u a:%u d:%u\n", _sent, _ack, delta); - _sent -= delta; - _ack -= delta + ((delta < 126)?2:4) + (_mask * 4); - // Serial.printf("s:%u a:%u\n", _sent, _ack); - if (!sent) { - _status = WS_MSG_ERROR; - } - } - //ets_printf("S: %u %u\n", _sent, sent); - return sent; + if (toSend && sent != toSend){ + //ets_printf("E: %u != %u\n", toSend, sent); + _sent -= (toSend - sent); + _ack -= (toSend - sent); + } + //ets_printf("S: %u %u\n", _sent, sent); + return sent; } @@ -472,14 +295,13 @@ AsyncWebSocketClient::~AsyncWebSocketClient() _messageQueue.clear(); _controlQueue.clear(); } - _server->_cleanBuffers(); _server->_handleEvent(this, WS_EVT_DISCONNECT, NULL, NULL, 0); } -void AsyncWebSocketClient::_clearQueue(){ - while (!_messageQueue.empty() && _messageQueue.front().get().finished()){ +void AsyncWebSocketClient::_clearQueue() +{ + while (!_messageQueue.empty() && _messageQueue.front().finished()) _messageQueue.pop_front(); - } } void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){ @@ -488,29 +310,28 @@ void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){ { AsyncWebLockGuard l(_lock); - if (!_controlQueue.empty()){ + if (!_controlQueue.empty()) { auto &head = _controlQueue.front(); - if(head.finished()){ + if (head.finished()){ len -= head.len(); - if(_status == WS_DISCONNECTING && head.opcode() == WS_DISCONNECT){ - _controlQueue.pop_front(); - _status = WS_DISCONNECTED; - l.unlock(); - _client->close(true); - return; + if (_status == WS_DISCONNECTING && head.opcode() == WS_DISCONNECT){ + _controlQueue.pop_front(); + _status = WS_DISCONNECTED; + l.unlock(); + _client->close(true); + return; } _controlQueue.pop_front(); } } if(len && !_messageQueue.empty()){ - _messageQueue.front().get().ack(len, time); + _messageQueue.front().ack(len, time); } _clearQueue(); } - _server->_cleanBuffers(); _runQueue(); } @@ -521,7 +342,8 @@ void AsyncWebSocketClient::_onPoll() { l.unlock(); _runQueue(); - } else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) + } + else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && (_controlQueue.empty() && _messageQueue.empty())) { l.unlock(); ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); @@ -534,15 +356,15 @@ void AsyncWebSocketClient::_runQueue() _clearQueue(); - if(!_controlQueue.empty() && (_messageQueue.empty() || _messageQueue.front().get().betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front().len() - 1)) + if (!_controlQueue.empty() && (_messageQueue.empty() || _messageQueue.front().betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front().len() - 1)) { //l.unlock(); _controlQueue.front().send(_client); } - else if (!_messageQueue.empty() && _messageQueue.front().get().betweenFrames() && webSocketSendFrameWindow(_client)) + else if (!_messageQueue.empty() && _messageQueue.front().betweenFrames() && webSocketSendFrameWindow(_client)) { //l.unlock(); - _messageQueue.front().get().send(_client); + _messageQueue.front().send(_client); } } @@ -573,7 +395,7 @@ bool AsyncWebSocketClient::canSend() const return size < WS_MAX_QUEUED_MESSAGES; } -void AsyncWebSocketClient::_queueControl(uint8_t opcode, uint8_t *data, size_t len, bool mask) +void AsyncWebSocketClient::_queueControl(uint8_t opcode, const uint8_t *data, size_t len, bool mask) { { AsyncWebLockGuard l(_lock); @@ -584,29 +406,7 @@ void AsyncWebSocketClient::_queueControl(uint8_t opcode, uint8_t *data, size_t l _runQueue(); } -void AsyncWebSocketClient::_queueMessage(const char *data, size_t len, uint8_t opcode, bool mask) -{ - if(_status != WS_CONNECTED) - return; - - { - AsyncWebLockGuard l(_lock); - if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) - { - l.unlock(); - ets_printf("ERROR: Too many messages queued\n"); - } - else - { - _messageQueue.emplace_back(data, len, opcode, mask); - } - } - - if(_client->canSend()) - _runQueue(); -} - -void AsyncWebSocketClient::_queueMessage(std::shared_ptr buffer, uint8_t opcode, bool mask) +void AsyncWebSocketClient::_queueMessage(std::shared_ptr> buffer, uint8_t opcode, bool mask) { if(_status != WS_CONNECTED) return; @@ -653,28 +453,32 @@ void AsyncWebSocketClient::close(uint16_t code, const char * message){ _queueControl(WS_DISCONNECT); } -void AsyncWebSocketClient::ping(uint8_t *data, size_t len){ - if(_status == WS_CONNECTED) - _queueControl(WS_PING, data, len); +void AsyncWebSocketClient::ping(const uint8_t *data, size_t len) +{ + if (_status == WS_CONNECTED) + _queueControl(WS_PING, data, len); } void AsyncWebSocketClient::_onError(int8_t){ //Serial.println("onErr"); } -void AsyncWebSocketClient::_onTimeout(uint32_t time){ - // Serial.println("onTime"); - (void)time; - _client->close(true); +void AsyncWebSocketClient::_onTimeout(uint32_t time) +{ + // Serial.println("onTime"); + (void)time; + _client->close(true); } -void AsyncWebSocketClient::_onDisconnect(){ - // Serial.println("onDis"); - _client = NULL; - _server->_handleDisconnect(this); +void AsyncWebSocketClient::_onDisconnect() +{ + // Serial.println("onDis"); + _client = NULL; + _server->_handleDisconnect(this); } -void AsyncWebSocketClient::_onData(void *pbuf, size_t plen){ +void AsyncWebSocketClient::_onData(void *pbuf, size_t plen) +{ // Serial.println("onData"); _lastMessageTime = millis(); uint8_t *data = (uint8_t*)pbuf; @@ -768,7 +572,8 @@ void AsyncWebSocketClient::_onData(void *pbuf, size_t plen){ } } -size_t AsyncWebSocketClient::printf(const char *format, ...) { +size_t AsyncWebSocketClient::printf(const char *format, ...) +{ va_list arg; va_start(arg, format); char* temp = new char[MAX_PRINTF_LEN]; @@ -799,7 +604,8 @@ size_t AsyncWebSocketClient::printf(const char *format, ...) { } #ifndef ESP32 -size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) { +size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) +{ va_list arg; va_start(arg, formatP); char* temp = new char[MAX_PRINTF_LEN]; @@ -830,63 +636,96 @@ size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) { } #endif -void AsyncWebSocketClient::text(const char * message, size_t len){ - _queueMessage(message, len); +namespace { +std::shared_ptr> makeBuffer(const uint8_t *message, size_t len) +{ + auto buffer = std::make_shared>(len); + std::memcpy(buffer->data(), message, len); + return buffer; } -void AsyncWebSocketClient::text(const char * message){ - text(message, strlen(message)); } -void AsyncWebSocketClient::text(uint8_t * message, size_t len){ - text((const char *)message, len); -} -void AsyncWebSocketClient::text(char * message){ - text(message, strlen(message)); -} -void AsyncWebSocketClient::text(const String &message){ - text(message.c_str(), message.length()); -} -void AsyncWebSocketClient::text(const __FlashStringHelper *data){ - text(String(data)); -} -void AsyncWebSocketClient::text(std::shared_ptr buffer) + +void AsyncWebSocketClient::text(std::shared_ptr> buffer) { _queueMessage(buffer); } -void AsyncWebSocketClient::binary(const char * message, size_t len) +void AsyncWebSocketClient::text(const uint8_t *message, size_t len) { - _queueMessage(message, len, WS_BINARY); + text(makeBuffer(message, len)); } -void AsyncWebSocketClient::binary(const char * message) -{ - binary(message, strlen(message)); -} -void AsyncWebSocketClient::binary(uint8_t * message, size_t len) -{ - binary((const char *)message, len); -} -void AsyncWebSocketClient::binary(char * message) + +void AsyncWebSocketClient::text(const char *message, size_t len) +{ + text((const uint8_t *)message, len); +} + +void AsyncWebSocketClient::text(const char *message) +{ + text(message, strlen(message)); +} + +void AsyncWebSocketClient::text(const String &message) +{ + text(message.c_str(), message.length()); +} + +void AsyncWebSocketClient::text(const __FlashStringHelper *data) +{ + PGM_P p = reinterpret_cast(data); + + size_t n = 0; + while (1) + { + if (pgm_read_byte(p+n) == 0) break; + n += 1; + } + + char * message = (char*) malloc(n+1); + if(message) + { + memcpy_P(message, p, n); + message[n] = 0; + text(message, n); + free(message); + } +} + +void AsyncWebSocketClient::binary(std::shared_ptr> buffer) +{ + _queueMessage(buffer, WS_BINARY); +} + +void AsyncWebSocketClient::binary(const uint8_t *message, size_t len) +{ + binary(makeBuffer(message, len)); +} + +void AsyncWebSocketClient::binary(const char *message, size_t len) +{ + binary((const uint8_t *)message, len); +} + +void AsyncWebSocketClient::binary(const char *message) { binary(message, strlen(message)); } + void AsyncWebSocketClient::binary(const String &message) { binary(message.c_str(), message.length()); } + void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len) { PGM_P p = reinterpret_cast(data); - char * message = (char*) malloc(len); - if(message) { + char *message = (char*) malloc(len); + if (message) { memcpy_P(message, p, len); binary(message, len); free(message); } } -void AsyncWebSocketClient::binary(std::shared_ptr buffer) -{ - _queueMessage(buffer, WS_BINARY); -} IPAddress AsyncWebSocketClient::remoteIP() const { @@ -968,114 +807,184 @@ AsyncWebSocketClient * AsyncWebSocket::client(uint32_t id){ } -void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message){ - AsyncWebSocketClient *c = client(id); - if (c) - c->close(code, message); +void AsyncWebSocket::close(uint32_t id, uint16_t code, const char * message) +{ + if (AsyncWebSocketClient *c = client(id)) + c->close(code, message); } void AsyncWebSocket::closeAll(uint16_t code, const char * message) { - for(auto& c: _clients) - { - if(c.status() == WS_CONNECTED) - c.close(code, message); - } + for (auto &c : _clients) + { + if (c.status() == WS_CONNECTED) + c.close(code, message); + } } void AsyncWebSocket::cleanupClients(uint16_t maxClients) { - if (count() > maxClients){ - _clients.front().close(); - } + if (count() > maxClients) { + _clients.front().close(); + } } -void AsyncWebSocket::ping(uint32_t id, uint8_t *data, size_t len) +void AsyncWebSocket::ping(uint32_t id, const uint8_t *data, size_t len) { - AsyncWebSocketClient * c = client(id); - if(c) - c->ping(data, len); + if (AsyncWebSocketClient * c = client(id)) + c->ping(data, len); } -void AsyncWebSocket::pingAll(uint8_t *data, size_t len) +void AsyncWebSocket::pingAll(const uint8_t *data, size_t len) { - for(auto& c: _clients){ - if(c.status() == WS_CONNECTED) - c.ping(data, len); - } + for (auto &c : _clients) { + if (c.status() == WS_CONNECTED) + c.ping(data, len); + } } -void AsyncWebSocket::text(uint32_t id, const char * message, size_t len) +void AsyncWebSocket::text(uint32_t id, const uint8_t *message, size_t len) { - AsyncWebSocketClient *c = client(id); - if(c) - c->text(message, len); + if (AsyncWebSocketClient * c = client(id)) + c->text(makeBuffer(message, len)); +} +void AsyncWebSocket::text(uint32_t id, const char *message, size_t len) +{ + text(id, (const uint8_t *)message, len); +} +void AsyncWebSocket::text(uint32_t id, const char * message) +{ + text(id, message, strlen(message)); +} +void AsyncWebSocket::text(uint32_t id, const String &message) +{ + text(id, message.c_str(), message.length()); +} +void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *data) +{ + PGM_P p = reinterpret_cast(data); + + size_t n = 0; + while (1) + { + if (pgm_read_byte(p+n) == 0) break; + n += 1; + } + + char * message = (char*) malloc(n+1); + if(message) + { + memcpy_P(message, p, n); + message[n] = 0; + text(id, message, n); + free(message); + } } -void AsyncWebSocket::textAll(std::shared_ptr buffer) +void AsyncWebSocket::textAll(std::shared_ptr> buffer) { - if (!buffer) - return; - - for(auto &c : _clients) + for (auto &c : _clients) if (c.status() == WS_CONNECTED) c.text(buffer); - - _cleanBuffers(); } - - +void AsyncWebSocket::textAll(const uint8_t *message, size_t len) +{ + textAll(makeBuffer(message, len)); +} void AsyncWebSocket::textAll(const char * message, size_t len) { - std::shared_ptr WSBuffer = makeBuffer((uint8_t *)message, len); - textAll(WSBuffer); + textAll((const uint8_t *)message, len); +} +void AsyncWebSocket::textAll(const char *message) +{ + textAll(message, strlen(message)); +} +void AsyncWebSocket::textAll(const String &message) +{ + textAll(message.c_str(), message.length()); +} +void AsyncWebSocket::textAll(const __FlashStringHelper *data) +{ + PGM_P p = reinterpret_cast(data); + + size_t n = 0; + while (1) + { + if (pgm_read_byte(p+n) == 0) break; + n += 1; + } + + char *message = (char*)malloc(n+1); + if(message) + { + memcpy_P(message, p, n); + message[n] = 0; + textAll(message, n); + free(message); + } } +void AsyncWebSocket::binary(uint32_t id, const uint8_t *message, size_t len) +{ + if (AsyncWebSocketClient * c = client(id)) + c->binary(makeBuffer(message, len)); +} void AsyncWebSocket::binary(uint32_t id, const char * message, size_t len) { - AsyncWebSocketClient * c = client(id); - if(c) - c->binary(message, len); + binary(id, (const uint8_t *)message, len); } - -void AsyncWebSocket::binaryAll(const char * message, size_t len){ - std::shared_ptr buffer = makeBuffer((uint8_t *)message, len); - binaryAll(buffer); -} - -void AsyncWebSocket::binaryAll(std::shared_ptr buffer) +void AsyncWebSocket::binary(uint32_t id, const char * message) { - if (!buffer) - return; + binary(id, message, strlen(message)); +} +void AsyncWebSocket::binary(uint32_t id, const String &message) +{ + binary(id, message.c_str(), message.length()); +} +void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *data, size_t len) +{ + PGM_P p = reinterpret_cast(data); + char *message = (char*) malloc(len); + if (message) { + memcpy_P(message, p, len); + binary(id, message, len); + free(message); + } +} +void AsyncWebSocket::binaryAll(std::shared_ptr> buffer) +{ for (auto &c : _clients) if (c.status() == WS_CONNECTED) c.binary(buffer); - - _cleanBuffers(); } -void AsyncWebSocket::message(uint32_t id, const char *data, size_t len, uint8_t opcode, bool mask) +void AsyncWebSocket::binaryAll(const uint8_t *message, size_t len) { - AsyncWebSocketClient *c = client(id); - if (c) - c->message(data, len, opcode, mask); + binaryAll(makeBuffer(message, len)); } -void AsyncWebSocket::message(uint32_t id, std::shared_ptr buffer, uint8_t opcode, bool mask) +void AsyncWebSocket::binaryAll(const char *message, size_t len) { - AsyncWebSocketClient *c = client(id); - if (c) - c->message(buffer, opcode, mask); + binaryAll((const uint8_t *)message, len); } - -void AsyncWebSocket::messageAll(std::shared_ptr buffer, uint8_t opcode, bool mask) +void AsyncWebSocket::binaryAll(const char *message) { - for (auto &c : _clients) - if (c.status() == WS_CONNECTED) - c.message(buffer, opcode, mask); - - _cleanBuffers(); + binaryAll(message, strlen(message)); +} +void AsyncWebSocket::binaryAll(const String &message) +{ + binaryAll(message.c_str(), message.length()); +} +void AsyncWebSocket::binaryAll(const __FlashStringHelper *data, size_t len) +{ + PGM_P p = reinterpret_cast(data); + char * message = (char*) malloc(len); + if(message) { + memcpy_P(message, p, len); + binaryAll(message, len); + free(message); + } } size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ @@ -1094,7 +1003,7 @@ size_t AsyncWebSocket::printfAll(const char *format, ...) { va_list arg; char* temp = new char[MAX_PRINTF_LEN]; - if(!temp){ + if (!temp) { return 0; } va_start(arg, format); @@ -1102,10 +1011,10 @@ size_t AsyncWebSocket::printfAll(const char *format, ...) va_end(arg); delete[] temp; - std::shared_ptr buffer = makeBuffer(len); + std::shared_ptr> buffer = std::make_shared>(len); va_start(arg, format); - vsnprintf( (char *)buffer->get(), len + 1, format, arg); + vsnprintf( (char *)buffer->data(), len + 1, format, arg); va_end(arg); textAll(buffer); @@ -1126,10 +1035,11 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){ } #endif -size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) { +size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) +{ va_list arg; char* temp = new char[MAX_PRINTF_LEN]; - if(!temp){ + if (!temp) { return 0; } va_start(arg, formatP); @@ -1137,106 +1047,16 @@ size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) { va_end(arg); delete[] temp; - std::shared_ptr buffer = makeBuffer(len + 1); + std::shared_ptr> buffer = std::make_shared>(len + 1); va_start(arg, formatP); - vsnprintf_P((char *)buffer->get(), len + 1, formatP, arg); + vsnprintf_P((char *)buffer->data(), len + 1, formatP, arg); va_end(arg); textAll(buffer); return len; } -void AsyncWebSocket::text(uint32_t id, const char * message) -{ - text(id, message, strlen(message)); -} -void AsyncWebSocket::text(uint32_t id, uint8_t * message, size_t len) -{ - text(id, (const char *)message, len); -} -void AsyncWebSocket::text(uint32_t id, char * message) -{ - text(id, message, strlen(message)); -} -void AsyncWebSocket::text(uint32_t id, const String &message) -{ - text(id, message.c_str(), message.length()); -} -void AsyncWebSocket::text(uint32_t id, const __FlashStringHelper *message) -{ - AsyncWebSocketClient * c = client(id); - if(c != NULL) - c->text(message); -} -void AsyncWebSocket::textAll(const char * message) -{ - textAll(message, strlen(message)); -} -void AsyncWebSocket::textAll(uint8_t * message, size_t len) -{ - textAll((const char *)message, len); -} -void AsyncWebSocket::textAll(char * message) -{ - textAll(message, strlen(message)); -} -void AsyncWebSocket::textAll(const String &message) -{ - textAll(message.c_str(), message.length()); -} -void AsyncWebSocket::textAll(const __FlashStringHelper *message) -{ - for (auto& c : _clients) - if (c.status() == WS_CONNECTED) - c.text(message); -} -void AsyncWebSocket::binary(uint32_t id, const char * message) -{ - binary(id, message, strlen(message)); -} -void AsyncWebSocket::binary(uint32_t id, uint8_t * message, size_t len) -{ - binary(id, (const char *)message, len); -} -void AsyncWebSocket::binary(uint32_t id, char * message) -{ - binary(id, message, strlen(message)); -} -void AsyncWebSocket::binary(uint32_t id, const String &message) -{ - binary(id, message.c_str(), message.length()); -} -void AsyncWebSocket::binary(uint32_t id, const __FlashStringHelper *message, size_t len) -{ - AsyncWebSocketClient * c = client(id); - if (c != NULL) - c-> binary(message, len); -} -void AsyncWebSocket::binaryAll(const char * message) -{ - binaryAll(message, strlen(message)); -} -void AsyncWebSocket::binaryAll(uint8_t * message, size_t len) -{ - binaryAll((const char *)message, len); -} -void AsyncWebSocket::binaryAll(char * message) -{ - binaryAll(message, strlen(message)); -} -void AsyncWebSocket::binaryAll(const String &message) -{ - binaryAll(message.c_str(), message.length()); -} -void AsyncWebSocket::binaryAll(const __FlashStringHelper *message, size_t len) -{ - for (auto& c : _clients) { - if (c.status() == WS_CONNECTED) - c.binary(message, len); - } -} - const char __WS_STR_CONNECTION[] PROGMEM = { "Connection" }; const char __WS_STR_UPGRADE[] PROGMEM = { "Upgrade" }; const char __WS_STR_ORIGIN[] PROGMEM = { "Origin" }; @@ -1306,41 +1126,6 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request) request->send(response); } -std::shared_ptr AsyncWebSocket::makeBuffer(size_t size) -{ - std::shared_ptr buffer = std::make_shared(size); - - { - AsyncWebLockGuard l(_lock); - _buffers.emplace_back(buffer); - } - - return buffer; -} - -std::shared_ptr AsyncWebSocket::makeBuffer(uint8_t * data, size_t size) -{ - std::shared_ptr buffer = std::make_shared(data, size); - - { - AsyncWebLockGuard l(_lock); - _buffers.emplace_back(buffer); - } - - return buffer; -} - -void AsyncWebSocket::_cleanBuffers() -{ - AsyncWebLockGuard l(_lock); - for (auto iter = std::begin(_buffers); iter != std::end(_buffers);){ - if(iter->lock()) - iter++; - else - iter = _buffers.erase(iter); - } -} - /* * Response to Web Socket request - sends the authorization and detaches the TCP Client from the web server * Authentication code from https://github.com/Links2004/arduinoWebSockets/blob/master/src/WebSockets.cpp#L480 diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 3d24e23..3b4e0f7 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -84,122 +84,25 @@ typedef enum { WS_CONTINUATION, WS_TEXT, WS_BINARY, WS_DISCONNECT = 0x08, WS_PIN typedef enum { WS_MSG_SENDING, WS_MSG_SENT, WS_MSG_ERROR } AwsMessageStatus; typedef enum { WS_EVT_CONNECT, WS_EVT_DISCONNECT, WS_EVT_PONG, WS_EVT_ERROR, WS_EVT_DATA } AwsEventType; -class AsyncWebSocketMessageBuffer { - private: - std::unique_ptr _data; - size_t _len{}; - - public: - AsyncWebSocketMessageBuffer(); - AsyncWebSocketMessageBuffer(size_t size); - AsyncWebSocketMessageBuffer(uint8_t *data, size_t size); - ~AsyncWebSocketMessageBuffer(); - bool reserve(size_t size); - uint8_t *get() { return _data.get(); } - size_t length() { return _len; } -}; - -class AsyncWebSocketMessage { - protected: - uint8_t _opcode; - bool _mask; - AwsMessageStatus _status; - public: - AsyncWebSocketMessage():_opcode(WS_TEXT),_mask(false),_status(WS_MSG_ERROR){} - virtual ~AsyncWebSocketMessage(){} - virtual void ack(size_t len __attribute__((unused)), uint32_t time __attribute__((unused))){} - virtual size_t send(AsyncClient *client __attribute__((unused))){ return 0; } - virtual bool finished(){ return _status != WS_MSG_SENDING; } - virtual bool betweenFrames() const { return false; } -}; - -class AsyncWebSocketBasicMessage: public AsyncWebSocketMessage { - private: - size_t _len; - size_t _sent; - size_t _ack; - size_t _acked; - uint8_t * _data; -public: - AsyncWebSocketBasicMessage(const char * data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); - AsyncWebSocketBasicMessage(uint8_t opcode=WS_TEXT, bool mask=false); - virtual ~AsyncWebSocketBasicMessage() override; - virtual bool betweenFrames() const override { return _acked == _ack; } - virtual void ack(size_t len, uint32_t time) override ; - virtual size_t send(AsyncClient *client) override ; -}; - -class AsyncWebSocketMultiMessage: public AsyncWebSocketMessage { - private: - uint8_t * _data; - size_t _len; - size_t _sent; - size_t _ack; - size_t _acked; - std::shared_ptr _WSbuffer; -public: - AsyncWebSocketMultiMessage(std::shared_ptr buffer, uint8_t opcode=WS_TEXT, bool mask=false); - virtual ~AsyncWebSocketMultiMessage() override; - virtual bool betweenFrames() const override { return _acked == _ack; } - virtual void ack(size_t len, uint32_t time) override ; - virtual size_t send(AsyncClient *client) override ; -}; - -class PolymorphMessageContainer +class AsyncWebSocketMessage { - union { - AsyncWebSocketBasicMessage basicMessage; - AsyncWebSocketMultiMessage multiMessage; - }; - - enum class Type : uint8_t { Basic, Multi }; - const Type type; +private: + std::shared_ptr> _WSbuffer; + uint8_t _opcode{WS_TEXT}; + bool _mask{false}; + AwsMessageStatus _status{WS_MSG_ERROR}; + size_t _sent{}; + size_t _ack{}; + size_t _acked{}; public: - PolymorphMessageContainer() = delete; - PolymorphMessageContainer(const PolymorphMessageContainer &) = delete; - PolymorphMessageContainer &operator=(const PolymorphMessageContainer &) = delete; + AsyncWebSocketMessage(std::shared_ptr> buffer, uint8_t opcode=WS_TEXT, bool mask=false); - PolymorphMessageContainer(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false) : - type{Type::Basic} - { - new (&basicMessage) AsyncWebSocketBasicMessage{data, len, opcode, mask}; - } + bool finished() const { return _status != WS_MSG_SENDING; } + bool betweenFrames() const { return _acked == _ack; } - PolymorphMessageContainer(std::shared_ptr buffer, uint8_t opcode=WS_TEXT, bool mask=false) : - type{Type::Multi} - { - new (&multiMessage) AsyncWebSocketMultiMessage{buffer, opcode, mask}; - } - - ~PolymorphMessageContainer() - { - switch (type) - { - case Type::Basic: basicMessage.~AsyncWebSocketBasicMessage(); break; - case Type::Multi: multiMessage.~AsyncWebSocketMultiMessage(); break; - } - } - - AsyncWebSocketMessage &get() - { - switch (type) - { - case Type::Basic: return basicMessage; - case Type::Multi: return multiMessage; - } - __builtin_unreachable(); - } - - const AsyncWebSocketMessage &get() const - { - switch (type) - { - case Type::Basic: return basicMessage; - case Type::Multi: return multiMessage; - } - __builtin_unreachable(); - } + void ack(size_t len, uint32_t time); + size_t send(AsyncClient *client); }; class AsyncWebSocketClient { @@ -212,7 +115,7 @@ class AsyncWebSocketClient { AsyncWebLock _lock; std::deque _controlQueue; - std::deque _messageQueue; + std::deque _messageQueue; uint8_t _pstate; AwsFrameInfo _pinfo; @@ -220,9 +123,8 @@ class AsyncWebSocketClient { uint32_t _lastMessageTime; uint32_t _keepAlivePeriod; - void _queueControl(uint8_t opcode, uint8_t *data=NULL, size_t len=0, bool mask=false); - void _queueMessage(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); - void _queueMessage(std::shared_ptr buffer, uint8_t opcode=WS_TEXT, bool mask=false); + void _queueControl(uint8_t opcode, const uint8_t *data=NULL, size_t len=0, bool mask=false); + void _queueMessage(std::shared_ptr> buffer, uint8_t opcode=WS_TEXT, bool mask=false); void _runQueue(); void _clearQueue(); @@ -246,7 +148,7 @@ class AsyncWebSocketClient { //control frames void close(uint16_t code=0, const char * message=NULL); - void ping(uint8_t *data=NULL, size_t len=0); + void ping(const uint8_t *data=NULL, size_t len=0); //set auto-ping period in seconds. disabled if zero (default) void keepAlivePeriod(uint16_t seconds){ @@ -257,8 +159,7 @@ class AsyncWebSocketClient { } //data packets - void message(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(data, len, opcode, mask); } - void message(std::shared_ptr buffer, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(buffer, opcode, mask); } + void message(std::shared_ptr> buffer, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(buffer, opcode, mask); } bool queueIsFull() const; size_t queueLen() const; @@ -266,21 +167,20 @@ class AsyncWebSocketClient { #ifndef ESP32 size_t printf_P(PGM_P formatP, ...) __attribute__ ((format (printf, 2, 3))); #endif - void text(const char * message, size_t len); - void text(const char * message); - void text(uint8_t * message, size_t len); - void text(char * message); - void text(const String &message); - void text(const __FlashStringHelper *data); - void text(std::shared_ptr buffer); + void text(std::shared_ptr> buffer); + void text(const uint8_t *message, size_t len); + void text(const char *message, size_t len); + void text(const char *message); + void text(const String &message); + void text(const __FlashStringHelper *message); + + void binary(std::shared_ptr> buffer); + void binary(const uint8_t *message, size_t len); void binary(const char * message, size_t len); void binary(const char * message); - void binary(uint8_t * message, size_t len); - void binary(char * message); void binary(const String &message); - void binary(const __FlashStringHelper *data, size_t len); - void binary(std::shared_ptr buffer); + void binary(const __FlashStringHelper *message, size_t len); bool canSend() const; @@ -324,42 +224,34 @@ class AsyncWebSocket: public AsyncWebHandler { void closeAll(uint16_t code=0, const char * message=NULL); void cleanupClients(uint16_t maxClients = DEFAULT_MAX_WS_CLIENTS); - void ping(uint32_t id, uint8_t *data=NULL, size_t len=0); - void pingAll(uint8_t *data=NULL, size_t len=0); // done + void ping(uint32_t id, const uint8_t *data=NULL, size_t len=0); + void pingAll(const uint8_t *data=NULL, size_t len=0); // done - void text(uint32_t id, const char * message, size_t len); - void text(uint32_t id, const char * message); - void text(uint32_t id, uint8_t * message, size_t len); - void text(uint32_t id, char * message); + void text(uint32_t id, const uint8_t * message, size_t len); + void text(uint32_t id, const char *message, size_t len); + void text(uint32_t id, const char *message); void text(uint32_t id, const String &message); void text(uint32_t id, const __FlashStringHelper *message); + void textAll(std::shared_ptr> buffer); + void textAll(const uint8_t *message, size_t len); void textAll(const char * message, size_t len); void textAll(const char * message); - void textAll(uint8_t * message, size_t len); - void textAll(char * message); void textAll(const String &message); void textAll(const __FlashStringHelper *message); // need to convert - void textAll(std::shared_ptr buffer); - void binary(uint32_t id, const char * message, size_t len); - void binary(uint32_t id, const char * message); - void binary(uint32_t id, uint8_t * message, size_t len); - void binary(uint32_t id, char * message); + void binary(uint32_t id, const uint8_t *message, size_t len); + void binary(uint32_t id, const char *message, size_t len); + void binary(uint32_t id, const char *message); void binary(uint32_t id, const String &message); void binary(uint32_t id, const __FlashStringHelper *message, size_t len); - void binaryAll(const char * message, size_t len); - void binaryAll(const char * message); - void binaryAll(uint8_t * message, size_t len); - void binaryAll(char * message); + void binaryAll(std::shared_ptr> buffer); + void binaryAll(const uint8_t *message, size_t len); + void binaryAll(const char *message, size_t len); + void binaryAll(const char *message); void binaryAll(const String &message); void binaryAll(const __FlashStringHelper *message, size_t len); - void binaryAll(std::shared_ptr buffer); - - void message(uint32_t id, const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); - void message(uint32_t id, std::shared_ptr buffer, uint8_t opcode=WS_TEXT, bool mask=false); - void messageAll(std::shared_ptr buffer, uint8_t opcode=WS_TEXT, bool mask=false); size_t printf(uint32_t id, const char *format, ...) __attribute__ ((format (printf, 3, 4))); size_t printfAll(const char *format, ...) __attribute__ ((format (printf, 2, 3))); @@ -386,13 +278,6 @@ class AsyncWebSocket: public AsyncWebHandler { virtual bool canHandle(AsyncWebServerRequest *request) override final; virtual void handleRequest(AsyncWebServerRequest *request) override final; - - // messagebuffer functions/objects. - std::shared_ptr makeBuffer(size_t size = 0); - std::shared_ptr makeBuffer(uint8_t * data, size_t size); - std::list> _buffers; - void _cleanBuffers(); - const std::list &getClients() const { return _clients; } };