From 344d31c1ffbffd845823e1f8c369fcbf2439e47d Mon Sep 17 00:00:00 2001 From: 0xFEEDC0DE64 Date: Sun, 20 Dec 2020 02:08:39 +0100 Subject: [PATCH] Multi-thread queue synchronization for AsyncWebSocketClient --- src/AsyncWebSocket.cpp | 348 +++++++++++++++++++++++++++-------------- src/AsyncWebSocket.h | 74 ++++++++- 2 files changed, 302 insertions(+), 120 deletions(-) diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 280702a..1ecd764 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -132,7 +132,7 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() ,_lock(false) ,_count(0) { - + Serial.printf("AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() this=0x%llx\r\n", uint64_t(this)); } AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t * data, size_t size) @@ -141,6 +141,7 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(uint8_t * data, size_t ,_lock(false) ,_count(0) { + Serial.printf("AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() this=0x%llx\r\n", uint64_t(this)); if (!data) { return; @@ -161,12 +162,13 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(size_t size) ,_lock(false) ,_count(0) { + Serial.printf("AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() this=0x%llx\r\n", uint64_t(this)); + _data = new uint8_t[_len + 1]; if (_data) { _data[_len] = 0; - } - + } } AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(const AsyncWebSocketMessageBuffer & copy) @@ -175,6 +177,8 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(const AsyncWebSocketMes ,_lock(false) ,_count(0) { + Serial.printf("AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() this=0x%llx\r\n", uint64_t(this)); + _len = copy._len; _lock = copy._lock; _count = 0; @@ -197,6 +201,8 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(AsyncWebSocketMessageBu ,_lock(false) ,_count(0) { + Serial.printf("AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer() this=0x%llx\r\n", uint64_t(this)); + _len = copy._len; _lock = copy._lock; _count = 0; @@ -210,6 +216,8 @@ AsyncWebSocketMessageBuffer::AsyncWebSocketMessageBuffer(AsyncWebSocketMessageBu AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer() { + Serial.printf("AsyncWebSocketMessageBuffer::~AsyncWebSocketMessageBuffer() this=0x%llx data=0x%llx\r\n", uint64_t(this), uint64_t(_data)); + if (_data) { delete[] _data; } @@ -385,14 +393,13 @@ AsyncWebSocketBasicMessage::~AsyncWebSocketBasicMessage() { */ -AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer * buffer, uint8_t opcode, bool mask) +AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask) :_len(0) ,_sent(0) ,_ack(0) ,_acked(0) ,_WSbuffer(nullptr) { - _opcode = opcode & 0x07; _mask = mask; @@ -475,8 +482,7 @@ AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() { const size_t AWSC_PING_PAYLOAD_LEN = 22; AsyncWebSocketClient::AsyncWebSocketClient(AsyncWebServerRequest *request, AsyncWebSocket *server) - : _messageQueue(LinkedList([](AsyncWebSocketMessage *m){ delete m; })) - , _tempObject(NULL) + : _tempObject(NULL) { _client = request->client(); _server = server; @@ -498,78 +504,170 @@ AsyncWebSocketClient::AsyncWebSocketClient(AsyncWebServerRequest *request, Async } AsyncWebSocketClient::~AsyncWebSocketClient(){ - _messageQueue.free(); + Serial.printf("AsyncWebSocketClient::~AsyncWebSocketClient task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + AsyncWebLockGuard l(_lock); + _messageQueue = {}; _controlQueue = {}; _server->_handleEvent(this, WS_EVT_DISCONNECT, NULL, NULL, 0); } void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){ - _lastMessageTime = millis(); - if (!_controlQueue.empty()){ - auto &head = _controlQueue.front(); - if (head.finished()){ - len -= head.len(); - if (_status == WS_DISCONNECTING && head.opcode() == WS_DISCONNECT){ - _controlQueue.pop(); - _status = WS_DISCONNECTED; - _client->close(true); - return; - } - _controlQueue.pop(); + Serial.printf("AsyncWebSocketClient::_onAck task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + _lastMessageTime = millis(); + + { + AsyncWebLockGuard l(_lock); + + if (!_controlQueue.empty()) { + auto &head = _controlQueue.front(); + if (head.finished()){ + len -= head.len(); + if (_status == WS_DISCONNECTING && head.opcode() == WS_DISCONNECT){ + _controlQueue.pop(); + _status = WS_DISCONNECTED; + _client->close(true); + return; + } + _controlQueue.pop(); + } + } + + if(len && !_messageQueue.empty()){ + _messageQueue.front().get().ack(len, time); + } } - } - if(len && !_messageQueue.isEmpty()){ - _messageQueue.front()->ack(len, time); - } - _server->_cleanBuffers(); - _runQueue(); -} -void AsyncWebSocketClient::_onPoll(){ - if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.isEmpty())){ - _runQueue(); - } else if(_keepAlivePeriod > 0 && _controlQueue.empty() && _messageQueue.isEmpty() && (millis() - _lastMessageTime) >= _keepAlivePeriod){ - ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); - } -} - -void AsyncWebSocketClient::_runQueue(){ - while(!_messageQueue.isEmpty() && _messageQueue.front()->finished()){ - _messageQueue.remove(_messageQueue.front()); - } - - if(!_controlQueue.empty() && (_messageQueue.isEmpty() || _messageQueue.front()->betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front().len() - 1)){ - _controlQueue.front().send(_client); - } else if(!_messageQueue.isEmpty() && _messageQueue.front()->betweenFrames() && webSocketSendFrameWindow(_client)){ - _messageQueue.front()->send(_client); - } -} - -bool AsyncWebSocketClient::queueIsFull() const { - return (_messageQueue.length() >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED); -} - -void AsyncWebSocketClient::_queueMessage(AsyncWebSocketMessage *dataMessage){ - if(dataMessage == NULL) - return; - if(_status != WS_CONNECTED){ - delete dataMessage; - return; - } - if(_messageQueue.length() >= WS_MAX_QUEUED_MESSAGES){ - ets_printf("ERROR: Too many messages queued\n"); - delete dataMessage; - } else { - _messageQueue.add(dataMessage); - } - if(_client->canSend()) + _server->_cleanBuffers(); _runQueue(); } -void AsyncWebSocketClient::_queueControl(uint8_t opcode, uint8_t *data, size_t len, bool mask){ - _controlQueue.emplace(opcode, data, len, mask); - if (_client->canSend()) - _runQueue(); +void AsyncWebSocketClient::_onPoll() +{ + Serial.printf("AsyncWebSocketClient::_onPoll task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + AsyncWebLockGuard l(_lock); + if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.empty())) + { + _runQueue(); + } + else if(_keepAlivePeriod > 0 && _controlQueue.empty() && _messageQueue.empty() && (millis() - _lastMessageTime) >= _keepAlivePeriod) + { + ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); + } +} + +void AsyncWebSocketClient::_runQueue() +{ + Serial.printf("AsyncWebSocketClient::_runQueue task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + AsyncWebLockGuard l(_lock); + + while (!_messageQueue.empty() && _messageQueue.front().get().finished()) + { + _messageQueue.pop(); + } + + if (!_controlQueue.empty() && (_messageQueue.empty() || _messageQueue.front().get().betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front().len() - 1)) + { + _controlQueue.front().send(_client); + } + else if (!_messageQueue.empty() && _messageQueue.front().get().betweenFrames() && webSocketSendFrameWindow(_client)) + { + _messageQueue.front().get().send(_client); + } +} + +bool AsyncWebSocketClient::queueIsFull() const +{ + Serial.printf("AsyncWebSocketClient::queueIsFull task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + size_t size; + { + AsyncWebLockGuard l(_lock); + size = _messageQueue.size(); + } + return (size >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED); +} + +bool AsyncWebSocketClient::canSend() const +{ + Serial.printf("AsyncWebSocketClient::canSend task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + size_t size; + { + AsyncWebLockGuard l(_lock); + size = _messageQueue.size(); + } + return size < WS_MAX_QUEUED_MESSAGES; +} + +void AsyncWebSocketClient::_queueControl(uint8_t opcode, uint8_t *data, size_t len, bool mask) +{ + Serial.printf("AsyncWebSocketClient::_queueControl task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + { + AsyncWebLockGuard l(_lock); + _controlQueue.emplace(opcode, data, len, mask); + } + + if (_client->canSend()) + _runQueue(); +} + +void AsyncWebSocketClient::_queueMessage(const char *data, size_t len, uint8_t opcode, bool mask) +{ + Serial.printf("AsyncWebSocketClient::_queueMessage task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + if(_status != WS_CONNECTED) + { + PolymorphMessageContainer{data, len, opcode, mask}; + return; + } + + { + AsyncWebLockGuard l(_lock); + if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) + { + ets_printf("ERROR: Too many messages queued\n"); + PolymorphMessageContainer{data, len, opcode, mask}; + } + else + { + _messageQueue.emplace(data, len, opcode, mask); + } + } + + if(_client->canSend()) + _runQueue(); +} + +void AsyncWebSocketClient::_queueMessage(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask) +{ + Serial.printf("AsyncWebSocketClient::_queueMessage task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + if(_status != WS_CONNECTED) + { + PolymorphMessageContainer{buffer, opcode, mask}; + return; + } + + { + AsyncWebLockGuard l(_lock); + if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) + { + ets_printf("ERROR: Too many messages queued\n"); + PolymorphMessageContainer{buffer, opcode, mask}; + } + else + { + _messageQueue.emplace(buffer, opcode, mask); + } + } + + if(_client->canSend()) + _runQueue(); } void AsyncWebSocketClient::close(uint16_t code, const char * message){ @@ -768,7 +866,7 @@ size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) { #endif void AsyncWebSocketClient::text(const char * message, size_t len){ - _queueMessage(new AsyncWebSocketBasicMessage(message, len)); + _queueMessage(message, len); } void AsyncWebSocketClient::text(const char * message){ text(message, strlen(message)); @@ -800,11 +898,11 @@ void AsyncWebSocketClient::text(const __FlashStringHelper *data){ } void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer * buffer) { - _queueMessage(new AsyncWebSocketMultiMessage(buffer)); + _queueMessage(buffer); } void AsyncWebSocketClient::binary(const char * message, size_t len){ - _queueMessage(new AsyncWebSocketBasicMessage(message, len, WS_BINARY)); + _queueMessage(message, len, WS_BINARY); } void AsyncWebSocketClient::binary(const char * message){ binary(message, strlen(message)); @@ -831,7 +929,7 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len){ } void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer * buffer) { - _queueMessage(new AsyncWebSocketMultiMessage(buffer, WS_BINARY)); + _queueMessage(buffer, WS_BINARY); } IPAddress AsyncWebSocketClient::remoteIP() const { @@ -983,30 +1081,43 @@ void AsyncWebSocket::binaryAll(const char * message, size_t len){ binaryAll(buffer); } -void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer * buffer) +void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer *buffer) { - if (!buffer) return; - buffer->lock(); - for (auto &c : _clients){ - if (c.status() == WS_CONNECTED) - c.binary(buffer); - } - buffer->unlock(); - _cleanBuffers(); + if (!buffer) + return; + + buffer->lock(); + + for (auto &c : _clients) { + if (c.status() == WS_CONNECTED) + c.binary(buffer); + } + + buffer->unlock(); + + _cleanBuffers(); } -void AsyncWebSocket::message(uint32_t id, AsyncWebSocketMessage *message){ - AsyncWebSocketClient *c = client(id); - if (c) - c->message(message); +void AsyncWebSocket::message(uint32_t id, const char *data, size_t len, uint8_t opcode, bool mask) +{ + AsyncWebSocketClient *c = client(id); + if (c) + c->message(data, len, opcode, mask); } -void AsyncWebSocket::messageAll(AsyncWebSocketMultiMessage *message){ - for (auto &c : _clients){ - if(c.status() == WS_CONNECTED) - c.message(message); - } - _cleanBuffers(); +void AsyncWebSocket::message(uint32_t id, AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask) +{ + AsyncWebSocketClient *c = client(id); + if (c) + c->message(buffer, opcode, mask); +} + +void AsyncWebSocket::messageAll(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask){ + for (auto &c : _clients){ + if (c.status() == WS_CONNECTED) + c.message(buffer, opcode, mask); + } + _cleanBuffers(); } size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ @@ -1206,38 +1317,47 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request){ AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(size_t size) { - AsyncWebSocketMessageBuffer *buffer{}; - { - AsyncWebLockGuard l(_lock); - _buffers.emplace_back(size); - buffer = &_buffers.back(); - } - return buffer; + Serial.printf("AsyncWebSocket::makeBuffer task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); + + AsyncWebSocketMessageBuffer *buffer{}; + + { + AsyncWebLockGuard l(_lock); + _buffers.emplace_back(size); + buffer = &_buffers.back(); + } + + return buffer; } AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(uint8_t * data, size_t size) { - AsyncWebSocketMessageBuffer *buffer{}; - - { - AsyncWebLockGuard l(_lock); - _buffers.emplace_back(data, size); - buffer = &_buffers.back(); - } + Serial.printf("AsyncWebSocket::makeBuffer task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); - return buffer; + AsyncWebSocketMessageBuffer *buffer{}; + + { + AsyncWebLockGuard l(_lock); + _buffers.emplace_back(data, size); + buffer = &_buffers.back(); + } + + return buffer; } void AsyncWebSocket::_cleanBuffers() { - AsyncWebLockGuard l(_lock); + Serial.printf("AsyncWebSocket::_cleanBuffers task=0x%llx %s\r\n", uint64_t(xTaskGetCurrentTaskHandle()), pcTaskGetTaskName(xTaskGetCurrentTaskHandle())); - for (auto iter = std::begin(_buffers); iter != std::end(_buffers);){ - if(iter->canDelete()){ - iter = _buffers.erase(iter); - } else - iter++; - } + AsyncWebLockGuard l(_lock); + + for (auto iter = std::begin(_buffers); iter != std::end(_buffers);) + { + if(iter->canDelete()) + iter = _buffers.erase(iter); + else + iter++; + } } /* diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index d67f8bb..294916a 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -157,6 +157,63 @@ public: virtual size_t send(AsyncClient *client) override ; }; +class PolymorphMessageContainer +{ + union { + AsyncWebSocketBasicMessage basicMessage; + AsyncWebSocketMultiMessage multiMessage; + }; + + enum class Type : uint8_t { Basic, Multi }; + const Type type; + +public: + PolymorphMessageContainer() = delete; + PolymorphMessageContainer(const PolymorphMessageContainer &) = delete; + PolymorphMessageContainer &operator=(const PolymorphMessageContainer &) = delete; + + PolymorphMessageContainer(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false) : + type{Type::Basic} + { + new (&basicMessage) AsyncWebSocketBasicMessage{data, len, opcode, mask}; + } + + PolymorphMessageContainer(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false) : + type{Type::Multi} + { + new (&multiMessage) AsyncWebSocketMultiMessage{buffer, opcode, mask}; + } + + ~PolymorphMessageContainer() + { + switch (type) + { + case Type::Basic: basicMessage.~AsyncWebSocketBasicMessage(); break; + case Type::Multi: multiMessage.~AsyncWebSocketMultiMessage(); break; + } + } + + AsyncWebSocketMessage &get() + { + switch (type) + { + case Type::Basic: return basicMessage; + case Type::Multi: return multiMessage; + } + __builtin_unreachable(); + } + + const AsyncWebSocketMessage &get() const + { + switch (type) + { + case Type::Basic: return basicMessage; + case Type::Multi: return multiMessage; + } + __builtin_unreachable(); + } +}; + class AsyncWebSocketClient { private: AsyncClient *_client; @@ -164,8 +221,10 @@ class AsyncWebSocketClient { uint32_t _clientId; AwsClientStatus _status; + AsyncWebLock _lock; + std::queue _controlQueue; - LinkedList _messageQueue; + std::queue _messageQueue; uint8_t _pstate; AwsFrameInfo _pinfo; @@ -173,8 +232,9 @@ class AsyncWebSocketClient { uint32_t _lastMessageTime; uint32_t _keepAlivePeriod; - void _queueMessage(AsyncWebSocketMessage *dataMessage); void _queueControl(uint8_t opcode, uint8_t *data=NULL, size_t len=0, bool mask=false); + void _queueMessage(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); + void _queueMessage(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false); void _runQueue(); public: @@ -208,7 +268,8 @@ class AsyncWebSocketClient { } //data packets - void message(AsyncWebSocketMessage *message){ _queueMessage(message); } + void message(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(data, len, opcode, mask); } + void message(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(buffer, opcode, mask); } bool queueIsFull() const; size_t printf(const char *format, ...) __attribute__ ((format (printf, 2, 3))); @@ -231,7 +292,7 @@ class AsyncWebSocketClient { void binary(const __FlashStringHelper *data, size_t len); void binary(AsyncWebSocketMessageBuffer *buffer); - bool canSend() { return _messageQueue.length() < WS_MAX_QUEUED_MESSAGES; } + bool canSend() const; //system callbacks (do not call) void _onAck(size_t len, uint32_t time); @@ -304,8 +365,9 @@ class AsyncWebSocket: public AsyncWebHandler { void binaryAll(const __FlashStringHelper *message, size_t len); void binaryAll(AsyncWebSocketMessageBuffer * buffer); - void message(uint32_t id, AsyncWebSocketMessage *message); - void messageAll(AsyncWebSocketMultiMessage *message); + void message(uint32_t id, const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); + void message(uint32_t id, AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false); + void messageAll(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false); size_t printf(uint32_t id, const char *format, ...) __attribute__ ((format (printf, 3, 4))); size_t printfAll(const char *format, ...) __attribute__ ((format (printf, 2, 3)));