diff --git a/src/internal.c b/src/internal.c index ae8d40bd3..416a86c6b 100755 --- a/src/internal.c +++ b/src/internal.c @@ -121,8 +121,8 @@ WOLFSSL_CALLBACKS needs LARGE_STATIC_BUFFERS, please add LARGE_STATIC_BUFFERS #ifdef WOLFSSL_DTLS - static INLINE int DtlsCheckWindow(DtlsState* state); - static INLINE int DtlsUpdateWindow(DtlsState* state); + static INLINE int DtlsCheckWindow(WOLFSSL* ssl); + static INLINE int DtlsUpdateWindow(WOLFSSL* ssl); #endif @@ -187,7 +187,7 @@ static INLINE int IsEncryptionOn(WOLFSSL* ssl, int isSend) #ifdef WOLFSSL_DTLS /* For DTLS, epoch 0 is always not encrypted. */ - if (ssl->options.dtls && !isSend && ssl->keys.dtls_state.curEpoch == 0) + if (ssl->options.dtls && !isSend && ssl->keys.curEpoch == 0) return 0; #endif /* WOLFSSL_DTLS */ @@ -336,21 +336,6 @@ void c32to24(word32 in, word24 out) } -#ifdef WOLFSSL_DTLS - -static INLINE void c32to48(word32 in, byte out[6]) -{ - out[0] = 0; - out[1] = 0; - out[2] = (in >> 24) & 0xff; - out[3] = (in >> 16) & 0xff; - out[4] = (in >> 8) & 0xff; - out[5] = in & 0xff; -} - -#endif /* WOLFSSL_DTLS */ - - /* convert 16 bit integer to opaque */ static INLINE void c16toa(word16 u16, byte* c) { @@ -581,37 +566,43 @@ static int ExportKeyState(WOLFSSL* ssl, byte* exp, word32 len, byte ver) XMEMSET(exp, 0, DTLS_EXPORT_KEY_SZ); - c32toa(keys->peer_sequence_number, exp + idx); idx += OPAQUE32_LEN; - c32toa(keys->peer_sequence_number, exp + idx); idx += OPAQUE32_LEN; - c32toa(keys->sequence_number, exp + idx); idx += OPAQUE32_LEN; + c32toa(keys->peer_sequence_number_hi, exp + idx); idx += OPAQUE32_LEN; + c32toa(keys->peer_sequence_number_lo, exp + idx); idx += OPAQUE32_LEN; + c32toa(keys->sequence_number_hi, exp + idx); idx += OPAQUE32_LEN; + c32toa(keys->sequence_number_lo, exp + idx); idx += OPAQUE32_LEN; - c16toa(keys->dtls_state.nextEpoch, exp + idx); idx += OPAQUE16_LEN; - c32toa(keys->dtls_state.nextSeq, exp + idx); idx += OPAQUE32_LEN; - c16toa(keys->dtls_state.curEpoch, exp + idx); idx += OPAQUE16_LEN; - c32toa(keys->dtls_state.curSeq, exp + idx); idx += OPAQUE32_LEN; - c32toa(keys->dtls_state.prevSeq, exp + idx); idx += OPAQUE32_LEN; + c16toa(keys->nextEpoch, exp + idx); idx += OPAQUE16_LEN; + c16toa(keys->nextSeq_hi, exp + idx); idx += OPAQUE16_LEN; + c32toa(keys->nextSeq_lo, exp + idx); idx += OPAQUE32_LEN; + c16toa(keys->curEpoch, exp + idx); idx += OPAQUE16_LEN; + c16toa(keys->curSeq_hi, exp + idx); idx += OPAQUE16_LEN; + c32toa(keys->curSeq_lo, exp + idx); idx += OPAQUE32_LEN; + c16toa(keys->prevSeq_hi, exp + idx); idx += OPAQUE16_LEN; + c32toa(keys->prevSeq_lo, exp + idx); idx += OPAQUE32_LEN; c16toa(keys->dtls_peer_handshake_number, exp + idx); idx += OPAQUE16_LEN; c16toa(keys->dtls_expected_peer_handshake_number, exp + idx); idx += OPAQUE16_LEN; - c32toa(keys->dtls_sequence_number, exp + idx); idx += OPAQUE32_LEN; - c32toa(keys->dtls_prev_sequence_number, exp + idx); idx += OPAQUE32_LEN; - c16toa(keys->dtls_epoch, exp + idx); idx += OPAQUE16_LEN; - c16toa(keys->dtls_handshake_number, exp + idx); idx += OPAQUE16_LEN; - c32toa(keys->encryptSz, exp + idx); idx += OPAQUE32_LEN; - c32toa(keys->padSz, exp + idx); idx += OPAQUE32_LEN; + c16toa(keys->dtls_sequence_number_hi, exp + idx); idx += OPAQUE16_LEN; + c32toa(keys->dtls_sequence_number_lo, exp + idx); idx += OPAQUE32_LEN; + c16toa(keys->dtls_prev_sequence_number_hi, exp + idx); idx += OPAQUE16_LEN; + c32toa(keys->dtls_prev_sequence_number_lo, exp + idx); idx += OPAQUE32_LEN; + c16toa(keys->dtls_epoch, exp + idx); idx += OPAQUE16_LEN; + c16toa(keys->dtls_handshake_number, exp + idx); idx += OPAQUE16_LEN; + c32toa(keys->encryptSz, exp + idx); idx += OPAQUE32_LEN; + c32toa(keys->padSz, exp + idx); idx += OPAQUE32_LEN; exp[idx++] = keys->encryptionOn; exp[idx++] = keys->decryptedCur; #ifdef WORD64_AVAILABLE - c64toa(keys->dtls_state.window, exp + idx); idx += OPAQUE64_LEN; - c64toa(keys->dtls_state.prevWindow, exp + idx); idx += OPAQUE64_LEN; + c64toa(keys->window, exp + idx); idx += OPAQUE64_LEN; + c64toa(keys->prevWindow, exp + idx); idx += OPAQUE64_LEN; #else - c32toa(keys->dtls_state.window, exp + idx); idx += OPAQUE32_LEN; - c32toa(0, exp + idx); idx += OPAQUE32_LEN; - c32toa(keys->dtls_state.prevWindow, exp + idx); idx += OPAQUE32_LEN; - c32toa(0, exp + idx); idx += OPAQUE32_LEN; + c32toa(keys->window, exp + idx); idx += OPAQUE32_LEN; + c32toa(0, exp + idx); idx += OPAQUE32_LEN; + c32toa(keys->prevWindow, exp + idx); idx += OPAQUE32_LEN; + c32toa(0, exp + idx); idx += OPAQUE32_LEN; #endif #ifdef HAVE_TRUNCATED_HMAC @@ -706,37 +697,43 @@ static int ImportKeyState(WOLFSSL* ssl, byte* exp, word32 len, byte ver) if (len < DTLS_EXPORT_MIN_KEY_SZ) { return BUFFER_E; } - ato32(exp + idx, &keys->peer_sequence_number); idx += OPAQUE32_LEN; - ato32(exp + idx, &keys->peer_sequence_number); idx += OPAQUE32_LEN; - ato32(exp + idx, &keys->sequence_number); idx += OPAQUE32_LEN; + ato32(exp + idx, &keys->peer_sequence_number_hi); idx += OPAQUE32_LEN; + ato32(exp + idx, &keys->peer_sequence_number_lo); idx += OPAQUE32_LEN; + ato32(exp + idx, &keys->sequence_number_hi); idx += OPAQUE32_LEN; + ato32(exp + idx, &keys->sequence_number_lo); idx += OPAQUE32_LEN; - ato16(exp + idx, &keys->dtls_state.nextEpoch); idx += OPAQUE16_LEN; - ato32(exp + idx, &keys->dtls_state.nextSeq); idx += OPAQUE32_LEN; - ato16(exp + idx, &keys->dtls_state.curEpoch); idx += OPAQUE16_LEN; - ato32(exp + idx, &keys->dtls_state.curSeq); idx += OPAQUE32_LEN; - ato32(exp + idx, &keys->dtls_state.prevSeq); idx += OPAQUE32_LEN; + ato16(exp + idx, &keys->nextEpoch); idx += OPAQUE16_LEN; + ato16(exp + idx, &keys->nextSeq_hi); idx += OPAQUE16_LEN; + ato32(exp + idx, &keys->nextSeq_lo); idx += OPAQUE32_LEN; + ato16(exp + idx, &keys->curEpoch); idx += OPAQUE16_LEN; + ato16(exp + idx, &keys->curSeq_hi); idx += OPAQUE16_LEN; + ato32(exp + idx, &keys->curSeq_lo); idx += OPAQUE32_LEN; + ato16(exp + idx, &keys->prevSeq_hi); idx += OPAQUE16_LEN; + ato32(exp + idx, &keys->prevSeq_lo); idx += OPAQUE32_LEN; ato16(exp + idx, &keys->dtls_peer_handshake_number); idx += OPAQUE16_LEN; ato16(exp + idx, &keys->dtls_expected_peer_handshake_number); idx += OPAQUE16_LEN; - ato32(exp + idx, &keys->dtls_sequence_number); idx += OPAQUE32_LEN; - ato32(exp + idx, &keys->dtls_prev_sequence_number); idx += OPAQUE32_LEN; - ato16(exp + idx, &keys->dtls_epoch); idx += OPAQUE16_LEN; - ato16(exp + idx, &keys->dtls_handshake_number); idx += OPAQUE16_LEN; - ato32(exp + idx, &keys->encryptSz); idx += OPAQUE32_LEN; - ato32(exp + idx, &keys->padSz); idx += OPAQUE32_LEN; + ato16(exp + idx, &keys->dtls_sequence_number_hi); idx += OPAQUE16_LEN; + ato32(exp + idx, &keys->dtls_sequence_number_lo); idx += OPAQUE32_LEN; + ato16(exp + idx, &keys->dtls_prev_sequence_number_hi); idx += OPAQUE16_LEN; + ato32(exp + idx, &keys->dtls_prev_sequence_number_lo); idx += OPAQUE32_LEN; + ato16(exp + idx, &keys->dtls_epoch); idx += OPAQUE16_LEN; + ato16(exp + idx, &keys->dtls_handshake_number); idx += OPAQUE16_LEN; + ato32(exp + idx, &keys->encryptSz); idx += OPAQUE32_LEN; + ato32(exp + idx, &keys->padSz); idx += OPAQUE32_LEN; keys->encryptionOn = exp[idx++]; keys->decryptedCur = exp[idx++]; #ifdef WORD64_AVAILABLE - ato64(exp + idx, &keys->dtls_state.window); idx += OPAQUE64_LEN; - ato64(exp + idx, &keys->dtls_state.prevWindow); idx += OPAQUE64_LEN; + ato64(exp + idx, &keys->window); idx += OPAQUE64_LEN; + ato64(exp + idx, &keys->prevWindow); idx += OPAQUE64_LEN; #else - ato32(exp + idx, &keys->dtls_state.window); idx += OPAQUE32_LEN; - ato32(exp + idx, 0); idx += OPAQUE32_LEN; - ato32(exp + idx, &keys->dtls_state.prevWindow); idx += OPAQUE32_LEN; - ato32(exp + idx, 0); idx += OPAQUE32_LEN; + ato32(exp + idx, &keys->window); idx += OPAQUE32_LEN; + ato32(exp + idx, 0); idx += OPAQUE32_LEN; + ato32(exp + idx, &keys->prevWindow); idx += OPAQUE32_LEN; + ato32(exp + idx, 0); idx += OPAQUE32_LEN; #endif #ifdef HAVE_TRUNCATED_HMAC @@ -1044,7 +1041,7 @@ static int ExportPeerInfo(WOLFSSL* ssl, byte* exp, word32 len, byte ver) return SOCKET_ERROR_E; } - c16toa((word16)fam, exp + idx); idx += DTLS_EXPORT_LEN; + c16toa((word16)fam, exp + idx); idx += DTLS_EXPORT_LEN; c16toa((word16)ipSz, exp + idx); idx += DTLS_EXPORT_LEN; XMEMCPY(exp + idx, ip, ipSz); idx += ipSz; c16toa(port, exp + idx); idx += DTLS_EXPORT_LEN; @@ -3939,6 +3936,98 @@ void FreeSSL(WOLFSSL* ssl, void* heap) } +#if !defined(NO_OLD_TLS) || defined(HAVE_CHACHA) || defined(HAVE_AESCCM) \ + || defined(HAVE_AESGCM) +static INLINE void GetSEQIncrement(WOLFSSL* ssl, int verify, word32 seq[2]) +{ + if (verify) { + seq[0] = ssl->keys.peer_sequence_number_hi; + seq[1] = ssl->keys.peer_sequence_number_lo++; + if (seq[1] > ssl->keys.peer_sequence_number_lo) { + /* handle rollover */ + ssl->keys.peer_sequence_number_hi++; + } + } + else { + seq[0] = ssl->keys.sequence_number_hi; + seq[1] = ssl->keys.sequence_number_lo++; + if (seq[1] > ssl->keys.sequence_number_lo) { + /* handle rollover */ + ssl->keys.sequence_number_hi++; + } + } +} + + +#ifdef WOLFSSL_DTLS +static INLINE void DtlsGetSEQ(WOLFSSL* ssl, int order, word32 seq[2]) +{ + if (order == PREV_ORDER) { + /* Previous epoch case */ + seq[0] = ((ssl->keys.dtls_epoch - 1) << 16) | + (ssl->keys.dtls_prev_sequence_number_hi & 0xFFFF); + seq[1] = ssl->keys.dtls_prev_sequence_number_lo; + } + else if (order == PEER_ORDER) { + seq[0] = (ssl->keys.curEpoch << 16) | + (ssl->keys.curSeq_hi & 0xFFFF); + seq[1] = ssl->keys.curSeq_lo; /* explicit from peer */ + } + else { + seq[0] = (ssl->keys.dtls_epoch << 16) | + (ssl->keys.dtls_sequence_number_hi & 0xFFFF); + seq[1] = ssl->keys.dtls_sequence_number_lo; + } +} + +static INLINE void DtlsSEQIncrement(WOLFSSL* ssl, int order) +{ + word32 seq; + + if (order == PREV_ORDER) { + seq = ssl->keys.dtls_prev_sequence_number_lo++; + if (seq > ssl->keys.dtls_prev_sequence_number_lo) { + /* handle rollover */ + ssl->keys.dtls_prev_sequence_number_hi++; + } + } + else if (order == PEER_ORDER) { + seq = ssl->keys.peer_sequence_number_lo++; + if (seq > ssl->keys.peer_sequence_number_lo) { + /* handle rollover */ + ssl->keys.peer_sequence_number_hi++; + } + } + else { + seq = ssl->keys.dtls_sequence_number_lo++; + if (seq > ssl->keys.dtls_sequence_number_lo) { + /* handle rollover */ + ssl->keys.dtls_sequence_number_hi++; + } + } +} +#endif /* WOLFSSL_DTLS */ + + +static INLINE void WriteSEQ(WOLFSSL* ssl, int verifyOrder, byte* out) +{ + word32 seq[2] = {0, 0}; + + if (!ssl->options.dtls) { + GetSEQIncrement(ssl, verifyOrder, seq); + } + else { +#ifdef WOLFSSL_DTLS + DtlsGetSEQ(ssl, verifyOrder, seq); +#endif + } + + c32toa(seq[0], out); + c32toa(seq[1], out + OPAQUE32_LEN); +} +#endif + + #ifdef WOLFSSL_DTLS int DtlsPoolInit(WOLFSSL* ssl) @@ -4039,13 +4128,19 @@ int DtlsPoolSend(WOLFSSL* ssl) for (i = 0, buf = pool->buf; i < pool->used; i++, buf++) { if (pool->epoch[i] == 0) { DtlsRecordLayerHeader* dtls; - word32* seqNumber; + int epochOrder; dtls = (DtlsRecordLayerHeader*)buf->buffer; - seqNumber = (ssl->keys.dtls_epoch == 0) ? - &ssl->keys.dtls_sequence_number : - &ssl->keys.dtls_prev_sequence_number; - c32to48((*seqNumber)++, dtls->sequence_number); + /* If the stored record's epoch is 0, and the currently set + * epoch is 0, use the "current order" sequence number. + * If the stored record's epoch is 0 and the currently set + * epoch is not 0, the stored record is considered a "previous + * order" sequence number. */ + epochOrder = (ssl->keys.dtls_epoch == 0) ? + CUR_ORDER : PREV_ORDER; + + WriteSEQ(ssl, epochOrder, dtls->sequence_number); + DtlsSEQIncrement(ssl, epochOrder); if ((ret = CheckAvailableSize(ssl, buf->length)) != 0) return ret; @@ -4711,8 +4806,7 @@ static void AddRecordHeader(byte* output, word32 length, byte type, WOLFSSL* ssl /* dtls record layer header extensions */ dtls = (DtlsRecordLayerHeader*)output; - c16toa(ssl->keys.dtls_epoch, dtls->epoch); - c32to48(ssl->keys.dtls_sequence_number++, dtls->sequence_number); + WriteSEQ(ssl, 0, dtls->sequence_number); c16toa((word16)length, dtls->length); #endif } @@ -5102,10 +5196,12 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx, /* type and version in same sport */ XMEMCPY(rh, input + *inOutIdx, ENUM_LEN + VERSION_SZ); *inOutIdx += ENUM_LEN + VERSION_SZ; - ato16(input + *inOutIdx, &ssl->keys.dtls_state.curEpoch); - *inOutIdx += 4; /* advance past epoch, skip first 2 seq bytes for now */ - ato32(input + *inOutIdx, &ssl->keys.dtls_state.curSeq); - *inOutIdx += 4; /* advance past rest of seq */ + ato16(input + *inOutIdx, &ssl->keys.curEpoch); + *inOutIdx += OPAQUE16_LEN; + ato16(input + *inOutIdx, &ssl->keys.curSeq_hi); + *inOutIdx += OPAQUE16_LEN; + ato32(input + *inOutIdx, &ssl->keys.curSeq_lo); + *inOutIdx += OPAQUE32_LEN; /* advance past rest of seq */ ato16(input + *inOutIdx, size); *inOutIdx += LENGTH_SZ; #endif @@ -5113,8 +5209,8 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx, #ifdef WOLFSSL_DTLS if (IsDtlsNotSctpMode(ssl) && - (!DtlsCheckWindow(&ssl->keys.dtls_state) || - (ssl->options.handShakeDone && ssl->keys.dtls_state.curEpoch == 0))) { + (!DtlsCheckWindow(ssl) || + (ssl->options.handShakeDone && ssl->keys.curEpoch == 0))) { return SEQUENCE_ERROR; } #endif @@ -7697,33 +7793,61 @@ static int DoHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx, #ifdef WOLFSSL_DTLS -static INLINE int DtlsCheckWindow(DtlsState* state) +static INLINE int DtlsCheckWindow(WOLFSSL* ssl) { - word32 cur; - word32 next; DtlsSeq window; + word16 cur_hi, next_hi; + word32 cur_lo, next_lo, diff; + int curLT; - if (state->curEpoch == state->nextEpoch) { - next = state->nextSeq; - window = state->window; + if (ssl->keys.curEpoch == ssl->keys.nextEpoch) { + next_hi = ssl->keys.nextSeq_hi; + next_lo = ssl->keys.nextSeq_lo; + window = ssl->keys.window; } - else if (state->curEpoch == state->nextEpoch - 1) { - next = state->prevSeq; - window = state->prevWindow; + else if (ssl->keys.curEpoch == ssl->keys.nextEpoch - 1) { + next_hi = ssl->keys.prevSeq_hi; + next_lo = ssl->keys.prevSeq_lo; + window = ssl->keys.prevWindow; } else { return 0; } - cur = state->curSeq; + cur_hi = ssl->keys.curSeq_hi; + cur_lo = ssl->keys.curSeq_lo; - if ((next > DTLS_SEQ_BITS) && (cur < next - DTLS_SEQ_BITS)) { + /* If the difference between next and cur is > 2^32, way outside window. */ + if ((cur_hi > next_hi + 1) || (next_hi > cur_hi + 1)) { + WOLFSSL_MSG("Current record from way too far in the future."); return 0; } - else if ((cur < next) && (window & ((DtlsSeq)1 << (next - cur - 1)))) { + + if (cur_hi == next_hi) { + curLT = cur_lo < next_lo; + diff = curLT ? next_lo - cur_lo : cur_lo - next_lo; + } + else { + curLT = cur_hi < next_hi; + diff = curLT ? cur_lo - next_lo : next_lo - cur_lo; + } + + /* Check to see that the next value is greater than the number of messages + * trackable in the window (32 or 64), and that the difference between the + * next expected sequence number and the received sequence number is + * inside the window. */ + if ((next_hi || next_lo > DTLS_SEQ_BITS) && + curLT && (diff > DTLS_SEQ_BITS)) { + + WOLFSSL_MSG("Current record sequence number from the past."); return 0; } - else if (cur > next + DTLS_SEQ_BITS) { + else if (curLT && (window & ((DtlsSeq)1 << (diff - 1)))) { + WOLFSSL_MSG("Current record sequence number already received."); + return 0; + } + else if (!curLT && (diff > DTLS_SEQ_BITS)) { + WOLFSSL_MSG("Rejecting message too far into the future."); return 0; } @@ -7731,30 +7855,50 @@ static INLINE int DtlsCheckWindow(DtlsState* state) } -static INLINE int DtlsUpdateWindow(DtlsState* state) +static INLINE int DtlsUpdateWindow(WOLFSSL* ssl) { - word32 cur; - word32* next; DtlsSeq* window; + word32* next_lo; + word16* next_hi; + int curLT; + word32 cur_lo, diff; + word16 cur_hi; - if (state->curEpoch == state->nextEpoch) { - next = &state->nextSeq; - window = &state->window; + if (ssl->keys.curEpoch == ssl->keys.nextEpoch) { + next_hi = &ssl->keys.nextSeq_hi; + next_lo = &ssl->keys.nextSeq_lo; + window = &ssl->keys.window; } else { - next = &state->prevSeq; - window = &state->prevWindow; + next_hi = &ssl->keys.prevSeq_hi; + next_lo = &ssl->keys.prevSeq_lo; + window = &ssl->keys.prevWindow; } - cur = state->curSeq; + cur_hi = ssl->keys.curSeq_hi; + cur_lo = ssl->keys.curSeq_lo; - if (cur < *next) { - *window |= ((DtlsSeq)1 << (*next - cur - 1)); + if (cur_hi == *next_hi) { + curLT = cur_lo < *next_lo; + diff = curLT ? *next_lo - cur_lo : cur_lo - *next_lo; } else { - *window <<= (1 + cur - *next); + curLT = cur_hi < *next_hi; + diff = curLT ? cur_lo - *next_lo : *next_lo - cur_lo; + } + + if (curLT) { + *window |= ((DtlsSeq)1 << (diff - 1)); + } + else { + if (diff >= DTLS_SEQ_BITS) + *window = 0; + else + *window <<= (1 + diff); *window |= 1; - *next = cur + 1; + *next_lo = cur_lo + 1; + if (*next_lo < cur_lo) + (*next_hi)++; } return 1; @@ -7881,26 +8025,6 @@ static int DoDtlsHandShakeMsg(WOLFSSL* ssl, byte* input, word32* inOutIdx, #endif -#if !defined(NO_OLD_TLS) || defined(HAVE_CHACHA) || defined(HAVE_AESCCM) \ - || defined(HAVE_AESGCM) -static INLINE word32 GetSEQIncrement(WOLFSSL* ssl, int verify) -{ -#ifdef WOLFSSL_DTLS - if (ssl->options.dtls) { - if (verify) - return ssl->keys.dtls_state.curSeq; /* explicit from peer */ - else - return ssl->keys.dtls_sequence_number - 1; /* already incremented */ - } -#endif - if (verify) - return ssl->keys.peer_sequence_number++; - else - return ssl->keys.sequence_number++; -} -#endif - - #ifdef HAVE_AEAD static INLINE void AeadIncrementExpIV(WOLFSSL* ssl) { @@ -7986,18 +8110,18 @@ static int ChachaAEADEncrypt(WOLFSSL* ssl, byte* out, const byte* input, if (ssl->options.oldPoly != 0) { /* get nonce */ - c32toa(ssl->keys.sequence_number, nonce + CHACHA20_OLD_OFFSET); + WriteSEQ(ssl, CUR_ORDER, nonce + CHACHA20_OLD_OFFSET); } /* opaque SEQ number stored for AD */ - c32toa(GetSEQIncrement(ssl, 0), add + AEAD_SEQ_OFFSET); + WriteSEQ(ssl, CUR_ORDER, add); /* Store the type, version. Unfortunately, they are in * the input buffer ahead of the plaintext. */ #ifdef WOLFSSL_DTLS if (ssl->options.dtls) { - c16toa(ssl->keys.dtls_epoch, add); additionalSrc -= DTLS_HANDSHAKE_EXTRA; + DtlsSEQIncrement(ssl, CUR_ORDER); } #endif @@ -8130,23 +8254,18 @@ static int ChachaAEADDecrypt(WOLFSSL* ssl, byte* plain, const byte* input, if (ssl->options.oldPoly != 0) { /* get nonce */ - c32toa(ssl->keys.peer_sequence_number, nonce + CHACHA20_OLD_OFFSET); + WriteSEQ(ssl, PEER_ORDER, nonce + CHACHA20_OLD_OFFSET); } - /* sequence number field is 64-bits, we only use 32-bits */ - c32toa(GetSEQIncrement(ssl, 1), add + AEAD_SEQ_OFFSET); + /* sequence number field is 64-bits */ + WriteSEQ(ssl, PEER_ORDER, add); /* get AD info */ + /* Store the type, version. */ add[AEAD_TYPE_OFFSET] = ssl->curRL.type; add[AEAD_VMAJ_OFFSET] = ssl->curRL.pvMajor; add[AEAD_VMIN_OFFSET] = ssl->curRL.pvMinor; - /* Store the type, version. */ - #ifdef WOLFSSL_DTLS - if (ssl->options.dtls) - c16toa(ssl->keys.dtls_state.curEpoch, add); - #endif - /* add TLS message size to additional data */ add[AEAD_AUTH_DATA_SZ - 2] = (msgLen >> 8) & 0xff; add[AEAD_AUTH_DATA_SZ - 1] = msgLen & 0xff; @@ -8281,15 +8400,13 @@ static INLINE int Encrypt(WOLFSSL* ssl, byte* out, const byte* input, word16 sz) XMEMSET(additional, 0, AEAD_AUTH_DATA_SZ); - /* sequence number field is 64-bits, we only use 32-bits */ - c32toa(GetSEQIncrement(ssl, 0), - additional + AEAD_SEQ_OFFSET); + /* sequence number field is 64-bits */ + WriteSEQ(ssl, CUR_ORDER, additional); /* Store the type, version. Unfortunately, they are in * the input buffer ahead of the plaintext. */ #ifdef WOLFSSL_DTLS if (ssl->options.dtls) { - c16toa(ssl->keys.dtls_epoch, additional); additionalSrc -= DTLS_HANDSHAKE_EXTRA; } #endif @@ -8312,6 +8429,10 @@ static INLINE int Encrypt(WOLFSSL* ssl, byte* out, const byte* input, word16 sz) additional, AEAD_AUTH_DATA_SZ); AeadIncrementExpIV(ssl); ForceZero(nonce, AESGCM_NONCE_SZ); + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif } break; #endif @@ -8326,15 +8447,13 @@ static INLINE int Encrypt(WOLFSSL* ssl, byte* out, const byte* input, word16 sz) XMEMSET(additional, 0, AEAD_AUTH_DATA_SZ); - /* sequence number field is 64-bits, we only use 32-bits */ - c32toa(GetSEQIncrement(ssl, 0), - additional + AEAD_SEQ_OFFSET); + /* sequence number field is 64-bits */ + WriteSEQ(ssl, CUR_ORDER, additional); /* Store the type, version. Unfortunately, they are in * the input buffer ahead of the plaintext. */ #ifdef WOLFSSL_DTLS if (ssl->options.dtls) { - c16toa(ssl->keys.dtls_epoch, additional); additionalSrc -= DTLS_HANDSHAKE_EXTRA; } #endif @@ -8357,6 +8476,10 @@ static INLINE int Encrypt(WOLFSSL* ssl, byte* out, const byte* input, word16 sz) additional, AEAD_AUTH_DATA_SZ); AeadIncrementExpIV(ssl); ForceZero(nonce, AESGCM_NONCE_SZ); + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif } break; #endif @@ -8450,13 +8573,8 @@ static INLINE int Decrypt(WOLFSSL* ssl, byte* plain, const byte* input, XMEMSET(additional, 0, AEAD_AUTH_DATA_SZ); - /* sequence number field is 64-bits, we only use 32-bits */ - c32toa(GetSEQIncrement(ssl, 1), additional + AEAD_SEQ_OFFSET); - - #ifdef WOLFSSL_DTLS - if (ssl->options.dtls) - c16toa(ssl->keys.dtls_state.curEpoch, additional); - #endif + /* sequence number field is 64-bits */ + WriteSEQ(ssl, PEER_ORDER, additional); additional[AEAD_TYPE_OFFSET] = ssl->curRL.type; additional[AEAD_VMAJ_OFFSET] = ssl->curRL.pvMajor; @@ -8492,13 +8610,8 @@ static INLINE int Decrypt(WOLFSSL* ssl, byte* plain, const byte* input, XMEMSET(additional, 0, AEAD_AUTH_DATA_SZ); - /* sequence number field is 64-bits, we only use 32-bits */ - c32toa(GetSEQIncrement(ssl, 1), additional + AEAD_SEQ_OFFSET); - - #ifdef WOLFSSL_DTLS - if (ssl->options.dtls) - c16toa(ssl->keys.dtls_state.curEpoch, additional); - #endif + /* sequence number field is 64-bits */ + WriteSEQ(ssl, PEER_ORDER, additional); additional[AEAD_TYPE_OFFSET] = ssl->curRL.type; additional[AEAD_VMAJ_OFFSET] = ssl->curRL.pvMajor; @@ -9343,7 +9456,7 @@ int ProcessReply(WOLFSSL* ssl) #ifdef WOLFSSL_DTLS if (IsDtlsNotSctpMode(ssl)) { - DtlsUpdateWindow(&ssl->keys.dtls_state); + DtlsUpdateWindow(ssl); } #endif /* WOLFSSL_DTLS */ @@ -9453,8 +9566,10 @@ int ProcessReply(WOLFSSL* ssl) #ifdef WOLFSSL_DTLS if (ssl->options.dtls) { DtlsPoolReset(ssl); - ssl->keys.dtls_state.nextEpoch++; - ssl->keys.dtls_state.nextSeq = 0; + ssl->keys.nextEpoch++; + ssl->keys.nextSeq_lo = 0; + ssl->keys.prevWindow = ssl->keys.window; + ssl->keys.window = 0; } #endif @@ -9645,7 +9760,7 @@ static int SSL_hmac(WOLFSSL* ssl, byte* digest, const byte* in, word32 sz, XMEMSET(seq, 0, SEQ_SZ); conLen[0] = (byte)content; c16toa((word16)sz, &conLen[ENUM_LEN]); - c32toa(GetSEQIncrement(ssl, verify), &seq[sizeof(word32)]); + WriteSEQ(ssl, verify, seq); if (ssl->specs.mac_algorithm == md5_mac) { wc_InitMd5(&md5); @@ -9966,6 +10081,10 @@ int BuildMessage(WOLFSSL* ssl, byte* output, int outSz, const byte* input, #endif ret = ssl->hmac(ssl, output+idx, output + headerSz + ivSz, inSz, type, 0); + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif } if (ret != 0) return ret; @@ -10003,9 +10122,12 @@ int SendFinished(WOLFSSL* ssl) if (ssl->options.dtls) { headerSz += DTLS_HANDSHAKE_EXTRA; ssl->keys.dtls_epoch++; - ssl->keys.dtls_prev_sequence_number = - ssl->keys.dtls_sequence_number; - ssl->keys.dtls_sequence_number = 0; + ssl->keys.dtls_prev_sequence_number_hi = + ssl->keys.dtls_sequence_number_hi; + ssl->keys.dtls_prev_sequence_number_lo = + ssl->keys.dtls_sequence_number_lo; + ssl->keys.dtls_sequence_number_hi = 0; + ssl->keys.dtls_sequence_number_lo = 0; } #endif @@ -10190,7 +10312,6 @@ int SendCertificate(WOLFSSL* ssl) HANDSHAKE_HEADER_SZ + DTLS_HANDSHAKE_EXTRA); /* Adding the headers increments these, decrement them for * actual message header. */ - ssl->keys.dtls_sequence_number--; ssl->keys.dtls_handshake_number--; AddFragHeaders(output, fragSz, 0, payloadSz, certificate, ssl); ssl->keys.dtls_handshake_number--; @@ -10277,6 +10398,12 @@ int SendCertificate(WOLFSSL* ssl) if (sendSz < 0) return sendSz; } + else { + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif + } #ifdef WOLFSSL_DTLS if (IsDtlsNotSctpMode(ssl)) { @@ -10302,8 +10429,8 @@ int SendCertificate(WOLFSSL* ssl) /* Clean up the fragment offset. */ ssl->fragOffset = 0; #ifdef WOLFSSL_DTLS - if (ssl->options.dtls) - ssl->keys.dtls_handshake_number++; + if (ssl->options.dtls) + ssl->keys.dtls_handshake_number++; #endif if (ssl->options.side == WOLFSSL_SERVER_END) ssl->options.serverState = SERVER_CERT_COMPLETE; @@ -10378,6 +10505,8 @@ int SendCertificateRequest(WOLFSSL* ssl) if ((ret = DtlsPoolSave(ssl, output, sendSz)) != 0) return ret; } + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); #endif ret = HashOutput(ssl, output, sendSz, 0); @@ -10469,8 +10598,13 @@ static int BuildCertificateStatus(WOLFSSL* ssl, byte type, buffer* status, if (sendSz < 0) ret = sendSz; } - else + else { + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif ret = HashOutput(ssl, output, sendSz, 0); + } #ifdef WOLFSSL_DTLS if (ret == 0 && IsDtlsNotSctpMode(ssl)) @@ -12902,6 +13036,10 @@ static void PickHashSigAlgo(WOLFSSL* ssl, if (sendSz < 0) return sendSz; } else { + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif ret = HashOutput(ssl, output, sendSz, 0); if (ret != 0) return ret; @@ -15423,6 +15561,10 @@ int SendClientKeyExchange(WOLFSSL* ssl) } } else { + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif ret = HashOutput(ssl, output, sendSz, 0); if (ret != 0) { goto exit_scke; @@ -15903,6 +16045,10 @@ int SendCertificateVerify(WOLFSSL* ssl) } } else { + #ifdef WOLFSSL_DTLS + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); + #endif ret = HashOutput(ssl, output, sendSz, 0); } @@ -16126,7 +16272,8 @@ int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if (ssl->options.dtls) { /* Server Hello should use the same sequence number as the * Client Hello. */ - ssl->keys.dtls_sequence_number = ssl->keys.dtls_state.curSeq; + ssl->keys.dtls_sequence_number_hi = ssl->keys.curSeq_hi; + ssl->keys.dtls_sequence_number_lo = ssl->keys.curSeq_lo; idx += DTLS_RECORD_EXTRA + DTLS_HANDSHAKE_EXTRA; sendSz += DTLS_RECORD_EXTRA + DTLS_HANDSHAKE_EXTRA; } @@ -16206,6 +16353,10 @@ int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if ((ret = DtlsPoolSave(ssl, output, sendSz)) != 0) return ret; } + + if (ssl->options.dtls) { + DtlsSEQIncrement(ssl, CUR_ORDER); + } #endif ret = HashOutput(ssl, output, sendSz, 0); @@ -17531,6 +17682,9 @@ int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, goto exit_sske; } } + + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); #endif ret = HashOutput(ssl, output, sendSz, 0); @@ -18749,6 +18903,9 @@ int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if ((ret = DtlsPoolSave(ssl, output, sendSz)) != 0) return 0; } + + if (ssl->options.dtls) + DtlsSEQIncrement(ssl, CUR_ORDER); #endif ret = HashOutput(ssl, output, sendSz, 0); @@ -18967,6 +19124,8 @@ int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if (ssl->options.dtls) { if ((ret = DtlsPoolSave(ssl, output, sendSz)) != 0) return ret; + + DtlsSEQIncrement(ssl, CUR_ORDER); } #endif @@ -19000,7 +19159,8 @@ int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, /* Hello Verify Request should use the same sequence number as the * Client Hello. */ - ssl->keys.dtls_sequence_number = ssl->keys.dtls_state.curSeq; + ssl->keys.dtls_sequence_number_hi = ssl->keys.curSeq_hi; + ssl->keys.dtls_sequence_number_lo = ssl->keys.curSeq_lo; AddHeaders(output, length, hello_verify_request, ssl); #ifdef OPENSSL_EXTRA diff --git a/src/keys.c b/src/keys.c index 50e888521..e630072a8 100644 --- a/src/keys.c +++ b/src/keys.c @@ -2628,10 +2628,14 @@ static int SetKeys(Ciphers* enc, Ciphers* dec, Keys* keys, CipherSpecs* specs, } #endif - if (enc) - keys->sequence_number = 0; - if (dec) - keys->peer_sequence_number = 0; + if (enc) { + keys->sequence_number_hi = 0; + keys->sequence_number_lo = 0; + } + if (dec) { + keys->peer_sequence_number_hi = 0; + keys->peer_sequence_number_lo = 0; + } (void)side; (void)heap; (void)enc; @@ -2747,7 +2751,8 @@ int SetKeysSide(WOLFSSL* ssl, enum encrypt_side side) keys->server_write_IV, MAX_WRITE_IV_SZ); } if (wc_encrypt) { - ssl->keys.sequence_number = keys->sequence_number; + ssl->keys.sequence_number_hi = keys->sequence_number_hi; + ssl->keys.sequence_number_lo = keys->sequence_number_lo; #ifdef HAVE_AEAD if (ssl->specs.cipher_type == aead) { /* Initialize the AES-GCM/CCM explicit IV to a zero. */ @@ -2766,7 +2771,8 @@ int SetKeysSide(WOLFSSL* ssl, enum encrypt_side side) #endif } if (wc_decrypt) { - ssl->keys.peer_sequence_number = keys->peer_sequence_number; + ssl->keys.peer_sequence_number_hi = keys->peer_sequence_number_hi; + ssl->keys.peer_sequence_number_lo = keys->peer_sequence_number_lo; #ifdef HAVE_AEAD if (ssl->specs.cipher_type == aead) { /* Initialize decrypt implicit IV by decrypt side */ diff --git a/src/tls.c b/src/tls.c index a6ad8dc2e..c81017e13 100644 --- a/src/tls.c +++ b/src/tls.c @@ -642,36 +642,68 @@ static INLINE void c32toa(word32 u32, byte* c) } -static INLINE word32 GetSEQIncrement(WOLFSSL* ssl, int verify) +static INLINE void GetSEQIncrement(WOLFSSL* ssl, int verify, word32 seq[2]) { -#ifdef WOLFSSL_DTLS - if (ssl->options.dtls) { - if (verify) - return ssl->keys.dtls_state.curSeq; /* explicit from peer */ - else - return ssl->keys.dtls_sequence_number - 1; /* already incremented */ + if (verify) { + seq[0] = ssl->keys.peer_sequence_number_hi; + seq[1] = ssl->keys.peer_sequence_number_lo++; + if (seq[1] > ssl->keys.peer_sequence_number_lo) { + /* handle rollover */ + ssl->keys.peer_sequence_number_hi++; + } + } + else { + seq[0] = ssl->keys.sequence_number_hi; + seq[1] = ssl->keys.sequence_number_lo++; + if (seq[1] > ssl->keys.sequence_number_lo) { + /* handle rollover */ + ssl->keys.sequence_number_hi++; + } } -#endif - if (verify) - return ssl->keys.peer_sequence_number++; - else - return ssl->keys.sequence_number++; } #ifdef WOLFSSL_DTLS - -static INLINE word32 GetEpoch(WOLFSSL* ssl, int verify) +static INLINE void DtlsGetSEQ(WOLFSSL* ssl, int order, word32 seq[2]) { - if (verify) - return ssl->keys.dtls_state.curEpoch; - else - return ssl->keys.dtls_epoch; + if (order == PREV_ORDER) { + /* Previous epoch case */ + seq[0] = ((ssl->keys.dtls_epoch - 1) << 16) | + (ssl->keys.dtls_prev_sequence_number_hi & 0xFFFF); + seq[1] = ssl->keys.dtls_prev_sequence_number_lo; + } + else if (order == PEER_ORDER) { + seq[0] = (ssl->keys.curEpoch << 16) | + (ssl->keys.curSeq_hi & 0xFFFF); + seq[1] = ssl->keys.curSeq_lo; /* explicit from peer */ + } + else { + seq[0] = (ssl->keys.dtls_epoch << 16) | + (ssl->keys.dtls_sequence_number_hi & 0xFFFF); + seq[1] = ssl->keys.dtls_sequence_number_lo; + } } - #endif /* WOLFSSL_DTLS */ +static INLINE void WriteSEQ(WOLFSSL* ssl, int verifyOrder, byte* out) +{ + word32 seq[2] = {0, 0}; + + if (!ssl->options.dtls) { + GetSEQIncrement(ssl, verifyOrder, seq); + } + else { +#ifdef WOLFSSL_DTLS + DtlsGetSEQ(ssl, verifyOrder, seq); +#endif + } + + c32toa(seq[0], out); + c32toa(seq[1], out + OPAQUE32_LEN); +} + + /*** end copy ***/ @@ -729,11 +761,7 @@ int wolfSSL_SetTlsHmacInner(WOLFSSL* ssl, byte* inner, word32 sz, int content, XMEMSET(inner, 0, WOLFSSL_TLS_HMAC_INNER_SZ); -#ifdef WOLFSSL_DTLS - if (ssl->options.dtls) - c16toa((word16)GetEpoch(ssl, verify), inner); -#endif - c32toa(GetSEQIncrement(ssl, verify), &inner[sizeof(word32)]); + WriteSEQ(ssl, verify, inner); inner[SEQ_SZ] = (byte)content; inner[SEQ_SZ + ENUM_LEN] = ssl->version.major; inner[SEQ_SZ + ENUM_LEN + ENUM_LEN] = ssl->version.minor; diff --git a/wolfssl/internal.h b/wolfssl/internal.h index a9cacda45..842ace428 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -994,7 +994,7 @@ enum Misc { CHACHA20_IMP_IV_SZ = 12, /* Size of ChaCha20 AEAD implicit IV */ CHACHA20_NONCE_SZ = 12, /* Size of ChacCha20 nonce */ - CHACHA20_OLD_OFFSET = 8, /* Offset for seq # in old poly1305 */ + CHACHA20_OLD_OFFSET = 4, /* Offset for seq # in old poly1305 */ /* For any new implicit/explicit IV size adjust AEAD_MAX_***_SZ */ @@ -1062,7 +1062,11 @@ enum Misc { HASH_SIG_SIZE = 2, /* default SHA1 RSA */ NO_COPY = 0, /* should we copy static buffer for write */ - COPY = 1 /* should we copy static buffer for write */ + COPY = 1, /* should we copy static buffer for write */ + + PREV_ORDER = -1, /* Sequence number is in previous epoch. */ + PEER_ORDER = 1, /* Peer sequence number for verify. */ + CUR_ORDER = 0 /* Current sequence number. */ }; @@ -1581,18 +1585,6 @@ typedef struct WOLFSSL_DTLS_CTX { #endif #define DTLS_SEQ_BITS (sizeof(DtlsSeq) * CHAR_BIT) - typedef struct DtlsState { - DtlsSeq window; /* Sliding window for current epoch */ - word16 nextEpoch; /* Expected epoch in next record */ - word32 nextSeq; /* Expected sequence in next record */ - - word16 curEpoch; /* Received epoch in current record */ - word32 curSeq; /* Received sequence in current record */ - - DtlsSeq prevWindow; /* Sliding window for old epoch */ - word32 prevSeq; /* Next sequence in allowed old epoch */ - } DtlsState; - #endif /* WOLFSSL_DTLS */ @@ -1613,17 +1605,33 @@ typedef struct Keys { byte aead_dec_imp_IV[AEAD_MAX_IMP_SZ]; #endif - word32 peer_sequence_number; - word32 sequence_number; + word32 peer_sequence_number_hi; + word32 peer_sequence_number_lo; + word32 sequence_number_hi; + word32 sequence_number_lo; #ifdef WOLFSSL_DTLS - DtlsState dtls_state; /* Peer's state */ + DtlsSeq window; /* Sliding window for current epoch */ + word16 nextEpoch; /* Expected epoch in next record */ + word16 nextSeq_hi; /* Expected sequence in next record */ + word32 nextSeq_lo; + + word16 curEpoch; /* Received epoch in current record */ + word16 curSeq_hi; /* Received sequence in current record */ + word32 curSeq_lo; + + DtlsSeq prevWindow; /* Sliding window for old epoch */ + word16 prevSeq_hi; /* Next sequence in allowed old epoch */ + word32 prevSeq_lo; + word16 dtls_peer_handshake_number; word16 dtls_expected_peer_handshake_number; - word32 dtls_sequence_number; /* Current tx sequence */ - word32 dtls_prev_sequence_number; /* Previous epoch's seq number*/ - word16 dtls_epoch; /* Current tx epoch */ + word16 dtls_epoch; /* Current epoch */ + word16 dtls_sequence_number_hi; /* Current epoch */ + word32 dtls_sequence_number_lo; + word16 dtls_prev_sequence_number_hi; /* Previous epoch */ + word32 dtls_prev_sequence_number_lo; word16 dtls_handshake_number; /* Current tx handshake seq */ #endif @@ -2571,8 +2579,7 @@ typedef struct DtlsRecordLayerHeader { byte type; byte pvMajor; byte pvMinor; - byte epoch[2]; /* increment on cipher state change */ - byte sequence_number[6]; /* per record */ + byte sequence_number[8]; /* per record */ byte length[2]; } DtlsRecordLayerHeader;