diff --git a/src/internal.c b/src/internal.c index 1d19e9b36..a51ba9528 100644 --- a/src/internal.c +++ b/src/internal.c @@ -182,6 +182,20 @@ int IsAtLeastTLSv1_2(const WOLFSSL* ssl) } +static INLINE int IsEncryptionOn(WOLFSSL* ssl, int isSend) +{ + (void)isSend; + + #ifdef WOLFSSL_DTLS + /* For DTLS, epoch 0 is always not encrypted. */ + if (ssl->options.dtls && !isSend && ssl->keys.dtls_state.curEpoch == 0) + return 0; + #endif /* WOLFSSL_DTLS */ + + return ssl->keys.encryptionOn; +} + + #ifdef HAVE_QSH /* free all structs that where used with QSH */ static int QSH_FreeAll(WOLFSSL* ssl) @@ -3412,9 +3426,6 @@ static int GetRecordHeader(WOLFSSL* ssl, const byte* input, word32* inOutIdx, return UNKNOWN_RECORD_TYPE; } - /* haven't decrypted this record yet */ - ssl->keys.decryptedCur = 0; - return 0; } @@ -4546,7 +4557,7 @@ static int DoCertificate(WOLFSSL* ssl, byte* input, word32* inOutIdx, if (fatal == 0 && ssl->secure_renegotiation && ssl->secure_renegotiation->enabled) { - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { /* compare against previous time */ if (XMEMCMP(dCert->subjectHash, ssl->secure_renegotiation->subject_hash, @@ -4895,7 +4906,7 @@ static int DoCertificate(WOLFSSL* ssl, byte* input, word32* inOutIdx, if (ret == 0 && ssl->options.side == WOLFSSL_CLIENT_END) ssl->options.serverState = SERVER_CERT_COMPLETE; - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { *inOutIdx += ssl->keys.padSz; } @@ -5007,7 +5018,7 @@ static int DoHelloRequest(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if (size) /* must be 0 */ return BUFFER_ERROR; - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { /* access beyond input + size should be checked against totalSz */ if (*inOutIdx + ssl->keys.padSz > totalSz) return BUFFER_E; @@ -5493,7 +5504,7 @@ static int DoHandShakeMsgType(WOLFSSL* ssl, byte* input, word32* inOutIdx, AddLateName("ServerHelloDone", &ssl->timeoutInfo); #endif ssl->options.serverState = SERVER_HELLODONE_COMPLETE; - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { *inOutIdx += ssl->keys.padSz; } if (ssl->options.resuming) { @@ -6854,7 +6865,7 @@ static int DoAlert(WOLFSSL* ssl, byte* input, word32* inOutIdx, int* type, ssl->options.closeNotify = 1; } WOLFSSL_ERROR(*type); - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { if (*inOutIdx + ssl->keys.padSz > totalSz) return BUFFER_E; *inOutIdx += ssl->keys.padSz; @@ -7162,13 +7173,7 @@ int ProcessReply(WOLFSSL* ssl) /* the record layer is here */ case runProcessingOneMessage: - #ifdef WOLFSSL_DTLS - if (ssl->options.dtls && - ssl->keys.dtls_state.curEpoch < ssl->keys.dtls_state.nextEpoch) - ssl->keys.decryptedCur = 1; - #endif - - if (ssl->keys.encryptionOn && ssl->keys.decryptedCur == 0) + if (IsEncryptionOn(ssl, 0)) { ret = SanityCheckCipherText(ssl, ssl->curSize); if (ret < 0) @@ -7220,7 +7225,6 @@ int ProcessReply(WOLFSSL* ssl) return DECRYPT_ERROR; } ssl->keys.encryptSz = ssl->curSize; - ssl->keys.decryptedCur = 1; } if (ssl->options.dtls) { @@ -7295,7 +7299,7 @@ int ProcessReply(WOLFSSL* ssl) } #endif - if (ssl->keys.encryptionOn && ssl->options.handShakeDone) { + if (IsEncryptionOn(ssl, 0) && ssl->options.handShakeDone) { ssl->buffers.inputBuffer.idx += ssl->keys.padSz; ssl->curSize -= (word16) ssl->buffers.inputBuffer.idx; } @@ -7394,7 +7398,7 @@ int ProcessReply(WOLFSSL* ssl) #endif ssl->options.processReply = runProcessingOneMessage; - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { WOLFSSL_MSG("Bundled encrypted messages, remove middle pad"); ssl->buffers.inputBuffer.idx -= ssl->keys.padSz; } @@ -7431,7 +7435,7 @@ int SendChangeCipher(WOLFSSL* ssl) #endif /* are we in scr */ - if (ssl->keys.encryptionOn && ssl->options.handShakeDone) { + if (IsEncryptionOn(ssl, 1) && ssl->options.handShakeDone) { sendSz += MAX_MSG_EXTRA; } @@ -7447,7 +7451,7 @@ int SendChangeCipher(WOLFSSL* ssl) output[idx] = 1; /* turn it on */ - if (ssl->keys.encryptionOn && ssl->options.handShakeDone) { + if (IsEncryptionOn(ssl, 1) && ssl->options.handShakeDone) { byte input[ENUM_LEN]; int inputSz = ENUM_LEN; @@ -8018,7 +8022,7 @@ int SendCertificate(WOLFSSL* ssl) sendSz += fragSz; } - if (ssl->keys.encryptionOn) + if (IsEncryptionOn(ssl, 1)) sendSz += MAX_MSG_EXTRA; } else { @@ -8042,14 +8046,14 @@ int SendCertificate(WOLFSSL* ssl) if (ssl->fragOffset == 0) { if (!ssl->options.dtls) { AddFragHeaders(output, fragSz, 0, payloadSz, certificate, ssl); - if (!ssl->keys.encryptionOn) + if (!IsEncryptionOn(ssl, 1)) HashOutputRaw(ssl, output + RECORD_HEADER_SZ, HANDSHAKE_HEADER_SZ); } else { #ifdef WOLFSSL_DTLS AddHeaders(output, payloadSz, certificate, ssl); - if (!ssl->keys.encryptionOn) + if (!IsEncryptionOn(ssl, 1)) HashOutputRaw(ssl, output + RECORD_HEADER_SZ + DTLS_RECORD_EXTRA, HANDSHAKE_HEADER_SZ + DTLS_HANDSHAKE_EXTRA); @@ -8064,20 +8068,20 @@ int SendCertificate(WOLFSSL* ssl) /* list total */ c32to24(listSz, output + i); - if (!ssl->keys.encryptionOn) + if (!IsEncryptionOn(ssl, 1)) HashOutputRaw(ssl, output + i, CERT_HEADER_SZ); i += CERT_HEADER_SZ; length -= CERT_HEADER_SZ; fragSz -= CERT_HEADER_SZ; if (certSz) { c32to24(certSz, output + i); - if (!ssl->keys.encryptionOn) + if (!IsEncryptionOn(ssl, 1)) HashOutputRaw(ssl, output + i, CERT_HEADER_SZ); i += CERT_HEADER_SZ; length -= CERT_HEADER_SZ; fragSz -= CERT_HEADER_SZ; - if (!ssl->keys.encryptionOn) { + if (!IsEncryptionOn(ssl, 1)) { HashOutputRaw(ssl, ssl->buffers.certificate.buffer, certSz); if (certChainSz) HashOutputRaw(ssl, ssl->buffers.certChain.buffer, @@ -8118,7 +8122,7 @@ int SendCertificate(WOLFSSL* ssl) length -= copySz; } - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 1)) { byte* input; int inputSz = i - RECORD_HEADER_SZ; /* build msg adds rec hdr */ @@ -8492,7 +8496,7 @@ int SendAlert(WOLFSSL* ssl, int severity, int type) /* only send encrypted alert if handshake actually complete, otherwise other side may not be able to handle it */ - if (ssl->keys.encryptionOn && ssl->options.handShakeDone) + if (IsEncryptionOn(ssl, 1) && ssl->options.handShakeDone) sendSz = BuildMessage(ssl, output, outputSz, input, ALERT_SIZE, alert); else { @@ -10015,7 +10019,7 @@ static void PickHashSigAlgo(WOLFSSL* ssl, } #endif - if (ssl->keys.encryptionOn) + if (IsEncryptionOn(ssl, 1)) sendSz += MAX_MSG_EXTRA; /* check for available size */ @@ -10113,7 +10117,7 @@ static void PickHashSigAlgo(WOLFSSL* ssl, } #endif - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 1)) { byte* input; int inputSz = idx - RECORD_HEADER_SZ; /* build msg adds rec hdr */ @@ -10376,7 +10380,7 @@ static void PickHashSigAlgo(WOLFSSL* ssl, ssl->options.serverState = SERVER_HELLO_COMPLETE; - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { *inOutIdx += ssl->keys.padSz; } @@ -10530,7 +10534,7 @@ static void PickHashSigAlgo(WOLFSSL* ssl, else if (IsTLS(ssl)) ssl->options.sendVerify = SEND_BLANK_CERT; - if (ssl->keys.encryptionOn) + if (IsEncryptionOn(ssl, 0)) *inOutIdx += ssl->keys.padSz; return 0; @@ -11319,7 +11323,7 @@ static void PickHashSigAlgo(WOLFSSL* ssl, return ret; } - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 0)) { *inOutIdx += ssl->keys.padSz; } @@ -12137,7 +12141,7 @@ static word32 QSH_KeyExchangeWrite(WOLFSSL* ssl, byte isServer) } #endif - if (ssl->keys.encryptionOn) + if (IsEncryptionOn(ssl, 1)) sendSz += MAX_MSG_EXTRA; #ifdef HAVE_QSH @@ -12193,7 +12197,7 @@ static word32 QSH_KeyExchangeWrite(WOLFSSL* ssl, byte isServer) XMEMCPY(output + idx, encSecret, encSz); idx += encSz; - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 1)) { byte* input; int inputSz = idx-RECORD_HEADER_SZ; /* buildmsg adds rechdr */ @@ -12293,7 +12297,7 @@ static word32 QSH_KeyExchangeWrite(WOLFSSL* ssl, byte isServer) if (ssl->options.sendVerify == SEND_BLANK_CERT) return 0; /* sent blank cert, can't verify */ - if (ssl->keys.encryptionOn) + if (IsEncryptionOn(ssl, 1)) sendSz += MAX_MSG_EXTRA; /* check for available size */ @@ -12569,7 +12573,7 @@ static word32 QSH_KeyExchangeWrite(WOLFSSL* ssl, byte isServer) } #endif - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl, 1)) { byte* input; int inputSz = sendSz - RECORD_HEADER_SZ; /* build msg adds rec hdr */ @@ -12683,7 +12687,7 @@ int DoSessionTicket(WOLFSSL* ssl, ssl->session.ticketLen = 0; } - if (ssl->keys.encryptionOn) { + if (IsEncryptionOn(ssl)) { *inOutIdx += ssl->keys.padSz; } diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 7acd2a064..de8ef669f 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -1453,7 +1453,6 @@ typedef struct Keys { word32 encryptSz; /* last size of encrypted data */ word32 padSz; /* how much to advance after decrypt part */ byte encryptionOn; /* true after change cipher spec */ - byte decryptedCur; /* only decrypt current record once */ } Keys;