diff --git a/components/wpa_supplicant/src/common/sae.c b/components/wpa_supplicant/src/common/sae.c index ee5ff1b06e..37b918230c 100644 --- a/components/wpa_supplicant/src/common/sae.c +++ b/components/wpa_supplicant/src/common/sae.c @@ -10,6 +10,7 @@ #include "utils/includes.h" #include "utils/common.h" +#include "common/wpa_common.h" #include "utils/const_time.h" #include "crypto/crypto.h" #include "crypto/sha256.h" @@ -1466,10 +1467,11 @@ static int sae_derive_keys(struct sae_data *sae, const u8 *k) const u8 *salt; struct wpabuf *rejected_groups = NULL; u8 keyseed[SAE_MAX_HASH_LEN]; - u8 keys[2 * SAE_MAX_HASH_LEN + SAE_PMK_LEN]; + u8 keys[2 * SAE_MAX_HASH_LEN + SAE_PMK_LEN_MAX]; struct crypto_bignum *tmp; int ret = -1; size_t hash_len, salt_len, prime_len = sae->tmp->prime_len; + size_t pmk_len; const u8 *addr[1]; size_t len[1]; @@ -1491,6 +1493,10 @@ static int sae_derive_keys(struct sae_data *sae, const u8 *k) hash_len = sae_ffc_prime_len_2_hash_len(prime_len); else hash_len = sae_ecc_prime_len_2_hash_len(prime_len); + if (wpa_key_mgmt_sae_ext_key(sae->akmp)) + pmk_len = hash_len; + else + pmk_len = SAE_PMK_LEN; if (sae->h2e && (sae->tmp->own_rejected_groups || sae->tmp->peer_rejected_groups)) { @@ -1552,27 +1558,27 @@ static int sae_derive_keys(struct sae_data *sae, const u8 *k) if (sae->pk) { if (sae_kdf_hash(hash_len, keyseed, "SAE-PK keys", val, sae->tmp->order_len, - keys, 2 * hash_len + SAE_PMK_LEN) < 0) + keys, 2 * hash_len + pmk_len) < 0) goto fail; } else { if (sae_kdf_hash(hash_len, keyseed, "SAE KCK and PMK", val, sae->tmp->order_len, - keys, hash_len + SAE_PMK_LEN) < 0) + keys, hash_len + pmk_len) < 0) goto fail; } #else /* CONFIG_SAE_PK */ if (sae_kdf_hash(hash_len, keyseed, "SAE KCK and PMK", val, sae->tmp->order_len, - keys, hash_len + SAE_PMK_LEN) < 0) + keys, hash_len + pmk_len) < 0) goto fail; #endif /* !CONFIG_SAE_PK */ forced_memzero(keyseed, sizeof(keyseed)); os_memcpy(sae->tmp->kck, keys, hash_len); sae->tmp->kck_len = hash_len; - os_memcpy(sae->pmk, keys + hash_len, SAE_PMK_LEN); + os_memcpy(sae->pmk, keys + hash_len, pmk_len); + sae->pmk_len = pmk_len; os_memcpy(sae->pmkid, val, SAE_PMKID_LEN); - #ifdef CONFIG_SAE_PK if (sae->pk) { os_memcpy(sae->tmp->kek, keys + hash_len + SAE_PMK_LEN, @@ -1585,7 +1591,7 @@ static int sae_derive_keys(struct sae_data *sae, const u8 *k) forced_memzero(keys, sizeof(keys)); wpa_hexdump_key(MSG_DEBUG, "SAE: KCK", sae->tmp->kck, sae->tmp->kck_len); - wpa_hexdump_key(MSG_DEBUG, "SAE: PMK", sae->pmk, SAE_PMK_LEN); + wpa_hexdump_key(MSG_DEBUG, "SAE: PMK", sae->pmk, sae->pmk_len); ret = 0; fail: diff --git a/components/wpa_supplicant/src/common/sae.h b/components/wpa_supplicant/src/common/sae.h index 028330268a..ca4233159b 100644 --- a/components/wpa_supplicant/src/common/sae.h +++ b/components/wpa_supplicant/src/common/sae.h @@ -108,7 +108,7 @@ enum sae_state { struct sae_data { enum sae_state state; u16 send_confirm; - u8 pmk[SAE_PMK_LEN]; + u8 pmk[SAE_PMK_LEN_MAX]; size_t pmk_len; int akmp; /* WPA_KEY_MGMT_* used in key derivation */ u32 own_akm_suite_selector;