fix(esp_local_ctrl): update for changes in protocomm security2 scheme

This commit is contained in:
Mahavir Jain
2025-03-04 22:51:27 +05:30
parent 1b319631b8
commit 0a329ab9a4
4 changed files with 75 additions and 24 deletions

View File

@@ -1,5 +1,5 @@
/* /*
* SPDX-FileCopyrightText: 2019-2022 Espressif Systems (Shanghai) CO LTD * SPDX-FileCopyrightText: 2019-2025 Espressif Systems (Shanghai) CO LTD
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
@@ -20,6 +20,8 @@
#include "esp_local_ctrl.pb-c.h" #include "esp_local_ctrl.pb-c.h"
#define ESP_LOCAL_CTRL_VERSION "v1.0" #define ESP_LOCAL_CTRL_VERSION "v1.0"
/* JSON format string for version endpoint */
#define ESP_LOCAL_CTRL_VER_FMT_STR "{\"local_ctrl\":{\"ver\":\"%s\",\"sec_ver\":%d,\"sec_patch_ver\":%d}}"
struct inst_ctx { struct inst_ctx {
protocomm_t *pc; protocomm_t *pc;
@@ -136,14 +138,6 @@ esp_err_t esp_local_ctrl_start(const esp_local_ctrl_config_t *config)
} }
} }
ret = protocomm_set_version(local_ctrl_inst_ctx->pc, "esp_local_ctrl/version",
ESP_LOCAL_CTRL_VERSION);
if (ret != ESP_OK) {
ESP_LOGE(TAG, "Failed to set version endpoint");
esp_local_ctrl_stop();
return ret;
}
protocomm_security_t *proto_sec_handle = NULL; protocomm_security_t *proto_sec_handle = NULL;
switch (local_ctrl_inst_ctx->config.proto_sec.version) { switch (local_ctrl_inst_ctx->config.proto_sec.version) {
case PROTOCOM_SEC_CUSTOM: case PROTOCOM_SEC_CUSTOM:
@@ -183,6 +177,29 @@ esp_err_t esp_local_ctrl_start(const esp_local_ctrl_config_t *config)
return ret; return ret;
} }
int sec_ver = 0;
uint8_t sec_patch_ver = 0;
protocomm_get_sec_version(local_ctrl_inst_ctx->pc, &sec_ver, &sec_patch_ver);
const int rsize = snprintf(NULL, 0, ESP_LOCAL_CTRL_VER_FMT_STR, ESP_LOCAL_CTRL_VERSION, sec_ver, sec_patch_ver) + 1;
char *ver_str = malloc(rsize);
if (!ver_str) {
ESP_LOGE(TAG, "Failed to allocate memory for version string");
esp_local_ctrl_stop();
return ESP_ERR_NO_MEM;
}
snprintf(ver_str, rsize, ESP_LOCAL_CTRL_VER_FMT_STR, ESP_LOCAL_CTRL_VERSION, sec_ver, sec_patch_ver);
ESP_LOGD(TAG, "ver_str: %s", ver_str);
ret = protocomm_set_version(local_ctrl_inst_ctx->pc, "esp_local_ctrl/version",
ver_str);
free(ver_str);
if (ret != ESP_OK) {
ESP_LOGE(TAG, "Failed to set version endpoint");
esp_local_ctrl_stop();
return ret;
}
ret = protocomm_add_endpoint(local_ctrl_inst_ctx->pc, "esp_local_ctrl/control", ret = protocomm_add_endpoint(local_ctrl_inst_ctx->pc, "esp_local_ctrl/control",
esp_local_ctrl_data_handler, NULL); esp_local_ctrl_data_handler, NULL);
if (ret != ESP_OK) { if (ret != ESP_OK) {

View File

@@ -351,6 +351,8 @@ static esp_err_t wifi_prov_mgr_start_service(const char *service_name, const cha
/* Set version information / capabilities of provisioning service and application */ /* Set version information / capabilities of provisioning service and application */
cJSON *version_json = wifi_prov_get_info_json(); cJSON *version_json = wifi_prov_get_info_json();
char *version_str = cJSON_Print(version_json); char *version_str = cJSON_Print(version_json);
ESP_LOGD(TAG, "version_str :%s:", version_str);
ret = protocomm_set_version(prov_ctx->pc, "proto-ver", version_str); ret = protocomm_set_version(prov_ctx->pc, "proto-ver", version_str);
free(version_str); free(version_str);
cJSON_Delete(version_json); cJSON_Delete(version_json);

View File

@@ -1,9 +1,8 @@
#!/usr/bin/env python #!/usr/bin/env python
# #
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2018-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
import argparse import argparse
import asyncio import asyncio
import json import json
@@ -113,9 +112,9 @@ def on_except(err):
print(err) print(err)
def get_security(secver, username, password, pop='', verbose=False): def get_security(secver, sec_patch_ver, username, password, pop='', verbose=False):
if secver == 2: if secver == 2:
return security.Security2(username, password, verbose) return security.Security2(sec_patch_ver, username, password, verbose)
if secver == 1: if secver == 1:
return security.Security1(pop, verbose) return security.Security1(pop, verbose)
if secver == 0: if secver == 0:
@@ -148,6 +147,32 @@ async def get_transport(sel_transport, service_name, check_hostname):
return None return None
async def get_sec_patch_ver(tp, verbose=False):
try:
response = await tp.send_data('esp_local_ctrl/version', '---')
if verbose:
print('esp_local_ctrl/version response : ', response)
try:
# Interpret this as JSON structure containing
# information with security version information
info = json.loads(response)
try:
sec_patch_ver = info['local_ctrl']['sec_patch_ver']
except KeyError:
sec_patch_ver = 0
return sec_patch_ver
except ValueError:
# If decoding as JSON fails, we assume default patch level
return 0
except Exception as e:
on_except(e)
return None
async def version_match(tp, protover, verbose=False): async def version_match(tp, protover, verbose=False):
try: try:
response = await tp.send_data('esp_local_ctrl/version', protover) response = await tp.send_data('esp_local_ctrl/version', protover)
@@ -164,7 +189,7 @@ async def version_match(tp, protover, verbose=False):
# information with versions and capabilities of both # information with versions and capabilities of both
# provisioning service and application # provisioning service and application
info = json.loads(response) info = json.loads(response)
if info['prov']['ver'].lower() == protover.lower(): if info['local_ctrl']['ver'].lower() == protover.lower():
return True return True
except ValueError: except ValueError:
@@ -191,14 +216,19 @@ async def has_capability(tp, capability='none', verbose=False):
# information with versions and capabilities of both # information with versions and capabilities of both
# provisioning service and application # provisioning service and application
info = json.loads(response) info = json.loads(response)
supported_capabilities = info['prov']['cap'] try:
if capability.lower() == 'none': supported_capabilities = info['local_ctrl']['cap']
# No specific capability to check, but capabilities if capability.lower() == 'none':
# feature is present so return True # No specific capability to check, but capabilities
return True # feature is present so return True
elif capability in supported_capabilities: return True
return True elif capability in supported_capabilities:
return False return True
return False
except KeyError:
# If capabilities field is not present, it means
# that capabilities are not supported
return False
except ValueError: except ValueError:
# If decoding as JSON fails, it means that capabilities # If decoding as JSON fails, it means that capabilities
@@ -341,6 +371,7 @@ async def main():
if obj_transport is None: if obj_transport is None:
raise RuntimeError('Failed to establish connection') raise RuntimeError('Failed to establish connection')
sec_patch_ver = 0
# If security version not specified check in capabilities # If security version not specified check in capabilities
if args.secver is None: if args.secver is None:
# First check if capabilities are supported or not # First check if capabilities are supported or not
@@ -362,13 +393,14 @@ async def main():
args.pop = '' args.pop = ''
if (args.secver == 2): if (args.secver == 2):
sec_patch_ver = await get_sec_patch_ver(obj_transport, args.verbose)
if len(args.sec2_usr) == 0: if len(args.sec2_usr) == 0:
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ') args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
if len(args.sec2_pwd) == 0: if len(args.sec2_pwd) == 0:
prompt_str = 'Security Scheme 2 - SRP6a Password required: ' prompt_str = 'Security Scheme 2 - SRP6a Password required: '
args.sec2_pwd = getpass(prompt_str) args.sec2_pwd = getpass(prompt_str)
obj_security = get_security(args.secver, args.sec2_usr, args.sec2_pwd, args.pop, args.verbose) obj_security = get_security(args.secver, sec_patch_ver, args.sec2_usr, args.sec2_pwd, args.pop, args.verbose)
if obj_security is None: if obj_security is None:
raise ValueError('Invalid Security Version') raise ValueError('Invalid Security Version')

View File

@@ -457,7 +457,7 @@ async def main():
args.sec1_pop = '' args.sec1_pop = ''
if (args.secver == 2): if (args.secver == 2):
sec_patch_ver = await get_sec_patch_ver(obj_transport) sec_patch_ver = await get_sec_patch_ver(obj_transport, args.verbose)
if len(args.sec2_usr) == 0: if len(args.sec2_usr) == 0:
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ') args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
if len(args.sec2_pwd) == 0: if len(args.sec2_pwd) == 0: