diff --git a/src/AsyncEventSource.cpp b/src/AsyncEventSource.cpp index 8ea5b1b..ecb6b01 100644 --- a/src/AsyncEventSource.cpp +++ b/src/AsyncEventSource.cpp @@ -307,7 +307,7 @@ size_t AsyncEventSource::count() const { } bool AsyncEventSource::canHandle(AsyncWebServerRequest *request){ - if(request->method() != HTTP_GET || !request->url().equals(_url)) + if(request->method() != HTTP_GET || !request->url().equals(_url) || !request->isExpectedRequestedConnType(RCT_EVENT)) return false; request->addInterestingHeader("Last-Event-ID"); return true; diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 5f0c581..69f959c 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -40,6 +40,7 @@ void SHA1Final(unsigned char digest[20], SHA1_CTX* context); #include #endif +#define MAX_PRINTF_LEN 64 size_t webSocketSendFrameWindow(AsyncClient *client){ if(!client->canSend()) @@ -645,16 +646,18 @@ void AsyncWebSocketClient::_onData(void *buf, size_t plen){ size_t AsyncWebSocketClient::printf(const char *format, ...) { va_list arg; va_start(arg, format); - char* temp = new char[64]; + char* temp = new char[MAX_PRINTF_LEN]; if(!temp){ return 0; } char* buffer = temp; - size_t len = vsnprintf(temp, 64, format, arg); + size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); va_end(arg); - if (len > 63) { + + if (len > (MAX_PRINTF_LEN - 1)) { buffer = new char[len + 1]; if (!buffer) { + delete[] temp; return 0; } va_start(arg, format); @@ -672,16 +675,18 @@ size_t AsyncWebSocketClient::printf(const char *format, ...) { size_t AsyncWebSocketClient::printf_P(PGM_P formatP, ...) { va_list arg; va_start(arg, formatP); - char* temp = new char[64]; + char* temp = new char[MAX_PRINTF_LEN]; if(!temp){ return 0; } char* buffer = temp; - size_t len = vsnprintf_P(temp, 64, formatP, arg); + size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); va_end(arg); - if (len > 63) { + + if (len > (MAX_PRINTF_LEN - 1)) { buffer = new char[len + 1]; if (!buffer) { + delete[] temp; return 0; } va_start(arg, formatP); @@ -928,11 +933,15 @@ size_t AsyncWebSocket::printf(uint32_t id, const char *format, ...){ size_t AsyncWebSocket::printfAll(const char *format, ...) { va_list arg; + char* temp = new char[MAX_PRINTF_LEN]; + if(!temp){ + return 0; + } va_start(arg, format); - return 0; - size_t len = vsnprintf(nullptr, 0, format, arg); + size_t len = vsnprintf(temp, MAX_PRINTF_LEN, format, arg); va_end(arg); - + delete[] temp; + AsyncWebSocketMessageBuffer * buffer = makeBuffer(len + 1); if (!buffer) { return 0; @@ -960,17 +969,24 @@ size_t AsyncWebSocket::printf_P(uint32_t id, PGM_P formatP, ...){ size_t AsyncWebSocket::printfAll_P(PGM_P formatP, ...) { va_list arg; + char* temp = new char[MAX_PRINTF_LEN]; + if(!temp){ + return 0; + } va_start(arg, formatP); - size_t len = vsnprintf_P(nullptr, 0, formatP, arg); + size_t len = vsnprintf_P(temp, MAX_PRINTF_LEN, formatP, arg); va_end(arg); - + delete[] temp; + AsyncWebSocketMessageBuffer * buffer = makeBuffer(len + 1); if (!buffer) { return 0; } + va_start(arg, formatP); - vsnprintf_P( (char *)buffer->get(), len + 1, formatP, arg); + vsnprintf_P((char *)buffer->get(), len + 1, formatP, arg); va_end(arg); + textAll(buffer); return len; } @@ -1058,8 +1074,8 @@ const char * WS_STR_UUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){ if(!_enabled) return false; - - if(request->method() != HTTP_GET || !request->url().equals(_url)) + + if(request->method() != HTTP_GET || !request->url().equals(_url) || !request->isExpectedRequestedConnType(RCT_WS)) return false; request->addInterestingHeader(WS_STR_CONNECTION); diff --git a/src/ESPAsyncWebServer.h b/src/ESPAsyncWebServer.h index 72644d0..82b1584 100644 --- a/src/ESPAsyncWebServer.h +++ b/src/ESPAsyncWebServer.h @@ -113,6 +113,8 @@ class AsyncWebHeader { * REQUEST :: Each incoming Client is wrapped inside a Request and both live together until disconnect * */ +typedef enum { RCT_NOT_USED = -1, RCT_DEFAULT = 0, RCT_HTTP, RCT_WS, RCT_EVENT, RCT_MAX } RequestedConnectionType; + typedef std::function AwsResponseFiller; class AsyncWebServerRequest { @@ -136,6 +138,8 @@ class AsyncWebServerRequest { String _contentType; String _boundary; String _authorization; + RequestedConnectionType _reqconntype; + void _removeNotInterestingHeaders(); bool _isDigest; bool _isMultipart; bool _isPlainPost; @@ -194,7 +198,9 @@ class AsyncWebServerRequest { size_t contentLength() const { return _contentLength; } bool multipart() const { return _isMultipart; } const char * methodToString() const; - + const char * requestedConnTypeToString() const; + RequestedConnectionType requestedConnType() const { return _reqconntype; } + bool isExpectedRequestedConnType(RequestedConnectionType erct1, RequestedConnectionType erct2 = RCT_NOT_USED, RequestedConnectionType erct3 = RCT_NOT_USED); //hash is the string representation of: // base64(user:pass) for basic or diff --git a/src/WebHandlers.cpp b/src/WebHandlers.cpp index 9276d4f..1a528c6 100644 --- a/src/WebHandlers.cpp +++ b/src/WebHandlers.cpp @@ -81,10 +81,13 @@ AsyncStaticWebHandler& AsyncStaticWebHandler::setLastModified(){ } #endif bool AsyncStaticWebHandler::canHandle(AsyncWebServerRequest *request){ - if (request->method() == HTTP_GET && - request->url().startsWith(_uri) && - _getFile(request)) { - + if(request->method() != HTTP_GET + || !request->url().startsWith(_uri) + || !request->isExpectedRequestedConnType(RCT_DEFAULT, RCT_HTTP) + ){ + return false; + } + if (_getFile(request)) { // We interested in "If-Modified-Since" header to check if file was modified if (_last_modified.length()) request->addInterestingHeader("If-Modified-Since"); diff --git a/src/WebRequest.cpp b/src/WebRequest.cpp index 37e4ec3..c5bfece 100644 --- a/src/WebRequest.cpp +++ b/src/WebRequest.cpp @@ -46,6 +46,7 @@ AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c) , _contentType() , _boundary() , _authorization() + , _reqconntype(RCT_HTTP) , _isDigest(false) , _isMultipart(false) , _isPlainPost(false) @@ -96,13 +97,17 @@ AsyncWebServerRequest::~AsyncWebServerRequest(){ } void AsyncWebServerRequest::_onData(void *buf, size_t len){ + int i = 0; while (true) { if(_parseState < PARSE_REQ_BODY){ // Find new line in buf char *str = (char*)buf; - size_t i = 0; - for (; i < len; i++) if (str[i] == '\n') break; + for (i = 0; i < len; i++) { + if (str[i] == '\n') { + break; + } + } if (i == len) { // No new line, just add the buffer in _temp char ch = str[len-1]; str[len-1] = 0; @@ -152,7 +157,6 @@ void AsyncWebServerRequest::_onData(void *buf, size_t len){ } } } - if(_parsedLength == _contentLength){ _parseState = PARSE_REQ_END; //check if authenticated before calling handleRequest and request auth instead @@ -160,11 +164,19 @@ void AsyncWebServerRequest::_onData(void *buf, size_t len){ else send(501); } } - break; } } +void AsyncWebServerRequest::_removeNotInterestingHeaders(){ + if (_interestingHeaders.containsIgnoreCase("ANY")) return; // nothing to do + for(const auto& header: _headers){ + if(!_interestingHeaders.containsIgnoreCase(header->name().c_str())){ + _headers.remove(header); + } + } +} + void AsyncWebServerRequest::_onPoll(){ //os_printf("p\n"); if(_response != NULL && _client != NULL && _client->canSend() && !_response->_finished()){ @@ -257,6 +269,24 @@ bool AsyncWebServerRequest::_parseReqHead(){ return true; } +bool strContains(String src, String find, bool mindcase = true) { + int pos=0, i=0; + const int slen = src.length(); + const int flen = find.length(); + + if (slen < flen) return false; + while (pos <= (slen - flen)) { + for (i=0; i < flen; i++) { + if (mindcase) { + if (src[pos+i] != find[i]) i = flen + 1; // no match + } else if (tolower(src[pos+i]) != tolower(find[i])) i = flen + 1; // no match + } + if (i == flen) return true; + pos++; + } + return false; +} + bool AsyncWebServerRequest::_parseReqHeader(){ int index = _temp.indexOf(':'); if(index){ @@ -264,8 +294,6 @@ bool AsyncWebServerRequest::_parseReqHeader(){ String value = _temp.substring(index + 2); if(name.equalsIgnoreCase("Host")){ _host = value; - _server->_rewriteRequest(this); - _server->_attachHandler(this); } else if(name.equalsIgnoreCase("Content-Type")){ if (value.startsWith("multipart/")){ _boundary = value.substring(value.indexOf('=')+1); @@ -287,10 +315,17 @@ bool AsyncWebServerRequest::_parseReqHeader(){ _authorization = value.substring(7); } } else { - if(_interestingHeaders.containsIgnoreCase(name) || _interestingHeaders.containsIgnoreCase("ANY")){ - _headers.add(new AsyncWebHeader(name, value)); + if(name.equalsIgnoreCase("Upgrade") && value.equalsIgnoreCase("websocket")){ + // WebSocket request can be uniquely identified by header: [Upgrade: websocket] + _reqconntype = RCT_WS; + } else { + if(name.equalsIgnoreCase("Accept") && strContains(value, "text/event-stream", false)){ + // WebEvent request can be uniquely identified by header: [Accept: text/event-stream] + _reqconntype = RCT_EVENT; + } } } + _headers.add(new AsyncWebHeader(name, value)); } _temp = String(); return true; @@ -511,6 +546,9 @@ void AsyncWebServerRequest::_parseLine(){ if(_parseState == PARSE_REQ_HEADERS){ if(!_temp.length()){ //end of headers + _server->_rewriteRequest(this); + _server->_attachHandler(this); + _removeNotInterestingHeaders(); if(_expectingContinue){ const char * response = "HTTP/1.1 100 Continue\r\n\r\n"; _client->write(response, os_strlen(response)); @@ -924,3 +962,22 @@ const char * AsyncWebServerRequest::methodToString() const { else if(_method & HTTP_OPTIONS) return "OPTIONS"; return "UNKNOWN"; } + +const char *AsyncWebServerRequest::requestedConnTypeToString() const { + switch (_reqconntype) { + case RCT_NOT_USED: return "RCT_NOT_USED"; + case RCT_DEFAULT: return "RCT_DEFAULT"; + case RCT_HTTP: return "RCT_HTTP"; + case RCT_WS: return "RCT_WS"; + case RCT_EVENT: return "RCT_EVENT"; + default: return "ERROR"; + } +} + +bool AsyncWebServerRequest::isExpectedRequestedConnType(RequestedConnectionType erct1, RequestedConnectionType erct2, RequestedConnectionType erct3) { + bool res = false; + if ((erct1 != RCT_NOT_USED) && (erct1 == _reqconntype)) res = true; + if ((erct2 != RCT_NOT_USED) && (erct2 == _reqconntype)) res = true; + if ((erct3 != RCT_NOT_USED) && (erct3 == _reqconntype)) res = true; + return res; +}