dtls13: support stream-based medium

Don't assume that the underlying medium of DTLS provides the full message in a
single operation. This is usually true for message-based socket (eg. using UDP)
and false for stream-based socket (eg. using TCP).

Commit changes:

- Do not error out if we don't have the full message while parsing the header.
- Do not assume that the record header is still in the buffer when decrypting
  the message.
- Try to get more data if we didn't read the full DTLS header.
This commit is contained in:
Marco Oliverio
2022-07-19 14:56:52 +02:00
parent 9a3efb67b8
commit 6711756b03
3 changed files with 50 additions and 16 deletions

View File

@@ -1190,6 +1190,26 @@ int Dtls13ReconstructEpochNumber(WOLFSSL* ssl, byte epochBits,
return SEQUENCE_ERROR;
}
int Dtls13GetUnifiedHeaderSize(const byte input, word16* size)
{
if (size == NULL)
return BAD_FUNC_ARG;
if (input & DTLS13_CID_BIT) {
WOLFSSL_MSG("DTLS1.3 header with connection ID. Not supported");
return WOLFSSL_NOT_IMPLEMENTED;
}
/* flags (1) + seq 8bit (1) */
*size = OPAQUE8_LEN + OPAQUE8_LEN;
if (input & DTLS13_SEQ_LEN_BIT)
*size += OPAQUE8_LEN;
if (input & DTLS13_LEN_BIT)
*size += OPAQUE16_LEN;
return 0;
}
/**
* Dtls13ParseUnifiedRecordLayer() - parse DTLS unified header
* @ssl: [in] ssl object
@@ -1236,10 +1256,6 @@ int Dtls13ParseUnifiedRecordLayer(WOLFSSL* ssl, const byte* input,
ato16(input + idx, &hdrInfo->recordLength);
idx += DTLS13_LEN_SIZE;
/* DTLS message must fit inside a datagram */
if (inputSize < idx + hdrInfo->recordLength)
return LENGTH_ERROR;
}
else {
/* length not present. The size of the record is the all the remaining
@@ -1259,8 +1275,6 @@ int Dtls13ParseUnifiedRecordLayer(WOLFSSL* ssl, const byte* input,
if (ret != 0)
return ret;
hdrInfo->headerLength = idx;
if (seqLen == DTLS13_SEQ_16_LEN) {
hdrInfo->seqHiPresent = 1;
hdrInfo->seqHi = seqNum[0];

View File

@@ -9675,6 +9675,7 @@ int CheckAvailableSize(WOLFSSL *ssl, int size)
}
#ifdef WOLFSSL_DTLS13
static int GetInputData(WOLFSSL *ssl, word32 size);
static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
word32* inOutIdx, RecordLayerHeader* rh, word16* size)
{
@@ -9687,6 +9688,9 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
readSize = ssl->buffers.inputBuffer.length - *inOutIdx;
if (readSize < DTLS_UNIFIED_HEADER_MIN_SZ)
return BUFFER_ERROR;
epochBits = *input & EE_MASK;
ret = Dtls13ReconstructEpochNumber(ssl, epochBits, &epochNumber);
if (ret != 0)
@@ -9718,6 +9722,20 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
return SEQUENCE_ERROR;
}
ret = Dtls13GetUnifiedHeaderSize(
*(input+*inOutIdx), &ssl->dtls13CurRlLength);
if (ret != 0)
return ret;
if (readSize < ssl->dtls13CurRlLength) {
/* when using DTLS over a medium that does not guarantee that a full
* message is received in a single read, we may end up without the full
* header */
ret = GetInputData(ssl, ssl->dtls13CurRlLength - readSize);
if (ret != 0)
return ret;
}
ret = Dtls13ParseUnifiedRecordLayer(ssl, input + *inOutIdx, readSize,
&hdrInfo);
@@ -9745,8 +9763,8 @@ static int GetDtls13RecordHeader(WOLFSSL* ssl, const byte* input,
ssl->keys.curSeq);
#endif /* WOLFSSL_DEBUG_TLS */
*inOutIdx += hdrInfo.headerLength;
ssl->dtls13CurRlLength = hdrInfo.headerLength;
XMEMCPY(ssl->dtls13CurRL, input + *inOutIdx, ssl->dtls13CurRlLength);
*inOutIdx += ssl->dtls13CurRlLength;
return 0;
}
@@ -9793,10 +9811,12 @@ static int GetDtlsRecordHeader(WOLFSSL* ssl, const byte* input,
}
/* not a unified header, check that we have at least
DTLS_RECORD_HEADER_SZ */
if (read_size < DTLS_RECORD_HEADER_SZ)
return LENGTH_ERROR;
* DTLS_RECORD_HEADER_SZ */
if (read_size < DTLS_RECORD_HEADER_SZ) {
ret = GetInputData(ssl, DTLS_RECORD_HEADER_SZ - read_size);
if (ret != 0)
return LENGTH_ERROR;
}
#endif /* WOLFSSL_DTLS13 */
/* type and version in same spot */
@@ -18466,8 +18486,7 @@ int ProcessReplyEx(WOLFSSL* ssl, int allowSocketErr)
#ifdef WOLFSSL_DTLS13
if (ssl->options.dtls) {
/* aad now points to the record header */
aad = in->buffer +
in->idx - ssl->dtls13CurRlLength;
aad = ssl->dtls13CurRL;
aad_size = ssl->dtls13CurRlLength;
}
#endif /* WOLFSSL_DTLS13 */

View File

@@ -1314,6 +1314,7 @@ enum Misc {
DTLS_HANDSHAKE_HEADER_SZ = 12, /* normal + seq(2) + offset(3) + length(3) */
DTLS_RECORD_HEADER_SZ = 13, /* normal + epoch(2) + seq_num(6) */
DTLS_UNIFIED_HEADER_MIN_SZ = 2,
DTLS_RECVD_RL_HEADER_MAX_SZ = 5, /* flags + seq_number(2) + length(20) */
DTLS_RECORD_HEADER_MAX_SZ = 13,
DTLS_HANDSHAKE_EXTRA = 8, /* diff from normal */
DTLS_RECORD_EXTRA = 8, /* diff from normal */
@@ -4368,7 +4369,6 @@ typedef enum EarlyDataState {
typedef struct Dtls13UnifiedHdrInfo {
word16 recordLength;
word16 headerLength;
byte seqLo;
byte seqHi;
byte seqHiPresent:1;
@@ -4658,7 +4658,7 @@ struct WOLFSSL {
Dtls13Epoch *dtls13DecryptEpoch;
w64wrapper dtls13Epoch;
w64wrapper dtls13PeerEpoch;
byte dtls13CurRL[DTLS_RECVD_RL_HEADER_MAX_SZ];
word16 dtls13CurRlLength;
/* used to store the message if it needs to be fragmented */
@@ -5453,6 +5453,7 @@ WOLFSSL_LOCAL int Dtls13RlAddPlaintextHeader(WOLFSSL* ssl, byte* out,
WOLFSSL_LOCAL int Dtls13EncryptRecordNumber(WOLFSSL* ssl, byte* hdr,
word16 recordLength);
WOLFSSL_LOCAL int Dtls13IsUnifiedHeader(byte header_flags);
WOLFSSL_LOCAL int Dtls13GetUnifiedHeaderSize(const byte input, word16* size);
WOLFSSL_LOCAL int Dtls13ParseUnifiedRecordLayer(WOLFSSL* ssl, const byte* input,
word16 input_size, Dtls13UnifiedHdrInfo* hdrInfo);
WOLFSSL_LOCAL int Dtls13HandshakeSend(WOLFSSL* ssl, byte* output,