diff --git a/src/dtls13.c b/src/dtls13.c index 2e0f68ad6..b5fecafd7 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -419,7 +419,7 @@ static int Dtls13SendFragFromBuffer(WOLFSSL* ssl, byte* output, word16 length) if (ret != 0) return ret; - buf = ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + buf = GetOutputBuffer(ssl); XMEMCPY(buf, output, length); @@ -924,8 +924,7 @@ static int Dtls13SendFragmentedInternal(WOLFSSL* ssl) if (ret != 0) return ret; - output = - ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); ret = Dtls13HandshakeAddHeaderFrag(ssl, output + rlHeaderLength, (enum HandShakeType)ssl->dtls13FragHandshakeType, @@ -1509,8 +1508,7 @@ static int Dtls13RtxSendBuffered(WOLFSSL* ssl) if (ret != 0) return ret; - output = - ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); XMEMCPY(output + headerLength, r->data, r->length); @@ -2342,8 +2340,7 @@ static int Dtls13WriteAckMessage(WOLFSSL* ssl, if (ret != 0) return ret; - output = - ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); ackMessage = output + headerLength; @@ -2617,8 +2614,7 @@ int SendDtls13Ack(WOLFSSL* ssl) if (ret != 0) return ret; - output = - ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); ret = Dtls13RlAddPlaintextHeader(ssl, output, ack, (word16)length); if (ret != 0) @@ -2632,10 +2628,10 @@ int SendDtls13Ack(WOLFSSL* ssl) if (ret != 0) return ret; - output = - ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); outputSize = ssl->buffers.outputBuffer.bufferSize - + ssl->buffers.outputBuffer.idx - ssl->buffers.outputBuffer.length; headerSize = Dtls13GetRlHeaderLength(ssl, 1); diff --git a/src/internal.c b/src/internal.c index 45124fd64..ca166e8d9 100644 --- a/src/internal.c +++ b/src/internal.c @@ -8967,10 +8967,7 @@ int DtlsMsgPoolSend(WOLFSSL* ssl, int sendOnlyFirstPacket) return ret; } - XMEMCPY(ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.idx + - ssl->buffers.outputBuffer.length, - pool->raw, pool->sz); + XMEMCPY(GetOutputBuffer(ssl), pool->raw, pool->sz); ssl->buffers.outputBuffer.length += pool->sz; } else { @@ -9011,8 +9008,7 @@ int DtlsMsgPoolSend(WOLFSSL* ssl, int sendOnlyFirstPacket) return ret; } - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); if (inputSz != ENUM_LEN) sendSz = BuildMessage(ssl, output, sendSz, input, inputSz, handshake, 0, 0, 0, epochOrder); @@ -9743,8 +9739,7 @@ static int SendHandshakeMsg(WOLFSSL* ssl, byte* input, word32 inputSz, return ret; if (ssl->buffers.outputBuffer.buffer == NULL) return MEMORY_E; - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); if (IsEncryptionOn(ssl, 1)) { /* First we need to add the fragment header ourselves. @@ -9952,6 +9947,7 @@ void ShrinkOutputBuffer(WOLFSSL* ssl) ssl->buffers.outputBuffer.bufferSize = STATIC_BUFFER_LEN; ssl->buffers.outputBuffer.dynamicFlag = 0; ssl->buffers.outputBuffer.offset = 0; + /* idx and length are assumed to be 0. */ } @@ -10074,6 +10070,14 @@ int SendBuffered(WOLFSSL* ssl) } +/* returns the current location in the output buffer to start writing to */ +byte* GetOutputBuffer(WOLFSSL* ssl) +{ + return ssl->buffers.outputBuffer.buffer + ssl->buffers.outputBuffer.idx + + ssl->buffers.outputBuffer.length; +} + + /* Grow the output buffer */ static WC_INLINE int GrowOutputBuffer(WOLFSSL* ssl, int size) { @@ -10085,6 +10089,8 @@ static WC_INLINE int GrowOutputBuffer(WOLFSSL* ssl, int size) #else const byte align = WOLFSSL_GENERAL_ALIGNMENT; #endif + int newSz = size + ssl->buffers.outputBuffer.idx + + ssl->buffers.outputBuffer.length; #if WOLFSSL_GENERAL_ALIGNMENT > 0 /* the encrypted data will be offset from the front of the buffer by @@ -10095,8 +10101,7 @@ static WC_INLINE int GrowOutputBuffer(WOLFSSL* ssl, int size) align *= 2; #endif - tmp = (byte*)XMALLOC(size + ssl->buffers.outputBuffer.length + align, - ssl->heap, DYNAMIC_TYPE_OUT_BUFFER); + tmp = (byte*)XMALLOC(newSz + align, ssl->heap, DYNAMIC_TYPE_OUT_BUFFER); WOLFSSL_MSG("growing output buffer"); if (tmp == NULL) @@ -10111,14 +10116,14 @@ static WC_INLINE int GrowOutputBuffer(WOLFSSL* ssl, int size) /* can be from IO memory pool which does not need copy if same buffer */ if (ssl->buffers.outputBuffer.length && tmp == ssl->buffers.outputBuffer.buffer) { - ssl->buffers.outputBuffer.bufferSize = - size + ssl->buffers.outputBuffer.length; + ssl->buffers.outputBuffer.bufferSize = newSz; return 0; } #endif if (ssl->buffers.outputBuffer.length) XMEMCPY(tmp, ssl->buffers.outputBuffer.buffer, + ssl->buffers.outputBuffer.idx + ssl->buffers.outputBuffer.length); if (ssl->buffers.outputBuffer.dynamicFlag) { @@ -10136,8 +10141,7 @@ static WC_INLINE int GrowOutputBuffer(WOLFSSL* ssl, int size) ssl->buffers.outputBuffer.offset = 0; ssl->buffers.outputBuffer.buffer = tmp; - ssl->buffers.outputBuffer.bufferSize = size + - ssl->buffers.outputBuffer.length; + ssl->buffers.outputBuffer.bufferSize = newSz; return 0; } @@ -10235,8 +10239,7 @@ int CheckAvailableSize(WOLFSSL *ssl, int size) #ifdef WOLFSSL_DTLS if (ssl->options.dtls) { - if (size + ssl->buffers.outputBuffer.length - - ssl->buffers.outputBuffer.idx > + if (size + ssl->buffers.outputBuffer.length > #if defined(WOLFSSL_SCTP) || defined(WOLFSSL_DTLS_MTU) ssl->dtlsMtuSz #else @@ -10268,8 +10271,9 @@ int CheckAvailableSize(WOLFSSL *ssl, int size) } #endif - if (ssl->buffers.outputBuffer.bufferSize - ssl->buffers.outputBuffer.length - < (word32)size) { + if ((ssl->buffers.outputBuffer.bufferSize - + ssl->buffers.outputBuffer.length - + ssl->buffers.outputBuffer.idx) < (word32)size) { if (GrowOutputBuffer(ssl, size) < 0) return MEMORY_E; } @@ -16065,13 +16069,13 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx, #endif /* !WOLFSSL_NO_TLS12 */ #ifdef WOLFSSL_EXTRA_ALERTS -void SendFatalAlertOnly(WOLFSSL *ssl, int error) +int SendFatalAlertOnly(WOLFSSL *ssl, int error) { int why; /* already sent a more specific fatal alert */ if (ssl->alert_history.last_tx.level == alert_fatal) - return; + return 0; switch (error) { /* not fatal errors */ @@ -16081,12 +16085,12 @@ void SendFatalAlertOnly(WOLFSSL *ssl, int error) #ifdef WOLFSSL_ASYNC_CRYPT case WC_PENDING_E: #endif - return; + return 0; /* peer already disconnected and ssl is possibly in bad state * don't try to send an alert */ case SOCKET_ERROR_E: - return; + return error; case BUFFER_ERROR: case ASN_PARSE_E: @@ -16114,14 +16118,15 @@ void SendFatalAlertOnly(WOLFSSL *ssl, int error) break; } - SendAlert(ssl, alert_fatal, why); + return SendAlert(ssl, alert_fatal, why); } #else -void SendFatalAlertOnly(WOLFSSL *ssl, int error) +int SendFatalAlertOnly(WOLFSSL *ssl, int error) { (void)ssl; (void)error; /* no op */ + return 0; } #endif /* WOLFSSL_EXTRA_ALERTS */ @@ -16555,7 +16560,9 @@ int DtlsMsgDrain(WOLFSSL* ssl) DtlsTxMsgListClean(ssl); } else if (!IsAtLeastTLSv1_3(ssl->version)) { - SendFatalAlertOnly(ssl, ret); + if (SendFatalAlertOnly(ssl, ret) == SOCKET_ERROR_E) { + ret = SOCKET_ERROR_E; + } } #ifdef WOLFSSL_ASYNC_CRYPT if (ret == WC_PENDING_E) { @@ -19874,8 +19881,12 @@ default: ssl->buffers.inputBuffer.buffer, &ssl->buffers.inputBuffer.idx, ssl->buffers.inputBuffer.length); - if (ret != 0) - SendFatalAlertOnly(ssl, ret); + if (ret != 0) { + if (SendFatalAlertOnly(ssl, ret) + == SOCKET_ERROR_E) { + ret = SOCKET_ERROR_E; + } + } } #endif #ifdef WOLFSSL_DTLS13 @@ -19912,8 +19923,10 @@ default: ssl->buffers.inputBuffer.buffer, &ssl->buffers.inputBuffer.idx, ssl->buffers.inputBuffer.length); - if (ret != 0) - SendFatalAlertOnly(ssl, ret); + if (ret != 0) { + if (SendFatalAlertOnly(ssl, ret) == SOCKET_ERROR_E) + ret = SOCKET_ERROR_E; + } #else ret = BUFFER_ERROR; #endif @@ -20328,8 +20341,7 @@ int SendChangeCipher(WOLFSSL* ssl) return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); AddRecordHeader(output, 1, change_cipher_spec, ssl, CUR_ORDER); @@ -21256,9 +21268,7 @@ int SendFinished(WOLFSSL* ssl) #endif /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; - + output = GetOutputBuffer(ssl); AddHandShakeHeader(input, finishedSz, 0, finishedSz, finished, ssl); /* make finished hashes */ @@ -21636,8 +21646,7 @@ int SendCertificate(WOLFSSL* ssl) return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); /* Safe to use ssl->fragOffset since it will be incremented immediately * after this block. This block needs to be entered only once to not @@ -21879,8 +21888,7 @@ int SendCertificateRequest(WOLFSSL* ssl) return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); AddHeaders(output, reqSz, certificate_request, ssl); @@ -22038,8 +22046,7 @@ static int BuildCertificateStatus(WOLFSSL* ssl, byte type, buffer* status, ssl->options.buildingMsg = 1; if ((ret = CheckAvailableSize(ssl, sendSz)) == 0) { - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); AddHeaders(output, length, certificate_status, ssl); @@ -22635,8 +22642,7 @@ int SendData(WOLFSSL* ssl, const void* data, int sz) return ssl->error = ret; /* get output buffer */ - out = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + out = GetOutputBuffer(ssl); #ifdef HAVE_LIBZ if (ssl->options.usingCompression) { @@ -22965,9 +22971,7 @@ static int SendAlert_ex(WOLFSSL* ssl, int severity, int type) return BUFFER_E; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; - + output = GetOutputBuffer(ssl); input[0] = (byte)severity; input[1] = (byte)type; ssl->alert_history.last_tx.code = type; @@ -26420,8 +26424,7 @@ static int HashSkeData(WOLFSSL* ssl, enum wc_HashType hashType, return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); AddHeaders(output, length, client_hello, ssl); @@ -29917,8 +29920,7 @@ int SendClientKeyExchange(WOLFSSL* ssl) goto exit_scke; /* get output buffer */ - args->output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + args->output = GetOutputBuffer(ssl); AddHeaders(args->output, args->encSz + tlsSz, client_key_exchange, ssl); @@ -30923,9 +30925,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; - + output = GetOutputBuffer(ssl); AddHeaders(output, length, server_hello, ssl); /* now write to output */ @@ -34406,9 +34406,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; - + output = GetOutputBuffer(ssl); AddHeaders(output, 0, server_hello_done, ssl); if (IsEncryptionOn(ssl, 1)) { @@ -35256,9 +35254,7 @@ cleanup: return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; - + output = GetOutputBuffer(ssl); AddHeaders(output, length, session_ticket, ssl); /* hint */ @@ -35797,9 +35793,7 @@ static int DefTicketEncCb(WOLFSSL* ssl, byte key_name[WOLFSSL_TICKET_NAME_SZ], return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; - + output = GetOutputBuffer(ssl); AddHeaders(output, 0, hello_request, ssl); if (IsEncryptionOn(ssl, 1)) { @@ -35871,8 +35865,7 @@ static int DefTicketEncCb(WOLFSSL* ssl, byte key_name[WOLFSSL_TICKET_NAME_SZ], return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); /* Hello Verify Request should use the same sequence number * as the Client Hello unless we are in renegotiation then diff --git a/src/tls13.c b/src/tls13.c index 2f5910a5b..4726c8322 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -4210,8 +4210,7 @@ int SendTls13ClientHello(WOLFSSL* ssl) return ret; /* Get position in output buffer to write new message to. */ - args->output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + args->output = GetOutputBuffer(ssl); /* Put the record and handshake headers on. */ AddTls13Headers(args->output, args->length, client_hello, ssl); @@ -6935,8 +6934,7 @@ int SendTls13ServerHello(WOLFSSL* ssl, byte extMsgType) return ret; /* Get position in output buffer to write new message to. */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); /* Put the record and handshake headers on. */ AddTls13Headers(output, length, server_hello, ssl); @@ -7178,8 +7176,7 @@ static int SendTls13EncryptedExtensions(WOLFSSL* ssl) return ret; /* Get position in output buffer to write new message to. */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); /* Put the record and handshake headers on. */ AddTls13Headers(output, length, encrypted_extensions, ssl); @@ -7300,8 +7297,7 @@ static int SendTls13CertificateRequest(WOLFSSL* ssl, byte* reqCtx, return ret; /* Get position in output buffer to write new message to. */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); /* Put the record and handshake headers on. */ AddTls13Headers(output, reqSz, certificate_request, ssl); @@ -8024,8 +8020,7 @@ static int SendTls13Certificate(WOLFSSL* ssl) return ret; /* Get position in output buffer to write new message to. */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); if (ssl->fragOffset == 0) { AddTls13FragHeaders(output, fragSz, 0, payloadSz, certificate, ssl); @@ -8278,8 +8273,7 @@ static int SendTls13CertificateVerify(WOLFSSL* ssl) } /* get output buffer */ - args->output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + args->output = GetOutputBuffer(ssl); /* Advance state and proceed */ ssl->options.asyncState = TLS_ASYNC_BUILD; @@ -9491,8 +9485,7 @@ static int SendTls13Finished(WOLFSSL* ssl) return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); input = output + RECORD_HEADER_SZ; #ifdef WOLFSSL_DTLS13 @@ -9748,8 +9741,7 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl) return ret; /* get output buffer */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); input = output + RECORD_HEADER_SZ; #ifdef WOLFSSL_DTLS13 @@ -9941,8 +9933,7 @@ static int SendTls13EndOfEarlyData(WOLFSSL* ssl) return ret; /* Get position in output buffer to write new message to. */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); /* Put the record and handshake headers on. */ AddTls13Headers(output, length, end_of_early_data, ssl); @@ -10364,8 +10355,7 @@ static int SendTls13NewSessionTicket(WOLFSSL* ssl) return ret; /* Get position in output buffer to write new message to. */ - output = ssl->buffers.outputBuffer.buffer + - ssl->buffers.outputBuffer.length; + output = GetOutputBuffer(ssl); /* Put the record and handshake headers on. */ AddTls13Headers(output, length, session_ticket, ssl); diff --git a/tests/api.c b/tests/api.c index 1511438da..73b0dd00a 100644 --- a/tests/api.c +++ b/tests/api.c @@ -10476,7 +10476,11 @@ static int test_tls_ext_duplicate(void) wolfSSL_SetIOReadCtx(ssl, &msg); ExpectIntNE(wolfSSL_accept(ssl), WOLFSSL_SUCCESS); - ExpectIntEQ(wolfSSL_get_error(ssl, 0), DUPLICATE_TLS_EXT_E); + /* can return duplicate ext error or socket error if the peer closed down + * while sending alert */ + if (wolfSSL_get_error(ssl, 0) != SOCKET_ERROR_E) { + ExpectIntEQ(wolfSSL_get_error(ssl, 0), DUPLICATE_TLS_EXT_E); + } wolfSSL_free(ssl); wolfSSL_CTX_free(ctx); diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 58247046c..dbe64dda1 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -5834,7 +5834,7 @@ WOLFSSL_LOCAL int ReceiveData(WOLFSSL* ssl, byte* output, int sz, int peek); WOLFSSL_LOCAL int SendFinished(WOLFSSL* ssl); WOLFSSL_LOCAL int RetrySendAlert(WOLFSSL* ssl); WOLFSSL_LOCAL int SendAlert(WOLFSSL* ssl, int severity, int type); -WOLFSSL_LOCAL void SendFatalAlertOnly(WOLFSSL *ssl, int error); +WOLFSSL_LOCAL int SendFatalAlertOnly(WOLFSSL *ssl, int error); WOLFSSL_LOCAL int ProcessReply(WOLFSSL* ssl); WOLFSSL_LOCAL int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr); @@ -5857,6 +5857,7 @@ WOLFSSL_LOCAL int TLSv1_3_Capable(WOLFSSL* ssl); WOLFSSL_LOCAL void FreeHandshakeResources(WOLFSSL* ssl); WOLFSSL_LOCAL void ShrinkInputBuffer(WOLFSSL* ssl, int forcedFree); WOLFSSL_LOCAL void ShrinkOutputBuffer(WOLFSSL* ssl); +WOLFSSL_LOCAL byte* GetOutputBuffer(WOLFSSL* ssl); WOLFSSL_LOCAL int VerifyClientSuite(word16 havePSK, byte cipherSuite0, byte cipherSuite);