diff --git a/examples/protocols/mqtt/ssl/mqtt_ssl_example_test.py b/examples/protocols/mqtt/ssl/pytest_mqtt_ssl.py similarity index 79% rename from examples/protocols/mqtt/ssl/mqtt_ssl_example_test.py rename to examples/protocols/mqtt/ssl/pytest_mqtt_ssl.py index 51ac4e9754..c16b89cbb1 100644 --- a/examples/protocols/mqtt/ssl/mqtt_ssl_example_test.py +++ b/examples/protocols/mqtt/ssl/pytest_mqtt_ssl.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Unlicense OR CC0-1.0 +import logging import os import re import ssl @@ -5,8 +8,9 @@ import sys from threading import Event, Thread import paho.mqtt.client as mqtt -import ttfw_idf -from tiny_test_fw import DUT +import pexpect +import pytest +from pytest_embedded import Dut event_client_connected = Event() event_stop_client = Event() @@ -16,19 +20,20 @@ message_log = '' # The callback for when the client receives a CONNACK response from the server. -def on_connect(client, userdata, flags, rc): +def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, str, bool, str) -> None + _ = (userdata, flags) print('Connected with result code ' + str(rc)) event_client_connected.set() client.subscribe('/topic/qos0') -def mqtt_client_task(client): +def mqtt_client_task(client): # type: (mqtt.Client) -> None while not event_stop_client.is_set(): client.loop() # The callback for when a PUBLISH message is received from the server. -def on_message(client, userdata, msg): +def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None global message_log global event_client_received_correct global event_client_received_binary @@ -55,8 +60,9 @@ def on_message(client, userdata, msg): message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' -@ttfw_idf.idf_example_test(env_tag='ethernet_router') -def test_examples_protocol_mqtt_ssl(env, extra_data): +@pytest.mark.esp32 +@pytest.mark.ethernet +def test_examples_protocol_mqtt_ssl(dut): # type: (Dut) -> None broker_url = '' broker_port = 0 """ @@ -67,18 +73,17 @@ def test_examples_protocol_mqtt_ssl(env, extra_data): 4. Test ESP32 client received correct qos0 message 5. Test python client receives binary data from running partition and compares it with the binary """ - dut1 = env.get_dut('mqtt_ssl', 'examples/protocols/mqtt/ssl', dut_class=ttfw_idf.ESP32DUT) - # check and log bin size - binary_file = os.path.join(dut1.app.binary_path, 'mqtt_ssl.bin') + binary_file = os.path.join(dut.app.binary_path, 'mqtt_ssl.bin') bin_size = os.path.getsize(binary_file) - ttfw_idf.log_performance('mqtt_ssl_bin_size', '{}KB' - .format(bin_size // 1024)) + logging.info('[Performance][mqtt_ssl_bin_size]: %s KB', bin_size // 1024) + # Look for host:port in sdkconfig try: - value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI']) + value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI')) + assert value is not None broker_url = value.group(1) broker_port = int(value.group(2)) - bin_size = min(int(dut1.app.get_sdkconfig()['CONFIG_BROKER_BIN_SIZE_TO_SEND']), bin_size) + bin_size = min(int(dut.app.sdkconfig.get('BROKER_BIN_SIZE_TO_SEND')), bin_size) except Exception: print('ENV_TEST_FAILURE: Cannot find broker url in sdkconfig') raise @@ -105,25 +110,20 @@ def test_examples_protocol_mqtt_ssl(env, extra_data): print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port)) if not event_client_connected.wait(timeout=30): raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url)) - dut1.start_app() try: - ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] + ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0] print('Connected to AP with IP: {}'.format(ip_address)) - except DUT.ExpectTimeout: + except pexpect.TIMEOUT: print('ENV_TEST_FAILURE: Cannot connect to AP') raise print('Checking py-client received msg published from esp...') if not event_client_received_correct.wait(timeout=30): raise ValueError('Wrong data received, msg log: {}'.format(message_log)) print('Checking esp-client received msg published from py-client...') - dut1.expect(re.compile(r'DATA=send binary please'), timeout=30) + dut.expect(r'DATA=send binary please', timeout=30) print('Receiving binary data from running partition...') if not event_client_received_binary.wait(timeout=30): raise ValueError('Binary not received within timeout') finally: event_stop_client.set() thread1.join() - - -if __name__ == '__main__': - test_examples_protocol_mqtt_ssl() diff --git a/examples/protocols/mqtt/tcp/mqtt_tcp_example_test.py b/examples/protocols/mqtt/tcp/pytest_mqtt_tcp.py similarity index 65% rename from examples/protocols/mqtt/tcp/mqtt_tcp_example_test.py rename to examples/protocols/mqtt/tcp/pytest_mqtt_tcp.py index 76dc4a4ca9..2786a76b97 100644 --- a/examples/protocols/mqtt/tcp/mqtt_tcp_example_test.py +++ b/examples/protocols/mqtt/tcp/pytest_mqtt_tcp.py @@ -1,19 +1,22 @@ +# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Unlicense OR CC0-1.0 +import logging import os -import re import socket import struct import sys import time from threading import Thread -import ttfw_idf +import pexpect +import pytest from common_test_methods import get_host_ip4_by_dest_ip -from tiny_test_fw import DUT +from pytest_embedded import Dut msgid = -1 -def mqqt_server_sketch(my_ip, port): +def mqqt_server_sketch(my_ip, port): # type: (str, str) -> None global msgid print('Starting the server on {}'.format(my_ip)) s = None @@ -32,13 +35,13 @@ def mqqt_server_sketch(my_ip, port): raise data = q.recv(1024) # check if received initial empty message - print('received from client {}'.format(data)) + print('received from client {!r}'.format(data)) data = bytearray([0x20, 0x02, 0x00, 0x00]) q.send(data) # try to receive qos1 data = q.recv(1024) msgid = struct.unpack('>H', data[15:17])[0] - print('received from client {}, msgid: {}'.format(data, msgid)) + print('received from client {!r}, msgid: {}'.format(data, msgid)) data = bytearray([0x40, 0x02, data[15], data[16]]) q.send(data) time.sleep(5) @@ -46,8 +49,9 @@ def mqqt_server_sketch(my_ip, port): print('server closed') -@ttfw_idf.idf_example_test(env_tag='ethernet_router') -def test_examples_protocol_mqtt_qos1(env, extra_data): +@pytest.mark.esp32 +@pytest.mark.ethernet +def test_examples_protocol_mqtt_qos1(dut: Dut) -> None: global msgid """ steps: (QoS1: Happy flow) @@ -56,18 +60,15 @@ def test_examples_protocol_mqtt_qos1(env, extra_data): 3. Test evaluates that qos1 message is queued and removed from queued after ACK received 4. Test the broker received the same message id evaluated in step 3 """ - dut1 = env.get_dut('mqtt_tcp', 'examples/protocols/mqtt/tcp', dut_class=ttfw_idf.ESP32DUT) # check and log bin size - binary_file = os.path.join(dut1.app.binary_path, 'mqtt_tcp.bin') + binary_file = os.path.join(dut.app.binary_path, 'mqtt_tcp.bin') bin_size = os.path.getsize(binary_file) - ttfw_idf.log_performance('mqtt_tcp_bin_size', '{}KB'.format(bin_size // 1024)) - # 1. start the dut test and wait till client gets IP address - dut1.start_app() + logging.info('[Performance][mqtt_tcp_bin_size]: %s KB', bin_size // 1024) # waiting for getting the IP address try: - ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] + ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)', timeout=30).group(1).decode() print('Connected to AP/Ethernet with IP: {}'.format(ip_address)) - except DUT.ExpectTimeout: + except pexpect.TIMEOUT: raise ValueError('ENV_TEST_FAILURE: Cannot connect to AP/Ethernet') # 2. start mqtt broker sketch @@ -75,20 +76,17 @@ def test_examples_protocol_mqtt_qos1(env, extra_data): thread1 = Thread(target=mqqt_server_sketch, args=(host_ip,1883)) thread1.start() - print('writing to device: {}'.format('mqtt://' + host_ip + '\n')) - dut1.write('mqtt://' + host_ip + '\n') + data_write = 'mqtt://' + host_ip + print('writing to device: {}'.format(data_write)) + dut.write(data_write) thread1.join() print('Message id received from server: {}'.format(msgid)) # 3. check the message id was enqueued and then deleted - msgid_enqueued = dut1.expect(re.compile(r'outbox: ENQUEUE msgid=([0-9]+)'), timeout=30) - msgid_deleted = dut1.expect(re.compile(r'outbox: DELETED msgid=([0-9]+)'), timeout=30) + msgid_enqueued = dut.expect(b'outbox: ENQUEUE msgid=([0-9]+)', timeout=30).group(1).decode() + msgid_deleted = dut.expect(b'outbox: DELETED msgid=([0-9]+)', timeout=30).group(1).decode() # 4. check the msgid of received data are the same as that of enqueued and deleted from outbox - if (msgid_enqueued[0] == str(msgid) and msgid_deleted[0] == str(msgid)): + if (msgid_enqueued == str(msgid) and msgid_deleted == str(msgid)): print('PASS: Received correct msg id') else: print('Failure!') raise ValueError('Mismatch of msgid: received: {}, enqueued {}, deleted {}'.format(msgid, msgid_enqueued, msgid_deleted)) - - -if __name__ == '__main__': - test_examples_protocol_mqtt_qos1() diff --git a/examples/protocols/mqtt/ws/mqtt_ws_example_test.py b/examples/protocols/mqtt/ws/pytest_mqtt_ws_example.py similarity index 72% rename from examples/protocols/mqtt/ws/mqtt_ws_example_test.py rename to examples/protocols/mqtt/ws/pytest_mqtt_ws_example.py index c24cfec336..392139c37e 100644 --- a/examples/protocols/mqtt/ws/mqtt_ws_example_test.py +++ b/examples/protocols/mqtt/ws/pytest_mqtt_ws_example.py @@ -1,11 +1,16 @@ +#!/usr/bin/env python +# +# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Unlicense OR CC0-1.0 +import logging import os import re import sys from threading import Event, Thread import paho.mqtt.client as mqtt -import ttfw_idf -from tiny_test_fw import DUT +import pytest +from pytest_embedded import Dut event_client_connected = Event() event_stop_client = Event() @@ -14,19 +19,21 @@ message_log = '' # The callback for when the client receives a CONNACK response from the server. -def on_connect(client, userdata, flags, rc): +def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None + _ = (userdata, flags) print('Connected with result code ' + str(rc)) event_client_connected.set() client.subscribe('/topic/qos0') -def mqtt_client_task(client): +def mqtt_client_task(client): # type: (mqtt.Client) -> None while not event_stop_client.is_set(): client.loop() # The callback for when a PUBLISH message is received from the server. -def on_message(client, userdata, msg): +def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None + _ = userdata global message_log payload = msg.payload.decode() if not event_client_received_correct.is_set() and payload == 'data': @@ -36,8 +43,9 @@ def on_message(client, userdata, msg): message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' -@ttfw_idf.idf_example_test(env_tag='ethernet_router') -def test_examples_protocol_mqtt_ws(env, extra_data): +@pytest.mark.esp32 +@pytest.mark.ethernet +def test_examples_protocol_mqtt_ws(dut): # type: (Dut) -> None broker_url = '' broker_port = 0 """ @@ -47,14 +55,14 @@ def test_examples_protocol_mqtt_ws(env, extra_data): 3. Test evaluates it received correct qos0 message 4. Test ESP32 client received correct qos0 message """ - dut1 = env.get_dut('mqtt_websocket', 'examples/protocols/mqtt/ws', dut_class=ttfw_idf.ESP32DUT) # check and log bin size - binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket.bin') + binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket.bin') bin_size = os.path.getsize(binary_file) - ttfw_idf.log_performance('mqtt_websocket_bin_size', '{}KB'.format(bin_size // 1024)) + logging.info('[Performance][mqtt_websocket_bin_size]: %s KB', bin_size // 1024) # Look for host:port in sdkconfig try: - value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI']) + value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI')) + assert value is not None broker_url = value.group(1) broker_port = int(value.group(2)) except Exception: @@ -78,22 +86,17 @@ def test_examples_protocol_mqtt_ws(env, extra_data): print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port)) if not event_client_connected.wait(timeout=30): raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url)) - dut1.start_app() try: - ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] + ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0] print('Connected to AP with IP: {}'.format(ip_address)) - except DUT.ExpectTimeout: + except Dut.ExpectTimeout: print('ENV_TEST_FAILURE: Cannot connect to AP') raise print('Checking py-client received msg published from esp...') if not event_client_received_correct.wait(timeout=30): raise ValueError('Wrong data received, msg log: {}'.format(message_log)) print('Checking esp-client received msg published from py-client...') - dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30) + dut.expect(r'DATA=data_to_esp32', timeout=30) finally: event_stop_client.set() thread1.join() - - -if __name__ == '__main__': - test_examples_protocol_mqtt_ws() diff --git a/examples/protocols/mqtt/wss/mqtt_wss_example_test.py b/examples/protocols/mqtt/wss/pytest_mqtt_wss_example.py similarity index 73% rename from examples/protocols/mqtt/wss/mqtt_wss_example_test.py rename to examples/protocols/mqtt/wss/pytest_mqtt_wss_example.py index c0a1454dc5..05eff5e515 100644 --- a/examples/protocols/mqtt/wss/mqtt_wss_example_test.py +++ b/examples/protocols/mqtt/wss/pytest_mqtt_wss_example.py @@ -1,3 +1,8 @@ +#!/usr/bin/env python +# +# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Unlicense OR CC0-1.0 +import logging import os import re import ssl @@ -5,8 +10,9 @@ import sys from threading import Event, Thread import paho.mqtt.client as mqtt -import ttfw_idf -from tiny_test_fw import DUT +import pexpect +import pytest +from pytest_embedded import Dut event_client_connected = Event() event_stop_client = Event() @@ -15,19 +21,21 @@ message_log = '' # The callback for when the client receives a CONNACK response from the server. -def on_connect(client, userdata, flags, rc): +def on_connect(client, userdata, flags, rc): # type: (mqtt.Client, tuple, bool, str) -> None + _ = (userdata, flags) print('Connected with result code ' + str(rc)) event_client_connected.set() client.subscribe('/topic/qos0') -def mqtt_client_task(client): +def mqtt_client_task(client): # type: (mqtt.Client) -> None while not event_stop_client.is_set(): client.loop() # The callback for when a PUBLISH message is received from the server. -def on_message(client, userdata, msg): +def on_message(client, userdata, msg): # type: (mqtt.Client, tuple, mqtt.client.MQTTMessage) -> None + _ = userdata global message_log payload = msg.payload.decode() if not event_client_received_correct.is_set() and payload == 'data': @@ -37,8 +45,9 @@ def on_message(client, userdata, msg): message_log += 'Received data:' + msg.topic + ' ' + payload + '\n' -@ttfw_idf.idf_example_test(env_tag='ethernet_router') -def test_examples_protocol_mqtt_wss(env, extra_data): +@pytest.mark.esp32 +@pytest.mark.ethernet +def test_examples_protocol_mqtt_wss(dut): # type: (Dut) -> None broker_url = '' broker_port = 0 """ @@ -48,14 +57,14 @@ def test_examples_protocol_mqtt_wss(env, extra_data): 3. Test evaluates it received correct qos0 message 4. Test ESP32 client received correct qos0 message """ - dut1 = env.get_dut('mqtt_websocket_secure', 'examples/protocols/mqtt/wss', dut_class=ttfw_idf.ESP32DUT) # check and log bin size - binary_file = os.path.join(dut1.app.binary_path, 'mqtt_websocket_secure.bin') + binary_file = os.path.join(dut.app.binary_path, 'mqtt_websocket_secure.bin') bin_size = os.path.getsize(binary_file) - ttfw_idf.log_performance('mqtt_websocket_secure_bin_size', '{}KB'.format(bin_size // 1024)) + logging.info('[Performance][mqtt_websocket_secure_bin_size]: %s KB', bin_size // 1024) # Look for host:port in sdkconfig try: - value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()['CONFIG_BROKER_URI']) + value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get('BROKER_URI')) + assert value is not None broker_url = value.group(1) broker_port = int(value.group(2)) except Exception: @@ -82,22 +91,17 @@ def test_examples_protocol_mqtt_wss(env, extra_data): print('Connecting py-client to broker {}:{}...'.format(broker_url, broker_port)) if not event_client_connected.wait(timeout=30): raise ValueError('ENV_TEST_FAILURE: Test script cannot connect to broker: {}'.format(broker_url)) - dut1.start_app() try: - ip_address = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] + ip_address = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30)[0] print('Connected to AP with IP: {}'.format(ip_address)) - except DUT.ExpectTimeout: + except pexpect.TIMEOUT: print('ENV_TEST_FAILURE: Cannot connect to AP') raise print('Checking py-client received msg published from esp...') if not event_client_received_correct.wait(timeout=30): raise ValueError('Wrong data received, msg log: {}'.format(message_log)) print('Checking esp-client received msg published from py-client...') - dut1.expect(re.compile(r'DATA=data_to_esp32'), timeout=30) + dut.expect(r'DATA=data_to_esp32', timeout=30) finally: event_stop_client.set() thread1.join() - - -if __name__ == '__main__': - test_examples_protocol_mqtt_wss() diff --git a/tools/requirements/requirements.pytest.txt b/tools/requirements/requirements.pytest.txt index fdd66462dc..63884b495a 100644 --- a/tools/requirements/requirements.pytest.txt +++ b/tools/requirements/requirements.pytest.txt @@ -18,6 +18,7 @@ netifaces rangehttpserver dbus-python; sys_platform == 'linux' protobuf +paho-mqtt # for twai tests, communicate with socket can device (e.g. Canable) python-can diff --git a/tools/requirements/requirements.ttfw.txt b/tools/requirements/requirements.ttfw.txt index 6203f1dfc2..b3d4d84c23 100644 --- a/tools/requirements/requirements.ttfw.txt +++ b/tools/requirements/requirements.ttfw.txt @@ -33,7 +33,3 @@ SimpleWebSocketServer # py_debug_backend debug_backend - -# examples/protocols/mqtt -# tools/test_apps/protocols/mqtt -paho-mqtt diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/app_test.py b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py similarity index 78% rename from tools/test_apps/protocols/mqtt/publish_connect_test/app_test.py rename to tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py index f87c376357..fd6133f5be 100644 --- a/tools/test_apps/protocols/mqtt/publish_connect_test/app_test.py +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py @@ -1,5 +1,8 @@ +# SPDX-FileCopyrightText: 2022 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Unlicense OR CC0-1.0 from __future__ import print_function, unicode_literals +import logging import os import random import re @@ -10,21 +13,25 @@ import string import subprocess import sys import time +import typing from itertools import count from threading import Event, Lock, Thread +from typing import Any import paho.mqtt.client as mqtt -import ttfw_idf +import pytest from common_test_methods import get_host_ip4_by_dest_ip +from pytest_embedded import Dut +from pytest_embedded_qemu.dut import QemuDut DEFAULT_MSG_SIZE = 16 -def _path(f): +def _path(f): # type: (str) -> str return os.path.join(os.path.dirname(os.path.realpath(__file__)),f) -def set_server_cert_cn(ip): +def set_server_cert_cn(ip): # type: (str) -> None arg_list = [ ['openssl', 'req', '-out', _path('srv.csr'), '-key', _path('server.key'),'-subj', '/CN={}'.format(ip), '-new'], ['openssl', 'x509', '-req', '-in', _path('srv.csr'), '-CA', _path('ca.crt'), @@ -36,8 +43,13 @@ def set_server_cert_cn(ip): # Publisher class creating a python client to send/receive published data from esp-mqtt client class MqttPublisher: + event_client_connected = Event() + event_client_got_all = Event() + expected_data = '' + published = 0 - def __init__(self, dut, transport, qos, repeat, published, queue, publish_cfg, log_details=False): + def __init__(self, dut, transport, + qos, repeat, published, queue, publish_cfg, log_details=False): # type: (MqttPublisher, Dut, str, int, int, int, int, dict, bool) -> None # instance variables used as parameters of the publish test self.event_stop_client = Event() self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)) @@ -58,11 +70,11 @@ class MqttPublisher: MqttPublisher.event_client_got_all.clear() MqttPublisher.expected_data = self.sample_string * self.repeat - def print_details(self, text): + def print_details(self, text): # type: (str) -> None if self.log_details: print(text) - def mqtt_client_task(self, client, lock): + def mqtt_client_task(self, client, lock): # type: (MqttPublisher, mqtt.Client, Lock) -> None while not self.event_stop_client.is_set(): with lock: client.loop() @@ -70,12 +82,12 @@ class MqttPublisher: # The callback for when the client receives a CONNACK response from the server (needs to be static) @staticmethod - def on_connect(_client, _userdata, _flags, _rc): + def on_connect(_client, _userdata, _flags, _rc): # type: (mqtt.Client, tuple, bool, str) -> None MqttPublisher.event_client_connected.set() # The callback for when a PUBLISH message is received from the server (needs to be static) @staticmethod - def on_message(client, userdata, msg): + def on_message(client, userdata, msg): # type: (mqtt.Client, int, mqtt.client.MQTTMessage) -> None payload = msg.payload.decode() if payload == MqttPublisher.expected_data: userdata += 1 @@ -83,7 +95,7 @@ class MqttPublisher: if userdata == MqttPublisher.published: MqttPublisher.event_client_got_all.set() - def __enter__(self): + def __enter__(self): # type: (MqttPublisher) -> None qos = self.publish_cfg['qos'] queue = self.publish_cfg['queue'] @@ -100,6 +112,7 @@ class MqttPublisher: self.client = mqtt.Client(transport='websockets') else: self.client = mqtt.Client() + assert self.client is not None self.client.on_connect = MqttPublisher.on_connect self.client.on_message = MqttPublisher.on_message self.client.user_data_set(0) @@ -137,7 +150,8 @@ class MqttPublisher: self.event_stop_client.set() thread1.join() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None + assert self.client is not None self.client.disconnect() self.event_stop_client.clear() @@ -145,7 +159,7 @@ class MqttPublisher: # Simple server for mqtt over TLS connection class TlsServer: - def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): + def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): # type: (TlsServer, int, bool, bool, bool) -> None self.port = port self.socket = socket.socket() self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -153,11 +167,9 @@ class TlsServer: self.shutdown = Event() self.client_cert = client_cert self.refuse_connection = refuse_connection - self.ssl_error = None self.use_alpn = use_alpn - self.negotiated_protocol = None - def __enter__(self): + def __enter__(self): # type: (TlsServer) -> TlsServer try: self.socket.bind(('', self.port)) except socket.error as e: @@ -170,20 +182,21 @@ class TlsServer: return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback): # type: (TlsServer, str, str, str) -> None self.shutdown.set() self.server_thread.join() self.socket.close() if (self.conn is not None): self.conn.close() - def get_last_ssl_error(self): + def get_last_ssl_error(self): # type: (TlsServer) -> str return self.ssl_error + @typing.no_type_check def get_negotiated_protocol(self): return self.negotiated_protocol - def run_server(self): + def run_server(self) -> None: context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) if self.client_cert: context.verify_mode = ssl.CERT_REQUIRED @@ -201,11 +214,10 @@ class TlsServer: print(' - negotiated_protocol: {}'.format(self.negotiated_protocol)) self.handle_conn() except ssl.SSLError as e: - self.conn = None self.ssl_error = str(e) print(' - SSLError: {}'.format(str(e))) - def handle_conn(self): + def handle_conn(self) -> None: while not self.shutdown.is_set(): r,w,e = select.select([self.conn], [], [], 1) try: @@ -216,7 +228,7 @@ class TlsServer: print(' - error: {}'.format(err)) raise - def process_mqtt_connect(self): + def process_mqtt_connect(self) -> None: try: data = bytearray(self.conn.recv(1024)) message = ''.join(format(x, '02x') for x in data) @@ -235,22 +247,22 @@ class TlsServer: self.shutdown.set() -def connection_tests(dut, cases, dut_ip): +def connection_tests(dut, cases, dut_ip): # type: (Dut, dict, str) -> None ip = get_host_ip4_by_dest_ip(dut_ip) set_server_cert_cn(ip) server_port = 2222 - def teardown_connection_suite(): + def teardown_connection_suite() -> None: dut.write('conn teardown 0 0') - def start_connection_case(case, desc): + def start_connection_case(case, desc): # type: (str, str) -> Any print('Starting {}: {}'.format(case, desc)) case_id = cases[case] dut.write('conn {} {} {}'.format(ip, server_port, case_id)) dut.expect('Test case:{} started'.format(case_id)) return case_id - for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']: + for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']: # All these cases connect to the server with no server verification or with server only verification with TlsServer(server_port): test_nr = start_connection_case(case, 'default server - expect to connect normally') @@ -266,13 +278,13 @@ def connection_tests(dut, cases, dut_ip): if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error(): raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) - for case in ['CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']: + for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']: # These cases connect to server with both server and client verification (client key might be password protected) with TlsServer(server_port, client_cert=True): test_nr = start_connection_case(case, 'server with client verification - expect to connect normally') dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30) - case = 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT' + case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT' with TlsServer(server_port) as s: test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error') dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) @@ -280,7 +292,7 @@ def connection_tests(dut, cases, dut_ip): if 'alert unknown ca' not in s.get_last_ssl_error(): raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) - case = 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT' + case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT' with TlsServer(server_port, client_cert=True) as s: test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error') dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) @@ -288,13 +300,13 @@ def connection_tests(dut, cases, dut_ip): if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error(): raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) - for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: + for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: with TlsServer(server_port, use_alpn=True) as s: test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol') dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30) - if case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None: + if case == 'EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None: print(' - client with alpn off, no negotiated protocol: OK') - elif case == 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt': + elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt': print(' - client with alpn on, negotiated protocol resolved: OK') else: raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol())) @@ -302,19 +314,19 @@ def connection_tests(dut, cases, dut_ip): teardown_connection_suite() -@ttfw_idf.idf_custom_test(env_tag='ethernet_router', group='test-apps') -def test_app_protocol_mqtt_publish_connect(env, extra_data): +@pytest.mark.esp32 +@pytest.mark.ethernet +def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None: """ steps: 1. join AP 2. connect to uri specified in the config 3. send and receive data """ - dut1 = env.get_dut('mqtt_publish_connect_test', 'tools/test_apps/protocols/mqtt/publish_connect_test') # check and log bin size - binary_file = os.path.join(dut1.app.binary_path, 'mqtt_publish_connect_test.bin') + binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin') bin_size = os.path.getsize(binary_file) - ttfw_idf.log_performance('mqtt_publish_connect_test_bin_size', '{}KB'.format(bin_size // 1024)) + logging.info('[Performance][mqtt_publish_connect_test_bin_size]: %s KB', bin_size // 1024) # Look for test case symbolic names and publish configs cases = {} @@ -322,25 +334,24 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data): try: # Get connection test cases configuration: symbolic names for test cases - for case in ['CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT', - 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT', - 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', - 'CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT', - 'CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT', - 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD', - 'CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT', - 'CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: - cases[case] = dut1.app.get_sdkconfig()[case] + for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', + 'EXAMPLE_CONNECT_CASE_SERVER_CERT', + 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', + 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT', + 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT', + 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD', + 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT', + 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: + cases[case] = dut.app.sdkconfig.get(case) except Exception: print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig') raise - dut1.start_app() - esp_ip = dut1.expect(re.compile(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]'), timeout=30)[0] + esp_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode() print('Got IP={}'.format(esp_ip)) if not os.getenv('MQTT_SKIP_CONNECT_TEST'): - connection_tests(dut1,cases,esp_ip) + connection_tests(dut,cases,esp_ip) # # start publish tests only if enabled in the environment (for weekend tests only) @@ -349,27 +360,28 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data): # Get publish test configuration try: - def get_host_port_from_dut(dut1, config_option): - value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut1.app.get_sdkconfig()[config_option]) + @typing.no_type_check + def get_host_port_from_dut(dut, config_option): + value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option)) if value is None: return None, None return value.group(1), int(value.group(2)) - publish_cfg['publish_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_SUBSCRIBE_TOPIC'].replace('"','') - publish_cfg['subscribe_topic'] = dut1.app.get_sdkconfig()['CONFIG_EXAMPLE_PUBLISH_TOPIC'].replace('"','') - publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_SSL_URI') - publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_TCP_URI') - publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WS_URI') - publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut1, 'CONFIG_EXAMPLE_BROKER_WSS_URI') + publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','') + publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','') + publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI') + publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI') + publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WS_URI') + publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI') except Exception: print('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig') raise - def start_publish_case(transport, qos, repeat, published, queue): + def start_publish_case(transport, qos, repeat, published, queue): # type: (str, int, int, int, int) -> None print('Starting Publish test: transport:{}, qos:{}, nr_of_msgs:{}, msg_size:{}, enqueue:{}' .format(transport, qos, published, repeat * DEFAULT_MSG_SIZE, queue)) - with MqttPublisher(dut1, transport, qos, repeat, published, queue, publish_cfg): + with MqttPublisher(dut, transport, qos, repeat, published, queue, publish_cfg): pass # Initialize message sizes and repeat counts (if defined in the environment) @@ -378,7 +390,7 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data): # Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x} env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']} if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']): - messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) + messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore continue break if not messages: # No message sizes present in the env - set defaults @@ -400,4 +412,4 @@ def test_app_protocol_mqtt_publish_connect(env, extra_data): if __name__ == '__main__': - test_app_protocol_mqtt_publish_connect(dut=ttfw_idf.ESP32QEMUDUT if sys.argv[1:] == ['qemu'] else ttfw_idf.ESP32DUT) + test_app_protocol_mqtt_publish_connect(dut=QemuDut if sys.argv[1:] == ['qemu'] else Dut)