diff --git a/src/dtls13.c b/src/dtls13.c index bd76ffe52..aeaaa3285 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -1190,6 +1190,26 @@ int Dtls13ReconstructEpochNumber(WOLFSSL* ssl, byte epochBits, return SEQUENCE_ERROR; } +int Dtls13GetUnifiedHeaderSize(const byte input, word16* size) +{ + if (size == NULL) + return BAD_FUNC_ARG; + + if (input & DTLS13_CID_BIT) { + WOLFSSL_MSG("DTLS1.3 header with connection ID. Not supported"); + return WOLFSSL_NOT_IMPLEMENTED; + } + + /* flags (1) + seq 8bit (1) */ + *size = OPAQUE8_LEN + OPAQUE8_LEN; + if (input & DTLS13_SEQ_LEN_BIT) + *size += OPAQUE8_LEN; + if (input & DTLS13_LEN_BIT) + *size += OPAQUE16_LEN; + + return 0; +} + /** * Dtls13ParseUnifiedRecordLayer() - parse DTLS unified header * @ssl: [in] ssl object @@ -1236,10 +1256,6 @@ int Dtls13ParseUnifiedRecordLayer(WOLFSSL* ssl, const byte* input, ato16(input + idx, &hdrInfo->recordLength); idx += DTLS13_LEN_SIZE; - - /* DTLS message must fit inside a datagram */ - if (inputSize < idx + hdrInfo->recordLength) - return LENGTH_ERROR; } else { /* length not present. The size of the record is the all the remaining @@ -1259,8 +1275,6 @@ int Dtls13ParseUnifiedRecordLayer(WOLFSSL* ssl, const byte* input, if (ret != 0) return ret; - hdrInfo->headerLength = idx; - if (seqLen == DTLS13_SEQ_16_LEN) { hdrInfo->seqHiPresent = 1; hdrInfo->seqHi = seqNum[0]; diff --git a/src/internal.c b/src/internal.c index 57c4fffc9..dd2af5d13 100644 --- a/src/internal.c +++ b/src/internal.c @@ -9675,6 +9675,7 @@ int CheckAvailableSize(WOLFSSL *ssl, int size) } #ifdef WOLFSSL_DTLS13 +static int GetInputData(WOLFSSL *ssl, word32 size); static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx, RecordLayerHeader* rh, word16* size) { @@ -9687,6 +9688,9 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input, readSize = ssl->buffers.inputBuffer.length - *inOutIdx; + if (readSize < DTLS_UNIFIED_HEADER_MIN_SZ) + return BUFFER_ERROR; + epochBits = *input & EE_MASK; ret = Dtls13ReconstructEpochNumber(ssl, epochBits, &epochNumber); if (ret != 0) @@ -9718,6 +9722,20 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input, return SEQUENCE_ERROR; } + ret = Dtls13GetUnifiedHeaderSize( + *(input+*inOutIdx), &ssl->dtls13CurRlLength); + if (ret != 0) + return ret; + + if (readSize < ssl->dtls13CurRlLength) { + /* when using DTLS over a medium that does not guarantee that a full + * message is received in a single read, we may end up without the full + * header */ + ret = GetInputData(ssl, ssl->dtls13CurRlLength - readSize); + if (ret != 0) + return ret; + } + ret = Dtls13ParseUnifiedRecordLayer(ssl, input + *inOutIdx, readSize, &hdrInfo); @@ -9745,8 +9763,8 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input, ssl->keys.curSeq); #endif /* WOLFSSL_DEBUG_TLS */ - *inOutIdx += hdrInfo.headerLength; - ssl->dtls13CurRlLength = hdrInfo.headerLength; + XMEMCPY(ssl->dtls13CurRL, input + *inOutIdx, ssl->dtls13CurRlLength); + *inOutIdx += ssl->dtls13CurRlLength; return 0; } @@ -9793,10 +9811,12 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input, } /* not a unified header, check that we have at least - DTLS_RECORD_HEADER_SZ */ - if (read_size < DTLS_RECORD_HEADER_SZ) - return LENGTH_ERROR; - + * DTLS_RECORD_HEADER_SZ */ + if (read_size < DTLS_RECORD_HEADER_SZ) { + ret = GetInputData(ssl, DTLS_RECORD_HEADER_SZ - read_size); + if (ret != 0) + return LENGTH_ERROR; + } #endif /* WOLFSSL_DTLS13 */ /* type and version in same spot */ @@ -18466,8 +18486,7 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr) #ifdef WOLFSSL_DTLS13 if (ssl->options.dtls) { /* aad now points to the record header */ - aad = in->buffer + - in->idx - ssl->dtls13CurRlLength; + aad = ssl->dtls13CurRL; aad_size = ssl->dtls13CurRlLength; } #endif /* WOLFSSL_DTLS13 */ diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 2a1a5f811..800352fd5 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -1314,6 +1314,7 @@ enum Misc { DTLS_HANDSHAKE_HEADER_SZ = 12, /* normal + seq(2) + offset(3) + length(3) */ DTLS_RECORD_HEADER_SZ = 13, /* normal + epoch(2) + seq_num(6) */ DTLS_UNIFIED_HEADER_MIN_SZ = 2, + DTLS_RECVD_RL_HEADER_MAX_SZ = 5, /* flags + seq_number(2) + length(20) */ DTLS_RECORD_HEADER_MAX_SZ = 13, DTLS_HANDSHAKE_EXTRA = 8, /* diff from normal */ DTLS_RECORD_EXTRA = 8, /* diff from normal */ @@ -4368,7 +4369,6 @@ typedef enum EarlyDataState { typedef struct Dtls13UnifiedHdrInfo { word16 recordLength; - word16 headerLength; byte seqLo; byte seqHi; byte seqHiPresent:1; @@ -4658,7 +4658,7 @@ struct WOLFSSL { Dtls13Epoch *dtls13DecryptEpoch; w64wrapper dtls13Epoch; w64wrapper dtls13PeerEpoch; - + byte dtls13CurRL[DTLS_RECVD_RL_HEADER_MAX_SZ]; word16 dtls13CurRlLength; /* used to store the message if it needs to be fragmented */ @@ -5453,6 +5453,7 @@ WOLFSSL_LOCAL int Dtls13RlAddPlaintextHeader(WOLFSSL* ssl, byte* out, WOLFSSL_LOCAL int Dtls13EncryptRecordNumber(WOLFSSL* ssl, byte* hdr, word16 recordLength); WOLFSSL_LOCAL int Dtls13IsUnifiedHeader(byte header_flags); +WOLFSSL_LOCAL int Dtls13GetUnifiedHeaderSize(const byte input, word16* size); WOLFSSL_LOCAL int Dtls13ParseUnifiedRecordLayer(WOLFSSL* ssl, const byte* input, word16 input_size, Dtls13UnifiedHdrInfo* hdrInfo); WOLFSSL_LOCAL int Dtls13HandshakeSend(WOLFSSL* ssl, byte* output,