mirror of
https://github.com/espressif/esp-idf.git
synced 2025-07-31 19:24:33 +02:00
fix(esp_local_ctrl): update for changes in protocomm security2 scheme
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
/*
|
/*
|
||||||
* SPDX-FileCopyrightText: 2019-2022 Espressif Systems (Shanghai) CO LTD
|
* SPDX-FileCopyrightText: 2019-2025 Espressif Systems (Shanghai) CO LTD
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
@@ -20,6 +20,8 @@
|
|||||||
#include "esp_local_ctrl.pb-c.h"
|
#include "esp_local_ctrl.pb-c.h"
|
||||||
|
|
||||||
#define ESP_LOCAL_CTRL_VERSION "v1.0"
|
#define ESP_LOCAL_CTRL_VERSION "v1.0"
|
||||||
|
/* JSON format string for version endpoint */
|
||||||
|
#define ESP_LOCAL_CTRL_VER_FMT_STR "{\"local_ctrl\":{\"ver\":\"%s\",\"sec_ver\":%d,\"sec_patch_ver\":%d}}"
|
||||||
|
|
||||||
struct inst_ctx {
|
struct inst_ctx {
|
||||||
protocomm_t *pc;
|
protocomm_t *pc;
|
||||||
@@ -136,14 +138,6 @@ esp_err_t esp_local_ctrl_start(const esp_local_ctrl_config_t *config)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = protocomm_set_version(local_ctrl_inst_ctx->pc, "esp_local_ctrl/version",
|
|
||||||
ESP_LOCAL_CTRL_VERSION);
|
|
||||||
if (ret != ESP_OK) {
|
|
||||||
ESP_LOGE(TAG, "Failed to set version endpoint");
|
|
||||||
esp_local_ctrl_stop();
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
protocomm_security_t *proto_sec_handle = NULL;
|
protocomm_security_t *proto_sec_handle = NULL;
|
||||||
switch (local_ctrl_inst_ctx->config.proto_sec.version) {
|
switch (local_ctrl_inst_ctx->config.proto_sec.version) {
|
||||||
case PROTOCOM_SEC_CUSTOM:
|
case PROTOCOM_SEC_CUSTOM:
|
||||||
@@ -183,6 +177,29 @@ esp_err_t esp_local_ctrl_start(const esp_local_ctrl_config_t *config)
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int sec_ver = 0;
|
||||||
|
uint8_t sec_patch_ver = 0;
|
||||||
|
protocomm_get_sec_version(local_ctrl_inst_ctx->pc, &sec_ver, &sec_patch_ver);
|
||||||
|
|
||||||
|
const int rsize = snprintf(NULL, 0, ESP_LOCAL_CTRL_VER_FMT_STR, ESP_LOCAL_CTRL_VERSION, sec_ver, sec_patch_ver) + 1;
|
||||||
|
char *ver_str = malloc(rsize);
|
||||||
|
if (!ver_str) {
|
||||||
|
ESP_LOGE(TAG, "Failed to allocate memory for version string");
|
||||||
|
esp_local_ctrl_stop();
|
||||||
|
return ESP_ERR_NO_MEM;
|
||||||
|
}
|
||||||
|
snprintf(ver_str, rsize, ESP_LOCAL_CTRL_VER_FMT_STR, ESP_LOCAL_CTRL_VERSION, sec_ver, sec_patch_ver);
|
||||||
|
|
||||||
|
ESP_LOGD(TAG, "ver_str: %s", ver_str);
|
||||||
|
ret = protocomm_set_version(local_ctrl_inst_ctx->pc, "esp_local_ctrl/version",
|
||||||
|
ver_str);
|
||||||
|
free(ver_str);
|
||||||
|
if (ret != ESP_OK) {
|
||||||
|
ESP_LOGE(TAG, "Failed to set version endpoint");
|
||||||
|
esp_local_ctrl_stop();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
ret = protocomm_add_endpoint(local_ctrl_inst_ctx->pc, "esp_local_ctrl/control",
|
ret = protocomm_add_endpoint(local_ctrl_inst_ctx->pc, "esp_local_ctrl/control",
|
||||||
esp_local_ctrl_data_handler, NULL);
|
esp_local_ctrl_data_handler, NULL);
|
||||||
if (ret != ESP_OK) {
|
if (ret != ESP_OK) {
|
||||||
|
@@ -351,6 +351,8 @@ static esp_err_t wifi_prov_mgr_start_service(const char *service_name, const cha
|
|||||||
/* Set version information / capabilities of provisioning service and application */
|
/* Set version information / capabilities of provisioning service and application */
|
||||||
cJSON *version_json = wifi_prov_get_info_json();
|
cJSON *version_json = wifi_prov_get_info_json();
|
||||||
char *version_str = cJSON_Print(version_json);
|
char *version_str = cJSON_Print(version_json);
|
||||||
|
ESP_LOGD(TAG, "version_str :%s:", version_str);
|
||||||
|
|
||||||
ret = protocomm_set_version(prov_ctx->pc, "proto-ver", version_str);
|
ret = protocomm_set_version(prov_ctx->pc, "proto-ver", version_str);
|
||||||
free(version_str);
|
free(version_str);
|
||||||
cJSON_Delete(version_json);
|
cJSON_Delete(version_json);
|
||||||
|
@@ -1,9 +1,8 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
#
|
#
|
||||||
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
|
# SPDX-FileCopyrightText: 2018-2025 Espressif Systems (Shanghai) CO LTD
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
#
|
#
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
@@ -113,9 +112,9 @@ def on_except(err):
|
|||||||
print(err)
|
print(err)
|
||||||
|
|
||||||
|
|
||||||
def get_security(secver, username, password, pop='', verbose=False):
|
def get_security(secver, sec_patch_ver, username, password, pop='', verbose=False):
|
||||||
if secver == 2:
|
if secver == 2:
|
||||||
return security.Security2(username, password, verbose)
|
return security.Security2(sec_patch_ver, username, password, verbose)
|
||||||
if secver == 1:
|
if secver == 1:
|
||||||
return security.Security1(pop, verbose)
|
return security.Security1(pop, verbose)
|
||||||
if secver == 0:
|
if secver == 0:
|
||||||
@@ -148,6 +147,32 @@ async def get_transport(sel_transport, service_name, check_hostname):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_sec_patch_ver(tp, verbose=False):
|
||||||
|
try:
|
||||||
|
response = await tp.send_data('esp_local_ctrl/version', '---')
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print('esp_local_ctrl/version response : ', response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Interpret this as JSON structure containing
|
||||||
|
# information with security version information
|
||||||
|
info = json.loads(response)
|
||||||
|
try:
|
||||||
|
sec_patch_ver = info['local_ctrl']['sec_patch_ver']
|
||||||
|
except KeyError:
|
||||||
|
sec_patch_ver = 0
|
||||||
|
return sec_patch_ver
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
# If decoding as JSON fails, we assume default patch level
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
on_except(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def version_match(tp, protover, verbose=False):
|
async def version_match(tp, protover, verbose=False):
|
||||||
try:
|
try:
|
||||||
response = await tp.send_data('esp_local_ctrl/version', protover)
|
response = await tp.send_data('esp_local_ctrl/version', protover)
|
||||||
@@ -164,7 +189,7 @@ async def version_match(tp, protover, verbose=False):
|
|||||||
# information with versions and capabilities of both
|
# information with versions and capabilities of both
|
||||||
# provisioning service and application
|
# provisioning service and application
|
||||||
info = json.loads(response)
|
info = json.loads(response)
|
||||||
if info['prov']['ver'].lower() == protover.lower():
|
if info['local_ctrl']['ver'].lower() == protover.lower():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -191,14 +216,19 @@ async def has_capability(tp, capability='none', verbose=False):
|
|||||||
# information with versions and capabilities of both
|
# information with versions and capabilities of both
|
||||||
# provisioning service and application
|
# provisioning service and application
|
||||||
info = json.loads(response)
|
info = json.loads(response)
|
||||||
supported_capabilities = info['prov']['cap']
|
try:
|
||||||
if capability.lower() == 'none':
|
supported_capabilities = info['local_ctrl']['cap']
|
||||||
# No specific capability to check, but capabilities
|
if capability.lower() == 'none':
|
||||||
# feature is present so return True
|
# No specific capability to check, but capabilities
|
||||||
return True
|
# feature is present so return True
|
||||||
elif capability in supported_capabilities:
|
return True
|
||||||
return True
|
elif capability in supported_capabilities:
|
||||||
return False
|
return True
|
||||||
|
return False
|
||||||
|
except KeyError:
|
||||||
|
# If capabilities field is not present, it means
|
||||||
|
# that capabilities are not supported
|
||||||
|
return False
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If decoding as JSON fails, it means that capabilities
|
# If decoding as JSON fails, it means that capabilities
|
||||||
@@ -341,6 +371,7 @@ async def main():
|
|||||||
if obj_transport is None:
|
if obj_transport is None:
|
||||||
raise RuntimeError('Failed to establish connection')
|
raise RuntimeError('Failed to establish connection')
|
||||||
|
|
||||||
|
sec_patch_ver = 0
|
||||||
# If security version not specified check in capabilities
|
# If security version not specified check in capabilities
|
||||||
if args.secver is None:
|
if args.secver is None:
|
||||||
# First check if capabilities are supported or not
|
# First check if capabilities are supported or not
|
||||||
@@ -362,13 +393,14 @@ async def main():
|
|||||||
args.pop = ''
|
args.pop = ''
|
||||||
|
|
||||||
if (args.secver == 2):
|
if (args.secver == 2):
|
||||||
|
sec_patch_ver = await get_sec_patch_ver(obj_transport, args.verbose)
|
||||||
if len(args.sec2_usr) == 0:
|
if len(args.sec2_usr) == 0:
|
||||||
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
|
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
|
||||||
if len(args.sec2_pwd) == 0:
|
if len(args.sec2_pwd) == 0:
|
||||||
prompt_str = 'Security Scheme 2 - SRP6a Password required: '
|
prompt_str = 'Security Scheme 2 - SRP6a Password required: '
|
||||||
args.sec2_pwd = getpass(prompt_str)
|
args.sec2_pwd = getpass(prompt_str)
|
||||||
|
|
||||||
obj_security = get_security(args.secver, args.sec2_usr, args.sec2_pwd, args.pop, args.verbose)
|
obj_security = get_security(args.secver, sec_patch_ver, args.sec2_usr, args.sec2_pwd, args.pop, args.verbose)
|
||||||
if obj_security is None:
|
if obj_security is None:
|
||||||
raise ValueError('Invalid Security Version')
|
raise ValueError('Invalid Security Version')
|
||||||
|
|
||||||
|
@@ -457,7 +457,7 @@ async def main():
|
|||||||
args.sec1_pop = ''
|
args.sec1_pop = ''
|
||||||
|
|
||||||
if (args.secver == 2):
|
if (args.secver == 2):
|
||||||
sec_patch_ver = await get_sec_patch_ver(obj_transport)
|
sec_patch_ver = await get_sec_patch_ver(obj_transport, args.verbose)
|
||||||
if len(args.sec2_usr) == 0:
|
if len(args.sec2_usr) == 0:
|
||||||
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
|
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
|
||||||
if len(args.sec2_pwd) == 0:
|
if len(args.sec2_pwd) == 0:
|
||||||
|
Reference in New Issue
Block a user