diff --git a/src/dtls13.c b/src/dtls13.c index 66ef38eb2..c0b825b19 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -159,6 +159,230 @@ static int Dtls13InitChaChaCipher(RecordNumberCiphers* c, byte* key, } #endif /* HAVE_CHACHA */ +struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl, w64wrapper epochNumber) +{ + Dtls13Epoch* e; + int i; + + for (i = 0; i < DTLS13_EPOCH_SIZE; ++i) { + e = &ssl->dtls13Epochs[i]; + if (w64Equal(e->epochNumber, epochNumber) && e->isValid) + return e; + } + + return NULL; +} + +static void Dtls13EpochCopyKeys(WOLFSSL* ssl, Dtls13Epoch* e, Keys* k, int side) +{ + byte clientWrite, serverWrite; + byte enc, dec; + + WOLFSSL_ENTER("Dtls13SetEpochKeys"); + + clientWrite = serverWrite = 0; + enc = dec = 0; + switch (side) { + + case ENCRYPT_SIDE_ONLY: + if (ssl->options.side == WOLFSSL_CLIENT_END) + clientWrite = 1; + if (ssl->options.side == WOLFSSL_SERVER_END) + serverWrite = 1; + enc = 1; + break; + + case DECRYPT_SIDE_ONLY: + if (ssl->options.side == WOLFSSL_CLIENT_END) + serverWrite = 1; + if (ssl->options.side == WOLFSSL_SERVER_END) + clientWrite = 1; + dec = 1; + break; + + case ENCRYPT_AND_DECRYPT_SIDE: + clientWrite = serverWrite = 1; + enc = dec = 1; + break; + } + + if (clientWrite) { + XMEMCPY(e->client_write_key, k->client_write_key, + sizeof(e->client_write_key)); + + XMEMCPY(e->client_write_IV, k->client_write_IV, + sizeof(e->client_write_IV)); + + XMEMCPY(e->client_sn_key, k->client_sn_key, sizeof(e->client_sn_key)); + } + + if (serverWrite) { + XMEMCPY(e->server_write_key, k->server_write_key, + sizeof(e->server_write_key)); + XMEMCPY(e->server_write_IV, k->server_write_IV, + sizeof(e->server_write_IV)); + XMEMCPY(e->server_sn_key, k->server_sn_key, sizeof(e->server_sn_key)); + } + + if (enc) + XMEMCPY(e->aead_enc_imp_IV, k->aead_enc_imp_IV, + sizeof(e->aead_enc_imp_IV)); + + if (dec) + XMEMCPY(e->aead_dec_imp_IV, k->aead_dec_imp_IV, + sizeof(e->aead_dec_imp_IV)); +} + +static Dtls13Epoch* Dtls13NewEpochSlot(WOLFSSL* ssl) +{ + Dtls13Epoch *e, *oldest = NULL; + w64wrapper oldestNumber; + int i; + + /* FIXME: add max function */ + oldestNumber = w64From32((word32)-1, (word32)-1); + oldest = NULL; + + for (i = 0; i < DTLS13_EPOCH_SIZE; ++i) { + e = &ssl->dtls13Epochs[i]; + if (!e->isValid) + return e; + + if (!w64Equal(e->epochNumber, ssl->dtls13Epoch) && + !w64Equal(e->epochNumber, ssl->dtls13PeerEpoch) && + w64LT(e->epochNumber, oldestNumber)) + oldest = e; + } + + if (oldest == NULL) + return NULL; + + e = oldest; + +#ifdef WOLFSSL_DEBUG_TLS + WOLFSSL_MSG_EX("Delete epoch: %d", e->epochNumber); +#endif /* WOLFSSL_DEBUG_TLS */ + + XMEMSET(e, 0, sizeof(*e)); + + return e; +} + +int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber, int side) +{ + Dtls13Epoch* e; + +#ifdef WOLFSSL_DEBUG_TLS + WOLFSSL_MSG_EX("New epoch: %d", w64GetLow32(epochNumber)); +#endif /* WOLFSSL_DEBUG_TLS */ + + e = Dtls13GetEpoch(ssl, epochNumber); + if (e == NULL) { + e = Dtls13NewEpochSlot(ssl); + if (e == NULL) + return BAD_STATE_E; + } + + Dtls13EpochCopyKeys(ssl, e, &ssl->keys, side); + + if (!e->isValid) { + /* fresh epoch, initialize fields */ + e->epochNumber = epochNumber; + e->isValid = 1; + e->side = side; + } + else if (e->side != side) { + /* epoch used for the other side already. update side */ + e->side = ENCRYPT_AND_DECRYPT_SIDE; + } + + return 0; +} + +int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber, + enum encrypt_side side) +{ + byte clientWrite, serverWrite; + Dtls13Epoch* e; + byte enc, dec; + + WOLFSSL_ENTER("Dtls13SetEpochKeys"); + + clientWrite = serverWrite = 0; + enc = dec = 0; + switch (side) { + + case ENCRYPT_SIDE_ONLY: + if (ssl->options.side == WOLFSSL_CLIENT_END) + clientWrite = 1; + if (ssl->options.side == WOLFSSL_SERVER_END) + serverWrite = 1; + enc = 1; + break; + + case DECRYPT_SIDE_ONLY: + if (ssl->options.side == WOLFSSL_CLIENT_END) + serverWrite = 1; + if (ssl->options.side == WOLFSSL_SERVER_END) + clientWrite = 1; + dec = 1; + break; + + case ENCRYPT_AND_DECRYPT_SIDE: + clientWrite = serverWrite = 1; + enc = dec = 1; + break; + } + + e = Dtls13GetEpoch(ssl, epochNumber); + /* we don't have the requested key */ + if (e == NULL) + return BAD_STATE_E; + + if (e->side != ENCRYPT_AND_DECRYPT_SIDE && e->side != side) + return BAD_STATE_E; + + if (enc) + ssl->dtls13EncryptEpoch = e; + if (dec) + ssl->dtls13DecryptEpoch = e; + + /* epoch 0 has no key to copy */ + if (w64IsZero(epochNumber)) + return 0; + + if (clientWrite) { + XMEMCPY(ssl->keys.client_write_key, e->client_write_key, + sizeof(ssl->keys.client_write_key)); + + XMEMCPY(ssl->keys.client_write_IV, e->client_write_IV, + sizeof(ssl->keys.client_write_IV)); + + XMEMCPY(ssl->keys.client_sn_key, e->client_sn_key, + sizeof(ssl->keys.client_sn_key)); + } + + if (serverWrite) { + XMEMCPY(ssl->keys.server_write_key, e->server_write_key, + sizeof(ssl->keys.server_write_key)); + + XMEMCPY(ssl->keys.server_write_IV, e->server_write_IV, + sizeof(ssl->keys.server_write_IV)); + + XMEMCPY(ssl->keys.server_sn_key, e->server_sn_key, + sizeof(ssl->keys.server_sn_key)); + } + + if (enc) + XMEMCPY(ssl->keys.aead_enc_imp_IV, e->aead_enc_imp_IV, + sizeof(ssl->keys.aead_enc_imp_IV)); + if (dec) + XMEMCPY(ssl->keys.aead_dec_imp_IV, e->aead_dec_imp_IV, + sizeof(ssl->keys.aead_dec_imp_IV)); + + return SetKeysSide(ssl, side); +} + int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side) { RecordNumberCiphers* enc = NULL; diff --git a/src/internal.c b/src/internal.c index ad3c099b5..692307a93 100644 --- a/src/internal.c +++ b/src/internal.c @@ -6762,6 +6762,15 @@ int InitSSL(WOLFSSL* ssl, WOLFSSL_CTX* ctx, int writeDup) } #endif /* HAVE_SECURE_RENEGOTIATION */ + +#ifdef WOLFSSL_DTLS13 + /* setup 0 (un-protected) epoch */ + ssl->dtls13Epochs[0].isValid = 1; + ssl->dtls13Epochs[0].side = ENCRYPT_AND_DECRYPT_SIDE; + ssl->dtls13EncryptEpoch = &ssl->dtls13Epochs[0]; + ssl->dtls13DecryptEpoch = &ssl->dtls13Epochs[0]; +#endif /* WOLFSSL_DTLS13 */ + return 0; } diff --git a/wolfssl/internal.h b/wolfssl/internal.h index c192ae257..9b238babc 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -4317,6 +4317,43 @@ typedef enum EarlyDataState { } EarlyDataState; #endif +#ifdef WOLFSSL_DTLS13 + +enum { + DTLS13_EPOCH_EARLYDATA = 1, + DTLS13_EPOCH_HANDSHAKE = 2, + DTLS13_EPOCH_TRAFFIC0 = 3 +}; + +typedef struct Dtls13Epoch { + w64wrapper epochNumber; + + w64wrapper nextSeqNumber; + w64wrapper nextPeerSeqNumber; + + word32 window[WOLFSSL_DTLS_WINDOW_WORDS]; + + /* key material for the epoch */ + byte client_write_key[MAX_SYM_KEY_SIZE]; + byte server_write_key[MAX_SYM_KEY_SIZE]; + byte client_write_IV[MAX_WRITE_IV_SZ]; + byte server_write_IV[MAX_WRITE_IV_SZ]; + + byte aead_exp_IV[AEAD_MAX_EXP_SZ]; + byte aead_enc_imp_IV[AEAD_MAX_IMP_SZ]; + byte aead_dec_imp_IV[AEAD_MAX_IMP_SZ]; + + byte client_sn_key[MAX_SYM_KEY_SIZE]; + byte server_sn_key[MAX_SYM_KEY_SIZE]; + + byte isValid; + byte side; +} Dtls13Epoch; + +#define DTLS13_EPOCH_SIZE 3 + +#endif /* WOLFSSL_DTLS13 */ + /* wolfSSL ssl type */ struct WOLFSSL { WOLFSSL_CTX* ctx; @@ -4514,6 +4551,12 @@ struct WOLFSSL { #ifdef WOLFSSL_DTLS13 RecordNumberCiphers dtlsRecordNumberEncrypt; RecordNumberCiphers dtlsRecordNumberDecrypt; + Dtls13Epoch dtls13Epochs[DTLS13_EPOCH_SIZE]; + Dtls13Epoch *dtls13EncryptEpoch; + Dtls13Epoch *dtls13DecryptEpoch; + w64wrapper dtls13Epoch; + w64wrapper dtls13PeerEpoch; + #endif /* WOLFSSL_DTLS13 */ #endif /* WOLFSSL_DTLS */ @@ -5260,6 +5303,12 @@ WOLFSSL_LOCAL word32 nid2oid(int nid, int grp); #ifdef WOLFSSL_DTLS13 +WOLFSSL_LOCAL struct Dtls13Epoch* Dtls13GetEpoch(WOLFSSL* ssl, + w64wrapper epochNumber); +WOLFSSL_LOCAL int Dtls13NewEpoch(WOLFSSL* ssl, w64wrapper epochNumber, + int side); +WOLFSSL_LOCAL int Dtls13SetEpochKeys(WOLFSSL* ssl, w64wrapper epochNumber, + enum encrypt_side side); WOLFSSL_LOCAL int Dtls13DeriveSnKeys(WOLFSSL* ssl, int provision); WOLFSSL_LOCAL int Dtls13SetRecordNumberKeys(WOLFSSL* ssl, enum encrypt_side side);