Moved implementations in cpp files

This commit is contained in:
Mathieu Carbou
2024-09-29 21:29:59 +02:00
parent 546f9ed1c4
commit b473625d1d
16 changed files with 553 additions and 483 deletions

View File

@ -61,6 +61,8 @@ board = rpipicow
lib_deps =
bblanchon/ArduinoJson @ 7.2.0
khoih-prog/AsyncTCP_RP2040W @ 1.2.0
build_flags = ${env.build_flags}
-Wno-missing-field-initializers
; CI
@ -90,3 +92,5 @@ board = ${sysenv.PIO_BOARD}
lib_deps =
bblanchon/ArduinoJson @ 7.2.0
khoih-prog/AsyncTCP_RP2040W @ 1.2.0
build_flags = ${env.build_flags}
-Wno-missing-field-initializers

View File

@ -22,7 +22,6 @@
#include <rom/ets_sys.h>
#endif
#include "AsyncEventSource.h"
#include "literals.h"
using namespace asyncsrv;
@ -288,6 +287,12 @@ void AsyncEventSource::onConnect(ArEventHandlerFunction cb) {
_connectcb = cb;
}
void AsyncEventSource::authorizeConnect(ArAuthorizeConnectHandler cb) {
AuthorizationMiddleware* m = new AuthorizationMiddleware(401, cb);
m->_freeOnRemoval = true;
addMiddleware(m);
}
void AsyncEventSource::_addClient(AsyncEventSourceClient* client) {
if (!client)
return;

View File

@ -124,11 +124,7 @@ class AsyncEventSource : public AsyncWebHandler {
const char* url() const { return _url.c_str(); }
void close();
void onConnect(ArEventHandlerFunction cb);
void authorizeConnect(ArAuthorizeConnectHandler cb) {
AuthorizationMiddleware* m = new AuthorizationMiddleware(401, cb);
m->_freeOnRemoval = true;
addMiddleware(m);
}
void authorizeConnect(ArAuthorizeConnectHandler cb);
void send(const String& message, const String& event, uint32_t id = 0, uint32_t reconnect = 0) { send(message.c_str(), event.c_str(), id, reconnect); }
void send(const String& message, const char* event, uint32_t id = 0, uint32_t reconnect = 0) { send(message.c_str(), event, id, reconnect); }
void send(const char* message, const char* event = NULL, uint32_t id = 0, uint32_t reconnect = 0);

151
src/AsyncJson.cpp Normal file
View File

@ -0,0 +1,151 @@
#include "AsyncJson.h"
#if ARDUINOJSON_VERSION_MAJOR == 5
AsyncJsonResponse::AsyncJsonResponse(bool isArray) : _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.createArray();
else
_root = _jsonBuffer.createObject();
}
#elif ARDUINOJSON_VERSION_MAJOR == 6
AsyncJsonResponse::AsyncJsonResponse(bool isArray, size_t maxJsonBufferSize) : _jsonBuffer(maxJsonBufferSize), _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.createNestedArray();
else
_root = _jsonBuffer.createNestedObject();
}
#else
AsyncJsonResponse::AsyncJsonResponse(bool isArray) : _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
#endif
size_t AsyncJsonResponse::setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measureLength();
#else
_contentLength = measureJson(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t AsyncJsonResponse::_fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.printTo(dest);
#else
serializeJson(_root, dest);
#endif
return len;
}
#if ARDUINOJSON_VERSION_MAJOR == 6
PrettyAsyncJsonResponse::PrettyAsyncJsonResponse(bool isArray, size_t maxJsonBufferSize) : AsyncJsonResponse{isArray, maxJsonBufferSize} {}
#else
PrettyAsyncJsonResponse::PrettyAsyncJsonResponse(bool isArray) : AsyncJsonResponse{isArray} {}
#endif
size_t PrettyAsyncJsonResponse::setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measurePrettyLength();
#else
_contentLength = measureJsonPretty(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t PrettyAsyncJsonResponse::_fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.prettyPrintTo(dest);
#else
serializeJsonPretty(_root, dest);
#endif
return len;
}
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncCallbackJsonWebHandler::AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest, size_t maxJsonBufferSize)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), maxJsonBufferSize(maxJsonBufferSize), _maxContentLength(16384) {}
#else
AsyncCallbackJsonWebHandler::AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
#endif
bool AsyncCallbackJsonWebHandler::canHandle(AsyncWebServerRequest* request) {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(JSON_MIMETYPE))
return false;
return true;
}
void AsyncCallbackJsonWebHandler::handleRequest(AsyncWebServerRequest* request) {
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
#if ARDUINOJSON_VERSION_MAJOR == 5
DynamicJsonBuffer jsonBuffer;
JsonVariant json = jsonBuffer.parse((uint8_t*)(request->_tempObject));
if (json.success()) {
#elif ARDUINOJSON_VERSION_MAJOR == 6
DynamicJsonDocument jsonBuffer(this->maxJsonBufferSize);
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#else
JsonDocument jsonBuffer;
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#endif
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
void AsyncCallbackJsonWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}

View File

@ -65,102 +65,35 @@ class AsyncJsonResponse : public AsyncAbstractResponse {
bool _isValid;
public:
#if ARDUINOJSON_VERSION_MAJOR == 5
AsyncJsonResponse(bool isArray = false) : _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.createArray();
else
_root = _jsonBuffer.createObject();
}
#elif ARDUINOJSON_VERSION_MAJOR == 6
AsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE) : _jsonBuffer(maxJsonBufferSize), _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.createNestedArray();
else
_root = _jsonBuffer.createNestedObject();
}
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
AsyncJsonResponse(bool isArray = false) : _isValid{false} {
_code = 200;
_contentType = JSON_MIMETYPE;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
AsyncJsonResponse(bool isArray = false);
#endif
JsonVariant& getRoot() { return _root; }
bool _sourceValid() const { return _isValid; }
size_t setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measureLength();
#else
_contentLength = measureJson(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t setLength();
size_t getSize() const { return _jsonBuffer.size(); }
size_t _fillBuffer(uint8_t* data, size_t len);
#if ARDUINOJSON_VERSION_MAJOR >= 6
bool overflowed() const { return _jsonBuffer.overflowed(); }
#endif
size_t _fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.printTo(dest);
#else
serializeJson(_root, dest);
#endif
return len;
}
};
class PrettyAsyncJsonResponse : public AsyncJsonResponse {
public:
#if ARDUINOJSON_VERSION_MAJOR == 6
PrettyAsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE) : AsyncJsonResponse{isArray, maxJsonBufferSize} {}
PrettyAsyncJsonResponse(bool isArray = false, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
PrettyAsyncJsonResponse(bool isArray = false) : AsyncJsonResponse{isArray} {}
PrettyAsyncJsonResponse(bool isArray = false);
#endif
size_t setLength() {
#if ARDUINOJSON_VERSION_MAJOR == 5
_contentLength = _root.measurePrettyLength();
#else
_contentLength = measureJsonPretty(_root);
#endif
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t _fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
#if ARDUINOJSON_VERSION_MAJOR == 5
_root.prettyPrintTo(dest);
#else
serializeJsonPretty(_root, dest);
#endif
return len;
}
size_t setLength();
size_t _fillBuffer(uint8_t* data, size_t len);
};
typedef std::function<void(AsyncWebServerRequest* request, JsonVariant& json)> ArJsonRequestHandlerFunction;
class AsyncCallbackJsonWebHandler : public AsyncWebHandler {
private:
protected:
const String _uri;
WebRequestMethodComposite _method;
@ -173,80 +106,19 @@ class AsyncCallbackJsonWebHandler : public AsyncWebHandler {
public:
#if ARDUINOJSON_VERSION_MAJOR == 6
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), maxJsonBufferSize(maxJsonBufferSize), _maxContentLength(16384) {}
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr, size_t maxJsonBufferSize = DYNAMIC_JSON_DOCUMENT_SIZE);
#else
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
AsyncCallbackJsonWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr);
#endif
void setMethod(WebRequestMethodComposite method) { _method = method; }
void setMaxContentLength(int maxContentLength) { _maxContentLength = maxContentLength; }
void onRequest(ArJsonRequestHandlerFunction fn) { _onRequest = fn; }
virtual bool canHandle(AsyncWebServerRequest* request) override final {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(JSON_MIMETYPE))
return false;
return true;
}
virtual void handleRequest(AsyncWebServerRequest* request) override final {
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
#if ARDUINOJSON_VERSION_MAJOR == 5
DynamicJsonBuffer jsonBuffer;
JsonVariant json = jsonBuffer.parse((uint8_t*)(request->_tempObject));
if (json.success()) {
#elif ARDUINOJSON_VERSION_MAJOR == 6
DynamicJsonDocument jsonBuffer(this->maxJsonBufferSize);
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#else
JsonDocument jsonBuffer;
DeserializationError error = deserializeJson(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
#endif
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {
}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}
virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; }
virtual bool canHandle(AsyncWebServerRequest* request) override final;
virtual void handleRequest(AsyncWebServerRequest* request) override final;
virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final;
virtual bool isRequestHandlerTrivial() override final { return !_onRequest; }
};
#endif

79
src/AsyncMessagePack.cpp Normal file
View File

@ -0,0 +1,79 @@
#include "AsyncMessagePack.h"
AsyncMessagePackResponse::AsyncMessagePackResponse(bool isArray) : _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_msgpack;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
size_t AsyncMessagePackResponse::setLength() {
_contentLength = measureMsgPack(_root);
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t AsyncMessagePackResponse::_fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
serializeMsgPack(_root, dest);
return len;
}
AsyncCallbackMessagePackWebHandler::AsyncCallbackMessagePackWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
bool AsyncCallbackMessagePackWebHandler::canHandle(AsyncWebServerRequest* request) {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(asyncsrv::T_application_msgpack))
return false;
return true;
}
void AsyncCallbackMessagePackWebHandler::handleRequest(AsyncWebServerRequest* request) {
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
JsonDocument jsonBuffer;
DeserializationError error = deserializeMsgPack(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
void AsyncCallbackMessagePackWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}

View File

@ -25,7 +25,6 @@
#include <ESPAsyncWebServer.h>
#include "ChunkPrint.h"
#include "literals.h"
class AsyncMessagePackResponse : public AsyncAbstractResponse {
protected:
@ -34,34 +33,13 @@ class AsyncMessagePackResponse : public AsyncAbstractResponse {
bool _isValid;
public:
AsyncMessagePackResponse(bool isArray = false) : _isValid{false} {
_code = 200;
_contentType = asyncsrv::T_application_msgpack;
if (isArray)
_root = _jsonBuffer.add<JsonArray>();
else
_root = _jsonBuffer.add<JsonObject>();
}
AsyncMessagePackResponse(bool isArray = false);
JsonVariant& getRoot() { return _root; }
bool _sourceValid() const { return _isValid; }
size_t setLength() {
_contentLength = measureMsgPack(_root);
if (_contentLength) {
_isValid = true;
}
return _contentLength;
}
size_t setLength();
size_t getSize() const { return _jsonBuffer.size(); }
size_t _fillBuffer(uint8_t* data, size_t len) {
ChunkPrint dest(data, _sentLength, len);
serializeMsgPack(_root, dest);
return len;
}
size_t _fillBuffer(uint8_t* data, size_t len);
};
class AsyncCallbackMessagePackWebHandler : public AsyncWebHandler {
@ -76,66 +54,14 @@ class AsyncCallbackMessagePackWebHandler : public AsyncWebHandler {
size_t _maxContentLength;
public:
AsyncCallbackMessagePackWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr)
: _uri(uri), _method(HTTP_GET | HTTP_POST | HTTP_PUT | HTTP_PATCH), _onRequest(onRequest), _maxContentLength(16384) {}
AsyncCallbackMessagePackWebHandler(const String& uri, ArJsonRequestHandlerFunction onRequest = nullptr);
void setMethod(WebRequestMethodComposite method) { _method = method; }
void setMaxContentLength(int maxContentLength) { _maxContentLength = maxContentLength; }
void onRequest(ArJsonRequestHandlerFunction fn) { _onRequest = fn; }
virtual bool canHandle(AsyncWebServerRequest* request) override final {
if (!_onRequest)
return false;
WebRequestMethodComposite request_method = request->method();
if (!(_method & request_method))
return false;
if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
if (request_method != HTTP_GET && !request->contentType().equalsIgnoreCase(asyncsrv::T_application_msgpack))
return false;
return true;
}
virtual void handleRequest(AsyncWebServerRequest* request) override final {
if (_onRequest) {
if (request->method() == HTTP_GET) {
JsonVariant json;
_onRequest(request, json);
return;
} else if (request->_tempObject != NULL) {
JsonDocument jsonBuffer;
DeserializationError error = deserializeMsgPack(jsonBuffer, (uint8_t*)(request->_tempObject));
if (!error) {
JsonVariant json = jsonBuffer.as<JsonVariant>();
_onRequest(request, json);
return;
}
}
request->send(_contentLength > _maxContentLength ? 413 : 400);
} else {
request->send(500);
}
}
virtual bool canHandle(AsyncWebServerRequest* request) override final;
virtual void handleRequest(AsyncWebServerRequest* request) override final;
virtual void handleUpload(__unused AsyncWebServerRequest* request, __unused const String& filename, __unused size_t index, __unused uint8_t* data, __unused size_t len, __unused bool final) override final {}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final {
if (_onRequest) {
_contentLength = total;
if (total > 0 && request->_tempObject == NULL && total < _maxContentLength) {
request->_tempObject = malloc(total);
}
if (request->_tempObject != NULL) {
memcpy((uint8_t*)(request->_tempObject) + index, data, len);
}
}
}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final;
virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; }
};

22
src/AsyncWebHeader.cpp Normal file
View File

@ -0,0 +1,22 @@
#include <ESPAsyncWebServer.h>
AsyncWebHeader::AsyncWebHeader(const String& data) {
if (!data)
return;
int index = data.indexOf(':');
if (index < 0)
return;
_name = data.substring(0, index);
_value = data.substring(index + 2);
}
String AsyncWebHeader::toString() const {
String str;
str.reserve(_name.length() + _value.length() + 2);
str.concat(_name);
str.concat((char)0x3a);
str.concat((char)0x20);
str.concat(_value);
str.concat(asyncsrv::T_rn);
return str;
}

View File

@ -281,7 +281,7 @@ class AsyncWebSocket : public AsyncWebHandler {
public:
explicit AsyncWebSocket(const char* url) : _url(url), _cNextId(1), _enabled(true) {}
AsyncWebSocket(const String& url) : _url(url), _cNextId(1), _enabled(true) {}
~AsyncWebSocket(){};
~AsyncWebSocket() {};
const char* url() const { return _url.c_str(); }
void enable(bool e) { _enabled = e; }
bool enabled() const { return _enabled; }
@ -339,15 +339,8 @@ class AsyncWebSocket : public AsyncWebHandler {
size_t printfAll_P(PGM_P formatP, ...) __attribute__((format(printf, 2, 3)));
#endif
// event listener
void onEvent(AwsEventHandler handler) {
_eventHandler = handler;
}
// Handshake Handler
void handleHandshake(AwsHandshakeHandler handler) {
_handshakeHandler = handler;
}
void onEvent(AwsEventHandler handler) { _eventHandler = handler; }
void handleHandshake(AwsHandshakeHandler handler) { _handshakeHandler = handler; }
// system callbacks (do not call)
uint32_t _getNextId() { return _cNextId++; }

16
src/ChunkPrint.cpp Normal file
View File

@ -0,0 +1,16 @@
#include <ChunkPrint.h>
ChunkPrint::ChunkPrint(uint8_t* destination, size_t from, size_t len)
: _destination(destination), _to_skip(from), _to_write(len), _pos{0} {}
size_t ChunkPrint::write(uint8_t c) {
if (_to_skip > 0) {
_to_skip--;
return 1;
} else if (_to_write > 0) {
_to_write--;
_destination[_pos++] = c;
return 1;
}
return 0;
}

View File

@ -11,22 +11,9 @@ class ChunkPrint : public Print {
size_t _pos;
public:
ChunkPrint(uint8_t* destination, size_t from, size_t len)
: _destination(destination), _to_skip(from), _to_write(len), _pos{0} {}
ChunkPrint(uint8_t* destination, size_t from, size_t len);
virtual ~ChunkPrint() {}
size_t write(uint8_t c) {
if (_to_skip > 0) {
_to_skip--;
return 1;
} else if (_to_write > 0) {
_to_write--;
_destination[_pos++] = c;
return 1;
}
return 0;
}
size_t write(const uint8_t* buffer, size_t size) {
return this->Print::write(buffer, size);
}
size_t write(uint8_t c);
size_t write(const uint8_t* buffer, size_t size) { return this->Print::write(buffer, size); }
};
#endif

View File

@ -24,6 +24,7 @@
#include "Arduino.h"
#include "FS.h"
#include <algorithm>
#include <deque>
#include <functional>
#include <list>
@ -141,33 +142,15 @@ class AsyncWebHeader {
public:
AsyncWebHeader() = default;
AsyncWebHeader(const AsyncWebHeader&) = default;
AsyncWebHeader(const char* name, const char* value) : _name(name), _value(value) {}
AsyncWebHeader(const String& name, const String& value) : _name(name), _value(value) {}
AsyncWebHeader(const String& data) {
if (!data)
return;
int index = data.indexOf(':');
if (index < 0)
return;
_name = data.substring(0, index);
_value = data.substring(index + 2);
}
AsyncWebHeader(const String& data);
AsyncWebHeader& operator=(const AsyncWebHeader&) = default;
const String& name() const { return _name; }
const String& value() const { return _value; }
String toString() const {
String str;
str.reserve(_name.length() + _value.length() + 2);
str.concat(_name);
str.concat((char)0x3a);
str.concat((char)0x20);
str.concat(_value);
str.concat(asyncsrv::T_rn);
return str;
}
String toString() const;
};
/*
@ -472,23 +455,11 @@ class AsyncWebServerRequest {
const std::list<AsyncWebHeader>& getHeaders() const { return _headers; }
size_t getHeaderNames(std::vector<const char*>& names) const {
names.clear();
const size_t size = _headers.size();
names.reserve(size);
for (const auto& h : _headers) {
names.push_back(h.name().c_str());
}
return size;
}
size_t getHeaderNames(std::vector<const char*>& names) const;
// Remove a header from the request.
// It will free the memory and prevent the header to be seen during request processing.
bool removeHeader(const char* name) {
const size_t size = _headers.size();
_headers.remove_if([name](const AsyncWebHeader& header) { return header.name().equalsIgnoreCase(name); });
return size != _headers.size();
}
bool removeHeader(const char* name);
// Remove all request headers.
void removeHeaders() { _headers.clear(); }
@ -509,26 +480,11 @@ class AsyncWebServerRequest {
bool hasAttribute(const char* name) const { return _attributes.find(name) != _attributes.end(); }
const String& getAttribute(const char* name, const String& defaultValue = emptyString) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second : defaultValue;
}
bool getAttribute(const char* name, bool defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second == "1" : defaultValue;
}
long getAttribute(const char* name, long defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toInt() : defaultValue;
}
float getAttribute(const char* name, float defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toFloat() : defaultValue;
}
double getAttribute(const char* name, double defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toDouble() : defaultValue;
}
const String& getAttribute(const char* name, const String& defaultValue = emptyString) const;
bool getAttribute(const char* name, bool defaultValue) const;
long getAttribute(const char* name, long defaultValue) const;
float getAttribute(const char* name, float defaultValue) const;
double getAttribute(const char* name, double defaultValue) const;
String urlDecode(const String& text) const;
};
@ -580,53 +536,15 @@ class AsyncMiddlewareFunction : public AsyncMiddleware {
// For internal use only: super class to add/remove middleware to server or handlers
class AsyncMiddlewareChain {
public:
virtual ~AsyncMiddlewareChain() {
for (AsyncMiddleware* m : _middlewares)
if (m->_freeOnRemoval)
delete m;
}
void addMiddleware(ArMiddlewareCallback fn) {
AsyncMiddlewareFunction* m = new AsyncMiddlewareFunction(fn);
m->_freeOnRemoval = true;
_middlewares.emplace_back(m);
}
void addMiddleware(AsyncMiddleware* middleware) {
if (middleware)
_middlewares.emplace_back(middleware);
}
void addMiddlewares(std::vector<AsyncMiddleware*> middlewares) {
for (AsyncMiddleware* m : middlewares)
addMiddleware(m);
}
bool removeMiddleware(AsyncMiddleware* middleware) {
// remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector.
const size_t size = _middlewares.size();
_middlewares.erase(std::remove_if(_middlewares.begin(), _middlewares.end(), [middleware](AsyncMiddleware* m) {
if (m == middleware) {
if (m->_freeOnRemoval)
delete m;
return true;
}
return false;
}),
_middlewares.end());
return size != _middlewares.size();
}
virtual ~AsyncMiddlewareChain();
void addMiddleware(ArMiddlewareCallback fn);
void addMiddleware(AsyncMiddleware* middleware);
void addMiddlewares(std::vector<AsyncMiddleware*> middlewares);
bool removeMiddleware(AsyncMiddleware* middleware);
// For internal use only
void _runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer) {
if (!_middlewares.size())
return finalizer();
ArMiddlewareNext next;
std::list<AsyncMiddleware*>::iterator it = _middlewares.begin();
next = [this, &next, &it, request, finalizer]() {
if (it == _middlewares.end())
return finalizer();
AsyncMiddleware* m = *it;
it++;
return m->run(request, next);
};
return next();
}
void _runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer);
protected:
std::list<AsyncMiddleware*> _middlewares;
@ -647,13 +565,9 @@ class AuthenticationMiddleware : public AsyncMiddleware {
void setPasswordIsHash(bool passwordIsHash) { _hash = passwordIsHash; }
void setAuthType(AuthType authType) { _authType = authType; }
bool allowed(AsyncWebServerRequest* request) {
return _authType == AUTH_NONE || !_username.length() || !_password.length() || request->authenticate(_username.c_str(), _password.c_str(), _realm, _hash);
}
bool allowed(AsyncWebServerRequest* request) { return _authType == AUTH_NONE || !_username.length() || !_password.length() || request->authenticate(_username.c_str(), _password.c_str(), _realm, _hash); }
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
return allowed(request) ? next() : request->requestAuthentication(_realm, _authType == AUTH_DIGEST);
}
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return allowed(request) ? next() : request->requestAuthentication(_realm, _authType == AUTH_DIGEST); }
private:
String _username;
@ -669,11 +583,8 @@ class AuthorizationMiddleware : public AsyncMiddleware {
public:
AuthorizationMiddleware(ArAuthorizeFunction authorizeConnectHandler) : _code(403), _authz(authorizeConnectHandler) {}
AuthorizationMiddleware(int code, ArAuthorizeFunction authorizeConnectHandler) : _code(code), _authz(authorizeConnectHandler) {}
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
if (_authz && !_authz(request))
return request->send(_code);
return next();
}
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return _authz && !_authz(request) ? request->send(_code) : next(); }
private:
int _code;
@ -685,23 +596,8 @@ class HeaderFreeMiddleware : public AsyncMiddleware {
public:
void keep(const char* name) { _toKeep.push_back(name); }
void unKeep(const char* name) { _toKeep.erase(std::remove(_toKeep.begin(), _toKeep.end(), name), _toKeep.end()); }
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
std::vector<const char*> reqHeaders;
request->getHeaderNames(reqHeaders);
for (const char* h : reqHeaders) {
bool keep = false;
for (const char* k : _toKeep) {
if (strcasecmp(h, k) == 0) {
keep = true;
break;
}
}
if (!keep) {
request->removeHeader(h);
}
}
next();
}
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
std::vector<const char*> _toKeep;
@ -712,11 +608,8 @@ class HeaderFilterMiddleware : public AsyncMiddleware {
public:
void filter(const char* name) { _toRemove.push_back(name); }
void unFilter(const char* name) { _toRemove.erase(std::remove(_toRemove.begin(), _toRemove.end(), name), _toRemove.end()); }
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it)
request->removeHeader(*it);
next();
}
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private:
std::vector<const char*> _toRemove;
@ -763,23 +656,7 @@ class RateLimitMiddleware : public AsyncMiddleware {
void setMaxRequests(size_t maxRequests) { _maxRequests = maxRequests; }
void setWindowSize(uint32_t seconds) { _windowSizeMillis = seconds * 1000; }
bool isRequestAllowed(uint32_t& retryAfterSeconds) {
uint32_t now = millis();
while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis)
_requestTimes.pop_front();
_requestTimes.push_back(now);
if (_requestTimes.size() > _maxRequests) {
_requestTimes.pop_front();
retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1;
return false;
}
retryAfterSeconds = 0;
return true;
}
bool isRequestAllowed(uint32_t& retryAfterSeconds);
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
@ -830,20 +707,8 @@ class AsyncWebHandler : public AsyncMiddlewareChain {
public:
AsyncWebHandler() {}
AsyncWebHandler& setFilter(ArRequestFilterFunction fn) {
_filter = fn;
return *this;
}
AsyncWebHandler& setAuthentication(const char* username, const char* password) {
if (username == nullptr || password == nullptr || strlen(username) == 0 || strlen(password) == 0)
return *this;
AuthenticationMiddleware* m = new AuthenticationMiddleware();
m->setUsername(username);
m->setPassword(password);
m->_freeOnRemoval = true;
addMiddleware(m);
return *this;
};
AsyncWebHandler& setFilter(ArRequestFilterFunction fn);
AsyncWebHandler& setAuthentication(const char* username, const char* password);
AsyncWebHandler& setAuthentication(const String& username, const String& password) { return setAuthentication(username.c_str(), password.c_str()); };
bool filter(AsyncWebServerRequest* request) { return _filter == NULL || _filter(request); }
virtual ~AsyncWebHandler() {}

View File

@ -1,4 +1,80 @@
#include "ESPAsyncWebServer.h"
#include <ESPAsyncWebServer.h>
AsyncMiddlewareChain::~AsyncMiddlewareChain() {
for (AsyncMiddleware* m : _middlewares)
if (m->_freeOnRemoval)
delete m;
}
void AsyncMiddlewareChain::addMiddleware(ArMiddlewareCallback fn) {
AsyncMiddlewareFunction* m = new AsyncMiddlewareFunction(fn);
m->_freeOnRemoval = true;
_middlewares.emplace_back(m);
}
void AsyncMiddlewareChain::addMiddleware(AsyncMiddleware* middleware) {
if (middleware)
_middlewares.emplace_back(middleware);
}
void AsyncMiddlewareChain::addMiddlewares(std::vector<AsyncMiddleware*> middlewares) {
for (AsyncMiddleware* m : middlewares)
addMiddleware(m);
}
bool AsyncMiddlewareChain::removeMiddleware(AsyncMiddleware* middleware) {
// remove all middlewares from _middlewares vector being equal to middleware, delete them having _freeOnRemoval flag to true and resize the vector.
const size_t size = _middlewares.size();
_middlewares.erase(std::remove_if(_middlewares.begin(), _middlewares.end(), [middleware](AsyncMiddleware* m) {
if (m == middleware) {
if (m->_freeOnRemoval)
delete m;
return true;
}
return false;
}),
_middlewares.end());
return size != _middlewares.size();
}
void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest* request, ArMiddlewareNext finalizer) {
if (!_middlewares.size())
return finalizer();
ArMiddlewareNext next;
std::list<AsyncMiddleware*>::iterator it = _middlewares.begin();
next = [this, &next, &it, request, finalizer]() {
if (it == _middlewares.end())
return finalizer();
AsyncMiddleware* m = *it;
it++;
return m->run(request, next);
};
return next();
}
void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
std::vector<const char*> reqHeaders;
request->getHeaderNames(reqHeaders);
for (const char* h : reqHeaders) {
bool keep = false;
for (const char* k : _toKeep) {
if (strcasecmp(h, k) == 0) {
keep = true;
break;
}
}
if (!keep) {
request->removeHeader(h);
}
}
next();
}
void HeaderFilterMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
for (auto it = _toRemove.begin(); it != _toRemove.end(); ++it)
request->removeHeader(*it);
next();
}
void LoggingMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
if (!isEnabled()) {
@ -90,6 +166,24 @@ void CorsMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next)
}
}
bool RateLimitMiddleware::isRequestAllowed(uint32_t& retryAfterSeconds) {
uint32_t now = millis();
while (!_requestTimes.empty() && _requestTimes.front() <= now - _windowSizeMillis)
_requestTimes.pop_front();
_requestTimes.push_back(now);
if (_requestTimes.size() > _maxRequests) {
_requestTimes.pop_front();
retryAfterSeconds = (_windowSizeMillis - (now - _requestTimes.front())) / 1000 + 1;
return false;
}
retryAfterSeconds = 0;
return true;
}
void RateLimitMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
uint32_t retryAfterSeconds;
if (isRequestAllowed(retryAfterSeconds)) {

View File

@ -63,10 +63,7 @@ class AsyncStaticWebHandler : public AsyncWebHandler {
AsyncStaticWebHandler& setLastModified(time_t last_modified);
AsyncStaticWebHandler& setLastModified(); // sets to current time. Make sure sntp is runing and time is updated
#endif
AsyncStaticWebHandler& setTemplateProcessor(AwsTemplateProcessor newCallback) {
_callback = newCallback;
return *this;
}
AsyncStaticWebHandler& setTemplateProcessor(AwsTemplateProcessor newCallback);
};
class AsyncCallbackWebHandler : public AsyncWebHandler {
@ -81,68 +78,17 @@ class AsyncCallbackWebHandler : public AsyncWebHandler {
public:
AsyncCallbackWebHandler() : _uri(), _method(HTTP_ANY), _onRequest(NULL), _onUpload(NULL), _onBody(NULL), _isRegex(false) {}
void setUri(const String& uri) {
_uri = uri;
_isRegex = uri.startsWith("^") && uri.endsWith("$");
}
void setUri(const String& uri);
void setMethod(WebRequestMethodComposite method) { _method = method; }
void onRequest(ArRequestHandlerFunction fn) { _onRequest = fn; }
void onUpload(ArUploadHandlerFunction fn) { _onUpload = fn; }
void onBody(ArBodyHandlerFunction fn) { _onBody = fn; }
virtual bool canHandle(AsyncWebServerRequest* request) override final {
if (!_onRequest)
return false;
if (!(_method & request->method()))
return false;
#ifdef ASYNCWEBSERVER_REGEX
if (_isRegex) {
std::regex pattern(_uri.c_str());
std::smatch matches;
std::string s(request->url().c_str());
if (std::regex_search(s, matches, pattern)) {
for (size_t i = 1; i < matches.size(); ++i) { // start from 1
request->_addPathParam(matches[i].str().c_str());
}
} else {
return false;
}
} else
#endif
if (_uri.length() && _uri.startsWith("/*.")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(uriTemplate.lastIndexOf("."));
if (!request->url().endsWith(uriTemplate))
return false;
} else if (_uri.length() && _uri.endsWith("*")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(0, uriTemplate.length() - 1);
if (!request->url().startsWith(uriTemplate))
return false;
} else if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
return true;
}
virtual void handleRequest(AsyncWebServerRequest* request) override final {
if (_onRequest)
_onRequest(request);
else
request->send(500);
}
virtual void handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) override final {
if (_onUpload)
_onUpload(request, filename, index, data, len, final);
}
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final {
if (_onBody)
_onBody(request, data, len, index, total);
}
virtual bool isRequestHandlerTrivial() override final { return _onRequest ? false : true; }
virtual bool canHandle(AsyncWebServerRequest* request) override final;
virtual void handleRequest(AsyncWebServerRequest* request) override final;
virtual void handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) override final;
virtual void handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) override final;
virtual bool isRequestHandlerTrivial() override final { return !_onRequest; }
};
#endif /* ASYNCWEBSERVERHANDLERIMPL_H_ */

View File

@ -23,6 +23,21 @@
using namespace asyncsrv;
AsyncWebHandler& AsyncWebHandler::setFilter(ArRequestFilterFunction fn) {
_filter = fn;
return *this;
}
AsyncWebHandler& AsyncWebHandler::setAuthentication(const char* username, const char* password) {
if (username == nullptr || password == nullptr || strlen(username) == 0 || strlen(password) == 0)
return *this;
AuthenticationMiddleware* m = new AuthenticationMiddleware();
m->setUsername(username);
m->setPassword(password);
m->_freeOnRemoval = true;
addMiddleware(m);
return *this;
};
AsyncStaticWebHandler::AsyncStaticWebHandler(const char* uri, FS& fs, const char* path, const char* cache_control)
: _fs(fs), _uri(uri), _path(path), _default_file(F("index.htm")), _cache_control(cache_control), _last_modified(), _callback(nullptr) {
// Ensure leading '/'
@ -234,3 +249,65 @@ void AsyncStaticWebHandler::handleRequest(AsyncWebServerRequest* request) {
request->send(404);
}
}
AsyncStaticWebHandler& AsyncStaticWebHandler::setTemplateProcessor(AwsTemplateProcessor newCallback) {
_callback = newCallback;
return *this;
}
void AsyncCallbackWebHandler::setUri(const String& uri) {
_uri = uri;
_isRegex = uri.startsWith("^") && uri.endsWith("$");
}
bool AsyncCallbackWebHandler::canHandle(AsyncWebServerRequest* request) {
if (!_onRequest)
return false;
if (!(_method & request->method()))
return false;
#ifdef ASYNCWEBSERVER_REGEX
if (_isRegex) {
std::regex pattern(_uri.c_str());
std::smatch matches;
std::string s(request->url().c_str());
if (std::regex_search(s, matches, pattern)) {
for (size_t i = 1; i < matches.size(); ++i) { // start from 1
request->_addPathParam(matches[i].str().c_str());
}
} else {
return false;
}
} else
#endif
if (_uri.length() && _uri.startsWith("/*.")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(uriTemplate.lastIndexOf("."));
if (!request->url().endsWith(uriTemplate))
return false;
} else if (_uri.length() && _uri.endsWith("*")) {
String uriTemplate = String(_uri);
uriTemplate = uriTemplate.substring(0, uriTemplate.length() - 1);
if (!request->url().startsWith(uriTemplate))
return false;
} else if (_uri.length() && (_uri != request->url() && !request->url().startsWith(_uri + "/")))
return false;
return true;
}
void AsyncCallbackWebHandler::handleRequest(AsyncWebServerRequest* request) {
if (_onRequest)
_onRequest(request);
else
request->send(500);
}
void AsyncCallbackWebHandler::handleUpload(AsyncWebServerRequest* request, const String& filename, size_t index, uint8_t* data, size_t len, bool final) {
if (_onUpload)
_onUpload(request, filename, index, data, len, final);
}
void AsyncCallbackWebHandler::handleBody(AsyncWebServerRequest* request, uint8_t* data, size_t len, size_t index, size_t total) {
if (_onBody)
_onBody(request, data, len, index, total);
}

View File

@ -636,6 +636,22 @@ const AsyncWebHeader* AsyncWebServerRequest::getHeader(size_t num) const {
return &(*std::next(_headers.cbegin(), num));
}
size_t AsyncWebServerRequest::getHeaderNames(std::vector<const char*>& names) const {
names.clear();
const size_t size = _headers.size();
names.reserve(size);
for (const auto& h : _headers) {
names.push_back(h.name().c_str());
}
return size;
}
bool AsyncWebServerRequest::removeHeader(const char* name) {
const size_t size = _headers.size();
_headers.remove_if([name](const AsyncWebHeader& header) { return header.name().equalsIgnoreCase(name); });
return size != _headers.size();
}
size_t AsyncWebServerRequest::params() const {
return _params.size();
}
@ -670,6 +686,27 @@ const AsyncWebParameter* AsyncWebServerRequest::getParam(size_t num) const {
return &(*std::next(_params.cbegin(), num));
}
const String& AsyncWebServerRequest::getAttribute(const char* name, const String& defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second : defaultValue;
}
bool AsyncWebServerRequest::getAttribute(const char* name, bool defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second == "1" : defaultValue;
}
long AsyncWebServerRequest::getAttribute(const char* name, long defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toInt() : defaultValue;
}
float AsyncWebServerRequest::getAttribute(const char* name, float defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toFloat() : defaultValue;
}
double AsyncWebServerRequest::getAttribute(const char* name, double defaultValue) const {
auto it = _attributes.find(name);
return it != _attributes.end() ? it->second.toDouble() : defaultValue;
}
AsyncWebServerResponse* AsyncWebServerRequest::beginResponse(int code, const char* contentType, const char* content, AwsTemplateProcessor callback) {
if (callback)
return new AsyncProgmemResponse(code, contentType, (const uint8_t*)content, strlen(content), callback);