diff --git a/src/tls13.c b/src/tls13.c index e9defe6bd..cfc674ef0 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -8761,6 +8761,10 @@ typedef struct Scv13Args { byte sigAlgo; byte* sigData; word16 sigDataSz; +#ifndef NO_RSA + byte* toSign; /* not allocated */ + word32 toSignSz; +#endif #ifdef WOLFSSL_DUAL_ALG_CERTS byte altSigAlgo; word32 altSigLen; /* Only used in the case of both native and alt. */ @@ -9315,7 +9319,17 @@ static int SendTls13CertificateVerify(WOLFSSL* ssl) #endif /* HAVE_DILITHIUM */ #ifndef NO_RSA if (ssl->hsType == DYNAMIC_TYPE_RSA) { - ret = RsaSign(ssl, rsaSigBuf->buffer, (word32)rsaSigBuf->length, + args->toSign = rsaSigBuf->buffer; + args->toSignSz = (word32)rsaSigBuf->length; + #if defined(HAVE_PK_CALLBACKS) && \ + defined(TLS13_RSA_PSS_SIGN_CB_NO_PREHASH) + /* Pass full data to sign (args->sigData), not hash of */ + if (ssl->ctx->RsaPssSignCb) { + args->toSign = args->sigData; + args->toSignSz = args->sigDataSz; + } + #endif + ret = RsaSign(ssl, (const byte*)args->toSign, args->toSignSz, sigOut, &args->sigLen, args->sigAlgo, ssl->options.hashAlgo, (RsaKey*)ssl->hsKey, ssl->buffers.key); @@ -9359,10 +9373,20 @@ static int SendTls13CertificateVerify(WOLFSSL* ssl) #endif /* HAVE_ECC */ #ifndef NO_RSA if (ssl->hsAltType == DYNAMIC_TYPE_RSA) { - ret = RsaSign(ssl, rsaSigBuf->buffer, - (word32)rsaSigBuf->length, sigOut, - &args->altSigLen, args->altSigAlgo, - ssl->options.hashAlgo, (RsaKey*)ssl->hsAltKey, + args->toSign = rsaSigBuf->buffer; + args->toSignSz = (word32)rsaSigBuf->length; + #if defined(HAVE_PK_CALLBACKS) && \ + defined(TLS13_RSA_PSS_SIGN_CB_NO_PREHASH) + /* Pass full data to sign (args->altSigData), not hash of */ + if (ssl->ctx->RsaPssSignCb) { + args->toSign = args->altSigData; + args->toSignSz = (word32)args->altSigDataSz; + } + #endif + ret = RsaSign(ssl, (const byte*)args->toSign, + args->toSignSz, sigOut, &args->altSigLen, + args->altSigAlgo, ssl->options.hashAlgo, + (RsaKey*)ssl->hsAltKey, ssl->buffers.altKey); if (ret == 0) { diff --git a/wolfssl/test.h b/wolfssl/test.h index 0fb23c196..0efa20e03 100644 --- a/wolfssl/test.h +++ b/wolfssl/test.h @@ -3902,9 +3902,11 @@ static WC_INLINE int myRsaPssSign(WOLFSSL* ssl, const byte* in, word32 inSz, { enum wc_HashType hashType = WC_HASH_TYPE_NONE; WC_RNG rng; - int ret; + int ret = 0; word32 idx = 0; RsaKey myKey; + byte* inBuf = (byte*)in; + word32 inBufSz = inSz; byte* keyBuf = (byte*)key; PkCbInfo* cbInfo = (PkCbInfo*)ctx; @@ -3942,17 +3944,40 @@ static WC_INLINE int myRsaPssSign(WOLFSSL* ssl, const byte* in, word32 inSz, if (ret != 0) return ret; - ret = wc_InitRsaKey(&myKey, NULL); + #ifdef TLS13_RSA_PSS_SIGN_CB_NO_PREHASH + /* With this defined, RSA-PSS sign callback when used from TLS 1.3 + * does not hash data before giving to this callback. User must + * compute hash themselves. */ + if (wolfSSL_GetVersion(ssl) == WOLFSSL_TLSV1_3) { + inBufSz = wc_HashGetDigestSize(hashType); + inBuf = (byte*)XMALLOC(inBufSz, NULL, DYNAMIC_TYPE_TMP_BUFFER); + if (inBuf == NULL) { + ret = MEMORY_E; + } + if (ret == 0) { + ret = wc_Hash(hashType, in, inSz, inBuf, inBufSz); + } + } + #endif + + if (ret == 0) { + ret = wc_InitRsaKey(&myKey, NULL); + } if (ret == 0) { ret = wc_RsaPrivateKeyDecode(keyBuf, &idx, &myKey, keySz); if (ret == 0) { - ret = wc_RsaPSS_Sign(in, inSz, out, *outSz, hashType, mgf, &myKey, - &rng); + ret = wc_RsaPSS_Sign(inBuf, inBufSz, out, *outSz, hashType, mgf, + &myKey, &rng); } if (ret > 0) { /* save and convert to 0 success */ *outSz = (word32) ret; ret = 0; } + #ifdef TLS13_RSA_PSS_SIGN_CB_NO_PREHASH + if ((inBuf != NULL) && (wolfSSL_GetVersion(ssl) == WOLFSSL_TLSV1_3)) { + XFREE(inBuf, NULL, DYNAMIC_TYPE_TMP_BUFFER); + } + #endif wc_FreeRsaKey(&myKey); } wc_FreeRng(&rng);