diff --git a/components/esp-tls/esp_tls.c b/components/esp-tls/esp_tls.c index 9803da5f4b..a4839190b1 100644 --- a/components/esp-tls/esp_tls.c +++ b/components/esp-tls/esp_tls.c @@ -267,7 +267,7 @@ static esp_err_t esp_tls_set_socket_non_blocking(int fd, bool non_blocking) return ESP_OK; } -static esp_err_t esp_tcp_connect(const char *host, int hostlen, int port, int *sockfd, esp_tls_error_handle_t error_handle, const esp_tls_cfg_t *cfg) +esp_err_t esp_tls_tcp_connect(const char *host, int hostlen, int port, const esp_tls_cfg_t *cfg, esp_tls_error_handle_t error_handle, int *sockfd) { struct sockaddr_storage address; int fd; @@ -371,7 +371,7 @@ static int esp_tls_low_level_conn(const char *hostname, int hostlen, int port, c _esp_tls_net_init(tls); tls->is_tls = true; } - if ((esp_ret = esp_tcp_connect(hostname, hostlen, port, &tls->sockfd, tls->error_handle, cfg)) != ESP_OK) { + if ((esp_ret = esp_tls_tcp_connect(hostname, hostlen, port, cfg, tls->error_handle, &tls->sockfd)) != ESP_OK) { ESP_INT_EVENT_TRACKER_CAPTURE(tls->error_handle, ESP_TLS_ERR_TYPE_ESP, esp_ret); return -1; } diff --git a/components/esp-tls/esp_tls.h b/components/esp-tls/esp_tls.h index 72a78c5e7d..2ef4848c86 100644 --- a/components/esp-tls/esp_tls.h +++ b/components/esp-tls/esp_tls.h @@ -599,6 +599,20 @@ int esp_tls_server_session_create(esp_tls_cfg_server_t *cfg, int sockfd, esp_tls void esp_tls_server_session_delete(esp_tls_t *tls); #endif /* ! CONFIG_ESP_TLS_SERVER */ +/** + * @brief Creates a plain TCP connection, returning a valid socket fd on success or an error handle + * + * @param[in] host Hostname of the host. + * @param[in] hostlen Length of hostname. + * @param[in] port Port number of the host. + * @param[in] cfg ESP-TLS configuration as esp_tls_cfg_t. + * @param[out] error_handle ESP-TLS error handle holding potential errors occurred during connection + * @param[out] sockfd Socket descriptor if successfully connected on TCP layer + * @return ESP_OK on success + * ESP-TLS based error codes on failure + */ +esp_err_t esp_tls_tcp_connect(const char *host, int hostlen, int port, const esp_tls_cfg_t *cfg, esp_tls_error_handle_t error_handle, int *sockfd); + #ifdef __cplusplus } #endif diff --git a/components/tcp_transport/test/test_transport.c b/components/tcp_transport/test/test_transport.c index 51d48f7188..505ec29507 100644 --- a/components/tcp_transport/test/test_transport.c +++ b/components/tcp_transport/test/test_transport.c @@ -315,6 +315,7 @@ static void socket_operation_test(esp_transport_handle_t transport_under_test, close(params.listen_sock); close(params.accepted_sock); + xEventGroupWaitBits(params.tcp_connect_done, TCP_LISTENER_DONE, true, true, max_wait); // Cleanup TEST_ASSERT_EQUAL(false, params.tcp_listener_failed); vEventGroupDelete(params.tcp_connect_done); diff --git a/components/tcp_transport/transport_ssl.c b/components/tcp_transport/transport_ssl.c index 60fb6e3c51..312ac453fd 100644 --- a/components/tcp_transport/transport_ssl.c +++ b/components/tcp_transport/transport_ssl.c @@ -42,6 +42,7 @@ typedef struct transport_esp_tls { esp_tls_cfg_t cfg; bool ssl_initialized; transport_ssl_conn_state_t conn_state; + int sockfd; } transport_esp_tls_t; static inline struct transport_esp_tls * ssl_get_context_data(esp_transport_handle_t t) @@ -95,12 +96,12 @@ static inline int tcp_connect_async(esp_transport_handle_t t, const char *host, return esp_tls_connect_async(t, host, port, timeout_ms, true); } -static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms, bool is_plain_tcp) +static int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) { transport_esp_tls_t *ssl = ssl_get_context_data(t); ssl->cfg.timeout_ms = timeout_ms; - ssl->cfg.is_plain_tcp = is_plain_tcp; + ssl->cfg.is_plain_tcp = false; ssl->ssl_initialized = true; ssl->tls = esp_tls_init(); @@ -114,19 +115,27 @@ static int esp_tls_connect(esp_transport_handle_t t, const char *host, int port, esp_transport_set_errors(t, ssl->tls->error_handle); esp_tls_conn_destroy(ssl->tls); ssl->tls = NULL; + ssl->sockfd = -1; return -1; } + ssl->sockfd = ssl->tls->sockfd; return 0; } -static inline int ssl_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) +static int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) { - return esp_tls_connect(t, host, port, timeout_ms, false); -} + transport_esp_tls_t *ssl = ssl_get_context_data(t); + esp_tls_last_error_t *err_handle = esp_transport_get_error_handle(t); -static inline int tcp_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms) -{ - return esp_tls_connect(t, host, port, timeout_ms, true); + ssl->cfg.timeout_ms = timeout_ms; + esp_err_t err = esp_tls_tcp_connect(host, strlen(host), port, &ssl->cfg, err_handle, &ssl->sockfd); + if (err != ESP_OK) { + ESP_LOGE(TAG, "Failed to open a new connection: %d", err); + err_handle->last_error = err; + ssl->sockfd = -1; + return -1; + } + return 0; } static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) @@ -139,20 +148,20 @@ static int ssl_poll_read(esp_transport_handle_t t, int timeout_ms) fd_set errset; FD_ZERO(&readset); FD_ZERO(&errset); - FD_SET(ssl->tls->sockfd, &readset); - FD_SET(ssl->tls->sockfd, &errset); + FD_SET(ssl->sockfd, &readset); + FD_SET(ssl->sockfd, &errset); - if ((remain = esp_tls_get_bytes_avail(ssl->tls)) > 0) { + if (ssl->tls && (remain = esp_tls_get_bytes_avail(ssl->tls)) > 0) { ESP_LOGD(TAG, "remain data in cache, need to read again"); return remain; } - ret = select(ssl->tls->sockfd + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); - if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) { + ret = select(ssl->sockfd + 1, &readset, NULL, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); + if (ret > 0 && FD_ISSET(ssl->sockfd, &errset)) { int sock_errno = 0; uint32_t optlen = sizeof(sock_errno); - getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); + getsockopt(ssl->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); esp_transport_capture_errno(t, sock_errno); - ESP_LOGE(TAG, "ssl_poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd); + ESP_LOGE(TAG, "poll_read select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd); ret = -1; } return ret; @@ -167,15 +176,15 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) fd_set errset; FD_ZERO(&writeset); FD_ZERO(&errset); - FD_SET(ssl->tls->sockfd, &writeset); - FD_SET(ssl->tls->sockfd, &errset); - ret = select(ssl->tls->sockfd + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); - if (ret > 0 && FD_ISSET(ssl->tls->sockfd, &errset)) { + FD_SET(ssl->sockfd, &writeset); + FD_SET(ssl->sockfd, &errset); + ret = select(ssl->sockfd + 1, NULL, &writeset, &errset, esp_transport_utils_ms_to_timeval(timeout_ms, &timeout)); + if (ret > 0 && FD_ISSET(ssl->sockfd, &errset)) { int sock_errno = 0; uint32_t optlen = sizeof(sock_errno); - getsockopt(ssl->tls->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); + getsockopt(ssl->sockfd, SOL_SOCKET, SO_ERROR, &sock_errno, &optlen); esp_transport_capture_errno(t, sock_errno); - ESP_LOGE(TAG, "ssl_poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd); + ESP_LOGE(TAG, "poll_write select error %d, errno = %s, fd = %d", sock_errno, strerror(sock_errno), ssl->tls->sockfd); ret = -1; } return ret; @@ -183,14 +192,14 @@ static int ssl_poll_write(esp_transport_handle_t t, int timeout_ms) static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) { - int poll, ret; + int poll; transport_esp_tls_t *ssl = ssl_get_context_data(t); if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) { ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms); return poll; } - ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len); + int ret = esp_tls_conn_write(ssl->tls, (const unsigned char *) buffer, len); if (ret < 0) { ESP_LOGE(TAG, "esp_tls_conn_write error, errno=%s", strerror(errno)); esp_transport_set_errors(t, ssl->tls->error_handle); @@ -198,15 +207,32 @@ static int ssl_write(esp_transport_handle_t t, const char *buffer, int len, int return ret; } +static int tcp_write(esp_transport_handle_t t, const char *buffer, int len, int timeout_ms) +{ + int poll; + transport_esp_tls_t *ssl = ssl_get_context_data(t); + + if ((poll = esp_transport_poll_write(t, timeout_ms)) <= 0) { + ESP_LOGW(TAG, "Poll timeout or error, errno=%s, fd=%d, timeout_ms=%d", strerror(errno), ssl->tls->sockfd, timeout_ms); + return poll; + } + int ret = send(ssl->sockfd,(const unsigned char *) buffer, len, 0); + if (ret < 0) { + ESP_LOGE(TAG, "tcp_write error, errno=%s", strerror(errno)); + esp_transport_capture_errno(t, errno); + } + return ret; +} + static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) { - int poll, ret; + int poll; transport_esp_tls_t *ssl = ssl_get_context_data(t); if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) { return poll; } - ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len); + int ret = esp_tls_conn_read(ssl->tls, (unsigned char *)buffer, len); if (ret < 0) { ESP_LOGE(TAG, "esp_tls_conn_read error, errno=%s", strerror(errno)); esp_transport_set_errors(t, ssl->tls->error_handle); @@ -221,6 +247,29 @@ static int ssl_read(esp_transport_handle_t t, char *buffer, int len, int timeout return ret; } +static int tcp_read(esp_transport_handle_t t, char *buffer, int len, int timeout_ms) +{ + int poll; + transport_esp_tls_t *ssl = ssl_get_context_data(t); + + if ((poll = esp_transport_poll_read(t, timeout_ms)) <= 0) { + return poll; + } + int ret = recv(ssl->sockfd, (unsigned char *)buffer, len, 0); + if (ret < 0) { + ESP_LOGE(TAG, "tcp_read error, errno=%s", strerror(errno)); + esp_transport_capture_errno(t, errno); + } + if (ret == 0) { + if (poll > 0) { + // no error, socket reads 0 while previously detected as readable -> connection has been closed cleanly + capture_tcp_transport_error(t, ERR_TCP_TRANSPORT_CONNECTION_CLOSED_BY_FIN); + } + ret = -1; + } + return ret; +} + static int ssl_close(esp_transport_handle_t t) { int ret = -1; @@ -229,6 +278,10 @@ static int ssl_close(esp_transport_handle_t t) ret = esp_tls_conn_destroy(ssl->tls); ssl->conn_state = TRANS_SSL_INIT; ssl->ssl_initialized = false; + ssl->sockfd = -1; + } else if (ssl && ssl->sockfd >= 0) { + close(ssl->sockfd); + ssl->sockfd = -1; } return ret; } @@ -344,6 +397,15 @@ static int ssl_get_socket(esp_transport_handle_t t) return -1; } +static int tcp_get_socket(esp_transport_handle_t t) +{ + transport_esp_tls_t *ctx = ssl_get_context_data(t); + if (ctx) { + return ctx->sockfd; + } + return -1; +} + void esp_transport_ssl_set_ds_data(esp_transport_handle_t t, void *ds_data) { GET_SSL_FROM_TRANSPORT_OR_RETURN(ssl, t); @@ -378,6 +440,7 @@ esp_transport_handle_t esp_transport_ssl_init(void) struct transport_esp_tls* esp_transport_esp_tls_create(void) { transport_esp_tls_t *transport_esp_tls = calloc(1, sizeof(transport_esp_tls_t)); + transport_esp_tls->sockfd = -1; return transport_esp_tls; } @@ -392,9 +455,9 @@ esp_transport_handle_t esp_transport_tcp_init(void) if (t == NULL) { return NULL; } - esp_transport_set_func(t, tcp_connect, ssl_read, ssl_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy); + esp_transport_set_func(t, tcp_connect, tcp_read, tcp_write, ssl_close, ssl_poll_read, ssl_poll_write, ssl_destroy); esp_transport_set_async_connect_func(t, tcp_connect_async); - t->_get_socket = ssl_get_socket; + t->_get_socket = tcp_get_socket; return t; }