From d4925f2bd6745095c0da7130cd46dd14ec7a9c11 Mon Sep 17 00:00:00 2001 From: David Cermak Date: Wed, 28 Jun 2023 21:42:27 +0200 Subject: [PATCH] fix(modem): Per review comments --- .github/workflows/modem__build-host-tests.yml | 2 ++ .../examples/modem_tcp_client/CMakeLists.txt | 1 - .../include/mbedtls_wrap.hpp | 18 ++++++---- .../extra_tcp_transports/mbedtls_wrap.cpp | 31 +++++++++-------- .../extra_tcp_transports/tls_transport.cpp | 33 +++++++++++-------- 5 files changed, 51 insertions(+), 34 deletions(-) diff --git a/.github/workflows/modem__build-host-tests.yml b/.github/workflows/modem__build-host-tests.yml index af8deaa3d..e3a9618db 100644 --- a/.github/workflows/modem__build-host-tests.yml +++ b/.github/workflows/modem__build-host-tests.yml @@ -22,6 +22,8 @@ jobs: example: modem_tcp_client - idf_ver: "release-v4.3" example: modem_tcp_client + - idf_ver: "release-v4.4" + example: modem_tcp_client include: - idf_ver: "release-v4.2" skip_config: usb diff --git a/components/esp_modem/examples/modem_tcp_client/CMakeLists.txt b/components/esp_modem/examples/modem_tcp_client/CMakeLists.txt index 2901da172..d3291e866 100644 --- a/components/esp_modem/examples/modem_tcp_client/CMakeLists.txt +++ b/components/esp_modem/examples/modem_tcp_client/CMakeLists.txt @@ -1,7 +1,6 @@ # The following lines of boilerplate have to be in your project's CMakeLists # in this exact order for cmake to work correctly cmake_minimum_required(VERSION 3.8) -set(CMAKE_CXX_STANDARD 17) set(EXTRA_COMPONENT_DIRS "../..") diff --git a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp index b528c19e4..a6e368022 100644 --- a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp +++ b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/include/mbedtls_wrap.hpp @@ -6,28 +6,32 @@ #pragma once #include +#include #include "mbedtls/ssl.h" #include "mbedtls/entropy.h" #include "mbedtls/ctr_drbg.h" #include "mbedtls/error.h" -using const_buf = std::pair; -using buf = std::pair; +using const_buf = std::span; class Tls { public: + enum class is_server : bool {}; + enum class do_verify : bool {}; + Tls(); - bool init(bool is_server, bool verify); + virtual ~Tls(); + bool init(is_server server, do_verify verify); int handshake(); int write(const unsigned char *buf, size_t len); int read(unsigned char *buf, size_t len); - bool set_own_cert(const_buf crt, const_buf key); - bool set_ca_cert(const_buf crt); + [[nodiscard]] bool set_own_cert(const_buf crt, const_buf key); + [[nodiscard]] bool set_ca_cert(const_buf crt); virtual int send(const unsigned char *buf, size_t len) = 0; virtual int recv(unsigned char *buf, size_t len) = 0; size_t get_available_bytes(); -private: +protected: mbedtls_ssl_context ssl_{}; mbedtls_x509_crt public_cert_{}; mbedtls_pk_context pk_key_{}; @@ -35,7 +39,9 @@ private: mbedtls_ssl_config conf_{}; mbedtls_ctr_drbg_context ctr_drbg_{}; mbedtls_entropy_context entropy_{}; + virtual void delay() {} +private: static void print_error(const char *function, int error_code); static int bio_write(void *ctx, const unsigned char *buf, size_t len); static int bio_read(void *ctx, unsigned char *buf, size_t len); diff --git a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp index dd1cfa045..ff510880a 100644 --- a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp +++ b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/mbedtls_wrap.cpp @@ -7,24 +7,24 @@ #include "mbedtls/ssl.h" #include "mbedtls_wrap.hpp" -bool Tls::init(bool is_server, bool verify) +bool Tls::init(is_server server, do_verify verify) { const char pers[] = "mbedtls_wrapper"; mbedtls_entropy_init(&entropy_); mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, (const unsigned char *)pers, sizeof(pers)); - int ret = mbedtls_ssl_config_defaults(&conf_, is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); + int ret = mbedtls_ssl_config_defaults(&conf_, server == is_server{true} ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); if (ret) { print_error("mbedtls_ssl_config_defaults", ret); return false; } mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_); - mbedtls_ssl_conf_authmode(&conf_, verify ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_NONE); + mbedtls_ssl_conf_authmode(&conf_, verify == do_verify{true} ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_NONE); ret = mbedtls_ssl_conf_own_cert(&conf_, &public_cert_, &pk_key_); if (ret) { print_error("mbedtls_ssl_conf_own_cert", ret); return false; } - if (verify) { + if (verify == do_verify{true}) { mbedtls_ssl_conf_ca_chain(&conf_, &ca_cert_, nullptr); } ret = mbedtls_ssl_setup(&ssl_, &conf_); @@ -43,12 +43,9 @@ void Tls::print_error(const char *function, int error_code) printf("%s() returned -0x%04X\n", function, -error_code); printf("-0x%04X: %s\n", -error_code, error_buf); } -#include "freertos/FreeRTOS.h" -#include "freertos/task.h" -#include "esp_log.h" + int Tls::handshake() { - ESP_LOGI("TLS", "handshake"); int ret = 0; mbedtls_ssl_set_bio(&ssl_, this, bio_write, bio_read, nullptr); @@ -57,9 +54,8 @@ int Tls::handshake() print_error( "mbedtls_ssl_handshake returned", ret ); return -1; } - vTaskDelay(pdMS_TO_TICKS(500)); + delay(); } - ESP_LOGI("TLS", "handshake done with %d", ret); return ret; } @@ -87,12 +83,12 @@ int Tls::read(unsigned char *buf, size_t len) bool Tls::set_own_cert(const_buf crt, const_buf key) { - int ret = mbedtls_x509_crt_parse(&public_cert_, crt.first, crt.second); + int ret = mbedtls_x509_crt_parse(&public_cert_, crt.data(), crt.size()); if (ret < 0) { print_error("mbedtls_x509_crt_parse", ret); return false; } - ret = mbedtls_pk_parse_key(&pk_key_, key.first, key.second, nullptr, 0); + ret = mbedtls_pk_parse_key(&pk_key_, key.data(), key.size(), nullptr, 0); if (ret < 0) { print_error("mbedtls_pk_parse_keyfile", ret); return false; @@ -102,7 +98,7 @@ bool Tls::set_own_cert(const_buf crt, const_buf key) bool Tls::set_ca_cert(const_buf crt) { - int ret = mbedtls_x509_crt_parse(&ca_cert_, crt.first, crt.second); + int ret = mbedtls_x509_crt_parse(&ca_cert_, crt.data(), crt.size()); if (ret < 0) { print_error("mbedtls_x509_crt_parse", ret); return false; @@ -127,3 +123,12 @@ size_t Tls::get_available_bytes() { return ::mbedtls_ssl_get_bytes_avail(&ssl_); } + +Tls::~Tls() +{ + ::mbedtls_ssl_config_free(&conf_); + ::mbedtls_ssl_free(&ssl_); + ::mbedtls_pk_free(&pk_key_); + ::mbedtls_x509_crt_free(&public_cert_); + ::mbedtls_x509_crt_free(&ca_cert_); +} diff --git a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/tls_transport.cpp b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/tls_transport.cpp index 8e315af99..fba3d96fa 100644 --- a/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/tls_transport.cpp +++ b/components/esp_modem/examples/modem_tcp_client/components/extra_tcp_transports/tls_transport.cpp @@ -3,6 +3,8 @@ * * SPDX-License-Identifier: Apache-2.0 */ +#include "freertos/FreeRTOS.h" +#include "freertos/task.h" #include "esp_log.h" #include "esp_transport.h" #include "mbedtls_wrap.hpp" @@ -19,8 +21,9 @@ public: private: esp_transport_handle_t transport_{}; int connect(const char *host, int port, int timeout_ms); + void delay() override; - struct priv { + struct transport { static int connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms); static int read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms); static int write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms); @@ -57,7 +60,7 @@ int TlsTransport::recv(unsigned char *buf, size_t len) bool TlsTransport::set_func(esp_transport_handle_t tls_transport) { - return esp_transport_set_func(tls_transport, TlsTransport::priv::connect, TlsTransport::priv::read, TlsTransport::priv::write, TlsTransport::priv::close, TlsTransport::priv::poll_read, TlsTransport::priv::poll_write, TlsTransport::priv::destroy) == ESP_OK; + return esp_transport_set_func(tls_transport, TlsTransport::transport::connect, TlsTransport::transport::read, TlsTransport::transport::write, TlsTransport::transport::close, TlsTransport::transport::poll_read, TlsTransport::transport::poll_write, TlsTransport::transport::destroy) == ESP_OK; } int TlsTransport::connect(const char *host, int port, int timeout_ms) @@ -65,22 +68,24 @@ int TlsTransport::connect(const char *host, int port, int timeout_ms) return esp_transport_connect(transport_, host, port, timeout_ms); } -int TlsTransport::priv::connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +void TlsTransport::delay() { - ESP_LOGI("tag", "SSL connect!"); - auto tls = static_cast(esp_transport_get_context_data(t)); - tls->init(false, false); + vTaskDelay(pdMS_TO_TICKS(500)); +} + +int TlsTransport::transport::connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +{ + auto tls = static_cast(esp_transport_get_context_data(t)); + tls->init(is_server{false}, do_verify{false}); - ESP_LOGI("tag", "TCP connect!"); auto ret = tls->connect(host, port, timeout_ms); if (ret < 0) { - ESP_LOGI("tag", "TCP connect fail!"); return ret; } return tls->handshake(); } -int TlsTransport::priv::read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) +int TlsTransport::transport::read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { auto tls = static_cast(esp_transport_get_context_data(t)); if (tls->get_available_bytes() <= 0) { @@ -95,7 +100,7 @@ int TlsTransport::priv::read(esp_transport_handle_t t, char *buffer, int len, in return tls->read(reinterpret_cast(buffer), len); } -int TlsTransport::priv::write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) +int TlsTransport::transport::write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) { int poll; if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) { @@ -107,25 +112,25 @@ int TlsTransport::priv::write(esp_transport_handle_t t, const char *buffer, int return tls->write(reinterpret_cast(buffer), len); } -int TlsTransport::priv::close(esp_transport_handle_t t) +int TlsTransport::transport::close(esp_transport_handle_t t) { auto tls = static_cast(esp_transport_get_context_data(t)); return esp_transport_close(tls->transport_); } -int TlsTransport::priv::poll_read(esp_transport_handle_t t, int timeout_ms) +int TlsTransport::transport::poll_read(esp_transport_handle_t t, int timeout_ms) { auto tls = static_cast(esp_transport_get_context_data(t)); return esp_transport_poll_read(tls->transport_, timeout_ms); } -int TlsTransport::priv::poll_write(esp_transport_handle_t t, int timeout_ms) +int TlsTransport::transport::poll_write(esp_transport_handle_t t, int timeout_ms) { auto tls = static_cast(esp_transport_get_context_data(t)); return esp_transport_poll_write(tls->transport_, timeout_ms); } -int TlsTransport::priv::destroy(esp_transport_handle_t t) +int TlsTransport::transport::destroy(esp_transport_handle_t t) { auto tls = static_cast(esp_transport_get_context_data(t)); return esp_transport_destroy(tls->transport_);