fix(modem): Per review comments

This commit is contained in:
David Cermak
2023-06-28 21:42:27 +02:00
parent ae629ed3a9
commit d4925f2bd6
5 changed files with 51 additions and 34 deletions

View File

@ -1,7 +1,6 @@
# The following lines of boilerplate have to be in your project's CMakeLists
# in this exact order for cmake to work correctly
cmake_minimum_required(VERSION 3.8)
set(CMAKE_CXX_STANDARD 17)
set(EXTRA_COMPONENT_DIRS "../..")

View File

@ -6,28 +6,32 @@
#pragma once
#include <utility>
#include <span>
#include "mbedtls/ssl.h"
#include "mbedtls/entropy.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/error.h"
using const_buf = std::pair<const unsigned char *, std::size_t>;
using buf = std::pair<unsigned char *, std::size_t>;
using const_buf = std::span<const unsigned char>;
class Tls {
public:
enum class is_server : bool {};
enum class do_verify : bool {};
Tls();
bool init(bool is_server, bool verify);
virtual ~Tls();
bool init(is_server server, do_verify verify);
int handshake();
int write(const unsigned char *buf, size_t len);
int read(unsigned char *buf, size_t len);
bool set_own_cert(const_buf crt, const_buf key);
bool set_ca_cert(const_buf crt);
[[nodiscard]] bool set_own_cert(const_buf crt, const_buf key);
[[nodiscard]] bool set_ca_cert(const_buf crt);
virtual int send(const unsigned char *buf, size_t len) = 0;
virtual int recv(unsigned char *buf, size_t len) = 0;
size_t get_available_bytes();
private:
protected:
mbedtls_ssl_context ssl_{};
mbedtls_x509_crt public_cert_{};
mbedtls_pk_context pk_key_{};
@ -35,7 +39,9 @@ private:
mbedtls_ssl_config conf_{};
mbedtls_ctr_drbg_context ctr_drbg_{};
mbedtls_entropy_context entropy_{};
virtual void delay() {}
private:
static void print_error(const char *function, int error_code);
static int bio_write(void *ctx, const unsigned char *buf, size_t len);
static int bio_read(void *ctx, unsigned char *buf, size_t len);

View File

@ -7,24 +7,24 @@
#include "mbedtls/ssl.h"
#include "mbedtls_wrap.hpp"
bool Tls::init(bool is_server, bool verify)
bool Tls::init(is_server server, do_verify verify)
{
const char pers[] = "mbedtls_wrapper";
mbedtls_entropy_init(&entropy_);
mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, (const unsigned char *)pers, sizeof(pers));
int ret = mbedtls_ssl_config_defaults(&conf_, is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
int ret = mbedtls_ssl_config_defaults(&conf_, server == is_server{true} ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
if (ret) {
print_error("mbedtls_ssl_config_defaults", ret);
return false;
}
mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
mbedtls_ssl_conf_authmode(&conf_, verify ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_NONE);
mbedtls_ssl_conf_authmode(&conf_, verify == do_verify{true} ? MBEDTLS_SSL_VERIFY_REQUIRED : MBEDTLS_SSL_VERIFY_NONE);
ret = mbedtls_ssl_conf_own_cert(&conf_, &public_cert_, &pk_key_);
if (ret) {
print_error("mbedtls_ssl_conf_own_cert", ret);
return false;
}
if (verify) {
if (verify == do_verify{true}) {
mbedtls_ssl_conf_ca_chain(&conf_, &ca_cert_, nullptr);
}
ret = mbedtls_ssl_setup(&ssl_, &conf_);
@ -43,12 +43,9 @@ void Tls::print_error(const char *function, int error_code)
printf("%s() returned -0x%04X\n", function, -error_code);
printf("-0x%04X: %s\n", -error_code, error_buf);
}
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "esp_log.h"
int Tls::handshake()
{
ESP_LOGI("TLS", "handshake");
int ret = 0;
mbedtls_ssl_set_bio(&ssl_, this, bio_write, bio_read, nullptr);
@ -57,9 +54,8 @@ int Tls::handshake()
print_error( "mbedtls_ssl_handshake returned", ret );
return -1;
}
vTaskDelay(pdMS_TO_TICKS(500));
delay();
}
ESP_LOGI("TLS", "handshake done with %d", ret);
return ret;
}
@ -87,12 +83,12 @@ int Tls::read(unsigned char *buf, size_t len)
bool Tls::set_own_cert(const_buf crt, const_buf key)
{
int ret = mbedtls_x509_crt_parse(&public_cert_, crt.first, crt.second);
int ret = mbedtls_x509_crt_parse(&public_cert_, crt.data(), crt.size());
if (ret < 0) {
print_error("mbedtls_x509_crt_parse", ret);
return false;
}
ret = mbedtls_pk_parse_key(&pk_key_, key.first, key.second, nullptr, 0);
ret = mbedtls_pk_parse_key(&pk_key_, key.data(), key.size(), nullptr, 0);
if (ret < 0) {
print_error("mbedtls_pk_parse_keyfile", ret);
return false;
@ -102,7 +98,7 @@ bool Tls::set_own_cert(const_buf crt, const_buf key)
bool Tls::set_ca_cert(const_buf crt)
{
int ret = mbedtls_x509_crt_parse(&ca_cert_, crt.first, crt.second);
int ret = mbedtls_x509_crt_parse(&ca_cert_, crt.data(), crt.size());
if (ret < 0) {
print_error("mbedtls_x509_crt_parse", ret);
return false;
@ -127,3 +123,12 @@ size_t Tls::get_available_bytes()
{
return ::mbedtls_ssl_get_bytes_avail(&ssl_);
}
Tls::~Tls()
{
::mbedtls_ssl_config_free(&conf_);
::mbedtls_ssl_free(&ssl_);
::mbedtls_pk_free(&pk_key_);
::mbedtls_x509_crt_free(&public_cert_);
::mbedtls_x509_crt_free(&ca_cert_);
}

View File

@ -3,6 +3,8 @@
*
* SPDX-License-Identifier: Apache-2.0
*/
#include "freertos/FreeRTOS.h"
#include "freertos/task.h"
#include "esp_log.h"
#include "esp_transport.h"
#include "mbedtls_wrap.hpp"
@ -19,8 +21,9 @@ public:
private:
esp_transport_handle_t transport_{};
int connect(const char *host, int port, int timeout_ms);
void delay() override;
struct priv {
struct transport {
static int connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms);
static int read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms);
static int write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms);
@ -57,7 +60,7 @@ int TlsTransport::recv(unsigned char *buf, size_t len)
bool TlsTransport::set_func(esp_transport_handle_t tls_transport)
{
return esp_transport_set_func(tls_transport, TlsTransport::priv::connect, TlsTransport::priv::read, TlsTransport::priv::write, TlsTransport::priv::close, TlsTransport::priv::poll_read, TlsTransport::priv::poll_write, TlsTransport::priv::destroy) == ESP_OK;
return esp_transport_set_func(tls_transport, TlsTransport::transport::connect, TlsTransport::transport::read, TlsTransport::transport::write, TlsTransport::transport::close, TlsTransport::transport::poll_read, TlsTransport::transport::poll_write, TlsTransport::transport::destroy) == ESP_OK;
}
int TlsTransport::connect(const char *host, int port, int timeout_ms)
@ -65,22 +68,24 @@ int TlsTransport::connect(const char *host, int port, int timeout_ms)
return esp_transport_connect(transport_, host, port, timeout_ms);
}
int TlsTransport::priv::connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
void TlsTransport::delay()
{
ESP_LOGI("tag", "SSL connect!");
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
tls->init(false, false);
vTaskDelay(pdMS_TO_TICKS(500));
}
int TlsTransport::transport::connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
{
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
tls->init(is_server{false}, do_verify{false});
ESP_LOGI("tag", "TCP connect!");
auto ret = tls->connect(host, port, timeout_ms);
if (ret < 0) {
ESP_LOGI("tag", "TCP connect fail!");
return ret;
}
return tls->handshake();
}
int TlsTransport::priv::read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
int TlsTransport::transport::read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms)
{
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
if (tls->get_available_bytes() <= 0) {
@ -95,7 +100,7 @@ int TlsTransport::priv::read(esp_transport_handle_t t, char *buffer, int len, in
return tls->read(reinterpret_cast<unsigned char *>(buffer), len);
}
int TlsTransport::priv::write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
int TlsTransport::transport::write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms)
{
int poll;
if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) {
@ -107,25 +112,25 @@ int TlsTransport::priv::write(esp_transport_handle_t t, const char *buffer, int
return tls->write(reinterpret_cast<const unsigned char *>(buffer), len);
}
int TlsTransport::priv::close(esp_transport_handle_t t)
int TlsTransport::transport::close(esp_transport_handle_t t)
{
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_close(tls->transport_);
}
int TlsTransport::priv::poll_read(esp_transport_handle_t t, int timeout_ms)
int TlsTransport::transport::poll_read(esp_transport_handle_t t, int timeout_ms)
{
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_poll_read(tls->transport_, timeout_ms);
}
int TlsTransport::priv::poll_write(esp_transport_handle_t t, int timeout_ms)
int TlsTransport::transport::poll_write(esp_transport_handle_t t, int timeout_ms)
{
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_poll_write(tls->transport_, timeout_ms);
}
int TlsTransport::priv::destroy(esp_transport_handle_t t)
int TlsTransport::transport::destroy(esp_transport_handle_t t)
{
auto tls = static_cast<TlsTransport *>(esp_transport_get_context_data(t));
return esp_transport_destroy(tls->transport_);