diff --git a/IDE/STM32Cube/wolfssl_example.c b/IDE/STM32Cube/wolfssl_example.c index eead6fa24..168cc848c 100644 --- a/IDE/STM32Cube/wolfssl_example.c +++ b/IDE/STM32Cube/wolfssl_example.c @@ -76,6 +76,9 @@ /* use non-blocking mode for read/write IO */ #define BENCH_USE_NONBLOCK #endif +#ifndef RECV_WAIT_TIMEOUT + #define RECV_WAIT_TIMEOUT 4000 +#endif /***************************************************************************** * Private types/enumerations/variables @@ -539,10 +542,14 @@ static int ServerMemRecv(info_t* info, char* buf, int sz) !info->client.done) { osSemaphoreRelease(info->server.mutex); #ifdef CMSIS_OS2_H_ - osThreadFlagsWait(1, osFlagsWaitAny, osWaitForever); + if (osThreadFlagsWait(1, osFlagsWaitAny, RECV_WAIT_TIMEOUT) == osFlagsErrorTimeout) { + return WOLFSSL_CBIO_ERR_TIMEOUT; + } osSemaphoreAcquire(info->server.mutex, osWaitForever); #else - osSignalWait(1, osWaitForever); + if (osSignalWait(1, RECV_WAIT_TIMEOUT) == osEventTimeout) { + return WOLFSSL_CBIO_ERR_TIMEOUT; + } osSemaphoreWait(info->server.mutex, osWaitForever); #endif } @@ -624,10 +631,14 @@ static int ClientMemRecv(info_t* info, char* buf, int sz) !info->server.done) { osSemaphoreRelease(info->client.mutex); #ifdef CMSIS_OS2_H_ - osThreadFlagsWait(1, osFlagsWaitAny, osWaitForever); + if (osThreadFlagsWait(1, osFlagsWaitAny, RECV_WAIT_TIMEOUT) == osFlagsErrorTimeout) { + return WOLFSSL_CBIO_ERR_TIMEOUT; + } osSemaphoreAcquire(info->client.mutex, osWaitForever); #else - osSignalWait(1, osWaitForever); + if (osSignalWait(1, RECV_WAIT_TIMEOUT) == osEventTimeout) { + return WOLFSSL_CBIO_ERR_TIMEOUT; + } osSemaphoreWait(info->client.mutex, osWaitForever); #endif } @@ -936,6 +947,9 @@ static void client_thread(const void* args) int ret; info_t* info = (info_t*)args; +#ifdef CMSIS_OS2_H_ + info->client.threadId = osThreadGetId(); +#endif do { ret = bench_tls_client(info); @@ -949,13 +963,14 @@ static void client_thread(const void* args) } info->client.ret = ret; info->client.done = 1; - osThreadSuspend(NULL); + osThreadSuspend(info->client.threadId); if (info->doShutdown) info->client.done = 1; } while (!info->doShutdown); osThreadTerminate(info->client.threadId); + info->client.threadId = NULL; } static int bench_tls_server(info_t* info) @@ -1207,6 +1222,9 @@ static void server_thread(const void* args) int ret; info_t* info = (info_t*)args; +#ifdef CMSIS_OS2_H_ + info->server.threadId = osThreadGetId(); +#endif do { ret = bench_tls_server(info); @@ -1220,13 +1238,14 @@ static void server_thread(const void* args) } info->server.ret = ret; info->server.done = 1; - osThreadSuspend(NULL); + osThreadSuspend(info->server.threadId); if (info->doShutdown) info->server.done = 1; } while (!info->doShutdown); osThreadTerminate(info->server.threadId); + info->server.threadId = NULL; } #ifdef CMSIS_OS2_H_ @@ -1343,7 +1362,7 @@ int bench_tls(void* args) /* start threads */ if (info->server.threadId == 0) { #ifdef CMSIS_OS2_H_ - info->server.threadId = osThreadNew(server_thread, info, &server_thread_attributes); + osThreadNew(server_thread, info, &server_thread_attributes); #else info->server.threadId = osThreadCreate(&info->server.threadDef, info); #endif @@ -1353,7 +1372,7 @@ int bench_tls(void* args) } if (info->client.threadId == 0) { #ifdef CMSIS_OS2_H_ - info->client.threadId = osThreadNew(client_thread, info, &client_thread_attributes); + osThreadNew(client_thread, info, &client_thread_attributes); #else info->client.threadId = osThreadCreate(&info->client.threadDef, info); #endif