From 0a329ab9a44a29bcd49d4aecb8c0c45bfe752f01 Mon Sep 17 00:00:00 2001 From: Mahavir Jain Date: Tue, 4 Mar 2025 22:51:27 +0530 Subject: [PATCH] fix(esp_local_ctrl): update for changes in protocomm security2 scheme --- .../esp_local_ctrl/src/esp_local_ctrl.c | 35 ++++++++--- components/wifi_provisioning/src/manager.c | 2 + .../esp_local_ctrl/scripts/esp_local_ctrl.py | 60 ++++++++++++++----- tools/esp_prov/esp_prov.py | 2 +- 4 files changed, 75 insertions(+), 24 deletions(-) diff --git a/components/esp_local_ctrl/src/esp_local_ctrl.c b/components/esp_local_ctrl/src/esp_local_ctrl.c index 51de3df35b..77f19ec118 100644 --- a/components/esp_local_ctrl/src/esp_local_ctrl.c +++ b/components/esp_local_ctrl/src/esp_local_ctrl.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: 2019-2022 Espressif Systems (Shanghai) CO LTD + * SPDX-FileCopyrightText: 2019-2025 Espressif Systems (Shanghai) CO LTD * * SPDX-License-Identifier: Apache-2.0 */ @@ -20,6 +20,8 @@ #include "esp_local_ctrl.pb-c.h" #define ESP_LOCAL_CTRL_VERSION "v1.0" +/* JSON format string for version endpoint */ +#define ESP_LOCAL_CTRL_VER_FMT_STR "{\"local_ctrl\":{\"ver\":\"%s\",\"sec_ver\":%d,\"sec_patch_ver\":%d}}" struct inst_ctx { protocomm_t *pc; @@ -136,14 +138,6 @@ esp_err_t esp_local_ctrl_start(const esp_local_ctrl_config_t *config) } } - ret = protocomm_set_version(local_ctrl_inst_ctx->pc, "esp_local_ctrl/version", - ESP_LOCAL_CTRL_VERSION); - if (ret != ESP_OK) { - ESP_LOGE(TAG, "Failed to set version endpoint"); - esp_local_ctrl_stop(); - return ret; - } - protocomm_security_t *proto_sec_handle = NULL; switch (local_ctrl_inst_ctx->config.proto_sec.version) { case PROTOCOM_SEC_CUSTOM: @@ -183,6 +177,29 @@ esp_err_t esp_local_ctrl_start(const esp_local_ctrl_config_t *config) return ret; } + int sec_ver = 0; + uint8_t sec_patch_ver = 0; + protocomm_get_sec_version(local_ctrl_inst_ctx->pc, &sec_ver, &sec_patch_ver); + + const int rsize = snprintf(NULL, 0, ESP_LOCAL_CTRL_VER_FMT_STR, ESP_LOCAL_CTRL_VERSION, sec_ver, sec_patch_ver) + 1; + char *ver_str = malloc(rsize); + if (!ver_str) { + ESP_LOGE(TAG, "Failed to allocate memory for version string"); + esp_local_ctrl_stop(); + return ESP_ERR_NO_MEM; + } + snprintf(ver_str, rsize, ESP_LOCAL_CTRL_VER_FMT_STR, ESP_LOCAL_CTRL_VERSION, sec_ver, sec_patch_ver); + + ESP_LOGD(TAG, "ver_str: %s", ver_str); + ret = protocomm_set_version(local_ctrl_inst_ctx->pc, "esp_local_ctrl/version", + ver_str); + free(ver_str); + if (ret != ESP_OK) { + ESP_LOGE(TAG, "Failed to set version endpoint"); + esp_local_ctrl_stop(); + return ret; + } + ret = protocomm_add_endpoint(local_ctrl_inst_ctx->pc, "esp_local_ctrl/control", esp_local_ctrl_data_handler, NULL); if (ret != ESP_OK) { diff --git a/components/wifi_provisioning/src/manager.c b/components/wifi_provisioning/src/manager.c index 559832ee9e..57be1dd63e 100644 --- a/components/wifi_provisioning/src/manager.c +++ b/components/wifi_provisioning/src/manager.c @@ -351,6 +351,8 @@ static esp_err_t wifi_prov_mgr_start_service(const char *service_name, const cha /* Set version information / capabilities of provisioning service and application */ cJSON *version_json = wifi_prov_get_info_json(); char *version_str = cJSON_Print(version_json); + ESP_LOGD(TAG, "version_str :%s:", version_str); + ret = protocomm_set_version(prov_ctx->pc, "proto-ver", version_str); free(version_str); cJSON_Delete(version_json); diff --git a/examples/protocols/esp_local_ctrl/scripts/esp_local_ctrl.py b/examples/protocols/esp_local_ctrl/scripts/esp_local_ctrl.py index 597a977634..5fc8086223 100644 --- a/examples/protocols/esp_local_ctrl/scripts/esp_local_ctrl.py +++ b/examples/protocols/esp_local_ctrl/scripts/esp_local_ctrl.py @@ -1,9 +1,8 @@ #!/usr/bin/env python # -# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2018-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Apache-2.0 # - import argparse import asyncio import json @@ -113,9 +112,9 @@ def on_except(err): print(err) -def get_security(secver, username, password, pop='', verbose=False): +def get_security(secver, sec_patch_ver, username, password, pop='', verbose=False): if secver == 2: - return security.Security2(username, password, verbose) + return security.Security2(sec_patch_ver, username, password, verbose) if secver == 1: return security.Security1(pop, verbose) if secver == 0: @@ -148,6 +147,32 @@ async def get_transport(sel_transport, service_name, check_hostname): return None +async def get_sec_patch_ver(tp, verbose=False): + try: + response = await tp.send_data('esp_local_ctrl/version', '---') + + if verbose: + print('esp_local_ctrl/version response : ', response) + + try: + # Interpret this as JSON structure containing + # information with security version information + info = json.loads(response) + try: + sec_patch_ver = info['local_ctrl']['sec_patch_ver'] + except KeyError: + sec_patch_ver = 0 + return sec_patch_ver + + except ValueError: + # If decoding as JSON fails, we assume default patch level + return 0 + + except Exception as e: + on_except(e) + return None + + async def version_match(tp, protover, verbose=False): try: response = await tp.send_data('esp_local_ctrl/version', protover) @@ -164,7 +189,7 @@ async def version_match(tp, protover, verbose=False): # information with versions and capabilities of both # provisioning service and application info = json.loads(response) - if info['prov']['ver'].lower() == protover.lower(): + if info['local_ctrl']['ver'].lower() == protover.lower(): return True except ValueError: @@ -191,14 +216,19 @@ async def has_capability(tp, capability='none', verbose=False): # information with versions and capabilities of both # provisioning service and application info = json.loads(response) - supported_capabilities = info['prov']['cap'] - if capability.lower() == 'none': - # No specific capability to check, but capabilities - # feature is present so return True - return True - elif capability in supported_capabilities: - return True - return False + try: + supported_capabilities = info['local_ctrl']['cap'] + if capability.lower() == 'none': + # No specific capability to check, but capabilities + # feature is present so return True + return True + elif capability in supported_capabilities: + return True + return False + except KeyError: + # If capabilities field is not present, it means + # that capabilities are not supported + return False except ValueError: # If decoding as JSON fails, it means that capabilities @@ -341,6 +371,7 @@ async def main(): if obj_transport is None: raise RuntimeError('Failed to establish connection') + sec_patch_ver = 0 # If security version not specified check in capabilities if args.secver is None: # First check if capabilities are supported or not @@ -362,13 +393,14 @@ async def main(): args.pop = '' if (args.secver == 2): + sec_patch_ver = await get_sec_patch_ver(obj_transport, args.verbose) if len(args.sec2_usr) == 0: args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ') if len(args.sec2_pwd) == 0: prompt_str = 'Security Scheme 2 - SRP6a Password required: ' args.sec2_pwd = getpass(prompt_str) - obj_security = get_security(args.secver, args.sec2_usr, args.sec2_pwd, args.pop, args.verbose) + obj_security = get_security(args.secver, sec_patch_ver, args.sec2_usr, args.sec2_pwd, args.pop, args.verbose) if obj_security is None: raise ValueError('Invalid Security Version') diff --git a/tools/esp_prov/esp_prov.py b/tools/esp_prov/esp_prov.py index 0d869842d5..35fc7ce6f9 100644 --- a/tools/esp_prov/esp_prov.py +++ b/tools/esp_prov/esp_prov.py @@ -457,7 +457,7 @@ async def main(): args.sec1_pop = '' if (args.secver == 2): - sec_patch_ver = await get_sec_patch_ver(obj_transport) + sec_patch_ver = await get_sec_patch_ver(obj_transport, args.verbose) if len(args.sec2_usr) == 0: args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ') if len(args.sec2_pwd) == 0: