diff --git a/examples/esp8266/WebSocketClientSSLWithCA/WebSocketClientSSLWithCA.ino b/examples/esp8266/WebSocketClientSSLWithCA/WebSocketClientSSLWithCA.ino index 72ffae9..214f5e6 100644 --- a/examples/esp8266/WebSocketClientSSLWithCA/WebSocketClientSSLWithCA.ino +++ b/examples/esp8266/WebSocketClientSSLWithCA/WebSocketClientSSLWithCA.ino @@ -90,6 +90,10 @@ void setup() { delay(100); } + //When using BearSSL, client certificate and private key can be set: + //webSocket.setSSLClientCertKey(clientCert, clientPrivateKey); + //clientCert and clientPrivateKey can be of types (const char *, const char *) , or of types (BearSSL::X509List, BearSSL::PrivateKey) + webSocket.beginSslWithCA("echo.websocket.org", 443, "/", ENDPOINT_CA_CERT); webSocket.onEvent(webSocketEvent); } diff --git a/src/SocketIOclient.cpp b/src/SocketIOclient.cpp index 6b14319..4233e2c 100644 --- a/src/SocketIOclient.cpp +++ b/src/SocketIOclient.cpp @@ -24,7 +24,37 @@ void SocketIOclient::begin(String host, uint16_t port, String url, String protoc WebSocketsClient::beginSocketIO(host, port, url, protocol); WebSocketsClient::enableHeartbeat(60 * 1000, 90 * 1000, 5); } +#if defined(HAS_SSL) +void SocketIOclient::beginSSL(const char * host, uint16_t port, const char * url, const char * protocol) { + WebSocketsClient::beginSocketIOSSL(host, port, url, protocol); + WebSocketsClient::enableHeartbeat(60 * 1000, 90 * 1000, 5); +} +void SocketIOclient::beginSSL(String host, uint16_t port, String url, String protocol) { + WebSocketsClient::beginSocketIOSSL(host, port, url, protocol); + WebSocketsClient::enableHeartbeat(60 * 1000, 90 * 1000, 5); +} +#if !defined(SSL_AXTLS) +void SocketIOclient::beginSSLWithCA(const char * host, uint16_t port, const char * url, const char * CA_cert, const char * protocol) { + WebSocketsClient::beginSocketIOSSLWithCA(host, port, url, CA_cert, protocol); + WebSocketsClient::enableHeartbeat(60 * 1000, 90 * 1000, 5); +} + +void SocketIOclient::beginSSLWithCA(const char * host, uint16_t port, const char * url, BearSSL::X509List * CA_cert, const char * protocol) { + WebSocketsClient::beginSocketIOSSLWithCA(host, port, url, CA_cert, protocol); + WebSocketsClient::enableHeartbeat(60 * 1000, 90 * 1000, 5); +} + +void SocketIOclient::setSSLClientCertKey(const char * clientCert, const char * clientPrivateKey) { + WebSocketsClient::setSSLClientCertKey(clientCert, clientPrivateKey); +} + +void SocketIOclient::setSSLClientCertKey(BearSSL::X509List * clientCert, BearSSL::PrivateKey * clientPrivateKey) { + WebSocketsClient::setSSLClientCertKey(clientCert, clientPrivateKey); +} + +#endif +#endif /** * set callback function * @param cbEvent SocketIOclientEvent diff --git a/src/SocketIOclient.h b/src/SocketIOclient.h index ca0a69b..cf674cb 100644 --- a/src/SocketIOclient.h +++ b/src/SocketIOclient.h @@ -49,6 +49,16 @@ class SocketIOclient : protected WebSocketsClient { void begin(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); void begin(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); +#ifdef HAS_SSL + void beginSSL(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); + void beginSSL(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); +#ifndef SSL_AXTLS + void beginSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * CA_cert = NULL, const char * protocol = "arduino"); + void beginSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", BearSSL::X509List * CA_cert = NULL, const char * protocol = "arduino"); + void setSSLClientCertKey(const char * clientCert = NULL, const char * clientPrivateKey = NULL); + void setSSLClientCertKey(BearSSL::X509List * clientCert = NULL, BearSSL::PrivateKey * clientPrivateKey = NULL); +#endif +#endif bool isConnected(void); void onEvent(SocketIOclientEvent cbEvent); diff --git a/src/WebSocketsClient.cpp b/src/WebSocketsClient.cpp index 038ff94..6ec0af1 100644 --- a/src/WebSocketsClient.cpp +++ b/src/WebSocketsClient.cpp @@ -122,12 +122,6 @@ void WebSocketsClient::beginSSL(const char * host, uint16_t port, const char * u _fingerprint = fingerprint; _CA_cert = NULL; } -void WebSocketsClient::beginSslWithCA(const char * host, uint16_t port, const char * url, const char * CA_cert, const char * protocol) { - begin(host, port, url, protocol); - _client.isSSL = true; - _fingerprint = SSL_FINGERPRINT_NULL; - _CA_cert = new BearSSL::X509List(CA_cert); -} void WebSocketsClient::beginSslWithCA(const char * host, uint16_t port, const char * url, BearSSL::X509List * CA_cert, const char * protocol) { begin(host, port, url, protocol); @@ -135,6 +129,20 @@ void WebSocketsClient::beginSslWithCA(const char * host, uint16_t port, const ch _fingerprint = SSL_FINGERPRINT_NULL; _CA_cert = CA_cert; } + +void WebSocketsClient::beginSslWithCA(const char * host, uint16_t port, const char * url, const char * CA_cert, const char * protocol) { + beginSslWithCA(host, port, url, new BearSSL::X509List(CA_cert), protocol); +} + +void WebSocketsClient::setSSLClientCertKey(BearSSL::X509List * clientCert, BearSSL::PrivateKey * clientPrivateKey) { + _client_cert = clientCert; + _client_key = clientPrivateKey; +} + +void WebSocketsClient::setSSLClientCertKey(const char * clientCert, const char * clientPrivateKey) { + setSSLClientCertKey(new BearSSL::X509List(clientCert), new BearSSL::PrivateKey(clientPrivateKey)); +} + #endif // SSL_AXTLS #endif // HAS_SSL @@ -148,7 +156,7 @@ void WebSocketsClient::beginSocketIO(String host, uint16_t port, String url, Str } #if defined(HAS_SSL) -void WebSocketsClient::beginSocketIOSSL(const char * host, uint16_t port, const char * url, const char * protocol) { +void WebSocketsClient::beginSocketIOSSL(const char * host, uint16_t port, const char * url, const char * protocol) { begin(host, port, url, protocol); _client.isSocketIO = true; _client.isSSL = true; @@ -159,17 +167,29 @@ void WebSocketsClient::beginSocketIOSSL(String host, uint16_t port, String url, beginSocketIOSSL(host.c_str(), port, url.c_str(), protocol.c_str()); } +#if defined(SSL_BARESSL) +void WebSocketsClient::beginSocketIOSSLWithCA(const char * host, uint16_t port, const char * url, BearSSL::X509List * CA_cert, const char * protocol) { + begin(host, port, url, protocol); + _client.isSocketIO = true; + _client.isSSL = true; + _fingerprint = SSL_FINGERPRINT_NULL; + _CA_cert = CA_cert; +} +#endif + void WebSocketsClient::beginSocketIOSSLWithCA(const char * host, uint16_t port, const char * url, const char * CA_cert, const char * protocol) { begin(host, port, url, protocol); _client.isSocketIO = true; _client.isSSL = true; _fingerprint = SSL_FINGERPRINT_NULL; -#if defined(SSL_AXTLS) - _CA_cert = CA_cert; -#else +#if defined(SSL_BARESSL) _CA_cert = new BearSSL::X509List(CA_cert); +#else + _CA_cert = CA_cert; #endif } + + #endif #if(WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC) @@ -213,6 +233,10 @@ void WebSocketsClient::loop(void) { _client.ssl->setFingerprint(_fingerprint); } else { _client.ssl->setInsecure(); + } + if(_client_cert && _client_key) { + _client.ssl->setClientRSACert(_client_cert, _client_key); + DEBUG_WEBSOCKETS("[WS-Client] setting client certificate and key"); #endif } } else { diff --git a/src/WebSocketsClient.h b/src/WebSocketsClient.h index f99dde7..cc9a0c7 100644 --- a/src/WebSocketsClient.h +++ b/src/WebSocketsClient.h @@ -49,6 +49,8 @@ class WebSocketsClient : protected WebSockets { #else void beginSSL(const char * host, uint16_t port, const char * url = "/", const uint8_t * fingerprint = NULL, const char * protocol = "arduino"); void beginSslWithCA(const char * host, uint16_t port, const char * url = "/", BearSSL::X509List * CA_cert = NULL, const char * protocol = "arduino"); + void setSSLClientCertKey(BearSSL::X509List * clientCert = NULL, BearSSL::PrivateKey * clientPrivateKey = NULL); + void setSSLClientCertKey(const char * clientCert = NULL, const char * clientPrivateKey = NULL); #endif void beginSslWithCA(const char * host, uint16_t port, const char * url = "/", const char * CA_cert = NULL, const char * protocol = "arduino"); #endif @@ -59,9 +61,13 @@ class WebSocketsClient : protected WebSockets { #if defined(HAS_SSL) void beginSocketIOSSL(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); void beginSocketIOSSL(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); - void beginSocketIOSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * CA_cert = NULL, const char * protocol = "arduino"); -#endif + void beginSocketIOSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * CA_cert = NULL, const char * protocol = "arduino"); +#if defined(SSL_BARESSL) + void beginSocketIOSSLWithCA(const char * host, uint16_t port, const char * url = "/socket.io/?EIO=3", BearSSL::X509List * CA_cert = NULL, const char * protocol = "arduino"); +#endif +#endif + #if(WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC) void loop(void); #else @@ -110,6 +116,8 @@ class WebSocketsClient : protected WebSockets { #else const uint8_t * _fingerprint; BearSSL::X509List * _CA_cert; + BearSSL::X509List * _client_cert; + BearSSL::PrivateKey * _client_key; #define SSL_FINGERPRINT_NULL NULL #endif