diff --git a/src/internal.c b/src/internal.c index e1c9c6130..62e88aab3 100644 --- a/src/internal.c +++ b/src/internal.c @@ -3107,11 +3107,14 @@ static int GetRecordHeader(CYASSL* ssl, const byte* input, word32* inOutIdx, static int GetHandShakeHeader(CYASSL* ssl, const byte* input, word32* inOutIdx, - byte *type, word32 *size) + byte *type, word32 *size, word32 totalSz) { const byte *ptr = input + *inOutIdx; (void)ssl; + *inOutIdx += HANDSHAKE_HEADER_SZ; + if (*inOutIdx > totalSz) + return BUFFER_E; *type = ptr[0]; c24to32(&ptr[1], size); @@ -3122,12 +3125,15 @@ static int GetHandShakeHeader(CYASSL* ssl, const byte* input, word32* inOutIdx, #ifdef CYASSL_DTLS static int GetDtlsHandShakeHeader(CYASSL* ssl, const byte* input, - word32* inOutIdx, byte *type, word32 *size, - word32 *fragOffset, word32 *fragSz) + word32* inOutIdx, byte *type, word32 *size, + word32 *fragOffset, word32 *fragSz, + word32 totalSz) { word32 idx = *inOutIdx; *inOutIdx += HANDSHAKE_HEADER_SZ + DTLS_HANDSHAKE_EXTRA; + if (*inOutIdx > totalSz) + return BUFFER_E; *type = input[idx++]; c24to32(input + idx, size); @@ -5073,7 +5079,7 @@ static int DoHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, CYASSL_ENTER("DoHandShakeMsg()"); - if (GetHandShakeHeader(ssl, input, inOutIdx, &type, &size) != 0) + if (GetHandShakeHeader(ssl, input, inOutIdx, &type, &size, totalSz) != 0) return PARSE_ERROR; ret = DoHandShakeMsgType(ssl, input, inOutIdx, type, size, totalSz); @@ -5181,7 +5187,7 @@ static int DoDtlsHandShakeMsg(CYASSL* ssl, byte* input, word32* inOutIdx, CYASSL_ENTER("DoDtlsHandShakeMsg()"); if (GetDtlsHandShakeHeader(ssl, input, inOutIdx, &type, - &size, &fragOffset, &fragSz) != 0) + &size, &fragOffset, &fragSz, totalSz) != 0) return PARSE_ERROR; if (*inOutIdx + fragSz > totalSz) @@ -9019,7 +9025,8 @@ static void PickHashSigAlgo(CYASSL* ssl, ssl->suites->sigAlgo = ssl->specs.sig_algo; ssl->suites->hashAlgo = sha_mac; - for (i = 0; i < hashSigAlgoSz; i += 2) { + /* i+1 since peek a byte ahead for type */ + for (i = 0; (i+1) < hashSigAlgoSz; i += 2) { if (hashSigAlgo[i+1] == ssl->specs.sig_algo) { if (hashSigAlgo[i] == sha_mac) { break;