From fba535756e796a09b1766c9fec3d933e5a791a41 Mon Sep 17 00:00:00 2001 From: Euripedes Rocha Date: Thu, 7 Sep 2023 11:35:04 +0200 Subject: [PATCH] ci(mqtt): Refactor publish connect test --- .../publish_connect_test/main/connect_test.c | 138 ++--- .../main/publish_connect_test.c | 312 +++++++++-- .../main/publish_connect_test.h | 55 ++ .../publish_connect_test/main/publish_test.c | 132 ++--- .../publish_connect_test/pytest_mqtt_app.py | 521 ++++++------------ .../pytest_mqtt_publish_app.py | 229 ++++++++ 6 files changed, 834 insertions(+), 553 deletions(-) create mode 100644 tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.h create mode 100644 tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/main/connect_test.c b/tools/test_apps/protocols/mqtt/publish_connect_test/main/connect_test.c index e8132efa31..88db721799 100644 --- a/tools/test_apps/protocols/mqtt/publish_connect_test/main/connect_test.c +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/main/connect_test.c @@ -8,11 +8,11 @@ */ #include -#include "esp_netif.h" +#include "esp_console.h" #include "esp_log.h" #include "mqtt_client.h" -#include "esp_tls.h" +#include "publish_connect_test.h" #if (!defined(CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT)) || \ (!defined(CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT)) || \ @@ -34,17 +34,23 @@ extern const uint8_t client_inv_crt[] asm("_binary_client_inv_crt_start"); extern const uint8_t client_no_pwd_key[] asm("_binary_client_no_pwd_key_start"); static const char *TAG = "connect_test"; -static esp_mqtt_client_handle_t mqtt_client = NULL; static int running_test_case = 0; static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data) { + (void)handler_args; + (void)base; + (void)event_id; esp_mqtt_event_handle_t event = event_data; ESP_LOGD(TAG, "Event: %d, Test case: %d", event->event_id, running_test_case); switch (event->event_id) { + case MQTT_EVENT_BEFORE_CONNECT: + break; case MQTT_EVENT_CONNECTED: ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED: Test=%d", running_test_case); break; + case MQTT_EVENT_DISCONNECTED: + break; case MQTT_EVENT_ERROR: ESP_LOGI(TAG, "MQTT_EVENT_ERROR: Test=%d", running_test_case); if (event->error_handle->error_type == MQTT_ERROR_TYPE_ESP_TLS) { @@ -61,44 +67,17 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_ } } -static void create_client(void) +static void connect_no_certs(esp_mqtt_client_handle_t client, const char *uri) { - const esp_mqtt_client_config_t mqtt_cfg = { - .broker.address.uri = "mqtts://127.0.0.1:1234" - }; - esp_mqtt_client_handle_t client = esp_mqtt_client_init(&mqtt_cfg); - esp_mqtt_client_register_event(client, ESP_EVENT_ANY_ID, mqtt_event_handler, client); - mqtt_client = client; - esp_mqtt_client_start(client); - ESP_LOGI(TAG, "mqtt client created for connection tests"); -} - -static void destroy_client(void) -{ - if (mqtt_client) { - esp_mqtt_client_stop(mqtt_client); - esp_mqtt_client_destroy(mqtt_client); - mqtt_client = NULL; - ESP_LOGI(TAG, "mqtt client for connection tests destroyed"); - } -} - -static void connect_no_certs(const char *host, const int port) -{ - char uri[64]; - sprintf(uri, "mqtts://%s:%d", host, port); + ESP_LOGI(TAG, "Runnning :CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT"); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -static void connect_with_client_key_password(const char *host, const int port) +static void connect_with_client_key_password(esp_mqtt_client_handle_t client, const char *uri) { - char uri[64]; - sprintf(uri, "mqtts://%s:%d", host, port); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri, .broker.verification.certificate = (const char *)ca_local_crt, @@ -107,15 +86,11 @@ static void connect_with_client_key_password(const char *host, const int port) .credentials.authentication.key_password = "esp32", .credentials.authentication.key_password_len = 5 }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -static void connect_with_server_der_cert(const char *host, const int port) +static void connect_with_server_der_cert(esp_mqtt_client_handle_t client, const char *uri) { - char uri[64]; - sprintf(uri, "mqtts://%s:%d", host, port); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri, .broker.verification.certificate = (const char *)ca_der_start, @@ -123,123 +98,96 @@ static void connect_with_server_der_cert(const char *host, const int port) .credentials.authentication.certificate = "NULL", .credentials.authentication.key = "NULL" }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -static void connect_with_wrong_server_cert(const char *host, const int port) +static void connect_with_wrong_server_cert(esp_mqtt_client_handle_t client, const char *uri) { - char uri[64]; - sprintf(uri, "mqtts://%s:%d", host, port); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri, .broker.verification.certificate = (const char *)client_pwd_crt, .credentials.authentication.certificate = "NULL", .credentials.authentication.key = "NULL" }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -static void connect_with_server_cert(const char *host, const int port) +static void connect_with_server_cert(esp_mqtt_client_handle_t client, const char *uri) { - char uri[64]; - sprintf(uri, "mqtts://%s:%d", host, port); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri, .broker.verification.certificate = (const char *)ca_local_crt, }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -static void connect_with_server_client_certs(const char *host, const int port) +static void connect_with_server_client_certs(esp_mqtt_client_handle_t client, const char *uri) { - char uri[64]; - sprintf(uri, "mqtts://%s:%d", host, port); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri, .broker.verification.certificate = (const char *)ca_local_crt, .credentials.authentication.certificate = (const char *)client_pwd_crt, .credentials.authentication.key = (const char *)client_no_pwd_key }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -static void connect_with_invalid_client_certs(const char *host, const int port) +static void connect_with_invalid_client_certs(esp_mqtt_client_handle_t client, const char *uri) { - char uri[64]; - sprintf(uri, "mqtts://%s:%d", host, port); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri, .broker.verification.certificate = (const char *)ca_local_crt, .credentials.authentication.certificate = (const char *)client_inv_crt, .credentials.authentication.key = (const char *)client_no_pwd_key }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -static void connect_with_alpn(const char *host, const int port) +static void connect_with_alpn(esp_mqtt_client_handle_t client, const char *uri) { - char uri[64]; const char *alpn_protos[] = { "mymqtt", NULL }; - sprintf(uri, "mqtts://%s:%d", host, port); const esp_mqtt_client_config_t mqtt_cfg = { .broker.address.uri = uri, .broker.verification.alpn_protos = alpn_protos }; - esp_mqtt_set_config(mqtt_client, &mqtt_cfg); - esp_mqtt_client_disconnect(mqtt_client); - esp_mqtt_client_reconnect(mqtt_client); + esp_mqtt_set_config(client, &mqtt_cfg); } -void connection_test(const char *line) -{ - char test_type[32]; - char host[32]; - int port; - int test_case; +void connect_setup(command_context_t * ctx) { + esp_mqtt_client_register_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, ctx->data); +} - sscanf(line, "%s %s %d %d", test_type, host, &port, &test_case); - if (mqtt_client == NULL) { - create_client(); - } - if (strcmp(host, "teardown") == 0) { - destroy_client();; - } - ESP_LOGI(TAG, "CASE:%d, connecting to mqtts://%s:%d ", test_case, host, port); +void connect_teardown(command_context_t * ctx) { + esp_mqtt_client_unregister_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler); +} +void connection_test(command_context_t * ctx, const char *uri, int test_case) +{ + ESP_LOGI(TAG, "CASE:%d, connecting to %s", test_case, uri); running_test_case = test_case; switch (test_case) { case CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT: - connect_no_certs(host, port); + connect_no_certs(ctx->mqtt_client, uri); break; case CONFIG_EXAMPLE_CONNECT_CASE_SERVER_CERT: - connect_with_server_cert(host, port); + connect_with_server_cert(ctx->mqtt_client, uri); break; case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH: - connect_with_server_client_certs(host, port); + connect_with_server_client_certs(ctx->mqtt_client, uri); break; case CONFIG_EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT: - connect_with_wrong_server_cert(host, port); + connect_with_wrong_server_cert(ctx->mqtt_client, uri); break; case CONFIG_EXAMPLE_CONNECT_CASE_SERVER_DER_CERT: - connect_with_server_der_cert(host, port); + connect_with_server_der_cert(ctx->mqtt_client, uri); break; case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD: - connect_with_client_key_password(host, port); + connect_with_client_key_password(ctx->mqtt_client, uri); break; case CONFIG_EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT: - connect_with_invalid_client_certs(host, port); + connect_with_invalid_client_certs(ctx->mqtt_client, uri); break; case CONFIG_EXAMPLE_CONNECT_CASE_NO_CERT_ALPN: - connect_with_alpn(host, port); + connect_with_alpn(ctx->mqtt_client, uri); break; default: ESP_LOGE(TAG, "Unknown test case %d ", test_case); diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.c b/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.c index b793215f73..d92a1fa01b 100644 --- a/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.c +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.c @@ -8,68 +8,310 @@ */ #include #include +#include #include #include "esp_system.h" +#include "mqtt_client.h" #include "nvs_flash.h" #include "esp_event.h" #include "esp_netif.h" #include "protocol_examples_common.h" +#include "esp_console.h" +#include "argtable3/argtable3.h" #include "esp_log.h" +#include "publish_connect_test.h" static const char *TAG = "publish_connect_test"; -void connection_test(const char *line); -void publish_test(const char *line); +command_context_t command_context; +connection_args_t connection_args; +publish_setup_args_t publish_setup_args; +publish_args_t publish_args; -static void get_string(char *line, size_t size) -{ - int count = 0; - while (count < size) { - int c = fgetc(stdin); - if (c == '\n') { - line[count] = '\0'; - break; - } else if (c > 0 && c < 127) { - line[count] = c; - ++count; - } - vTaskDelay(10 / portTICK_PERIOD_MS); +#define RETURN_ON_PARSE_ERROR(args) do { \ + int nerrors = arg_parse(argc, argv, (void **) &(args)); \ + if (nerrors != 0) { \ + arg_print_errors(stderr, (args).end, argv[0]); \ + return 1; \ + }} while(0) + + +static int do_free_heap(int argc, char **argv) { + (void)argc; + (void)argv; + ESP_LOGI(TAG, "Note free memory: %d bytes", esp_get_free_heap_size()); + return 0; +} + +static int do_init(int argc, char **argv) { + (void)argc; + (void)argv; + const esp_mqtt_client_config_t mqtt_cfg = { + .broker.address.uri = "mqtts://127.0.0.1:1234", + .network.disable_auto_reconnect = true + }; + command_context.mqtt_client = esp_mqtt_client_init(&mqtt_cfg); + if(!command_context.mqtt_client) { + ESP_LOGE(TAG, "Failed to initialize client"); + return 1; } + publish_init_flags(); + ESP_LOGI(TAG, "Mqtt client initialized"); + return 0; +} + +static int do_start(int argc, char **argv) { + (void)argc; + (void)argv; + if(esp_mqtt_client_start(command_context.mqtt_client) != ESP_OK) { + ESP_LOGE(TAG, "Failed to start mqtt client task"); + return 1; + } + ESP_LOGI(TAG, "Mqtt client started"); + return 0; +} + +static int do_stop(int argc, char **argv) { + (void)argc; + (void)argv; + if(esp_mqtt_client_stop(command_context.mqtt_client) != ESP_OK) { + ESP_LOGE(TAG, "Failed to stop mqtt client task"); + return 1; + } + ESP_LOGI(TAG, "Mqtt client stoped"); + return 0; +} + +static int do_disconnect(int argc, char **argv) { + (void)argc; + (void)argv; + if(esp_mqtt_client_disconnect(command_context.mqtt_client) != ESP_OK) { + ESP_LOGE(TAG, "Failed to request disconnection"); + return 1; + } + ESP_LOGI(TAG, "Mqtt client disconnected"); + return 0; +} + +static int do_connect_setup(int argc, char **argv) { + (void)argc; + (void)argv; + connect_setup(&command_context); + return 0; +} + +static int do_connect_teardown(int argc, char **argv) { + (void)argc; + (void)argv; + connect_teardown(&command_context); + return 0; +} + +static int do_reconnect(int argc, char **argv) { + (void)argc; + (void)argv; + if(esp_mqtt_client_reconnect(command_context.mqtt_client) != ESP_OK) { + ESP_LOGE(TAG, "Failed to request reconnection"); + return 1; + } + ESP_LOGI(TAG, "Mqtt client will reconnect"); + return 0; + ; +} + +static int do_destroy(int argc, char **argv) { + (void)argc; + (void)argv; + esp_mqtt_client_destroy(command_context.mqtt_client); + command_context.mqtt_client = NULL; + ESP_LOGI(TAG, "mqtt client for tests destroyed"); + return 0; +} + +static int do_connect(int argc, char **argv) +{ + int nerrors = arg_parse(argc, argv, (void **) &connection_args); + if (nerrors != 0) { + arg_print_errors(stderr, connection_args.end, argv[0]); + return 1; + } + if(!command_context.mqtt_client) { + ESP_LOGE(TAG, "MQTT client not initialized, call init first"); + return 1; + } + connection_test(&command_context, *connection_args.uri->sval, *connection_args.test_case->ival); + return 0; +} + +static int do_publish_setup(int argc, char **argv) { + RETURN_ON_PARSE_ERROR(publish_setup_args); + if(command_context.data) { + free(command_context.data); + } + command_context.data = calloc(1, sizeof(publish_context_t)); + ((publish_context_t*)command_context.data)->pattern = strdup(*publish_setup_args.pattern->sval); + ((publish_context_t*)command_context.data)->pattern_repetitions = *publish_setup_args.pattern_repetitions->ival; + publish_setup(&command_context, *publish_setup_args.transport->sval); + return 0; +} + +static int do_publish(int argc, char **argv) { + RETURN_ON_PARSE_ERROR(publish_args); + publish_test(&command_context, publish_args.expected_to_publish->ival[0], publish_args.qos->ival[0], publish_args.enqueue->ival[0]); + return 0; +} + +static int do_publish_report(int argc, char **argv) { + (void)argc; + (void)argv; + publish_context_t * ctx = command_context.data; + ESP_LOGI(TAG,"Test Report : Messages received %d, %d expected", ctx->nr_of_msg_received, ctx->nr_of_msg_expected); + + return 0; +} +void register_common_commands(void) { + const esp_console_cmd_t init = { + .command = "init", + .help = "Run inition test\n", + .hint = NULL, + .func = &do_init, + }; + + const esp_console_cmd_t start = { + .command = "start", + .help = "Run startion test\n", + .hint = NULL, + .func = &do_start, + }; + const esp_console_cmd_t stop = { + .command = "stop", + .help = "Run stopion test\n", + .hint = NULL, + .func = &do_stop, + }; + const esp_console_cmd_t destroy = { + .command = "destroy", + .help = "Run destroyion test\n", + .hint = NULL, + .func = &do_destroy, + }; + const esp_console_cmd_t free_heap = { + .command = "free_heap", + .help = "Run destroyion test\n", + .hint = NULL, + .func = &do_free_heap, + }; + ESP_ERROR_CHECK(esp_console_cmd_register(&init)); + ESP_ERROR_CHECK(esp_console_cmd_register(&start)); + ESP_ERROR_CHECK(esp_console_cmd_register(&stop)); + ESP_ERROR_CHECK(esp_console_cmd_register(&destroy)); + ESP_ERROR_CHECK(esp_console_cmd_register(&free_heap)); +} +void register_publish_commands(void) { + publish_setup_args.transport = arg_str1(NULL,NULL,"", "Selected transport to test"); + publish_setup_args.pattern = arg_str1(NULL,NULL,"", "Message pattern repeated to build big messages"); + publish_setup_args.pattern_repetitions = arg_int1(NULL,NULL,"", "How many times the pattern is repeated"); + publish_setup_args.end = arg_end(1); + + publish_args.expected_to_publish = arg_int1(NULL,NULL,"", "How many times the pattern is repeated"); + publish_args.qos = arg_int1(NULL,NULL,"", "How many times the pattern is repeated"); + publish_args.enqueue = arg_int1(NULL,NULL,"", "How many times the pattern is repeated"); + publish_args.end = arg_end(1); + const esp_console_cmd_t publish_setup = { + .command = "publish_setup", + .help = "Run publish test\n", + .hint = NULL, + .func = &do_publish_setup, + .argtable = &publish_setup_args + }; + + const esp_console_cmd_t publish = { + .command = "publish", + .help = "Run publish test\n", + .hint = NULL, + .func = &do_publish, + .argtable = &publish_args + }; + const esp_console_cmd_t publish_report = { + .command = "publish_report", + .help = "Run destroyion test\n", + .hint = NULL, + .func = &do_publish_report, + }; + ESP_ERROR_CHECK(esp_console_cmd_register(&publish_setup)); + ESP_ERROR_CHECK(esp_console_cmd_register(&publish)); + ESP_ERROR_CHECK(esp_console_cmd_register(&publish_report)); +} +void register_connect_commands(void){ + connection_args.uri = arg_str1(NULL,NULL,"", "Broker address"); + connection_args.test_case = arg_int1(NULL, NULL, "","Selected test case"); + connection_args.end = arg_end(1); + + const esp_console_cmd_t connect = { + .command = "connect", + .help = "Run connection test\n", + .hint = NULL, + .func = &do_connect, + .argtable = &connection_args + }; + + const esp_console_cmd_t reconnect = { + .command = "reconnect", + .help = "Run reconnection test\n", + .hint = NULL, + .func = &do_reconnect, + }; + const esp_console_cmd_t connection_setup = { + .command = "connection_setup", + .help = "Run reconnection test\n", + .hint = NULL, + .func = &do_connect_setup, + }; + const esp_console_cmd_t connection_teardown = { + .command = "connection_teardown", + .help = "Run reconnection test\n", + .hint = NULL, + .func = &do_connect_teardown, + }; + const esp_console_cmd_t disconnect = { + .command = "disconnect", + .help = "Run disconnection test\n", + .hint = NULL, + .func = &do_disconnect, + }; + ESP_ERROR_CHECK(esp_console_cmd_register(&connect)); + ESP_ERROR_CHECK(esp_console_cmd_register(&disconnect)); + ESP_ERROR_CHECK(esp_console_cmd_register(&reconnect)); + ESP_ERROR_CHECK(esp_console_cmd_register(&connection_setup)); + ESP_ERROR_CHECK(esp_console_cmd_register(&connection_teardown)); } void app_main(void) { - char line[256]; + static const size_t max_line = 256; ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size()); ESP_LOGI(TAG, "[APP] IDF version: %s", esp_get_idf_version()); esp_log_level_set("*", ESP_LOG_INFO); + esp_log_level_set("wifi", ESP_LOG_ERROR); esp_log_level_set("mqtt_client", ESP_LOG_VERBOSE); - esp_log_level_set("transport_base", ESP_LOG_VERBOSE); - esp_log_level_set("transport", ESP_LOG_VERBOSE); esp_log_level_set("outbox", ESP_LOG_VERBOSE); ESP_ERROR_CHECK(nvs_flash_init()); ESP_ERROR_CHECK(esp_netif_init()); ESP_ERROR_CHECK(esp_event_loop_create_default()); - - /* This helper function configures Wi-Fi or Ethernet, as selected in menuconfig. - * Read "Establishing Wi-Fi or Ethernet Connection" section in - * examples/protocols/README.md for more information about this function. - */ ESP_ERROR_CHECK(example_connect()); + esp_console_repl_t *repl = NULL; + esp_console_repl_config_t repl_config = ESP_CONSOLE_REPL_CONFIG_DEFAULT(); + repl_config.prompt = "mqtt>"; + repl_config.max_cmdline_length = max_line; + esp_console_register_help_command(); + register_common_commands(); + register_connect_commands(); + register_publish_commands(); - while (1) { - get_string(line, sizeof(line)); - if (memcmp(line, "conn", 4) == 0) { - // line starting with "conn" indicate connection tests - connection_test(line); - get_string(line, sizeof(line)); - continue; - } else { - publish_test(line); - } - } - + esp_console_dev_uart_config_t hw_config = ESP_CONSOLE_DEV_UART_CONFIG_DEFAULT(); + ESP_ERROR_CHECK(esp_console_new_repl_uart(&hw_config, &repl_config, &repl)); + ESP_ERROR_CHECK(esp_console_start_repl(repl)); } diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.h b/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.h new file mode 100644 index 0000000000..454e82e53f --- /dev/null +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_connect_test.h @@ -0,0 +1,55 @@ +/* + * SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD + * + * SPDX-License-Identifier: Unlicense OR CC0-1.0 + */ +#pragma once + +#include "mqtt_client.h" + +typedef enum {NONE, TCP, SSL, WS, WSS} transport_t; + +typedef struct { + esp_mqtt_client_handle_t mqtt_client; + void * data; +} command_context_t; + +typedef struct { + transport_t selected_transport; + char *pattern; + int pattern_repetitions; + int qos; + char *expected; + size_t expected_size; + size_t nr_of_msg_received; + size_t nr_of_msg_expected; + char * received_data; +} publish_context_t ; + +typedef struct { + struct arg_str *uri; + struct arg_int *test_case; + struct arg_end *end; +} connection_args_t; + +typedef struct { + struct arg_int *expected_to_publish; + struct arg_int *qos; + struct arg_int *enqueue; + struct arg_end *end; +} publish_args_t; + +typedef struct { + struct arg_str *transport; + struct arg_str *pattern; + struct arg_int *pattern_repetitions; + struct arg_end *end; +} publish_setup_args_t; + +void publish_init_flags(void); +void publish_setup(command_context_t * ctx, char const * transport); +void publish_teardown(command_context_t * ctx); +void publish_test(command_context_t * ctx, int expect_to_publish, int qos, bool enqueue); +void connection_test(command_context_t * ctx, const char *uri, int test_case); +void connect_setup(command_context_t * ctx); +void connect_teardown(command_context_t * ctx); diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_test.c b/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_test.c index d9ddf08e75..3bdefb61af 100644 --- a/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_test.c +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/main/publish_test.c @@ -6,33 +6,25 @@ software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include #include #include #include #include -#include "esp_system.h" #include "freertos/FreeRTOS.h" -#include "freertos/task.h" -#include "freertos/event_groups.h" +#include +#include "esp_system.h" #include "esp_log.h" #include "mqtt_client.h" #include "sdkconfig.h" +#include "publish_connect_test.h" static const char *TAG = "publish_test"; static EventGroupHandle_t mqtt_event_group; const static int CONNECTED_BIT = BIT0; -static esp_mqtt_client_handle_t mqtt_client = NULL; - -static char *expected_data = NULL; -static char *actual_data = NULL; -static size_t expected_size = 0; -static size_t expected_published = 0; -static size_t actual_published = 0; -static int qos_test = 0; - #if CONFIG_EXAMPLE_BROKER_CERTIFICATE_OVERRIDDEN == 1 static const uint8_t mqtt_eclipseprojects_io_pem_start[] = "-----BEGIN CERTIFICATE-----\n" CONFIG_EXAMPLE_BROKER_CERTIFICATE_OVERRIDE "\n-----END CERTIFICATE-----"; #else @@ -42,6 +34,7 @@ extern const uint8_t mqtt_eclipseprojects_io_pem_end[] asm("_binary_mqtt_eclip static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_t event_id, void *event_data) { + publish_context_t * test_data = handler_args; esp_mqtt_event_handle_t event = event_data; esp_mqtt_client_handle_t client = event->client; static int msg_id = 0; @@ -52,7 +45,7 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_ case MQTT_EVENT_CONNECTED: ESP_LOGI(TAG, "MQTT_EVENT_CONNECTED"); xEventGroupSetBits(mqtt_event_group, CONNECTED_BIT); - msg_id = esp_mqtt_client_subscribe(client, CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, qos_test); + msg_id = esp_mqtt_client_subscribe(client, CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, test_data->qos); ESP_LOGI(TAG, "sent subscribe successful %s , msg_id=%d", CONFIG_EXAMPLE_SUBSCRIBE_TOPIC, msg_id); break; @@ -67,13 +60,12 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_ ESP_LOGI(TAG, "MQTT_EVENT_UNSUBSCRIBED, msg_id=%d", event->msg_id); break; case MQTT_EVENT_PUBLISHED: - ESP_LOGI(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id); + ESP_LOGD(TAG, "MQTT_EVENT_PUBLISHED, msg_id=%d", event->msg_id); break; case MQTT_EVENT_DATA: ESP_LOGI(TAG, "MQTT_EVENT_DATA"); - printf("TOPIC=%.*s\r\n", event->topic_len, event->topic); - printf("DATA=%.*s\r\n", event->data_len, event->data); - printf("ID=%d, total_len=%d, data_len=%d, current_data_offset=%d\n", event->msg_id, event->total_data_len, event->data_len, event->current_data_offset); + ESP_LOGI(TAG, "TOPIC=%.*s", event->topic_len, event->topic); + ESP_LOGI(TAG, "ID=%d, total_len=%d, data_len=%d, current_data_offset=%d", event->msg_id, event->total_data_len, event->data_len, event->current_data_offset); if (event->topic) { actual_len = event->data_len; msg_id = event->msg_id; @@ -85,24 +77,23 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_ abort(); } } - memcpy(actual_data + event->current_data_offset, event->data, event->data_len); + memcpy(test_data->received_data + event->current_data_offset, event->data, event->data_len); if (actual_len == event->total_data_len) { - if (0 == memcmp(actual_data, expected_data, expected_size)) { - printf("OK!"); - memset(actual_data, 0, expected_size); - actual_published ++; - if (actual_published == expected_published) { - printf("Correct pattern received exactly x times\n"); + if (0 == memcmp(test_data->received_data, test_data->expected, test_data->expected_size)) { + memset(test_data->received_data, 0, test_data->expected_size); + test_data->nr_of_msg_received ++; + if (test_data->nr_of_msg_received == test_data->nr_of_msg_expected) { + ESP_LOGI(TAG, "Correct pattern received exactly x times"); ESP_LOGI(TAG, "Test finished correctly!"); } } else { - printf("FAILED!"); + ESP_LOGE(TAG, "FAILED!"); abort(); } } break; case MQTT_EVENT_ERROR: - ESP_LOGI(TAG, "MQTT_EVENT_ERROR"); + ESP_LOGE(TAG, "MQTT_EVENT_ERROR"); break; default: ESP_LOGI(TAG, "Other event id:%d", event->event_id); @@ -110,37 +101,31 @@ static void mqtt_event_handler(void *handler_args, esp_event_base_t base, int32_ } } -typedef enum {NONE, TCP, SSL, WS, WSS} transport_t; -static transport_t current_transport; void test_init(void) { - mqtt_event_group = xEventGroupCreate(); - esp_mqtt_client_config_t config = {0}; - mqtt_client = esp_mqtt_client_init(&config); - current_transport = NONE; - esp_mqtt_client_register_event(mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, NULL); ESP_LOGI(TAG, "[APP] Free memory: %d bytes", esp_get_free_heap_size()); } -void pattern_setup(char *pattern, int repeat) +void pattern_setup(publish_context_t * test_data) { - int pattern_size = strlen(pattern); - free(expected_data); - free(actual_data); - actual_published = 0; - expected_size = pattern_size * repeat; - expected_data = malloc(expected_size); - actual_data = malloc(expected_size); - for (int i = 0; i < repeat; i++) { - memcpy(expected_data + i * pattern_size, pattern, pattern_size); + int pattern_size = strlen(test_data->pattern); + free(test_data->expected); + free(test_data->received_data); + test_data->nr_of_msg_received = 0; + test_data->expected_size = (size_t)(pattern_size) * test_data->pattern_repetitions; + test_data->expected = malloc(test_data->expected_size); + test_data->received_data = malloc(test_data->expected_size); + for (int i = 0; i < test_data->pattern_repetitions; i++) { + memcpy(test_data->expected + (ptrdiff_t)(i * pattern_size), test_data->pattern, pattern_size); } - printf("EXPECTED STRING %.*s, SIZE:%d\n", expected_size, expected_data, expected_size); + ESP_LOGI(TAG, "EXPECTED STRING %.*s, SIZE:%d", test_data->expected_size, test_data->expected, test_data->expected_size); } -static void configure_client(char *transport) +static void configure_client(command_context_t * ctx, const char *transport) { + publish_context_t * test_data = ctx->data; ESP_LOGI(TAG, "Configuration"); transport_t selected_transport; if (0 == strcmp(transport, "tcp")) { @@ -157,7 +142,8 @@ static void configure_client(char *transport) } - if (selected_transport != current_transport) { + if (selected_transport != test_data->selected_transport) { + test_data->selected_transport = selected_transport; esp_mqtt_client_config_t config = {0}; switch (selected_transport) { case NONE: @@ -183,43 +169,45 @@ static void configure_client(char *transport) ESP_LOGI(TAG, "Set certificate"); config.broker.verification.certificate = (const char *)mqtt_eclipseprojects_io_pem_start; } - esp_mqtt_set_config(mqtt_client, &config); - + esp_mqtt_set_config(ctx->mqtt_client, &config); } - } -void publish_test(const char *line) -{ - char pattern[32]; - char transport[32]; - int repeat = 0; - int enqueue = 0; - static bool is_test_init = false; - if (!is_test_init) { - test_init(); - is_test_init = true; - } else { - esp_mqtt_client_stop(mqtt_client); - } +void publish_init_flags(void) { + mqtt_event_group = xEventGroupCreate(); +} - sscanf(line, "%s %s %d %d %d %d", transport, pattern, &repeat, &expected_published, &qos_test, &enqueue); - ESP_LOGI(TAG, "PATTERN:%s REPEATED:%d PUBLISHED:%d", pattern, repeat, expected_published); - pattern_setup(pattern, repeat); +void publish_setup(command_context_t * ctx, char const * const transport) { xEventGroupClearBits(mqtt_event_group, CONNECTED_BIT); - configure_client(transport); - esp_mqtt_client_start(mqtt_client); + publish_context_t * data = (publish_context_t*)ctx->data; + pattern_setup(data); + configure_client(ctx, transport); + esp_mqtt_client_register_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler, data); +} - ESP_LOGI(TAG, "Note free memory: %d bytes", esp_get_free_heap_size()); +void publish_teardown(command_context_t * ctx) +{ + esp_mqtt_client_unregister_event(ctx->mqtt_client, ESP_EVENT_ANY_ID, mqtt_event_handler); +} + +void publish_test(command_context_t * ctx, int expect_to_publish, int qos, bool enqueue) +{ + publish_context_t * data = (publish_context_t*)ctx->data; + data->nr_of_msg_expected = expect_to_publish; + ESP_LOGI(TAG, "PATTERN:%s REPEATED:%d PUBLISHED:%d", data->pattern, data->pattern_repetitions, data->nr_of_msg_expected); xEventGroupWaitBits(mqtt_event_group, CONNECTED_BIT, false, true, portMAX_DELAY); - for (int i = 0; i < expected_published; i++) { + for (int i = 0; i < data->nr_of_msg_expected; i++) { int msg_id; if (enqueue) { - msg_id = esp_mqtt_client_enqueue(mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, expected_data, expected_size, qos_test, 0, true); + msg_id = esp_mqtt_client_enqueue(ctx->mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, data->expected, data->expected_size, qos, 0, true); } else { - msg_id = esp_mqtt_client_publish(mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, expected_data, expected_size, qos_test, 0); + msg_id = esp_mqtt_client_publish(ctx->mqtt_client, CONFIG_EXAMPLE_PUBLISH_TOPIC, data->expected, data->expected_size, qos, 0); + if(msg_id < 0) { + ESP_LOGE(TAG, "Failed to publish"); + break; + } } - ESP_LOGI(TAG, "[%d] Publishing...", msg_id); + ESP_LOGD(TAG, "Publishing msg_id=%d", msg_id); } } diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py index 6b27adc449..237701894b 100644 --- a/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_app.py @@ -1,31 +1,21 @@ # SPDX-FileCopyrightText: 2022-2023 Espressif Systems (Shanghai) CO LTD # SPDX-License-Identifier: Unlicense OR CC0-1.0 -from __future__ import print_function, unicode_literals -import difflib +import contextlib import logging import os -import random import re -import select -import socket +import socketserver import ssl -import string import subprocess -import sys -import time -import typing -from itertools import count -from threading import Event, Lock, Thread -from typing import Any +from threading import Thread +from typing import Any, Callable, Dict, Optional -import paho.mqtt.client as mqtt import pytest from common_test_methods import get_host_ip4_by_dest_ip from pytest_embedded import Dut -from pytest_embedded_qemu.dut import QemuDut -DEFAULT_MSG_SIZE = 16 +SERVER_PORT = 2222 def _path(f): # type: (str) -> str @@ -42,313 +32,98 @@ def set_server_cert_cn(ip): # type: (str) -> None raise RuntimeError('openssl command {} failed'.format(args)) -# Publisher class creating a python client to send/receive published data from esp-mqtt client -class MqttPublisher: - event_client_connected = Event() - event_client_got_all = Event() - expected_data = '' - published = 0 - sample = '' +class MQTTHandler(socketserver.StreamRequestHandler): - def __init__(self, dut, transport, - qos, repeat, published, queue, publish_cfg, log_details=False): # type: (MqttPublisher, Dut, str, int, int, int, int, dict, bool) -> None - # instance variables used as parameters of the publish test - self.event_stop_client = Event() - self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)) - self.client = None - self.dut = dut - self.log_details = log_details - self.repeat = repeat - self.publish_cfg = publish_cfg - self.publish_cfg['qos'] = qos - self.publish_cfg['queue'] = queue - self.publish_cfg['transport'] = transport - self.lock = Lock() - # static variables used to pass options to and from static callbacks of paho-mqtt client - MqttPublisher.event_client_connected = Event() - MqttPublisher.event_client_got_all = Event() - MqttPublisher.published = published - MqttPublisher.event_client_connected.clear() - MqttPublisher.event_client_got_all.clear() - MqttPublisher.expected_data = f'{self.sample_string * self.repeat}' - MqttPublisher.sample = self.sample_string - - def print_details(self, text): # type: (str) -> None - if self.log_details: - logging.info(text) - - def mqtt_client_task(self, client, lock): # type: (MqttPublisher, mqtt.Client, Lock) -> None - while not self.event_stop_client.is_set(): - with lock: - client.loop() - time.sleep(0.001) # yield to other threads - - # The callback for when the client receives a CONNACK response from the server (needs to be static) - @staticmethod - def on_connect(_client, _userdata, _flags, _rc): # type: (mqtt.Client, tuple, bool, str) -> None - MqttPublisher.event_client_connected.set() - - # The callback for when a PUBLISH message is received from the server (needs to be static) - @staticmethod - def on_message(client, userdata, msg): # type: (mqtt.Client, int, mqtt.client.MQTTMessage) -> None - payload = msg.payload.decode('utf-8') - if payload == MqttPublisher.expected_data: - userdata += 1 - client.user_data_set(userdata) - if userdata == MqttPublisher.published: - MqttPublisher.event_client_got_all.set() - else: - differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, MqttPublisher.expected_data)))) - logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:' - f'{len(MqttPublisher.expected_data)}') - logging.info(f'Repetitions: {payload.count(MqttPublisher.sample)}') - logging.info(f'Pattern: {MqttPublisher.sample}') - logging.info(f'First : {payload[:DEFAULT_MSG_SIZE]}') - logging.info(f'Last : {payload[-DEFAULT_MSG_SIZE:]}') - matcher = difflib.SequenceMatcher(a=payload, b=MqttPublisher.expected_data) - for match in matcher.get_matching_blocks(): - logging.info(f'Match: {match}') - - def __enter__(self): # type: (MqttPublisher) -> None - - qos = self.publish_cfg['qos'] - queue = self.publish_cfg['queue'] - transport = self.publish_cfg['transport'] - broker_host = self.publish_cfg['broker_host_' + transport] - broker_port = self.publish_cfg['broker_port_' + transport] - - # Start the test - self.print_details(f'PUBLISH TEST: transport:{transport}, qos:{qos}, sequence:{MqttPublisher.published},' - f"enqueue:{queue}, sample msg:'{MqttPublisher.expected_data}'") - - try: - if transport in ['ws', 'wss']: - self.client = mqtt.Client(transport='websockets') + def handle(self) -> None: + logging.info(' - connection from: {}'.format(self.client_address)) + data = bytearray(self.request.recv(1024)) + message = ''.join(format(x, '02x') for x in data) + if message[0:16] == '101800044d515454': + if self.server.refuse_connection is False: # type: ignore + logging.info(' - received mqtt connect, sending ACK') + self.request.send(bytearray.fromhex('20020000')) else: - self.client = mqtt.Client() - assert self.client is not None - self.client.on_connect = MqttPublisher.on_connect - self.client.on_message = MqttPublisher.on_message - self.client.user_data_set(0) - - if transport in ['ssl', 'wss']: - self.client.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) - self.client.tls_insecure_set(True) - self.print_details('Connecting...') - self.client.connect(broker_host, broker_port, 60) - except Exception: - self.print_details(f'ENV_TEST_FAILURE: Unexpected error while connecting to broker {broker_host}') - raise - # Starting a py-client in a separate thread - thread1 = Thread(target=self.mqtt_client_task, args=(self.client, self.lock)) - thread1.start() - self.print_details('Connecting py-client to broker {}:{}...'.format(broker_host, broker_port)) - if not MqttPublisher.event_client_connected.wait(timeout=30): - raise ValueError(f'ENV_TEST_FAILURE: Test script cannot connect to broker: {broker_host}') - with self.lock: - self.client.subscribe(self.publish_cfg['subscribe_topic'], qos) - self.dut.write(f'{transport} {self.sample_string} {self.repeat} {MqttPublisher.published} {qos} {queue}') - try: - # waiting till subscribed to defined topic - self.dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60) - for _ in range(MqttPublisher.published): - with self.lock: - self.client.publish(self.publish_cfg['publish_topic'], self.sample_string * self.repeat, qos) - self.print_details('Publishing...') - self.print_details('Checking esp-client received msg published from py-client...') - self.dut.expect(re.compile(rb'Correct pattern received exactly x times'), timeout=60) - if not MqttPublisher.event_client_got_all.wait(timeout=60): - raise ValueError('Not all data received from ESP32: {}'.format(transport)) - logging.info(' - all data received from ESP32') - finally: - self.event_stop_client.set() - thread1.join() - - def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None - assert self.client is not None - self.client.disconnect() - self.event_stop_client.clear() + # injecting connection not authorized error + logging.info(' - received mqtt connect, sending NAK') + self.request.send(bytearray.fromhex('20020005')) + else: + raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message)) # Simple server for mqtt over TLS connection -class TlsServer: +class TlsServer(socketserver.TCPServer): + timeout = 30.0 + allow_reuse_address = True + allow_reuse_port = True - def __init__(self, port, client_cert=False, refuse_connection=False, use_alpn=False): # type: (TlsServer, int, bool, bool, bool) -> None - self.port = port - self.socket = socket.socket() - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.settimeout(10.0) - self.shutdown = Event() - self.client_cert = client_cert + def __init__(self, + port:int = SERVER_PORT, + ServerHandler: Callable[[Any, Any, Any], socketserver.BaseRequestHandler] = MQTTHandler, + client_cert:bool=False, + refuse_connection:bool=False, + use_alpn:bool=False): self.refuse_connection = refuse_connection - self.use_alpn = use_alpn - self.conn = socket.socket() + self.context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.ssl_error = '' + self.alpn_protocol: Optional[str] = None + if client_cert: + self.context.verify_mode = ssl.CERT_REQUIRED + self.context.load_verify_locations(cafile=_path('ca.crt')) + self.context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key')) + if use_alpn: + self.context.set_alpn_protocols(['mymqtt', 'http/1.1']) + self.server_thread = Thread(target=self.serve_forever) + super().__init__(('',port), ServerHandler) - def __enter__(self): # type: (TlsServer) -> TlsServer - try: - self.socket.bind(('', self.port)) - except socket.error as e: - print('Bind failed:{}'.format(e)) - raise + def server_activate(self) -> None: + self.socket = self.context.wrap_socket(self.socket, server_side=True) + super().server_activate() - self.socket.listen(1) - self.server_thread = Thread(target=self.run_server) + def __enter__(self): # type: ignore self.server_thread.start() - return self - def __exit__(self, exc_type, exc_value, traceback): # type: (TlsServer, str, str, str) -> None - self.shutdown.set() - self.server_thread.join() - self.socket.close() - if (self.conn is not None): - self.conn.close() + def server_close(self) -> None: + try: + self.shutdown() + self.server_thread.join() + super().server_close() + except RuntimeError as e: + logging.exception(e) - def get_last_ssl_error(self): # type: (TlsServer) -> str + # We need to override it here to capture ssl.SSLError + # The implementation is a slightly modified version from cpython original code. + def _handle_request_noblock(self) -> None: + try: + request, client_address = self.get_request() + self.alpn_protocol = request.selected_alpn_protocol() # type: ignore + except ssl.SSLError as e: + self.ssl_error = e.reason + return + except OSError: + return + if self.verify_request(request, client_address): + try: + self.process_request(request, client_address) + except Exception: + self.handle_error(request, client_address) + self.shutdown_request(request) + except: # noqa: E722 + self.shutdown_request(request) + raise + else: + self.shutdown_request(request) + + def last_ssl_error(self): # type: (TlsServer) -> str return self.ssl_error - @typing.no_type_check - def get_negotiated_protocol(self): - return self.negotiated_protocol - - def run_server(self) -> None: - context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - if self.client_cert: - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(cafile=_path('ca.crt')) - context.load_cert_chain(certfile=_path('srv.crt'), keyfile=_path('server.key')) - if self.use_alpn: - context.set_alpn_protocols(['mymqtt', 'http/1.1']) - self.socket = context.wrap_socket(self.socket, server_side=True) - try: - self.conn, address = self.socket.accept() # accept new connection - self.socket.settimeout(10.0) - print(' - connection from: {}'.format(address)) - if self.use_alpn: - self.negotiated_protocol = self.conn.selected_alpn_protocol() - print(' - negotiated_protocol: {}'.format(self.negotiated_protocol)) - self.handle_conn() - except ssl.SSLError as e: - self.ssl_error = str(e) - print(' - SSLError: {}'.format(str(e))) - - def handle_conn(self) -> None: - while not self.shutdown.is_set(): - r,w,e = select.select([self.conn], [], [], 1) - try: - if self.conn in r: - self.process_mqtt_connect() - - except socket.error as err: - print(' - error: {}'.format(err)) - raise - - def process_mqtt_connect(self) -> None: - try: - data = bytearray(self.conn.recv(1024)) - message = ''.join(format(x, '02x') for x in data) - if message[0:16] == '101800044d515454': - if self.refuse_connection is False: - print(' - received mqtt connect, sending ACK') - self.conn.send(bytearray.fromhex('20020000')) - else: - # injecting connection not authorized error - print(' - received mqtt connect, sending NAK') - self.conn.send(bytearray.fromhex('20020005')) - else: - raise Exception(' - error process_mqtt_connect unexpected connect received: {}'.format(message)) - finally: - # stop the server after the connect message in happy flow, or if any exception occur - self.shutdown.set() + def get_negotiated_protocol(self) -> Optional[str]: + return self.alpn_protocol -def connection_tests(dut, cases, dut_ip): # type: (Dut, dict, str) -> None - ip = get_host_ip4_by_dest_ip(dut_ip) - set_server_cert_cn(ip) - server_port = 2222 - - def teardown_connection_suite() -> None: - dut.write('conn teardown 0 0\n') - - def start_connection_case(case, desc): # type: (str, str) -> Any - print('Starting {}: {}'.format(case, desc)) - case_id = cases[case] - dut.write('conn {} {} {}\n'.format(ip, server_port, case_id)) - dut.expect('Test case:{} started'.format(case_id)) - return case_id - - for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']: - # All these cases connect to the server with no server verification or with server only verification - with TlsServer(server_port): - test_nr = start_connection_case(case, 'default server - expect to connect normally') - dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30) - with TlsServer(server_port, refuse_connection=True): - test_nr = start_connection_case(case, 'ssl shall connect, but mqtt sends connect refusal') - dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) - dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error - with TlsServer(server_port, client_cert=True) as s: - test_nr = start_connection_case(case, 'server with client verification - handshake error since client presents no client certificate') - dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) - dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE) - if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' not in s.get_last_ssl_error(): - raise RuntimeError('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) - - for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']: - # These cases connect to server with both server and client verification (client key might be password protected) - with TlsServer(server_port, client_cert=True): - test_nr = start_connection_case(case, 'server with client verification - expect to connect normally') - dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30) - - case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT' - with TlsServer(server_port) as s: - test_nr = start_connection_case(case, 'invalid server certificate on default server - expect ssl handshake error') - dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) - dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA) - if 'alert unknown ca' not in s.get_last_ssl_error(): - raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) - - case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT' - with TlsServer(server_port, client_cert=True) as s: - test_nr = start_connection_case(case, 'Invalid client certificate on server with client verification - expect ssl handshake error') - dut.expect('MQTT_EVENT_ERROR: Test={}'.format(test_nr), timeout=30) - dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED) - if 'CERTIFICATE_VERIFY_FAILED' not in s.get_last_ssl_error(): - raise Exception('Unexpected ssl error from the server {}'.format(s.get_last_ssl_error())) - - for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: - with TlsServer(server_port, use_alpn=True) as s: - test_nr = start_connection_case(case, 'server with alpn - expect connect, check resolved protocol') - dut.expect('MQTT_EVENT_CONNECTED: Test={}'.format(test_nr), timeout=30) - if case == 'EXAMPLE_CONNECT_CASE_NO_CERT' and s.get_negotiated_protocol() is None: - print(' - client with alpn off, no negotiated protocol: OK') - elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN' and s.get_negotiated_protocol() == 'mymqtt': - print(' - client with alpn on, negotiated protocol resolved: OK') - else: - raise Exception('Unexpected negotiated protocol {}'.format(s.get_negotiated_protocol())) - - teardown_connection_suite() - - -@pytest.mark.esp32 -@pytest.mark.ethernet -def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None: - """ - steps: - 1. join AP - 2. connect to uri specified in the config - 3. send and receive data - """ - # check and log bin size - binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin') - bin_size = os.path.getsize(binary_file) - logging.info('[Performance][mqtt_publish_connect_test_bin_size]: %s KB', bin_size // 1024) - - # Look for test case symbolic names and publish configs +def get_test_cases(dut: Dut) -> Any: cases = {} - publish_cfg = {} try: - # Get connection test cases configuration: symbolic names for test cases for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', @@ -360,63 +135,107 @@ def test_app_protocol_mqtt_publish_connect(dut: Dut) -> None: 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: cases[case] = dut.app.sdkconfig.get(case) except Exception: - print('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig') + logging.error('ENV_TEST_FAILURE: Some mandatory CONNECTION test case not found in sdkconfig') raise + return cases - esp_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode() - print('Got IP={}'.format(esp_ip)) - connection_tests(dut,cases,esp_ip) +def get_dut_ip(dut: Dut) -> Any: + dut_ip = dut.expect(r'IPv4 address: (\d+\.\d+\.\d+\.\d+)[^\d]', timeout=30).group(1).decode() + logging.info('Got IP={}'.format(dut_ip)) + return get_host_ip4_by_dest_ip(dut_ip) - # Get publish test configuration + +@contextlib.contextmanager +def connect_dut(dut: Dut, uri:str, case_id:int) -> Any: + dut.write('connection_setup') + dut.write(f'connect {uri} {case_id}') + dut.expect(f'Test case:{case_id} started') + dut.write('reconnect') + yield + dut.write('connection_teardown') + dut.write('disconnect') + + +def run_cases(dut:Dut, uri:str, cases:Dict[str, int]) -> None: try: - @typing.no_type_check - def get_host_port_from_dut(dut, config_option): - value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option)) - if value is None: - return None, None - return value.group(1), int(value.group(2)) + dut.write('init') + dut.write(f'start') + dut.write(f'disconnect') + for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_CERT', 'EXAMPLE_CONNECT_CASE_SERVER_DER_CERT']: + # All these cases connect to the server with no server verification or with server only verification + with TlsServer(), connect_dut(dut, uri, cases[case]): + logging.info(f'Running {case}: default server - expect to connect normally') + dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30) + with TlsServer(refuse_connection=True), connect_dut(dut, uri, cases[case]): + logging.info(f'Running {case}: ssl shall connect, but mqtt sends connect refusal') + dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) + dut.expect('MQTT ERROR: 0x5') # expecting 0x5 ... connection not authorized error + with TlsServer(client_cert=True) as server, connect_dut(dut, uri, cases[case]): + logging.info(f'Running {case}: server with client verification - handshake error since client presents no client certificate') + dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) + dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (PEER_DID_NOT_RETURN_A_CERTIFICATE) + assert 'PEER_DID_NOT_RETURN_A_CERTIFICATE' in server.last_ssl_error() - publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','') - publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','') - publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI') - publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI') - publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WS_URI') - publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_host_port_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI') + for case in ['EXAMPLE_CONNECT_CASE_MUTUAL_AUTH', 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_KEY_PWD']: + # These cases connect to server with both server and client verification (client key might be password protected) + with TlsServer(client_cert=True), connect_dut(dut, uri, cases[case]): + logging.info(f'Running {case}: server with client verification - expect to connect normally') + dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30) - except Exception: - logging.error('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig') - raise + case = 'EXAMPLE_CONNECT_CASE_INVALID_SERVER_CERT' + with TlsServer() as s, connect_dut(dut, uri, cases[case]): + logging.info(f'Running {case}: invalid server certificate on default server - expect ssl handshake error') + dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) + dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (TLSV1_ALERT_UNKNOWN_CA) + if re.match('.*alert.*unknown.*ca',s.last_ssl_error(), flags=re.I) is None: + raise Exception(f'Unexpected ssl error from the server: {s.last_ssl_error()}') - # Initialize message sizes and repeat counts (if defined in the environment) - messages = [] - for i in count(0): - # Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x} - env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']} - if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']): - messages.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore - continue - break - if not messages: # No message sizes present in the env - set defaults - messages = [{'len':0, 'repeat':5}, # zero-sized messages - {'len':2, 'repeat':10}, # short messages - {'len':200, 'repeat':3}, # long messages - {'len':20, 'repeat':50} # many medium sized - ] + case = 'EXAMPLE_CONNECT_CASE_MUTUAL_AUTH_BAD_CRT' + with TlsServer(client_cert=True) as s, connect_dut(dut, uri, cases[case]): + logging.info(f'Running {case}: Invalid client certificate on server with client verification - expect ssl handshake error') + dut.expect(f'MQTT_EVENT_ERROR: Test={cases[case]}', timeout=30) + dut.expect('ESP-TLS ERROR: ESP_ERR_MBEDTLS_SSL_HANDSHAKE_FAILED') # expect ... handshake error (CERTIFICATE_VERIFY_FAILED) + if 'CERTIFICATE_VERIFY_FAILED' not in s.last_ssl_error(): + raise Exception('Unexpected ssl error from the server {}'.format(s.last_ssl_error())) - # Iterate over all publish message properties - for transport in ['tcp', 'ssl', 'ws', 'wss']: - if publish_cfg['broker_host_' + transport] is None: - print('Skipping transport: {}...'.format(transport)) - continue - for enqueue in [0, 1]: - for qos in [0, 1, 2]: - for msg in messages: - logging.info(f'Starting Publish test: transport:{transport}, qos:{qos}, nr_of_msgs:{msg["repeat"]},' - f'msg_size:{msg["len"] * DEFAULT_MSG_SIZE}, enqueue:{enqueue}') - with MqttPublisher(dut, transport, qos, msg['len'], msg['repeat'], enqueue, publish_cfg): - pass + for case in ['EXAMPLE_CONNECT_CASE_NO_CERT', 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN']: + with TlsServer(use_alpn=True) as s, connect_dut(dut, uri, cases[case]): + logging.info(f'Running {case}: server with alpn - expect connect, check resolved protocol') + dut.expect(f'MQTT_EVENT_CONNECTED: Test={cases[case]}', timeout=30) + if case == 'EXAMPLE_CONNECT_CASE_NO_CERT': + assert s.get_negotiated_protocol() is None + elif case == 'EXAMPLE_CONNECT_CASE_NO_CERT_ALPN': + assert s.get_negotiated_protocol() == 'mymqtt' + else: + assert False, f'Unexpected negotiated protocol {s.get_negotiated_protocol()}' + finally: + dut.write('stop') + dut.write('destroy') -if __name__ == '__main__': - test_app_protocol_mqtt_publish_connect(dut=QemuDut if sys.argv[1:] == ['qemu'] else Dut) +@pytest.mark.esp32 +@pytest.mark.ethernet +def test_mqtt_connect( + dut: Dut, + log_performance: Callable[[str, object], None], +) -> None: + """ + steps: + 1. join AP + 2. connect to uri specified in the config + 3. send and receive data + """ + # check and log bin size + binary_file = os.path.join(dut.app.binary_path, 'mqtt_publish_connect_test.bin') + bin_size = os.path.getsize(binary_file) + log_performance('mqtt_publish_connect_test_bin_size', f'{bin_size // 1024} KB') + + ip = get_dut_ip(dut) + set_server_cert_cn(ip) + uri = f'mqtts://{ip}:{SERVER_PORT}' + + # Look for test case symbolic names and publish configs + cases = get_test_cases(dut) + dut.expect_exact('mqtt>', timeout=30) + run_cases(dut, uri, cases) diff --git a/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py new file mode 100644 index 0000000000..f6b77f7a26 --- /dev/null +++ b/tools/test_apps/protocols/mqtt/publish_connect_test/pytest_mqtt_publish_app.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: 2023 Espressif Systems (Shanghai) CO LTD +# SPDX-License-Identifier: Unlicense OR CC0-1.0 + +import contextlib +import difflib +import logging +import os +import random +import re +import ssl +import string +from itertools import count, product +from threading import Event, Lock +from typing import Any, Dict, List, Tuple, no_type_check + +import paho.mqtt.client as mqtt +import pexpect +import pytest +from pytest_embedded import Dut + +DEFAULT_MSG_SIZE = 16 + + +# Publisher class creating a python client to send/receive published data from esp-mqtt client +class MqttPublisher(mqtt.Client): + + def __init__(self, repeat, published, publish_cfg, log_details=False): # type: (MqttPublisher, int, int, dict, bool) -> None + self.sample_string = ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(DEFAULT_MSG_SIZE)) + self.log_details = log_details + self.repeat = repeat + self.publish_cfg = publish_cfg + self.expected_data = f'{self.sample_string * self.repeat}' + self.published = published + self.received = 0 + self.lock = Lock() + self.event_client_connected = Event() + self.event_client_got_all = Event() + transport = 'websockets' if self.publish_cfg['transport'] in ['ws', 'wss'] else 'tcp' + super().__init__('MqttTestRunner', userdata=0, transport=transport) + + def print_details(self, text): # type: (str) -> None + if self.log_details: + logging.info(text) + + def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc:int) -> None: + self.event_client_connected.set() + + def on_connect_fail(self, mqttc: Any, obj: Any) -> None: + logging.error('Connect failed') + + def on_message(self, mqttc: Any, userdata: Any, msg: mqtt.MQTTMessage) -> None: + payload = msg.payload.decode('utf-8') + if payload == self.expected_data: + userdata += 1 + self.user_data_set(userdata) + self.received = userdata + if userdata == self.published: + self.event_client_got_all.set() + else: + differences = len(list(filter(lambda data: data[0] != data[1], zip(payload, self.expected_data)))) + logging.error(f'Payload differ in {differences} positions from expected data. received size: {len(payload)} expected size:' + f'{len(self.expected_data)}') + logging.info(f'Repetitions: {payload.count(self.sample_string)}') + logging.info(f'Pattern: {self.sample_string}') + logging.info(f'First : {payload[:DEFAULT_MSG_SIZE]}') + logging.info(f'Last : {payload[-DEFAULT_MSG_SIZE:]}') + matcher = difflib.SequenceMatcher(a=payload, b=self.expected_data) + for match in matcher.get_matching_blocks(): + logging.info(f'Match: {match}') + + def __enter__(self) -> Any: + qos = self.publish_cfg['qos'] + broker_host = self.publish_cfg['broker_host_' + self.publish_cfg['transport']] + broker_port = self.publish_cfg['broker_port_' + self.publish_cfg['transport']] + + try: + self.print_details('Connecting...') + if self.publish_cfg['transport'] in ['ssl', 'wss']: + self.tls_set(None, None, None, cert_reqs=ssl.CERT_NONE, tls_version=ssl.PROTOCOL_TLSv1_2, ciphers=None) + self.tls_insecure_set(True) + self.event_client_connected.clear() + self.loop_start() + self.connect(broker_host, broker_port, 60) + except Exception: + self.print_details(f'ENV_TEST_FAILURE: Unexpected error while connecting to broker {broker_host}') + raise + self.print_details(f'Connecting py-client to broker {broker_host}:{broker_port}...') + + if not self.event_client_connected.wait(timeout=30): + raise ValueError(f'ENV_TEST_FAILURE: Test script cannot connect to broker: {broker_host}') + self.event_client_got_all.clear() + self.subscribe(self.publish_cfg['subscribe_topic'], qos) + return self + + def __exit__(self, exc_type, exc_value, traceback): # type: (MqttPublisher, str, str, dict) -> None + self.disconnect() + self.loop_stop() + + +def get_configurations(dut: Dut) -> Dict[str,Any]: + publish_cfg = {} + try: + @no_type_check + def get_broker_from_dut(dut, config_option): + # logging.info('Option:', config_option, dut.app.sdkconfig.get(config_option)) + value = re.search(r'\:\/\/([^:]+)\:([0-9]+)', dut.app.sdkconfig.get(config_option)) + if value is None: + return None, None + return value.group(1), int(value.group(2)) + # Get publish test configuration + publish_cfg['publish_topic'] = dut.app.sdkconfig.get('EXAMPLE_SUBSCRIBE_TOPIC').replace('"','') + publish_cfg['subscribe_topic'] = dut.app.sdkconfig.get('EXAMPLE_PUBLISH_TOPIC').replace('"','') + publish_cfg['broker_host_ssl'], publish_cfg['broker_port_ssl'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_SSL_URI') + publish_cfg['broker_host_tcp'], publish_cfg['broker_port_tcp'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_TCP_URI') + publish_cfg['broker_host_ws'], publish_cfg['broker_port_ws'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_WS_URI') + publish_cfg['broker_host_wss'], publish_cfg['broker_port_wss'] = get_broker_from_dut(dut, 'EXAMPLE_BROKER_WSS_URI') + + except Exception: + logging.info('ENV_TEST_FAILURE: Some mandatory PUBLISH test case not found in sdkconfig') + raise + logging.info(f'configuration: {publish_cfg}') + return publish_cfg + + +@contextlib.contextmanager +def connected_and_subscribed(dut:Dut, transport:str, pattern:str, pattern_repetitions:int) -> Any: + dut.write(f'publish_setup {transport} {pattern} {pattern_repetitions}') + dut.write(f'start') + dut.expect(re.compile(rb'MQTT_EVENT_SUBSCRIBED'), timeout=60) + yield + dut.write(f'stop') + + +def get_scenarios() -> List[Dict[str, int]]: + scenarios = [] + # Initialize message sizes and repeat counts (if defined in the environment) + for i in count(0): + # Check env variable: MQTT_PUBLISH_MSG_{len|repeat}_{x} + env_dict = {var:'MQTT_PUBLISH_MSG_' + var + '_' + str(i) for var in ['len', 'repeat']} + if os.getenv(env_dict['len']) and os.getenv(env_dict['repeat']): + scenarios.append({var: int(os.getenv(env_dict[var])) for var in ['len', 'repeat']}) # type: ignore + continue + break + if not scenarios: # No message sizes present in the env - set defaults + scenarios = [{'len':0, 'repeat':5}, # zero-sized messages + {'len':2, 'repeat':5}, # short messages + {'len':200, 'repeat':3}, # long messages + ] + return scenarios + + +def get_timeout(test_case: Any) -> int: + transport, qos, enqueue, scenario = test_case + if transport in ['ws', 'wss'] or qos == 2: + return 90 + return 60 + + +def run_publish_test_case(dut: Dut, test_case: Any, publish_cfg: Any) -> None: + transport, qos, enqueue, scenario = test_case + if publish_cfg['broker_host_' + transport] is None: + pytest.skip(f'Skipping transport: {transport}...') + repeat = scenario['len'] + published = scenario['repeat'] + publish_cfg['qos'] = qos + publish_cfg['queue'] = enqueue + publish_cfg['transport'] = transport + test_timeout = get_timeout(test_case) + logging.info(f'Starting Publish test: transport:{transport}, qos:{qos}, nr_of_msgs:{published},' + f' msg_size:{repeat*DEFAULT_MSG_SIZE}, enqueue:{enqueue}') + with MqttPublisher(repeat, published, publish_cfg) as publisher, connected_and_subscribed(dut, transport, publisher.sample_string, scenario['len']): + msgs_published: List[mqtt.MQTTMessageInfo] = [] + dut.write(f'publish {publisher.published} {qos} {enqueue}') + assert publisher.event_client_got_all.wait(timeout=test_timeout), (f'Not all data received from ESP32: {transport} ' + f'qos={qos} received: {publisher.received} ' + f'expected: {publisher.published}') + logging.info(' - all data received from ESP32') + payload = publisher.sample_string * publisher.repeat + for _ in range(publisher.published): + with publisher.lock: + msg = publisher.publish(topic=publisher.publish_cfg['publish_topic'], payload=payload, qos=qos) + if qos > 0: + msgs_published.append(msg) + logging.info(f'Published: {len(msgs_published)}') + while msgs_published: + msgs_published = [msg for msg in msgs_published if msg.is_published()] + + try: + dut.expect(re.compile(rb'Correct pattern received exactly x times'), timeout=test_timeout) + except pexpect.exceptions.ExceptionPexpect: + dut.write(f'publish_report') + dut.expect(re.compile(rb'Test Report'), timeout=30) + raise + logging.info('ESP32 received all data from runner') + + +stress_scenarios = [{'len':20, 'repeat':50}] # many medium sized +transport_cases = ['tcp', 'ws', 'wss', 'ssl'] +qos_cases = [0, 1, 2] +enqueue_cases = [0, 1] + + +def make_cases(scenarios: List[Dict[str, int]]) -> List[Tuple[str, int, int, Dict[str, int]]]: + return [test_case for test_case in product(transport_cases, qos_cases, enqueue_cases, scenarios)] + + +test_cases = make_cases(get_scenarios()) +stress_test_cases = make_cases(stress_scenarios) + + +@pytest.mark.esp32 +@pytest.mark.ethernet +@pytest.mark.parametrize('test_case', test_cases) +def test_mqtt_publish(dut: Dut, test_case: Any) -> None: + publish_cfg = get_configurations(dut) + dut.expect(re.compile(rb'mqtt>'), timeout=30) + dut.confirm_write('init', expect_pattern='init', timeout=30) + run_publish_test_case(dut, test_case, publish_cfg) + + +@pytest.mark.esp32 +@pytest.mark.ethernet +@pytest.mark.nightly_run +@pytest.mark.parametrize('test_case', stress_test_cases) +def test_mqtt_publish_stress(dut: Dut, test_case: Any) -> None: + publish_cfg = get_configurations(dut) + dut.expect(re.compile(rb'mqtt>'), timeout=30) + dut.write('init') + run_publish_test_case(dut, test_case, publish_cfg)