fix(modem): Per review comments

This commit is contained in:
David Cermak
2023-06-28 21:42:27 +02:00
parent ae629ed3a9
commit d4925f2bd6
5 changed files with 51 additions and 34 deletions

View File

@ -22,6 +22,8 @@ jobs:
example: modem_tcp_client example: modem_tcp_client
- idf_ver: "release-v4.3" - idf_ver: "release-v4.3"
example: modem_tcp_client example: modem_tcp_client
- idf_ver: "release-v4.4"
example: modem_tcp_client
include: include:
- idf_ver: "release-v4.2" - idf_ver: "release-v4.2"
skip_config: usb skip_config: usb

View File

@ -1,7 +1,6 @@
# The following lines of boilerplate have to be in your project's CMakeLists # The following lines of boilerplate have to be in your project's CMakeLists
# in this exact order for cmake to work correctly # in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.8) cmake_minimum_required(VERSION 3.8)
set(CMAKE_CXX_STANDARD 17)
set(EXTRA_COMPONENT_DIRS "../..") set(EXTRA_COMPONENT_DIRS "../..")

View File

@ -6,28 +6,32 @@
#pragma once #pragma once
#include <utility> #include <utility>
#include <span>
#include "mbedtls/ssl.h" #include "mbedtls/ssl.h"
#include "mbedtls/entropy.h" #include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h" #include "mbedtls/ctr_drbg.h"
#include "mbedtls/error.h" #include "mbedtls/error.h"
using const_buf = std::pair<const unsigned char *, std::size_t>; using const_buf = std::span<const unsigned char>;
using buf = std::pair<unsigned char *, std::size_t>;
class Tls { class Tls {
public: public:
enum class is_server : bool {};
enum class do_verify : bool {};
Tls(); Tls();
bool init(bool is_server, bool verify); virtual ~Tls();
bool init(is_server server, do_verify verify);
int handshake(); int handshake();
int write(const unsigned char *buf, size_t len); int write(const unsigned char *buf, size_t len);
int read(unsigned char *buf, size_t len); int read(unsigned char *buf, size_t len);
bool set_own_cert(const_buf crt, const_buf key); [[nodiscard]] bool set_own_cert(const_buf crt, const_buf key);
bool set_ca_cert(const_buf crt); [[nodiscard]] bool set_ca_cert(const_buf crt);
virtual int send(const unsigned char *buf, size_t len) = 0; virtual int send(const unsigned char *buf, size_t len) = 0;
virtual int recv(unsigned char *buf, size_t len) = 0; virtual int recv(unsigned char *buf, size_t len) = 0;
size_t get_available_bytes(); size_t get_available_bytes();
private: protected:
mbedtls_ssl_context ssl_{}; mbedtls_ssl_context ssl_{};
mbedtls_x509_crt public_cert_{}; mbedtls_x509_crt public_cert_{};
mbedtls_pk_context pk_key_{}; mbedtls_pk_context pk_key_{};
@ -35,7 +39,9 @@ private:
mbedtls_ssl_config conf_{}; mbedtls_ssl_config conf_{};
mbedtls_ctr_drbg_context ctr_drbg_{}; mbedtls_ctr_drbg_context ctr_drbg_{};
mbedtls_entropy_context entropy_{}; mbedtls_entropy_context entropy_{};
virtual void delay() {}
private:
static void print_error(const char *function, int error_code); static void print_error(const char *function, int error_code);
static int bio_write(void *ctx, const unsigned char *buf, size_t len); static int bio_write(void *ctx, const unsigned char *buf, size_t len);
static int bio_read(void *ctx, unsigned char *buf, size_t len); static int bio_read(void *ctx, unsigned char *buf, size_t len);

View File

@ -7,24 +7,24 @@
#include "mbedtls/ssl.h" #include "mbedtls/ssl.h"
#include "mbedtls_wrap.hpp" #include "mbedtls_wrap.hpp"
bool Tls::init(bool is_server, bool verify) bool Tls::init(is_server server, do_verify verify)
{ {
const char pers[] = "mbedtls_wrapper"; const char pers[] = "mbedtls_wrapper";
mbedtls_entropy_init(&entropy_); mbedtls_entropy_init(&entropy_);
mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, (const unsigned char *)pers, sizeof(pers)); mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, (const unsigned char *)pers, sizeof(pers));
int ret = mbedtls_ssl_config_defaults(&conf_, is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); int ret = mbedtls_ssl_config_defaults(&conf_, server == is_server{true} ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
if (ret) { if (ret) {
print_error("mbedtls_ssl_config_defaults", ret); print_error("mbedtls_ssl_config_defaults", ret);
return false; return false;
} }
mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_); mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
mbedtls_ssl_conf_authmode(&conf_, verify ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_NONE); mbedtls_ssl_conf_authmode(&conf_, verify == do_verify{true} ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_NONE);
ret = mbedtls_ssl_conf_own_cert(&conf_, &public_cert_, &pk_key_); ret = mbedtls_ssl_conf_own_cert(&conf_, &public_cert_, &pk_key_);
if (ret) { if (ret) {
print_error("mbedtls_ssl_conf_own_cert", ret); print_error("mbedtls_ssl_conf_own_cert", ret);
return false; return false;
} }
if (verify) { if (verify == do_verify{true}) {
mbedtls_ssl_conf_ca_chain(&conf_, &ca_cert_, nullptr); mbedtls_ssl_conf_ca_chain(&conf_, &ca_cert_, nullptr);
} }
ret = mbedtls_ssl_setup(&ssl_, &conf_); ret = mbedtls_ssl_setup(&ssl_, &conf_);
@ -43,12 +43,9 @@ void Tls::print_error(const char *function, int error_code)
printf("%s() returned -0x%04X\n", function, -error_code); printf("%s() returned -0x%04X\n", function, -error_code);
printf("-0x%04X: %s\n", -error_code, error_buf); printf("-0x%04X: %s\n", -error_code, error_buf);
} }
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "esp_log.h"
int Tls::handshake() int Tls::handshake()
{ {
ESP_LOGI("TLS", "handshake");
int ret = 0; int ret = 0;
mbedtls_ssl_set_bio(&ssl_, this, bio_write, bio_read, nullptr); mbedtls_ssl_set_bio(&ssl_, this, bio_write, bio_read, nullptr);
@ -57,9 +54,8 @@ int Tls::handshake()
print_error( "mbedtls_ssl_handshake returned", ret ); print_error( "mbedtls_ssl_handshake returned", ret );
return -1; return -1;
} }
vTaskDelay(pdMS_TO_TICKS(500)); delay();
} }
ESP_LOGI("TLS", "handshake done with %d", ret);
return ret; return ret;
} }
@ -87,12 +83,12 @@ int Tls::read(unsigned char *buf, size_t len)
bool Tls::set_own_cert(const_buf crt, const_buf key) bool Tls::set_own_cert(const_buf crt, const_buf key)
{ {
int ret = mbedtls_x509_crt_parse(&public_cert_, crt.first, crt.second); int ret = mbedtls_x509_crt_parse(&public_cert_, crt.data(), crt.size());
if (ret < 0) { if (ret < 0) {
print_error("mbedtls_x509_crt_parse", ret); print_error("mbedtls_x509_crt_parse", ret);
return false; return false;
} }
ret = mbedtls_pk_parse_key(&pk_key_, key.first, key.second, nullptr, 0); ret = mbedtls_pk_parse_key(&pk_key_, key.data(), key.size(), nullptr, 0);
if (ret < 0) { if (ret < 0) {
print_error("mbedtls_pk_parse_keyfile", ret); print_error("mbedtls_pk_parse_keyfile", ret);
return false; return false;
@ -102,7 +98,7 @@ bool Tls::set_own_cert(const_buf crt, const_buf key)
bool Tls::set_ca_cert(const_buf crt) bool Tls::set_ca_cert(const_buf crt)
{ {
int ret = mbedtls_x509_crt_parse(&ca_cert_, crt.first, crt.second); int ret = mbedtls_x509_crt_parse(&ca_cert_, crt.data(), crt.size());
if (ret < 0) { if (ret < 0) {
print_error("mbedtls_x509_crt_parse", ret); print_error("mbedtls_x509_crt_parse", ret);
return false; return false;
@ -127,3 +123,12 @@ size_t Tls::get_available_bytes()
{ {
return ::mbedtls_ssl_get_bytes_avail(&ssl_); return ::mbedtls_ssl_get_bytes_avail(&ssl_);
} }
Tls::~Tls()
{
::mbedtls_ssl_config_free(&conf_);
::mbedtls_ssl_free(&ssl_);
::mbedtls_pk_free(&pk_key_);
::mbedtls_x509_crt_free(&public_cert_);
::mbedtls_x509_crt_free(&ca_cert_);
}

View File

@ -3,6 +3,8 @@
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "esp_log.h" #include "esp_log.h"
#include "esp_transport.h" #include "esp_transport.h"
#include "mbedtls_wrap.hpp" #include "mbedtls_wrap.hpp"
@ -19,8 +21,9 @@ public:
private: private:
esp_transport_handle_t transport_{}; esp_transport_handle_t transport_{};
int connect(const char *host, int port, int timeout_ms); int connect(const char *host, int port, int timeout_ms);
void delay() override;
struct priv { struct transport {
static int connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms); static int connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms);
static int read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms); static int read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms);
static int write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms); static int write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms);
@ -57,7 +60,7 @@ int TlsTransport::recv(unsigned char *buf, size_t len)
bool TlsTransport::set_func(esp_transport_handle_t tls_transport) bool TlsTransport::set_func(esp_transport_handle_t tls_transport)
{ {
return esp_transport_set_func(tls_transport, TlsTransport::priv::connect, TlsTransport::priv::read, TlsTransport::priv::write, TlsTransport::priv::close, TlsTransport::priv::poll_read, TlsTransport::priv::poll_write, TlsTransport::priv::destroy) == ESP_OK; return esp_transport_set_func(tls_transport, TlsTransport::transport::connect, TlsTransport::transport::read, TlsTransport::transport::write, TlsTransport::transport::close, TlsTransport::transport::poll_read, TlsTransport::transport::poll_write, TlsTransport::transport::destroy) == ESP_OK;
} }
int TlsTransport::connect(const char *host, int port, int timeout_ms) int TlsTransport::connect(const char *host, int port, int timeout_ms)
@ -65,22 +68,24 @@ int TlsTransport::connect(const char *host, int port, int timeout_ms)
return esp_transport_connect(transport_, host, port, timeout_ms); return esp_transport_connect(transport_, host, port, timeout_ms);
} }
int TlsTransport::priv::connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) void TlsTransport::delay()
{ {
ESP_LOGI("tag", "SSL connect!"); vTaskDelay(pdMS_TO_TICKS(500));
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t)); }
tls->init(false, false);
int TlsTransport::transport::connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
{
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
tls->init(is_server{false}, do_verify{false});
ESP_LOGI("tag", "TCP connect!");
auto ret = tls->connect(host, port, timeout_ms); auto ret = tls->connect(host, port, timeout_ms);
if (ret < 0) { if (ret < 0) {
ESP_LOGI("tag", "TCP connect fail!");
return ret; return ret;
} }
return tls->handshake(); return tls->handshake();
} }
int TlsTransport::priv::read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) int TlsTransport::transport::read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
{ {
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t)); auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
if (tls->get_available_bytes() <= 0) { if (tls->get_available_bytes() <= 0) {
@ -95,7 +100,7 @@ int TlsTransport::priv::read(esp_transport_handle_t t, char *buffer, int len, in
return tls->read(reinterpret_cast<unsigned char *>(buffer), len); return tls->read(reinterpret_cast<unsigned char *>(buffer), len);
} }
int TlsTransport::priv::write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) int TlsTransport::transport::write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
{ {
int poll; int poll;
if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) { if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
@ -107,25 +112,25 @@ int TlsTransport::priv::write(esp_transport_handle_t t, const char *buffer, int
return tls->write(reinterpret_cast<const unsigned char *>(buffer), len); return tls->write(reinterpret_cast<const unsigned char *>(buffer), len);
} }
int TlsTransport::priv::close(esp_transport_handle_t t) int TlsTransport::transport::close(esp_transport_handle_t t)
{ {
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t)); auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_close(tls->transport_); return esp_transport_close(tls->transport_);
} }
int TlsTransport::priv::poll_read(esp_transport_handle_t t, int timeout_ms) int TlsTransport::transport::poll_read(esp_transport_handle_t t, int timeout_ms)
{ {
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t)); auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_poll_read(tls->transport_, timeout_ms); return esp_transport_poll_read(tls->transport_, timeout_ms);
} }
int TlsTransport::priv::poll_write(esp_transport_handle_t t, int timeout_ms) int TlsTransport::transport::poll_write(esp_transport_handle_t t, int timeout_ms)
{ {
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t)); auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_poll_write(tls->transport_, timeout_ms); return esp_transport_poll_write(tls->transport_, timeout_ms);
} }
int TlsTransport::priv::destroy(esp_transport_handle_t t) int TlsTransport::transport::destroy(esp_transport_handle_t t)
{ {
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t)); auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_destroy(tls->transport_); return esp_transport_destroy(tls->transport_);