diff --git a/src/dtls13.c b/src/dtls13.c index abfe8aa9e..2e7947481 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -1551,6 +1551,9 @@ int Dtls13HandshakeSend(WOLFSSL* ssl, byte* message, word16 outputSize, maxFrag = wolfSSL_GetMaxFragSize(ssl, MAX_RECORD_SIZE); maxLen = length; + if (handshakeType == key_update) + ssl->dtls13WaitKeyUpdateAck = 1; + if (maxLen < maxFrag) { ret = Dtls13SendOneFragmentRtx(ssl, handshakeType, outputSize, message, length, hashOutput); @@ -2106,6 +2109,26 @@ static int Dtls13RtxIsTrackedByRn(const Dtls13RtxRecord* r, w64wrapper epoch, return 0; } +static int Dtls13KeyUpdateAckReceived(WOLFSSL* ssl) +{ + int ret; + w64Increment(&ssl->dtls13Epoch); + + /* Epoch wrapped up */ + if (w64IsZero(ssl->dtls13Epoch)) + return BAD_STATE_E; + + ret = DeriveTls13Keys(ssl, update_traffic_key, ENCRYPT_SIDE_ONLY, 1); + if (ret != 0) + return ret; + + ret = Dtls13NewEpoch(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY); + if (ret != 0) + return ret; + + return Dtls13SetEpochKeys(ssl, ssl->dtls13Epoch, ENCRYPT_SIDE_ONLY); +} + #ifdef WOLFSSL_DEBUG_TLS static void Dtls13PrintRtxRecord(Dtls13RtxRecord* r) { @@ -2200,12 +2223,27 @@ int Dtls13RtxTimeout(WOLFSSL* ssl) return Dtls13RtxSendBuffered(ssl); } +static int Dtls13RtxHasKeyUpdateBuffered(WOLFSSL* ssl) +{ + Dtls13RtxRecord* r = ssl->dtls13Rtx.rtxRecords; + + while (r != NULL) { + if (r->handshakeType == key_update) + return 1; + + r = r->next; + } + + return 0; +} + int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize, word32* processedSize) { const byte* ackMessage; w64wrapper epoch, seq; word16 length; + int ret; int i; if (inputSize < OPAQUE16_LEN) @@ -2234,6 +2272,16 @@ int DoDtls13Ack(WOLFSSL* ssl, const byte* input, word32 inputSize, ssl->options.serverState = SERVER_FINISHED_ACKED; } + if (ssl->dtls13WaitKeyUpdateAck) { + if (!Dtls13RtxHasKeyUpdateBuffered(ssl)) { + /* we removed the KeyUpdate message because it was ACKed */ + ssl->dtls13WaitKeyUpdateAck = 0; + ret = Dtls13KeyUpdateAckReceived(ssl); + if (ret != 0) + return ret; + } + } + *processedSize = length + OPAQUE16_LEN; /* After the handshake, not retransmitting here may incur in some extra time diff --git a/src/tls13.c b/src/tls13.c index d44b25887..e2732cd5e 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -7896,6 +7896,11 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl) WOLFSSL_START(WC_FUNC_KEY_UPDATE_SEND); WOLFSSL_ENTER("SendTls13KeyUpdate"); +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) + i = Dtls13GetRlHeaderLength(1) + DTLS_HANDSHAKE_HEADER_SZ; +#endif /* WOLFSSL_DTLS13 */ + outputSz = OPAQUE8_LEN + MAX_MSG_EXTRA; /* Check buffers are big enough and grow if needed. */ if ((ret = CheckAvailableSize(ssl, outputSz)) != 0) @@ -7906,6 +7911,11 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl) ssl->buffers.outputBuffer.length; input = output + RECORD_HEADER_SZ; +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) + input = output + Dtls13GetRlHeaderLength(1); +#endif /* WOLFSSL_DTLS13 */ + AddTls13Headers(output, OPAQUE8_LEN, key_update, ssl); /* If: @@ -7918,6 +7928,15 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl) /* Sent response, no longer need to respond. */ ssl->keys.keyUpdateRespond = 0; +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + ret = Dtls13HandshakeSend(ssl, output, outputSz, + OPAQUE8_LEN + Dtls13GetRlHeaderLength(1) + DTLS_HANDSHAKE_HEADER_SZ, + key_update, 0); + } + else { +#endif /* WOLFSSL_DTLS13 */ + /* This message is always encrypted. */ sendSz = BuildTls13Message(ssl, output, outputSz, input, headerSz + OPAQUE8_LEN, handshake, 0, 0, 0); @@ -7935,15 +7954,26 @@ static int SendTls13KeyUpdate(WOLFSSL* ssl) ssl->buffers.outputBuffer.length += sendSz; ret = SendBuffered(ssl); + + if (ret != 0 && ret != WANT_WRITE) return ret; +#ifdef WOLFSSL_DTLS13 + } +#endif /* WOLFSSL_DTLS13 */ + + /* In DTLS we must wait for the ack before setting up the new keys */ + if (!ssl->options.dtls) { + + /* Future traffic uses new encryption keys. */ + if ((ret = DeriveTls13Keys( + ssl, update_traffic_key, ENCRYPT_SIDE_ONLY, 1)) + != 0) + return ret; + if ((ret = SetKeysSide(ssl, ENCRYPT_SIDE_ONLY)) != 0) + return ret; + } - /* Future traffic uses new encryption keys. */ - if ((ret = DeriveTls13Keys(ssl, update_traffic_key, ENCRYPT_SIDE_ONLY, 1)) - != 0) - return ret; - if ((ret = SetKeysSide(ssl, ENCRYPT_SIDE_ONLY)) != 0) - return ret; WOLFSSL_LEAVE("SendTls13KeyUpdate", ret); WOLFSSL_END(WC_FUNC_KEY_UPDATE_SEND); @@ -8001,8 +8031,37 @@ static int DoTls13KeyUpdate(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if ((ret = SetKeysSide(ssl, DECRYPT_SIDE_ONLY)) != 0) return ret; - if (ssl->keys.keyUpdateRespond) +#ifdef WOLFSSL_DTLS13 + if (ssl->options.dtls) { + w64Increment(&ssl->dtls13PeerEpoch); + + ret = Dtls13NewEpoch(ssl, ssl->dtls13PeerEpoch, DECRYPT_SIDE_ONLY); + if (ret != 0) + return ret; + + ret = Dtls13SetEpochKeys(ssl, ssl->dtls13PeerEpoch, DECRYPT_SIDE_ONLY); + if (ret != 0) + return ret; + } +#endif /* WOLFSSL_DTLS13 */ + + if (ssl->keys.keyUpdateRespond) { + +#ifdef WOLFSSL_DTLS13 + /* we already sent a keyUpdate (either in response to a previous + KeyUpdate or initiated by the application) and we are waiting for the + ack. We can't send a new KeyUpdate right away but to honor the RFC we + should send another KeyUpdate after the one in-flight is acked. We + don't do that as it looks redundant, it will make the code more + complex and I don't see a good use case for that. */ + if (ssl->options.dtls && ssl->dtls13WaitKeyUpdateAck) { + ssl->keys.keyUpdateRespond = 0; + return 0; + } +#endif /* WOLFSSL_DTLS13 */ + return SendTls13KeyUpdate(ssl); + } WOLFSSL_LEAVE("DoTls13KeyUpdate", ret); WOLFSSL_END(WC_FUNC_KEY_UPDATE_DO); @@ -9029,7 +9088,7 @@ int DoTls13HandShakeMsgType(WOLFSSL* ssl, byte* input, word32* inOutIdx, break; case key_update: - WOLFSSL_MSG("processing finished"); + WOLFSSL_MSG("processing key update"); ret = DoTls13KeyUpdate(ssl, input, inOutIdx, size); break; @@ -9894,6 +9953,17 @@ int wolfSSL_update_keys(WOLFSSL* ssl) if (ssl == NULL || !IsAtLeastTLSv1_3(ssl->version)) return BAD_FUNC_ARG; +#ifdef WOLFSSL_DTLS13 + /* we are already waiting for the ack of a sent key update message. We can't + send another one before receiving its ack. Either wolfSSL_update_keys() + was invoked multiple times over a short period of time or we replied to a + KeyUpdate with update request. We'll just ignore sending this + KeyUpdate. */ + /* TODO: add WOLFSSL_ERROR_ALREADY_IN_PROGRESS type of error here */ + if (ssl->options.dtls && ssl->dtls13WaitKeyUpdateAck) + return WOLFSSL_SUCCESS; +#endif /* WOLFSSL_DTLS13 */ + ret = SendTls13KeyUpdate(ssl); if (ret == WANT_WRITE) ret = WOLFSSL_ERROR_WANT_WRITE; diff --git a/wolfssl/internal.h b/wolfssl/internal.h index d40e5f120..8103dbc04 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -4635,6 +4635,7 @@ struct WOLFSSL { byte dtls13SendingFragments:1; byte dtls13SendingAckOrRtx:1; byte dtls13FastTimeout:1; + byte dtls13WaitKeyUpdateAck:1; word32 dtls13MessageLength; word32 dtls13FragOffset; byte dtls13FragHandshakeType;