diff --git a/src/WebSockets.h b/src/WebSockets.h index 39a1e0b..5951023 100644 --- a/src/WebSockets.h +++ b/src/WebSockets.h @@ -57,6 +57,7 @@ #if defined(ESP8266) || defined(ESP32) +#define HAS_SSL #define WEBSOCKETS_MAX_DATA_SIZE (15*1024) #define WEBSOCKETS_USE_BIG_MEM #define GET_FREE_HEAP ESP.getFreeHeap() diff --git a/src/WebSocketsClient.cpp b/src/WebSocketsClient.cpp index 2d96358..eb46c12 100644 --- a/src/WebSocketsClient.cpp +++ b/src/WebSocketsClient.cpp @@ -42,14 +42,15 @@ WebSocketsClient::~WebSocketsClient() { void WebSocketsClient::begin(const char *host, uint16_t port, const char * url, const char * protocol) { _host = host; _port = port; -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) +#if defined(HAS_SSL) _fingerprint = ""; + _CA_cert = NULL; #endif _client.num = 0; _client.status = WSC_NOT_CONNECTED; _client.tcp = NULL; -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) +#if defined(HAS_SSL) _client.isSSL = false; _client.ssl = NULL; #endif @@ -92,16 +93,24 @@ void WebSocketsClient::begin(IPAddress host, uint16_t port, const char * url, co return begin(host.toString().c_str(), port, url, protocol); } -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) +#if defined(HAS_SSL) void WebSocketsClient::beginSSL(const char *host, uint16_t port, const char * url, const char * fingerprint, const char * protocol) { begin(host, port, url, protocol); _client.isSSL = true; _fingerprint = fingerprint; + _CA_cert = NULL; } void WebSocketsClient::beginSSL(String host, uint16_t port, String url, String fingerprint, String protocol) { beginSSL(host.c_str(), port, url.c_str(), fingerprint.c_str(), protocol.c_str()); } + +void WebSocketsClient::beginSslWithCA(const char *host, uint16_t port, const char * url, const char * CA_cert, const char * protocol) { + begin(host, port, url, protocol); + _client.isSSL = true; + _fingerprint = ""; + _CA_cert = CA_cert; +} #endif void WebSocketsClient::beginSocketIO(const char *host, uint16_t port, const char * url, const char * protocol) { @@ -113,7 +122,7 @@ void WebSocketsClient::beginSocketIO(String host, uint16_t port, String url, Str beginSocketIO(host.c_str(), port, url.c_str(), protocol.c_str()); } -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) +#if defined(HAS_SSL) void WebSocketsClient::beginSocketIOSSL(const char *host, uint16_t port, const char * url, const char * protocol) { begin(host, port, url, protocol); _client.isSocketIO = true; @@ -124,6 +133,14 @@ void WebSocketsClient::beginSocketIOSSL(const char *host, uint16_t port, const c void WebSocketsClient::beginSocketIOSSL(String host, uint16_t port, String url, String protocol) { beginSocketIOSSL(host.c_str(), port, url.c_str(), protocol.c_str()); } + +void WebSocketsClient::beginSocketIOSSLWithCA(const char *host, uint16_t port, const char * url, const char * CA_cert, const char * protocol) { + begin(host, port, url, protocol); + _client.isSocketIO = true; + _client.isSSL = true; + _fingerprint = ""; + _CA_cert = CA_cert; +} #endif #if (WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC) @@ -147,6 +164,16 @@ void WebSocketsClient::loop(void) { } _client.ssl = new WiFiClientSecure(); _client.tcp = _client.ssl; + if(_CA_cert) { + DEBUG_WEBSOCKETS("[WS-Client] setting CA certificate"); +#if defined(ESP32) + _client.ssl->setCACert(_CA_cert); +#elif defined(ESP8266) + _client.ssl->setCACert((const uint8_t *)_CA_cert, strlen(_CA_cert) + 1); +#else +#error setCACert not implemented +#endif + } } else { DEBUG_WEBSOCKETS("[WS-Client] connect ws...\n"); if(_client.tcp) { @@ -710,9 +737,11 @@ void WebSocketsClient::connectedCb() { _client.tcp->setTimeout(WEBSOCKETS_TCP_TIMEOUT); #endif -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) +#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32 _client.tcp->setNoDelay(true); +#endif +#if defined(HAS_SSL) if(_client.isSSL && _fingerprint.length()) { if(!_client.ssl->verify(_fingerprint.c_str(), _host.c_str())) { DEBUG_WEBSOCKETS("[WS-Client] certificate mismatch\n"); @@ -806,4 +835,4 @@ void WebSocketsClient::enableHeartbeat(uint32_t pingInterval, uint32_t pongTimeo */ void WebSocketsClient::disableHeartbeat(){ _client.pingInterval = 0; -} \ No newline at end of file +} diff --git a/src/WebSocketsClient.h b/src/WebSocketsClient.h index 47fe8c8..07c82d1 100644 --- a/src/WebSocketsClient.h +++ b/src/WebSocketsClient.h @@ -43,17 +43,19 @@ class WebSocketsClient: private WebSockets { void begin(String host, uint16_t port, String url = "/", String protocol = "arduino"); void begin(IPAddress host, uint16_t port, const char * url = "/", const char * protocol = "arduino"); -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) +#if defined(HAS_SSL) void beginSSL(const char *host, uint16_t port, const char * url = "/", const char * = "", const char * protocol = "arduino"); void beginSSL(String host, uint16_t port, String url = "/", String fingerprint = "", String protocol = "arduino"); + void beginSslWithCA(const char *host, uint16_t port, const char * url = "/", const char * CA_cert = NULL, const char * protocol = "arduino"); #endif void beginSocketIO(const char *host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); void beginSocketIO(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) +#if defined(HAS_SSL) void beginSocketIOSSL(const char *host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * protocol = "arduino"); void beginSocketIOSSL(String host, uint16_t port, String url = "/socket.io/?EIO=3", String protocol = "arduino"); + void beginSocketIOSSLWithCA(const char *host, uint16_t port, const char * url = "/socket.io/?EIO=3", const char * CA_cert = NULL, const char * protocol = "arduino"); #endif #if (WEBSOCKETS_NETWORK_TYPE != NETWORK_ESP8266_ASYNC) @@ -93,8 +95,9 @@ class WebSocketsClient: private WebSockets { String _host; uint16_t _port; -#if (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP8266) || (WEBSOCKETS_NETWORK_TYPE == NETWORK_ESP32) +#if defined(HAS_SSL) String _fingerprint; + const char *_CA_cert; #endif WSclient_t _client;