diff --git a/cyassl/internal.h b/cyassl/internal.h index 7b2e4e382..bcd8882cc 100644 --- a/cyassl/internal.h +++ b/cyassl/internal.h @@ -1671,6 +1671,7 @@ struct CYASSL { #endif #ifdef CYASSL_DTLS int dtls_timeout_init; /* starting timeout vaule */ + int dtls_timeout_max; /* maximum timeout value */ int dtls_timeout; /* current timeout value, changes */ DtlsPool* dtls_pool; DtlsMsg* dtls_msg_list; diff --git a/cyassl/ssl.h b/cyassl/ssl.h index ffec73b12..b49493622 100644 --- a/cyassl/ssl.h +++ b/cyassl/ssl.h @@ -255,6 +255,7 @@ CYASSL_API int CyaSSL_set_cipher_list(CYASSL*, const char*); /* Nonblocking DTLS helper functions */ CYASSL_API int CyaSSL_dtls_get_current_timeout(CYASSL* ssl); CYASSL_API int CyaSSL_dtls_set_timeout_init(CYASSL* ssl, int); +CYASSL_API int CyaSSL_dtls_set_timeout_max(CYASSL* ssl, int); CYASSL_API int CyaSSL_dtls_got_timeout(CYASSL* ssl); CYASSL_API int CyaSSL_dtls(CYASSL* ssl); diff --git a/examples/client/client.c b/examples/client/client.c index f1a528cfb..9b8faae54 100644 --- a/examples/client/client.c +++ b/examples/client/client.c @@ -56,16 +56,17 @@ static void NonBlockingSSL_Connect(CYASSL* ssl) while (ret != SSL_SUCCESS && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) { + int currTimeout = 1; + if (error == SSL_ERROR_WANT_READ) printf("... client would read block\n"); else printf("... client would write block\n"); - if (CyaSSL_dtls(ssl)) - select_ret = tcp_select(sockfd, - CyaSSL_dtls_get_current_timeout(ssl)); - else - select_ret = tcp_select(sockfd, 1); +#ifdef CYASSL_DTLS + currTimeout = CyaSSL_dtls_get_current_timeout(ssl); +#endif + select_ret = tcp_select(sockfd, currTimeout); if ((select_ret == TEST_RECV_READY) || (select_ret == TEST_ERROR_READY)) { @@ -76,11 +77,15 @@ static void NonBlockingSSL_Connect(CYASSL* ssl) #endif error = CyaSSL_get_error(ssl, 0); } - else if (select_ret == TEST_TIMEOUT && - (!CyaSSL_dtls(ssl) || - (CyaSSL_dtls_got_timeout(ssl) >= 0))) { + else if (select_ret == TEST_TIMEOUT && !CyaSSL_dtls(ssl)) { error = SSL_ERROR_WANT_READ; } +#ifdef CYASSL_DTLS + else if (select_ret == TEST_TIMEOUT && CyaSSL_dtls(ssl) && + CyaSSL_dtls_got_timeout(ssl) >= 0) { + error = SSL_ERROR_WANT_READ; + } +#endif else { error = SSL_FATAL_ERROR; } diff --git a/examples/server/server.c b/examples/server/server.c index d31e749ab..c08bdacda 100644 --- a/examples/server/server.c +++ b/examples/server/server.c @@ -55,16 +55,17 @@ static void NonBlockingSSL_Accept(SSL* ssl) while (ret != SSL_SUCCESS && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) { + int currTimeout = 1; + if (error == SSL_ERROR_WANT_READ) printf("... server would read block\n"); else printf("... server would write block\n"); - if (CyaSSL_dtls(ssl)) - select_ret = tcp_select(sockfd, - CyaSSL_dtls_get_current_timeout(ssl)); - else - select_ret = tcp_select(sockfd, 1); +#ifdef CYASSL_DTLS + currTimeout = CyaSSL_dtls_get_current_timeout(ssl); +#endif + select_ret = tcp_select(sockfd, currTimeout); if ((select_ret == TEST_RECV_READY) || (select_ret == TEST_ERROR_READY)) { @@ -76,11 +77,15 @@ static void NonBlockingSSL_Accept(SSL* ssl) #endif error = SSL_get_error(ssl, 0); } - else if (select_ret == TEST_TIMEOUT && - (!CyaSSL_dtls(ssl) || - (CyaSSL_dtls_got_timeout(ssl) >= 0))) { + else if (select_ret == TEST_TIMEOUT && !CyaSSL_dtls(ssl)) { error = SSL_ERROR_WANT_READ; } +#ifdef CYASSL_DTLS + else if (select_ret == TEST_TIMEOUT && CyaSSL_dtls(ssl) && + CyaSSL_dtls_got_timeout(ssl) >= 0) { + error = SSL_ERROR_WANT_READ; + } +#endif else { error = SSL_FATAL_ERROR; } diff --git a/src/internal.c b/src/internal.c index 64676bc4a..4fe4b2981 100644 --- a/src/internal.c +++ b/src/internal.c @@ -1347,6 +1347,7 @@ int InitSSL(CYASSL* ssl, CYASSL_CTX* ctx) ssl->keys.dtls_peer_epoch = 0; ssl->keys.dtls_expected_peer_epoch = 0; ssl->dtls_timeout_init = DTLS_TIMEOUT_INIT; + ssl->dtls_timeout_max = DTLS_TIMEOUT_MAX; ssl->dtls_timeout = ssl->dtls_timeout_init; ssl->dtls_pool = NULL; ssl->dtls_msg_list = NULL; @@ -1806,7 +1807,7 @@ void DtlsPoolReset(CYASSL* ssl) int DtlsPoolTimeout(CYASSL* ssl) { int result = -1; - if (ssl->dtls_timeout < DTLS_TIMEOUT_MAX) { + if (ssl->dtls_timeout < ssl->dtls_timeout_max) { ssl->dtls_timeout *= DTLS_TIMEOUT_MULTIPLIER; result = 0; } diff --git a/src/ssl.c b/src/ssl.c index 51fc423c9..1d4ce16d4 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -3532,16 +3532,13 @@ int CyaSSL_set_cipher_list(CYASSL* ssl, const char* list) #ifndef CYASSL_LEANPSK +#ifdef CYASSL_DTLS int CyaSSL_dtls_get_current_timeout(CYASSL* ssl) { (void)ssl; -#ifdef CYASSL_DTLS return ssl->dtls_timeout; -#else - return NOT_COMPILED_IN; -#endif } @@ -3551,33 +3548,43 @@ int CyaSSL_dtls_set_timeout_init(CYASSL* ssl, int timeout) if (ssl == NULL || timeout < 0) return BAD_FUNC_ARG; -#ifdef CYASSL_DTLS ssl->dtls_timeout_init = timeout; return SSL_SUCCESS; -#else - return NOT_COMPILED_IN; -#endif +} + + +/* user may need to alter max dtls recv timeout, SSL_SUCCESS on ok */ +int CyaSSL_dtls_set_timeout_max(CYASSL* ssl, int timeout) +{ + if (ssl == NULL || timeout < 0) + return BAD_FUNC_ARG; + + if (ssl->dtls_timeout_max < ssl->dtls_timeout_init) { + CYASSL_MSG("Can't set dtls timeout max less than dtls timeout init"); + return BAD_FUNC_ARG; + } + + ssl->dtls_timeout_max = timeout; + + return SSL_SUCCESS; } int CyaSSL_dtls_got_timeout(CYASSL* ssl) { -#ifdef CYASSL_DTLS int result = SSL_SUCCESS; + DtlsMsgListDelete(ssl->dtls_msg_list, ssl->heap); ssl->dtls_msg_list = NULL; if (DtlsPoolTimeout(ssl) < 0 || DtlsPoolSend(ssl) < 0) { result = SSL_FATAL_ERROR; } return result; -#else - (void)ssl; - return NOT_COMPILED_IN; -#endif } -#endif +#endif /* DTLS */ +#endif /* LEANPSK */ /* client only parts */