diff --git a/components/wpa_supplicant/esp_supplicant/src/crypto/tls_mbedtls.c b/components/wpa_supplicant/esp_supplicant/src/crypto/tls_mbedtls.c index 274ed6af70..606af1a0e3 100644 --- a/components/wpa_supplicant/esp_supplicant/src/crypto/tls_mbedtls.c +++ b/components/wpa_supplicant/esp_supplicant/src/crypto/tls_mbedtls.c @@ -1001,19 +1001,30 @@ static int tls_connection_prf(void *tls_ctx, struct tls_connection *conn, size_t context_len, int server_random_first, u8 *out, size_t out_len) { - if (context) + int ret; + u8 *seed, *pos; + size_t seed_len = 2 * TLS_RANDOM_LEN; + mbedtls_ssl_context *ssl = &conn->tls->ssl; + + if (context_len > 65535) return -1; - int ret; - u8 seed[2 * TLS_RANDOM_LEN]; - mbedtls_ssl_context *ssl = &conn->tls->ssl; + if (context) + seed_len += 2 + context_len; + + seed = os_malloc(seed_len); + if (!seed) { + return -1; + } if (!ssl) { wpa_printf(MSG_ERROR, "TLS: %s, session ingo is null", __func__); + os_free(seed); return -1; } if (!mbedtls_ssl_is_handshake_over(ssl)) { wpa_printf(MSG_ERROR, "TLS: %s, incorrect tls state=%d", __func__, ssl->MBEDTLS_PRIVATE(state)); + os_free(seed); return -1; } @@ -1024,14 +1035,23 @@ static int tls_connection_prf(void *tls_ctx, struct tls_connection *conn, os_memcpy(seed, conn->randbytes, 2 * TLS_RANDOM_LEN); } + if (context) { + pos = seed + 2 * TLS_RANDOM_LEN; + WPA_PUT_BE16(pos, context_len); + pos += 2; + os_memcpy(pos, context, context_len); + } + wpa_hexdump_key(MSG_MSGDUMP, "random", seed, 2 * TLS_RANDOM_LEN); wpa_hexdump_key(MSG_MSGDUMP, "master", ssl->MBEDTLS_PRIVATE(session)->MBEDTLS_PRIVATE(master), TLS_MASTER_SECRET_LEN); ret = mbedtls_ssl_tls_prf(conn->tls_prf_type, conn->master_secret, TLS_MASTER_SECRET_LEN, label, seed, 2 * TLS_RANDOM_LEN, out, out_len); + os_free(seed); if (ret < 0) { wpa_printf(MSG_ERROR, "prf failed, ret=%d", ret); + return -1; } wpa_hexdump_key(MSG_MSGDUMP, "key", out, out_len);