allow sniffer to handle bundled record layer messages

This commit is contained in:
toddouska
2014-12-03 11:58:50 -08:00
parent f1c6e901a4
commit 264e180147

View File

@ -1775,7 +1775,8 @@ static int DoHandShake(const byte* input, int* sslBytes,
byte type; byte type;
int size; int size;
int ret = 0; int ret = 0;
int startBytes;
if (*sslBytes < HANDSHAKE_HEADER_SZ) { if (*sslBytes < HANDSHAKE_HEADER_SZ) {
SetError(HANDSHAKE_INPUT_STR, error, session, FATAL_ERROR_STATE); SetError(HANDSHAKE_INPUT_STR, error, session, FATAL_ERROR_STATE);
return -1; return -1;
@ -1785,7 +1786,8 @@ static int DoHandShake(const byte* input, int* sslBytes,
input += HANDSHAKE_HEADER_SZ; input += HANDSHAKE_HEADER_SZ;
*sslBytes -= HANDSHAKE_HEADER_SZ; *sslBytes -= HANDSHAKE_HEADER_SZ;
startBytes = *sslBytes;
if (*sslBytes < size) { if (*sslBytes < size) {
SetError(HANDSHAKE_INPUT_STR, error, session, FATAL_ERROR_STATE); SetError(HANDSHAKE_INPUT_STR, error, session, FATAL_ERROR_STATE);
return -1; return -1;
@ -1839,7 +1841,9 @@ static int DoHandShake(const byte* input, int* sslBytes,
default: default:
SetError(GOT_UNKNOWN_HANDSHAKE_STR, error, session, 0); SetError(GOT_UNKNOWN_HANDSHAKE_STR, error, session, 0);
return -1; return -1;
} }
*sslBytes = startBytes - size; /* actual bytes of full process */
return ret; return ret;
} }
@ -2629,7 +2633,7 @@ static int ProcessMessage(const byte* sslFrame, SnifferSession* session,
int sslBytes, byte* data, const byte* end,char* error) int sslBytes, byte* data, const byte* end,char* error)
{ {
const byte* sslBegin = sslFrame; const byte* sslBegin = sslFrame;
const byte* tmp; const byte* recordEnd;
RecordLayerHeader rh; RecordLayerHeader rh;
int rhSize = 0; int rhSize = 0;
int ret; int ret;
@ -2674,7 +2678,7 @@ doMessage:
} }
sslFrame += RECORD_HEADER_SZ; sslFrame += RECORD_HEADER_SZ;
sslBytes -= RECORD_HEADER_SZ; sslBytes -= RECORD_HEADER_SZ;
tmp = sslFrame + rhSize; /* may have more than one record to process */ recordEnd = sslFrame + rhSize; /* may have more than one record */
/* decrypt if needed */ /* decrypt if needed */
if ((session->flags.side == CYASSL_SERVER_END && if ((session->flags.side == CYASSL_SERVER_END &&
@ -2696,15 +2700,27 @@ doMessage:
return -1; return -1;
} }
} }
doPart:
switch ((enum ContentType)rh.type) { switch ((enum ContentType)rh.type) {
case handshake: case handshake:
Trace(GOT_HANDSHAKE_STR); {
ret = DoHandShake(sslFrame, &sslBytes, session, error); int startIdx = sslBytes;
if (ret != 0) { int used;
if (session->flags.fatalError == 0)
SetError(BAD_HANDSHAKE_STR,error,session,FATAL_ERROR_STATE); Trace(GOT_HANDSHAKE_STR);
return -1; ret = DoHandShake(sslFrame, &sslBytes, session, error);
if (ret != 0) {
if (session->flags.fatalError == 0)
SetError(BAD_HANDSHAKE_STR, error, session,
FATAL_ERROR_STATE);
return -1;
}
/* DoHandShake now fully decrements sslBytes to remaining */
used = startIdx - sslBytes;
sslFrame += used;
} }
break; break;
case change_cipher_spec: case change_cipher_spec:
@ -2715,6 +2731,10 @@ doMessage:
Trace(GOT_CHANGE_CIPHER_STR); Trace(GOT_CHANGE_CIPHER_STR);
ssl->options.handShakeState = HANDSHAKE_DONE; ssl->options.handShakeState = HANDSHAKE_DONE;
ssl->options.handShakeDone = 1; ssl->options.handShakeDone = 1;
sslFrame += 1;
sslBytes -= 1;
break; break;
case application_data: case application_data:
Trace(GOT_APP_DATA_STR); Trace(GOT_APP_DATA_STR);
@ -2739,21 +2759,33 @@ doMessage:
} }
if (ssl->buffers.outputBuffer.dynamicFlag) if (ssl->buffers.outputBuffer.dynamicFlag)
ShrinkOutputBuffer(ssl); ShrinkOutputBuffer(ssl);
sslFrame += inOutIdx;
sslBytes -= inOutIdx;
} }
break; break;
case alert: case alert:
Trace(GOT_ALERT_STR); Trace(GOT_ALERT_STR);
sslFrame += rhSize;
sslBytes -= rhSize;
break; break;
case no_type: case no_type:
default: default:
SetError(GOT_UNKNOWN_RECORD_STR, error, session, FATAL_ERROR_STATE); SetError(GOT_UNKNOWN_RECORD_STR, error, session, FATAL_ERROR_STATE);
return -1; return -1;
} }
if (tmp < end) { /* do we have another msg in record ? */
if (sslFrame < recordEnd) {
Trace(ANOTHER_MSG_STR); Trace(ANOTHER_MSG_STR);
sslFrame = tmp; goto doPart;
sslBytes = (int)(end - tmp); }
/* do we have more records ? */
if (recordEnd < end) {
Trace(ANOTHER_MSG_STR);
sslFrame = recordEnd;
sslBytes = (int)(end - recordEnd);
goto doMessage; goto doMessage;
} }