Complete rework of AuthenticationMiddleware...

- to align methods and enum with PsychicHttp and Arduino WebServer
- to support hash
- to pre-compute base64 / digest hash to speed up requests
Closes #111
This commit is contained in:
Mathieu Carbou
2024-10-01 01:48:17 +02:00
parent c295c7b676
commit 3e416accbd
9 changed files with 239 additions and 90 deletions

View File

@@ -193,10 +193,12 @@ AuthenticationMiddleware authMiddleware;
// [...] // [...]
authMiddleware.setAuthType(AuthenticationMiddleware::AuthType::AUTH_DIGEST); authMiddleware.setAuthType(AsyncAuthType::AUTH_DIGEST);
authMiddleware.setRealm("My app name"); authMiddleware.setRealm("My app name");
authMiddleware.setUsername("admin"); authMiddleware.setUsername("admin");
authMiddleware.setPassword("admin"); authMiddleware.setPassword("admin");
authMiddleware.setAuthFailureMessage("Authentication failed");
authMiddleware.generateHash(); // optimization to avoid generating the hash at each request
// [...] // [...]

View File

@@ -193,10 +193,12 @@ AuthenticationMiddleware authMiddleware;
// [...] // [...]
authMiddleware.setAuthType(AuthenticationMiddleware::AuthType::AUTH_DIGEST); authMiddleware.setAuthType(AsyncAuthType::AUTH_DIGEST);
authMiddleware.setRealm("My app name"); authMiddleware.setRealm("My app name");
authMiddleware.setUsername("admin"); authMiddleware.setUsername("admin");
authMiddleware.setPassword("admin"); authMiddleware.setPassword("admin");
authMiddleware.setAuthFailureMessage("Authentication failed");
authMiddleware.generateHash(); // optimization to avoid generating the hash at each request
// [...] // [...]

View File

@@ -50,8 +50,13 @@ HeaderFilterMiddleware headerFilter;
// remove all headers from the incoming request except the ones provided in the constructor // remove all headers from the incoming request except the ones provided in the constructor
HeaderFreeMiddleware headerFree; HeaderFreeMiddleware headerFree;
// basicAuth
AuthenticationMiddleware basicAuth;
AuthenticationMiddleware basicAuthHash;
// simple digest authentication // simple digest authentication
AuthenticationMiddleware simpleDigestAuth; AuthenticationMiddleware digestAuth;
AuthenticationMiddleware digestAuthHash;
// complex authentication which adds request attributes for the next middlewares and handler // complex authentication which adds request attributes for the next middlewares and handler
AsyncMiddlewareFunction complexAuth([](AsyncWebServerRequest* request, ArMiddlewareNext next) { AsyncMiddlewareFunction complexAuth([](AsyncWebServerRequest* request, ArMiddlewareNext next) {
@@ -177,9 +182,31 @@ void setup() {
requestLogger.setOutput(Serial); requestLogger.setOutput(Serial);
simpleDigestAuth.setUsername("admin"); basicAuth.setUsername("admin");
simpleDigestAuth.setPassword("admin"); basicAuth.setPassword("admin");
simpleDigestAuth.setRealm("MyApp"); basicAuth.setRealm("MyApp");
basicAuth.setAuthFailureMessage("Authentication failed");
basicAuth.setAuthType(AsyncAuthType::AUTH_BASIC);
basicAuth.generateHash();
basicAuthHash.setUsername("admin");
basicAuthHash.setPasswordHash("YWRtaW46YWRtaW4="); // BASE64(admin:admin)
basicAuthHash.setRealm("MyApp");
basicAuthHash.setAuthFailureMessage("Authentication failed");
basicAuthHash.setAuthType(AsyncAuthType::AUTH_BASIC);
digestAuth.setUsername("admin");
digestAuth.setPassword("admin");
digestAuth.setRealm("MyApp");
digestAuth.setAuthFailureMessage("Authentication failed");
digestAuth.setAuthType(AsyncAuthType::AUTH_DIGEST);
digestAuth.generateHash();
digestAuthHash.setUsername("admin");
digestAuthHash.setPasswordHash("f499b71f9a36d838b79268e145e132f7"); // MD5(user:realm:pass)
digestAuthHash.setRealm("MyApp");
digestAuthHash.setAuthFailureMessage("Authentication failed");
digestAuthHash.setAuthType(AsyncAuthType::AUTH_DIGEST);
rateLimit.setMaxRequests(5); rateLimit.setMaxRequests(5);
rateLimit.setWindowSize(10); rateLimit.setWindowSize(10);
@@ -225,15 +252,37 @@ void setup() {
}) })
.addMiddleware(&headerFree); .addMiddleware(&headerFree);
// simple digest authentication // basic authentication method
// curl -v -X GET -H "x-remove-me: value" --digest -u admin:admin http://192.168.4.1/middleware/auth-simple // curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic
server.on("/middleware/auth-simple", HTTP_GET, [](AsyncWebServerRequest* request) { server.on("/middleware/auth-basic", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!"); request->send(200, "text/plain", "Hello, world!");
}) })
.addMiddleware(&simpleDigestAuth); .addMiddleware(&basicAuth);
// curl -v -X GET -H "x-remove-me: value" --digest -u user:password http://192.168.4.1/middleware/auth-complex // basic authentication method with hash
server.on("/middleware/auth-complex", HTTP_GET, [](AsyncWebServerRequest* request) { // curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic-hash
server.on("/middleware/auth-basic-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&basicAuthHash);
// digest authentication
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest
server.on("/middleware/auth-digest", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&digestAuth);
// digest authentication with hash
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest-hash
server.on("/middleware/auth-digest-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&digestAuthHash);
// test digest auth with cors
// curl -v -X GET -H "origin: http://192.168.4.1" --digest -u user:password http://192.168.4.1/middleware/auth-custom
server.on("/middleware/auth-custom", HTTP_GET, [](AsyncWebServerRequest* request) {
String buffer = "Hello "; String buffer = "Hello ";
buffer.concat(request->getAttribute("user")); buffer.concat(request->getAttribute("user"));
buffer.concat(" with role: "); buffer.concat(" with role: ");
@@ -244,6 +293,12 @@ void setup() {
/////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////
// curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
// curl -v -X POST -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
server.on("/redirect", HTTP_GET | HTTP_POST, [](AsyncWebServerRequest* request) {
request->redirect("/");
});
server.on("/", HTTP_GET, [](AsyncWebServerRequest* request) { server.on("/", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world"); request->send(200, "text/plain", "Hello, world");
}); });

View File

@@ -164,6 +164,15 @@ typedef enum { RCT_NOT_USED = -1,
RCT_EVENT, RCT_EVENT,
RCT_MAX } RequestedConnectionType; RCT_MAX } RequestedConnectionType;
// this enum is similar to Arduino WebServer's AsyncAuthType and PsychicHttp
typedef enum {
AUTH_NONE = 0,
AUTH_BASIC,
AUTH_DIGEST,
AUTH_BEARER,
AUTH_OTHER,
} AsyncAuthType;
typedef std::function<size_t(uint8_t*, size_t, size_t)> AwsResponseFiller; typedef std::function<size_t(uint8_t*, size_t, size_t)> AwsResponseFiller;
typedef std::function<String(const String&)> AwsTemplateProcessor; typedef std::function<String(const String&)> AwsTemplateProcessor;
@@ -194,7 +203,7 @@ class AsyncWebServerRequest {
String _boundary; String _boundary;
String _authorization; String _authorization;
RequestedConnectionType _reqconntype; RequestedConnectionType _reqconntype;
bool _isDigest; AsyncAuthType _authMethod = AsyncAuthType::AUTH_NONE;
bool _isMultipart; bool _isMultipart;
bool _isPlainPost; bool _isPlainPost;
bool _expectingContinue; bool _expectingContinue;
@@ -271,8 +280,9 @@ class AsyncWebServerRequest {
// base64(user:pass) for basic or // base64(user:pass) for basic or
// user:realm:md5(user:realm:pass) for digest // user:realm:md5(user:realm:pass) for digest
bool authenticate(const char* hash); bool authenticate(const char* hash);
bool authenticate(const char* username, const char* password, const char* realm = NULL, bool passwordIsHash = false); bool authenticate(const char* username, const char* credentials, const char* realm = NULL, bool isHash = false);
void requestAuthentication(const char* realm = NULL, bool isDigest = true); void requestAuthentication(const char* realm = nullptr, bool isDigest = true) { requestAuthentication(isDigest ? AsyncAuthType::AUTH_DIGEST : AsyncAuthType::AUTH_BASIC, realm); }
void requestAuthentication(AsyncAuthType method, const char* realm = nullptr, const char* _authFailMsg = nullptr);
void setHandler(AsyncWebHandler* handler) { _handler = handler; } void setHandler(AsyncWebHandler* handler) { _handler = handler; }
@@ -554,28 +564,31 @@ class AsyncMiddlewareChain {
// AuthenticationMiddleware is a middleware that checks if the request is authenticated // AuthenticationMiddleware is a middleware that checks if the request is authenticated
class AuthenticationMiddleware : public AsyncMiddleware { class AuthenticationMiddleware : public AsyncMiddleware {
public: public:
typedef enum { void setUsername(const char* username);
AUTH_NONE, void setPassword(const char* password);
AUTH_BASIC, void setPasswordHash(const char* hash);
AUTH_DIGEST
} AuthType;
void setUsername(const char* username) { _username = username; }
void setPassword(const char* password) { _password = password; }
void setRealm(const char* realm) { _realm = realm; } void setRealm(const char* realm) { _realm = realm; }
void setPasswordIsHash(bool passwordIsHash) { _hash = passwordIsHash; } void setAuthFailureMessage(const char* message) { _authFailMsg = message; }
void setAuthType(AuthType authType) { _authType = authType; } void setAuthType(AsyncAuthType authMethod) { _authMethod = authMethod; }
bool allowed(AsyncWebServerRequest* request) { return _authType == AUTH_NONE || !_username.length() || !_password.length() || request->authenticate(_username.c_str(), _password.c_str(), _realm, _hash); } // precompute and store the hash value based on the username, realm, and authMethod
// returns true if the hash was successfully generated and replaced
bool generateHash();
void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return allowed(request) ? next() : request->requestAuthentication(_realm, _authType == AUTH_DIGEST); } bool allowed(AsyncWebServerRequest* request);
void run(AsyncWebServerRequest* request, ArMiddlewareNext next);
private: private:
String _username; String _username;
String _password; String _credentials;
const char* _realm = nullptr;
bool _hash = false; bool _hash = false;
AuthType _authType = AUTH_DIGEST;
String _realm = asyncsrv::T_LOGIN_REQ;
AsyncAuthType _authMethod = AsyncAuthType::AUTH_NONE;
String _authFailMsg;
bool _hasCreds = false;
}; };
using ArAuthorizeFunction = std::function<bool(AsyncWebServerRequest* request)>; using ArAuthorizeFunction = std::function<bool(AsyncWebServerRequest* request)>;

View File

@@ -1,3 +1,4 @@
#include "WebAuthentication.h"
#include <ESPAsyncWebServer.h> #include <ESPAsyncWebServer.h>
AsyncMiddlewareChain::~AsyncMiddlewareChain() { AsyncMiddlewareChain::~AsyncMiddlewareChain() {
@@ -52,6 +53,62 @@ void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest* request, ArMiddlewar
return next(); return next();
} }
void AuthenticationMiddleware::setUsername(const char* username) {
_username = username;
_hasCreds = _username.length() && _credentials.length();
}
void AuthenticationMiddleware::setPassword(const char* password) {
_credentials = password;
_hash = false;
_hasCreds = _username.length() && _credentials.length();
}
void AuthenticationMiddleware::setPasswordHash(const char* hash) {
_credentials = hash;
_hash = true;
_hasCreds = _username.length() && _credentials.length();
}
bool AuthenticationMiddleware::generateHash() {
// ensure we have all the necessary data
if (!_hasCreds)
return false;
// if we already have a hash, do nothing
if (_hash)
return false;
switch (_authMethod) {
case AsyncAuthType::AUTH_DIGEST:
_credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str());
_hash = true;
return true;
case AsyncAuthType::AUTH_BASIC:
_credentials = generateBasicHash(_username.c_str(), _credentials.c_str());
_hash = true;
return true;
default:
return false;
}
}
bool AuthenticationMiddleware::allowed(AsyncWebServerRequest* request) {
if (_authMethod == AsyncAuthType::AUTH_NONE)
return true;
if (!_hasCreds)
return false;
return request->authenticate(_username.c_str(), _credentials.c_str(), _realm.c_str(), _hash);
}
void AuthenticationMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str());
}
void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) { void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
std::vector<const char*> reqHeaders; std::vector<const char*> reqHeaders;
request->getHeaderNames(reqHeaders); request->getHeaderNames(reqHeaders);

View File

@@ -34,36 +34,34 @@ using namespace asyncsrv;
bool checkBasicAuthentication(const char* hash, const char* username, const char* password) { bool checkBasicAuthentication(const char* hash, const char* username, const char* password) {
if (username == NULL || password == NULL || hash == NULL) if (username == NULL || password == NULL || hash == NULL)
return false; return false;
return generateBasicHash(username, password).equalsIgnoreCase(hash);
}
String generateBasicHash(const char* username, const char* password) {
if (username == NULL || password == NULL)
return emptyString;
size_t toencodeLen = strlen(username) + strlen(password) + 1; size_t toencodeLen = strlen(username) + strlen(password) + 1;
size_t encodedLen = base64_encode_expected_len(toencodeLen);
if (strlen(hash) != encodedLen)
// Fix from https://github.com/me-no-dev/ESPAsyncWebServer/issues/667
#ifdef ARDUINO_ARCH_ESP32
if (strlen(hash) != encodedLen)
#else
if (strlen(hash) != encodedLen - 1)
#endif
return false;
char* toencode = new char[toencodeLen + 1]; char* toencode = new char[toencodeLen + 1];
if (toencode == NULL) { if (toencode == NULL) {
return false; return emptyString;
} }
char* encoded = new char[base64_encode_expected_len(toencodeLen) + 1]; char* encoded = new char[base64_encode_expected_len(toencodeLen) + 1];
if (encoded == NULL) { if (encoded == NULL) {
delete[] toencode; delete[] toencode;
return false; return emptyString;
} }
sprintf_P(toencode, PSTR("%s:%s"), username, password); sprintf_P(toencode, PSTR("%s:%s"), username, password);
if (base64_encode_chars(toencode, toencodeLen, encoded) > 0 && memcmp(hash, encoded, encodedLen) == 0) { if (base64_encode_chars(toencode, toencodeLen, encoded) > 0) {
String res = String(encoded);
delete[] toencode; delete[] toencode;
delete[] encoded; delete[] encoded;
return true; return res;
} }
delete[] toencode; delete[] toencode;
delete[] encoded; delete[] encoded;
return false; return emptyString;
} }
static bool getMD5(uint8_t* data, uint16_t len, char* output) { // 33 bytes or more static bool getMD5(uint8_t* data, uint16_t len, char* output) { // 33 bytes or more
@@ -94,7 +92,7 @@ static bool getMD5(uint8_t* data, uint16_t len, char* output) { // 33 bytes or m
return true; return true;
} }
static String genRandomMD5() { String genRandomMD5() {
#ifdef ESP8266 #ifdef ESP8266
uint32_t r = RANDOM_REG32; uint32_t r = RANDOM_REG32;
#else #else
@@ -122,31 +120,21 @@ String generateDigestHash(const char* username, const char* password, const char
return emptyString; return emptyString;
} }
char* out = (char*)malloc(33); char* out = (char*)malloc(33);
String res = String(username);
res += ':'; String in;
res.concat(realm); in.reserve(strlen(username) + strlen(realm) + strlen(password) + 2);
res += ':'; in.concat(username);
String in = res; in.concat(':');
in.concat(realm);
in.concat(':');
in.concat(password); in.concat(password);
if (out == NULL || !getMD5((uint8_t*)(in.c_str()), in.length(), out)) if (out == NULL || !getMD5((uint8_t*)(in.c_str()), in.length(), out))
return emptyString; return emptyString;
res.concat(out);
free(out);
return res;
}
String requestDigestAuthentication(const char* realm) { in = String(out);
String header(T_realm__); free(out);
if (realm == NULL) return in;
header.concat(T_asyncesp);
else
header.concat(realm);
header.concat(T_auth_nonce);
header.concat(genRandomMD5());
header.concat(T__opaque);
header.concat(genRandomMD5());
header += (char)0x22; // '"'
return header;
} }
#ifndef ESP8266 #ifndef ESP8266
@@ -235,9 +223,9 @@ bool checkDigestAuthentication(const char* header, const __FlashStringHelper* me
} }
} while (nextBreak > 0); } while (nextBreak > 0);
String ha1 = (passwordIsHash) ? String(password) : stringMD5(myUsername + ':' + myRealm + ':' + password); String ha1 = passwordIsHash ? password : stringMD5(myUsername + ':' + myRealm + ':' + password).c_str();
String ha2 = String(method) + ':' + myUri; String ha2 = stringMD5(String(method) + ':' + myUri);
String response = ha1 + ':' + myNonce + ':' + myNc + ':' + myCnonce + ':' + myQop + ':' + stringMD5(ha2); String response = ha1 + ':' + myNonce + ':' + myNc + ':' + myCnonce + ':' + myQop + ':' + ha2;
if (myResponse.equals(stringMD5(response))) { if (myResponse.equals(stringMD5(response))) {
// os_printf("AUTH SUCCESS\n"); // os_printf("AUTH SUCCESS\n");

View File

@@ -25,7 +25,6 @@
#include "Arduino.h" #include "Arduino.h"
bool checkBasicAuthentication(const char* header, const char* username, const char* password); bool checkBasicAuthentication(const char* header, const char* username, const char* password);
String requestDigestAuthentication(const char* realm);
bool checkDigestAuthentication(const char* header, const char* method, const char* username, const char* password, const char* realm, bool passwordIsHash, const char* nonce, const char* opaque, const char* uri); bool checkDigestAuthentication(const char* header, const char* method, const char* username, const char* password, const char* realm, bool passwordIsHash, const char* nonce, const char* opaque, const char* uri);
@@ -36,4 +35,8 @@ bool checkDigestAuthentication(const char* header, const __FlashStringHelper* me
// for storing hashed versions on the device that can be authenticated against // for storing hashed versions on the device that can be authenticated against
String generateDigestHash(const char* username, const char* password, const char* realm); String generateDigestHash(const char* username, const char* password, const char* realm);
String generateBasicHash(const char* username, const char* password);
String genRandomMD5();
#endif #endif

View File

@@ -35,7 +35,7 @@ enum { PARSE_REQ_START,
PARSE_REQ_FAIL }; PARSE_REQ_FAIL };
AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c) AsyncWebServerRequest::AsyncWebServerRequest(AsyncWebServer* s, AsyncClient* c)
: _client(c), _server(s), _handler(NULL), _response(NULL), _temp(), _parseState(0), _version(0), _method(HTTP_ANY), _url(), _host(), _contentType(), _boundary(), _authorization(), _reqconntype(RCT_HTTP), _isDigest(false), _isMultipart(false), _isPlainPost(false), _expectingContinue(false), _contentLength(0), _parsedLength(0), _multiParseState(0), _boundaryPosition(0), _itemStartIndex(0), _itemSize(0), _itemName(), _itemFilename(), _itemType(), _itemValue(), _itemBuffer(0), _itemBufferIndex(0), _itemIsFile(false), _tempObject(NULL) { : _client(c), _server(s), _handler(NULL), _response(NULL), _temp(), _parseState(0), _version(0), _method(HTTP_ANY), _url(), _host(), _contentType(), _boundary(), _authorization(), _reqconntype(RCT_HTTP), _authMethod(AsyncAuthType::AUTH_NONE), _isMultipart(false), _isPlainPost(false), _expectingContinue(false), _contentLength(0), _parsedLength(0), _multiParseState(0), _boundaryPosition(0), _itemStartIndex(0), _itemSize(0), _itemName(), _itemFilename(), _itemType(), _itemValue(), _itemBuffer(0), _itemBufferIndex(0), _itemIsFile(false), _tempObject(NULL) {
c->onError([](void* r, AsyncClient* c, int8_t error) { (void)c; AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onError(error); }, this); c->onError([](void* r, AsyncClient* c, int8_t error) { (void)c; AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onError(error); }, this);
c->onAck([](void* r, AsyncClient* c, size_t len, uint32_t time) { (void)c; AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onAck(len, time); }, this); c->onAck([](void* r, AsyncClient* c, size_t len, uint32_t time) { (void)c; AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onAck(len, time); }, this);
c->onDisconnect([](void* r, AsyncClient* c) { AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onDisconnect(); delete c; }, this); c->onDisconnect([](void* r, AsyncClient* c) { AsyncWebServerRequest *req = (AsyncWebServerRequest*)r; req->_onDisconnect(); delete c; }, this);
@@ -285,9 +285,16 @@ bool AsyncWebServerRequest::_parseReqHeader() {
} else if (name.equalsIgnoreCase(T_AUTH)) { } else if (name.equalsIgnoreCase(T_AUTH)) {
if (value.length() > 5 && value.substring(0, 5).equalsIgnoreCase(T_BASIC)) { if (value.length() > 5 && value.substring(0, 5).equalsIgnoreCase(T_BASIC)) {
_authorization = value.substring(6); _authorization = value.substring(6);
_authMethod = AsyncAuthType::AUTH_BASIC;
} else if (value.length() > 6 && value.substring(0, 6).equalsIgnoreCase(T_DIGEST)) { } else if (value.length() > 6 && value.substring(0, 6).equalsIgnoreCase(T_DIGEST)) {
_isDigest = true; _authMethod = AsyncAuthType::AUTH_DIGEST;
_authorization = value.substring(7); _authorization = value.substring(7);
} else if (value.length() > 6 && value.substring(0, 6).equalsIgnoreCase(T_BEARER)) {
_authMethod = AsyncAuthType::AUTH_BEARER;
_authorization = value.substring(7);
} else {
_authorization = value;
_authMethod = AsyncAuthType::AUTH_OTHER;
} }
} else { } else {
if (name.equalsIgnoreCase(T_UPGRADE) && value.equalsIgnoreCase(T_WS)) { if (name.equalsIgnoreCase(T_UPGRADE) && value.equalsIgnoreCase(T_WS)) {
@@ -774,7 +781,7 @@ void AsyncWebServerRequest::redirect(const char* url, int code) {
bool AsyncWebServerRequest::authenticate(const char* username, const char* password, const char* realm, bool passwordIsHash) { bool AsyncWebServerRequest::authenticate(const char* username, const char* password, const char* realm, bool passwordIsHash) {
if (_authorization.length()) { if (_authorization.length()) {
if (_isDigest) if (_authMethod == AsyncAuthType::AUTH_DIGEST)
return checkDigestAuthentication(_authorization.c_str(), methodToString(), username, password, realm, passwordIsHash, NULL, NULL, NULL); return checkDigestAuthentication(_authorization.c_str(), methodToString(), username, password, realm, passwordIsHash, NULL, NULL, NULL);
else if (!passwordIsHash) else if (!passwordIsHash)
return checkBasicAuthentication(_authorization.c_str(), username, password); return checkBasicAuthentication(_authorization.c_str(), username, password);
@@ -788,7 +795,7 @@ bool AsyncWebServerRequest::authenticate(const char* hash) {
if (!_authorization.length() || hash == NULL) if (!_authorization.length() || hash == NULL)
return false; return false;
if (_isDigest) { if (_authMethod == AsyncAuthType::AUTH_DIGEST) {
String hStr = String(hash); String hStr = String(hash);
int separator = hStr.indexOf(':'); int separator = hStr.indexOf(':');
if (separator <= 0) if (separator <= 0)
@@ -803,23 +810,45 @@ bool AsyncWebServerRequest::authenticate(const char* hash) {
return checkDigestAuthentication(_authorization.c_str(), methodToString(), username.c_str(), hStr.c_str(), realm.c_str(), true, NULL, NULL, NULL); return checkDigestAuthentication(_authorization.c_str(), methodToString(), username.c_str(), hStr.c_str(), realm.c_str(), true, NULL, NULL, NULL);
} }
// Basic Auth, Bearer Auth, or other
return (_authorization.equals(hash)); return (_authorization.equals(hash));
} }
void AsyncWebServerRequest::requestAuthentication(const char* realm, bool isDigest) { void AsyncWebServerRequest::requestAuthentication(AsyncAuthType method, const char* realm, const char* _authFailMsg) {
AsyncWebServerResponse* r = beginResponse(401); if (!realm)
if (!isDigest && realm == NULL) { realm = T_LOGIN_REQ;
r->addHeader(T_WWW_AUTH, T_BASIC_REALM_LOGIN_REQ);
} else if (!isDigest) { AsyncWebServerResponse* r = _authFailMsg ? beginResponse(401, T_text_html, _authFailMsg) : beginResponse(401);
String header(T_BASIC_REALM);
header.concat(realm); switch (method) {
header += '"'; case AsyncAuthType::AUTH_BASIC: {
r->addHeader(T_WWW_AUTH, header.c_str()); String header;
} else { header.reserve(strlen(T_BASIC_REALM) + strlen(realm) + 1);
String header(T_DIGEST_); header.concat(T_BASIC_REALM);
header.concat(requestDigestAuthentication(realm)); header.concat(realm);
r->addHeader(T_WWW_AUTH, header.c_str()); header.concat('"');
r->addHeader(T_WWW_AUTH, header.c_str());
break;
}
case AsyncAuthType::AUTH_DIGEST: {
constexpr size_t len = strlen(T_DIGEST_) + strlen(T_realm__) + strlen(T_auth_nonce) + 32 + strlen(T__opaque) + 32 + 1;
String header;
header.reserve(len + strlen(realm));
header.concat(T_DIGEST_);
header.concat(T_realm__);
header.concat(realm);
header.concat(T_auth_nonce);
header.concat(genRandomMD5());
header.concat(T__opaque);
header.concat(genRandomMD5());
header.concat((char)0x22); // '"'
r->addHeader(T_WWW_AUTH, header.c_str());
break;
}
default:
break;
} }
send(r); send(r);
} }

View File

@@ -12,7 +12,7 @@ static constexpr const char* T_app_xform_urlencoded = "application/x-www-form-ur
static constexpr const char* T_AUTH = "Authorization"; static constexpr const char* T_AUTH = "Authorization";
static constexpr const char* T_BASIC = "Basic"; static constexpr const char* T_BASIC = "Basic";
static constexpr const char* T_BASIC_REALM = "Basic realm=\""; static constexpr const char* T_BASIC_REALM = "Basic realm=\"";
static constexpr const char* T_BASIC_REALM_LOGIN_REQ = "Basic realm=\"Login Required\""; static constexpr const char* T_LOGIN_REQ = "Login Required";
static constexpr const char* T_BODY = "body"; static constexpr const char* T_BODY = "body";
static constexpr const char* T_Cache_Control = "Cache-Control"; static constexpr const char* T_Cache_Control = "Cache-Control";
static constexpr const char* T_chunked = "chunked"; static constexpr const char* T_chunked = "chunked";
@@ -25,6 +25,7 @@ static constexpr const char* T_Content_Type = "Content-Type";
static constexpr const char* T_Cookie = "Cookie"; static constexpr const char* T_Cookie = "Cookie";
static constexpr const char* T_DIGEST = "Digest"; static constexpr const char* T_DIGEST = "Digest";
static constexpr const char* T_DIGEST_ = "Digest "; static constexpr const char* T_DIGEST_ = "Digest ";
static constexpr const char* T_BEARER = "Bearer";
static constexpr const char* T_ETag = "ETag"; static constexpr const char* T_ETag = "ETag";
static constexpr const char* T_EXPECT = "Expect"; static constexpr const char* T_EXPECT = "Expect";
static constexpr const char* T_HTTP_1_0 = "HTTP/1.0"; static constexpr const char* T_HTTP_1_0 = "HTTP/1.0";
@@ -149,7 +150,6 @@ static constexpr const char* T_HTTP_CODE_ANY = "Unknown code";
// other // other
static constexpr const char* T__opaque = "\", opaque=\""; static constexpr const char* T__opaque = "\", opaque=\"";
static constexpr const char* T_13 = "13"; static constexpr const char* T_13 = "13";
static constexpr const char* T_asyncesp = "asyncesp";
static constexpr const char* T_auth_nonce = "\", qop=\"auth\", nonce=\""; static constexpr const char* T_auth_nonce = "\", qop=\"auth\", nonce=\"";
static constexpr const char* T_cnonce = "cnonce"; static constexpr const char* T_cnonce = "cnonce";
static constexpr const char* T_data_ = "data: "; static constexpr const char* T_data_ = "data: ";
@@ -182,7 +182,7 @@ static const char T_app_xform_urlencoded[] PROGMEM = "application/x-www-form-url
static const char T_AUTH[] PROGMEM = "Authorization"; static const char T_AUTH[] PROGMEM = "Authorization";
static const char T_BASIC[] PROGMEM = "Basic"; static const char T_BASIC[] PROGMEM = "Basic";
static const char T_BASIC_REALM[] PROGMEM = "Basic realm=\""; static const char T_BASIC_REALM[] PROGMEM = "Basic realm=\"";
static const char T_BASIC_REALM_LOGIN_REQ[] PROGMEM = "Basic realm=\"Login Required\""; static const char T_LOGIN_REQ[] PROGMEM = "Login Required";
static const char T_BODY[] PROGMEM = "body"; static const char T_BODY[] PROGMEM = "body";
static const char T_Cache_Control[] PROGMEM = "Cache-Control"; static const char T_Cache_Control[] PROGMEM = "Cache-Control";
static const char T_chunked[] PROGMEM = "chunked"; static const char T_chunked[] PROGMEM = "chunked";
@@ -195,6 +195,7 @@ static const char T_Content_Type[] PROGMEM = "Content-Type";
static const char T_Cookie[] PROGMEM = "Cookie"; static const char T_Cookie[] PROGMEM = "Cookie";
static const char T_DIGEST[] PROGMEM = "Digest"; static const char T_DIGEST[] PROGMEM = "Digest";
static const char T_DIGEST_[] PROGMEM = "Digest "; static const char T_DIGEST_[] PROGMEM = "Digest ";
static const char T_BEARER[] PROGMEM = "Bearer";
static const char T_ETag[] PROGMEM = "ETag"; static const char T_ETag[] PROGMEM = "ETag";
static const char T_EXPECT[] PROGMEM = "Expect"; static const char T_EXPECT[] PROGMEM = "Expect";
static const char T_HTTP_1_0[] PROGMEM = "HTTP/1.0"; static const char T_HTTP_1_0[] PROGMEM = "HTTP/1.0";
@@ -319,7 +320,6 @@ static const char T_HTTP_CODE_ANY[] PROGMEM = "Unknown code";
// other // other
static const char T__opaque[] PROGMEM = "\", opaque=\""; static const char T__opaque[] PROGMEM = "\", opaque=\"";
static const char T_13[] PROGMEM = "13"; static const char T_13[] PROGMEM = "13";
static const char T_asyncesp[] PROGMEM = "asyncesp";
static const char T_auth_nonce[] PROGMEM = "\", qop=\"auth\", nonce=\""; static const char T_auth_nonce[] PROGMEM = "\", qop=\"auth\", nonce=\"";
static const char T_cnonce[] PROGMEM = "cnonce"; static const char T_cnonce[] PROGMEM = "cnonce";
static const char T_data_[] PROGMEM = "data: "; static const char T_data_[] PROGMEM = "data: ";