From b6b43d35fa4b74432f36e3e0500b81eede34d4c1 Mon Sep 17 00:00:00 2001 From: Chris Byrne Date: Tue, 24 Jul 2018 11:22:54 -0700 Subject: [PATCH] Handle multiple WebSocket frames within a TCP packet (#338) --- src/AsyncWebSocket.cpp | 142 ++++++++++++++++++++++------------------- src/AsyncWebSocket.h | 2 +- 2 files changed, 77 insertions(+), 67 deletions(-) diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index e4f8efb..8641a48 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -565,83 +565,93 @@ void AsyncWebSocketClient::_onDisconnect(){ _server->_handleDisconnect(this); } -void AsyncWebSocketClient::_onData(void *buf, size_t plen){ +void AsyncWebSocketClient::_onData(void *pbuf, size_t plen){ _lastMessageTime = millis(); - uint8_t *fdata = (uint8_t*)buf; - uint8_t * data = fdata; - if(!_pstate){ - _pinfo.index = 0; - _pinfo.final = (fdata[0] & 0x80) != 0; - _pinfo.opcode = fdata[0] & 0x0F; - _pinfo.masked = (fdata[1] & 0x80) != 0; - _pinfo.len = fdata[1] & 0x7F; - data += 2; - plen = plen - 2; - if(_pinfo.len == 126){ - _pinfo.len = fdata[3] | (uint16_t)(fdata[2]) << 8; + uint8_t *data = (uint8_t*)pbuf; + while(plen > 0){ + if(!_pstate){ + const uint8_t *fdata = data; + _pinfo.index = 0; + _pinfo.final = (fdata[0] & 0x80) != 0; + _pinfo.opcode = fdata[0] & 0x0F; + _pinfo.masked = (fdata[1] & 0x80) != 0; + _pinfo.len = fdata[1] & 0x7F; data += 2; - plen = plen - 2; - } else if(_pinfo.len == 127){ - _pinfo.len = fdata[9] | (uint16_t)(fdata[8]) << 8 | (uint32_t)(fdata[7]) << 16 | (uint32_t)(fdata[6]) << 24 | (uint64_t)(fdata[5]) << 32 | (uint64_t)(fdata[4]) << 40 | (uint64_t)(fdata[3]) << 48 | (uint64_t)(fdata[2]) << 56; - data += 8; - plen = plen - 8; + plen -= 2; + if(_pinfo.len == 126){ + _pinfo.len = fdata[3] | (uint16_t)(fdata[2]) << 8; + data += 2; + plen -= 2; + } else if(_pinfo.len == 127){ + _pinfo.len = fdata[9] | (uint16_t)(fdata[8]) << 8 | (uint32_t)(fdata[7]) << 16 | (uint32_t)(fdata[6]) << 24 | (uint64_t)(fdata[5]) << 32 | (uint64_t)(fdata[4]) << 40 | (uint64_t)(fdata[3]) << 48 | (uint64_t)(fdata[2]) << 56; + data += 8; + plen -= 8; + } + + if(_pinfo.masked){ + memcpy(_pinfo.mask, data, 4); + data += 4; + plen -= 4; + } } + const size_t datalen = std::min((size_t)(_pinfo.len - _pinfo.index), plen); + const auto datalast = data[datalen]; + if(_pinfo.masked){ - memcpy(_pinfo.mask, data, 4); - data += 4; - plen = plen - 4; - size_t i; - for(i=0;i_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, (uint8_t*)data, plen); + if((datalen + _pinfo.index) < _pinfo.len){ + _pstate = 1; - _pinfo.index += plen; - } else if((plen + _pinfo.index) == _pinfo.len){ - _pstate = 0; - if(_pinfo.opcode == WS_DISCONNECT){ - if(plen){ - uint16_t reasonCode = (uint16_t)(data[0] << 8) + data[1]; - char * reasonString = (char*)(data+2); - if(reasonCode > 1001){ - _server->_handleEvent(this, WS_EVT_ERROR, (void *)&reasonCode, (uint8_t*)reasonString, strlen(reasonString)); + if(_pinfo.index == 0){ + if(_pinfo.opcode){ + _pinfo.message_opcode = _pinfo.opcode; + _pinfo.num = 0; + } else _pinfo.num += 1; + } + _server->_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, (uint8_t*)data, datalen); + + _pinfo.index += datalen; + } else if((datalen + _pinfo.index) == _pinfo.len){ + _pstate = 0; + if(_pinfo.opcode == WS_DISCONNECT){ + if(datalen){ + uint16_t reasonCode = (uint16_t)(data[0] << 8) + data[1]; + char * reasonString = (char*)(data+2); + if(reasonCode > 1001){ + _server->_handleEvent(this, WS_EVT_ERROR, (void *)&reasonCode, (uint8_t*)reasonString, strlen(reasonString)); + } } + if(_status == WS_DISCONNECTING){ + _status = WS_DISCONNECTED; + _client->close(true); + } else { + _status = WS_DISCONNECTING; + _queueControl(new AsyncWebSocketControl(WS_DISCONNECT, data, datalen)); + } + } else if(_pinfo.opcode == WS_PING){ + _queueControl(new AsyncWebSocketControl(WS_PONG, data, datalen)); + } else if(_pinfo.opcode == WS_PONG){ + if(datalen != AWSC_PING_PAYLOAD_LEN || memcmp(AWSC_PING_PAYLOAD, data, AWSC_PING_PAYLOAD_LEN) != 0) + _server->_handleEvent(this, WS_EVT_PONG, NULL, data, datalen); + } else if(_pinfo.opcode < 8){//continuation or text/binary frame + _server->_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, data, datalen); } - if(_status == WS_DISCONNECTING){ - _status = WS_DISCONNECTED; - _client->close(true); - } else { - _status = WS_DISCONNECTING; - _queueControl(new AsyncWebSocketControl(WS_DISCONNECT, data, plen)); - } - } else if(_pinfo.opcode == WS_PING){ - _queueControl(new AsyncWebSocketControl(WS_PONG, data, plen)); - } else if(_pinfo.opcode == WS_PONG){ - if(plen != AWSC_PING_PAYLOAD_LEN || memcmp(AWSC_PING_PAYLOAD, data, AWSC_PING_PAYLOAD_LEN) != 0) - _server->_handleEvent(this, WS_EVT_PONG, NULL, (uint8_t*)data, plen); - } else if(_pinfo.opcode < 8){//continuation or text/binary frame - _server->_handleEvent(this, WS_EVT_DATA, (void *)&_pinfo, (uint8_t*)data, plen); + } else { + //os_printf("frame error: len: %u, index: %llu, total: %llu\n", datalen, _pinfo.index, _pinfo.len); + //what should we do? + break; } - } else { - //os_printf("frame error: len: %u, index: %llu, total: %llu\n", plen, _pinfo.index, _pinfo.len); - //what should we do? + + // restore byte as _handleEvent may have added a null terminator i.e., data[len] = 0; + if (datalen > 0) + data[datalen] = datalast; + + data += datalen; + plen -= datalen; } } diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index e2291b1..37c8d11 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -214,7 +214,7 @@ class AsyncWebSocketClient { void _onPoll(); void _onTimeout(uint32_t time); void _onDisconnect(); - void _onData(void *buf, size_t plen); + void _onData(void *pbuf, size_t plen); }; typedef std::function AwsEventHandler;