esp_prov: Compatibility changes and refactoring

- Removed python 2 compatibility
- Removed dependencies on redundant external modules
- Interactive provisioning input for security scheme 2
- Style changes:
  Updated print statements to format strings
  Colored verbose logging
  Raised exceptions on errors instead of clean exits
This commit is contained in:
Laukik Hase
2022-06-22 15:14:19 +05:30
parent 2c4e5c2963
commit 9aefcb12f5
20 changed files with 212 additions and 341 deletions

View File

@@ -2192,14 +2192,6 @@ tools/ci/python_packages/ttfw_idf/unity_test_parser.py
tools/ci/python_packages/wifi_tools.py tools/ci/python_packages/wifi_tools.py
tools/ci/test_autocomplete.py tools/ci/test_autocomplete.py
tools/esp_app_trace/test/sysview/blink.c tools/esp_app_trace/test/sysview/blink.c
tools/esp_prov/__init__.py
tools/esp_prov/prov/__init__.py
tools/esp_prov/prov/wifi_prov.py
tools/esp_prov/security/security.py
tools/esp_prov/security/security0.py
tools/esp_prov/security/security1.py
tools/esp_prov/transport/__init__.py
tools/esp_prov/utils/__init__.py
tools/find_apps.py tools/find_apps.py
tools/find_build_apps/__init__.py tools/find_build_apps/__init__.py
tools/find_build_apps/cmake.py tools/find_build_apps/cmake.py

View File

@@ -206,7 +206,6 @@ tools/esp_prov/transport/transport.py
tools/esp_prov/transport/transport_ble.py tools/esp_prov/transport/transport_ble.py
tools/esp_prov/transport/transport_console.py tools/esp_prov/transport/transport_console.py
tools/esp_prov/transport/transport_http.py tools/esp_prov/transport/transport_http.py
tools/esp_prov/utils/convenience.py
tools/find_apps.py tools/find_apps.py
tools/find_build_apps/common.py tools/find_build_apps/common.py
tools/gen_esp_err_to_name.py tools/gen_esp_err_to_name.py

View File

@@ -1 +1,5 @@
# SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0
#
from .esp_prov import * # noqa: export esp_prov module to users from .esp_prov import * # noqa: export esp_prov module to users

View File

@@ -4,8 +4,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
from __future__ import print_function
import argparse import argparse
import asyncio import asyncio
import json import json
@@ -13,7 +11,6 @@ import os
import sys import sys
import textwrap import textwrap
import time import time
from builtins import input as binput
from getpass import getpass from getpass import getpass
try: try:
@@ -289,7 +286,7 @@ async def wait_wifi_connected(tp, sec):
retry -= 1 retry -= 1
print('Waiting to poll status again (status %s, %d tries left)...' % (ret, retry)) print('Waiting to poll status again (status %s, %d tries left)...' % (ret, retry))
else: else:
print('---- Provisioning failed ----') print('---- Provisioning failed! ----')
return False return False
@@ -381,68 +378,69 @@ async def main():
if args.secver == 2 and args.sec2_gen_cred: if args.secver == 2 and args.sec2_gen_cred:
if not args.sec2_usr or not args.sec2_pwd: if not args.sec2_usr or not args.sec2_pwd:
print('---- Username/password cannot be empty for security scheme 2 (SRP6a) ----') raise ValueError('Username/password cannot be empty for security scheme 2 (SRP6a)')
exit(1)
print('==== Salt-verifier for security scheme 2 (SRP6a) ====') print('==== Salt-verifier for security scheme 2 (SRP6a) ====')
security.sec2_gen_salt_verifier(args.sec2_usr, args.sec2_pwd, args.sec2_salt_len) security.sec2_gen_salt_verifier(args.sec2_usr, args.sec2_pwd, args.sec2_salt_len)
exit(0) sys.exit()
obj_transport = await get_transport(args.mode.lower(), args.name) obj_transport = await get_transport(args.mode.lower(), args.name)
if obj_transport is None: if obj_transport is None:
print('---- Failed to establish connection ----') raise RuntimeError('Failed to establish connection')
exit(1)
try: try:
# If security version not specified check in capabilities # If security version not specified check in capabilities
if args.secver is None: if args.secver is None:
# First check if capabilities are supported or not # First check if capabilities are supported or not
if not await has_capability(obj_transport): if not await has_capability(obj_transport):
print('Security capabilities could not be determined. Please specify "--sec_ver" explicitly') print('Security capabilities could not be determined, please specify "--sec_ver" explicitly')
print('---- Invalid Security Version ----') raise ValueError('Invalid Security Version')
exit(2)
# When no_sec is present, use security 0, else security 1 # When no_sec is present, use security 0, else security 1
args.secver = int(not await has_capability(obj_transport, 'no_sec')) args.secver = int(not await has_capability(obj_transport, 'no_sec'))
print('Security scheme determined to be :', args.secver) print(f'==== Security Scheme: {args.secver} ====')
if (args.secver != 0) and not await has_capability(obj_transport, 'no_pop'): if (args.secver == 1):
if not await has_capability(obj_transport, 'no_pop'):
if len(args.sec1_pop) == 0: if len(args.sec1_pop) == 0:
args.sec1_pop = binput('Proof of Possession required : ') prompt_str = 'Proof of Possession required: '
args.sec1_pop = getpass(prompt_str)
elif len(args.sec1_pop) != 0: elif len(args.sec1_pop) != 0:
print('---- Proof of Possession will be ignored ----') print('Proof of Possession will be ignored')
args.sec1_pop = '' args.sec1_pop = ''
if (args.secver == 2):
if len(args.sec2_usr) == 0:
args.sec2_usr = input('Security Scheme 2 - SRP6a Username required: ')
if len(args.sec2_pwd) == 0:
prompt_str = 'Security Scheme 2 - SRP6a Password required: '
args.sec2_pwd = getpass(prompt_str)
obj_security = get_security(args.secver, args.sec2_usr, args.sec2_pwd, args.sec1_pop, args.verbose) obj_security = get_security(args.secver, args.sec2_usr, args.sec2_pwd, args.sec1_pop, args.verbose)
if obj_security is None: if obj_security is None:
print('---- Invalid Security Version ----') raise ValueError('Invalid Security Version')
exit(2)
if args.version != '': if args.version != '':
print('\n==== Verifying protocol version ====') print('\n==== Verifying protocol version ====')
if not await version_match(obj_transport, args.version, args.verbose): if not await version_match(obj_transport, args.version, args.verbose):
print('---- Error in protocol version matching ----') raise RuntimeError('Error in protocol version matching')
exit(3)
print('==== Verified protocol version successfully ====') print('==== Verified protocol version successfully ====')
print('\n==== Starting Session ====') print('\n==== Starting Session ====')
if not await establish_session(obj_transport, obj_security): if not await establish_session(obj_transport, obj_security):
print('Failed to establish session. Ensure that security scheme and proof of possession are correct') print('Failed to establish session. Ensure that security scheme and proof of possession are correct')
print('---- Error in establishing session ----') raise RuntimeError('Error in establishing session')
exit(4)
print('==== Session Established ====') print('==== Session Established ====')
if args.custom_data != '': if args.custom_data != '':
print('\n==== Sending Custom data to esp32 ====') print('\n==== Sending Custom data to Target ====')
if not await custom_data(obj_transport, obj_security, args.custom_data): if not await custom_data(obj_transport, obj_security, args.custom_data):
print('---- Error in custom data ----') raise RuntimeError('Error in custom data')
exit(5)
print('==== Custom data sent successfully ====') print('==== Custom data sent successfully ====')
if args.ssid == '': if args.ssid == '':
if not await has_capability(obj_transport, 'wifi_scan'): if not await has_capability(obj_transport, 'wifi_scan'):
print('---- Wi-Fi Scan List is not supported by provisioning service ----') raise RuntimeError('Wi-Fi Scan List is not supported by provisioning service')
print('---- Rerun esp_prov with SSID and Passphrase as argument ----')
exit(3)
while True: while True:
print('\n==== Scanning Wi-Fi APs ====') print('\n==== Scanning Wi-Fi APs ====')
@@ -451,12 +449,11 @@ async def main():
end_time = time.time() end_time = time.time()
print('\n++++ Scan finished in ' + str(end_time - start_time) + ' sec') print('\n++++ Scan finished in ' + str(end_time - start_time) + ' sec')
if APs is None: if APs is None:
print('---- Error in scanning Wi-Fi APs ----') raise RuntimeError('Error in scanning Wi-Fi APs')
exit(8)
if len(APs) == 0: if len(APs) == 0:
print('No APs found!') print('No APs found!')
exit(9) sys.exit()
print('==== Wi-Fi Scan results ====') print('==== Wi-Fi Scan results ====')
print('{0: >4} {1: <33} {2: <12} {3: >4} {4: <4} {5: <16}'.format( print('{0: >4} {1: <33} {2: <12} {3: >4} {4: <4} {5: <16}'.format(
@@ -467,7 +464,7 @@ async def main():
while True: while True:
try: try:
select = int(binput('Select AP by number (0 to rescan) : ')) select = int(input('Select AP by number (0 to rescan) : '))
if select < 0 or select > len(APs): if select < 0 or select > len(APs):
raise ValueError raise ValueError
break break
@@ -483,16 +480,14 @@ async def main():
prompt_str = 'Enter passphrase for {0} : '.format(args.ssid) prompt_str = 'Enter passphrase for {0} : '.format(args.ssid)
args.passphrase = getpass(prompt_str) args.passphrase = getpass(prompt_str)
print('\n==== Sending Wi-Fi credential to esp32 ====') print('\n==== Sending Wi-Fi Credentials to Target ====')
if not await send_wifi_config(obj_transport, obj_security, args.ssid, args.passphrase): if not await send_wifi_config(obj_transport, obj_security, args.ssid, args.passphrase):
print('---- Error in send Wi-Fi config ----') raise RuntimeError('Error in send Wi-Fi config')
exit(6)
print('==== Wi-Fi Credentials sent successfully ====') print('==== Wi-Fi Credentials sent successfully ====')
print('\n==== Applying config to esp32 ====') print('\n==== Applying Wi-Fi Config to Target ====')
if not await apply_wifi_config(obj_transport, obj_security): if not await apply_wifi_config(obj_transport, obj_security):
print('---- Error in apply Wi-Fi config ----') raise RuntimeError('Error in apply Wi-Fi config')
exit(7)
print('==== Apply config sent successfully ====') print('==== Apply config sent successfully ====')
await wait_wifi_connected(obj_transport, obj_security) await wait_wifi_connected(obj_transport, obj_security)

View File

@@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD # SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# # SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
from .custom_prov import * # noqa F403 from .custom_prov import * # noqa F403

View File

@@ -4,26 +4,23 @@
# APIs for interpreting and creating protobuf packets for `custom-config` protocomm endpoint # APIs for interpreting and creating protobuf packets for `custom-config` protocomm endpoint
from __future__ import print_function from utils import str_to_bytes
import utils
from future.utils import tobytes
def print_verbose(security_ctx, data): def print_verbose(security_ctx, data):
if (security_ctx.verbose): if (security_ctx.verbose):
print('++++ ' + data + ' ++++') print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def custom_data_request(security_ctx, data): def custom_data_request(security_ctx, data):
# Encrypt the custom data # Encrypt the custom data
enc_cmd = security_ctx.encrypt_data(tobytes(data)) enc_cmd = security_ctx.encrypt_data(str_to_bytes(data))
print_verbose(security_ctx, 'Client -> Device (CustomData cmd) ' + utils.str_to_hexstr(enc_cmd)) print_verbose(security_ctx, f'Client -> Device (CustomData cmd): 0x{enc_cmd.hex()}')
return enc_cmd.decode('latin-1') return enc_cmd.decode('latin-1')
def custom_data_response(security_ctx, response_data): def custom_data_response(security_ctx, response_data):
# Decrypt response packet # Decrypt response packet
decrypt = security_ctx.decrypt_data(tobytes(response_data)) decrypt = security_ctx.decrypt_data(str_to_bytes(response_data))
print('CustomData response: ' + str(decrypt)) print(f'++++ CustomData response: {str(decrypt)}++++')
return 0 return 0

View File

@@ -1,30 +1,16 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD # SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# # SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
# APIs for interpreting and creating protobuf packets for Wi-Fi provisioning # APIs for interpreting and creating protobuf packets for Wi-Fi provisioning
from __future__ import print_function
import proto import proto
import utils from utils import str_to_bytes
from future.utils import tobytes
def print_verbose(security_ctx, data): def print_verbose(security_ctx, data):
if (security_ctx.verbose): if (security_ctx.verbose):
print('++++ ' + data + ' ++++') print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def config_get_status_request(security_ctx): def config_get_status_request(security_ctx):
@@ -33,34 +19,34 @@ def config_get_status_request(security_ctx):
cfg1.msg = proto.wifi_config_pb2.TypeCmdGetStatus cfg1.msg = proto.wifi_config_pb2.TypeCmdGetStatus
cmd_get_status = proto.wifi_config_pb2.CmdGetStatus() cmd_get_status = proto.wifi_config_pb2.CmdGetStatus()
cfg1.cmd_get_status.MergeFrom(cmd_get_status) cfg1.cmd_get_status.MergeFrom(cmd_get_status)
encrypted_cfg = security_ctx.encrypt_data(cfg1.SerializeToString()).decode('latin-1') encrypted_cfg = security_ctx.encrypt_data(cfg1.SerializeToString())
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdGetStatus) ' + utils.str_to_hexstr(encrypted_cfg)) print_verbose(security_ctx, f'Client -> Device (Encrypted CmdGetStatus): 0x{encrypted_cfg.hex()}')
return encrypted_cfg return encrypted_cfg.decode('latin-1')
def config_get_status_response(security_ctx, response_data): def config_get_status_response(security_ctx, response_data):
# Interpret protobuf response packet from GetStatus command # Interpret protobuf response packet from GetStatus command
decrypted_message = security_ctx.decrypt_data(tobytes(response_data)) decrypted_message = security_ctx.decrypt_data(str_to_bytes(response_data))
cmd_resp1 = proto.wifi_config_pb2.WiFiConfigPayload() cmd_resp1 = proto.wifi_config_pb2.WiFiConfigPayload()
cmd_resp1.ParseFromString(decrypted_message) cmd_resp1.ParseFromString(decrypted_message)
print_verbose(security_ctx, 'Response type ' + str(cmd_resp1.msg)) print_verbose(security_ctx, f'CmdGetStatus type: {str(cmd_resp1.msg)}')
print_verbose(security_ctx, 'Response status ' + str(cmd_resp1.resp_get_status.status)) print_verbose(security_ctx, f'CmdGetStatus status: {str(cmd_resp1.resp_get_status.status)}')
if cmd_resp1.resp_get_status.sta_state == 0: if cmd_resp1.resp_get_status.sta_state == 0:
print('++++ WiFi state: ' + 'connected ++++') print('==== WiFi state: Connected ====')
return 'connected' return 'connected'
elif cmd_resp1.resp_get_status.sta_state == 1: elif cmd_resp1.resp_get_status.sta_state == 1:
print('++++ WiFi state: ' + 'connecting... ++++') print('++++ WiFi state: Connecting... ++++')
return 'connecting' return 'connecting'
elif cmd_resp1.resp_get_status.sta_state == 2: elif cmd_resp1.resp_get_status.sta_state == 2:
print('++++ WiFi state: ' + 'disconnected ++++') print('---- WiFi state: Disconnected ----')
return 'disconnected' return 'disconnected'
elif cmd_resp1.resp_get_status.sta_state == 3: elif cmd_resp1.resp_get_status.sta_state == 3:
print('++++ WiFi state: ' + 'connection failed ++++') print('---- WiFi state: Connection Failed ----')
if cmd_resp1.resp_get_status.fail_reason == 0: if cmd_resp1.resp_get_status.fail_reason == 0:
print('++++ Failure reason: ' + 'Incorrect Password ++++') print('---- Failure reason: Incorrect Password ----')
elif cmd_resp1.resp_get_status.fail_reason == 1: elif cmd_resp1.resp_get_status.fail_reason == 1:
print('++++ Failure reason: ' + 'Incorrect SSID ++++') print('---- Failure reason: Incorrect SSID ----')
return 'failed' return 'failed'
return 'unknown' return 'unknown'
@@ -69,19 +55,19 @@ def config_set_config_request(security_ctx, ssid, passphrase):
# Form protobuf request packet for SetConfig command # Form protobuf request packet for SetConfig command
cmd = proto.wifi_config_pb2.WiFiConfigPayload() cmd = proto.wifi_config_pb2.WiFiConfigPayload()
cmd.msg = proto.wifi_config_pb2.TypeCmdSetConfig cmd.msg = proto.wifi_config_pb2.TypeCmdSetConfig
cmd.cmd_set_config.ssid = tobytes(ssid) cmd.cmd_set_config.ssid = str_to_bytes(ssid)
cmd.cmd_set_config.passphrase = tobytes(passphrase) cmd.cmd_set_config.passphrase = str_to_bytes(passphrase)
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1') enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, 'Client -> Device (SetConfig cmd) ' + utils.str_to_hexstr(enc_cmd)) print_verbose(security_ctx, f'Client -> Device (SetConfig cmd): 0x{enc_cmd.hex()}')
return enc_cmd return enc_cmd.decode('latin-1')
def config_set_config_response(security_ctx, response_data): def config_set_config_response(security_ctx, response_data):
# Interpret protobuf response packet from SetConfig command # Interpret protobuf response packet from SetConfig command
decrypt = security_ctx.decrypt_data(tobytes(response_data)) decrypt = security_ctx.decrypt_data(str_to_bytes(response_data))
cmd_resp4 = proto.wifi_config_pb2.WiFiConfigPayload() cmd_resp4 = proto.wifi_config_pb2.WiFiConfigPayload()
cmd_resp4.ParseFromString(decrypt) cmd_resp4.ParseFromString(decrypt)
print_verbose(security_ctx, 'SetConfig status ' + str(cmd_resp4.resp_set_config.status)) print_verbose(security_ctx, f'SetConfig status: 0x{str(cmd_resp4.resp_set_config.status)}')
return cmd_resp4.resp_set_config.status return cmd_resp4.resp_set_config.status
@@ -89,15 +75,15 @@ def config_apply_config_request(security_ctx):
# Form protobuf request packet for ApplyConfig command # Form protobuf request packet for ApplyConfig command
cmd = proto.wifi_config_pb2.WiFiConfigPayload() cmd = proto.wifi_config_pb2.WiFiConfigPayload()
cmd.msg = proto.wifi_config_pb2.TypeCmdApplyConfig cmd.msg = proto.wifi_config_pb2.TypeCmdApplyConfig
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1') enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, 'Client -> Device (ApplyConfig cmd) ' + utils.str_to_hexstr(enc_cmd)) print_verbose(security_ctx, f'Client -> Device (ApplyConfig cmd): 0x{enc_cmd.hex()}')
return enc_cmd return enc_cmd.decode('latin-1')
def config_apply_config_response(security_ctx, response_data): def config_apply_config_response(security_ctx, response_data):
# Interpret protobuf response packet from ApplyConfig command # Interpret protobuf response packet from ApplyConfig command
decrypt = security_ctx.decrypt_data(tobytes(response_data)) decrypt = security_ctx.decrypt_data(str_to_bytes(response_data))
cmd_resp5 = proto.wifi_config_pb2.WiFiConfigPayload() cmd_resp5 = proto.wifi_config_pb2.WiFiConfigPayload()
cmd_resp5.ParseFromString(decrypt) cmd_resp5.ParseFromString(decrypt)
print_verbose(security_ctx, 'ApplyConfig status ' + str(cmd_resp5.resp_apply_config.status)) print_verbose(security_ctx, f'ApplyConfig status: 0x{str(cmd_resp5.resp_apply_config.status)}')
return cmd_resp5.resp_apply_config.status return cmd_resp5.resp_apply_config.status

View File

@@ -3,17 +3,13 @@
# #
# APIs for interpreting and creating protobuf packets for Wi-Fi Scanning # APIs for interpreting and creating protobuf packets for Wi-Fi Scanning
from __future__ import print_function
import proto import proto
import utils from utils import str_to_bytes
from future.utils import tobytes
def print_verbose(security_ctx, data): def print_verbose(security_ctx, data):
if (security_ctx.verbose): if (security_ctx.verbose):
print('++++ ' + data + ' ++++') print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def scan_start_request(security_ctx, blocking=True, passive=False, group_channels=5, period_ms=120): def scan_start_request(security_ctx, blocking=True, passive=False, group_channels=5, period_ms=120):
@@ -24,17 +20,17 @@ def scan_start_request(security_ctx, blocking=True, passive=False, group_channel
cmd.cmd_scan_start.passive = passive cmd.cmd_scan_start.passive = passive
cmd.cmd_scan_start.group_channels = group_channels cmd.cmd_scan_start.group_channels = group_channels
cmd.cmd_scan_start.period_ms = period_ms cmd.cmd_scan_start.period_ms = period_ms
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1') enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdScanStart) ' + utils.str_to_hexstr(enc_cmd)) print_verbose(security_ctx, f'Client -> Device (Encrypted CmdScanStart): 0x{enc_cmd.hex()}')
return enc_cmd return enc_cmd.decode('latin-1')
def scan_start_response(security_ctx, response_data): def scan_start_response(security_ctx, response_data):
# Interpret protobuf response packet from ScanStart command # Interpret protobuf response packet from ScanStart command
dec_resp = security_ctx.decrypt_data(tobytes(response_data)) dec_resp = security_ctx.decrypt_data(str_to_bytes(response_data))
resp = proto.wifi_scan_pb2.WiFiScanPayload() resp = proto.wifi_scan_pb2.WiFiScanPayload()
resp.ParseFromString(dec_resp) resp.ParseFromString(dec_resp)
print_verbose(security_ctx, 'ScanStart status ' + str(resp.status)) print_verbose(security_ctx, f'ScanStart status: 0x{str(resp.status)}')
if resp.status != 0: if resp.status != 0:
raise RuntimeError raise RuntimeError
@@ -43,17 +39,17 @@ def scan_status_request(security_ctx):
# Form protobuf request packet for ScanStatus command # Form protobuf request packet for ScanStatus command
cmd = proto.wifi_scan_pb2.WiFiScanPayload() cmd = proto.wifi_scan_pb2.WiFiScanPayload()
cmd.msg = proto.wifi_scan_pb2.TypeCmdScanStatus cmd.msg = proto.wifi_scan_pb2.TypeCmdScanStatus
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1') enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdScanStatus) ' + utils.str_to_hexstr(enc_cmd)) print_verbose(security_ctx, f'Client -> Device (Encrypted CmdScanStatus): 0x{enc_cmd.hex()}')
return enc_cmd return enc_cmd.decode('latin-1')
def scan_status_response(security_ctx, response_data): def scan_status_response(security_ctx, response_data):
# Interpret protobuf response packet from ScanStatus command # Interpret protobuf response packet from ScanStatus command
dec_resp = security_ctx.decrypt_data(tobytes(response_data)) dec_resp = security_ctx.decrypt_data(str_to_bytes(response_data))
resp = proto.wifi_scan_pb2.WiFiScanPayload() resp = proto.wifi_scan_pb2.WiFiScanPayload()
resp.ParseFromString(dec_resp) resp.ParseFromString(dec_resp)
print_verbose(security_ctx, 'ScanStatus status ' + str(resp.status)) print_verbose(security_ctx, f'ScanStatus status: 0x{str(resp.status)}')
if resp.status != 0: if resp.status != 0:
raise RuntimeError raise RuntimeError
return {'finished': resp.resp_scan_status.scan_finished, 'count': resp.resp_scan_status.result_count} return {'finished': resp.resp_scan_status.scan_finished, 'count': resp.resp_scan_status.result_count}
@@ -65,17 +61,17 @@ def scan_result_request(security_ctx, index, count):
cmd.msg = proto.wifi_scan_pb2.TypeCmdScanResult cmd.msg = proto.wifi_scan_pb2.TypeCmdScanResult
cmd.cmd_scan_result.start_index = index cmd.cmd_scan_result.start_index = index
cmd.cmd_scan_result.count = count cmd.cmd_scan_result.count = count
enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString()).decode('latin-1') enc_cmd = security_ctx.encrypt_data(cmd.SerializeToString())
print_verbose(security_ctx, 'Client -> Device (Encrypted CmdScanResult) ' + utils.str_to_hexstr(enc_cmd)) print_verbose(security_ctx, f'Client -> Device (Encrypted CmdScanResult): 0x{enc_cmd.hex()}')
return enc_cmd return enc_cmd.decode('latin-1')
def scan_result_response(security_ctx, response_data): def scan_result_response(security_ctx, response_data):
# Interpret protobuf response packet from ScanResult command # Interpret protobuf response packet from ScanResult command
dec_resp = security_ctx.decrypt_data(tobytes(response_data)) dec_resp = security_ctx.decrypt_data(str_to_bytes(response_data))
resp = proto.wifi_scan_pb2.WiFiScanPayload() resp = proto.wifi_scan_pb2.WiFiScanPayload()
resp.ParseFromString(dec_resp) resp.ParseFromString(dec_resp)
print_verbose(security_ctx, 'ScanResult status ' + str(resp.status)) print_verbose(security_ctx, f'ScanResult status: 0x{str(resp.status)}')
if resp.status != 0: if resp.status != 0:
raise RuntimeError raise RuntimeError
authmode_str = ['Open', 'WEP', 'WPA_PSK', 'WPA2_PSK', 'WPA_WPA2_PSK', authmode_str = ['Open', 'WEP', 'WPA_PSK', 'WPA2_PSK', 'WPA_WPA2_PSK',
@@ -83,13 +79,13 @@ def scan_result_response(security_ctx, response_data):
results = [] results = []
for entry in resp.resp_scan_result.entries: for entry in resp.resp_scan_result.entries:
results += [{'ssid': entry.ssid.decode('latin-1').rstrip('\x00'), results += [{'ssid': entry.ssid.decode('latin-1').rstrip('\x00'),
'bssid': utils.str_to_hexstr(entry.bssid.decode('latin-1')), 'bssid': entry.bssid.hex(),
'channel': entry.channel, 'channel': entry.channel,
'rssi': entry.rssi, 'rssi': entry.rssi,
'auth': authmode_str[entry.auth]}] 'auth': authmode_str[entry.auth]}]
print_verbose(security_ctx, 'ScanResult SSID : ' + str(results[-1]['ssid'])) print_verbose(security_ctx, f"ScanResult SSID : {str(results[-1]['ssid'])}")
print_verbose(security_ctx, 'ScanResult BSSID : ' + str(results[-1]['bssid'])) print_verbose(security_ctx, f"ScanResult BSSID : {str(results[-1]['bssid'])}")
print_verbose(security_ctx, 'ScanResult Channel : ' + str(results[-1]['channel'])) print_verbose(security_ctx, f"ScanResult Channel : {str(results[-1]['channel'])}")
print_verbose(security_ctx, 'ScanResult RSSI : ' + str(results[-1]['rssi'])) print_verbose(security_ctx, f"ScanResult RSSI : {str(results[-1]['rssi'])}")
print_verbose(security_ctx, 'ScanResult AUTH : ' + str(results[-1]['auth'])) print_verbose(security_ctx, f"ScanResult AUTH : {str(results[-1]['auth'])}")
return results return results

View File

@@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD # SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# # SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
# Base class for protocomm security # Base class for protocomm security

View File

@@ -1,25 +1,12 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD # SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# # SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
# APIs for interpreting and creating protobuf packets for # APIs for interpreting and creating protobuf packets for
# protocomm endpoint with security type protocomm_security0 # protocomm endpoint with security type protocomm_security0
from __future__ import print_function
import proto import proto
from future.utils import tobytes from utils import str_to_bytes
from .security import Security from .security import Security
@@ -52,10 +39,10 @@ class Security0(Security):
def setup0_response(self, response_data): def setup0_response(self, response_data):
# Interpret protocomm security0 response packet # Interpret protocomm security0 response packet
setup_resp = proto.session_pb2.SessionData() setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data)) setup_resp.ParseFromString(str_to_bytes(response_data))
# Check if security scheme matches # Check if security scheme matches
if setup_resp.sec_ver != proto.session_pb2.SecScheme0: if setup_resp.sec_ver != proto.session_pb2.SecScheme0:
print('Incorrect sec scheme') raise RuntimeError('Incorrect security scheme')
def encrypt_data(self, data): def encrypt_data(self, data):
# Passive. No encryption when security0 used # Passive. No encryption when security0 used

View File

@@ -1,35 +1,24 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD # SPDX-FileCopyrightText: 2018-2022 Espressif Systems (Shanghai) CO LTD
# # SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
# APIs for interpreting and creating protobuf packets for # APIs for interpreting and creating protobuf packets for
# protocomm endpoint with security type protocomm_security1 # protocomm endpoint with security type protocomm_security1
from __future__ import print_function
import proto import proto
import session_pb2
import utils
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from future.utils import tobytes from utils import long_to_bytes, str_to_bytes
from .security import Security from .security import Security
def a_xor_b(a: bytes, b: bytes) -> bytes:
return b''.join(long_to_bytes(a[i] ^ b[i]) for i in range(0, len(b)))
# Enum for state of protocomm_security1 FSM # Enum for state of protocomm_security1 FSM
class security_state: class security_state:
REQUEST1 = 0 REQUEST1 = 0
@@ -38,25 +27,11 @@ class security_state:
FINISHED = 3 FINISHED = 3
def xor(a, b):
# XOR two inputs of type `bytes`
ret = bytearray()
# Decode the input bytes to strings
a = a.decode('latin-1')
b = b.decode('latin-1')
for i in range(max(len(a), len(b))):
# Convert the characters to corresponding 8-bit ASCII codes
# then XOR them and store in bytearray
ret.append(([0, ord(a[i])][i < len(a)]) ^ ([0, ord(b[i])][i < len(b)]))
# Convert bytearray to bytes
return bytes(ret)
class Security1(Security): class Security1(Security):
def __init__(self, pop, verbose): def __init__(self, pop, verbose):
# Initialize state of the security1 FSM # Initialize state of the security1 FSM
self.session_state = security_state.REQUEST1 self.session_state = security_state.REQUEST1
self.pop = tobytes(pop) self.pop = str_to_bytes(pop)
self.verbose = verbose self.verbose = verbose
Security.__init__(self, self.security1_session) Security.__init__(self, self.security1_session)
@@ -66,59 +41,55 @@ class Security1(Security):
if (self.session_state == security_state.REQUEST1): if (self.session_state == security_state.REQUEST1):
self.session_state = security_state.RESPONSE1_REQUEST2 self.session_state = security_state.RESPONSE1_REQUEST2
return self.setup0_request() return self.setup0_request()
if (self.session_state == security_state.RESPONSE1_REQUEST2): elif (self.session_state == security_state.RESPONSE1_REQUEST2):
self.session_state = security_state.RESPONSE2 self.session_state = security_state.RESPONSE2
self.setup0_response(response_data) self.setup0_response(response_data)
return self.setup1_request() return self.setup1_request()
if (self.session_state == security_state.RESPONSE2): elif (self.session_state == security_state.RESPONSE2):
self.session_state = security_state.FINISHED self.session_state = security_state.FINISHED
self.setup1_response(response_data) self.setup1_response(response_data)
return None return None
else:
print('Unexpected state') print('Unexpected state')
return None return None
def __generate_key(self): def __generate_key(self):
# Generate private and public key pair for client # Generate private and public key pair for client
self.client_private_key = X25519PrivateKey.generate() self.client_private_key = X25519PrivateKey.generate()
try: self.client_public_key = self.client_private_key.public_key().public_bytes(
self.client_public_key = self.client_private_key.public_key().public_bytes( encoding=serialization.Encoding.Raw,
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
format=serialization.PublicFormat.Raw)
except TypeError:
# backward compatible call for older cryptography library
self.client_public_key = self.client_private_key.public_key().public_bytes()
def _print_verbose(self, data): def _print_verbose(self, data):
if (self.verbose): if (self.verbose):
print('++++ ' + data + ' ++++') print(f'\x1b[32;20m++++ {data} ++++\x1b[0m')
def setup0_request(self): def setup0_request(self):
# Form SessionCmd0 request packet using client public key # Form SessionCmd0 request packet using client public key
setup_req = session_pb2.SessionData() setup_req = proto.session_pb2.SessionData()
setup_req.sec_ver = session_pb2.SecScheme1 setup_req.sec_ver = proto.session_pb2.SecScheme1
self.__generate_key() self.__generate_key()
setup_req.sec1.sc0.client_pubkey = self.client_public_key setup_req.sec1.sc0.client_pubkey = self.client_public_key
self._print_verbose('Client Public Key:\t' + utils.str_to_hexstr(self.client_public_key.decode('latin-1'))) self._print_verbose(f'Client Public Key:\t0x{self.client_public_key.hex()}')
return setup_req.SerializeToString().decode('latin-1') return setup_req.SerializeToString().decode('latin-1')
def setup0_response(self, response_data): def setup0_response(self, response_data):
# Interpret SessionResp0 response packet # Interpret SessionResp0 response packet
setup_resp = proto.session_pb2.SessionData() setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data)) setup_resp.ParseFromString(str_to_bytes(response_data))
self._print_verbose('Security version:\t' + str(setup_resp.sec_ver)) self._print_verbose('Security version:\t' + str(setup_resp.sec_ver))
if setup_resp.sec_ver != session_pb2.SecScheme1: if setup_resp.sec_ver != proto.session_pb2.SecScheme1:
print('Incorrect sec scheme') raise RuntimeError('Incorrect security scheme')
exit(1)
self.device_public_key = setup_resp.sec1.sr0.device_pubkey self.device_public_key = setup_resp.sec1.sr0.device_pubkey
# Device random is the initialization vector # Device random is the initialization vector
device_random = setup_resp.sec1.sr0.device_random device_random = setup_resp.sec1.sr0.device_random
self._print_verbose('Device Public Key:\t' + utils.str_to_hexstr(self.device_public_key.decode('latin-1'))) self._print_verbose(f'Device Public Key:\t0x{self.device_public_key.hex()}')
self._print_verbose('Device Random:\t' + utils.str_to_hexstr(device_random.decode('latin-1'))) self._print_verbose(f'Device Random:\t0x{device_random.hex()}')
# Calculate Curve25519 shared key using Client private key and Device public key # Calculate Curve25519 shared key using Client private key and Device public key
sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key)) sharedK = self.client_private_key.exchange(X25519PublicKey.from_public_bytes(self.device_public_key))
self._print_verbose('Shared Key:\t' + utils.str_to_hexstr(sharedK.decode('latin-1'))) self._print_verbose(f'Shared Key:\t0x{sharedK.hex()}')
# If PoP is provided, XOR SHA256 of PoP with the previously # If PoP is provided, XOR SHA256 of PoP with the previously
# calculated Shared Key to form the actual Shared Key # calculated Shared Key to form the actual Shared Key
@@ -128,8 +99,8 @@ class Security1(Security):
h.update(self.pop) h.update(self.pop)
digest = h.finalize() digest = h.finalize()
# XOR with and update Shared Key # XOR with and update Shared Key
sharedK = xor(sharedK, digest) sharedK = a_xor_b(sharedK, digest)
self._print_verbose('New Shared Key XORed with PoP:\t' + utils.str_to_hexstr(sharedK.decode('latin-1'))) self._print_verbose(f'Updated Shared Key (Shared key XORed with PoP):\t0x{sharedK.hex()}')
# Initialize the encryption engine with Shared Key and initialization vector # Initialize the encryption engine with Shared Key and initialization vector
cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend()) cipher = Cipher(algorithms.AES(sharedK), modes.CTR(device_random), backend=default_backend())
self.cipher = cipher.encryptor() self.cipher = cipher.encryptor()
@@ -137,36 +108,33 @@ class Security1(Security):
def setup1_request(self): def setup1_request(self):
# Form SessionCmd1 request packet using encrypted device public key # Form SessionCmd1 request packet using encrypted device public key
setup_req = proto.session_pb2.SessionData() setup_req = proto.session_pb2.SessionData()
setup_req.sec_ver = session_pb2.SecScheme1 setup_req.sec_ver = proto.session_pb2.SecScheme1
setup_req.sec1.msg = proto.sec1_pb2.Session_Command1 setup_req.sec1.msg = proto.sec1_pb2.Session_Command1
# Encrypt device public key and attach to the request packet # Encrypt device public key and attach to the request packet
client_verify = self.cipher.update(self.device_public_key) client_verify = self.cipher.update(self.device_public_key)
self._print_verbose('Client Verify:\t' + utils.str_to_hexstr(client_verify.decode('latin-1'))) self._print_verbose(f'Client Proof:\t0x{client_verify.hex()}')
setup_req.sec1.sc1.client_verify_data = client_verify setup_req.sec1.sc1.client_verify_data = client_verify
return setup_req.SerializeToString().decode('latin-1') return setup_req.SerializeToString().decode('latin-1')
def setup1_response(self, response_data): def setup1_response(self, response_data):
# Interpret SessionResp1 response packet # Interpret SessionResp1 response packet
setup_resp = proto.session_pb2.SessionData() setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data)) setup_resp.ParseFromString(str_to_bytes(response_data))
# Ensure security scheme matches # Ensure security scheme matches
if setup_resp.sec_ver == session_pb2.SecScheme1: if setup_resp.sec_ver == proto.session_pb2.SecScheme1:
# Read encrypyed device verify string # Read encrypyed device verify string
device_verify = setup_resp.sec1.sr1.device_verify_data device_verify = setup_resp.sec1.sr1.device_verify_data
self._print_verbose('Device verify:\t' + utils.str_to_hexstr(device_verify.decode('latin-1'))) self._print_verbose(f'Device Proof:\t0x{device_verify.hex()}')
# Decrypt the device verify string # Decrypt the device verify string
enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data) enc_client_pubkey = self.cipher.update(setup_resp.sec1.sr1.device_verify_data)
self._print_verbose('Enc client pubkey:\t ' + utils.str_to_hexstr(enc_client_pubkey.decode('latin-1')))
# Match decryped string with client public key # Match decryped string with client public key
if enc_client_pubkey != self.client_public_key: if enc_client_pubkey != self.client_public_key:
print('Mismatch in device verify') raise RuntimeError('Failed to verify device!')
return -2
else: else:
print('Unsupported security protocol') raise RuntimeError('Unsupported security protocol')
return -1
def encrypt_data(self, data): def encrypt_data(self, data):
return self.cipher.update(tobytes(data)) return self.cipher.update(data)
def decrypt_data(self, data): def decrypt_data(self, data):
return self.cipher.update(tobytes(data)) return self.cipher.update(data)

View File

@@ -9,10 +9,10 @@ from typing import Any, Type
import proto import proto
from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from future.utils import tobytes from utils import long_to_bytes, str_to_bytes
from .security import Security from .security import Security
from .srp6a import Srp6a, bytes_to_long, generate_salt_and_verifier, long_to_bytes from .srp6a import Srp6a, generate_salt_and_verifier
AES_KEY_LEN = 256 // 8 AES_KEY_LEN = 256 // 8
@@ -70,7 +70,7 @@ class Security2(Security):
self.setup1_response(response_data) self.setup1_response(response_data)
return None return None
print('Unexpected state') print('---- Unexpected state! ----')
return None return None
def _print_verbose(self, data: str) -> None: def _print_verbose(self, data: str) -> None:
@@ -83,34 +83,30 @@ class Security2(Security):
setup_req.sec_ver = proto.session_pb2.SecScheme2 setup_req.sec_ver = proto.session_pb2.SecScheme2
setup_req.sec2.msg = proto.sec2_pb2.S2Session_Command0 setup_req.sec2.msg = proto.sec2_pb2.S2Session_Command0
setup_req.sec2.sc0.client_username = tobytes(self.username) setup_req.sec2.sc0.client_username = str_to_bytes(self.username)
self.srp6a_ctx = Srp6a(self.username, self.password) self.srp6a_ctx = Srp6a(self.username, self.password)
if self.srp6a_ctx is None: if self.srp6a_ctx is None:
print('Failed to initialize SRP6a instance!') raise RuntimeError('Failed to initialize SRP6a instance!')
exit(1)
client_pubkey = long_to_bytes(self.srp6a_ctx.A) client_pubkey = long_to_bytes(self.srp6a_ctx.A)
setup_req.sec2.sc0.client_pubkey = client_pubkey setup_req.sec2.sc0.client_pubkey = client_pubkey
self._print_verbose('Client Public Key:\t' + hex(bytes_to_long(client_pubkey))) self._print_verbose(f'Client Public Key:\t0x{client_pubkey.hex()}')
return setup_req.SerializeToString().decode('latin-1') return setup_req.SerializeToString().decode('latin-1')
def setup0_response(self, response_data: bytes) -> None: def setup0_response(self, response_data: bytes) -> None:
# Interpret SessionResp0 response packet # Interpret SessionResp0 response packet
setup_resp = proto.session_pb2.SessionData() setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data)) setup_resp.ParseFromString(str_to_bytes(response_data))
self._print_verbose('Security version:\t' + str(setup_resp.sec_ver)) self._print_verbose(f'Security version:\t{str(setup_resp.sec_ver)}')
if setup_resp.sec_ver != proto.session_pb2.SecScheme2: if setup_resp.sec_ver != proto.session_pb2.SecScheme2:
print('Incorrect sec scheme') raise RuntimeError('Incorrect security scheme')
exit(1)
# Device public key, random salt and password verifier # Device public key, random salt and password verifier
device_pubkey = setup_resp.sec2.sr0.device_pubkey device_pubkey = setup_resp.sec2.sr0.device_pubkey
device_salt = setup_resp.sec2.sr0.device_salt device_salt = setup_resp.sec2.sr0.device_salt
self._print_verbose('Device Public Key:\t' + hex(bytes_to_long(device_pubkey))) self._print_verbose(f'Device Public Key:\t0x{device_pubkey.hex()}')
self._print_verbose('Device Salt:\t' + hex(bytes_to_long(device_salt)))
self.client_pop_key = self.srp6a_ctx.process_challenge(device_salt, device_pubkey) self.client_pop_key = self.srp6a_ctx.process_challenge(device_salt, device_pubkey)
def setup1_request(self) -> Any: def setup1_request(self) -> Any:
@@ -120,7 +116,10 @@ class Security2(Security):
setup_req.sec2.msg = proto.sec2_pb2.S2Session_Command1 setup_req.sec2.msg = proto.sec2_pb2.S2Session_Command1
# Encrypt device public key and attach to the request packet # Encrypt device public key and attach to the request packet
self._print_verbose('Client Proof:\t' + hex(bytes_to_long(self.client_pop_key))) if self.client_pop_key is None:
raise RuntimeError('Failed to generate client proof!')
self._print_verbose(f'Client Proof:\t0x{self.client_pop_key.hex()}')
setup_req.sec2.sc1.client_proof = self.client_pop_key setup_req.sec2.sc1.client_proof = self.client_pop_key
return setup_req.SerializeToString().decode('latin-1') return setup_req.SerializeToString().decode('latin-1')
@@ -128,37 +127,36 @@ class Security2(Security):
def setup1_response(self, response_data: bytes) -> Any: def setup1_response(self, response_data: bytes) -> Any:
# Interpret SessionResp1 response packet # Interpret SessionResp1 response packet
setup_resp = proto.session_pb2.SessionData() setup_resp = proto.session_pb2.SessionData()
setup_resp.ParseFromString(tobytes(response_data)) setup_resp.ParseFromString(str_to_bytes(response_data))
# Ensure security scheme matches # Ensure security scheme matches
if setup_resp.sec_ver == proto.session_pb2.SecScheme2: if setup_resp.sec_ver == proto.session_pb2.SecScheme2:
# Read encrypyed device proof string # Read encrypyed device proof string
device_proof = setup_resp.sec2.sr1.device_proof device_proof = setup_resp.sec2.sr1.device_proof
self._print_verbose('Device Proof:\t' + hex(bytes_to_long(device_proof))) self._print_verbose(f'Device Proof:\t0x{device_proof.hex()}')
self.srp6a_ctx.verify_session(device_proof) self.srp6a_ctx.verify_session(device_proof)
if not self.srp6a_ctx.authenticated(): if not self.srp6a_ctx.authenticated():
print('Failed to verify device proof') raise RuntimeError('Failed to verify device proof')
exit(1)
else: else:
print('Unsupported security protocol') raise RuntimeError('Unsupported security protocol')
exit(1)
# Getting the shared secret # Getting the shared secret
shared_secret = self.srp6a_ctx.get_session_key() shared_secret = self.srp6a_ctx.get_session_key()
self._print_verbose('Shared Secret:\t' + hex(bytes_to_long(shared_secret))) self._print_verbose(f'Shared Secret:\t0x{shared_secret.hex()}')
# Using the first 256 bits of a 512 bit key # Using the first 256 bits of a 512 bit key
session_key = shared_secret[:AES_KEY_LEN] session_key = shared_secret[:AES_KEY_LEN]
self._print_verbose('Session Key:\t' + hex(bytes_to_long(session_key))) self._print_verbose(f'Session Key:\t0x{session_key.hex()}')
# 96-bit nonce # 96-bit nonce
self.nonce = setup_resp.sec2.sr1.device_nonce self.nonce = setup_resp.sec2.sr1.device_nonce
self._print_verbose('Nonce:\t' + hex(bytes_to_long(self.nonce))) if self.nonce is None:
raise RuntimeError('Received invalid nonce from device!')
self._print_verbose(f'Nonce:\t0x{self.nonce.hex()}')
# Initialize the encryption engine with Shared Key and initialization vector # Initialize the encryption engine with Shared Key and initialization vector
self.cipher = AESGCM(session_key) self.cipher = AESGCM(session_key)
if self.cipher is None: if self.cipher is None:
print('Failed to initialize AES-GCM cryptographic engine!') raise RuntimeError('Failed to initialize AES-GCM cryptographic engine!')
exit(1)
def encrypt_data(self, data: bytes) -> Any: def encrypt_data(self, data: bytes) -> Any:
return self.cipher.encrypt(self.nonce, data, None) return self.cipher.encrypt(self.nonce, data, None)

View File

@@ -1,5 +1,6 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# N A large safe prime (N = 2q+1, where q is prime) [All arithmetic is done modulo N] # N A large safe prime (N = 2q+1, where q is prime) [All arithmetic is done modulo N]
# g A generator modulo N # g A generator modulo N
@@ -19,6 +20,8 @@ import hashlib
import os import os
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
from utils import bytes_to_long, long_to_bytes
SHA1 = 0 SHA1 = 0
SHA224 = 1 SHA224 = 1
SHA256 = 2 SHA256 = 2
@@ -143,21 +146,11 @@ def get_ng(ng_type: int) -> Tuple[int, int]:
return int(n_hex, 16), int(g_hex, 16) return int(n_hex, 16), int(g_hex, 16)
def bytes_to_long(s: bytes) -> int: def get_random(nbytes: int) -> Any:
return int.from_bytes(s, 'big')
def long_to_bytes(n: int) -> bytes:
if n == 0:
return b'\x00'
return n.to_bytes((n.bit_length() + 7) // 8, 'big')
def get_random(nbytes: int) -> int:
return bytes_to_long(os.urandom(nbytes)) return bytes_to_long(os.urandom(nbytes))
def get_random_of_length(nbytes: int) -> int: def get_random_of_length(nbytes: int) -> Any:
offset = (nbytes * 8) - 1 offset = (nbytes * 8) - 1
return get_random(nbytes) | (1 << offset) return get_random(nbytes) | (1 << offset)
@@ -255,7 +248,7 @@ class Srp6a (object):
def get_username(self) -> str: def get_username(self) -> str:
return self.Iu return self.Iu
def get_ephemeral_secret(self) -> bytes: def get_ephemeral_secret(self) -> Any:
return long_to_bytes(self.a) return long_to_bytes(self.a)
def get_session_key(self) -> Any: def get_session_key(self) -> Any:

View File

@@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD # SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# # SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
from .transport_ble import * # noqa: F403, F401 from .transport_ble import * # noqa: F403, F401

View File

@@ -2,12 +2,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
from __future__ import print_function
import platform import platform
from builtins import input
import utils from utils import hex_str_to_bytes, str_to_bytes
fallback = True fallback = True
@@ -29,18 +26,23 @@ def device_sort(device):
class BLE_Bleak_Client: class BLE_Bleak_Client:
def __init__(self): def __init__(self):
self.adapter = None
self.adapter_props = None self.adapter_props = None
self.characteristics = dict()
self.chrc_names = None
self.device = None
self.devname = None
self.iface = None
self.nu_lookup = None
self.services = None
self.srv_uuid_adv = None
self.srv_uuid_fallback = None
async def connect(self, devname, iface, chrc_names, fallback_srv_uuid): async def connect(self, devname, iface, chrc_names, fallback_srv_uuid):
self.devname = devname self.devname = devname
self.srv_uuid_fallback = fallback_srv_uuid self.srv_uuid_fallback = fallback_srv_uuid
self.chrc_names = [name.lower() for name in chrc_names] self.chrc_names = [name.lower() for name in chrc_names]
self.device = None self.iface = iface
self.adapter = None
self.services = None
self.nu_lookup = None
self.characteristics = dict()
self.srv_uuid_adv = None
print('Discovering...') print('Discovering...')
try: try:
@@ -62,7 +64,7 @@ class BLE_Bleak_Client:
print('==== BLE Discovery results ====') print('==== BLE Discovery results ====')
print('{0: >4} {1: <33} {2: <12}'.format( print('{0: >4} {1: <33} {2: <12}'.format(
'S.N.', 'Name', 'Address')) 'S.N.', 'Name', 'Address'))
for i in range(len(devices)): for i, _ in enumerate(devices):
print('[{0: >2}] {1: <33} {2: <12}'.format(i + 1, devices[i].name or 'Unknown', devices[i].address)) print('[{0: >2}] {1: <33} {2: <12}'.format(i + 1, devices[i].name or 'Unknown', devices[i].address))
while True: while True:
@@ -193,10 +195,10 @@ class BLE_Console_Client:
async def send_data(self, characteristic_uuid, data): async def send_data(self, characteristic_uuid, data):
print("BLECLI >> Write following data to characteristic with UUID '" + characteristic_uuid + "' :") print("BLECLI >> Write following data to characteristic with UUID '" + characteristic_uuid + "' :")
print('\t>> ' + utils.str_to_hexstr(data)) print('\t>> ' + str_to_bytes(data).hex())
print('BLECLI >> Enter data read from characteristic (in hex) :') print('BLECLI >> Enter data read from characteristic (in hex) :')
resp = input('\t<< ') resp = input('\t<< ')
return utils.hexstr_to_str(resp) return hex_str_to_bytes(resp)
# -------------------------------------------------------------------- # --------------------------------------------------------------------

View File

@@ -12,6 +12,7 @@ class Transport_BLE(Transport):
def __init__(self, service_uuid, nu_lookup): def __init__(self, service_uuid, nu_lookup):
self.nu_lookup = nu_lookup self.nu_lookup = nu_lookup
self.service_uuid = service_uuid self.service_uuid = service_uuid
self.name_uuid_lookup = None
# Expect service UUID like '0000ffff-0000-1000-8000-00805f9b34fb' # Expect service UUID like '0000ffff-0000-1000-8000-00805f9b34fb'
for name in nu_lookup.keys(): for name in nu_lookup.keys():
# Calculate characteristic UUID for each endpoint # Calculate characteristic UUID for each endpoint
@@ -39,7 +40,7 @@ class Transport_BLE(Transport):
# Check if expected characteristics are provided by the service # Check if expected characteristics are provided by the service
for name in self.name_uuid_lookup.keys(): for name in self.name_uuid_lookup.keys():
if not self.cli.has_characteristic(self.name_uuid_lookup[name]): if not self.cli.has_characteristic(self.name_uuid_lookup[name]):
raise RuntimeError("'" + name + "' endpoint not found") raise RuntimeError(f"'{name}' endpoint not found")
async def disconnect(self): async def disconnect(self):
await self.cli.disconnect() await self.cli.disconnect()
@@ -47,5 +48,5 @@ class Transport_BLE(Transport):
async def send_data(self, ep_name, data): async def send_data(self, ep_name, data):
# Write (and read) data to characteristic corresponding to the endpoint # Write (and read) data to characteristic corresponding to the endpoint
if ep_name not in self.name_uuid_lookup.keys(): if ep_name not in self.name_uuid_lookup.keys():
raise RuntimeError('Invalid endpoint : ' + ep_name) raise RuntimeError(f'Invalid endpoint: {ep_name}')
return await self.cli.send_data(self.name_uuid_lookup[ep_name], data) return await self.cli.send_data(self.name_uuid_lookup[ep_name], data)

View File

@@ -2,11 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
from __future__ import print_function from utils import hex_str_to_bytes, str_to_bytes
from builtins import input
import utils
from .transport import Transport from .transport import Transport
@@ -14,10 +10,10 @@ from .transport import Transport
class Transport_Console(Transport): class Transport_Console(Transport):
async def send_data(self, path, data, session_id=0): async def send_data(self, path, data, session_id=0):
print('Client->Device msg :', path, session_id, utils.str_to_hexstr(data)) print('Client->Device msg :', path, session_id, str_to_bytes(data).hex())
try: try:
resp = input('Enter device->client msg : ') resp = input('Enter device->client msg : ')
except Exception as err: except Exception as err:
print('error:', err) print('error:', err)
return None return None
return utils.hexstr_to_str(resp) return hex_str_to_bytes(resp)

View File

@@ -22,14 +22,14 @@ class Transport_HTTP(Transport):
try: try:
socket.gethostbyname(hostname.split(':')[0]) socket.gethostbyname(hostname.split(':')[0])
except socket.gaierror: except socket.gaierror:
raise RuntimeError('Unable to resolve hostname :' + hostname) raise RuntimeError(f'Unable to resolve hostname: {hostname}')
if ssl_context is None: if ssl_context is None:
self.conn = HTTPConnection(hostname, timeout=60) self.conn = HTTPConnection(hostname, timeout=60)
else: else:
self.conn = HTTPSConnection(hostname, context=ssl_context, timeout=60) self.conn = HTTPSConnection(hostname, context=ssl_context, timeout=60)
try: try:
print('Connecting to ' + hostname) print(f'++++ Connecting to {hostname}++++')
self.conn.connect() self.conn.connect()
except Exception as err: except Exception as err:
raise RuntimeError('Connection Failure : ' + str(err)) raise RuntimeError('Connection Failure : ' + str(err))

View File

@@ -1,16 +1,5 @@
# Copyright 2018 Espressif Systems (Shanghai) PTE LTD # SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD
# # SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# #
from .convenience import * # noqa: F403, F401 from .convenience import * # noqa: F403, F401

View File

@@ -3,21 +3,22 @@
# #
# Convenience functions for commonly used data type conversions # Convenience functions for commonly used data type conversions
import binascii
from future.utils import tobytes def bytes_to_long(s: bytes) -> int:
return int.from_bytes(s, 'big')
def str_to_hexstr(string): def long_to_bytes(n: int) -> bytes:
# Form hexstr by appending ASCII codes (in hex) corresponding to if n == 0:
# each character in the input string return b'\x00'
return binascii.hexlify(tobytes(string)).decode('latin-1') return n.to_bytes((n.bit_length() + 7) // 8, 'big')
def hexstr_to_str(hexstr): # 'deadbeef' -> b'deadbeef'
# Prepend 0 (if needed) to make the hexstr length an even number def str_to_bytes(s: str) -> bytes:
if len(hexstr) % 2 == 1: return bytes(s, encoding='latin-1')
hexstr = '0' + hexstr
# Interpret consecutive pairs of hex characters as 8 bit ASCII codes
# and append characters corresponding to each code to form the string # 'deadbeef' -> b'\xde\xad\xbe\xef'
return binascii.unhexlify(tobytes(hexstr)).decode('latin-1') def hex_str_to_bytes(s: str) -> bytes:
return bytes.fromhex(s)