diff --git a/examples/security/flash_encryption/conftest.py b/examples/security/flash_encryption/conftest.py new file mode 100644 index 0000000000..95d19428ed --- /dev/null +++ b/examples/security/flash_encryption/conftest.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import pytest +from _pytest.fixtures import FixtureRequest +from _pytest.monkeypatch import MonkeyPatch +from pytest_embedded_idf.serial import IdfSerial + + +# This is a custom Serial Class to add the erase_flash functionality +class FlashEncSerial(IdfSerial): + + @IdfSerial.use_esptool + def erase_partition(self, partition_name: str) -> None: + + if partition_name is None: + logging.error('Invalid arguments') + return + + if not self.app.partition_table: + logging.error('Partition table not parsed.') + return + + if partition_name in self.app.partition_table: + address = self.app.partition_table[partition_name]['offset'] + size = self.app.partition_table[partition_name]['size'] + logging.info('Erasing the partition {0} of size {1} at {2}'.format(partition_name, size, address)) + self.stub.erase_region(address, size) + else: + logging.error('partition name {0} not found in app partition table'.format(partition_name)) + return + + +@pytest.fixture(scope='module') +def monkeypatch_module(request: FixtureRequest) -> MonkeyPatch: + mp = MonkeyPatch() + request.addfinalizer(mp.undo) + return mp + + +@pytest.fixture(scope='module', autouse=True) +def replace_dut_class(monkeypatch_module: MonkeyPatch) -> None: + monkeypatch_module.setattr('pytest_embedded_idf.serial.IdfSerial', FlashEncSerial) diff --git a/examples/security/flash_encryption/pytest_flash_encryption.py b/examples/security/flash_encryption/pytest_flash_encryption.py index bebf82b74a..b8e9de2f08 100644 --- a/examples/security/flash_encryption/pytest_flash_encryption.py +++ b/examples/security/flash_encryption/pytest_flash_encryption.py @@ -4,23 +4,13 @@ from __future__ import print_function import binascii -import os -import sys from collections import namedtuple from io import BytesIO +import espsecure import pytest from pytest_embedded import Dut -try: - import espsecure -except ImportError: - idf_path = os.getenv('IDF_PATH') - if not idf_path or not os.path.exists(idf_path): - raise - sys.path.insert(0, os.path.join(idf_path, 'components', 'esptool_py', 'esptool')) - import espsecure - # To prepare a test runner for this example: # 1. Generate zero flash encryption key: @@ -33,6 +23,8 @@ except ImportError: @pytest.mark.esp32c3 @pytest.mark.flash_encryption def test_examples_security_flash_encryption(dut: Dut) -> None: + # Erase the nvs_key partition + dut.serial.erase_partition('nvs_key') # calculate the expected ciphertext flash_addr = dut.app.partition_table['storage']['offset'] plain_hex_str = '00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f' @@ -72,7 +64,3 @@ def test_examples_security_flash_encryption(dut: Dut) -> None: ] for line in lines: dut.expect(line, timeout=2) - - -if __name__ == '__main__': - test_examples_security_flash_encryption()