diff --git a/src/dtls.c b/src/dtls.c index 6babe3116..82be280c0 100644 --- a/src/dtls.c +++ b/src/dtls.c @@ -111,7 +111,7 @@ typedef struct WolfSSL_CH { WolfSSL_ConstVector cipherSuite; WolfSSL_ConstVector compression; WolfSSL_ConstVector extension; - const byte* msg; + const byte* raw; word32 length; /* Store the DTLS 1.2 cookie since we can just compute it once in dtls.c */ byte dtls12cookie[DTLS_COOKIE_SZ]; @@ -221,7 +221,7 @@ static int ParseClientHello(const byte* input, word32 helloSz, WolfSSL_CH* ch) if (OPAQUE16_LEN + RAN_LEN + OPAQUE8_LEN > helloSz) return BUFFER_ERROR; - ch->msg = input - DTLS_HANDSHAKE_HEADER_SZ; + ch->raw = input; ch->pv = (ProtocolVersion*)(input + idx); idx += OPAQUE16_LEN; ch->random = (byte*)(input + idx); @@ -241,7 +241,7 @@ static int ParseClientHello(const byte* input, word32 helloSz, WolfSSL_CH* ch) idx += ReadVector16(input + idx, &ch->extension); if (idx > helloSz) return BUFFER_ERROR; - ch->length = idx + DTLS_HANDSHAKE_HEADER_SZ; + ch->length = idx; return 0; } @@ -502,8 +502,8 @@ static int SendStatelessReply(const WOLFSSL* ssl, WolfSSL_CH* ch, byte isTls13, } /* Hashes are reset in SendTls13ServerHello when sending a HRR */ - ret = Dtls13HashHandshake((WOLFSSL*)ssl, ch->msg, - (word16)ch->length); + ret = Dtls13HashHandshakeType((WOLFSSL*)ssl, ch->raw, ch->length, + client_hello); if (ret != 0) goto dtls13_cleanup; diff --git a/src/dtls13.c b/src/dtls13.c index a6b63b1a2..91bddcf8f 100644 --- a/src/dtls13.c +++ b/src/dtls13.c @@ -435,8 +435,27 @@ static int Dtls13SendNow(WOLFSSL* ssl, enum HandShakeType handshakeType) return 0; } -/* Handshake header DTLS only fields are not included in the transcript hash */ -int Dtls13HashHandshake(WOLFSSL* ssl, const byte* output, word16 length) +/* Handshake header DTLS only fields are not inlcuded in the transcript hash. + * body points to the body of the DTLSHandshake message. */ +int Dtls13HashHandshakeType(WOLFSSL* ssl, const byte* body, word32 length, + enum HandShakeType handshakeType) +{ + /* msg_type(1) + length (3) */ + byte header[OPAQUE32_LEN]; + int ret; + + header[0] = (byte)handshakeType; + c32to24(length, header + 1); + + ret = HashRaw(ssl, header, OPAQUE32_LEN); + if (ret != 0) + return ret; + + return HashRaw(ssl, body, length); +} + +/* Handshake header DTLS only fields are not inlcuded in the transcript hash */ +int Dtls13HashHandshake(WOLFSSL* ssl, const byte* input, word16 length) { int ret; @@ -444,18 +463,18 @@ int Dtls13HashHandshake(WOLFSSL* ssl, const byte* output, word16 length) return BAD_FUNC_ARG; /* msg_type(1) + length (3) */ - ret = HashRaw(ssl, output, OPAQUE32_LEN); + ret = HashRaw(ssl, input, OPAQUE32_LEN); if (ret != 0) return ret; - output += OPAQUE32_LEN; + input += OPAQUE32_LEN; length -= OPAQUE32_LEN; /* message_seq(2) + fragment_offset(3) + fragment_length(3) */ - output += OPAQUE64_LEN; + input += OPAQUE64_LEN; length -= OPAQUE64_LEN; - return HashRaw(ssl, output, length); + return HashRaw(ssl, input, length); } static int Dtls13SendFragment(WOLFSSL* ssl, byte* output, word16 output_size, diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 6eb3872e0..add0d2c6f 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -6126,8 +6126,10 @@ WOLFSSL_LOCAL int Dtls13ReconstructSeqNumber(WOLFSSL* ssl, WOLFSSL_LOCAL int SendDtls13Ack(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13RtxProcessingCertificate(WOLFSSL* ssl, byte* input, word32 inputSize); -WOLFSSL_LOCAL int Dtls13HashHandshake(WOLFSSL* ssl, const byte* output, +WOLFSSL_LOCAL int Dtls13HashHandshake(WOLFSSL* ssl, const byte* input, word16 length); +WOLFSSL_LOCAL int Dtls13HashHandshakeType(WOLFSSL* ssl, const byte* body, + word32 length, enum HandShakeType handshakeType); WOLFSSL_LOCAL void Dtls13FreeFsmResources(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13RtxTimeout(WOLFSSL* ssl); WOLFSSL_LOCAL int Dtls13ProcessBufferedMessages(WOLFSSL* ssl);