diff --git a/src/tls13.c b/src/tls13.c index 866d8b2f8..c37e5922f 100644 --- a/src/tls13.c +++ b/src/tls13.c @@ -4145,7 +4145,7 @@ static int CreateECCEncodedSig(byte* sigData, int sigDataSz, int hashAlgo) * based on the digest of the signature data. * * ssl The SSL/TLS object. - * hashAlgo The signature algorithm used to generate signature. + * sigAlgo The signature algorithm used to generate signature. * hashAlgo The hash algorithm used to generate signature. * decSig The decrypted signature. * decSigSz The size of the decrypted signature. @@ -4170,7 +4170,7 @@ static int CheckRSASignature(WOLFSSL* ssl, int sigAlgo, int hashAlgo, if (ret < 0) return ret; - /* PSS signature can be done in-pace */ + /* PSS signature can be done in-place */ ret = CreateRSAEncodedSig(sigData, sigData, sigDataSz, sigAlgo, hashAlgo); if (ret < 0) diff --git a/wolfcrypt/src/error.c b/wolfcrypt/src/error.c index 50f529e10..280f14d90 100644 --- a/wolfcrypt/src/error.c +++ b/wolfcrypt/src/error.c @@ -440,6 +440,9 @@ const char* wc_GetErrorString(int error) case WC_HW_WAIT_E: return "Hardware waiting on resource"; + case PSS_SALTLEN_E: + return "PSS - Length of salt is too big for hash algorithm"; + default: return "unknown error number"; diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index 0e803138e..178b41e95 100644 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -702,10 +702,23 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock, /* 0x00 .. 0x00 0x01 | Salt | Gen Hash | 0xbc * XOR MGF over all bytes down to end of Salt * Gen Hash = HASH(8 * 0x00 | Message Hash | Salt) + * + * input Digest of the message. + * inputLen Length of digest. + * pkcsBlock Buffer to write to. + * pkcsBlockLen Length of buffer to write to. + * rng Random number generator (for salt). + * htype Hash function to use. + * mgf Mask generation function. + * saltLen Length of salt to put in padding. + * bits Length of key in bits. + * heap Used for dynamic memory allocation. + * returns 0 on success, PSS_SALTLEN_E when the salt length is invalid + * and other negative values on error. */ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock, word32 pkcsBlockLen, WC_RNG* rng, enum wc_HashType hType, int mgf, - int bits, void* heap) + int saltLen, int bits, void* heap) { int ret; int hLen, i; @@ -718,15 +731,22 @@ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock, if (hLen < 0) return hLen; + if (saltLen == -1) + saltLen = hLen; + else if (saltLen > hLen || saltLen < -1) + return PSS_SALTLEN_E; + if ((int)pkcsBlockLen - hLen - 1 < saltLen + 2) + return PSS_SALTLEN_E; + s = m = pkcsBlock; - XMEMSET(m, 0, 8); - m += 8; + XMEMSET(m, 0, RSA_PSS_PAD_SZ); + m += RSA_PSS_PAD_SZ; XMEMCPY(m, input, inputLen); m += inputLen; - if ((ret = wc_RNG_GenerateBlock(rng, salt, hLen)) != 0) + if ((ret = wc_RNG_GenerateBlock(rng, salt, saltLen)) != 0) return ret; - XMEMCPY(m, salt, hLen); - m += hLen; + XMEMCPY(m, salt, saltLen); + m += saltLen; h = pkcsBlock + pkcsBlockLen - 1 - hLen; if ((ret = wc_Hash(hType, s, (word32)(m - s), h, hLen)) != 0) @@ -738,9 +758,9 @@ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock, return ret; pkcsBlock[0] &= (1 << ((bits - 1) & 0x7)) - 1; - m = pkcsBlock + pkcsBlockLen - 1 - hLen - hLen - 1; + m = pkcsBlock + pkcsBlockLen - 1 - saltLen - hLen - 1; *(m++) ^= 0x01; - for (i = 0; i < hLen; i++) + for (i = 0; i < saltLen; i++) m[i] ^= salt[i]; return 0; @@ -799,8 +819,8 @@ static int RsaPad(const byte* input, word32 inputLen, byte* pkcsBlock, /* helper function to direct which padding is used */ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock, word32 pkcsBlockLen, byte padValue, WC_RNG* rng, int padType, - enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen, int bits, - void* heap) + enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen, + int saltLen, int bits, void* heap) { int ret; @@ -824,7 +844,7 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock, case WC_RSA_PSS_PAD: WOLFSSL_MSG("wolfSSL Using RSA PSS padding"); ret = RsaPad_PSS(input, inputLen, pkcsBlock, pkcsBlockLen, rng, - hType, mgf, bits, heap); + hType, mgf, saltLen, bits, heap); break; #endif @@ -838,6 +858,7 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock, (void)mgf; (void)optLabel; (void)labelLen; + (void)saltLen; (void)bits; (void)heap; @@ -934,9 +955,23 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen, #endif /* WC_NO_RSA_OAEP */ #ifdef WC_RSA_PSS +/* 0x00 .. 0x00 0x01 | Salt | Gen Hash | 0xbc + * MGF over all bytes down to end of Salt + * + * pkcsBlock Buffer holding decrypted data. + * pkcsBlockLen Length of buffer. + * htype Hash function to use. + * mgf Mask generation function. + * saltLen Length of salt to put in padding. + * bits Length of key in bits. + * heap Used for dynamic memory allocation. + * returns 0 on success, PSS_SALTLEN_E when the salt length is invalid, + * BAD_PADDING_E when the padding is not valid, MEMORY_E when allocation fails + * and other negative values on error. + */ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, byte **output, enum wc_HashType hType, int mgf, - int bits, void* heap) + int saltLen, int bits, void* heap) { int ret; byte* tmp; @@ -946,15 +981,21 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, if (hLen < 0) return hLen; + if (saltLen == -1) + saltLen = hLen; + else if (saltLen > hLen || saltLen < -1) + return PSS_SALTLEN_E; + if ((int)pkcsBlockLen - hLen - 1 < saltLen + 2) + return PSS_SALTLEN_E; + if (pkcsBlock[pkcsBlockLen - 1] != 0xbc) { WOLFSSL_MSG("RsaUnPad_PSS: Padding Error 0xBC"); return BAD_PADDING_E; } tmp = (byte*)XMALLOC(pkcsBlockLen, heap, DYNAMIC_TYPE_RSA_BUFFER); - if (tmp == NULL) { + if (tmp == NULL) return MEMORY_E; - } if ((ret = RsaMGF(mgf, pkcsBlock + pkcsBlockLen - 1 - hLen, hLen, tmp, pkcsBlockLen - 1 - hLen, heap)) != 0) { @@ -963,7 +1004,7 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, } tmp[0] &= (1 << ((bits - 1) & 0x7)) - 1; - for (i = 0; i < (int)(pkcsBlockLen - 1 - hLen - hLen - 1); i++) { + for (i = 0; i < (int)(pkcsBlockLen - 1 - saltLen - hLen - 1); i++) { if (tmp[i] != pkcsBlock[i]) { XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER); WOLFSSL_MSG("RsaUnPad_PSS: Padding Error Match"); @@ -980,11 +1021,11 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER); - i = pkcsBlockLen - (RSA_PSS_PAD_SZ + 3 * hLen + 1); + i = pkcsBlockLen - (RSA_PSS_PAD_SZ + saltLen + 2 * hLen + 1); XMEMSET(pkcsBlock + i, 0, RSA_PSS_PAD_SZ); *output = pkcsBlock + i; - return RSA_PSS_PAD_SZ + 3 * hLen; + return RSA_PSS_PAD_SZ + saltLen + 2 * hLen; } #endif @@ -1038,8 +1079,8 @@ static int RsaUnPad(const byte *pkcsBlock, unsigned int pkcsBlockLen, /* helper function to direct unpadding */ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, byte padValue, int padType, enum wc_HashType hType, - int mgf, byte* optLabel, word32 labelLen, int bits, - void* heap) + int mgf, byte* optLabel, word32 labelLen, int saltLen, + int bits, void* heap) { int ret; @@ -1061,7 +1102,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, case WC_RSA_PSS_PAD: WOLFSSL_MSG("wolfSSL Using RSA PSS un-padding"); ret = RsaUnPad_PSS((byte*)pkcsBlock, pkcsBlockLen, out, hType, mgf, - bits, heap); + saltLen, bits, heap); break; #endif @@ -1075,6 +1116,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, (void)mgf; (void)optLabel; (void)labelLen; + (void)saltLen; (void)bits; (void)heap; @@ -1451,12 +1493,15 @@ int wc_RsaFunction(const byte* in, word32 inLen, byte* out, hash : type of hash algorithm to use found in wolfssl/wolfcrypt/hash.h mgf : type of mask generation function to use label : optional label - labelSz : size of optional label buffer */ + labelSz : size of optional label buffer + saltLen : Length of salt used in PSS + rng : random number generator */ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out, word32 outLen, RsaKey* key, int rsa_type, byte pad_value, int pad_type, enum wc_HashType hash, int mgf, - byte* label, word32 labelSz, WC_RNG* rng) + byte* label, word32 labelSz, int saltLen, + WC_RNG* rng) { int ret, sz; @@ -1502,7 +1547,7 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out, #endif ret = wc_RsaPad_ex(in, inLen, out, sz, pad_value, rng, pad_type, hash, - mgf, label, labelSz, mp_count_bits(&key->n), + mgf, label, labelSz, saltLen, mp_count_bits(&key->n), key->heap); if (ret < 0) { break; @@ -1561,12 +1606,15 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out, hash : type of hash algorithm to use found in wolfssl/wolfcrypt/hash.h mgf : type of mask generation function to use label : optional label - labelSz : size of optional label buffer */ + labelSz : size of optional label buffer + saltLen : Length of salt used in PSS + rng : random number generator */ static int RsaPrivateDecryptEx(byte* in, word32 inLen, byte* out, word32 outLen, byte** outPtr, RsaKey* key, int rsa_type, byte pad_value, int pad_type, enum wc_HashType hash, int mgf, - byte* label, word32 labelSz, WC_RNG* rng) + byte* label, word32 labelSz, int saltLen, + WC_RNG* rng) { int ret = RSA_WRONG_TYPE_E; @@ -1636,8 +1684,8 @@ static int RsaPrivateDecryptEx(byte* in, word32 inLen, byte* out, { byte* pad = NULL; ret = wc_RsaUnPad_ex(key->data, key->dataLen, &pad, pad_value, pad_type, - hash, mgf, label, labelSz, mp_count_bits(&key->n), - key->heap); + hash, mgf, label, labelSz, saltLen, + mp_count_bits(&key->n), key->heap); if (ret > 0 && ret <= (int)outLen && pad != NULL) { /* only copy output if not inline */ if (outPtr == NULL) { @@ -1696,7 +1744,7 @@ int wc_RsaPublicEncrypt(const byte* in, word32 inLen, byte* out, word32 outLen, { return RsaPublicEncryptEx(in, inLen, out, outLen, key, RSA_PUBLIC_ENCRYPT, RSA_BLOCK_TYPE_2, WC_RSA_PKCSV15_PAD, - WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng); + WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng); } @@ -1707,7 +1755,7 @@ int wc_RsaPublicEncrypt_ex(const byte* in, word32 inLen, byte* out, word32 labelSz) { return RsaPublicEncryptEx(in, inLen, out, outLen, key, RSA_PUBLIC_ENCRYPT, - RSA_BLOCK_TYPE_2, type, hash, mgf, label, labelSz, rng); + RSA_BLOCK_TYPE_2, type, hash, mgf, label, labelSz, 0, rng); } #endif /* WC_NO_RSA_OAEP */ @@ -1720,7 +1768,7 @@ int wc_RsaPrivateDecryptInline(byte* in, word32 inLen, byte** out, RsaKey* key) #endif return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key, RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, WC_RSA_PKCSV15_PAD, - WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng); + WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng); } @@ -1735,7 +1783,7 @@ int wc_RsaPrivateDecryptInline_ex(byte* in, word32 inLen, byte** out, #endif return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key, RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, type, hash, - mgf, label, labelSz, rng); + mgf, label, labelSz, 0, rng); } #endif /* WC_NO_RSA_OAEP */ @@ -1749,7 +1797,7 @@ int wc_RsaPrivateDecrypt(const byte* in, word32 inLen, byte* out, #endif return RsaPrivateDecryptEx((byte*)in, inLen, out, outLen, NULL, key, RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, WC_RSA_PKCSV15_PAD, - WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng); + WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng); } #ifndef WC_NO_RSA_OAEP @@ -1764,7 +1812,7 @@ int wc_RsaPrivateDecrypt_ex(const byte* in, word32 inLen, byte* out, #endif return RsaPrivateDecryptEx((byte*)in, inLen, out, outLen, NULL, key, RSA_PRIVATE_DECRYPT, RSA_BLOCK_TYPE_2, type, hash, mgf, label, - labelSz, rng); + labelSz, 0, rng); } #endif /* WC_NO_RSA_OAEP */ @@ -1777,7 +1825,7 @@ int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out, RsaKey* key) #endif return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key, RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PKCSV15_PAD, - WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng); + WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng); } int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, word32 outLen, @@ -1795,12 +1843,44 @@ int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, word32 outLen, #endif return RsaPrivateDecryptEx((byte*)in, inLen, out, outLen, NULL, key, RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PKCSV15_PAD, - WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng); + WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng); } #ifdef WC_RSA_PSS +/* Verify the message signed with RSA-PSS. + * The input buffer is reused for the ouput buffer. + * Salt length is equal to hash length. + * + * in Buffer holding encrypted data. + * inLen Length of data in buffer. + * out Pointer to address containing the PSS data. + * hash Hash algorithm. + * mgf Mask generation function. + * key Public RSA key. + * returns the length of the PSS data on success and negative indicates failure. + */ int wc_RsaPSS_VerifyInline(byte* in, word32 inLen, byte** out, enum wc_HashType hash, int mgf, RsaKey* key) +{ + return wc_RsaPSS_VerifyInline_ex(in, inLen, out, hash, mgf, -1, key); +} + +/* Verify the message signed with RSA-PSS. + * The input buffer is reused for the ouput buffer. + * + * in Buffer holding encrypted data. + * inLen Length of data in buffer. + * out Pointer to address containing the PSS data. + * hash Hash algorithm. + * mgf Mask generation function. + * key Public RSA key. + * saltLen Length of salt used. -1 indicates salt length is the same as the + * hash length. + * returns the length of the PSS data on success and negative indicates failure. + */ +int wc_RsaPSS_VerifyInline_ex(byte* in, word32 inLen, byte** out, + enum wc_HashType hash, int mgf, int saltLen, + RsaKey* key) { WC_RNG* rng = NULL; #ifdef WC_RSA_BLINDING @@ -1808,32 +1888,115 @@ int wc_RsaPSS_VerifyInline(byte* in, word32 inLen, byte** out, #endif return RsaPrivateDecryptEx(in, inLen, in, inLen, out, key, RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD, - hash, mgf, NULL, 0, rng); + hash, mgf, NULL, 0, saltLen, rng); } -/* Sig = 8 * 0x00 | Space for Message Hash | Salt | Exp Hash - * Exp Hash = HASH(8 * 0x00 | Message Hash | Salt) +/* Verify the message signed with RSA-PSS. + * Salt length is equal to hash length. + * + * in Buffer holding encrypted data. + * inLen Length of data in buffer. + * out Pointer to address containing the PSS data. + * hash Hash algorithm. + * mgf Mask generation function. + * key Public RSA key. + * returns the length of the PSS data on success and negative indicates failure. + */ +int wc_RsaPSS_Verify(byte* in, word32 inLen, byte* out, word32 outLen, + enum wc_HashType hash, int mgf, RsaKey* key) +{ + return wc_RsaPSS_Verify_ex(in, inLen, out, outLen, hash, mgf, -1, key); +} + +/* Verify the message signed with RSA-PSS. + * + * in Buffer holding encrypted data. + * inLen Length of data in buffer. + * out Pointer to address containing the PSS data. + * hash Hash algorithm. + * mgf Mask generation function. + * key Public RSA key. + * saltLen Length of salt used. -1 indicates salt length is the same as the + * hash length. + * returns the length of the PSS data on success and negative indicates failure. + */ +int wc_RsaPSS_Verify_ex(byte* in, word32 inLen, byte* out, word32 outLen, + enum wc_HashType hash, int mgf, int saltLen, + RsaKey* key) +{ + WC_RNG* rng = NULL; +#ifdef WC_RSA_BLINDING + rng = key->rng; +#endif + return RsaPrivateDecryptEx(in, inLen, out, outLen, NULL, key, + RSA_PUBLIC_DECRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD, + hash, mgf, NULL, 0, saltLen, rng); +} + + +/* Checks the PSS data to ensure that the signature matches. + * Salt length is equal to hash length. + * + * in Hash of the data that is being verified. + * inSz Length of hash. + * sig Buffer holding PSS data. + * sigSz Size of PSS data. + * hashType Hash algorithm. + * returns BAD_PADDING_E when the PSS data is invalid, BAD_FUNC_ARG when + * NULL is passed in to in or sig or inSz is not the same as the hash + * algorithm length and 0 on success. */ int wc_RsaPSS_CheckPadding(const byte* in, word32 inSz, byte* sig, word32 sigSz, enum wc_HashType hashType) { - int ret; + return wc_RsaPSS_CheckPadding_ex(in, inSz, sig, sigSz, hashType, inSz); +} + +/* Checks the PSS data to ensure that the signature matches. + * + * in Hash of the data that is being verified. + * inSz Length of hash. + * sig Buffer holding PSS data. + * sigSz Size of PSS data. + * hashType Hash algorithm. + * saltLen Length of salt used. -1 indicates salt length is the same as the + * hash length. + * returns BAD_PADDING_E when the PSS data is invalid, BAD_FUNC_ARG when + * NULL is passed in to in or sig or inSz is not the same as the hash + * algorithm length and 0 on success. + */ +int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inSz, byte* sig, + word32 sigSz, enum wc_HashType hashType, + int saltLen) +{ + int ret = 0; if (in == NULL || sig == NULL || - inSz != (word32)wc_HashGetDigestSize(hashType) || - sigSz != RSA_PSS_PAD_SZ + inSz * 3) + inSz != (word32)wc_HashGetDigestSize(hashType)) ret = BAD_FUNC_ARG; - else { + + if (ret == 0) { + if (saltLen == -1) + saltLen = inSz; + else if (saltLen < -1 || (word32)saltLen > inSz) + ret = PSS_SALTLEN_E; + } + /* Sig = 8 * 0x00 | Space for Message Hash | Salt | Exp Hash */ + if (ret == 0) { + if (sigSz != RSA_PSS_PAD_SZ + inSz + (word32)saltLen + inSz) + ret = BAD_PADDING_E; + } + /* Exp Hash = HASH(8 * 0x00 | Message Hash | Salt) */ + if (ret == 0) { XMEMCPY(sig + RSA_PSS_PAD_SZ, in, inSz); - ret = wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz * 2, sig, inSz); - if (ret != 0) - return ret; - if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz * 2, inSz) != 0) { + ret = wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz + saltLen, sig, + inSz); + } + if (ret == 0) { + if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz + saltLen, inSz) != 0) { WOLFSSL_MSG("RsaPSS_CheckPadding: Padding Error"); ret = BAD_PADDING_E; } - else - ret = 0; } return ret; @@ -1845,16 +2008,52 @@ int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out, word32 outLen, { return RsaPublicEncryptEx(in, inLen, out, outLen, key, RSA_PRIVATE_ENCRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PKCSV15_PAD, - WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, rng); + WC_HASH_TYPE_NONE, WC_MGF1NONE, NULL, 0, 0, rng); } #ifdef WC_RSA_PSS +/* Sign the hash of a message using RSA-PSS. + * Salt length is equal to hash length. + * + * in Buffer holding hash of message. + * inLen Length of data in buffer (hash length). + * out Buffer to write encrypted signature into. + * outLen Size of buffer to write to. + * hash Hash algorithm. + * mgf Mask generation function. + * key Public RSA key. + * rng Random number generator. + * returns the length of the encrypted signature on success, a negative value + * indicates failure. + */ int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out, word32 outLen, enum wc_HashType hash, int mgf, RsaKey* key, WC_RNG* rng) +{ + return wc_RsaPSS_Sign_ex(in, inLen, out, outLen, hash, mgf, -1, key, rng); +} + +/* Sign the hash of a message using RSA-PSS. + * + * in Buffer holding hash of message. + * inLen Length of data in buffer (hash length). + * out Buffer to write encrypted signature into. + * outLen Size of buffer to write to. + * hash Hash algorithm. + * mgf Mask generation function. + * saltLen Length of salt used. -1 indicates salt length is the same as the + * hash length. + * key Public RSA key. + * rng Random number generator. + * returns the length of the encrypted signature on success, a negative value + * indicates failure. + */ +int wc_RsaPSS_Sign_ex(const byte* in, word32 inLen, byte* out, word32 outLen, + enum wc_HashType hash, int mgf, int saltLen, RsaKey* key, + WC_RNG* rng) { return RsaPublicEncryptEx(in, inLen, out, outLen, key, RSA_PRIVATE_ENCRYPT, RSA_BLOCK_TYPE_1, WC_RSA_PSS_PAD, - hash, mgf, NULL, 0, rng); + hash, mgf, NULL, 0, saltLen, rng); } #endif diff --git a/wolfcrypt/test/test.c b/wolfcrypt/test/test.c index bcf9dcd2d..c1eb272ab 100644 --- a/wolfcrypt/test/test.c +++ b/wolfcrypt/test/test.c @@ -7653,6 +7653,175 @@ done: #endif #define RSA_TEST_BYTES 256 + +#ifdef WC_RSA_PSS +static int rsa_pss_test(WC_RNG* rng, RsaKey* key) +{ + byte digest[WC_MAX_DIGEST_SIZE]; + int ret = 0; + const char* inStr = "Everyone gets Friday off."; + word32 inLen = (word32)XSTRLEN((char*)inStr); + word32 outSz; + word32 plainSz; + word32 digestSz; + int i, j; +#ifdef RSA_PSS_TEST_WRONG_PARAMS + int k, l; +#endif + byte* plain; + int mgf[] = { +#ifndef NO_SHA + WC_MGF1SHA1, +#endif +#ifdef WOLFSSL_SHA224 + WC_MGF1SHA224, +#endif + WC_MGF1SHA256, +#ifdef WOLFSSL_SHA384 + WC_MGF1SHA384, +#endif +#ifdef WOLFSSL_SHA512 + WC_MGF1SHA512 +#endif + }; + enum wc_HashType hash[] = { +#ifndef NO_SHA + WC_HASH_TYPE_SHA, +#endif +#ifdef WOLFSSL_SHA224 + WC_HASH_TYPE_SHA224, +#endif + WC_HASH_TYPE_SHA256, +#ifdef WOLFSSL_SHA384 + WC_HASH_TYPE_SHA384, +#endif +#ifdef WOLFSSL_SHA512 + WC_HASH_TYPE_SHA512, +#endif + }; + + DECLARE_VAR_INIT(in, byte, inLen, inStr, HEAP_HINT); + DECLARE_VAR(out, byte, RSA_TEST_BYTES, HEAP_HINT); + DECLARE_VAR(sig, byte, RSA_TEST_BYTES, HEAP_HINT); + + /* Test all combinations of hash and MGF. */ + for (j = 0; j < (int)(sizeof(hash)/sizeof(*hash)); j++) { + /* Calculate hash of message. */ + ret = wc_Hash(hash[j], in, inLen, digest, sizeof(digest)); + if (ret != 0) + ERROR_OUT(-5450, exit_rsa_pss); + digestSz = wc_HashGetDigestSize(hash[j]); + + for (i = 0; i < (int)(sizeof(mgf)/sizeof(*mgf)); i++) { + outSz = RSA_TEST_BYTES; + ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[j], + mgf[i], -1, key, rng); + if (ret <= 0) + ERROR_OUT(-5451, exit_rsa_pss); + outSz = ret; + + XMEMCPY(sig, out, outSz); + plain = NULL; + ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[j], + mgf[i], -1, key); + if (ret <= 0) + ERROR_OUT(-5452, exit_rsa_pss); + plainSz = ret; + + ret = wc_RsaPSS_CheckPadding(digest, digestSz, plain, plainSz, + hash[j]); + if (ret != 0) + ERROR_OUT(-5453, exit_rsa_pss); + +#ifdef RSA_PSS_TEST_WRONG_PARAMS + for (k = 0; k < (int)(sizeof(mgf)/sizeof(*mgf)); k++) { + for (l = 0; l < (int)(sizeof(hash)/sizeof(*hash)); l++) { + if (i == k && j == l) + continue; + + XMEMCPY(sig, out, outSz); + ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, (byte**)&plain, + hash[l], mgf[k], -1, key); + if (ret >= 0) + ERROR_OUT(-5454, exit_rsa_pss); + } + } +#endif + } + } + + /* Test that a salt length of zero works. */ + digestSz = wc_HashGetDigestSize(hash[0]); + outSz = RSA_TEST_BYTES; + ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[0], mgf[0], 0, + key, rng); + if (ret <= 0) + ERROR_OUT(-5460, exit_rsa_pss); + outSz = ret; + + ret = wc_RsaPSS_Verify_ex(out, outSz, sig, outSz, hash[0], mgf[0], 0, + key); + if (ret <= 0) + ERROR_OUT(-5461, exit_rsa_pss); + plainSz = ret; + + ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, sig, plainSz, hash[0], + 0); + if (ret != 0) + ERROR_OUT(-5462, exit_rsa_pss); + + XMEMCPY(sig, out, outSz); + plain = NULL; + ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[0], mgf[0], 0, + key); + if (ret <= 0) + ERROR_OUT(-5463, exit_rsa_pss); + plainSz = ret; + + ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0], + 0); + if (ret != 0) + ERROR_OUT(-5464, exit_rsa_pss); + + /* Test bad salt lengths in various APIs. */ + digestSz = wc_HashGetDigestSize(hash[0]); + outSz = RSA_TEST_BYTES; + ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[0], mgf[0], -2, + key, rng); + if (ret != PSS_SALTLEN_E) + ERROR_OUT(-5470, exit_rsa_pss); + ret = wc_RsaPSS_Sign_ex(digest, digestSz, out, outSz, hash[0], mgf[0], + digestSz + 1, key, rng); + if (ret != PSS_SALTLEN_E) + ERROR_OUT(-5471, exit_rsa_pss); + + ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[0], mgf[0], -2, + key); + if (ret != PSS_SALTLEN_E) + ERROR_OUT(-5472, exit_rsa_pss); + ret = wc_RsaPSS_VerifyInline_ex(sig, outSz, &plain, hash[0], mgf[0], + digestSz + 1, key); + if (ret != PSS_SALTLEN_E) + ERROR_OUT(-5473, exit_rsa_pss); + + ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0], + -2); + if (ret != PSS_SALTLEN_E) + ERROR_OUT(-5474, exit_rsa_pss); + ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0], + digestSz + 1); + if (ret != PSS_SALTLEN_E) + ERROR_OUT(-5475, exit_rsa_pss); + + ret = 0; +exit_rsa_pss: + FREE_VAR(in, HEAP_HINT); + FREE_VAR(out, HEAP_HINT); + + return ret; +} +#endif + int rsa_test(void) { int ret; @@ -8967,6 +9136,10 @@ int rsa_test(void) #endif /* WOLFSSL_CERT_REQ */ #endif /* WOLFSSL_CERT_GEN */ +#ifdef WC_RSA_PSS + ret = rsa_pss_test(&rng, &key); +#endif + exit_rsa: wc_FreeRsaKey(&key); #ifdef WOLFSSL_CERT_EXT diff --git a/wolfssl/wolfcrypt/error-crypt.h b/wolfssl/wolfcrypt/error-crypt.h index 8204dd13d..4f6066272 100644 --- a/wolfssl/wolfcrypt/error-crypt.h +++ b/wolfssl/wolfcrypt/error-crypt.h @@ -194,7 +194,9 @@ enum { WC_HW_E = -248, /* Error with hardware crypto use */ WC_HW_WAIT_E = -249, /* Hardware waiting on resource */ - WC_LAST_E = -249, /* Update this to indicate last error */ + PSS_SALTLEN_E = -250, /* PSS length of salt is to long for hash */ + + WC_LAST_E = -250, /* Update this to indicate last error */ MIN_CODE_E = -300 /* errors -101 - -299 */ /* add new companion error id strings for any new error codes diff --git a/wolfssl/wolfcrypt/rsa.h b/wolfssl/wolfcrypt/rsa.h index 3dafb8a34..ee5a70366 100644 --- a/wolfssl/wolfcrypt/rsa.h +++ b/wolfssl/wolfcrypt/rsa.h @@ -150,6 +150,10 @@ WOLFSSL_API int wc_RsaSSL_Sign(const byte* in, word32 inLen, byte* out, WOLFSSL_API int wc_RsaPSS_Sign(const byte* in, word32 inLen, byte* out, word32 outLen, enum wc_HashType hash, int mgf, RsaKey* key, WC_RNG* rng); +WOLFSSL_API int wc_RsaPSS_Sign_ex(const byte* in, word32 inLen, byte* out, + word32 outLen, enum wc_HashType hash, + int mgf, int saltLen, RsaKey* key, + WC_RNG* rng); WOLFSSL_API int wc_RsaSSL_VerifyInline(byte* in, word32 inLen, byte** out, RsaKey* key); WOLFSSL_API int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, @@ -157,9 +161,22 @@ WOLFSSL_API int wc_RsaSSL_Verify(const byte* in, word32 inLen, byte* out, WOLFSSL_API int wc_RsaPSS_VerifyInline(byte* in, word32 inLen, byte** out, enum wc_HashType hash, int mgf, RsaKey* key); +WOLFSSL_API int wc_RsaPSS_VerifyInline_ex(byte* in, word32 inLen, byte** out, + enum wc_HashType hash, int mgf, + int saltLen, RsaKey* key); +WOLFSSL_API int wc_RsaPSS_Verify(byte* in, word32 inLen, byte* out, + word32 outLen, enum wc_HashType hash, int mgf, + RsaKey* key); +WOLFSSL_API int wc_RsaPSS_Verify_ex(byte* in, word32 inLen, byte* out, + word32 outLen, enum wc_HashType hash, + int mgf, int saltLen, RsaKey* key); WOLFSSL_API int wc_RsaPSS_CheckPadding(const byte* in, word32 inLen, byte* sig, word32 sigSz, enum wc_HashType hashType); +WOLFSSL_API int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inLen, + byte* sig, word32 sigSz, + enum wc_HashType hashType, + int saltLen); WOLFSSL_API int wc_RsaEncryptSize(RsaKey* key);