diff --git a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c index b26d092279..fcebfab1f9 100644 --- a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c +++ b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.c @@ -23,6 +23,33 @@ static const char *TAG = "Dynamic Impl"; +static void esp_mbedtls_set_buf_state(unsigned char *buf, esp_mbedtls_ssl_buf_states state) +{ + struct esp_mbedtls_ssl_buf *temp = __containerof(buf, struct esp_mbedtls_ssl_buf, buf[0]); + temp->state = state; +} + +static esp_mbedtls_ssl_buf_states esp_mbedtls_get_buf_state(unsigned char *buf) +{ + struct esp_mbedtls_ssl_buf *temp = __containerof(buf, struct esp_mbedtls_ssl_buf, buf[0]); + return temp->state; +} + +void esp_mbedtls_free_buf(unsigned char *buf) +{ + struct esp_mbedtls_ssl_buf *temp = __containerof(buf, struct esp_mbedtls_ssl_buf, buf[0]); + ESP_LOGV(TAG, "free buffer @ %p", temp); + mbedtls_free(temp); +} + +static void esp_mbedtls_init_ssl_buf(struct esp_mbedtls_ssl_buf *buf, unsigned int len) +{ + if (buf) { + buf->state = ESP_MBEDTLS_SSL_BUF_CACHED; + buf->len = len; + } +} + static void esp_mbedtls_parse_record_header(mbedtls_ssl_context *ssl) { ssl->in_msgtype = ssl->in_hdr[0]; @@ -118,21 +145,22 @@ static void init_rx_buffer(mbedtls_ssl_context *ssl, unsigned char *buf) static int esp_mbedtls_alloc_tx_buf(mbedtls_ssl_context *ssl, int len) { - unsigned char *buf; + struct esp_mbedtls_ssl_buf *esp_buf; if (ssl->out_buf) { - mbedtls_free(ssl->out_buf); + esp_mbedtls_free_buf(ssl->out_buf); ssl->out_buf = NULL; } - buf = mbedtls_calloc(1, len); - if (!buf) { - ESP_LOGE(TAG, "alloc(%d bytes) failed", len); + esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + len); + if (!esp_buf) { + ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + len); return MBEDTLS_ERR_SSL_ALLOC_FAILED; } - ESP_LOGV(TAG, "add out buffer %d bytes @ %p", len, buf); + ESP_LOGV(TAG, "add out buffer %d bytes @ %p", len, esp_buf->buf); + esp_mbedtls_init_ssl_buf(esp_buf, len); /** * Mark the out_msg offset from ssl->out_buf. * @@ -140,7 +168,7 @@ static int esp_mbedtls_alloc_tx_buf(mbedtls_ssl_context *ssl, int len) */ ssl->out_msg = (unsigned char *)MBEDTLS_SSL_HEADER_LEN; - init_tx_buffer(ssl, buf); + init_tx_buffer(ssl, esp_buf->buf); return 0; } @@ -150,7 +178,7 @@ int esp_mbedtls_setup_tx_buffer(mbedtls_ssl_context *ssl) CHECK_OK(esp_mbedtls_alloc_tx_buf(ssl, TX_IDLE_BUFFER_SIZE)); /* mark the out buffer has no data cached */ - ssl->out_iv = NULL; + esp_mbedtls_set_buf_state(ssl->out_buf, ESP_MBEDTLS_SSL_BUF_NO_CACHED); return 0; } @@ -168,10 +196,7 @@ int esp_mbedtls_reset_add_tx_buffer(mbedtls_ssl_context *ssl) int esp_mbedtls_reset_free_tx_buffer(mbedtls_ssl_context *ssl) { - ESP_LOGV(TAG, "free out buffer @ %p", ssl->out_buf); - - mbedtls_free(ssl->out_buf); - + esp_mbedtls_free_buf(ssl->out_buf); init_tx_buffer(ssl, NULL); CHECK_OK(esp_mbedtls_setup_tx_buffer(ssl)); @@ -181,21 +206,22 @@ int esp_mbedtls_reset_free_tx_buffer(mbedtls_ssl_context *ssl) int esp_mbedtls_reset_add_rx_buffer(mbedtls_ssl_context *ssl) { - unsigned char *buf; + struct esp_mbedtls_ssl_buf *esp_buf; if (ssl->in_buf) { - mbedtls_free(ssl->in_buf); + esp_mbedtls_free_buf(ssl->in_buf); ssl->in_buf = NULL; } - buf = mbedtls_calloc(1, MBEDTLS_SSL_IN_BUFFER_LEN); - if (!buf) { - ESP_LOGE(TAG, "alloc(%d bytes) failed", MBEDTLS_SSL_IN_BUFFER_LEN); + esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + MBEDTLS_SSL_IN_BUFFER_LEN); + if (!esp_buf) { + ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + MBEDTLS_SSL_IN_BUFFER_LEN); return MBEDTLS_ERR_SSL_ALLOC_FAILED; } - ESP_LOGV(TAG, "add in buffer %d bytes @ %p", MBEDTLS_SSL_IN_BUFFER_LEN, buf); + ESP_LOGV(TAG, "add in buffer %d bytes @ %p", MBEDTLS_SSL_IN_BUFFER_LEN, esp_buf->buf); + esp_mbedtls_init_ssl_buf(esp_buf, MBEDTLS_SSL_IN_BUFFER_LEN); /** * Mark the in_msg offset from ssl->in_buf. * @@ -203,38 +229,34 @@ int esp_mbedtls_reset_add_rx_buffer(mbedtls_ssl_context *ssl) */ ssl->in_msg = (unsigned char *)MBEDTLS_SSL_HEADER_LEN; - init_rx_buffer(ssl, buf); + init_rx_buffer(ssl, esp_buf->buf); return 0; } void esp_mbedtls_reset_free_rx_buffer(mbedtls_ssl_context *ssl) { - ESP_LOGV(TAG, "free in buffer @ %p", ssl->in_buf); - - mbedtls_free(ssl->in_buf); - - init_rx_buffer(ssl, NULL); + esp_mbedtls_free_buf(ssl->in_buf); + init_rx_buffer(ssl, NULL); } int esp_mbedtls_add_tx_buffer(mbedtls_ssl_context *ssl, size_t buffer_len) { int ret = 0; int cached = 0; - unsigned char *buf; + struct esp_mbedtls_ssl_buf *esp_buf; unsigned char cache_buf[CACHE_BUFFER_SIZE]; ESP_LOGV(TAG, "--> add out"); if (ssl->out_buf) { - if (ssl->out_iv) { + if (esp_mbedtls_get_buf_state(ssl->out_buf) == ESP_MBEDTLS_SSL_BUF_CACHED) { ESP_LOGV(TAG, "out buffer is not empty"); ret = 0; goto exit; } else { memcpy(cache_buf, ssl->out_buf, CACHE_BUFFER_SIZE); - - mbedtls_free(ssl->out_buf); + esp_mbedtls_free_buf(ssl->out_buf); init_tx_buffer(ssl, NULL); cached = 1; } @@ -242,15 +264,17 @@ int esp_mbedtls_add_tx_buffer(mbedtls_ssl_context *ssl, size_t buffer_len) buffer_len = tx_buffer_len(ssl, buffer_len); - buf = mbedtls_calloc(1, buffer_len); - if (!buf) { - ESP_LOGE(TAG, "alloc(%d bytes) failed", buffer_len); + esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + buffer_len); + if (!esp_buf) { + ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + buffer_len); ret = MBEDTLS_ERR_SSL_ALLOC_FAILED; goto exit; } - ESP_LOGV(TAG, "add out buffer %d bytes @ %p", buffer_len, buf); - init_tx_buffer(ssl, buf); + ESP_LOGV(TAG, "add out buffer %d bytes @ %p", buffer_len, esp_buf->buf); + + esp_mbedtls_init_ssl_buf(esp_buf, buffer_len); + init_tx_buffer(ssl, esp_buf->buf); if (cached) { memcpy(ssl->out_ctr, cache_buf, COUNTER_SIZE); @@ -270,11 +294,11 @@ int esp_mbedtls_free_tx_buffer(mbedtls_ssl_context *ssl) { int ret = 0; unsigned char buf[CACHE_BUFFER_SIZE]; - unsigned char *pdata; + struct esp_mbedtls_ssl_buf *esp_buf; ESP_LOGV(TAG, "--> free out"); - if (!ssl->out_buf || (ssl->out_buf && !ssl->out_iv)) { + if (!ssl->out_buf || (ssl->out_buf && (esp_mbedtls_get_buf_state(ssl->out_buf) == ESP_MBEDTLS_SSL_BUF_NO_CACHED))) { ret = 0; goto exit; } @@ -282,22 +306,19 @@ int esp_mbedtls_free_tx_buffer(mbedtls_ssl_context *ssl) memcpy(buf, ssl->out_ctr, COUNTER_SIZE); memcpy(buf + COUNTER_SIZE, ssl->out_iv, CACHE_IV_SIZE); - ESP_LOGV(TAG, "free out buffer @ %p", ssl->out_buf); - - mbedtls_free(ssl->out_buf); - + esp_mbedtls_free_buf(ssl->out_buf); init_tx_buffer(ssl, NULL); - pdata = mbedtls_calloc(1, TX_IDLE_BUFFER_SIZE); - if (!pdata) { - ESP_LOGE(TAG, "alloc(%d bytes) failed", TX_IDLE_BUFFER_SIZE); + esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + TX_IDLE_BUFFER_SIZE); + if (!esp_buf) { + ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + TX_IDLE_BUFFER_SIZE); return MBEDTLS_ERR_SSL_ALLOC_FAILED; } - memcpy(pdata, buf, CACHE_BUFFER_SIZE); - init_tx_buffer(ssl, pdata); - ssl->out_iv = NULL; - + esp_mbedtls_init_ssl_buf(esp_buf, TX_IDLE_BUFFER_SIZE); + memcpy(esp_buf->buf, buf, CACHE_BUFFER_SIZE); + init_tx_buffer(ssl, esp_buf->buf); + esp_mbedtls_set_buf_state(ssl->out_buf, ESP_MBEDTLS_SSL_BUF_NO_CACHED); exit: ESP_LOGV(TAG, "<-- free out"); @@ -309,7 +330,7 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl) int cached = 0; int ret = 0; int buffer_len; - unsigned char *buf; + struct esp_mbedtls_ssl_buf *esp_buf; unsigned char cache_buf[16]; unsigned char msg_head[5]; size_t in_msglen, in_left; @@ -317,9 +338,13 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl) ESP_LOGV(TAG, "--> add rx"); if (ssl->in_buf) { - ESP_LOGV(TAG, "in buffer is not empty"); - ret = 0; - goto exit; + if (esp_mbedtls_get_buf_state(ssl->in_buf) == ESP_MBEDTLS_SSL_BUF_CACHED) { + ESP_LOGV(TAG, "in buffer is not empty"); + ret = 0; + goto exit; + } else { + cached = 1; + } } ssl->in_hdr = msg_head; @@ -346,22 +371,23 @@ int esp_mbedtls_add_rx_buffer(mbedtls_ssl_context *ssl) ESP_LOGV(TAG, "message length is %d RX buffer length should be %d left is %d", (int)in_msglen, (int)buffer_len, (int)ssl->in_left); - buf = mbedtls_calloc(1, buffer_len); - if (!buf) { - ESP_LOGE(TAG, "alloc(%d bytes) failed", buffer_len); + if (cached) { + memcpy(cache_buf, ssl->in_buf, 16); + esp_mbedtls_free_buf(ssl->in_buf); + init_rx_buffer(ssl, NULL); + } + + esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + buffer_len); + if (!esp_buf) { + ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + buffer_len); ret = MBEDTLS_ERR_SSL_ALLOC_FAILED; goto exit; } - ESP_LOGV(TAG, "add in buffer %d bytes @ %p", buffer_len, buf); + ESP_LOGV(TAG, "add in buffer %d bytes @ %p", buffer_len, esp_buf->buf); - if (ssl->in_ctr) { - memcpy(cache_buf, ssl->in_ctr, 16); - mbedtls_free(ssl->in_ctr); - cached = 1; - } - - init_rx_buffer(ssl, buf); + esp_mbedtls_init_ssl_buf(esp_buf, buffer_len); + init_rx_buffer(ssl, esp_buf->buf); if (cached) { memcpy(ssl->in_ctr, cache_buf, 8); @@ -382,14 +408,15 @@ int esp_mbedtls_free_rx_buffer(mbedtls_ssl_context *ssl) { int ret = 0; unsigned char buf[16]; - unsigned char *pdata; + struct esp_mbedtls_ssl_buf *esp_buf; ESP_LOGV(TAG, "--> free rx"); /** * When have read multi messages once, can't free the input buffer directly. */ - if (!ssl->in_buf || (ssl->in_hslen && (ssl->in_hslen < ssl->in_msglen))) { + if (!ssl->in_buf || (ssl->in_hslen && (ssl->in_hslen < ssl->in_msglen)) || + (ssl->in_buf && (esp_mbedtls_get_buf_state(ssl->in_buf) == ESP_MBEDTLS_SSL_BUF_NO_CACHED))) { ret = 0; goto exit; } @@ -404,22 +431,20 @@ int esp_mbedtls_free_rx_buffer(mbedtls_ssl_context *ssl) memcpy(buf, ssl->in_ctr, 8); memcpy(buf + 8, ssl->in_iv, 8); - ESP_LOGV(TAG, "free in buffer @ %p", ssl->out_buf); - - mbedtls_free(ssl->in_buf); - + esp_mbedtls_free_buf(ssl->in_buf); init_rx_buffer(ssl, NULL); - pdata = mbedtls_calloc(1, 16); - if (!pdata) { - ESP_LOGE(TAG, "alloc(%d bytes) failed", 16); + esp_buf = mbedtls_calloc(1, SSL_BUF_HEAD_OFFSET_SIZE + 16); + if (!esp_buf) { + ESP_LOGE(TAG, "alloc(%d bytes) failed", SSL_BUF_HEAD_OFFSET_SIZE + 16); ret = MBEDTLS_ERR_SSL_ALLOC_FAILED; goto exit; } - memcpy(pdata, buf, 16); - ssl->in_ctr = pdata; - + esp_mbedtls_init_ssl_buf(esp_buf, 16); + memcpy(esp_buf->buf, buf, 16); + init_rx_buffer(ssl, esp_buf->buf); + esp_mbedtls_set_buf_state(ssl->in_buf, ESP_MBEDTLS_SSL_BUF_NO_CACHED); exit: ESP_LOGV(TAG, "<-- free rx"); @@ -516,4 +541,17 @@ void esp_mbedtls_free_peer_cert(mbedtls_ssl_context *ssl) ssl->session_negotiate->peer_cert = NULL; } } + +bool esp_mbedtls_ssl_is_rsa(mbedtls_ssl_context *ssl) +{ + const mbedtls_ssl_ciphersuite_t *ciphersuite_info = + ssl->transform_negotiate->ciphersuite_info; + + if (ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA || + ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA_PSK) { + return true; + } else { + return false; + } +} #endif diff --git a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h index 2121cf3a0d..8f74097d5d 100644 --- a/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h +++ b/components/mbedtls/port/dynamic/esp_mbedtls_dynamic_impl.h @@ -33,9 +33,6 @@ \ if ((_ret = _fn) != 0) { \ ESP_LOGV(TAG, "\"%s\" result is -0x%x", # _fn, -_ret); \ - if (_ret == MBEDTLS_ERR_SSL_CONN_EOF) {\ - return 0; \ - } \ TRACE_CHECK(_fn, "fail"); \ return _ret; \ } \ @@ -44,6 +41,21 @@ \ }) +typedef enum { + ESP_MBEDTLS_SSL_BUF_CACHED, + ESP_MBEDTLS_SSL_BUF_NO_CACHED, +} esp_mbedtls_ssl_buf_states; + +struct esp_mbedtls_ssl_buf { + esp_mbedtls_ssl_buf_states state; + unsigned int len; + unsigned char buf[]; +}; + +#define SSL_BUF_HEAD_OFFSET_SIZE offsetof(struct esp_mbedtls_ssl_buf, buf) + +void esp_mbedtls_free_buf(unsigned char *buf); + int esp_mbedtls_setup_tx_buffer(mbedtls_ssl_context *ssl); void esp_mbedtls_setup_rx_buffer(mbedtls_ssl_context *ssl); @@ -82,6 +94,8 @@ void esp_mbedtls_free_cacert(mbedtls_ssl_context *ssl); #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT void esp_mbedtls_free_peer_cert(mbedtls_ssl_context *ssl); + +bool esp_mbedtls_ssl_is_rsa(mbedtls_ssl_context *ssl); #endif #endif /* _DYNAMIC_IMPL_H_ */ diff --git a/components/mbedtls/port/dynamic/esp_ssl_cli.c b/components/mbedtls/port/dynamic/esp_ssl_cli.c index a500eb85ea..94ed064d6d 100644 --- a/components/mbedtls/port/dynamic/esp_ssl_cli.c +++ b/components/mbedtls/port/dynamic/esp_ssl_cli.c @@ -73,7 +73,17 @@ static int manage_resource(mbedtls_ssl_context *ssl, bool add) CHECK_OK(esp_mbedtls_free_rx_buffer(ssl)); } #ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT - esp_mbedtls_free_peer_cert(ssl); + /** + * If current ciphersuite is RSA, we should free peer' + * certificate at step MBEDTLS_SSL_CLIENT_KEY_EXCHANGE. + * + * And if it is other kinds of ciphersuite, we can free + * peer certificate here. + */ + + if (esp_mbedtls_ssl_is_rsa(ssl) == false) { + esp_mbedtls_free_peer_cert(ssl); + } #endif } break; @@ -123,6 +133,12 @@ static int manage_resource(mbedtls_ssl_context *ssl, bool add) size_t buffer_len = MBEDTLS_SSL_OUT_BUFFER_LEN; CHECK_OK(esp_mbedtls_add_tx_buffer(ssl, buffer_len)); + } else { +#ifdef CONFIG_MBEDTLS_DYNAMIC_FREE_PEER_CERT + if (esp_mbedtls_ssl_is_rsa(ssl) == true) { + esp_mbedtls_free_peer_cert(ssl); + } +#endif } break; case MBEDTLS_SSL_CERTIFICATE_VERIFY: diff --git a/components/mbedtls/port/dynamic/esp_ssl_tls.c b/components/mbedtls/port/dynamic/esp_ssl_tls.c index 7d0372040b..d8b4506b58 100644 --- a/components/mbedtls/port/dynamic/esp_ssl_tls.c +++ b/components/mbedtls/port/dynamic/esp_ssl_tls.c @@ -85,7 +85,16 @@ int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t { int ret; - CHECK_OK(esp_mbedtls_add_rx_buffer(ssl)); + ESP_LOGD(TAG, "add mbedtls RX buffer"); + ret = esp_mbedtls_add_rx_buffer(ssl); + if (ret == MBEDTLS_ERR_SSL_CONN_EOF) { + ESP_LOGD(TAG, "fail, the connection indicated an EOF"); + return 0; + } else if (ret < 0) { + ESP_LOGD(TAG, "fail, error=-0x%x", -ret); + return ret; + } + ESP_LOGD(TAG, "end"); ret = __real_mbedtls_ssl_read(ssl, buf, len); @@ -99,12 +108,12 @@ int __wrap_mbedtls_ssl_read(mbedtls_ssl_context *ssl, unsigned char *buf, size_t void __wrap_mbedtls_ssl_free(mbedtls_ssl_context *ssl) { if (ssl->out_buf) { - mbedtls_free(ssl->out_buf); + esp_mbedtls_free_buf(ssl->out_buf); ssl->out_buf = NULL; } if (ssl->in_buf) { - mbedtls_free(ssl->in_buf); + esp_mbedtls_free_buf(ssl->in_buf); ssl->in_buf = NULL; }