diff --git a/src/internal.c b/src/internal.c index c8e74fe46..95e45da28 100644 --- a/src/internal.c +++ b/src/internal.c @@ -3101,7 +3101,8 @@ int RsaVerify(WOLFSSL* ssl, byte* in, word32 inSz, byte** out, int sigAlgo, /* Verify RSA signature, 0 on success */ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz, - const byte* plain, word32 plainSz, int sigAlgo, int hashAlgo, RsaKey* key) + const byte* plain, word32 plainSz, int sigAlgo, int hashAlgo, RsaKey* key, + const byte* keyBuf, word32 keySz, void* ctx) { byte* out = NULL; /* inline result */ int ret; @@ -3136,8 +3137,19 @@ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz, ret = ConvertHashPss(hashAlgo, &hashType, &mgf); if (ret != 0) return ret; - ret = wc_RsaPSS_VerifyInline(verifySig, sigSz, &out, hashType, mgf, - key); + #ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->RsaPssVerifyCb) { + ret = ssl->ctx->RsaPssVerifyCb(ssl, verifySig, sigSz, &out, + TypeHash(hashAlgo), mgf, + keyBuf, keySz, ctx); + } + else + #endif /* HAVE_PK_CALLBACKS */ + { + ret = wc_RsaPSS_VerifyInline(verifySig, sigSz, &out, hashType, mgf, + key); + } + if (ret > 0) { ret = wc_RsaPSS_CheckPadding(plain, plainSz, out, ret, hashType); if (ret != 0) @@ -3145,9 +3157,19 @@ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz, } } else -#endif +#endif /* WC_RSA_PSS */ { - ret = wc_RsaSSL_VerifyInline(verifySig, sigSz, &out, key); + #ifdef HAVE_PK_CALLBACKS + if (ssl->ctx->RsaVerifyCb) { + ret = ssl->ctx->RsaVerifyCb(ssl, verifySig, sigSz, &out, + keyBuf, keySz, ctx); + } + else + #endif /* HAVE_PK_CALLBACKS */ + { + ret = wc_RsaSSL_VerifyInline(verifySig, sigSz, &out, key); + } + if (ret > 0) { if (ret != (int)plainSz || !out || XMEMCMP(plain, out, plainSz) != 0) { @@ -20549,7 +20571,13 @@ int SendCertificateVerify(WOLFSSL* ssl) ret = VerifyRsaSign(ssl, args->verifySig, args->sigSz, ssl->buffers.sig.buffer, ssl->buffers.sig.length, - args->sigAlgo, ssl->suites->hashAlgo, key + args->sigAlgo, ssl->suites->hashAlgo, key, + ssl->buffers.key->buffer, ssl->buffers.key->length, + #ifdef HAVE_PK_CALLBACKS + ssl->RsaVerifyCtx + #else + NULL + #endif ); } #endif /* !NO_RSA */ @@ -22304,7 +22332,13 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, ssl->buffers.sig.buffer, ssl->buffers.sig.length, ssl->suites->sigAlgo, ssl->suites->hashAlgo, - key + key, ssl->buffers.key->buffer, + ssl->buffers.key->length, + #ifdef HAVE_PK_CALLBACKS + ssl->RsaVerifyCtx + #else + NULL + #endif ); break; } @@ -22376,7 +22410,13 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, ssl->buffers.sig.buffer, ssl->buffers.sig.length, ssl->suites->sigAlgo, ssl->suites->hashAlgo, - key + key, ssl->buffers.key->buffer, + ssl->buffers.key->length, + #ifdef HAVE_PK_CALLBACKS + ssl->RsaVerifyCtx + #else + NULL + #endif ); break; } diff --git a/src/tls13.c b/src/tls13.c index 242e1a2c1..be75cd9f2 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -4997,7 +4997,14 @@ static int SendTls13CertificateVerify(WOLFSSL* ssl) /* check for signature faults */ ret = VerifyRsaSign(ssl, args->verifySig, args->sigLen, sig->buffer, sig->length, args->sigAlgo, - ssl->suites->hashAlgo, (RsaKey*)ssl->hsKey); + ssl->suites->hashAlgo, (RsaKey*)ssl->hsKey, + ssl->buffers.key->buffer, ssl->buffers.key->length, + #ifdef HAVE_PK_CALLBACKS + ssl->RsaVerifyCtx + #else + NULL + #endif + ); } #endif /* !NO_RSA */ diff --git a/wolfssl/internal.h b/wolfssl/internal.h index 4e3bc0d80..b65ca7cfa 100644 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -3831,11 +3831,9 @@ WOLFSSL_LOCAL int SetTicket(WOLFSSL*, const byte*, word32); enum wc_HashType hashType); WOLFSSL_LOCAL int ConvertHashPss(int hashAlgo, enum wc_HashType* hashType, int* mgf); #endif - WOLFSSL_LOCAL int VerifyRsaSign(WOLFSSL* ssl, - byte* verifySig, word32 sigSz, - const byte* plain, word32 plainSz, - int sigAlgo, int hashAlgo, - RsaKey* key); + WOLFSSL_LOCAL int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, + word32 sigSz, const byte* plain, word32 plainSz, int sigAlgo, + int hashAlgo, RsaKey* key, const byte* keyBuf, word32 keySz, void* ctx); WOLFSSL_LOCAL int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out, word32* outSz, int sigAlgo, int hashAlgo, RsaKey* key, const byte* keyBuf, word32 keySz, void* ctx);