diff --git a/src/internal.c b/src/internal.c index 11f4df2a6..28a461652 100644 --- a/src/internal.c +++ b/src/internal.c @@ -1551,17 +1551,36 @@ static int GetHandShakeHeader(CYASSL* ssl, const byte* input, word32* inOutIdx, (void)ssl; *inOutIdx += HANDSHAKE_HEADER_SZ; -#ifdef CYASSL_DTLS - if (ssl->options.dtls) - *inOutIdx += DTLS_HANDSHAKE_EXTRA; -#endif - *type = ptr[0]; c24to32(&ptr[1], size); return 0; } +#ifdef CYASSL_DTLS +static int GetDtlsHandShakeHeader(CYASSL* ssl, const byte* input, + word32* inOutIdx, byte *type, word32 *size, + word32 *fragOffset, word32 *fragSz) +{ + word32 seq; + word32 idx = *inOutIdx; + + (void)ssl; + *inOutIdx += HANDSHAKE_HEADER_SZ + DTLS_HANDSHAKE_EXTRA; + + *type = input[idx++]; + c24to32(input, size); + idx += BYTE3_LEN; + + c24to32(input, fragOffset); + idx += BYTE3_LEN; + c24to32(input, fragSz); + idx += BYTE3_LEN; + + return 0; +} +#endif + /* fill with MD5 pad size since biggest required */ static const byte PAD1[PAD_MD5] = @@ -2088,21 +2107,11 @@ int DoFinished(CYASSL* ssl, const byte* input, word32* inOutIdx, int sniff) } -static int DoHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, - word32 totalSz) +static int DoHandShakeMsgType(CYASSL* ssl, byte* input, word32* inOutIdx, + byte type, word32 size, word32 totalSz) { - byte type; - word32 size; int ret = 0; - CYASSL_ENTER("DoHandShakeMsg()"); - - if (GetHandShakeHeader(ssl, input, inOutIdx, &type, &size) != 0) - return PARSE_ERROR; - - if (*inOutIdx + size > totalSz) - return INCOMPLETE_DATA; - HashInput(ssl, input + *inOutIdx, size); #ifdef CYASSL_CALLBACKS /* add name later, add on record and handshake header part back on */ @@ -2187,6 +2196,54 @@ static int DoHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, ret = UNKNOWN_HANDSHAKE_TYPE; } + return ret; +} + + +#ifdef CYASSL_DTLS +static int DoDtlsHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, + word32 totalSz) +{ + byte type; + word32 size; + word32 fragOffset, fragSz; + int ret = 0; + + CYASSL_ENTER("DoDtlsHandShakeMsg()"); + if (GetDtlsHandShakeHeader(ssl, input, inOutIdx, &type, + &size, &fragOffset, &fragSz) != 0) + return PARSE_ERROR; + + if (*inOutIdx + size > totalSz) + return INCOMPLETE_DATA; + + /* XXX if fragmented, knit back together. */ + ret = DoHandShakeMsgType(ssl, input, inOutIdx, type, size, totalSz); + + CYASSL_LEAVE("DoDtlsHandShakeMsg()", ret); + return ret; +} +#endif + + +static int DoHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, + word32 totalSz) +{ + byte type; + word32 size; + word32 fragOffset, fragSz; + int ret = 0; + + CYASSL_ENTER("DoHandShakeMsg()"); + + if (GetHandShakeHeader(ssl, input, inOutIdx, &type, &size) != 0) + return PARSE_ERROR; + + if (*inOutIdx + size > totalSz) + return INCOMPLETE_DATA; + + ret = DoHandShakeMsgType(ssl, input, inOutIdx, type, size, totalSz); + CYASSL_LEAVE("DoHandShakeMsg()", ret); return ret; } @@ -2703,11 +2760,21 @@ int ProcessReply(CYASSL* ssl) switch (ssl->curRL.type) { case handshake : /* debugging in DoHandShakeMsg */ - if ((ret = DoHandShakeMsg(ssl, - ssl->buffers.inputBuffer.buffer, - &ssl->buffers.inputBuffer.idx, - ssl->buffers.inputBuffer.length)) - != 0) + if (!ssl->options.dtls) { + ret = DoHandShakeMsg(ssl, + ssl->buffers.inputBuffer.buffer, + &ssl->buffers.inputBuffer.idx, + ssl->buffers.inputBuffer.length); + } + else { +#ifdef CYASSL_DTLS + ret = DoDtlsHandShakeMsg(ssl, + ssl->buffers.inputBuffer.buffer, + &ssl->buffers.inputBuffer.idx, + ssl->buffers.inputBuffer.length); +#endif + } + if (ret != 0) return ret; break;