diff --git a/components/wpa_supplicant/esp_supplicant/src/esp_wpa_main.c b/components/wpa_supplicant/esp_supplicant/src/esp_wpa_main.c index 31ddb87928..af93891300 100644 --- a/components/wpa_supplicant/esp_supplicant/src/esp_wpa_main.c +++ b/components/wpa_supplicant/esp_supplicant/src/esp_wpa_main.c @@ -269,7 +269,7 @@ static int check_n_add_wps_sta(struct hostapd_data *hapd, struct sta_info *sta_i } #endif -static bool hostap_sta_join(void **sm, u8 *bssid, u8 *wpa_ie, u8 wpa_ie_len, bool *pmf_enable) +static bool hostap_sta_join(void **sta, u8 *bssid, u8 *wpa_ie, u8 wpa_ie_len, bool *pmf_enable) { struct sta_info *sta_info; struct hostapd_data *hapd = hostapd_get_hapd_data(); @@ -277,6 +277,10 @@ static bool hostap_sta_join(void **sm, u8 *bssid, u8 *wpa_ie, u8 wpa_ie_len, boo if (!hapd) { return 0; } + + if (*sta) { + ap_free_sta(hapd, *sta); + } sta_info = ap_sta_add(hapd, bssid); if (!sta_info) { wpa_printf(MSG_ERROR, "failed to add station " MACSTR, MAC2STR(bssid)); @@ -284,13 +288,12 @@ static bool hostap_sta_join(void **sm, u8 *bssid, u8 *wpa_ie, u8 wpa_ie_len, boo } #ifdef CONFIG_WPS_REGISTRAR if (check_n_add_wps_sta(hapd, sta_info, wpa_ie, wpa_ie_len, pmf_enable)) { - *sm = sta_info; + *sta = sta_info; return true; } #endif - if (wpa_ap_join(sm, bssid, wpa_ie, wpa_ie_len, pmf_enable)) { - sta_info->wpa_sm = *sm; - *sm = sta_info; + if (wpa_ap_join(sta_info, bssid, wpa_ie, wpa_ie_len, pmf_enable)) { + *sta = sta_info; return true; } diff --git a/components/wpa_supplicant/src/ap/ap_config.h b/components/wpa_supplicant/src/ap/ap_config.h index 62b4c423aa..0b6fc253f9 100644 --- a/components/wpa_supplicant/src/ap/ap_config.h +++ b/components/wpa_supplicant/src/ap/ap_config.h @@ -372,7 +372,8 @@ int hostapd_wep_key_cmp(struct hostapd_wep_keys *a, const u8 * hostapd_get_psk(const struct hostapd_bss_config *conf, const u8 *addr, const u8 *prev_psk); int hostapd_setup_wpa_psk(struct hostapd_bss_config *conf); -bool wpa_ap_join(void** sm, uint8_t *bssid, uint8_t *wpa_ie, uint8_t wpa_ie_len, bool *pmf_enable); -bool wpa_ap_remove(void* sm); +struct sta_info; +bool wpa_ap_join(struct sta_info *sta, uint8_t *bssid, uint8_t *wpa_ie, uint8_t wpa_ie_len, bool *pmf_enable); +bool wpa_ap_remove(void* sta_info); #endif /* HOSTAPD_CONFIG_H */ diff --git a/components/wpa_supplicant/src/ap/wpa_auth.c b/components/wpa_supplicant/src/ap/wpa_auth.c index 6c0e902876..f13cc15a2a 100644 --- a/components/wpa_supplicant/src/ap/wpa_auth.c +++ b/components/wpa_supplicant/src/ap/wpa_auth.c @@ -2340,53 +2340,49 @@ static int wpa_sm_step(struct wpa_state_machine *sm) return 0; } -bool wpa_ap_join(void** sm, uint8_t *bssid, uint8_t *wpa_ie, uint8_t wpa_ie_len, bool *pmf_enable) +bool wpa_ap_join(struct sta_info *sta, uint8_t *bssid, uint8_t *wpa_ie, uint8_t wpa_ie_len, bool *pmf_enable) { struct hostapd_data *hapd = (struct hostapd_data*)esp_wifi_get_hostap_private_internal(); - struct wpa_state_machine **wpa_sm; - if (!sm || !bssid || !wpa_ie){ + if (!sta || !bssid || !wpa_ie){ return false; } - - wpa_sm = (struct wpa_state_machine **)sm; - if (hapd) { if (hapd->wpa_auth->conf.wpa) { - if (*wpa_sm){ - wpa_auth_sta_deinit(*wpa_sm); + if (sta->wpa_sm){ + wpa_auth_sta_deinit(sta->wpa_sm); } - *wpa_sm = wpa_auth_sta_init(hapd->wpa_auth, bssid); - wpa_printf( MSG_DEBUG, "init wpa sm=%p\n", *wpa_sm); + sta->wpa_sm = wpa_auth_sta_init(hapd->wpa_auth, bssid); + wpa_printf( MSG_DEBUG, "init wpa sm=%p\n", sta->wpa_sm); - if (*wpa_sm == NULL) { + if (sta->wpa_sm == NULL) { return false; } - if (wpa_validate_wpa_ie(hapd->wpa_auth, *wpa_sm, wpa_ie, wpa_ie_len)) { + if (wpa_validate_wpa_ie(hapd->wpa_auth, sta->wpa_sm, wpa_ie, wpa_ie_len)) { return false; } //Check whether AP uses Management Frame Protection for this connection - *pmf_enable = wpa_auth_uses_mfp(*wpa_sm); + *pmf_enable = wpa_auth_uses_mfp(sta->wpa_sm); } - wpa_auth_sta_associated(hapd->wpa_auth, *wpa_sm); + wpa_auth_sta_associated(hapd->wpa_auth, sta->wpa_sm); } return true; } -bool wpa_ap_remove(void* sm) +bool wpa_ap_remove(void* sta_info) { struct hostapd_data *hapd = hostapd_get_hapd_data(); - if (!sm || !hapd) { + if (!sta_info || !hapd) { return false; } - ap_free_sta(hapd, sm); + ap_free_sta(hapd, sta_info); return true; }