diff --git a/src/internal.c b/src/internal.c index 113af79d8..12e129f08 100644 --- a/src/internal.c +++ b/src/internal.c @@ -3843,19 +3843,27 @@ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz, ret = ssl->ctx->RsaPssSignCheckCb(ssl, verifySig, sigSz, &out, TypeHash(hashAlgo), mgf, keyBuf, keySz, ctx); + if (ret > 0) { + ret = wc_RsaPSS_CheckPadding(plain, plainSz, out, ret, + hashType); + if (ret != 0) + ret = VERIFY_CERT_ERROR; + } } else #endif /* HAVE_PK_CALLBACKS */ { ret = wc_RsaPSS_VerifyInline(verifySig, sigSz, &out, hashType, mgf, key); + if (ret > 0) { + ret = wc_RsaPSS_CheckPadding_ex(plain, plainSz, out, ret, + hashType, -1, + mp_count_bits(&key->n)); + if (ret != 0) + ret = VERIFY_CERT_ERROR; + } } - if (ret > 0) { - ret = wc_RsaPSS_CheckPadding(plain, plainSz, out, ret, hashType); - if (ret != 0) - ret = VERIFY_CERT_ERROR; - } } else #endif /* WC_RSA_PSS */ @@ -19426,6 +19434,9 @@ typedef struct DskeArgs { word16 sigSz; byte sigAlgo; byte hashAlgo; +#if !defined(NO_RSA) && defined(WC_RSA_PSS) + int bits; +#endif } DskeArgs; static void FreeDskeArgs(WOLFSSL* ssl, void* pArgs) @@ -20180,6 +20191,9 @@ static int DoServerKeyExchange(WOLFSSL* ssl, const byte* input, if (ret >= 0) { args->sigSz = (word16)ret; + #ifdef WC_RSA_PSS + args->bits = mp_count_bits(&ssl->peerRsaKey->n); + #endif ret = 0; } #ifdef WOLFSSL_ASYNC_CRYPT @@ -20299,11 +20313,12 @@ static int DoServerKeyExchange(WOLFSSL* ssl, const byte* input, #ifndef NO_RSA #ifdef WC_RSA_PSS case rsa_pss_sa_algo: - ret = wc_RsaPSS_CheckPadding( + ret = wc_RsaPSS_CheckPadding_ex( ssl->buffers.digest.buffer, ssl->buffers.digest.length, args->output, args->sigSz, - HashAlgoToType(args->hashAlgo)); + HashAlgoToType(args->hashAlgo), + -1, args->bits); if (ret != 0) return ret; break; @@ -25689,11 +25704,12 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, if (args->sigAlgo == rsa_pss_sa_algo) { SetDigest(ssl, args->hashAlgo); - ret = wc_RsaPSS_CheckPadding( - ssl->buffers.digest.buffer, - ssl->buffers.digest.length, - args->output, args->sigSz, - HashAlgoToType(args->hashAlgo)); + ret = wc_RsaPSS_CheckPadding_ex( + ssl->buffers.digest.buffer, + ssl->buffers.digest.length, + args->output, args->sigSz, + HashAlgoToType(args->hashAlgo), -1, + mp_count_bits(&ssl->peerRsaKey->n)); if (ret != 0) { ret = SIG_VERIFY_E; goto exit_dcv;