diff --git a/components/mqtt/test_apps/test_mqtt/pytest_mqtt_ut.py b/components/mqtt/test_apps/test_mqtt/pytest_mqtt_ut.py index 300d53e..429b857 100644 --- a/components/mqtt/test_apps/test_mqtt/pytest_mqtt_ut.py +++ b/components/mqtt/test_apps/test_mqtt/pytest_mqtt_ut.py @@ -1,10 +1,11 @@ -# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 import pytest from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_mqtt_client(dut: Dut) -> None: dut.expect_unity_test_output() diff --git a/components/mqtt/test_apps/test_mqtt5/pytest_mqtt5_ut.py b/components/mqtt/test_apps/test_mqtt5/pytest_mqtt5_ut.py index df15c39..7e38dfa 100644 --- a/components/mqtt/test_apps/test_mqtt5/pytest_mqtt5_ut.py +++ b/components/mqtt/test_apps/test_mqtt5/pytest_mqtt5_ut.py @@ -1,10 +1,11 @@ -# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 import pytest from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_mqtt5_client(dut: Dut) -> None: dut.expect_unity_test_output() diff --git a/examples/protocols/mqtt/ssl/pytest_mqtt_ssl.py b/examples/protocols/mqtt/ssl/pytest_mqtt_ssl.py index d27cc65..c34cea6 100644 --- a/examples/protocols/mqtt/ssl/pytest_mqtt_ssl.py +++ b/examples/protocols/mqtt/ssl/pytest_mqtt_ssl.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: 2022-2024 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 import logging import os @@ -12,6 +12,7 @@ import paho.mqtt.client as mqtt import pexpect import pytest from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize event_client_connected = Event() event_stop_client = Event() @@ -50,7 +51,9 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client recv_binary = binary + '.received' with open(recv_binary, 'w', encoding='utf-8') as fw: fw.write(msg.payload) - raise ValueError('Received binary (saved as: {}) does not match the original file: {}'.format(recv_binary, binary)) + raise ValueError( + 'Received binary (saved as: {}) does not match the original file: {}'.format(recv_binary, binary) + ) payload = msg.payload.decode() if not event_client_received_correct.is_set() and payload == 'data': @@ -61,8 +64,8 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None broker_url = '' broker_port = 0 @@ -95,14 +98,16 @@ def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None client.on_connect = on_connect client.on_message = on_message client.user_data_set((binary_file, bin_size)) - client.tls_set(None, - None, - None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) + client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) client.tls_insecure_set(True) print('Connecting...') client.connect(broker_url, broker_port, 60) except Exception: - print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0])) + print( + 'ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format( + broker_url, sys.exc_info()[0] + ) + ) raise # Starting a py-client in a separate thread thread1 = Thread(target=mqtt_client_task, args=(client,)) diff --git a/examples/protocols/mqtt/tcp/pytest_mqtt_tcp.py b/examples/protocols/mqtt/tcp/pytest_mqtt_tcp.py index 2786a76..f8deb23 100644 --- a/examples/protocols/mqtt/tcp/pytest_mqtt_tcp.py +++ b/examples/protocols/mqtt/tcp/pytest_mqtt_tcp.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 import logging import os @@ -12,6 +12,7 @@ import pexpect import pytest from common_test_methods import get_host_ip4_by_dest_ip from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize msgid = -1 @@ -25,13 +26,15 @@ def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None s.settimeout(60) s.bind((my_ip, port)) s.listen(1) - q,addr = s.accept() + q, addr = s.accept() q.settimeout(30) print('connection accepted') except Exception: - print('Local server on {}:{} listening/accepting failure: {}' - 'Possibly check permissions or firewall settings' - 'to accept connections on this address'.format(my_ip, port, sys.exc_info()[0])) + print( + 'Local server on {}:{} listening/accepting failure: {}' + 'Possibly check permissions or firewall settings' + 'to accept connections on this address'.format(my_ip, port, sys.exc_info()[0]) + ) raise data = q.recv(1024) # check if received initial empty message @@ -49,8 +52,8 @@ def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None print('server closed') -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_examples_protocol_mqtt_qos1(dut: Dut) -> None: global msgid """ @@ -73,7 +76,7 @@ def test_examples_protocol_mqtt_qos1(dut: Dut) -> None: # 2. start mqtt broker sketch host_ip = get_host_ip4_by_dest_ip(ip_address) - thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883)) + thread1 = Thread(target=mqqt_server_sketch, args=(host_ip, 1883)) thread1.start() data_write = 'mqtt://' + host_ip @@ -85,8 +88,10 @@ def test_examples_protocol_mqtt_qos1(dut: Dut) -> None: msgid_enqueued = dut.expect(b'outbox: ENQUEUE msgid=([0-9]+)', timeout=30).group(1).decode() msgid_deleted = dut.expect(b'outbox: DELETED msgid=([0-9]+)', timeout=30).group(1).decode() # 4. check the msgid of received data are the same as that of enqueued and deleted from outbox - if (msgid_enqueued == str(msgid) and msgid_deleted == str(msgid)): + if msgid_enqueued == str(msgid) and msgid_deleted == str(msgid): print('PASS: Received correct msg id') else: print('Failure!') - raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted)) + raise ValueError( + 'Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted) + ) diff --git a/examples/protocols/mqtt/ws/pytest_mqtt_ws_example.py b/examples/protocols/mqtt/ws/pytest_mqtt_ws_example.py index 392139c..3ad6bd7 100644 --- a/examples/protocols/mqtt/ws/pytest_mqtt_ws_example.py +++ b/examples/protocols/mqtt/ws/pytest_mqtt_ws_example.py @@ -1,16 +1,18 @@ #!/usr/bin/env python # -# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 import logging import os import re import sys -from threading import Event, Thread +from threading import Event +from threading import Thread import paho.mqtt.client as mqtt import pytest from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize event_client_connected = Event() event_stop_client = Event() @@ -43,8 +45,8 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None broker_url = '' broker_port = 0 @@ -77,7 +79,11 @@ def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None print('Connecting...') client.connect(broker_url, broker_port, 60) except Exception: - print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0])) + print( + 'ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format( + broker_url, sys.exc_info()[0] + ) + ) raise # Starting a py-client in a separate thread thread1 = Thread(target=mqtt_client_task, args=(client,)) diff --git a/examples/protocols/mqtt/wss/pytest_mqtt_wss_example.py b/examples/protocols/mqtt/wss/pytest_mqtt_wss_example.py index 05eff5e..cb4a22e 100644 --- a/examples/protocols/mqtt/wss/pytest_mqtt_wss_example.py +++ b/examples/protocols/mqtt/wss/pytest_mqtt_wss_example.py @@ -1,18 +1,20 @@ #!/usr/bin/env python # -# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 import logging import os import re import ssl import sys -from threading import Event, Thread +from threading import Event +from threading import Thread import paho.mqtt.client as mqtt import pexpect import pytest from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize event_client_connected = Event() event_stop_client = Event() @@ -45,8 +47,8 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None broker_url = '' broker_port = 0 @@ -76,13 +78,15 @@ def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None client = mqtt.Client(transport='websockets') client.on_connect = on_connect client.on_message = on_message - client.tls_set(None, - None, - None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) + client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) print('Connecting...') client.connect(broker_url, broker_port, 60) except Exception: - print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0])) + print( + 'ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format( + broker_url, sys.exc_info()[0] + ) + ) raise # Starting a py-client in a separate thread thread1 = Thread(target=mqtt_client_task, args=(client,)) diff --git a/examples/protocols/mqtt5/pytest_mqtt5.py b/examples/protocols/mqtt5/pytest_mqtt5.py index 603be68..dc7063e 100644 --- a/examples/protocols/mqtt5/pytest_mqtt5.py +++ b/examples/protocols/mqtt5/pytest_mqtt5.py @@ -1,23 +1,23 @@ #!/usr/bin/env python # -# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Apache-2.0 - import logging import os import pytest from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_examples_protocol_mqtt5(dut: Dut) -> None: """ steps: | 1. join AP 2. connect to mqtt://mqtt.eclipseprojects.io - 3. check conneciton success + 3. check connection success """ # check and log bin size binary_file = os.path.join(dut.app.binary_path, 'mqtt5.bin') diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py index 2377018..70da7eb 100644 --- a/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py @@ -1,6 +1,5 @@ -# SPDX-FileCopyrightText: 2022-2023 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 - import contextlib import logging import os @@ -9,31 +8,49 @@ import socketserver import ssl import subprocess from threading import Thread -from typing import Any, Callable, Dict, Optional +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional import pytest from common_test_methods import get_host_ip4_by_dest_ip from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize SERVER_PORT = 2222 def _path(f): # type: (str) -> str - return os.path.join(os.path.dirname(os.path.realpath(__file__)),f) + return os.path.join(os.path.dirname(os.path.realpath(__file__)), f) def set_server_cert_cn(ip): # type: (str) -> None arg_list = [ - ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'], - ['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'), - '-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']] + ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'), '-subj', '/CN={}'.format(ip), '-new'], + [ + 'openssl', + 'x509', + '-req', + '-in', + _path('srv.csr'), + '-CA', + _path('ca.crt'), + '-CAkey', + _path('ca.key'), + '-CAcreateserial', + '-out', + _path('srv.crt'), + '-days', + '360', + ], + ] for args in arg_list: if subprocess.check_call(args) != 0: raise RuntimeError('openssl command {} failed'.format(args)) class MQTTHandler(socketserver.StreamRequestHandler): - def handle(self) -> None: logging.info(' - connection from: {}'.format(self.client_address)) data = bytearray(self.request.recv(1024)) @@ -56,12 +73,14 @@ class TlsServer(socketserver.TCPServer): allow_reuse_address = True allow_reuse_port = True - def __init__(self, - port:int = SERVER_PORT, - ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler, - client_cert:bool=False, - refuse_connection:bool=False, - use_alpn:bool=False): + def __init__( + self, + port: int = SERVER_PORT, + ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler, + client_cert: bool = False, + refuse_connection: bool = False, + use_alpn: bool = False, + ): self.refuse_connection = refuse_connection self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.ssl_error = '' @@ -73,7 +92,7 @@ class TlsServer(socketserver.TCPServer): if use_alpn: self.context.set_alpn_protocols(['mymqtt', 'http/1.1']) self.server_thread = Thread(target=self.serve_forever) - super().__init__(('',port), ServerHandler) + super().__init__(('', port), ServerHandler) def server_activate(self) -> None: self.socket = self.context.wrap_socket(self.socket, server_side=True) @@ -125,14 +144,16 @@ def get_test_cases(dut: Dut) -> Any: cases = {} try: # Get connection test cases configuration: symbolic names for test cases - for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', - 'EXAMPLE_CONNECT_CASE_SERVER_CERT', - 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', - 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT', - 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT', - 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD', - 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT', - 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: + for case in [ + 'EXAMPLE_CONNECT_CASE_NO_CERT', + 'EXAMPLE_CONNECT_CASE_SERVER_CERT', + 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', + 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT', + 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT', + 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD', + 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT', + 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN', + ]: cases[case] = dut.app.sdkconfig.get(case) except Exception: logging.error('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig') @@ -147,7 +168,7 @@ def get_dut_ip(dut: Dut) -> Any: @contextlib.contextmanager -def connect_dut(dut: Dut, uri:str, case_id:int) -> Any: +def connect_dut(dut: Dut, uri: str, case_id: int) -> Any: dut.write('connection_setup') dut.write(f'connect {uri} {case_id}') dut.expect(f'Test case:{case_id} started') @@ -157,12 +178,16 @@ def connect_dut(dut: Dut, uri:str, case_id:int) -> Any: dut.write('disconnect') -def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None: +def run_cases(dut: Dut, uri: str, cases: Dict[str, int]) -> None: try: dut.write('init') dut.write(f'start') dut.write(f'disconnect') - for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']: + for case in [ + 'EXAMPLE_CONNECT_CASE_NO_CERT', + 'EXAMPLE_CONNECT_CASE_SERVER_CERT', + 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT', + ]: # All these cases connect to the server with no server verification or with server only verification with TlsServer(), connect_dut(dut, uri, cases[case]): logging.info(f'Running {case}: default server - expect to connect normally') @@ -172,9 +197,13 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None: dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error with TlsServer(client_cert=True) as server, connect_dut(dut, uri, cases[case]): - logging.info(f'Running {case}: server with client verification - handshake error since client presents no client certificate') + logging.info( + f'Running {case}: server with client verification - handshake error since client presents no client certificate' + ) dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) - dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE) + dut.expect( + 'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED' + ) # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE) assert 'PEER_DID_NOT_RETURN_A_CERTIFICATE' in server.last_ssl_error() for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']: @@ -187,15 +216,21 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None: with TlsServer() as s, connect_dut(dut, uri, cases[case]): logging.info(f'Running {case}: invalid server certificate on default server - expect ssl handshake error') dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) - dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA) - if re.match('.*alert.*unknown.*ca',s.last_ssl_error(), flags=re.I) is None: + dut.expect( + 'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED' + ) # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA) + if re.match('.*alert.*unknown.*ca', s.last_ssl_error(), flags=re.I) is None: raise Exception(f'Unexpected ssl error from the server: {s.last_ssl_error()}') case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT' with TlsServer(client_cert=True) as s, connect_dut(dut, uri, cases[case]): - logging.info(f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error') + logging.info( + f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error' + ) dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) - dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED) + dut.expect( + 'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED' + ) # expect ... handshake error (CERTIFICATE_VERIFY_FAILED) if 'CERTIFICATE_VERIFY_FAILED' not in s.last_ssl_error(): raise Exception('Unexpected ssl error from the server {}'.format(s.last_ssl_error())) @@ -214,8 +249,8 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None: dut.write('destroy') -@pytest.mark.esp32 @pytest.mark.ethernet +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_mqtt_connect( dut: Dut, log_performance: Callable[[str, object], None], diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py index b2df800..4eb25d1 100644 --- a/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: 2023-2024 Espressif Systems (Shanghai) CO LTD +# SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 import contextlib import difflib @@ -22,13 +22,13 @@ import paho.mqtt.client as mqtt import pexpect import pytest from pytest_embedded import Dut +from pytest_embedded_idf.utils import idf_parametrize DEFAULT_MSG_SIZE = 16 # Publisher class creating a python client to send/receive published data from esp-mqtt client class MqttPublisher(mqtt.Client): - def __init__(self, config, log_details=False): # type: (MqttPublisher, dict, bool) -> None self.log_details = log_details self.config = config @@ -40,7 +40,9 @@ class MqttPublisher(mqtt.Client): self.event_client_subscribed = Event() self.event_client_got_all = Event() transport = 'websockets' if self.config['transport'] in ['ws', 'wss'] else 'tcp' - client_id = 'MqttTestRunner' + ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(5)) + client_id = 'MqttTestRunner' + ''.join( + random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(5) + ) super().__init__(client_id, userdata=0, transport=transport) def print_details(self, text): # type: (str) -> None @@ -53,7 +55,7 @@ class MqttPublisher(mqtt.Client): logging.info(f'Subscribed to {self.config["subscribe_topic"]} successfully with QoS: {granted_qos}') self.event_client_subscribed.set() - def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc:int) -> None: + def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc: int) -> None: self.event_client_connected.set() def on_connect_fail(self, mqttc: Any, obj: Any) -> None: @@ -67,8 +69,10 @@ class MqttPublisher(mqtt.Client): self.event_client_got_all.set() else: differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, self.expected_data)))) - logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:' - f'{len(self.expected_data)}') + logging.error( + f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:' + f'{len(self.expected_data)}' + ) logging.info(f'Repetitions: {payload.count(self.config["pattern"])}') logging.info(f'Pattern: {self.config["pattern"]}') logging.info(f'First: {payload[:DEFAULT_MSG_SIZE]}') @@ -107,9 +111,10 @@ class MqttPublisher(mqtt.Client): self.loop_stop() -def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]: +def get_configurations(dut: Dut, test_case: Any) -> Dict[str, Any]: publish_cfg = {} try: + @no_type_check def get_config_from_dut(dut, config_option): # logging.info('Option:', config_option, dut.app.sdkconfig.get(config_option)) @@ -117,11 +122,18 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]: if value is None: return None, None return value.group(1), int(value.group(2)) + # Get publish test configuration - publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI') - publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI') + publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_config_from_dut( + dut, 'EXAMPLE_BROKER_SSL_URI' + ) + publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_config_from_dut( + dut, 'EXAMPLE_BROKER_TCP_URI' + ) publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_WS_URI') - publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI') + publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_config_from_dut( + dut, 'EXAMPLE_BROKER_WSS_URI' + ) except Exception: logging.info('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig') @@ -133,9 +145,13 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]: publish_cfg['qos'] = qos publish_cfg['enqueue'] = enqueue publish_cfg['transport'] = transport - publish_cfg['pattern'] = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)) + publish_cfg['pattern'] = ''.join( + random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE) + ) publish_cfg['test_timeout'] = get_timeout(test_case) - unique_topic = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(DEFAULT_MSG_SIZE)) + unique_topic = ''.join( + random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(DEFAULT_MSG_SIZE) + ) publish_cfg['subscribe_topic'] = 'test/subscribe_to/' + unique_topic publish_cfg['publish_topic'] = 'test/subscribe_to/' + unique_topic logging.info(f'configuration: {publish_cfg}') @@ -143,7 +159,7 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]: @contextlib.contextmanager -def connected_and_subscribed(dut:Dut) -> Any: +def connected_and_subscribed(dut: Dut) -> Any: dut.write('start') dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60) yield @@ -155,16 +171,17 @@ def get_scenarios() -> List[Dict[str, int]]: # Initialize message sizes and repeat counts (if defined in the environment) for i in count(0): # Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x} - env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']} + env_dict = {var: 'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']} if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']): scenarios.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore continue break - if not scenarios: # No message sizes present in the env - set defaults - scenarios = [{'msg_len':0, 'nr_of_msgs':5}, # zero-sized messages - {'msg_len':2, 'nr_of_msgs':5}, # short messages - {'msg_len':200, 'nr_of_msgs':3}, # long messages - ] + if not scenarios: # No message sizes present in the env - set defaults + scenarios = [ + {'msg_len': 0, 'nr_of_msgs': 5}, # zero-sized messages + {'msg_len': 2, 'nr_of_msgs': 5}, # short messages + {'msg_len': 200, 'nr_of_msgs': 3}, # long messages + ] return scenarios @@ -181,17 +198,23 @@ def get_timeout(test_case: Any) -> int: def run_publish_test_case(dut: Dut, config: Any) -> None: - logging.info(f'Starting Publish test: transport:{config["transport"]}, qos:{config["qos"]},' - f'nr_of_msgs:{config["scenario"]["nr_of_msgs"]},' - f' msg_size:{config["scenario"]["msg_len"] * DEFAULT_MSG_SIZE}, enqueue:{config["enqueue"]}') - dut.write(f'publish_setup {config["transport"]} {config["publish_topic"]} {config["subscribe_topic"]} {config["pattern"]} {config["scenario"]["msg_len"]}') + logging.info( + f'Starting Publish test: transport:{config["transport"]}, qos:{config["qos"]},' + f'nr_of_msgs:{config["scenario"]["nr_of_msgs"]},' + f' msg_size:{config["scenario"]["msg_len"] * DEFAULT_MSG_SIZE}, enqueue:{config["enqueue"]}' + ) + dut.write( + f'publish_setup {config["transport"]} {config["publish_topic"]} {config["subscribe_topic"]} {config["pattern"]} {config["scenario"]["msg_len"]}' + ) with MqttPublisher(config) as publisher, connected_and_subscribed(dut): assert publisher.event_client_subscribed.wait(timeout=config['test_timeout']), 'Runner failed to subscribe' msgs_published: List[mqtt.MQTTMessageInfo] = [] dut.write(f'publish {config["scenario"]["nr_of_msgs"]} {config["qos"]} {config["enqueue"]}') - assert publisher.event_client_got_all.wait(timeout=config['test_timeout']), (f'Not all data received from ESP32: {config["transport"]} ' - f'qos={config["qos"]} received: {publisher.received} ' - f'expected: {config["scenario"]["nr_of_msgs"]}') + assert publisher.event_client_got_all.wait(timeout=config['test_timeout']), ( + f'Not all data received from ESP32: {config["transport"]} ' + f'qos={config["qos"]} received: {publisher.received} ' + f'expected: {config["scenario"]["nr_of_msgs"]}' + ) logging.info(' - all data received from ESP32') payload = config['pattern'] * config['scenario']['msg_len'] for _ in range(config['scenario']['nr_of_msgs']): @@ -214,15 +237,17 @@ def run_publish_test_case(dut: Dut, config: Any) -> None: logging.info('ESP32 received all data from runner') -stress_scenarios = [{'msg_len':20, 'nr_of_msgs':30}] # many medium sized +stress_scenarios = [{'msg_len': 20, 'nr_of_msgs': 30}] # many medium sized transport_cases = ['tcp', 'ws', 'wss', 'ssl'] qos_cases = [0, 1, 2] enqueue_cases = [0, 1] local_broker_supported_transports = ['tcp'] -local_broker_scenarios = [{'msg_len':0, 'nr_of_msgs':5}, # zero-sized messages - {'msg_len':5, 'nr_of_msgs':20}, # short messages - {'msg_len':500, 'nr_of_msgs':10}, # long messages - {'msg_len':20, 'nr_of_msgs':20}] # many medium sized +local_broker_scenarios = [ + {'msg_len': 0, 'nr_of_msgs': 5}, # zero-sized messages + {'msg_len': 5, 'nr_of_msgs': 20}, # short messages + {'msg_len': 500, 'nr_of_msgs': 10}, # long messages + {'msg_len': 20, 'nr_of_msgs': 20}, +] # many medium sized def make_cases(transport: Any, scenarios: List[Dict[str, int]]) -> List[Tuple[str, int, int, Dict[str, int]]]: @@ -233,10 +258,10 @@ test_cases = make_cases(transport_cases, get_scenarios()) stress_test_cases = make_cases(transport_cases, stress_scenarios) -@pytest.mark.esp32 @pytest.mark.ethernet @pytest.mark.parametrize('test_case', test_cases) @pytest.mark.parametrize('config', ['default'], indirect=True) +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_mqtt_publish(dut: Dut, test_case: Any) -> None: publish_cfg = get_configurations(dut, test_case) dut.expect(re.compile(rb'mqtt>'), timeout=30) @@ -244,11 +269,11 @@ def test_mqtt_publish(dut: Dut, test_case: Any) -> None: run_publish_test_case(dut, publish_cfg) -@pytest.mark.esp32 @pytest.mark.ethernet_stress @pytest.mark.nightly_run @pytest.mark.parametrize('test_case', stress_test_cases) @pytest.mark.parametrize('config', ['default'], indirect=True) +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None: publish_cfg = get_configurations(dut, test_case) dut.expect(re.compile(rb'mqtt>'), timeout=30) @@ -256,10 +281,10 @@ def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None: run_publish_test_case(dut, publish_cfg) -@pytest.mark.esp32 @pytest.mark.ethernet @pytest.mark.parametrize('test_case', make_cases(local_broker_supported_transports, local_broker_scenarios)) @pytest.mark.parametrize('config', ['local_broker'], indirect=True) +@idf_parametrize('target', ['esp32'], indirect=['target']) def test_mqtt_publish_local(dut: Dut, test_case: Any) -> None: if test_case[0] not in local_broker_supported_transports: pytest.skip(f'Skipping transport: {test_case[0]}...')