diff --git a/platformio.ini b/platformio.ini index d263276..661e4a6 100644 --- a/platformio.ini +++ b/platformio.ini @@ -61,6 +61,8 @@ board = rpipicow lib_deps = bblanchon/ArduinoJson @ 7.2.0 khoih-prog/AsyncTCP_RP2040W @ 1.2.0 +build_flags = ${env.build_flags} + -Wno-missing-field-initializers ; CI @@ -90,3 +92,5 @@ board = ${sysenv.PIO_BOARD} lib_deps = bblanchon/ArduinoJson @ 7.2.0 khoih-prog/AsyncTCP_RP2040W @ 1.2.0 +build_flags = ${env.build_flags} + -Wno-missing-field-initializers diff --git a/src/AsyncEventSource.cpp b/src/AsyncEventSource.cpp index 09839f4..2d6ab7b 100644 --- a/src/AsyncEventSource.cpp +++ b/src/AsyncEventSource.cpp @@ -22,7 +22,6 @@ #include #endif #include "AsyncEventSource.h" -#include "literals.h" using namespace asyncsrv; @@ -288,6 +287,12 @@ void AsyncEventSource::onConnect(ArEventHandlerFunction cb) { _connectcb = cb; } +void AsyncEventSource::authorizeConnect(ArAuthorizeConnectHandler cb) { + AuthorizationMiddleware* m = new AuthorizationMiddleware(401, cb); + m->_freeOnRemoval = true; + addMiddleware(m); +} + void AsyncEventSource::_addClient(AsyncEventSourceClient* client) { if (!client) return; diff --git a/src/AsyncEventSource.h b/src/AsyncEventSource.h index 302f869..1c688b0 100644 --- a/src/AsyncEventSource.h +++ b/src/AsyncEventSource.h @@ -124,11 +124,7 @@ class AsyncEventSource : public AsyncWebHandler { const char* url() const { return _url.c_str(); } void close(); void onConnect(ArEventHandlerFunction cb); - void authorizeConnect(ArAuthorizeConnectHandler cb) { - AuthorizationMiddleware* m = new AuthorizationMiddleware(401, cb); - m->_freeOnRemoval = true; - addMiddleware(m); - } + void authorizeConnect(ArAuthorizeConnectHandler cb); void send(const String& message, const String& event, uint32_t id = 0, uint32_t reconnect = 0) { send(message.c_str(), event.c_str(), id, reconnect); } void send(const String& message, const char* event, uint32_t id = 0, uint32_t reconnect = 0) { send(message.c_str(), event, id, reconnect); } void send(const char* message, const char* event = NULL, uint32_t id = 0, uint32_t reconnect = 0); diff --git a/src/AsyncJson.cpp b/src/AsyncJson.cpp new file mode 100644 index 0000000..511f2ec --- /dev/null +++ b/src/AsyncJson.cpp @@ -0,0 +1,151 @@ +#include "AsyncJson.h" + +#if ARDUINOJSON_VERSION_MAJOR == 5 +AsyncJsonResponse::AsyncJsonResponse(bool isArray) : _isValid{false} { + _code = 200; + _contentType = JSON_MIMETYPE; + if (isArray) + _root = _jsonBuffer.createArray(); + else + _root = _jsonBuffer.createObject(); +} +#elif ARDUINOJSON_VERSION_MAJOR == 6 +AsyncJsonResponse::AsyncJsonResponse(bool isArray, size_t maxJsonBufferSize) : _jsonBuffer(maxJsonBufferSize), _isValid{false} { + _code = 200; + _contentType = JSON_MIMETYPE; + if (isArray) + _root = _jsonBuffer.createNestedArray(); + else + _root = _jsonBuffer.createNestedObject(); +} +#else +AsyncJsonResponse::AsyncJsonResponse(bool isArray) : _isValid{false} { + _code = 200; + _contentType = JSON_MIMETYPE; + if (isArray) + _root = _jsonBuffer.add(); + else + _root = _jsonBuffer.add(); +} +#endif + +size_t AsyncJsonResponse::setLength() { +#if ARDUINOJSON_VERSION_MAJOR == 5 + _contentLength = _root.measureLength(); +#else + _contentLength = measureJson(_root); +#endif + if (_contentLength) { + _isValid = true; + } + return _contentLength; +} + +size_t AsyncJsonResponse::_fillBuffer(uint8_t* data, size_t len) { + ChunkPrint dest(data, _sentLength, len); +#if ARDUINOJSON_VERSION_MAJOR == 5 + _root.printTo(dest); +#else + serializeJson(_root, dest); +#endif + return len; +} + +#if ARDUINOJSON_VERSION_MAJOR == 6 +PrettyAsyncJsonResponse::PrettyAsyncJsonResponse(bool isArray, size_t maxJsonBufferSize) : AsyncJsonResponse{isArray, maxJsonBufferSize} {} +#else +PrettyAsyncJsonResponse::PrettyAsyncJsonResponse(bool isArray) : AsyncJsonResponse{isArray} {} +#endif + +size_t PrettyAsyncJsonResponse::setLength() { +#if ARDUINOJSON_VERSION_MAJOR == 5 + _contentLength = _root.measurePrettyLength(); +#else + _contentLength = measureJsonPretty(_root); +#endif + if (_contentLength) { + _isValid = true; + } + return _contentLength; +} + +size_t PrettyAsyncJsonResponse::_fillBuffer(uint8_t* data, size_t len) { + ChunkPrint dest(data, _sentLength, len); +#if ARDUINOJSON_VERSION_MAJOR == 5 + _root.prettyPrintTo(dest); +#else + serializeJsonPretty(_root, dest); +#endif + return len; +} + +#if ARDUINOJSON_VERSION_MAJOR == 6 +AsyncCallbackJsonWebHandler::AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest, size_t maxJsonBufferSize) + : _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), maxJsonBufferSize(maxJsonBufferSize), _maxContentLength(16384) {} +#else +AsyncCallbackJsonWebHandler::AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest) + : _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {} +#endif + +bool AsyncCallbackJsonWebHandler::canHandle(AsyncWebServerRequest* request) { + if (!_onRequest) + return false; + + WebRequestMethodComposite request_method = request->method(); + if (!(_method & request_method)) + return false; + + if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/"))) + return false; + + if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(JSON_MIMETYPE)) + return false; + + return true; +} + +void AsyncCallbackJsonWebHandler::handleRequest(AsyncWebServerRequest* request) { + if (_onRequest) { + if (request->method() == HTTP_GET) { + JsonVariant json; + _onRequest(request, json); + return; + } else if (request->_tempObject != NULL) { + +#if ARDUINOJSON_VERSION_MAJOR == 5 + DynamicJsonBuffer jsonBuffer; + JsonVariant json = jsonBuffer.parse((uint8_t*)(request->_tempObject)); + if (json.success()) { +#elif ARDUINOJSON_VERSION_MAJOR == 6 + DynamicJsonDocument jsonBuffer(this->maxJsonBufferSize); + DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject)); + if (!error) { + JsonVariant json = jsonBuffer.as(); +#else + JsonDocument jsonBuffer; + DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject)); + if (!error) { + JsonVariant json = jsonBuffer.as(); +#endif + + _onRequest(request, json); + return; + } + } + request->send(_contentLength > _maxContentLength ? 413 : 400); + } else { + request->send(500); + } +} + +void AsyncCallbackJsonWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { + if (_onRequest) { + _contentLength = total; + if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) { + request->_tempObject = malloc(total); + } + if (request->_tempObject != NULL) { + memcpy((uint8_t*)(request->_tempObject) + index, data, len); + } + } +} \ No newline at end of file diff --git a/src/AsyncJson.h b/src/AsyncJson.h index 5426f41..0e7e61d 100644 --- a/src/AsyncJson.h +++ b/src/AsyncJson.h @@ -65,102 +65,35 @@ class AsyncJsonResponse : public AsyncAbstractResponse { bool _isValid; public: -#if ARDUINOJSON_VERSION_MAJOR == 5 - AsyncJsonResponse(bool isArray = false) : _isValid{false} { - _code = 200; - _contentType = JSON_MIMETYPE; - if (isArray) - _root = _jsonBuffer.createArray(); - else - _root = _jsonBuffer.createObject(); - } -#elif ARDUINOJSON_VERSION_MAJOR == 6 - AsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE) : _jsonBuffer(maxJsonBufferSize), _isValid{false} { - _code = 200; - _contentType = JSON_MIMETYPE; - if (isArray) - _root = _jsonBuffer.createNestedArray(); - else - _root = _jsonBuffer.createNestedObject(); - } +#if ARDUINOJSON_VERSION_MAJOR == 6 + AsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE); #else - AsyncJsonResponse(bool isArray = false) : _isValid{false} { - _code = 200; - _contentType = JSON_MIMETYPE; - if (isArray) - _root = _jsonBuffer.add(); - else - _root = _jsonBuffer.add(); - } + AsyncJsonResponse(bool isArray = false); #endif - JsonVariant& getRoot() { return _root; } bool _sourceValid() const { return _isValid; } - size_t setLength() { - -#if ARDUINOJSON_VERSION_MAJOR == 5 - _contentLength = _root.measureLength(); -#else - _contentLength = measureJson(_root); -#endif - - if (_contentLength) { - _isValid = true; - } - return _contentLength; - } - + size_t setLength(); size_t getSize() const { return _jsonBuffer.size(); } - + size_t _fillBuffer(uint8_t* data, size_t len); #if ARDUINOJSON_VERSION_MAJOR >= 6 bool overflowed() const { return _jsonBuffer.overflowed(); } #endif - - size_t _fillBuffer(uint8_t* data, size_t len) { - ChunkPrint dest(data, _sentLength, len); - -#if ARDUINOJSON_VERSION_MAJOR == 5 - _root.printTo(dest); -#else - serializeJson(_root, dest); -#endif - return len; - } }; class PrettyAsyncJsonResponse : public AsyncJsonResponse { public: #if ARDUINOJSON_VERSION_MAJOR == 6 - PrettyAsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE) : AsyncJsonResponse{isArray, maxJsonBufferSize} {} + PrettyAsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE); #else - PrettyAsyncJsonResponse(bool isArray = false) : AsyncJsonResponse{isArray} {} + PrettyAsyncJsonResponse(bool isArray = false); #endif - size_t setLength() { -#if ARDUINOJSON_VERSION_MAJOR == 5 - _contentLength = _root.measurePrettyLength(); -#else - _contentLength = measureJsonPretty(_root); -#endif - if (_contentLength) { - _isValid = true; - } - return _contentLength; - } - size_t _fillBuffer(uint8_t* data, size_t len) { - ChunkPrint dest(data, _sentLength, len); -#if ARDUINOJSON_VERSION_MAJOR == 5 - _root.prettyPrintTo(dest); -#else - serializeJsonPretty(_root, dest); -#endif - return len; - } + size_t setLength(); + size_t _fillBuffer(uint8_t* data, size_t len); }; typedef std::function ArJsonRequestHandlerFunction; class AsyncCallbackJsonWebHandler : public AsyncWebHandler { - private: protected: const String _uri; WebRequestMethodComposite _method; @@ -173,80 +106,19 @@ class AsyncCallbackJsonWebHandler : public AsyncWebHandler { public: #if ARDUINOJSON_VERSION_MAJOR == 6 - AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE) - : _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), maxJsonBufferSize(maxJsonBufferSize), _maxContentLength(16384) {} + AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE); #else - AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr) - : _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {} + AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr); #endif void setMethod(WebRequestMethodComposite method) { _method = method; } void setMaxContentLength(int maxContentLength) { _maxContentLength = maxContentLength; } void onRequest(ArJsonRequestHandlerFunction fn) { _onRequest = fn; } - virtual bool canHandle(AsyncWebServerRequest* request) override final { - if (!_onRequest) - return false; - - WebRequestMethodComposite request_method = request->method(); - if (!(_method & request_method)) - return false; - - if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/"))) - return false; - - if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(JSON_MIMETYPE)) - return false; - - return true; - } - - virtual void handleRequest(AsyncWebServerRequest* request) override final { - if (_onRequest) { - if (request->method() == HTTP_GET) { - JsonVariant json; - _onRequest(request, json); - return; - } else if (request->_tempObject != NULL) { - -#if ARDUINOJSON_VERSION_MAJOR == 5 - DynamicJsonBuffer jsonBuffer; - JsonVariant json = jsonBuffer.parse((uint8_t*)(request->_tempObject)); - if (json.success()) { -#elif ARDUINOJSON_VERSION_MAJOR == 6 - DynamicJsonDocument jsonBuffer(this->maxJsonBufferSize); - DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject)); - if (!error) { - JsonVariant json = jsonBuffer.as(); -#else - JsonDocument jsonBuffer; - DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject)); - if (!error) { - JsonVariant json = jsonBuffer.as(); -#endif - - _onRequest(request, json); - return; - } - } - request->send(_contentLength > _maxContentLength ? 413 : 400); - } else { - request->send(500); - } - } - virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final { - } - virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final { - if (_onRequest) { - _contentLength = total; - if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) { - request->_tempObject = malloc(total); - } - if (request->_tempObject != NULL) { - memcpy((uint8_t*)(request->_tempObject) + index, data, len); - } - } - } - virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; } + virtual bool canHandle(AsyncWebServerRequest* request) override final; + virtual void handleRequest(AsyncWebServerRequest* request) override final; + virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {} + virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final; + virtual bool isRequestHandlerTrivial() override final { return !_onRequest; } }; #endif diff --git a/src/AsyncMessagePack.cpp b/src/AsyncMessagePack.cpp new file mode 100644 index 0000000..65827d7 --- /dev/null +++ b/src/AsyncMessagePack.cpp @@ -0,0 +1,79 @@ +#include "AsyncMessagePack.h" + +AsyncMessagePackResponse::AsyncMessagePackResponse(bool isArray) : _isValid{false} { + _code = 200; + _contentType = asyncsrv::T_application_msgpack; + if (isArray) + _root = _jsonBuffer.add(); + else + _root = _jsonBuffer.add(); +} + +size_t AsyncMessagePackResponse::setLength() { + _contentLength = measureMsgPack(_root); + if (_contentLength) { + _isValid = true; + } + return _contentLength; +} + +size_t AsyncMessagePackResponse::_fillBuffer(uint8_t* data, size_t len) { + ChunkPrint dest(data, _sentLength, len); + serializeMsgPack(_root, dest); + return len; +} + +AsyncCallbackMessagePackWebHandler::AsyncCallbackMessagePackWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest) + : _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {} + +bool AsyncCallbackMessagePackWebHandler::canHandle(AsyncWebServerRequest* request) { + if (!_onRequest) + return false; + + WebRequestMethodComposite request_method = request->method(); + if (!(_method & request_method)) + return false; + + if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/"))) + return false; + + if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(asyncsrv::T_application_msgpack)) + return false; + + return true; +} + +void AsyncCallbackMessagePackWebHandler::handleRequest(AsyncWebServerRequest* request) { + if (_onRequest) { + if (request->method() == HTTP_GET) { + JsonVariant json; + _onRequest(request, json); + return; + + } else if (request->_tempObject != NULL) { + JsonDocument jsonBuffer; + DeserializationError error = deserializeMsgPack(jsonBuffer, (uint8_t*)(request->_tempObject)); + + if (!error) { + JsonVariant json = jsonBuffer.as(); + _onRequest(request, json); + return; + } + } + request->send(_contentLength > _maxContentLength ? 413 : 400); + } else { + request->send(500); + } +} + +void AsyncCallbackMessagePackWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { + if (_onRequest) { + _contentLength = total; + if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) { + request->_tempObject = malloc(total); + } + if (request->_tempObject != NULL) { + memcpy((uint8_t*)(request->_tempObject) + index, data, len); + } + } +} diff --git a/src/AsyncMessagePack.h b/src/AsyncMessagePack.h index 9a6b6ba..16104ef 100644 --- a/src/AsyncMessagePack.h +++ b/src/AsyncMessagePack.h @@ -25,7 +25,6 @@ #include #include "ChunkPrint.h" -#include "literals.h" class AsyncMessagePackResponse : public AsyncAbstractResponse { protected: @@ -34,34 +33,13 @@ class AsyncMessagePackResponse : public AsyncAbstractResponse { bool _isValid; public: - AsyncMessagePackResponse(bool isArray = false) : _isValid{false} { - _code = 200; - _contentType = asyncsrv::T_application_msgpack; - if (isArray) - _root = _jsonBuffer.add(); - else - _root = _jsonBuffer.add(); - } + AsyncMessagePackResponse(bool isArray = false); JsonVariant& getRoot() { return _root; } - bool _sourceValid() const { return _isValid; } - - size_t setLength() { - _contentLength = measureMsgPack(_root); - if (_contentLength) { - _isValid = true; - } - return _contentLength; - } - + size_t setLength(); size_t getSize() const { return _jsonBuffer.size(); } - - size_t _fillBuffer(uint8_t* data, size_t len) { - ChunkPrint dest(data, _sentLength, len); - serializeMsgPack(_root, dest); - return len; - } + size_t _fillBuffer(uint8_t* data, size_t len); }; class AsyncCallbackMessagePackWebHandler : public AsyncWebHandler { @@ -76,66 +54,14 @@ class AsyncCallbackMessagePackWebHandler : public AsyncWebHandler { size_t _maxContentLength; public: - AsyncCallbackMessagePackWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr) - : _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {} + AsyncCallbackMessagePackWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr); void setMethod(WebRequestMethodComposite method) { _method = method; } void setMaxContentLength(int maxContentLength) { _maxContentLength = maxContentLength; } void onRequest(ArJsonRequestHandlerFunction fn) { _onRequest = fn; } - - virtual bool canHandle(AsyncWebServerRequest* request) override final { - if (!_onRequest) - return false; - - WebRequestMethodComposite request_method = request->method(); - if (!(_method & request_method)) - return false; - - if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/"))) - return false; - - if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(asyncsrv::T_application_msgpack)) - return false; - - return true; - } - - virtual void handleRequest(AsyncWebServerRequest* request) override final { - if (_onRequest) { - if (request->method() == HTTP_GET) { - JsonVariant json; - _onRequest(request, json); - return; - - } else if (request->_tempObject != NULL) { - JsonDocument jsonBuffer; - DeserializationError error = deserializeMsgPack(jsonBuffer, (uint8_t*)(request->_tempObject)); - - if (!error) { - JsonVariant json = jsonBuffer.as(); - _onRequest(request, json); - return; - } - } - request->send(_contentLength > _maxContentLength ? 413 : 400); - } else { - request->send(500); - } - } - + virtual bool canHandle(AsyncWebServerRequest* request) override final; + virtual void handleRequest(AsyncWebServerRequest* request) override final; virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {} - - virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final { - if (_onRequest) { - _contentLength = total; - if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) { - request->_tempObject = malloc(total); - } - if (request->_tempObject != NULL) { - memcpy((uint8_t*)(request->_tempObject) + index, data, len); - } - } - } - + virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final; virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; } }; diff --git a/src/AsyncWebHeader.cpp b/src/AsyncWebHeader.cpp new file mode 100644 index 0000000..ba271a3 --- /dev/null +++ b/src/AsyncWebHeader.cpp @@ -0,0 +1,22 @@ +#include + +AsyncWebHeader::AsyncWebHeader(const String& data) { + if (!data) + return; + int index = data.indexOf(':'); + if (index < 0) + return; + _name = data.substring(0, index); + _value = data.substring(index + 2); +} + +String AsyncWebHeader::toString() const { + String str; + str.reserve(_name.length() + _value.length() + 2); + str.concat(_name); + str.concat((char)0x3a); + str.concat((char)0x20); + str.concat(_value); + str.concat(asyncsrv::T_rn); + return str; +} diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 34256a7..a832a2d 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -281,7 +281,7 @@ class AsyncWebSocket : public AsyncWebHandler { public: explicit AsyncWebSocket(const char* url) : _url(url), _cNextId(1), _enabled(true) {} AsyncWebSocket(const String& url) : _url(url), _cNextId(1), _enabled(true) {} - ~AsyncWebSocket(){}; + ~AsyncWebSocket() {}; const char* url() const { return _url.c_str(); } void enable(bool e) { _enabled = e; } bool enabled() const { return _enabled; } @@ -339,15 +339,8 @@ class AsyncWebSocket : public AsyncWebHandler { size_t printfAll_P(PGM_P formatP, ...) __attribute__((format(printf, 2, 3))); #endif - // event listener - void onEvent(AwsEventHandler handler) { - _eventHandler = handler; - } - - // Handshake Handler - void handleHandshake(AwsHandshakeHandler handler) { - _handshakeHandler = handler; - } + void onEvent(AwsEventHandler handler) { _eventHandler = handler; } + void handleHandshake(AwsHandshakeHandler handler) { _handshakeHandler = handler; } // system callbacks (do not call) uint32_t _getNextId() { return _cNextId++; } diff --git a/src/ChunkPrint.cpp b/src/ChunkPrint.cpp new file mode 100644 index 0000000..8c9717a --- /dev/null +++ b/src/ChunkPrint.cpp @@ -0,0 +1,16 @@ +#include + +ChunkPrint::ChunkPrint(uint8_t* destination, size_t from, size_t len) + : _destination(destination), _to_skip(from), _to_write(len), _pos{0} {} + +size_t ChunkPrint::write(uint8_t c) { + if (_to_skip > 0) { + _to_skip--; + return 1; + } else if (_to_write > 0) { + _to_write--; + _destination[_pos++] = c; + return 1; + } + return 0; +} \ No newline at end of file diff --git a/src/ChunkPrint.h b/src/ChunkPrint.h index 2f40741..103d21e 100644 --- a/src/ChunkPrint.h +++ b/src/ChunkPrint.h @@ -11,22 +11,9 @@ class ChunkPrint : public Print { size_t _pos; public: - ChunkPrint(uint8_t* destination, size_t from, size_t len) - : _destination(destination), _to_skip(from), _to_write(len), _pos{0} {} + ChunkPrint(uint8_t* destination, size_t from, size_t len); virtual ~ChunkPrint() {} - size_t write(uint8_t c) { - if (_to_skip > 0) { - _to_skip--; - return 1; - } else if (_to_write > 0) { - _to_write--; - _destination[_pos++] = c; - return 1; - } - return 0; - } - size_t write(const uint8_t* buffer, size_t size) { - return this->Print::write(buffer, size); - } + size_t write(uint8_t c); + size_t write(const uint8_t* buffer, size_t size) { return this->Print::write(buffer, size); } }; #endif diff --git a/src/ESPAsyncWebServer.h b/src/ESPAsyncWebServer.h index 8922ea7..621a9c8 100644 --- a/src/ESPAsyncWebServer.h +++ b/src/ESPAsyncWebServer.h @@ -24,6 +24,7 @@ #include "Arduino.h" #include "FS.h" +#include #include #include #include @@ -141,33 +142,15 @@ class AsyncWebHeader { public: AsyncWebHeader() = default; AsyncWebHeader(const AsyncWebHeader&) = default; - AsyncWebHeader(const char* name, const char* value) : _name(name), _value(value) {} AsyncWebHeader(const String& name, const String& value) : _name(name), _value(value) {} - AsyncWebHeader(const String& data) { - if (!data) - return; - int index = data.indexOf(':'); - if (index < 0) - return; - _name = data.substring(0, index); - _value = data.substring(index + 2); - } + AsyncWebHeader(const String& data); AsyncWebHeader& operator=(const AsyncWebHeader&) = default; const String& name() const { return _name; } const String& value() const { return _value; } - String toString() const { - String str; - str.reserve(_name.length() + _value.length() + 2); - str.concat(_name); - str.concat((char)0x3a); - str.concat((char)0x20); - str.concat(_value); - str.concat(asyncsrv::T_rn); - return str; - } + String toString() const; }; /* @@ -472,23 +455,11 @@ class AsyncWebServerRequest { const std::list& getHeaders() const { return _headers; } - size_t getHeaderNames(std::vector& names) const { - names.clear(); - const size_t size = _headers.size(); - names.reserve(size); - for (const auto& h : _headers) { - names.push_back(h.name().c_str()); - } - return size; - } + size_t getHeaderNames(std::vector& names) const; // Remove a header from the request. // It will free the memory and prevent the header to be seen during request processing. - bool removeHeader(const char* name) { - const size_t size = _headers.size(); - _headers.remove_if([name](const AsyncWebHeader& header) { return header.name().equalsIgnoreCase(name); }); - return size != _headers.size(); - } + bool removeHeader(const char* name); // Remove all request headers. void removeHeaders() { _headers.clear(); } @@ -509,26 +480,11 @@ class AsyncWebServerRequest { bool hasAttribute(const char* name) const { return _attributes.find(name) != _attributes.end(); } - const String& getAttribute(const char* name, const String& defaultValue = emptyString) const { - auto it = _attributes.find(name); - return it != _attributes.end() ? it->second : defaultValue; - } - bool getAttribute(const char* name, bool defaultValue) const { - auto it = _attributes.find(name); - return it != _attributes.end() ? it->second == "1" : defaultValue; - } - long getAttribute(const char* name, long defaultValue) const { - auto it = _attributes.find(name); - return it != _attributes.end() ? it->second.toInt() : defaultValue; - } - float getAttribute(const char* name, float defaultValue) const { - auto it = _attributes.find(name); - return it != _attributes.end() ? it->second.toFloat() : defaultValue; - } - double getAttribute(const char* name, double defaultValue) const { - auto it = _attributes.find(name); - return it != _attributes.end() ? it->second.toDouble() : defaultValue; - } + const String& getAttribute(const char* name, const String& defaultValue = emptyString) const; + bool getAttribute(const char* name, bool defaultValue) const; + long getAttribute(const char* name, long defaultValue) const; + float getAttribute(const char* name, float defaultValue) const; + double getAttribute(const char* name, double defaultValue) const; String urlDecode(const String& text) const; }; @@ -580,53 +536,15 @@ class AsyncMiddlewareFunction : public AsyncMiddleware { // For internal use only: super class to add/remove middleware to server or handlers class AsyncMiddlewareChain { public: - virtual ~AsyncMiddlewareChain() { - for (AsyncMiddleware* m : _middlewares) - if (m->_freeOnRemoval) - delete m; - } - void addMiddleware(ArMiddlewareCallback fn) { - AsyncMiddlewareFunction* m = new AsyncMiddlewareFunction(fn); - m->_freeOnRemoval = true; - _middlewares.emplace_back(m); - } - void addMiddleware(AsyncMiddleware* middleware) { - if (middleware) - _middlewares.emplace_back(middleware); - } - void addMiddlewares(std::vector middlewares) { - for (AsyncMiddleware* m : middlewares) - addMiddleware(m); - } - bool removeMiddleware(AsyncMiddleware* middleware) { - // remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector. - const size_t size = _middlewares.size(); - _middlewares.erase(std::remove_if(_middlewares.begin(), _middlewares.end(), [middleware](AsyncMiddleware* m) { - if (m == middleware) { - if (m->_freeOnRemoval) - delete m; - return true; - } - return false; - }), - _middlewares.end()); - return size != _middlewares.size(); - } + virtual ~AsyncMiddlewareChain(); + + void addMiddleware(ArMiddlewareCallback fn); + void addMiddleware(AsyncMiddleware* middleware); + void addMiddlewares(std::vector middlewares); + bool removeMiddleware(AsyncMiddleware* middleware); + // For internal use only - void _runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer) { - if (!_middlewares.size()) - return finalizer(); - ArMiddlewareNext next; - std::list::iterator it = _middlewares.begin(); - next = [this, &next, &it, request, finalizer]() { - if (it == _middlewares.end()) - return finalizer(); - AsyncMiddleware* m = *it; - it++; - return m->run(request, next); - }; - return next(); - } + void _runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer); protected: std::list _middlewares; @@ -647,13 +565,9 @@ class AuthenticationMiddleware : public AsyncMiddleware { void setPasswordIsHash(bool passwordIsHash) { _hash = passwordIsHash; } void setAuthType(AuthType authType) { _authType = authType; } - bool allowed(AsyncWebServerRequest* request) { - return _authType == AUTH_NONE || !_username.length() || !_password.length() || request->authenticate(_username.c_str(), _password.c_str(), _realm, _hash); - } + bool allowed(AsyncWebServerRequest* request) { return _authType == AUTH_NONE || !_username.length() || !_password.length() || request->authenticate(_username.c_str(), _password.c_str(), _realm, _hash); } - void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { - return allowed(request) ? next() : request->requestAuthentication(_realm, _authType == AUTH_DIGEST); - } + void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return allowed(request) ? next() : request->requestAuthentication(_realm, _authType == AUTH_DIGEST); } private: String _username; @@ -669,11 +583,8 @@ class AuthorizationMiddleware : public AsyncMiddleware { public: AuthorizationMiddleware(ArAuthorizeFunction authorizeConnectHandler) : _code(403), _authz(authorizeConnectHandler) {} AuthorizationMiddleware(int code, ArAuthorizeFunction authorizeConnectHandler) : _code(code), _authz(authorizeConnectHandler) {} - void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { - if (_authz && !_authz(request)) - return request->send(_code); - return next(); - } + + void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return _authz && !_authz(request) ? request->send(_code) : next(); } private: int _code; @@ -685,23 +596,8 @@ class HeaderFreeMiddleware : public AsyncMiddleware { public: void keep(const char* name) { _toKeep.push_back(name); } void unKeep(const char* name) { _toKeep.erase(std::remove(_toKeep.begin(), _toKeep.end(), name), _toKeep.end()); } - void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { - std::vector reqHeaders; - request->getHeaderNames(reqHeaders); - for (const char* h : reqHeaders) { - bool keep = false; - for (const char* k : _toKeep) { - if (strcasecmp(h, k) == 0) { - keep = true; - break; - } - } - if (!keep) { - request->removeHeader(h); - } - } - next(); - } + + void run(AsyncWebServerRequest* request, ArMiddlewareNext next); private: std::vector _toKeep; @@ -712,11 +608,8 @@ class HeaderFilterMiddleware : public AsyncMiddleware { public: void filter(const char* name) { _toRemove.push_back(name); } void unFilter(const char* name) { _toRemove.erase(std::remove(_toRemove.begin(), _toRemove.end(), name), _toRemove.end()); } - void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { - for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it) - request->removeHeader(*it); - next(); - } + + void run(AsyncWebServerRequest* request, ArMiddlewareNext next); private: std::vector _toRemove; @@ -763,23 +656,7 @@ class RateLimitMiddleware : public AsyncMiddleware { void setMaxRequests(size_t maxRequests) { _maxRequests = maxRequests; } void setWindowSize(uint32_t seconds) { _windowSizeMillis = seconds * 1000; } - bool isRequestAllowed(uint32_t& retryAfterSeconds) { - uint32_t now = millis(); - - while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis) - _requestTimes.pop_front(); - - _requestTimes.push_back(now); - - if (_requestTimes.size() > _maxRequests) { - _requestTimes.pop_front(); - retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1; - return false; - } - - retryAfterSeconds = 0; - return true; - } + bool isRequestAllowed(uint32_t& retryAfterSeconds); void run(AsyncWebServerRequest* request, ArMiddlewareNext next); @@ -830,20 +707,8 @@ class AsyncWebHandler : public AsyncMiddlewareChain { public: AsyncWebHandler() {} - AsyncWebHandler& setFilter(ArRequestFilterFunction fn) { - _filter = fn; - return *this; - } - AsyncWebHandler& setAuthentication(const char* username, const char* password) { - if (username == nullptr || password == nullptr || strlen(username) == 0 || strlen(password) == 0) - return *this; - AuthenticationMiddleware* m = new AuthenticationMiddleware(); - m->setUsername(username); - m->setPassword(password); - m->_freeOnRemoval = true; - addMiddleware(m); - return *this; - }; + AsyncWebHandler& setFilter(ArRequestFilterFunction fn); + AsyncWebHandler& setAuthentication(const char* username, const char* password); AsyncWebHandler& setAuthentication(const String& username, const String& password) { return setAuthentication(username.c_str(), password.c_str()); }; bool filter(AsyncWebServerRequest* request) { return _filter == NULL || _filter(request); } virtual ~AsyncWebHandler() {} diff --git a/src/Middlewares.cpp b/src/Middleware.cpp similarity index 52% rename from src/Middlewares.cpp rename to src/Middleware.cpp index 1f0252a..5dc3d05 100644 --- a/src/Middlewares.cpp +++ b/src/Middleware.cpp @@ -1,4 +1,80 @@ -#include "ESPAsyncWebServer.h" +#include + +AsyncMiddlewareChain::~AsyncMiddlewareChain() { + for (AsyncMiddleware* m : _middlewares) + if (m->_freeOnRemoval) + delete m; +} + +void AsyncMiddlewareChain::addMiddleware(ArMiddlewareCallback fn) { + AsyncMiddlewareFunction* m = new AsyncMiddlewareFunction(fn); + m->_freeOnRemoval = true; + _middlewares.emplace_back(m); +} + +void AsyncMiddlewareChain::addMiddleware(AsyncMiddleware* middleware) { + if (middleware) + _middlewares.emplace_back(middleware); +} + +void AsyncMiddlewareChain::addMiddlewares(std::vector middlewares) { + for (AsyncMiddleware* m : middlewares) + addMiddleware(m); +} + +bool AsyncMiddlewareChain::removeMiddleware(AsyncMiddleware* middleware) { + // remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector. + const size_t size = _middlewares.size(); + _middlewares.erase(std::remove_if(_middlewares.begin(), _middlewares.end(), [middleware](AsyncMiddleware* m) { + if (m == middleware) { + if (m->_freeOnRemoval) + delete m; + return true; + } + return false; + }), + _middlewares.end()); + return size != _middlewares.size(); +} + +void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer) { + if (!_middlewares.size()) + return finalizer(); + ArMiddlewareNext next; + std::list::iterator it = _middlewares.begin(); + next = [this, &next, &it, request, finalizer]() { + if (it == _middlewares.end()) + return finalizer(); + AsyncMiddleware* m = *it; + it++; + return m->run(request, next); + }; + return next(); +} + +void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) { + std::vector reqHeaders; + request->getHeaderNames(reqHeaders); + for (const char* h : reqHeaders) { + bool keep = false; + for (const char* k : _toKeep) { + if (strcasecmp(h, k) == 0) { + keep = true; + break; + } + } + if (!keep) { + request->removeHeader(h); + } + } + next(); +} + +void HeaderFilterMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) { + for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it) + request->removeHeader(*it); + next(); +} void LoggingMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) { if (!isEnabled()) { @@ -90,6 +166,24 @@ void CorsMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) } } +bool RateLimitMiddleware::isRequestAllowed(uint32_t& retryAfterSeconds) { + uint32_t now = millis(); + + while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis) + _requestTimes.pop_front(); + + _requestTimes.push_back(now); + + if (_requestTimes.size() > _maxRequests) { + _requestTimes.pop_front(); + retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1; + return false; + } + + retryAfterSeconds = 0; + return true; +} + void RateLimitMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) { uint32_t retryAfterSeconds; if (isRequestAllowed(retryAfterSeconds)) { diff --git a/src/WebHandlerImpl.h b/src/WebHandlerImpl.h index 2838a8e..cb43f80 100644 --- a/src/WebHandlerImpl.h +++ b/src/WebHandlerImpl.h @@ -63,10 +63,7 @@ class AsyncStaticWebHandler : public AsyncWebHandler { AsyncStaticWebHandler& setLastModified(time_t last_modified); AsyncStaticWebHandler& setLastModified(); // sets to current time. Make sure sntp is runing and time is updated #endif - AsyncStaticWebHandler& setTemplateProcessor(AwsTemplateProcessor newCallback) { - _callback = newCallback; - return *this; - } + AsyncStaticWebHandler& setTemplateProcessor(AwsTemplateProcessor newCallback); }; class AsyncCallbackWebHandler : public AsyncWebHandler { @@ -81,68 +78,17 @@ class AsyncCallbackWebHandler : public AsyncWebHandler { public: AsyncCallbackWebHandler() : _uri(), _method(HTTP_ANY), _onRequest(NULL), _onUpload(NULL), _onBody(NULL), _isRegex(false) {} - void setUri(const String& uri) { - _uri = uri; - _isRegex = uri.startsWith("^") && uri.endsWith("$"); - } + void setUri(const String& uri); void setMethod(WebRequestMethodComposite method) { _method = method; } void onRequest(ArRequestHandlerFunction fn) { _onRequest = fn; } void onUpload(ArUploadHandlerFunction fn) { _onUpload = fn; } void onBody(ArBodyHandlerFunction fn) { _onBody = fn; } - virtual bool canHandle(AsyncWebServerRequest* request) override final { - - if (!_onRequest) - return false; - - if (!(_method & request->method())) - return false; - -#ifdef ASYNCWEBSERVER_REGEX - if (_isRegex) { - std::regex pattern(_uri.c_str()); - std::smatch matches; - std::string s(request->url().c_str()); - if (std::regex_search(s, matches, pattern)) { - for (size_t i = 1; i < matches.size(); ++i) { // start from 1 - request->_addPathParam(matches[i].str().c_str()); - } - } else { - return false; - } - } else -#endif - if (_uri.length() && _uri.startsWith("/*.")) { - String uriTemplate = String(_uri); - uriTemplate = uriTemplate.substring(uriTemplate.lastIndexOf(".")); - if (!request->url().endsWith(uriTemplate)) - return false; - } else if (_uri.length() && _uri.endsWith("*")) { - String uriTemplate = String(_uri); - uriTemplate = uriTemplate.substring(0, uriTemplate.length() - 1); - if (!request->url().startsWith(uriTemplate)) - return false; - } else if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/"))) - return false; - - return true; - } - - virtual void handleRequest(AsyncWebServerRequest* request) override final { - if (_onRequest) - _onRequest(request); - else - request->send(500); - } - virtual void handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) override final { - if (_onUpload) - _onUpload(request, filename, index, data, len, final); - } - virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final { - if (_onBody) - _onBody(request, data, len, index, total); - } - virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; } + virtual bool canHandle(AsyncWebServerRequest* request) override final; + virtual void handleRequest(AsyncWebServerRequest* request) override final; + virtual void handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) override final; + virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final; + virtual bool isRequestHandlerTrivial() override final { return !_onRequest; } }; #endif /* ASYNCWEBSERVERHANDLERIMPL_H_ */ diff --git a/src/WebHandlers.cpp b/src/WebHandlers.cpp index 37ce2fa..ef976e5 100644 --- a/src/WebHandlers.cpp +++ b/src/WebHandlers.cpp @@ -23,6 +23,21 @@ using namespace asyncsrv; +AsyncWebHandler& AsyncWebHandler::setFilter(ArRequestFilterFunction fn) { + _filter = fn; + return *this; +} +AsyncWebHandler& AsyncWebHandler::setAuthentication(const char* username, const char* password) { + if (username == nullptr || password == nullptr || strlen(username) == 0 || strlen(password) == 0) + return *this; + AuthenticationMiddleware* m = new AuthenticationMiddleware(); + m->setUsername(username); + m->setPassword(password); + m->_freeOnRemoval = true; + addMiddleware(m); + return *this; +}; + AsyncStaticWebHandler::AsyncStaticWebHandler(const char* uri, FS& fs, const char* path, const char* cache_control) : _fs(fs), _uri(uri), _path(path), _default_file(F("index.htm")), _cache_control(cache_control), _last_modified(), _callback(nullptr) { // Ensure leading '/' @@ -234,3 +249,65 @@ void AsyncStaticWebHandler::handleRequest(AsyncWebServerRequest* request) { request->send(404); } } + +AsyncStaticWebHandler& AsyncStaticWebHandler::setTemplateProcessor(AwsTemplateProcessor newCallback) { + _callback = newCallback; + return *this; +} + +void AsyncCallbackWebHandler::setUri(const String& uri) { + _uri = uri; + _isRegex = uri.startsWith("^") && uri.endsWith("$"); +} + +bool AsyncCallbackWebHandler::canHandle(AsyncWebServerRequest* request) { + if (!_onRequest) + return false; + + if (!(_method & request->method())) + return false; + +#ifdef ASYNCWEBSERVER_REGEX + if (_isRegex) { + std::regex pattern(_uri.c_str()); + std::smatch matches; + std::string s(request->url().c_str()); + if (std::regex_search(s, matches, pattern)) { + for (size_t i = 1; i < matches.size(); ++i) { // start from 1 + request->_addPathParam(matches[i].str().c_str()); + } + } else { + return false; + } + } else +#endif + if (_uri.length() && _uri.startsWith("/*.")) { + String uriTemplate = String(_uri); + uriTemplate = uriTemplate.substring(uriTemplate.lastIndexOf(".")); + if (!request->url().endsWith(uriTemplate)) + return false; + } else if (_uri.length() && _uri.endsWith("*")) { + String uriTemplate = String(_uri); + uriTemplate = uriTemplate.substring(0, uriTemplate.length() - 1); + if (!request->url().startsWith(uriTemplate)) + return false; + } else if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/"))) + return false; + + return true; +} + +void AsyncCallbackWebHandler::handleRequest(AsyncWebServerRequest* request) { + if (_onRequest) + _onRequest(request); + else + request->send(500); +} +void AsyncCallbackWebHandler::handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) { + if (_onUpload) + _onUpload(request, filename, index, data, len, final); +} +void AsyncCallbackWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) { + if (_onBody) + _onBody(request, data, len, index, total); +} \ No newline at end of file diff --git a/src/WebRequest.cpp b/src/WebRequest.cpp index 67453e1..ff369b2 100644 --- a/src/WebRequest.cpp +++ b/src/WebRequest.cpp @@ -636,6 +636,22 @@ const AsyncWebHeader* AsyncWebServerRequest::getHeader(size_t num) const { return &(*std::next(_headers.cbegin(), num)); } +size_t AsyncWebServerRequest::getHeaderNames(std::vector& names) const { + names.clear(); + const size_t size = _headers.size(); + names.reserve(size); + for (const auto& h : _headers) { + names.push_back(h.name().c_str()); + } + return size; +} + +bool AsyncWebServerRequest::removeHeader(const char* name) { + const size_t size = _headers.size(); + _headers.remove_if([name](const AsyncWebHeader& header) { return header.name().equalsIgnoreCase(name); }); + return size != _headers.size(); +} + size_t AsyncWebServerRequest::params() const { return _params.size(); } @@ -670,6 +686,27 @@ const AsyncWebParameter* AsyncWebServerRequest::getParam(size_t num) const { return &(*std::next(_params.cbegin(), num)); } +const String& AsyncWebServerRequest::getAttribute(const char* name, const String& defaultValue) const { + auto it = _attributes.find(name); + return it != _attributes.end() ? it->second : defaultValue; +} +bool AsyncWebServerRequest::getAttribute(const char* name, bool defaultValue) const { + auto it = _attributes.find(name); + return it != _attributes.end() ? it->second == "1" : defaultValue; +} +long AsyncWebServerRequest::getAttribute(const char* name, long defaultValue) const { + auto it = _attributes.find(name); + return it != _attributes.end() ? it->second.toInt() : defaultValue; +} +float AsyncWebServerRequest::getAttribute(const char* name, float defaultValue) const { + auto it = _attributes.find(name); + return it != _attributes.end() ? it->second.toFloat() : defaultValue; +} +double AsyncWebServerRequest::getAttribute(const char* name, double defaultValue) const { + auto it = _attributes.find(name); + return it != _attributes.end() ? it->second.toDouble() : defaultValue; +} + AsyncWebServerResponse* AsyncWebServerRequest::beginResponse(int code, const char* contentType, const char* content, AwsTemplateProcessor callback) { if (callback) return new AsyncProgmemResponse(code, contentType, (const uint8_t*)content, strlen(content), callback);