test: format all test scripts

This commit is contained in:
igor.udot
2025-02-24 10:18:03 +08:00
committed by Rocha Euripedes
parent b8d7b2bded
commit d30dfd5074
9 changed files with 184 additions and 102 deletions

View File

@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import pytest import pytest
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_client(dut: Dut) -> None: def test_mqtt_client(dut: Dut) -> None:
dut.expect_unity_test_output() dut.expect_unity_test_output()

View File

@@ -1,10 +1,11 @@
# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import pytest import pytest
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt5_client(dut: Dut) -> None: def test_mqtt5_client(dut: Dut) -> None:
dut.expect_unity_test_output() dut.expect_unity_test_output()

View File

@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2022-2024 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging import logging
import os import os
@@ -12,6 +12,7 @@ import paho.mqtt.client as mqtt
import pexpect import pexpect
import pytest import pytest
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
@@ -50,7 +51,9 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client
recv_binary = binary + '.received' recv_binary = binary + '.received'
with open(recv_binary, 'w', encoding='utf-8') as fw: with open(recv_binary, 'w', encoding='utf-8') as fw:
fw.write(msg.payload) fw.write(msg.payload)
raise ValueError('Received binary (saved as: {}) does not match the original file: {}'.format(recv_binary, binary)) raise ValueError(
'Received binary (saved as: {}) does not match the original file: {}'.format(recv_binary, binary)
)
payload = msg.payload.decode() payload = msg.payload.decode()
if not event_client_received_correct.is_set() and payload == 'data': if not event_client_received_correct.is_set() and payload == 'data':
@@ -61,8 +64,8 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None
broker_url = '' broker_url = ''
broker_port = 0 broker_port = 0
@@ -95,14 +98,16 @@ def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None
client.on_connect = on_connect client.on_connect = on_connect
client.on_message = on_message client.on_message = on_message
client.user_data_set((binary_file, bin_size)) client.user_data_set((binary_file, bin_size))
client.tls_set(None, client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
None,
None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
client.tls_insecure_set(True) client.tls_insecure_set(True)
print('Connecting...') print('Connecting...')
client.connect(broker_url, broker_port, 60) client.connect(broker_url, broker_port, 60)
except Exception: except Exception:
print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0])) print(
'ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(
broker_url, sys.exc_info()[0]
)
)
raise raise
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
thread1 = Thread(target=mqtt_client_task, args=(client,)) thread1 = Thread(target=mqtt_client_task, args=(client,))

View File

@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging import logging
import os import os
@@ -12,6 +12,7 @@ import pexpect
import pytest import pytest
from common_test_methods import get_host_ip4_by_dest_ip from common_test_methods import get_host_ip4_by_dest_ip
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
msgid = -1 msgid = -1
@@ -25,13 +26,15 @@ def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None
s.settimeout(60) s.settimeout(60)
s.bind((my_ip, port)) s.bind((my_ip, port))
s.listen(1) s.listen(1)
q,addr = s.accept() q, addr = s.accept()
q.settimeout(30) q.settimeout(30)
print('connection accepted') print('connection accepted')
except Exception: except Exception:
print('Local server on {}:{} listening/accepting failure: {}' print(
'Local server on {}:{} listening/accepting failure: {}'
'Possibly check permissions or firewall settings' 'Possibly check permissions or firewall settings'
'to accept connections on this address'.format(my_ip, port, sys.exc_info()[0])) 'to accept connections on this address'.format(my_ip, port, sys.exc_info()[0])
)
raise raise
data = q.recv(1024) data = q.recv(1024)
# check if received initial empty message # check if received initial empty message
@@ -49,8 +52,8 @@ def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None
print('server closed') print('server closed')
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_examples_protocol_mqtt_qos1(dut: Dut) -> None: def test_examples_protocol_mqtt_qos1(dut: Dut) -> None:
global msgid global msgid
""" """
@@ -73,7 +76,7 @@ def test_examples_protocol_mqtt_qos1(dut: Dut) -> None:
# 2. start mqtt broker sketch # 2. start mqtt broker sketch
host_ip = get_host_ip4_by_dest_ip(ip_address) host_ip = get_host_ip4_by_dest_ip(ip_address)
thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883)) thread1 = Thread(target=mqqt_server_sketch, args=(host_ip, 1883))
thread1.start() thread1.start()
data_write = 'mqtt://' + host_ip data_write = 'mqtt://' + host_ip
@@ -85,8 +88,10 @@ def test_examples_protocol_mqtt_qos1(dut: Dut) -> None:
msgid_enqueued = dut.expect(b'outbox: ENQUEUE msgid=([0-9]+)', timeout=30).group(1).decode() msgid_enqueued = dut.expect(b'outbox: ENQUEUE msgid=([0-9]+)', timeout=30).group(1).decode()
msgid_deleted = dut.expect(b'outbox: DELETED msgid=([0-9]+)', timeout=30).group(1).decode() msgid_deleted = dut.expect(b'outbox: DELETED msgid=([0-9]+)', timeout=30).group(1).decode()
# 4. check the msgid of received data are the same as that of enqueued and deleted from outbox # 4. check the msgid of received data are the same as that of enqueued and deleted from outbox
if (msgid_enqueued == str(msgid) and msgid_deleted == str(msgid)): if msgid_enqueued == str(msgid) and msgid_deleted == str(msgid):
print('PASS: Received correct msg id') print('PASS: Received correct msg id')
else: else:
print('Failure!') print('Failure!')
raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted)) raise ValueError(
'Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted)
)

View File

@@ -1,16 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
# #
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging import logging
import os import os
import re import re
import sys import sys
from threading import Event, Thread from threading import Event
from threading import Thread
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import pytest import pytest
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
@@ -43,8 +45,8 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None
broker_url = '' broker_url = ''
broker_port = 0 broker_port = 0
@@ -77,7 +79,11 @@ def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None
print('Connecting...') print('Connecting...')
client.connect(broker_url, broker_port, 60) client.connect(broker_url, broker_port, 60)
except Exception: except Exception:
print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0])) print(
'ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(
broker_url, sys.exc_info()[0]
)
)
raise raise
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
thread1 = Thread(target=mqtt_client_task, args=(client,)) thread1 = Thread(target=mqtt_client_task, args=(client,))

View File

@@ -1,18 +1,20 @@
#!/usr/bin/env python #!/usr/bin/env python
# #
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import logging import logging
import os import os
import re import re
import ssl import ssl
import sys import sys
from threading import Event, Thread from threading import Event
from threading import Thread
import paho.mqtt.client as mqtt import paho.mqtt.client as mqtt
import pexpect import pexpect
import pytest import pytest
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
event_client_connected = Event() event_client_connected = Event()
event_stop_client = Event() event_stop_client = Event()
@@ -45,8 +47,8 @@ def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client
message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' message_log += 'Received data:' + msg.topic + ' ' + payload + '\n'
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None
broker_url = '' broker_url = ''
broker_port = 0 broker_port = 0
@@ -76,13 +78,15 @@ def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None
client = mqtt.Client(transport='websockets') client = mqtt.Client(transport='websockets')
client.on_connect = on_connect client.on_connect = on_connect
client.on_message = on_message client.on_message = on_message
client.tls_set(None, client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
None,
None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None)
print('Connecting...') print('Connecting...')
client.connect(broker_url, broker_port, 60) client.connect(broker_url, broker_port, 60)
except Exception: except Exception:
print('ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(broker_url, sys.exc_info()[0])) print(
'ENV_TEST_FAILURE: Unexpected error while connecting to broker {}: {}:'.format(
broker_url, sys.exc_info()[0]
)
)
raise raise
# Starting a py-client in a separate thread # Starting a py-client in a separate thread
thread1 = Thread(target=mqtt_client_task, args=(client,)) thread1 = Thread(target=mqtt_client_task, args=(client,))

View File

@@ -1,23 +1,23 @@
#!/usr/bin/env python #!/usr/bin/env python
# #
# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
import os import os
import pytest import pytest
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_examples_protocol_mqtt5(dut: Dut) -> None: def test_examples_protocol_mqtt5(dut: Dut) -> None:
""" """
steps: | steps: |
1. join AP 1. join AP
2. connect to mqtt://mqtt.eclipseprojects.io 2. connect to mqtt://mqtt.eclipseprojects.io
3. check conneciton success 3. check connection success
""" """
# check and log bin size # check and log bin size
binary_file = os.path.join(dut.app.binary_path, 'mqtt5.bin') binary_file = os.path.join(dut.app.binary_path, 'mqtt5.bin')

View File

@@ -1,6 +1,5 @@
# SPDX-FileCopyrightText: 2022-2023 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2022-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import contextlib import contextlib
import logging import logging
import os import os
@@ -9,31 +8,49 @@ import socketserver
import ssl import ssl
import subprocess import subprocess
from threading import Thread from threading import Thread
from typing import Any, Callable, Dict, Optional from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
import pytest import pytest
from common_test_methods import get_host_ip4_by_dest_ip from common_test_methods import get_host_ip4_by_dest_ip
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
SERVER_PORT = 2222 SERVER_PORT = 2222
def _path(f): # type: (str) -> str def _path(f): # type: (str) -> str
return os.path.join(os.path.dirname(os.path.realpath(__file__)),f) return os.path.join(os.path.dirname(os.path.realpath(__file__)), f)
def set_server_cert_cn(ip): # type: (str) -> None def set_server_cert_cn(ip): # type: (str) -> None
arg_list = [ arg_list = [
['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'], ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'), '-subj', '/CN={}'.format(ip), '-new'],
['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'), [
'-CAkey', _path('ca.key'), '-CAcreateserial', '-out', _path('srv.crt'), '-days', '360']] 'openssl',
'x509',
'-req',
'-in',
_path('srv.csr'),
'-CA',
_path('ca.crt'),
'-CAkey',
_path('ca.key'),
'-CAcreateserial',
'-out',
_path('srv.crt'),
'-days',
'360',
],
]
for args in arg_list: for args in arg_list:
if subprocess.check_call(args) != 0: if subprocess.check_call(args) != 0:
raise RuntimeError('openssl command {} failed'.format(args)) raise RuntimeError('openssl command {} failed'.format(args))
class MQTTHandler(socketserver.StreamRequestHandler): class MQTTHandler(socketserver.StreamRequestHandler):
def handle(self) -> None: def handle(self) -> None:
logging.info(' - connection from: {}'.format(self.client_address)) logging.info(' - connection from: {}'.format(self.client_address))
data = bytearray(self.request.recv(1024)) data = bytearray(self.request.recv(1024))
@@ -56,12 +73,14 @@ class TlsServer(socketserver.TCPServer):
allow_reuse_address = True allow_reuse_address = True
allow_reuse_port = True allow_reuse_port = True
def __init__(self, def __init__(
port:int = SERVER_PORT, self,
port: int = SERVER_PORT,
ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler, ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler,
client_cert:bool=False, client_cert: bool = False,
refuse_connection:bool=False, refuse_connection: bool = False,
use_alpn:bool=False): use_alpn: bool = False,
):
self.refuse_connection = refuse_connection self.refuse_connection = refuse_connection
self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_error = '' self.ssl_error = ''
@@ -73,7 +92,7 @@ class TlsServer(socketserver.TCPServer):
if use_alpn: if use_alpn:
self.context.set_alpn_protocols(['mymqtt', 'http/1.1']) self.context.set_alpn_protocols(['mymqtt', 'http/1.1'])
self.server_thread = Thread(target=self.serve_forever) self.server_thread = Thread(target=self.serve_forever)
super().__init__(('',port), ServerHandler) super().__init__(('', port), ServerHandler)
def server_activate(self) -> None: def server_activate(self) -> None:
self.socket = self.context.wrap_socket(self.socket, server_side=True) self.socket = self.context.wrap_socket(self.socket, server_side=True)
@@ -125,14 +144,16 @@ def get_test_cases(dut: Dut) -> Any:
cases = {} cases = {}
try: try:
# Get connection test cases configuration: symbolic names for test cases # Get connection test cases configuration: symbolic names for test cases
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', for case in [
'EXAMPLE_CONNECT_CASE_NO_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH',
'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD',
'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT',
'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN',
]:
cases[case] = dut.app.sdkconfig.get(case) cases[case] = dut.app.sdkconfig.get(case)
except Exception: except Exception:
logging.error('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig') logging.error('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig')
@@ -147,7 +168,7 @@ def get_dut_ip(dut: Dut) -> Any:
@contextlib.contextmanager @contextlib.contextmanager
def connect_dut(dut: Dut, uri:str, case_id:int) -> Any: def connect_dut(dut: Dut, uri: str, case_id: int) -> Any:
dut.write('connection_setup') dut.write('connection_setup')
dut.write(f'connect {uri} {case_id}') dut.write(f'connect {uri} {case_id}')
dut.expect(f'Test case:{case_id} started') dut.expect(f'Test case:{case_id} started')
@@ -157,12 +178,16 @@ def connect_dut(dut: Dut, uri:str, case_id:int) -> Any:
dut.write('disconnect') dut.write('disconnect')
def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None: def run_cases(dut: Dut, uri: str, cases: Dict[str, int]) -> None:
try: try:
dut.write('init') dut.write('init')
dut.write(f'start') dut.write(f'start')
dut.write(f'disconnect') dut.write(f'disconnect')
for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']: for case in [
'EXAMPLE_CONNECT_CASE_NO_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_CERT',
'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT',
]:
# All these cases connect to the server with no server verification or with server only verification # All these cases connect to the server with no server verification or with server only verification
with TlsServer(), connect_dut(dut, uri, cases[case]): with TlsServer(), connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: default server - expect to connect normally') logging.info(f'Running {case}: default server - expect to connect normally')
@@ -172,9 +197,13 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error
with TlsServer(client_cert=True) as server, connect_dut(dut, uri, cases[case]): with TlsServer(client_cert=True) as server, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: server with client verification - handshake error since client presents no client certificate') logging.info(
f'Running {case}: server with client verification - handshake error since client presents no client certificate'
)
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE) dut.expect(
'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED'
) # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE)
assert 'PEER_DID_NOT_RETURN_A_CERTIFICATE' in server.last_ssl_error() assert 'PEER_DID_NOT_RETURN_A_CERTIFICATE' in server.last_ssl_error()
for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']: for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']:
@@ -187,15 +216,21 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
with TlsServer() as s, connect_dut(dut, uri, cases[case]): with TlsServer() as s, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: invalid server certificate on default server - expect ssl handshake error') logging.info(f'Running {case}: invalid server certificate on default server - expect ssl handshake error')
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA) dut.expect(
if re.match('.*alert.*unknown.*ca',s.last_ssl_error(), flags=re.I) is None: 'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED'
) # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA)
if re.match('.*alert.*unknown.*ca', s.last_ssl_error(), flags=re.I) is None:
raise Exception(f'Unexpected ssl error from the server: {s.last_ssl_error()}') raise Exception(f'Unexpected ssl error from the server: {s.last_ssl_error()}')
case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT' case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT'
with TlsServer(client_cert=True) as s, connect_dut(dut, uri, cases[case]): with TlsServer(client_cert=True) as s, connect_dut(dut, uri, cases[case]):
logging.info(f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error') logging.info(
f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error'
)
dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30)
dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED) dut.expect(
'ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED'
) # expect ... handshake error (CERTIFICATE_VERIFY_FAILED)
if 'CERTIFICATE_VERIFY_FAILED' not in s.last_ssl_error(): if 'CERTIFICATE_VERIFY_FAILED' not in s.last_ssl_error():
raise Exception('Unexpected ssl error from the server {}'.format(s.last_ssl_error())) raise Exception('Unexpected ssl error from the server {}'.format(s.last_ssl_error()))
@@ -214,8 +249,8 @@ def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None:
dut.write('destroy') dut.write('destroy')
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_connect( def test_mqtt_connect(
dut: Dut, dut: Dut,
log_performance: Callable[[str, object], None], log_performance: Callable[[str, object], None],

View File

@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2023-2024 Espressif Systems (Shanghai) CO LTD # SPDX-FileCopyrightText: 2023-2025 Espressif Systems (Shanghai) CO LTD
# SPDX-License-Identifier: Unlicense OR CC0-1.0 # SPDX-License-Identifier: Unlicense OR CC0-1.0
import contextlib import contextlib
import difflib import difflib
@@ -22,13 +22,13 @@ import paho.mqtt.client as mqtt
import pexpect import pexpect
import pytest import pytest
from pytest_embedded import Dut from pytest_embedded import Dut
from pytest_embedded_idf.utils import idf_parametrize
DEFAULT_MSG_SIZE = 16 DEFAULT_MSG_SIZE = 16
# Publisher class creating a python client to send/receive published data from esp-mqtt client # Publisher class creating a python client to send/receive published data from esp-mqtt client
class MqttPublisher(mqtt.Client): class MqttPublisher(mqtt.Client):
def __init__(self, config, log_details=False): # type: (MqttPublisher, dict, bool) -> None def __init__(self, config, log_details=False): # type: (MqttPublisher, dict, bool) -> None
self.log_details = log_details self.log_details = log_details
self.config = config self.config = config
@@ -40,7 +40,9 @@ class MqttPublisher(mqtt.Client):
self.event_client_subscribed = Event() self.event_client_subscribed = Event()
self.event_client_got_all = Event() self.event_client_got_all = Event()
transport = 'websockets' if self.config['transport'] in ['ws', 'wss'] else 'tcp' transport = 'websockets' if self.config['transport'] in ['ws', 'wss'] else 'tcp'
client_id = 'MqttTestRunner' + ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(5)) client_id = 'MqttTestRunner' + ''.join(
random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(5)
)
super().__init__(client_id, userdata=0, transport=transport) super().__init__(client_id, userdata=0, transport=transport)
def print_details(self, text): # type: (str) -> None def print_details(self, text): # type: (str) -> None
@@ -53,7 +55,7 @@ class MqttPublisher(mqtt.Client):
logging.info(f'Subscribed to {self.config["subscribe_topic"]} successfully with QoS: {granted_qos}') logging.info(f'Subscribed to {self.config["subscribe_topic"]} successfully with QoS: {granted_qos}')
self.event_client_subscribed.set() self.event_client_subscribed.set()
def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc:int) -> None: def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc: int) -> None:
self.event_client_connected.set() self.event_client_connected.set()
def on_connect_fail(self, mqttc: Any, obj: Any) -> None: def on_connect_fail(self, mqttc: Any, obj: Any) -> None:
@@ -67,8 +69,10 @@ class MqttPublisher(mqtt.Client):
self.event_client_got_all.set() self.event_client_got_all.set()
else: else:
differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, self.expected_data)))) differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, self.expected_data))))
logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:' logging.error(
f'{len(self.expected_data)}') f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:'
f'{len(self.expected_data)}'
)
logging.info(f'Repetitions: {payload.count(self.config["pattern"])}') logging.info(f'Repetitions: {payload.count(self.config["pattern"])}')
logging.info(f'Pattern: {self.config["pattern"]}') logging.info(f'Pattern: {self.config["pattern"]}')
logging.info(f'First: {payload[:DEFAULT_MSG_SIZE]}') logging.info(f'First: {payload[:DEFAULT_MSG_SIZE]}')
@@ -107,9 +111,10 @@ class MqttPublisher(mqtt.Client):
self.loop_stop() self.loop_stop()
def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]: def get_configurations(dut: Dut, test_case: Any) -> Dict[str, Any]:
publish_cfg = {} publish_cfg = {}
try: try:
@no_type_check @no_type_check
def get_config_from_dut(dut, config_option): def get_config_from_dut(dut, config_option):
# logging.info('Option:', config_option, dut.app.sdkconfig.get(config_option)) # logging.info('Option:', config_option, dut.app.sdkconfig.get(config_option))
@@ -117,11 +122,18 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]:
if value is None: if value is None:
return None, None return None, None
return value.group(1), int(value.group(2)) return value.group(1), int(value.group(2))
# Get publish test configuration # Get publish test configuration
publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI') publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_config_from_dut(
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI') dut, 'EXAMPLE_BROKER_SSL_URI'
)
publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_config_from_dut(
dut, 'EXAMPLE_BROKER_TCP_URI'
)
publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_WS_URI') publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_WS_URI')
publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_config_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI') publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_config_from_dut(
dut, 'EXAMPLE_BROKER_WSS_URI'
)
except Exception: except Exception:
logging.info('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig') logging.info('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig')
@@ -133,9 +145,13 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]:
publish_cfg['qos'] = qos publish_cfg['qos'] = qos
publish_cfg['enqueue'] = enqueue publish_cfg['enqueue'] = enqueue
publish_cfg['transport'] = transport publish_cfg['transport'] = transport
publish_cfg['pattern'] = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)) publish_cfg['pattern'] = ''.join(
random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)
)
publish_cfg['test_timeout'] = get_timeout(test_case) publish_cfg['test_timeout'] = get_timeout(test_case)
unique_topic = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(DEFAULT_MSG_SIZE)) unique_topic = ''.join(
random.choice(string.ascii_uppercase + string.ascii_lowercase) for _ in range(DEFAULT_MSG_SIZE)
)
publish_cfg['subscribe_topic'] = 'test/subscribe_to/' + unique_topic publish_cfg['subscribe_topic'] = 'test/subscribe_to/' + unique_topic
publish_cfg['publish_topic'] = 'test/subscribe_to/' + unique_topic publish_cfg['publish_topic'] = 'test/subscribe_to/' + unique_topic
logging.info(f'configuration: {publish_cfg}') logging.info(f'configuration: {publish_cfg}')
@@ -143,7 +159,7 @@ def get_configurations(dut: Dut, test_case: Any) -> Dict[str,Any]:
@contextlib.contextmanager @contextlib.contextmanager
def connected_and_subscribed(dut:Dut) -> Any: def connected_and_subscribed(dut: Dut) -> Any:
dut.write('start') dut.write('start')
dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60) dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60)
yield yield
@@ -155,15 +171,16 @@ def get_scenarios() -> List[Dict[str, int]]:
# Initialize message sizes and repeat counts (if defined in the environment) # Initialize message sizes and repeat counts (if defined in the environment)
for i in count(0): for i in count(0):
# Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x} # Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x}
env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']} env_dict = {var: 'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']}
if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']): if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']):
scenarios.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore scenarios.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore
continue continue
break break
if not scenarios: # No message sizes present in the env - set defaults if not scenarios: # No message sizes present in the env - set defaults
scenarios = [{'msg_len':0, 'nr_of_msgs':5}, # zero-sized messages scenarios = [
{'msg_len':2, 'nr_of_msgs':5}, # short messages {'msg_len': 0, 'nr_of_msgs': 5}, # zero-sized messages
{'msg_len':200, 'nr_of_msgs':3}, # long messages {'msg_len': 2, 'nr_of_msgs': 5}, # short messages
{'msg_len': 200, 'nr_of_msgs': 3}, # long messages
] ]
return scenarios return scenarios
@@ -181,17 +198,23 @@ def get_timeout(test_case: Any) -> int:
def run_publish_test_case(dut: Dut, config: Any) -> None: def run_publish_test_case(dut: Dut, config: Any) -> None:
logging.info(f'Starting Publish test: transport:{config["transport"]}, qos:{config["qos"]},' logging.info(
f'Starting Publish test: transport:{config["transport"]}, qos:{config["qos"]},'
f'nr_of_msgs:{config["scenario"]["nr_of_msgs"]},' f'nr_of_msgs:{config["scenario"]["nr_of_msgs"]},'
f' msg_size:{config["scenario"]["msg_len"] * DEFAULT_MSG_SIZE}, enqueue:{config["enqueue"]}') f' msg_size:{config["scenario"]["msg_len"] * DEFAULT_MSG_SIZE}, enqueue:{config["enqueue"]}'
dut.write(f'publish_setup {config["transport"]} {config["publish_topic"]} {config["subscribe_topic"]} {config["pattern"]} {config["scenario"]["msg_len"]}') )
dut.write(
f'publish_setup {config["transport"]} {config["publish_topic"]} {config["subscribe_topic"]} {config["pattern"]} {config["scenario"]["msg_len"]}'
)
with MqttPublisher(config) as publisher, connected_and_subscribed(dut): with MqttPublisher(config) as publisher, connected_and_subscribed(dut):
assert publisher.event_client_subscribed.wait(timeout=config['test_timeout']), 'Runner failed to subscribe' assert publisher.event_client_subscribed.wait(timeout=config['test_timeout']), 'Runner failed to subscribe'
msgs_published: List[mqtt.MQTTMessageInfo] = [] msgs_published: List[mqtt.MQTTMessageInfo] = []
dut.write(f'publish {config["scenario"]["nr_of_msgs"]} {config["qos"]} {config["enqueue"]}') dut.write(f'publish {config["scenario"]["nr_of_msgs"]} {config["qos"]} {config["enqueue"]}')
assert publisher.event_client_got_all.wait(timeout=config['test_timeout']), (f'Not all data received from ESP32: {config["transport"]} ' assert publisher.event_client_got_all.wait(timeout=config['test_timeout']), (
f'Not all data received from ESP32: {config["transport"]} '
f'qos={config["qos"]} received: {publisher.received} ' f'qos={config["qos"]} received: {publisher.received} '
f'expected: {config["scenario"]["nr_of_msgs"]}') f'expected: {config["scenario"]["nr_of_msgs"]}'
)
logging.info(' - all data received from ESP32') logging.info(' - all data received from ESP32')
payload = config['pattern'] * config['scenario']['msg_len'] payload = config['pattern'] * config['scenario']['msg_len']
for _ in range(config['scenario']['nr_of_msgs']): for _ in range(config['scenario']['nr_of_msgs']):
@@ -214,15 +237,17 @@ def run_publish_test_case(dut: Dut, config: Any) -> None:
logging.info('ESP32 received all data from runner') logging.info('ESP32 received all data from runner')
stress_scenarios = [{'msg_len':20, 'nr_of_msgs':30}] # many medium sized stress_scenarios = [{'msg_len': 20, 'nr_of_msgs': 30}] # many medium sized
transport_cases = ['tcp', 'ws', 'wss', 'ssl'] transport_cases = ['tcp', 'ws', 'wss', 'ssl']
qos_cases = [0, 1, 2] qos_cases = [0, 1, 2]
enqueue_cases = [0, 1] enqueue_cases = [0, 1]
local_broker_supported_transports = ['tcp'] local_broker_supported_transports = ['tcp']
local_broker_scenarios = [{'msg_len':0, 'nr_of_msgs':5}, # zero-sized messages local_broker_scenarios = [
{'msg_len':5, 'nr_of_msgs':20}, # short messages {'msg_len': 0, 'nr_of_msgs': 5}, # zero-sized messages
{'msg_len':500, 'nr_of_msgs':10}, # long messages {'msg_len': 5, 'nr_of_msgs': 20}, # short messages
{'msg_len':20, 'nr_of_msgs':20}] # many medium sized {'msg_len': 500, 'nr_of_msgs': 10}, # long messages
{'msg_len': 20, 'nr_of_msgs': 20},
] # many medium sized
def make_cases(transport: Any, scenarios: List[Dict[str, int]]) -> List[Tuple[str, int, int, Dict[str, int]]]: def make_cases(transport: Any, scenarios: List[Dict[str, int]]) -> List[Tuple[str, int, int, Dict[str, int]]]:
@@ -233,10 +258,10 @@ test_cases = make_cases(transport_cases, get_scenarios())
stress_test_cases = make_cases(transport_cases, stress_scenarios) stress_test_cases = make_cases(transport_cases, stress_scenarios)
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@pytest.mark.parametrize('test_case', test_cases) @pytest.mark.parametrize('test_case', test_cases)
@pytest.mark.parametrize('config', ['default'], indirect=True) @pytest.mark.parametrize('config', ['default'], indirect=True)
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_publish(dut: Dut, test_case: Any) -> None: def test_mqtt_publish(dut: Dut, test_case: Any) -> None:
publish_cfg = get_configurations(dut, test_case) publish_cfg = get_configurations(dut, test_case)
dut.expect(re.compile(rb'mqtt>'), timeout=30) dut.expect(re.compile(rb'mqtt>'), timeout=30)
@@ -244,11 +269,11 @@ def test_mqtt_publish(dut: Dut, test_case: Any) -> None:
run_publish_test_case(dut, publish_cfg) run_publish_test_case(dut, publish_cfg)
@pytest.mark.esp32
@pytest.mark.ethernet_stress @pytest.mark.ethernet_stress
@pytest.mark.nightly_run @pytest.mark.nightly_run
@pytest.mark.parametrize('test_case', stress_test_cases) @pytest.mark.parametrize('test_case', stress_test_cases)
@pytest.mark.parametrize('config', ['default'], indirect=True) @pytest.mark.parametrize('config', ['default'], indirect=True)
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None: def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None:
publish_cfg = get_configurations(dut, test_case) publish_cfg = get_configurations(dut, test_case)
dut.expect(re.compile(rb'mqtt>'), timeout=30) dut.expect(re.compile(rb'mqtt>'), timeout=30)
@@ -256,10 +281,10 @@ def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None:
run_publish_test_case(dut, publish_cfg) run_publish_test_case(dut, publish_cfg)
@pytest.mark.esp32
@pytest.mark.ethernet @pytest.mark.ethernet
@pytest.mark.parametrize('test_case', make_cases(local_broker_supported_transports, local_broker_scenarios)) @pytest.mark.parametrize('test_case', make_cases(local_broker_supported_transports, local_broker_scenarios))
@pytest.mark.parametrize('config', ['local_broker'], indirect=True) @pytest.mark.parametrize('config', ['local_broker'], indirect=True)
@idf_parametrize('target', ['esp32'], indirect=['target'])
def test_mqtt_publish_local(dut: Dut, test_case: Any) -> None: def test_mqtt_publish_local(dut: Dut, test_case: Any) -> None:
if test_case[0] not in local_broker_supported_transports: if test_case[0] not in local_broker_supported_transports:
pytest.skip(f'Skipping transport: {test_case[0]}...') pytest.skip(f'Skipping transport: {test_case[0]}...')