dtls13: change encryption keys dynamically based on the epoch

In DTLSv1.3, because of retransmission and reordering, we may need to encrypt or
decrypt records with older keys. As an example, if the server finished message
is lost, the server will need to retransmit that message using handshake traffic
keys, even if he already used the traffic0 ones (as, for example, to send
NewSessionTicket just after the finished message).

This commit implements a way to save the key bound to a DTLS epoch and setting
the right key/epoch when needed.
This commit is contained in:
Marco Oliverio
2022-05-20 09:59:33 +02:00
committed by David Garske
parent de04973051
commit 2696c3cdd3
3 changed files with 282 additions and 0 deletions

View File

@@ -159,6 +159,230 @@ static int Dtls13InitChaChaCipher(RecordNumberCiphers* c, byte* key,
}
#endif /* HAVE_CHACHA */
struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl, w64wrapper epochNumber)
{
Dtls13Epoch* e;
int i;
for (i = 0; i < DTLS13_EPOCH_SIZE; ++i) {
e = &ssl->dtls13Epochs[i];
if (w64Equal(e->epochNumber, epochNumber) && e->isValid)
return e;
}
return NULL;
}
static void Dtls13EpochCopyKeys(WOLFSSL* ssl, Dtls13Epoch* e, Keys* k, int side)
{
byte clientWrite, serverWrite;
byte enc, dec;
WOLFSSL_ENTER("Dtls13SetEpochKeys");
clientWrite = serverWrite = 0;
enc = dec = 0;
switch (side) {
case ENCRYPT_SIDE_ONLY:
if (ssl->options.side == WOLFSSL_CLIENT_END)
clientWrite = 1;
if (ssl->options.side == WOLFSSL_SERVER_END)
serverWrite = 1;
enc = 1;
break;
case DECRYPT_SIDE_ONLY:
if (ssl->options.side == WOLFSSL_CLIENT_END)
serverWrite = 1;
if (ssl->options.side == WOLFSSL_SERVER_END)
clientWrite = 1;
dec = 1;
break;
case ENCRYPT_AND_DECRYPT_SIDE:
clientWrite = serverWrite = 1;
enc = dec = 1;
break;
}
if (clientWrite) {
XMEMCPY(e->client_write_key, k->client_write_key,
sizeof(e->client_write_key));
XMEMCPY(e->client_write_IV, k->client_write_IV,
sizeof(e->client_write_IV));
XMEMCPY(e->client_sn_key, k->client_sn_key, sizeof(e->client_sn_key));
}
if (serverWrite) {
XMEMCPY(e->server_write_key, k->server_write_key,
sizeof(e->server_write_key));
XMEMCPY(e->server_write_IV, k->server_write_IV,
sizeof(e->server_write_IV));
XMEMCPY(e->server_sn_key, k->server_sn_key, sizeof(e->server_sn_key));
}
if (enc)
XMEMCPY(e->aead_enc_imp_IV, k->aead_enc_imp_IV,
sizeof(e->aead_enc_imp_IV));
if (dec)
XMEMCPY(e->aead_dec_imp_IV, k->aead_dec_imp_IV,
sizeof(e->aead_dec_imp_IV));
}
static Dtls13Epoch* Dtls13NewEpochSlot(WOLFSSL* ssl)
{
Dtls13Epoch *e, *oldest = NULL;
w64wrapper oldestNumber;
int i;
/* FIXME: add max function */
oldestNumber = w64From32((word32)-1, (word32)-1);
oldest = NULL;
for (i = 0; i < DTLS13_EPOCH_SIZE; ++i) {
e = &ssl->dtls13Epochs[i];
if (!e->isValid)
return e;
if (!w64Equal(e->epochNumber, ssl->dtls13Epoch) &&
!w64Equal(e->epochNumber, ssl->dtls13PeerEpoch) &&
w64LT(e->epochNumber, oldestNumber))
oldest = e;
}
if (oldest == NULL)
return NULL;
e = oldest;
#ifdef WOLFSSL_DEBUG_TLS
WOLFSSL_MSG_EX("Delete epoch: %d", e->epochNumber);
#endif /* WOLFSSL_DEBUG_TLS */
XMEMSET(e, 0, sizeof(*e));
return e;
}
int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber, int side)
{
Dtls13Epoch* e;
#ifdef WOLFSSL_DEBUG_TLS
WOLFSSL_MSG_EX("New epoch: %d", w64GetLow32(epochNumber));
#endif /* WOLFSSL_DEBUG_TLS */
e = Dtls13GetEpoch(ssl, epochNumber);
if (e == NULL) {
e = Dtls13NewEpochSlot(ssl);
if (e == NULL)
return BAD_STATE_E;
}
Dtls13EpochCopyKeys(ssl, e, &ssl->keys, side);
if (!e->isValid) {
/* fresh epoch, initialize fields */
e->epochNumber = epochNumber;
e->isValid = 1;
e->side = side;
}
else if (e->side != side) {
/* epoch used for the other side already. update side */
e->side = ENCRYPT_AND_DECRYPT_SIDE;
}
return 0;
}
int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber,
enum encrypt_side side)
{
byte clientWrite, serverWrite;
Dtls13Epoch* e;
byte enc, dec;
WOLFSSL_ENTER("Dtls13SetEpochKeys");
clientWrite = serverWrite = 0;
enc = dec = 0;
switch (side) {
case ENCRYPT_SIDE_ONLY:
if (ssl->options.side == WOLFSSL_CLIENT_END)
clientWrite = 1;
if (ssl->options.side == WOLFSSL_SERVER_END)
serverWrite = 1;
enc = 1;
break;
case DECRYPT_SIDE_ONLY:
if (ssl->options.side == WOLFSSL_CLIENT_END)
serverWrite = 1;
if (ssl->options.side == WOLFSSL_SERVER_END)
clientWrite = 1;
dec = 1;
break;
case ENCRYPT_AND_DECRYPT_SIDE:
clientWrite = serverWrite = 1;
enc = dec = 1;
break;
}
e = Dtls13GetEpoch(ssl, epochNumber);
/* we don't have the requested key */
if (e == NULL)
return BAD_STATE_E;
if (e->side != ENCRYPT_AND_DECRYPT_SIDE && e->side != side)
return BAD_STATE_E;
if (enc)
ssl->dtls13EncryptEpoch = e;
if (dec)
ssl->dtls13DecryptEpoch = e;
/* epoch 0 has no key to copy */
if (w64IsZero(epochNumber))
return 0;
if (clientWrite) {
XMEMCPY(ssl->keys.client_write_key, e->client_write_key,
sizeof(ssl->keys.client_write_key));
XMEMCPY(ssl->keys.client_write_IV, e->client_write_IV,
sizeof(ssl->keys.client_write_IV));
XMEMCPY(ssl->keys.client_sn_key, e->client_sn_key,
sizeof(ssl->keys.client_sn_key));
}
if (serverWrite) {
XMEMCPY(ssl->keys.server_write_key, e->server_write_key,
sizeof(ssl->keys.server_write_key));
XMEMCPY(ssl->keys.server_write_IV, e->server_write_IV,
sizeof(ssl->keys.server_write_IV));
XMEMCPY(ssl->keys.server_sn_key, e->server_sn_key,
sizeof(ssl->keys.server_sn_key));
}
if (enc)
XMEMCPY(ssl->keys.aead_enc_imp_IV, e->aead_enc_imp_IV,
sizeof(ssl->keys.aead_enc_imp_IV));
if (dec)
XMEMCPY(ssl->keys.aead_dec_imp_IV, e->aead_dec_imp_IV,
sizeof(ssl->keys.aead_dec_imp_IV));
return SetKeysSide(ssl, side);
}
int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side)
{
RecordNumberCiphers* enc = NULL;

View File

@@ -6762,6 +6762,15 @@ int InitSSL(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup)
}
#endif /* HAVE_SECURE_RENEGOTIATION */
#ifdef WOLFSSL_DTLS13
/* setup 0 (un-protected) epoch */
ssl->dtls13Epochs[0].isValid = 1;
ssl->dtls13Epochs[0].side = ENCRYPT_AND_DECRYPT_SIDE;
ssl->dtls13EncryptEpoch = &ssl->dtls13Epochs[0];
ssl->dtls13DecryptEpoch = &ssl->dtls13Epochs[0];
#endif /* WOLFSSL_DTLS13 */
return 0;
}

View File

@@ -4317,6 +4317,43 @@ typedef enum EarlyDataState {
} EarlyDataState;
#endif
#ifdef WOLFSSL_DTLS13
enum {
DTLS13_EPOCH_EARLYDATA = 1,
DTLS13_EPOCH_HANDSHAKE = 2,
DTLS13_EPOCH_TRAFFIC0 = 3
};
typedef struct Dtls13Epoch {
w64wrapper epochNumber;
w64wrapper nextSeqNumber;
w64wrapper nextPeerSeqNumber;
word32 window[WOLFSSL_DTLS_WINDOW_WORDS];
/* key material for the epoch */
byte client_write_key[MAX_SYM_KEY_SIZE];
byte server_write_key[MAX_SYM_KEY_SIZE];
byte client_write_IV[MAX_WRITE_IV_SZ];
byte server_write_IV[MAX_WRITE_IV_SZ];
byte aead_exp_IV[AEAD_MAX_EXP_SZ];
byte aead_enc_imp_IV[AEAD_MAX_IMP_SZ];
byte aead_dec_imp_IV[AEAD_MAX_IMP_SZ];
byte client_sn_key[MAX_SYM_KEY_SIZE];
byte server_sn_key[MAX_SYM_KEY_SIZE];
byte isValid;
byte side;
} Dtls13Epoch;
#define DTLS13_EPOCH_SIZE 3
#endif /* WOLFSSL_DTLS13 */
/* wolfSSL ssl type */
struct WOLFSSL {
WOLFSSL_CTX* ctx;
@@ -4514,6 +4551,12 @@ struct WOLFSSL {
#ifdef WOLFSSL_DTLS13
RecordNumberCiphers dtlsRecordNumberEncrypt;
RecordNumberCiphers dtlsRecordNumberDecrypt;
Dtls13Epoch dtls13Epochs[DTLS13_EPOCH_SIZE];
Dtls13Epoch *dtls13EncryptEpoch;
Dtls13Epoch *dtls13DecryptEpoch;
w64wrapper dtls13Epoch;
w64wrapper dtls13PeerEpoch;
#endif /* WOLFSSL_DTLS13 */
#endif /* WOLFSSL_DTLS */
@@ -5260,6 +5303,12 @@ WOLFSSL_LOCAL word32 nid2oid(int nid, int grp);
#ifdef WOLFSSL_DTLS13
WOLFSSL_LOCAL struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl,
w64wrapper epochNumber);
WOLFSSL_LOCAL int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber,
int side);
WOLFSSL_LOCAL int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber,
enum encrypt_side side);
WOLFSSL_LOCAL int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision);
WOLFSSL_LOCAL int Dtls13SetRecordNumberKeys(WOLFSSL* ssl,
enum encrypt_side side);