diff --git a/src/internal.c b/src/internal.c index d88716f46..18285c112 100755 --- a/src/internal.c +++ b/src/internal.c @@ -2732,8 +2732,41 @@ void FreeX509(WOLFSSL_X509* x509) #ifndef NO_RSA +#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS) +static int ConvertHashPss(int hashAlgo, enum wc_HashType* hashType, int* mgf) { + switch (hashAlgo) { + #ifdef WOLFSSL_SHA512 + case sha512_mac: + *hashType = WC_HASH_TYPE_SHA512; + if (mgf != NULL) + *mgf = WC_MGF1SHA512; + break; + #endif + #ifdef WOLFSSL_SHA384 + case sha384_mac: + *hashType = WC_HASH_TYPE_SHA384; + if (mgf != NULL) + *mgf = WC_MGF1SHA384; + break; + #endif + #ifndef NO_SHA256 + case sha256_mac: + *hashType = WC_HASH_TYPE_SHA256; + if (mgf != NULL) + *mgf = WC_MGF1SHA256; + break; + #endif + default: + return BAD_FUNC_ARG; + } + + return 0; +} +#endif + int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out, - word32* outSz, RsaKey* key, const byte* keyBuf, word32 keySz, void* ctx) + word32* outSz, int sigAlgo, int hashAlgo, RsaKey* key, + const byte* keyBuf, word32 keySz, void* ctx) { int ret; @@ -2741,6 +2774,8 @@ int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out, (void)keyBuf; (void)keySz; (void)ctx; + (void)sigAlgo; + (void)hashAlgo; WOLFSSL_ENTER("RsaSign"); @@ -2752,7 +2787,20 @@ int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out, else #endif /*HAVE_PK_CALLBACKS */ { - ret = wc_RsaSSL_Sign(in, inSz, out, *outSz, key, ssl->rng); +#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS) + if (sigAlgo == rsa_pss_sa_algo) { + enum wc_HashType hashType = WC_HASH_TYPE_NONE; + int mgf = 0; + + ret = ConvertHashPss(hashAlgo, &hashType, &mgf); + if (ret != 0) + return ret; + ret = wc_RsaPSS_Sign(in, inSz, out, *outSz, hashType, mgf, key, + ssl->rng); + } + else +#endif + ret = wc_RsaSSL_Sign(in, inSz, out, *outSz, key, ssl->rng); } /* Handle async pending response */ @@ -2795,35 +2843,17 @@ int RsaVerify(WOLFSSL* ssl, byte* in, word32 inSz, byte** out, int sigAlgo, else #endif /*HAVE_PK_CALLBACKS */ { -#ifdef WOLFSSL_TLS13 - #ifdef WC_RSA_PSS +#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS) if (sigAlgo == rsa_pss_sa_algo) { enum wc_HashType hashType = WC_HASH_TYPE_NONE; int mgf = 0; - switch (hashAlgo) { - case sha512_mac: - #ifdef WOLFSSL_SHA512 - hashType = WC_HASH_TYPE_SHA512; - mgf = WC_MGF1SHA512; - #endif - break; - case sha384_mac: - #ifdef WOLFSSL_SHA384 - hashType = WC_HASH_TYPE_SHA384; - mgf = WC_MGF1SHA384; - #endif - break; - case sha256_mac: - #ifndef NO_SHA256 - hashType = WC_HASH_TYPE_SHA256; - mgf = WC_MGF1SHA256; - #endif - break; - } + + ret = ConvertHashPss(hashAlgo, &hashType, &mgf); + if (ret != 0) + return ret; ret = wc_RsaPSS_VerifyInline(in, inSz, out, hashType, mgf, key); } else - #endif #endif ret = wc_RsaSSL_VerifyInline(in, inSz, out, key); } @@ -2840,14 +2870,40 @@ int RsaVerify(WOLFSSL* ssl, byte* in, word32 inSz, byte** out, int sigAlgo, return ret; } +#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS) +int CheckRsaPssPadding(const byte* plain, word32 plainSz, byte* out, + word32 sigSz, enum wc_HashType hashType) +{ + int ret; + + if (plainSz != sigSz || out == NULL) + ret = VERIFY_CERT_ERROR; + else { + out -= 2 * sigSz; + XMEMCPY(out, plain, plainSz); + out -= 8; + XMEMSET(out, 0, 8); + wc_Hash(hashType, out, 8 + plainSz * 2, out, plainSz); + if (XMEMCMP(out, out + 8 + plainSz * 2, plainSz) != 0) + ret = VERIFY_CERT_ERROR; + else + ret = 0; + } + + return ret; +} +#endif + /* Verify RSA signature, 0 on success */ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz, - const byte* plain, word32 plainSz, RsaKey* key) + const byte* plain, word32 plainSz, int sigAlgo, int hashAlgo, RsaKey* key) { byte* out = NULL; /* inline result */ int ret; (void)ssl; + (void)sigAlgo; + (void)hashAlgo; WOLFSSL_ENTER("VerifyRsaSign"); @@ -2860,15 +2916,31 @@ int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz, return BUFFER_E; } - ret = wc_RsaSSL_VerifyInline(verifySig, sigSz, &out, key); +#if defined(WOLFSSL_TLS13) && defined(WC_RSA_PSS) + if (sigAlgo == rsa_pss_sa_algo) { + enum wc_HashType hashType = WC_HASH_TYPE_NONE; + int mgf = 0; - if (ret > 0) { - if (ret != (int)plainSz || !out || - XMEMCMP(plain, out, plainSz) != 0) { - WOLFSSL_MSG("RSA Signature verification failed"); - ret = RSA_SIGN_FAULT; - } else { - ret = 0; /* RSA reset */ + ret = ConvertHashPss(hashAlgo, &hashType, &mgf); + if (ret != 0) + return ret; + ret = wc_RsaPSS_VerifyInline(verifySig, sigSz, &out, hashType, mgf, + key); + if (ret > 0) + ret = CheckRsaPssPadding(plain, plainSz, out, ret, hashType); + } + else +#endif + { + ret = wc_RsaSSL_VerifyInline(verifySig, sigSz, &out, key); + if (ret > 0) { + if (ret != (int)plainSz || !out || + XMEMCMP(plain, out, plainSz) != 0) { + WOLFSSL_MSG("RSA Signature verification failed"); + ret = RSA_SIGN_FAULT; + } else { + ret = 0; /* RSA reset */ + } } } @@ -18216,7 +18288,7 @@ int SendCertificateVerify(WOLFSSL* ssl) ret = RsaSign(ssl, ssl->buffers.sig.buffer, ssl->buffers.sig.length, args->verify + args->extraSz + VERIFY_HEADER, &args->sigSz, - key, + rsa_sa_algo, no_mac, key, ssl->buffers.key->buffer, ssl->buffers.key->length, #ifdef HAVE_PK_CALLBACKS @@ -18271,7 +18343,7 @@ int SendCertificateVerify(WOLFSSL* ssl) ret = VerifyRsaSign(ssl, args->verifySig, args->sigSz, ssl->buffers.sig.buffer, ssl->buffers.sig.length, - key + rsa_sa_algo, no_mac, key ); } #endif /* !NO_RSA */ @@ -19816,7 +19888,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, ssl->buffers.sig.length, args->output + args->idx, &args->sigSz, - key, + rsa_sa_algo, no_mac, key, ssl->buffers.key->buffer, ssl->buffers.key->length, #ifdef HAVE_PK_CALLBACKS @@ -19872,7 +19944,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, ssl->buffers.sig.length, args->output + args->idx, &args->sigSz, - key, + rsa_sa_algo, no_mac, key, ssl->buffers.key->buffer, ssl->buffers.key->length, #ifdef HAVE_PK_CALLBACKS @@ -19955,7 +20027,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, args->verifySig, args->sigSz, ssl->buffers.sig.buffer, ssl->buffers.sig.length, - key + rsa_sa_algo, no_mac, key ); break; } @@ -20010,7 +20082,7 @@ static int DoSessionTicket(WOLFSSL* ssl, const byte* input, word32* inOutIdx, args->verifySig, args->sigSz, ssl->buffers.sig.buffer, ssl->buffers.sig.length, - key + rsa_sa_algo, no_mac, key ); break; } diff --git a/src/tls.c b/src/tls.c index df8f7c5b5..5b845b629 100755 --- a/src/tls.c +++ b/src/tls.c @@ -4566,7 +4566,7 @@ static word16 TLSX_SignatureAlgorithms_Write(byte* data, byte* output) static int TLSX_SignatureAlgorithms_Parse(WOLFSSL *ssl, byte* input, word16 length) { - int ret = 0; + int i; word16 len; (void)ssl; @@ -4581,9 +4581,13 @@ static int TLSX_SignatureAlgorithms_Parse(WOLFSSL *ssl, byte* input, if (length != OPAQUE16_LEN + len) return BUFFER_ERROR; - /* Ignore for now. */ + ssl->pssAlgo = 0; + for (i = 0; i < len; i += 2) { + if (input[i] == 0x08 && input[i + 1] <= 0x06) + ssl->pssAlgo |= 1 << input[i + 1]; + } - return ret; + return 0; } /* Sets a new SupportedVersions extension into the extension list. diff --git a/src/tls13.c b/src/tls13.c index 85a179162..df5efa840 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -2882,18 +2882,24 @@ static INLINE void EncodeSigAlg(byte hashAlgo, byte hsType, byte* output) { switch (hsType) { #ifdef HAVE_ECC - case DYNAMIC_TYPE_ECC: + case ecc_dsa_sa_algo: output[0] = hashAlgo; output[1] = ecc_dsa_sa_algo; break; #endif #ifndef NO_RSA - case DYNAMIC_TYPE_RSA: + case rsa_sa_algo: output[0] = hashAlgo; output[1] = rsa_sa_algo; break; -#endif + #ifdef WC_RSA_PSS /* PSS signatures: 0x080[4-6] */ + case rsa_pss_sa_algo: + output[0] = rsa_pss_sa_algo; + output[1] = hashAlgo; + break; + #endif +#endif /* ED25519: 0x0807 */ /* ED448: 0x0808 */ } @@ -2908,6 +2914,7 @@ static INLINE void EncodeSigAlg(byte hashAlgo, byte hsType, byte* output) static INLINE void DecodeSigAlg(byte* input, byte* hashAlgo, byte* hsType) { switch (input[0]) { + #ifdef WC_RSA_PSS case 0x08: /* PSS signatures: 0x080[4-6] */ if (input[1] <= 0x06) { @@ -2915,6 +2922,7 @@ static INLINE void DecodeSigAlg(byte* input, byte* hashAlgo, byte* hsType) *hashAlgo = input[1]; } break; + #endif /* ED25519: 0x0807 */ /* ED448: 0x0808 */ default: @@ -3014,12 +3022,22 @@ static void CreateSigData(WOLFSSL* ssl, byte* sigData, word16* sigDataSz, * returns the length of the encoded signature or negative on error. */ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz, - int hashAlgo) + int sigAlgo, int hashAlgo) { Digest digest; int hashSz = 0; int hashOid = 0; int ret = BAD_FUNC_ARG; + byte* hash; + + (void)sigAlgo; + +#ifdef WC_RSA_PSS + if (sigAlgo == rsa_pss_sa_algo) + hash = sig; + else +#endif + hash = sigData; /* Digest the signature data. */ switch (hashAlgo) { @@ -3029,7 +3047,7 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz, if (ret == 0) { ret = wc_Sha256Update(&digest.sha256, sigData, sigDataSz); if (ret == 0) - ret = wc_Sha256Final(&digest.sha256, sigData); + ret = wc_Sha256Final(&digest.sha256, hash); wc_Sha256Free(&digest.sha256); } hashSz = SHA256_DIGEST_SIZE; @@ -3042,7 +3060,7 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz, if (ret == 0) { ret = wc_Sha384Update(&digest.sha384, sigData, sigDataSz); if (ret == 0) - ret = wc_Sha384Final(&digest.sha384, sigData); + ret = wc_Sha384Final(&digest.sha384, hash); wc_Sha384Free(&digest.sha384); } hashSz = SHA384_DIGEST_SIZE; @@ -3055,7 +3073,7 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz, if (ret == 0) { ret = wc_Sha512Update(&digest.sha512, sigData, sigDataSz); if (ret == 0) - ret = wc_Sha512Final(&digest.sha512, sigData); + ret = wc_Sha512Final(&digest.sha512, hash); wc_Sha512Free(&digest.sha512); } hashSz = SHA512_DIGEST_SIZE; @@ -3067,8 +3085,15 @@ static int CreateRSAEncodedSig(byte* sig, byte* sigData, int sigDataSz, if (ret != 0) return ret; - /* Encode the signature data as per PKCS #1.5 */ - return wc_EncodeSignature(sig, sigData, hashSz, hashOid); +#ifdef WC_RSA_PSS + if (sigAlgo == rsa_pss_sa_algo) + return hashSz; + else +#endif + { + /* Encode the signature data as per PKCS #1.5 */ + return wc_EncodeSignature(sig, hash, hashSz, hashOid); + } } #ifdef HAVE_ECC @@ -3156,7 +3181,39 @@ static int CheckRSASignature(WOLFSSL* ssl, int sigAlgo, int hashAlgo, #endif word32 sigSz; - if (sigAlgo == rsa_sa_algo) { + CreateSigData(ssl, sigData, &sigDataSz, 1); +#ifdef WC_RSA_PSS + if (sigAlgo == rsa_pss_sa_algo) { + int hashType = WC_HASH_TYPE_NONE; + + switch (hashAlgo) { + case sha512_mac: + #ifdef WOLFSSL_SHA512 + hashType = WC_HASH_TYPE_SHA512; + #endif + break; + case sha384_mac: + #ifdef WOLFSSL_SHA384 + hashType = WC_HASH_TYPE_SHA384; + #endif + break; + case sha256_mac: + #ifndef NO_SHA256 + hashType = WC_HASH_TYPE_SHA256; + #endif + break; + } + + ret = sigSz = CreateRSAEncodedSig(sigData, sigData, sigDataSz, + rsa_pss_sa_algo, hashAlgo); + if (ret < 0) + return ret; + + ret = CheckRsaPssPadding(sigData, sigSz, decSig, decSigSz, hashType); + } + else +#endif + { #ifdef WOLFSSL_SMALL_STACK encodedSig = (byte*)XMALLOC(MAX_ENCODED_SIG_SZ, ssl->heap, DYNAMIC_TYPE_TMP_BUFFER); @@ -3166,29 +3223,14 @@ static int CheckRSASignature(WOLFSSL* ssl, int sigAlgo, int hashAlgo, } #endif - CreateSigData(ssl, sigData, &sigDataSz, 1); - sigSz = CreateRSAEncodedSig(encodedSig, sigData, sigDataSz, hashAlgo); + sigSz = CreateRSAEncodedSig(encodedSig, sigData, sigDataSz, + DYNAMIC_TYPE_RSA, hashAlgo); /* Check the encoded and decrypted signature data match. */ if (decSigSz != sigSz || decSig == NULL || XMEMCMP(decSig, encodedSig, sigSz) != 0) { ret = VERIFY_CERT_ERROR; } } - else { - CreateSigData(ssl, sigData, &sigDataSz, 1); - sigSz = CreateECCEncodedSig(sigData, sigDataSz, hashAlgo); - if (decSigSz != sigSz || decSig == NULL) - ret = VERIFY_CERT_ERROR; - else { - decSig -= 2 * decSigSz; - XMEMCPY(decSig, sigData, decSigSz); - decSig -= 8; - XMEMSET(decSig, 0, 8); - CreateECCEncodedSig(decSig, 8 + decSigSz * 2, hashAlgo); - if (XMEMCMP(decSig, decSig + 8 + decSigSz * 2, decSigSz) != 0) - ret = VERIFY_CERT_ERROR; - } - } #ifdef WOLFSSL_SMALL_STACK end: @@ -3465,6 +3507,7 @@ typedef struct Scv13Args { int sendSz; word16 length; + int sigAlgo; byte* sigData; word16 sigDataSz; } Scv13Args; @@ -3570,7 +3613,17 @@ int SendTls13CertificateVerify(WOLFSSL* ssl) goto exit_scv; /* Add signature algorithm. */ - EncodeSigAlg(ssl->suites->hashAlgo, ssl->hsType, args->verify); + if (ssl->hsType == DYNAMIC_TYPE_RSA) { + #ifdef WC_RSA_PSS + if (ssl->pssAlgo | (1 << ssl->suites->hashAlgo)) + args->sigAlgo = rsa_pss_sa_algo; + else + #endif + args->sigAlgo = rsa_sa_algo; + } + else if (ssl->hsType == DYNAMIC_TYPE_ECC) + args->sigAlgo = ecc_dsa_sa_algo; + EncodeSigAlg(ssl->suites->hashAlgo, args->sigAlgo, args->verify); /* Create the data to be signed. */ args->sigData = (byte*)XMALLOC(MAX_SIG_DATA_SZ, ssl->heap, @@ -3591,9 +3644,8 @@ int SendTls13CertificateVerify(WOLFSSL* ssl) ERROR_OUT(MEMORY_E, exit_scv); } - /* Digest the signature data and encode. Used in verify too. */ ret = CreateRSAEncodedSig(sig->buffer, args->sigData, - args->sigDataSz, ssl->suites->hashAlgo); + args->sigDataSz, args->sigAlgo, ssl->suites->hashAlgo); if (ret < 0) goto exit_scv; sig->length = ret; @@ -3645,6 +3697,7 @@ int SendTls13CertificateVerify(WOLFSSL* ssl) ret = RsaSign(ssl, sig->buffer, sig->length, args->verify + HASH_SIG_SIZE + VERIFY_HEADER, &args->sigLen, + args->sigAlgo, ssl->suites->hashAlgo, (RsaKey*)ssl->hsKey, ssl->buffers.key->buffer, ssl->buffers.key->length, #ifdef HAVE_PK_CALLBACKS @@ -3690,7 +3743,8 @@ int SendTls13CertificateVerify(WOLFSSL* ssl) /* check for signature faults */ ret = VerifyRsaSign(ssl, args->verifySig, args->sigLen, - sig->buffer, sig->length, (RsaKey*)ssl->hsKey); + sig->buffer, sig->length, args->sigAlgo, + ssl->suites->hashAlgo, (RsaKey*)ssl->hsKey); } #endif /* !NO_RSA */ diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index 7f9da78b7..f38916427 100755 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -1587,6 +1587,15 @@ int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out, word32 outLen, WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng); } +#ifdef WC_RSA_PSS +int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out, word32 outLen, + enum wc_HashType hash, int mgf, RsaKey* key, WC_RNG* rng) +{ + return RsaPublicEncryptEx(in, inLen, out, outLen, key, + RSA_PRIVATE_ENCRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD, + hash, mgf, NULL, 0, rng); +} +#endif int wc_RsaEncryptSize(RsaKey* key) { diff --git a/wolfssl/internal.h b/wolfssl/internal.h index ea56d93b6..078c2f98f 100755 --- a/wolfssl/internal.h +++ b/wolfssl/internal.h @@ -3098,7 +3098,8 @@ struct WOLFSSL { byte user_set_QSHSchemes; #endif #ifdef WOLFSSL_TLS13 - word16 namedGroup; + word16 namedGroup; + byte pssAlgo; #endif #ifdef HAVE_NTRU word16 peerNtruKeyLen; @@ -3411,12 +3412,17 @@ WOLFSSL_LOCAL void ShrinkOutputBuffer(WOLFSSL* ssl); WOLFSSL_LOCAL int VerifyClientSuite(WOLFSSL* ssl); #ifndef NO_CERTS #ifndef NO_RSA + WOLFSSL_LOCAL int CheckRsaPssPadding(const byte* plain, word32 plainSz, + byte* out, word32 sigSz, + enum wc_HashType hashType); WOLFSSL_LOCAL int VerifyRsaSign(WOLFSSL* ssl, byte* verifySig, word32 sigSz, const byte* plain, word32 plainSz, + int sigAlgo, int hashAlgo, RsaKey* key); - WOLFSSL_LOCAL int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, byte* out, - word32* outSz, RsaKey* key, const byte* keyBuf, word32 keySz, void* ctx); + WOLFSSL_LOCAL int RsaSign(WOLFSSL* ssl, const byte* in, word32 inSz, + byte* out, word32* outSz, int sigAlgo, int hashAlgo, RsaKey* key, + const byte* keyBuf, word32 keySz, void* ctx); WOLFSSL_LOCAL int RsaVerify(WOLFSSL* ssl, byte* in, word32 inSz, byte** out, int sigAlgo, int hashAlgo, RsaKey* key, const byte* keyBuf, word32 keySz, void* ctx); diff --git a/wolfssl/wolfcrypt/rsa.h b/wolfssl/wolfcrypt/rsa.h index 6905d1dd2..c1b86a893 100644 --- a/wolfssl/wolfcrypt/rsa.h +++ b/wolfssl/wolfcrypt/rsa.h @@ -122,6 +122,9 @@ WOLFSSL_API int wc_RsaPrivateDecrypt(const byte* in, word32 inLen, byte* out, word32 outLen, RsaKey* key); WOLFSSL_API int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out, word32 outLen, RsaKey* key, WC_RNG* rng); +WOLFSSL_API int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out, + word32 outLen, enum wc_HashType hash, int mgf, + RsaKey* key, WC_RNG* rng); WOLFSSL_API int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out, RsaKey* key); WOLFSSL_API int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out,