mirror of
				https://github.com/espressif/esp-protocols.git
				synced 2025-11-04 08:31:37 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			298 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			298 lines
		
	
	
		
			9.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
//
 | 
						|
// SPDX-FileCopyrightText: 2021-2022 Espressif Systems (Shanghai) CO LTD
 | 
						|
//
 | 
						|
// SPDX-License-Identifier: BSL-1.0
 | 
						|
//
 | 
						|
#pragma once
 | 
						|
 | 
						|
#include "mbedtls/ssl.h"
 | 
						|
#include "mbedtls/entropy.h"
 | 
						|
#include "mbedtls/ctr_drbg.h"
 | 
						|
#include "mbedtls/error.h"
 | 
						|
#include "mbedtls/esp_debug.h"
 | 
						|
#include "esp_log.h"
 | 
						|
 | 
						|
namespace asio {
 | 
						|
namespace ssl {
 | 
						|
namespace mbedtls {
 | 
						|
 | 
						|
const char *error_message(int error_code)
 | 
						|
{
 | 
						|
    static char error_buf[100];
 | 
						|
    mbedtls_strerror(error_code, error_buf, sizeof(error_buf));
 | 
						|
    return error_buf;
 | 
						|
}
 | 
						|
 | 
						|
void throw_alloc_failure(const char *location)
 | 
						|
{
 | 
						|
    asio::error_code ec( MBEDTLS_ERR_SSL_ALLOC_FAILED, asio::error::get_mbedtls_category());
 | 
						|
    asio::detail::throw_error(ec, location);
 | 
						|
}
 | 
						|
 | 
						|
namespace error_codes {
 | 
						|
 | 
						|
bool is_error(int ret)
 | 
						|
{
 | 
						|
    return  ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE;
 | 
						|
}
 | 
						|
 | 
						|
static bool want_write(int ret)
 | 
						|
{
 | 
						|
    return  ret == MBEDTLS_ERR_SSL_WANT_WRITE;
 | 
						|
}
 | 
						|
 | 
						|
static bool want_read(int ret)
 | 
						|
{
 | 
						|
    return  ret == MBEDTLS_ERR_SSL_WANT_READ;
 | 
						|
}
 | 
						|
 | 
						|
} // namespace error_codes
 | 
						|
 | 
						|
enum rw_state {
 | 
						|
    IDLE, READING, WRITING, CLOSED
 | 
						|
};
 | 
						|
 | 
						|
class engine {
 | 
						|
public:
 | 
						|
    explicit engine(std::shared_ptr<context> ctx): ctx_(std::move(ctx)),
 | 
						|
        bio_(bio::new_pair("mbedtls-engine")), state_(IDLE), verify_mode_(0) {}
 | 
						|
 | 
						|
    void set_verify_mode(asio::ssl::verify_mode mode)
 | 
						|
    {
 | 
						|
        verify_mode_ = mode;
 | 
						|
    }
 | 
						|
 | 
						|
    bio *ext_bio() const
 | 
						|
    {
 | 
						|
        return bio_.second.get();
 | 
						|
    }
 | 
						|
 | 
						|
    rw_state get_state() const
 | 
						|
    {
 | 
						|
        return state_;
 | 
						|
    }
 | 
						|
 | 
						|
    int shutdown()
 | 
						|
    {
 | 
						|
        int ret = mbedtls_ssl_close_notify(&impl_.ssl_);
 | 
						|
        if (ret) {
 | 
						|
            impl::print_error("mbedtls_ssl_close_notify", ret);
 | 
						|
        }
 | 
						|
        state_ = CLOSED;
 | 
						|
        return ret;
 | 
						|
    }
 | 
						|
 | 
						|
    int connect()
 | 
						|
    {
 | 
						|
        return handshake(true);
 | 
						|
    }
 | 
						|
 | 
						|
    int accept()
 | 
						|
    {
 | 
						|
        return handshake(false);
 | 
						|
    }
 | 
						|
 | 
						|
    int write(const void *buffer, int len)
 | 
						|
    {
 | 
						|
        int ret = impl_.write(buffer, len);
 | 
						|
        state_ = ret == len ? IDLE : WRITING;
 | 
						|
        return ret;
 | 
						|
    }
 | 
						|
 | 
						|
    int read(void *buffer, int len)
 | 
						|
    {
 | 
						|
        int ret = impl_.read(buffer, len);
 | 
						|
        state_ = ret == len ? IDLE : READING;
 | 
						|
        return ret;
 | 
						|
    }
 | 
						|
 | 
						|
private:
 | 
						|
    int handshake(bool is_client_not_server)
 | 
						|
    {
 | 
						|
        if (impl_.before_handshake()) {
 | 
						|
            impl_.configure(ctx_.get(), is_client_not_server, impl_verify_mode(is_client_not_server));
 | 
						|
        }
 | 
						|
        return do_handshake();
 | 
						|
    }
 | 
						|
 | 
						|
    static int bio_read(void *ctx, unsigned char *buf, size_t len)
 | 
						|
    {
 | 
						|
        auto bio = static_cast<BIO *>(ctx);
 | 
						|
        int read = bio->read(buf, len);
 | 
						|
        if (read <= 0 && bio->should_read()) {
 | 
						|
            return MBEDTLS_ERR_SSL_WANT_READ;
 | 
						|
        }
 | 
						|
        return read;
 | 
						|
    }
 | 
						|
 | 
						|
    static int bio_write(void *ctx, const unsigned char *buf, size_t len)
 | 
						|
    {
 | 
						|
        auto bio = static_cast<BIO *>(ctx);
 | 
						|
        int written = bio->write(buf, len);
 | 
						|
        if (written <= 0 && bio->should_write()) {
 | 
						|
            return MBEDTLS_ERR_SSL_WANT_WRITE;
 | 
						|
        }
 | 
						|
        return written;
 | 
						|
    }
 | 
						|
 | 
						|
    int do_handshake()
 | 
						|
    {
 | 
						|
        int ret = 0;
 | 
						|
        mbedtls_ssl_set_bio(&impl_.ssl_, bio_.first.get(), bio_write, bio_read, nullptr);
 | 
						|
 | 
						|
        while (impl_.ssl_.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER) {
 | 
						|
            ret = mbedtls_ssl_handshake_step(&impl_.ssl_);
 | 
						|
 | 
						|
            if (ret != 0) {
 | 
						|
                if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
 | 
						|
                    impl::print_error("mbedtls_ssl_handshake_step", ret);
 | 
						|
                }
 | 
						|
                if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
 | 
						|
                    state_ = READING;
 | 
						|
                } else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
 | 
						|
                    state_ = WRITING;
 | 
						|
                }
 | 
						|
                break;
 | 
						|
            }
 | 
						|
        }
 | 
						|
        return ret;
 | 
						|
    }
 | 
						|
 | 
						|
    // Converts OpenSSL verification mode to mbedtls enum
 | 
						|
    int impl_verify_mode(bool is_client_not_server) const
 | 
						|
    {
 | 
						|
        int mode = MBEDTLS_SSL_VERIFY_UNSET;
 | 
						|
        if (is_client_not_server) {
 | 
						|
            if (verify_mode_ & SSL_VERIFY_PEER) {
 | 
						|
                mode = MBEDTLS_SSL_VERIFY_REQUIRED;
 | 
						|
            } else if (verify_mode_ == SSL_VERIFY_NONE) {
 | 
						|
                mode = MBEDTLS_SSL_VERIFY_NONE;
 | 
						|
            }
 | 
						|
        } else {
 | 
						|
            if (verify_mode_ & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
 | 
						|
                mode = MBEDTLS_SSL_VERIFY_REQUIRED;
 | 
						|
            } else if (verify_mode_ & SSL_VERIFY_PEER) {
 | 
						|
                mode = MBEDTLS_SSL_VERIFY_OPTIONAL;
 | 
						|
            } else if (verify_mode_ == SSL_VERIFY_NONE) {
 | 
						|
                mode = MBEDTLS_SSL_VERIFY_NONE;
 | 
						|
            }
 | 
						|
        }
 | 
						|
        return mode;
 | 
						|
    }
 | 
						|
 | 
						|
    struct impl {
 | 
						|
        static void print_error(const char *function, int error_code)
 | 
						|
        {
 | 
						|
            constexpr const char *TAG = "mbedtls-engine-impl";
 | 
						|
            ESP_LOGE(TAG, "%s() returned -0x%04X", function, -error_code);
 | 
						|
            ESP_LOGI(TAG, "-0x%04X: %s", -error_code, error_message(error_code));
 | 
						|
        }
 | 
						|
 | 
						|
        bool before_handshake() const
 | 
						|
        {
 | 
						|
            return ssl_.MBEDTLS_PRIVATE(state) == 0;
 | 
						|
        }
 | 
						|
 | 
						|
        int write(const void *buffer, int len)
 | 
						|
        {
 | 
						|
            int ret = mbedtls_ssl_write(&ssl_, static_cast<const unsigned char *>(buffer), len);
 | 
						|
            if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
 | 
						|
                print_error("mbedtls_ssl_write", ret);
 | 
						|
            }
 | 
						|
            return ret;
 | 
						|
        }
 | 
						|
 | 
						|
        int read(void *buffer, int len)
 | 
						|
        {
 | 
						|
            int ret = mbedtls_ssl_read(&ssl_, static_cast<unsigned char *>(buffer), len);
 | 
						|
            if (ret < 0 && ret != MBEDTLS_ERR_SSL_WANT_READ) {
 | 
						|
                print_error("mbedtls_ssl_read", ret);
 | 
						|
            }
 | 
						|
            return ret;
 | 
						|
        }
 | 
						|
 | 
						|
        impl()
 | 
						|
        {
 | 
						|
            const unsigned char pers[] = "asio ssl";
 | 
						|
            mbedtls_ssl_init(&ssl_);
 | 
						|
            mbedtls_ssl_config_init(&conf_);
 | 
						|
            mbedtls_ctr_drbg_init(&ctr_drbg_);
 | 
						|
#ifdef CONFIG_MBEDTLS_DEBUG
 | 
						|
            mbedtls_esp_enable_debug_log(&conf_, CONFIG_MBEDTLS_DEBUG_LEVEL);
 | 
						|
#endif
 | 
						|
            mbedtls_entropy_init(&entropy_);
 | 
						|
            mbedtls_ctr_drbg_seed(&ctr_drbg_, mbedtls_entropy_func, &entropy_, pers, sizeof(pers));
 | 
						|
            mbedtls_x509_crt_init(&public_cert_);
 | 
						|
            mbedtls_pk_init(&pk_key_);
 | 
						|
            mbedtls_x509_crt_init(&ca_cert_);
 | 
						|
        }
 | 
						|
 | 
						|
        bool configure(context *ctx, bool is_client_not_server, int mbedtls_verify_mode)
 | 
						|
        {
 | 
						|
            mbedtls_x509_crt_init(&public_cert_);
 | 
						|
            mbedtls_pk_init(&pk_key_);
 | 
						|
            mbedtls_x509_crt_init(&ca_cert_);
 | 
						|
            int ret = mbedtls_ssl_config_defaults(&conf_, is_client_not_server ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
 | 
						|
                                                  MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
 | 
						|
            if (ret) {
 | 
						|
                print_error("mbedtls_ssl_config_defaults", ret);
 | 
						|
                return false;
 | 
						|
            }
 | 
						|
            mbedtls_ssl_conf_rng(&conf_, mbedtls_ctr_drbg_random, &ctr_drbg_);
 | 
						|
            mbedtls_ssl_conf_authmode(&conf_, mbedtls_verify_mode);
 | 
						|
            if (ctx->cert_chain_.size() > 0 && ctx->private_key_.size() > 0) {
 | 
						|
                ret = mbedtls_x509_crt_parse(&public_cert_, ctx->data(container::CERT), ctx->size(container::CERT));
 | 
						|
                if (ret < 0) {
 | 
						|
                    print_error("mbedtls_x509_crt_parse", ret);
 | 
						|
                    return false;
 | 
						|
                }
 | 
						|
                ret = mbedtls_pk_parse_key(&pk_key_, ctx->data(container::PRIVKEY), ctx->size(container::PRIVKEY),
 | 
						|
                                           nullptr, 0, mbedtls_ctr_drbg_random, &ctr_drbg_);
 | 
						|
                if (ret < 0) {
 | 
						|
                    print_error("mbedtls_pk_parse_keyfile", ret);
 | 
						|
                    return false;
 | 
						|
                }
 | 
						|
                ret = mbedtls_ssl_conf_own_cert(&conf_, &public_cert_, &pk_key_);
 | 
						|
                if (ret) {
 | 
						|
                    print_error("mbedtls_ssl_conf_own_cert", ret);
 | 
						|
                    return false;
 | 
						|
                }
 | 
						|
            }
 | 
						|
 | 
						|
            if (ctx->ca_cert_.size() > 0) {
 | 
						|
                ret = mbedtls_x509_crt_parse(&ca_cert_, ctx->data(container::CA_CERT), ctx->size(container::CA_CERT));
 | 
						|
                if (ret < 0) {
 | 
						|
                    print_error("mbedtls_x509_crt_parse", ret);
 | 
						|
                    return false;
 | 
						|
                }
 | 
						|
                mbedtls_ssl_conf_ca_chain(&conf_, &ca_cert_, nullptr);
 | 
						|
            } else {
 | 
						|
                mbedtls_ssl_conf_ca_chain(&conf_, nullptr, nullptr);
 | 
						|
            }
 | 
						|
            ret = mbedtls_ssl_setup(&ssl_, &conf_);
 | 
						|
            if (ret) {
 | 
						|
                print_error("mbedtls_ssl_setup", ret);
 | 
						|
                return false;
 | 
						|
            }
 | 
						|
            return true;
 | 
						|
        }
 | 
						|
        mbedtls_ssl_context ssl_{};
 | 
						|
        mbedtls_entropy_context entropy_{};
 | 
						|
        mbedtls_ctr_drbg_context ctr_drbg_{};
 | 
						|
        mbedtls_ssl_config conf_{};
 | 
						|
        mbedtls_x509_crt public_cert_{};
 | 
						|
        mbedtls_pk_context pk_key_{};
 | 
						|
        mbedtls_x509_crt ca_cert_{};
 | 
						|
    };
 | 
						|
 | 
						|
    impl impl_{};
 | 
						|
    std::shared_ptr<context> ctx_;
 | 
						|
    std::pair<std::shared_ptr<bio>, std::shared_ptr<bio>> bio_;
 | 
						|
    enum rw_state state_;
 | 
						|
    asio::ssl::verify_mode verify_mode_;
 | 
						|
};
 | 
						|
 | 
						|
}
 | 
						|
}
 | 
						|
} // namespace asio::ssl::mbedtls
 |