From 248dd12993e127e4a3ae13f1364a64e79b377c78 Mon Sep 17 00:00:00 2001 From: Juliusz Sosinowicz Date: Thu, 24 Sep 2020 17:09:34 +0200 Subject: [PATCH] Enable RSA-PSS padding in EVP_Digest* API --- src/ssl.c | 255 +++++++++++++++++++++++++--------------- tests/api.c | 33 ++++++ wolfcrypt/src/evp.c | 8 +- wolfcrypt/src/rsa.c | 46 +++++++- wolfssl/openssl/rsa.h | 6 +- wolfssl/wolfcrypt/rsa.h | 5 +- 6 files changed, 251 insertions(+), 102 deletions(-) diff --git a/src/ssl.c b/src/ssl.c index e4aeb2b95..0ce466814 100644 --- a/src/ssl.c +++ b/src/ssl.c @@ -30805,39 +30805,7 @@ static void show(const char *title, const unsigned char *out, unsigned int outle #define show(a,b,c) #endif -/* return SSL_SUCCESS on ok, 0 otherwise */ -int wolfSSL_RSA_sign(int type, const unsigned char* m, - unsigned int mLen, unsigned char* sigRet, - unsigned int* sigLen, WOLFSSL_RSA* rsa) -{ - return wolfSSL_RSA_sign_ex(type, m, mLen, sigRet, sigLen, rsa, 1); -} - -int wolfSSL_RSA_sign_ex(int type, const unsigned char* m, - unsigned int mLen, unsigned char* sigRet, - unsigned int* sigLen, WOLFSSL_RSA* rsa, int flag) -{ - word32 outLen; - word32 signSz; - int initTmpRng = 0; - WC_RNG* rng = NULL; - int ret = 0; -#ifdef WOLFSSL_SMALL_STACK - WC_RNG* tmpRNG = NULL; - byte* encodedSig = NULL; -#else - WC_RNG tmpRNG[1]; - byte encodedSig[MAX_ENCODED_SIG_SZ]; -#endif - - WOLFSSL_ENTER("wolfSSL_RSA_sign"); - - if (m == NULL || sigRet == NULL || sigLen == NULL || rsa == NULL) { - WOLFSSL_MSG("Bad function arguments"); - return 0; - } - show("Message to Sign", m, mLen); - +static int nid2HashSum(int type) { switch (type) { #ifdef WOLFSSL_MD2 case NID_md2: type = MD2h; break; @@ -30873,6 +30841,43 @@ int wolfSSL_RSA_sign_ex(int type, const unsigned char* m, WOLFSSL_MSG("This NID (md type) not configured or not implemented"); return 0; } + return type; +} + +/* return SSL_SUCCESS on ok, 0 otherwise */ +int wolfSSL_RSA_sign(int type, const unsigned char* m, + unsigned int mLen, unsigned char* sigRet, + unsigned int* sigLen, WOLFSSL_RSA* rsa) +{ + return wolfSSL_RSA_sign_ex(type, m, mLen, sigRet, sigLen, rsa, 1, + RSA_PKCS1_PADDING); +} + +int wolfSSL_RSA_sign_ex(int type, const unsigned char* m, + unsigned int mLen, unsigned char* sigRet, + unsigned int* sigLen, WOLFSSL_RSA* rsa, int flag, + int padding) +{ + word32 outLen; + word32 signSz; + int initTmpRng = 0; + WC_RNG* rng = NULL; + int ret = 0; +#ifdef WOLFSSL_SMALL_STACK + WC_RNG* tmpRNG = NULL; + byte* encodedSig = NULL; +#else + WC_RNG tmpRNG[1]; + byte encodedSig[MAX_ENCODED_SIG_SZ]; +#endif + + WOLFSSL_ENTER("wolfSSL_RSA_sign"); + + if (m == NULL || sigRet == NULL || sigLen == NULL || rsa == NULL) { + WOLFSSL_MSG("Bad function arguments"); + return 0; + } + show("Message to Sign", m, mLen); if (rsa->inSet == 0) { @@ -30884,6 +30889,8 @@ int wolfSSL_RSA_sign_ex(int type, const unsigned char* m, } } + type = nid2HashSum(type); + outLen = (word32)wolfSSL_BN_num_bytes(rsa->n); #ifdef WOLFSSL_SMALL_STACK @@ -30915,32 +30922,71 @@ int wolfSSL_RSA_sign_ex(int type, const unsigned char* m, } if (rng) { - - signSz = wc_EncodeSignature(encodedSig, m, mLen, type); - if (signSz == 0) { - WOLFSSL_MSG("Bad Encode Signature"); - } - else { - show("Encoded Message", encodedSig, signSz); - if (flag != 0) { + if (flag != 0) { + switch (padding) { +#ifdef WC_RSA_NO_PADDING + case RSA_NO_PADDING: + WOLFSSL_MSG("RSA_NO_PADDING not supported for signing"); + ret = BAD_FUNC_ARG; + break; +#endif +#ifdef WC_RSA_PSS + case RSA_PKCS1_PSS_PADDING: + { + enum wc_HashType hType = wc_OidGetHash(type); + ret = wc_RsaPSS_Sign(m, mLen, sigRet, outLen, + hType, hash2mgf(hType), (RsaKey*)rsa->internal, rng); + break; + } +#endif +#ifndef WC_NO_RSA_OAEP + case RSA_PKCS1_OAEP_PADDING: + { + WOLFSSL_MSG("RSA_PKCS1_OAEP_PADDING not supported for signing"); + ret = BAD_FUNC_ARG; + break; + } +#endif + case RSA_PKCS1_PADDING: + default: + signSz = wc_EncodeSignature(encodedSig, m, mLen, type); + if (signSz == 0) { + WOLFSSL_MSG("Bad Encode Signature"); + } + show("Encoded Message", encodedSig, signSz); ret = wc_RsaSSL_Sign(encodedSig, signSz, sigRet, outLen, (RsaKey*)rsa->internal, rng); - if (ret <= 0) { - WOLFSSL_MSG("Bad Rsa Sign"); - ret = 0; + } + if (ret <= 0) { + WOLFSSL_MSG("Bad Rsa Sign"); + ret = 0; + } + else { + *sigLen = (unsigned int)ret; + ret = SSL_SUCCESS; + show("Signature", sigRet, *sigLen); + } + } else { + switch (padding) { + case RSA_NO_PADDING: + case RSA_PKCS1_PSS_PADDING: + case RSA_PKCS1_OAEP_PADDING: + ret = SSL_SUCCESS; + XMEMCPY(sigRet, m, mLen); + *sigLen = mLen; + break; + case RSA_PKCS1_PADDING: + default: + signSz = wc_EncodeSignature(encodedSig, m, mLen, type); + if (signSz == 0) { + WOLFSSL_MSG("Bad Encode Signature"); } - else { - *sigLen = (unsigned int)ret; - ret = SSL_SUCCESS; - show("Signature", sigRet, *sigLen); - } - } else { ret = SSL_SUCCESS; XMEMCPY(sigRet, encodedSig, signSz); *sigLen = signSz; + break; } } - } if (initTmpRng) @@ -30959,65 +31005,87 @@ int wolfSSL_RSA_sign_ex(int type, const unsigned char* m, return ret; } - /* returns WOLFSSL_SUCCESS on successful verify and WOLFSSL_FAILURE on fail */ int wolfSSL_RSA_verify(int type, const unsigned char* m, unsigned int mLen, const unsigned char* sig, unsigned int sigLen, WOLFSSL_RSA* rsa) { + return wolfSSL_RSA_verify_ex(type, m, mLen, sig, sigLen, rsa, RSA_PKCS1_PADDING); +} + +#define wolfSSL_RSA_verify_ex_return(msg, ret_code) { \ + WOLFSSL_MSG(msg); \ + if (sigRet) \ + XFREE(sigRet, NULL, DYNAMIC_TYPE_TMP_BUFFER); \ + if (sigDec) \ + XFREE(sigDec, NULL, DYNAMIC_TYPE_TMP_BUFFER); \ + return ret_code; \ +} + +/* returns WOLFSSL_SUCCESS on successful verify and WOLFSSL_FAILURE on fail */ +int wolfSSL_RSA_verify_ex(int type, const unsigned char* m, + unsigned int mLen, const unsigned char* sig, + unsigned int sigLen, WOLFSSL_RSA* rsa, + int padding) { + int ret; - unsigned char *sigRet ; - unsigned char *sigDec ; + unsigned char *sigRet = NULL; + unsigned char *sigDec = NULL; unsigned int len; + int hSum = nid2HashSum(type); + enum wc_HashType hType; WOLFSSL_ENTER("wolfSSL_RSA_verify"); if ((m == NULL) || (sig == NULL)) { WOLFSSL_MSG("Bad function arguments"); return WOLFSSL_FAILURE; } - - sigRet = (unsigned char *)XMALLOC(sigLen, NULL, DYNAMIC_TYPE_TMP_BUFFER); - if (sigRet == NULL) { - WOLFSSL_MSG("Memory failure"); - return WOLFSSL_FAILURE; - } sigDec = (unsigned char *)XMALLOC(sigLen, NULL, DYNAMIC_TYPE_TMP_BUFFER); if (sigDec == NULL) { - WOLFSSL_MSG("Memory failure"); - XFREE(sigRet, NULL, DYNAMIC_TYPE_TMP_BUFFER); - return WOLFSSL_FAILURE; + wolfSSL_RSA_verify_ex_return("Memory failure", WOLFSSL_FAILURE); } - /* get non-encrypted signature to be compared with decrypted signature */ - ret = wolfSSL_RSA_sign_ex(type, m, mLen, sigRet, &len, rsa, 0); - if (ret <= 0) { - WOLFSSL_MSG("Message Digest Error"); - XFREE(sigRet, NULL, DYNAMIC_TYPE_TMP_BUFFER); - XFREE(sigDec, NULL, DYNAMIC_TYPE_TMP_BUFFER); - return WOLFSSL_FAILURE; - } - show("Encoded Message", sigRet, len); - /* decrypt signature */ - ret = wc_RsaSSL_Verify(sig, sigLen, (unsigned char *)sigDec, sigLen, - (RsaKey*)rsa->internal); - if (ret <= 0) { - WOLFSSL_MSG("RSA Decrypt error"); - XFREE(sigRet, NULL, DYNAMIC_TYPE_TMP_BUFFER); - XFREE(sigDec, NULL, DYNAMIC_TYPE_TMP_BUFFER); - return WOLFSSL_FAILURE; - } - show("Decrypted Signature", sigDec, ret); - - if ((int)len == ret && XMEMCMP(sigRet, sigDec, ret) == 0) { - WOLFSSL_MSG("wolfSSL_RSA_verify success"); - XFREE(sigRet, NULL, DYNAMIC_TYPE_TMP_BUFFER); - XFREE(sigDec, NULL, DYNAMIC_TYPE_TMP_BUFFER); - return WOLFSSL_SUCCESS; + if (padding != RSA_PKCS1_PSS_PADDING) { + sigRet = (unsigned char *)XMALLOC(sigLen, NULL, DYNAMIC_TYPE_TMP_BUFFER); + if (sigRet == NULL) { + wolfSSL_RSA_verify_ex_return("Memory failure", WOLFSSL_FAILURE); + } + /* get non-encrypted signature to be compared with decrypted signature */ + ret = wolfSSL_RSA_sign_ex(type, m, mLen, sigRet, &len, rsa, 0, padding); + if (ret <= 0) { + wolfSSL_RSA_verify_ex_return("Message Digest Error", WOLFSSL_FAILURE); + } + show("Encoded Message", sigRet, len); } else { - WOLFSSL_MSG("wolfSSL_RSA_verify failed"); - XFREE(sigRet, NULL, DYNAMIC_TYPE_TMP_BUFFER); - XFREE(sigDec, NULL, DYNAMIC_TYPE_TMP_BUFFER); - return WOLFSSL_FAILURE; + show("Encoded Message", m, mLen); + } + /* decrypt signature */ + hType = wc_OidGetHash(hSum); + ret = wc_RsaSSL_Verify_ex(sig, sigLen, (unsigned char *)sigDec, sigLen, + (RsaKey*)rsa->internal, padding, hType); + if (ret <= 0) { + wolfSSL_RSA_verify_ex_return("RSA Decrypt error", WOLFSSL_FAILURE); + } + show("Decrypted Signature", sigDec, ret); + if (padding == RSA_PKCS1_PSS_PADDING) { + if ((ret = wc_RsaPSS_CheckPadding_ex(m, mLen, sigDec, ret, + hType, RSA_PSS_SALT_LEN_DEFAULT, + mp_count_bits(&((RsaKey*)rsa->internal)->n))) == 0) { + wolfSSL_RSA_verify_ex_return("wolfSSL_RSA_verify success", + WOLFSSL_SUCCESS); + } + else { + wolfSSL_RSA_verify_ex_return("wolfSSL_RSA_verify failed", + WOLFSSL_FAILURE); + } + } + else if ((int)len == ret && XMEMCMP(sigRet, sigDec, ret) == 0) { + wolfSSL_RSA_verify_ex_return("wolfSSL_RSA_verify success", + WOLFSSL_SUCCESS); + } + else { + wolfSSL_RSA_verify_ex_return("wolfSSL_RSA_verify failed", + WOLFSSL_FAILURE); } } @@ -45872,7 +45940,8 @@ int wolfSSL_RSA_public_decrypt(int flen, const unsigned char* from, /* size of 'to' buffer must be size of RSA key */ tlen = wc_RsaSSL_Verify_ex(from, flen, to, wolfSSL_RSA_size(rsa), - (RsaKey*)rsa->internal, pad_type); + (RsaKey*)rsa->internal, pad_type, + WC_HASH_TYPE_NONE); if (tlen <= 0) WOLFSSL_MSG("wolfSSL_RSA_public_decrypt failed"); else { diff --git a/tests/api.c b/tests/api.c index bd7bf1bfe..f986d0de8 100644 --- a/tests/api.c +++ b/tests/api.c @@ -26646,6 +26646,7 @@ static void test_wolfSSL_EVP_MD_rsa_signing(void) defined(USE_CERT_BUFFERS_2048) WOLFSSL_EVP_PKEY* privKey; WOLFSSL_EVP_PKEY* pubKey; + WOLFSSL_EVP_PKEY_CTX* keyCtx; const char testData[] = "Hi There"; WOLFSSL_EVP_MD_CTX mdCtx; size_t checkSz = -1; @@ -26653,6 +26654,12 @@ static void test_wolfSSL_EVP_MD_rsa_signing(void) const unsigned char* cp; const unsigned char* p; unsigned char check[2048/8]; + size_t i; + int paddings[] = { + RSA_PKCS1_PADDING, + RSA_PKCS1_PSS_PADDING, + }; + printf(testingFmt, "wolfSSL_EVP_MD_rsa_signing()"); @@ -26707,6 +26714,32 @@ static void test_wolfSSL_EVP_MD_rsa_signing(void) AssertIntEQ(wolfSSL_EVP_DigestVerifyFinal(&mdCtx, check, checkSz), 1); AssertIntEQ(wolfSSL_EVP_MD_CTX_cleanup(&mdCtx), 1); + /* Check all signing padding types */ + for (i = 0; i < sizeof(paddings)/sizeof(int); i++) { + wolfSSL_EVP_MD_CTX_init(&mdCtx); + AssertIntEQ(wolfSSL_EVP_DigestSignInit(&mdCtx, &keyCtx, + wolfSSL_EVP_sha256(), NULL, privKey), 1); + AssertIntEQ(wolfSSL_EVP_PKEY_CTX_set_rsa_padding(keyCtx, + paddings[i]), 1); + AssertIntEQ(wolfSSL_EVP_DigestSignUpdate(&mdCtx, testData, + (unsigned int)XSTRLEN(testData)), 1); + AssertIntEQ(wolfSSL_EVP_DigestSignFinal(&mdCtx, NULL, &checkSz), 1); + AssertIntEQ((int)checkSz, sz); + AssertIntEQ(wolfSSL_EVP_DigestSignFinal(&mdCtx, check, &checkSz), 1); + AssertIntEQ((int)checkSz,sz); + AssertIntEQ(wolfSSL_EVP_MD_CTX_cleanup(&mdCtx), 1); + + wolfSSL_EVP_MD_CTX_init(&mdCtx); + AssertIntEQ(wolfSSL_EVP_DigestVerifyInit(&mdCtx, &keyCtx, + wolfSSL_EVP_sha256(), NULL, pubKey), 1); + AssertIntEQ(wolfSSL_EVP_PKEY_CTX_set_rsa_padding(keyCtx, + paddings[i]), 1); + AssertIntEQ(wolfSSL_EVP_DigestVerifyUpdate(&mdCtx, testData, + (unsigned int)XSTRLEN(testData)), 1); + AssertIntEQ(wolfSSL_EVP_DigestVerifyFinal(&mdCtx, check, checkSz), 1); + AssertIntEQ(wolfSSL_EVP_MD_CTX_cleanup(&mdCtx), 1); + } + wolfSSL_EVP_PKEY_free(pubKey); wolfSSL_EVP_PKEY_free(privKey); diff --git a/wolfcrypt/src/evp.c b/wolfcrypt/src/evp.c index e84ccf3f9..c3678ab15 100644 --- a/wolfcrypt/src/evp.c +++ b/wolfcrypt/src/evp.c @@ -2519,8 +2519,8 @@ int wolfSSL_EVP_DigestSignFinal(WOLFSSL_EVP_MD_CTX *ctx, unsigned char *sig, int nid = wolfSSL_EVP_MD_type(wolfSSL_EVP_MD_CTX_md(ctx)); if (nid < 0) break; - ret = wolfSSL_RSA_sign(nid, digest, hashLen, sig, &sigSz, - ctx->pctx->pkey->rsa); + ret = wolfSSL_RSA_sign_ex(nid, digest, hashLen, sig, &sigSz, + ctx->pctx->pkey->rsa, 1, ctx->pctx->padding); if (ret >= 0) *siglen = sigSz; break; @@ -2614,9 +2614,9 @@ int wolfSSL_EVP_DigestVerifyFinal(WOLFSSL_EVP_MD_CTX *ctx, int nid = wolfSSL_EVP_MD_type(wolfSSL_EVP_MD_CTX_md(ctx)); if (nid < 0) return WOLFSSL_FAILURE; - return wolfSSL_RSA_verify(nid, digest, hashLen, sig, + return wolfSSL_RSA_verify_ex(nid, digest, hashLen, sig, (unsigned int)siglen, - ctx->pctx->pkey->rsa); + ctx->pctx->pkey->rsa, ctx->pctx->padding); } #endif /* NO_RSA */ diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index 742853c6c..f36dfbfe9 100644 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -1748,6 +1748,45 @@ int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, return ret; } +int hash2mgf(enum wc_HashType hType) +{ + switch (hType) { +#ifndef NO_SHA + case WC_HASH_TYPE_SHA: + return WC_MGF1SHA1; +#endif +#ifndef NO_SHA256 +#ifdef WOLFSSL_SHA224 + case WC_HASH_TYPE_SHA224: + return WC_MGF1SHA224; +#endif + case WC_HASH_TYPE_SHA256: + return WC_MGF1SHA256; +#endif +#ifdef WOLFSSL_SHA384 + case WC_HASH_TYPE_SHA384: + return WC_MGF1SHA384; +#endif +#ifdef WOLFSSL_SHA512 + case WC_HASH_TYPE_SHA512: + return WC_MGF1SHA512; +#endif + case WC_HASH_TYPE_NONE: + case WC_HASH_TYPE_MD2: + case WC_HASH_TYPE_MD4: + case WC_HASH_TYPE_MD5: + case WC_HASH_TYPE_MD5_SHA: + case WC_HASH_TYPE_SHA3_224: + case WC_HASH_TYPE_SHA3_256: + case WC_HASH_TYPE_SHA3_384: + case WC_HASH_TYPE_SHA3_512: + case WC_HASH_TYPE_BLAKE2B: + case WC_HASH_TYPE_BLAKE2S: + default: + WOLFSSL_MSG("Unrecognized or unsupported hash function"); + return WC_MGF1NONE; + } +} #ifdef WC_RSA_NONBLOCK static int wc_RsaFunctionNonBlock(const byte* in, word32 inLen, byte* out, @@ -3209,11 +3248,12 @@ int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out, RsaKey* key) int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, word32 outLen, RsaKey* key) { - return wc_RsaSSL_Verify_ex(in, inLen, out, outLen, key , WC_RSA_PKCSV15_PAD); + return wc_RsaSSL_Verify_ex(in, inLen, out, outLen, key , WC_RSA_PKCSV15_PAD, + WC_HASH_TYPE_NONE); } int wc_RsaSSL_Verify_ex(const byte* in, word32 inLen, byte* out, word32 outLen, - RsaKey* key, int pad_type) + RsaKey* key, int pad_type, enum wc_HashType hash) { WC_RNG* rng; @@ -3229,7 +3269,7 @@ int wc_RsaSSL_Verify_ex(const byte* in, word32 inLen, byte* out, word32 outLen, return RsaPrivateDecryptEx((byte*)in, inLen, out, outLen, NULL, key, RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, pad_type, - WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng); + hash, hash2mgf(hash), NULL, 0, RSA_PSS_SALT_LEN_DEFAULT, rng); } #endif diff --git a/wolfssl/openssl/rsa.h b/wolfssl/openssl/rsa.h index 5445db196..a818007f1 100644 --- a/wolfssl/openssl/rsa.h +++ b/wolfssl/openssl/rsa.h @@ -116,10 +116,14 @@ WOLFSSL_API int wolfSSL_RSA_sign(int type, const unsigned char* m, unsigned int* sigLen, WOLFSSL_RSA*); WOLFSSL_API int wolfSSL_RSA_sign_ex(int type, const unsigned char* m, unsigned int mLen, unsigned char* sigRet, - unsigned int* sigLen, WOLFSSL_RSA*, int); + unsigned int* sigLen, WOLFSSL_RSA*, int, int); WOLFSSL_API int wolfSSL_RSA_verify(int type, const unsigned char* m, unsigned int mLen, const unsigned char* sig, unsigned int sigLen, WOLFSSL_RSA*); +WOLFSSL_API int wolfSSL_RSA_verify_ex(int type, const unsigned char* m, + unsigned int mLen, const unsigned char* sig, + unsigned int sigLen, WOLFSSL_RSA* rsa, + int padding); WOLFSSL_API int wolfSSL_RSA_public_decrypt(int flen, const unsigned char* from, unsigned char* to, WOLFSSL_RSA*, int padding); WOLFSSL_API int wolfSSL_RSA_GenAdd(WOLFSSL_RSA*); diff --git a/wolfssl/wolfcrypt/rsa.h b/wolfssl/wolfcrypt/rsa.h index 8feee70d4..f77c942d3 100644 --- a/wolfssl/wolfcrypt/rsa.h +++ b/wolfssl/wolfcrypt/rsa.h @@ -245,7 +245,8 @@ WOLFSSL_API int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out, WOLFSSL_API int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, word32 outLen, RsaKey* key); WOLFSSL_API int wc_RsaSSL_Verify_ex(const byte* in, word32 inLen, byte* out, - word32 outLen, RsaKey* key, int pad_type); + word32 outLen, RsaKey* key, int pad_type, + enum wc_HashType hash); WOLFSSL_API int wc_RsaPSS_VerifyInline(byte* in, word32 inLen, byte** out, enum wc_HashType hash, int mgf, RsaKey* key); @@ -367,6 +368,8 @@ WOLFSSL_LOCAL int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** ou int mgf, byte* optLabel, word32 labelLen, int saltLen, int bits, void* heap); +WOLFSSL_LOCAL int hash2mgf(enum wc_HashType hType); + #endif /* HAVE_USER_RSA */ #ifdef __cplusplus