diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 35a4003..4b14c25 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -416,14 +416,13 @@ AsyncWebSocketBasicMessage::~AsyncWebSocketBasicMessage() { */ -AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer * buffer, uint8_t opcode, bool mask) +AsyncWebSocketMultiMessage::AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask) :_len(0) ,_sent(0) ,_ack(0) ,_acked(0) ,_WSbuffer(nullptr) { - _opcode = opcode & 0x07; _mask = mask; @@ -522,134 +521,185 @@ AsyncWebSocketMultiMessage::~AsyncWebSocketMultiMessage() { const size_t AWSC_PING_PAYLOAD_LEN = 22; AsyncWebSocketClient::AsyncWebSocketClient(AsyncWebServerRequest *request, AsyncWebSocket *server) - : _messageQueue(LinkedList([](AsyncWebSocketMessage *m){ delete m; })) - , _tempObject(NULL) + : _tempObject(NULL) { - _client = request->client(); - _server = server; - _clientId = _server->_getNextId(); - _status = WS_CONNECTED; - _pstate = 0; - _lastMessageTime = millis(); - _keepAlivePeriod = 0; - _client->setRxTimeout(0); - _client->onError([](void *r, AsyncClient* c, int8_t error){ (void)c; ((AsyncWebSocketClient*)(r))->_onError(error); }, this); - _client->onAck([](void *r, AsyncClient* c, size_t len, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onAck(len, time); }, this); - _client->onDisconnect([](void *r, AsyncClient* c){ ((AsyncWebSocketClient*)(r))->_onDisconnect(); delete c; }, this); - _client->onTimeout([](void *r, AsyncClient* c, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onTimeout(time); }, this); - _client->onData([](void *r, AsyncClient* c, void *buf, size_t len){ (void)c; ((AsyncWebSocketClient*)(r))->_onData(buf, len); }, this); - _client->onPoll([](void *r, AsyncClient* c){ (void)c; ((AsyncWebSocketClient*)(r))->_onPoll(); }, this); - _server->_handleEvent(this, WS_EVT_CONNECT, request, NULL, 0); - delete request; - memset(&_pinfo,0,sizeof(_pinfo)); + _client = request->client(); + _server = server; + _clientId = _server->_getNextId(); + _status = WS_CONNECTED; + _pstate = 0; + _lastMessageTime = millis(); + _keepAlivePeriod = 0; + _client->setRxTimeout(0); + _client->onError([](void *r, AsyncClient* c, int8_t error){ (void)c; ((AsyncWebSocketClient*)(r))->_onError(error); }, this); + _client->onAck([](void *r, AsyncClient* c, size_t len, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onAck(len, time); }, this); + _client->onDisconnect([](void *r, AsyncClient* c){ ((AsyncWebSocketClient*)(r))->_onDisconnect(); delete c; }, this); + _client->onTimeout([](void *r, AsyncClient* c, uint32_t time){ (void)c; ((AsyncWebSocketClient*)(r))->_onTimeout(time); }, this); + _client->onData([](void *r, AsyncClient* c, void *buf, size_t len){ (void)c; ((AsyncWebSocketClient*)(r))->_onData(buf, len); }, this); + _client->onPoll([](void *r, AsyncClient* c){ (void)c; ((AsyncWebSocketClient*)(r))->_onPoll(); }, this); + _server->_handleEvent(this, WS_EVT_CONNECT, request, NULL, 0); + delete request; + memset(&_pinfo,0,sizeof(_pinfo)); } -AsyncWebSocketClient::~AsyncWebSocketClient(){ - // Serial.printf("%u FREE Q\n", id()); - _messageQueue.free(); - _controlQueue.clear(); - _server->_cleanBuffers(); - _server->_handleEvent(this, WS_EVT_DISCONNECT, NULL, NULL, 0); +AsyncWebSocketClient::~AsyncWebSocketClient() +{ + { + AsyncWebLockGuard l(_lock); + + _messageQueue.clear(); + _controlQueue.clear(); + } + _server->_cleanBuffers(); + _server->_handleEvent(this, WS_EVT_DISCONNECT, NULL, NULL, 0); } void AsyncWebSocketClient::_clearQueue(){ - while(!_messageQueue.isEmpty() && _messageQueue.front()->finished()){ - _messageQueue.remove(_messageQueue.front()); - } + while (!_messageQueue.empty() && _messageQueue.front().get().finished()){ + _messageQueue.pop_front(); + } } void AsyncWebSocketClient::_onAck(size_t len, uint32_t time){ - // Serial.printf("%u onAck\n", id()); - _lastMessageTime = millis(); - if(!_controlQueue.empty()){ - auto &head = _controlQueue.front(); - if(head.finished()){ - len -= head.len(); - if(_status == WS_DISCONNECTING && head.opcode() == WS_DISCONNECT){ - _controlQueue.pop_front(); - _status = WS_DISCONNECTED; - _client->close(true); - return; - } - _controlQueue.pop_front(); + _lastMessageTime = millis(); + + { + AsyncWebLockGuard l(_lock); + + if (!_controlQueue.empty()){ + auto &head = _controlQueue.front(); + if(head.finished()){ + len -= head.len(); + if(_status == WS_DISCONNECTING && head.opcode() == WS_DISCONNECT){ + _controlQueue.pop_front(); + _status = WS_DISCONNECTED; + _client->close(true); + return; + } + _controlQueue.pop_front(); + } + } + + if(len && !_messageQueue.empty()){ + _messageQueue.front().get().ack(len, time); + } + + _clearQueue(); } - } - if(len && !_messageQueue.isEmpty()){ - _messageQueue.front()->ack(len, time); - } - - _clearQueue(); - - _server->_cleanBuffers(); - // Serial.println("RUN 1"); - _runQueue(); + _server->_cleanBuffers(); + _runQueue(); } void AsyncWebSocketClient::_onPoll(){ - if(_client->canSend() && (!_controlQueue.empty() || !_messageQueue.isEmpty())){ - _runQueue(); - } else if(_keepAlivePeriod > 0 && _controlQueue.empty() && _messageQueue.isEmpty() && (millis() - _lastMessageTime) >= _keepAlivePeriod){ - ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); - } + if(_client->canSend() && [this](){ AsyncWebLockGuard l(_lock); return !_controlQueue.empty() || !_messageQueue.empty(); }()) + { + _runQueue(); + } else if(_keepAlivePeriod > 0 && (millis() - _lastMessageTime) >= _keepAlivePeriod && + [this](){ AsyncWebLockGuard l(_lock); return _controlQueue.empty() && _messageQueue.empty(); }()) + { + ping((uint8_t *)AWSC_PING_PAYLOAD, AWSC_PING_PAYLOAD_LEN); + } } -void AsyncWebSocketClient::_runQueue(){ - _clearQueue(); +void AsyncWebSocketClient::_runQueue() +{ + AsyncWebLockGuard l(_lock); - //size_t m0 = _messageQueue.isEmpty()? 0 : _messageQueue.length(); - //size_t m1 = _messageQueue.isEmpty()? 0 : _messageQueue.front()->betweenFrames(); - // Serial.printf("%u R C = %u %u\n", _clientId, m0, m1); - if(!_controlQueue.empty() && (_messageQueue.isEmpty() || _messageQueue.front()->betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front().len() - 1)){ - // Serial.printf("%u R S C\n", _clientId); - _controlQueue.front().send(_client); - } else if(!_messageQueue.isEmpty() && _messageQueue.front()->betweenFrames() && webSocketSendFrameWindow(_client)){ - // Serial.printf("%u R S M = ", _clientId); - _messageQueue.front()->send(_client); - } + _clearQueue(); - _clearQueue(); + if(!_controlQueue.empty() && (_messageQueue.empty() || _messageQueue.front().get().betweenFrames()) && webSocketSendFrameWindow(_client) > (size_t)(_controlQueue.front().len() - 1)) + { + l.unlock(); + _controlQueue.front().send(_client); + } + else if (!_messageQueue.empty() && _messageQueue.front().get().betweenFrames() && webSocketSendFrameWindow(_client)) + { + l.unlock(); + _messageQueue.front().get().send(_client); + } } -bool AsyncWebSocketClient::queueIsFull() const { - return (_messageQueue.length() >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED); +bool AsyncWebSocketClient::queueIsFull() const +{ + size_t size; + { + AsyncWebLockGuard l(_lock); + size = _messageQueue.size(); + } + return (size >= WS_MAX_QUEUED_MESSAGES) || (_status != WS_CONNECTED); } -size_t AsyncWebSocketClient::queueLen() const { - return _messageQueue.length() + _controlQueue.size(); +size_t AsyncWebSocketClient::queueLen() const +{ + AsyncWebLockGuard l(_lock); + + return _messageQueue.size() + _controlQueue.size(); } -void AsyncWebSocketClient::_queueMessage(AsyncWebSocketMessage *dataMessage){ - if(dataMessage == NULL){ - // Serial.printf("%u Q1\n", _clientId); - return; - } - if(_status != WS_CONNECTED){ - // Serial.printf("%u Q2\n", _clientId); - delete dataMessage; - return; - } - if(_messageQueue.length() >= WS_MAX_QUEUED_MESSAGES){ - ets_printf(String(F("ERROR: Too many messages queued\n")).c_str()); - // Serial.printf("%u Q3\n", _clientId); - delete dataMessage; - } else { - _messageQueue.add(dataMessage); - // Serial.printf("%u Q A %u\n", _clientId, _messageQueue.length()); - } - if(_client->canSend()) { - // Serial.printf("%u Q S\n", _clientId); - // Serial.println("RUN 3"); - _runQueue(); - } +bool AsyncWebSocketClient::canSend() const +{ + size_t size; + { + AsyncWebLockGuard l(_lock); + size = _messageQueue.size(); + } + return size < WS_MAX_QUEUED_MESSAGES; } -void AsyncWebSocketClient::_queueControl(uint8_t opcode, uint8_t *data, size_t len, bool mask){ - _controlQueue.emplace_back(opcode, data, len, mask); - if (_client->canSend()) { - // Serial.println("RUN 4"); - _runQueue(); - } +void AsyncWebSocketClient::_queueControl(uint8_t opcode, uint8_t *data, size_t len, bool mask) +{ + { + AsyncWebLockGuard l(_lock); + _controlQueue.emplace_back(opcode, data, len, mask); + } + + if (_client->canSend()) + _runQueue(); +} + +void AsyncWebSocketClient::_queueMessage(const char *data, size_t len, uint8_t opcode, bool mask) +{ + if(_status != WS_CONNECTED) + return; + + { + AsyncWebLockGuard l(_lock); + if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) + { + l.unlock(); + ets_printf("ERROR: Too many messages queued\n"); + } + else + { + _messageQueue.emplace_back(data, len, opcode, mask); + } + } + + if(_client->canSend()) + _runQueue(); +} + +void AsyncWebSocketClient::_queueMessage(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask) +{ + if(_status != WS_CONNECTED) + return; + + { + AsyncWebLockGuard l(_lock); + if (_messageQueue.size() >= WS_MAX_QUEUED_MESSAGES) + { + l.unlock(); + ets_printf("ERROR: Too many messages queued\n"); + } + else + { + _messageQueue.emplace_back(buffer, opcode, mask); + } + } + + if(_client->canSend()) + _runQueue(); } void AsyncWebSocketClient::close(uint16_t code, const char * message){ @@ -855,7 +905,7 @@ size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) { #endif void AsyncWebSocketClient::text(const char * message, size_t len){ - _queueMessage(new AsyncWebSocketBasicMessage(message, len)); + _queueMessage(message, len); } void AsyncWebSocketClient::text(const char * message){ text(message, strlen(message)); @@ -874,11 +924,11 @@ void AsyncWebSocketClient::text(const __FlashStringHelper *data){ } void AsyncWebSocketClient::text(AsyncWebSocketMessageBuffer * buffer) { - _queueMessage(new AsyncWebSocketMultiMessage(buffer)); + _queueMessage(buffer); } void AsyncWebSocketClient::binary(const char * message, size_t len){ - _queueMessage(new AsyncWebSocketBasicMessage(message, len, WS_BINARY)); + _queueMessage(message, len, WS_BINARY); } void AsyncWebSocketClient::binary(const char * message){ binary(message, strlen(message)); @@ -904,7 +954,7 @@ void AsyncWebSocketClient::binary(const __FlashStringHelper *data, size_t len){ } void AsyncWebSocketClient::binary(AsyncWebSocketMessageBuffer * buffer) { - _queueMessage(new AsyncWebSocketMultiMessage(buffer, WS_BINARY)); + _queueMessage(buffer, WS_BINARY); } IPAddress AsyncWebSocketClient::remoteIP() const { @@ -1024,16 +1074,20 @@ void AsyncWebSocket::text(uint32_t id, const char * message, size_t len){ c->text(message, len); } -void AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer * buffer){ - if (!buffer) return; - buffer->lock(); - for(auto& c: _clients){ - if(c.status() == WS_CONNECTED){ - c.text(buffer); - } - } - buffer->unlock(); - _cleanBuffers(); +void AsyncWebSocket::textAll(AsyncWebSocketMessageBuffer * buffer) +{ + if (!buffer) + return; + + buffer->lock(); + + for(auto &c : _clients) + if (c.status() == WS_CONNECTED) + c.text(buffer); + + buffer->unlock(); + + _cleanBuffers(); } @@ -1056,28 +1110,40 @@ void AsyncWebSocket::binaryAll(const char * message, size_t len){ void AsyncWebSocket::binaryAll(AsyncWebSocketMessageBuffer * buffer) { - if (!buffer) return; - buffer->lock(); - for(auto& c: _clients){ - if(c.status() == WS_CONNECTED) - c.binary(buffer); - } - buffer->unlock(); - _cleanBuffers(); + if (!buffer) + return; + + buffer->lock(); + + for (auto &c : _clients) + if (c.status() == WS_CONNECTED) + c.binary(buffer); + + buffer->unlock(); + + _cleanBuffers(); } -void AsyncWebSocket::message(uint32_t id, AsyncWebSocketMessage *message){ - AsyncWebSocketClient *c = client(id); - if (c) - c->message(message); +void AsyncWebSocket::message(uint32_t id, const char *data, size_t len, uint8_t opcode, bool mask) +{ + AsyncWebSocketClient *c = client(id); + if (c) + c->message(data, len, opcode, mask); } -void AsyncWebSocket::messageAll(AsyncWebSocketMultiMessage *message){ - for(auto& c: _clients){ - if(c.status() == WS_CONNECTED) - c.message(message); - } - _cleanBuffers(); +void AsyncWebSocket::message(uint32_t id, AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask) +{ + AsyncWebSocketClient *c = client(id); + if (c) + c->message(buffer, opcode, mask); +} + +void AsyncWebSocket::messageAll(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode, bool mask){ + for (auto &c : _clients){ + if (c.status() == WS_CONNECTED) + c.message(buffer, opcode, mask); + } + _cleanBuffers(); } size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ @@ -1295,40 +1361,42 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request){ request->send(response); } -AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(size_t size) +AsyncWebSocketMessageBuffer *AsyncWebSocket::makeBuffer(size_t size) { - AsyncWebSocketMessageBuffer * buffer{}; - { - AsyncWebLockGuard l(_lock); - _buffers.emplace_back(size); - buffer = &_buffers.back(); - } - return buffer; + AsyncWebSocketMessageBuffer * buffer{}; + + { + AsyncWebLockGuard l(_lock); + _buffers.emplace_back(size); + buffer = &_buffers.back(); + } + + return buffer; } -AsyncWebSocketMessageBuffer * AsyncWebSocket::makeBuffer(uint8_t * data, size_t size) +AsyncWebSocketMessageBuffer *AsyncWebSocket::makeBuffer(uint8_t * data, size_t size) { - AsyncWebSocketMessageBuffer * buffer{}; + AsyncWebSocketMessageBuffer * buffer{}; - { - AsyncWebLockGuard l(_lock); - _buffers.emplace_back(data, size); - buffer = &_buffers.back(); - } + { + AsyncWebLockGuard l(_lock); + _buffers.emplace_back(data, size); + buffer = &_buffers.back(); + } - return buffer; + return buffer; } void AsyncWebSocket::_cleanBuffers() { - AsyncWebLockGuard l(_lock); + AsyncWebLockGuard l(_lock); - for (auto iter = std::begin(_buffers); iter != std::end(_buffers);){ - if(iter->canDelete()){ - iter = _buffers.erase(iter); - } else - iter++; - } + for (auto iter = std::begin(_buffers); iter != std::end(_buffers);){ + if(iter->canDelete()){ + iter = _buffers.erase(iter); + } else + iter++; + } } /* diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index d23b144..30a8cf6 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -148,7 +148,7 @@ class AsyncWebSocketMultiMessage: public AsyncWebSocketMessage { size_t _sent; size_t _ack; size_t _acked; - AsyncWebSocketMessageBuffer * _WSbuffer; + AsyncWebSocketMessageBuffer *_WSbuffer; public: AsyncWebSocketMultiMessage(AsyncWebSocketMessageBuffer * buffer, uint8_t opcode=WS_TEXT, bool mask=false); virtual ~AsyncWebSocketMultiMessage() override; @@ -157,6 +157,63 @@ public: virtual size_t send(AsyncClient *client) override ; }; +class PolymorphMessageContainer +{ + union { + AsyncWebSocketBasicMessage basicMessage; + AsyncWebSocketMultiMessage multiMessage; + }; + + enum class Type : uint8_t { Basic, Multi }; + const Type type; + +public: + PolymorphMessageContainer() = delete; + PolymorphMessageContainer(const PolymorphMessageContainer &) = delete; + PolymorphMessageContainer &operator=(const PolymorphMessageContainer &) = delete; + + PolymorphMessageContainer(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false) : + type{Type::Basic} + { + new (&basicMessage) AsyncWebSocketBasicMessage{data, len, opcode, mask}; + } + + PolymorphMessageContainer(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false) : + type{Type::Multi} + { + new (&multiMessage) AsyncWebSocketMultiMessage{buffer, opcode, mask}; + } + + ~PolymorphMessageContainer() + { + switch (type) + { + case Type::Basic: basicMessage.~AsyncWebSocketBasicMessage(); break; + case Type::Multi: multiMessage.~AsyncWebSocketMultiMessage(); break; + } + } + + AsyncWebSocketMessage &get() + { + switch (type) + { + case Type::Basic: return basicMessage; + case Type::Multi: return multiMessage; + } + __builtin_unreachable(); + } + + const AsyncWebSocketMessage &get() const + { + switch (type) + { + case Type::Basic: return basicMessage; + case Type::Multi: return multiMessage; + } + __builtin_unreachable(); + } +}; + class AsyncWebSocketClient { private: AsyncClient *_client; @@ -164,8 +221,10 @@ class AsyncWebSocketClient { uint32_t _clientId; AwsClientStatus _status; + AsyncWebLock _lock; + std::deque _controlQueue; - LinkedList _messageQueue; + std::deque _messageQueue; uint8_t _pstate; AwsFrameInfo _pinfo; @@ -173,8 +232,9 @@ class AsyncWebSocketClient { uint32_t _lastMessageTime; uint32_t _keepAlivePeriod; - void _queueMessage(AsyncWebSocketMessage *dataMessage); void _queueControl(uint8_t opcode, uint8_t *data=NULL, size_t len=0, bool mask=false); + void _queueMessage(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); + void _queueMessage(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false); void _runQueue(); void _clearQueue(); @@ -209,7 +269,8 @@ class AsyncWebSocketClient { } //data packets - void message(AsyncWebSocketMessage *message){ _queueMessage(message); } + void message(const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(data, len, opcode, mask); } + void message(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false) { _queueMessage(buffer, opcode, mask); } bool queueIsFull() const; size_t queueLen() const; @@ -233,7 +294,7 @@ class AsyncWebSocketClient { void binary(const __FlashStringHelper *data, size_t len); void binary(AsyncWebSocketMessageBuffer *buffer); - bool canSend() { return _messageQueue.length() < WS_MAX_QUEUED_MESSAGES; } + bool canSend() const; //system callbacks (do not call) void _onAck(size_t len, uint32_t time); @@ -291,7 +352,7 @@ class AsyncWebSocket: public AsyncWebHandler { void textAll(char * message); void textAll(const String &message); void textAll(const __FlashStringHelper *message); // need to convert - void textAll(AsyncWebSocketMessageBuffer * buffer); + void textAll(AsyncWebSocketMessageBuffer *buffer); void binary(uint32_t id, const char * message, size_t len); void binary(uint32_t id, const char * message); @@ -306,10 +367,11 @@ class AsyncWebSocket: public AsyncWebHandler { void binaryAll(char * message); void binaryAll(const String &message); void binaryAll(const __FlashStringHelper *message, size_t len); - void binaryAll(AsyncWebSocketMessageBuffer * buffer); + void binaryAll(AsyncWebSocketMessageBuffer *buffer); - void message(uint32_t id, AsyncWebSocketMessage *message); - void messageAll(AsyncWebSocketMultiMessage *message); + void message(uint32_t id, const char *data, size_t len, uint8_t opcode=WS_TEXT, bool mask=false); + void message(uint32_t id, AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false); + void messageAll(AsyncWebSocketMessageBuffer *buffer, uint8_t opcode=WS_TEXT, bool mask=false); size_t printf(uint32_t id, const char *format, ...) __attribute__ ((format (printf, 3, 4))); size_t printfAll(const char *format, ...) __attribute__ ((format (printf, 2, 3))); diff --git a/src/AsyncWebSynchronization.h b/src/AsyncWebSynchronization.h index 02ad2dc..0f76815 100644 --- a/src/AsyncWebSynchronization.h +++ b/src/AsyncWebSynchronization.h @@ -121,6 +121,13 @@ public: _lock->unlock(); } } + + void unlock() { + if (_lock) { + _lock->unlock(); + _lock = NULL; + } + } }; #endif // ASYNCWEBSYNCHRONIZATION_H_