diff --git a/examples/benchmark/tls_bench.c b/examples/benchmark/tls_bench.c index 37b64b850..1806eea7a 100644 --- a/examples/benchmark/tls_bench.c +++ b/examples/benchmark/tls_bench.c @@ -345,10 +345,10 @@ static double gettime_secs(int reset) static void* client_thread(void* args) { info_t* info = (info_t*)args; - unsigned char buf[MEM_BUFFER_SZ]; + unsigned char* buf; unsigned char *writeBuf; double start; - int ret; + int ret, bufSize; WOLFSSL_CTX* cli_ctx; WOLFSSL* cli_ssl; int haveShownPeerInfo = 0; @@ -414,28 +414,45 @@ static void* client_thread(void* args) showPeer(cli_ssl); } - /* write test message to server */ - while (info->client_stats.rxTotal < info->numBytes) { - start = gettime_secs(1); - ret = wolfSSL_write(cli_ssl, writeBuf, info->packetSize); - info->client_stats.txTime += gettime_secs(0) - start; - if (ret > 0) { - info->client_stats.txTotal += ret; - } - - /* read echo of message */ - start = gettime_secs(1); - ret = wolfSSL_read(cli_ssl, buf, sizeof(buf)-1); - info->client_stats.rxTime += gettime_secs(0) - start; - if (ret > 0) { - info->client_stats.rxTotal += ret; - } - - /* validate echo */ - if (strncmp((char*)writeBuf, (char*)buf, info->packetSize) != 0) { - err_sys("echo check failed!\n"); - } + /* Allocate buf after handshake is complete */ + bufSize = wolfSSL_GetMaxOutputSize(cli_ssl); + if (bufSize > 0) { + buf = (unsigned char*)malloc(bufSize); } + else { + buf = NULL; + } + + if (buf != NULL) { + /* write test message to server */ + while (info->client_stats.rxTotal < info->numBytes) { + start = gettime_secs(1); + ret = wolfSSL_write(cli_ssl, writeBuf, info->packetSize); + info->client_stats.txTime += gettime_secs(0) - start; + if (ret > 0) { + info->client_stats.txTotal += ret; + } + + /* read echo of message */ + start = gettime_secs(1); + ret = wolfSSL_read(cli_ssl, buf, bufSize-1); + info->client_stats.rxTime += gettime_secs(0) - start; + if (ret > 0) { + info->client_stats.rxTotal += ret; + } + + /* validate echo */ + if (strncmp((char*)writeBuf, (char*)buf, info->packetSize) != 0) { + err_sys("echo check failed!\n"); + } + } + + free(buf); + } + else { + err_sys("failed to allocate memory"); + } + info->client_stats.connCount++; @@ -456,9 +473,9 @@ static void* client_thread(void* args) static void* server_thread(void* args) { info_t* info = (info_t*)args; - unsigned char buf[MEM_BUFFER_SZ]; + unsigned char *buf; double start; - int ret, len = 0; + int ret, len = 0, bufSize; WOLFSSL_CTX* srv_ctx; WOLFSSL* srv_ssl; @@ -521,24 +538,39 @@ static void* server_thread(void* args) info->server_stats.connTime += start; - while (info->server_stats.txTotal < info->numBytes) { - /* read msg post handshake from client */ - memset(buf, 0, sizeof(buf)); - start = gettime_secs(1); - ret = wolfSSL_read(srv_ssl, buf, sizeof(buf)-1); - info->server_stats.rxTime += gettime_secs(0) - start; - if (ret > 0) { - info->server_stats.rxTotal += ret; - len = ret; - } + /* Allocate buf after handshake is complete */ + bufSize = wolfSSL_GetMaxOutputSize(srv_ssl); + if (bufSize > 0) { + buf = (unsigned char*)malloc(bufSize); + } + else { + buf = NULL; + } - /* write message back to client */ - start = gettime_secs(1); - ret = wolfSSL_write(srv_ssl, buf, len); - info->server_stats.txTime += gettime_secs(0) - start; - if (ret > 0) { - info->server_stats.txTotal += ret; + if (buf != NULL) { + while (info->server_stats.txTotal < info->numBytes) { + /* read msg post handshake from client */ + memset(buf, 0, bufSize); + start = gettime_secs(1); + ret = wolfSSL_read(srv_ssl, buf, bufSize-1); + info->server_stats.rxTime += gettime_secs(0) - start; + if (ret > 0) { + info->server_stats.rxTotal += ret; + len = ret; + } + + /* write message back to client */ + start = gettime_secs(1); + ret = wolfSSL_write(srv_ssl, buf, len); + info->server_stats.txTime += gettime_secs(0) - start; + if (ret > 0) { + info->server_stats.txTotal += ret; + } } + free(buf); + } + else { + err_sys("failed to allocate memory"); } info->server_stats.connCount++; diff --git a/src/internal.c b/src/internal.c index 78bb88799..91ed63fb6 100644 --- a/src/internal.c +++ b/src/internal.c @@ -5208,7 +5208,6 @@ void FreeSSL(WOLFSSL* ssl, void* heap) (void)heap; } - #if !defined(NO_OLD_TLS) || defined(HAVE_CHACHA) || defined(HAVE_AESCCM) \ || defined(HAVE_AESGCM) || defined(WOLFSSL_DTLS) static INLINE void GetSEQIncrement(WOLFSSL* ssl, int verify, word32 seq[2]) @@ -13585,17 +13584,18 @@ int SendCertificate(WOLFSSL* ssl) length -= (ssl->fragOffset + headerSz); maxFragment = MAX_RECORD_SIZE; + if (ssl->options.dtls) { #ifdef WOLFSSL_DTLS + /* The 100 bytes is used to account for the UDP and IP headers. + It can also include the record padding and MAC if the + SendCertificate is called for a secure renegotiation. */ maxFragment = MAX_MTU - DTLS_RECORD_HEADER_SZ - DTLS_HANDSHAKE_HEADER_SZ - 100; #endif /* WOLFSSL_DTLS */ } - #ifdef HAVE_MAX_FRAGMENT - if (ssl->max_fragment != 0 && maxFragment >= ssl->max_fragment) - maxFragment = ssl->max_fragment; - #endif /* HAVE_MAX_FRAGMENT */ + maxFragment = wolfSSL_GetMaxRecordSize(ssl, maxFragment); while (length > 0 && ret == 0) { byte* output = NULL; @@ -14447,10 +14447,7 @@ int SendData(WOLFSSL* ssl, const void* data, int sz) if (sent == sz) break; - len = min(sz - sent, OUTPUT_RECORD_SIZE); -#ifdef HAVE_MAX_FRAGMENT - len = min(len, ssl->max_fragment); -#endif + len = wolfSSL_GetMaxRecordSize(ssl, sz - sent); #ifdef WOLFSSL_DTLS if (IsDtlsNotSctpMode(ssl)) { @@ -25641,6 +25638,30 @@ int wolfSSL_AsyncPush(WOLFSSL* ssl, WC_ASYNC_DEV* asyncDev) #endif /* WOLFSSL_ASYNC_CRYPT */ +/* return the max record size */ +int wolfSSL_GetMaxRecordSize(WOLFSSL* ssl, int maxFragment) +{ + (void) ssl; /* Avoid compiler warnings */ + + if (maxFragment > MAX_RECORD_SIZE) { + maxFragment = MAX_RECORD_SIZE; + } + +#ifdef HAVE_MAX_FRAGMENT + if ((ssl->max_fragment != 0) && (maxFragment > ssl->max_fragment)) { + maxFragment = ssl->max_fragment; + } +#endif /* HAVE_MAX_FRAGMENT */ +#ifdef WOLFSSL_DTLS + if ((ssl->options.dtls) && (maxFragment > MAX_UDP_SIZE)) { + maxFragment = MAX_UDP_SIZE; + } +#endif + + return maxFragment; +} + + #undef ERROR_OUT #endif /* WOLFCRYPT_ONLY */ diff --git a/src/ssl.c b/src/ssl.c index 1b0e803b9..82d5165d1 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -1386,8 +1386,6 @@ int wolfSSL_CTX_is_static_memory(WOLFSSL_CTX* ctx, WOLFSSL_MEM_STATS* mem_stats) /* return max record layer size plaintext input size */ int wolfSSL_GetMaxOutputSize(WOLFSSL* ssl) { - int maxSize = OUTPUT_RECORD_SIZE; - WOLFSSL_ENTER("wolfSSL_GetMaxOutputSize"); if (ssl == NULL) @@ -1398,17 +1396,7 @@ int wolfSSL_GetMaxOutputSize(WOLFSSL* ssl) return BAD_FUNC_ARG; } -#ifdef HAVE_MAX_FRAGMENT - maxSize = min(maxSize, ssl->max_fragment); -#endif - -#ifdef WOLFSSL_DTLS - if (ssl->options.dtls) { - maxSize = min(maxSize, MAX_UDP_SIZE); - } -#endif - - return maxSize; + return wolfSSL_GetMaxRecordSize(ssl, OUTPUT_RECORD_SIZE); } @@ -1717,10 +1705,8 @@ static int wolfSSL_read_internal(WOLFSSL* ssl, void* data, int sz, int peek) } #endif - sz = min(sz, OUTPUT_RECORD_SIZE); -#ifdef HAVE_MAX_FRAGMENT - sz = min(sz, ssl->max_fragment); -#endif + sz = wolfSSL_GetMaxRecordSize(ssl, sz); + ret = ReceiveData(ssl, (byte*)data, sz, peek); #ifdef HAVE_WRITE_DUP diff --git a/src/tls13.c b/src/tls13.c index 72afb1a95..73838b38c 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -4629,12 +4629,7 @@ static int SendTls13Certificate(WOLFSSL* ssl) if (ssl->fragOffset != 0) length -= (ssl->fragOffset + headerSz); - maxFragment = MAX_RECORD_SIZE; - - #ifdef HAVE_MAX_FRAGMENT - if (ssl->max_fragment != 0 && maxFragment >= ssl->max_fragment) - maxFragment = ssl->max_fragment; - #endif /* HAVE_MAX_FRAGMENT */ + maxFragment = wolfSSL_GetMaxRecordSize(ssl, MAX_RECORD_SIZE); while (length > 0 && ret == 0) { byte* output = NULL; diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 8827ef0fc..598cef2df 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -3832,6 +3832,7 @@ WOLFSSL_LOCAL void ShrinkOutputBuffer(WOLFSSL* ssl); WOLFSSL_LOCAL int VerifyClientSuite(WOLFSSL* ssl); WOLFSSL_LOCAL int SetTicket(WOLFSSL*, const byte*, word32); +WOLFSSL_LOCAL int wolfSSL_GetMaxRecordSize(WOLFSSL* ssl, int maxFragment); #ifndef NO_CERTS #ifndef NO_RSA