Fix for missing heap hint with RSA PSS and WOLFSSL_PSS_LONG_SALT (#4363)

* Fix for missing heap hint with RSA PSS and `WOLFSSL_PSS_LONG_SALT`. This fix will only allocate buffer if it exceeds the local buffer. Added `wc_RsaPSS_CheckPadding_ex2` to support heap hint if required. Fixed asn.c build issue with `NO_CERTS`. Fixed several spelling errors in asn.c. ZD12855.

* Improve the dynamic memory NULL checking in `wc_RsaPSS_CheckPadding_ex2` with `WOLFSSL_PSS_LONG_SALT` defined.
This commit is contained in:
David Garske
2021-09-02 22:42:31 -07:00
committed by GitHub
parent a3ee84bf6d
commit 35cef831bf
4 changed files with 71 additions and 53 deletions

View File

@ -2145,8 +2145,8 @@ int GetASNTag(const byte* input, word32* inOutIdx, byte* tag, word32 maxIdx)
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in] tag ASN.1 tag value expected in header.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @param [in] check Whether to check the buffer has at least the
@ -2195,8 +2195,8 @@ static int GetASNHeader_ex(const byte* input, byte tag, word32* inOutIdx,
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in] tag ASN.1 tag value expected in header.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @return Number of bytes in the ASN.1 data on success.
@ -2233,8 +2233,8 @@ static int GetHeader(const byte* input, byte* tag, word32* inOutIdx, int* len,
/* Decode the header of a BER/DER encoded SEQUENCE.
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @return Number of bytes in the ASN.1 data on success.
@ -2251,8 +2251,8 @@ int GetSequence(const byte* input, word32* inOutIdx, int* len,
/* Decode the header of a BER/DER encoded SEQUENCE.
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @param [in] check Whether to check the buffer has at least the
@ -2271,8 +2271,8 @@ int GetSequence_ex(const byte* input, word32* inOutIdx, int* len,
/* Decode the header of a BER/DER encoded SET.
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @return Number of bytes in the ASN.1 data on success.
@ -2289,8 +2289,8 @@ int GetSet(const byte* input, word32* inOutIdx, int* len,
/* Decode the header of a BER/DER encoded SET.
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @param [in] check Whether to check the buffer has at least the
@ -2400,9 +2400,9 @@ static int GetBoolean(const byte* input, word32* inOutIdx, word32 maxIdx)
/* Decode the header of a BER/DER encoded OCTET STRING.
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [out] len Nnumber of bytes in the ASN.1 data.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @return Number of bytes in the ASN.1 data on success.
* @return BUFFER_E when there is not enough data to parse.
@ -2688,7 +2688,8 @@ const char* GetSigName(int oid) {
* When output is NULL, calculate the header length only.
*
* @param [in] len Length of INTEGER data in bytes.
* @param [in] firstByte First byte of data, most significant byte of integer, * to encode.
* @param [in] firstByte First byte of data, most significant byte of integer,
* to encode.
* @param [out] output Buffer to write into.
* @return Number of bytes added to the buffer.
*/
@ -2705,7 +2706,7 @@ int SetASNInt(int len, byte firstByte, byte* output)
/* Check if first byte has top bit set in which case a 0 is needed to
* maintain positive value. */
if (firstByte & 0x80) {
/* Add pre-prended byte to length of data in INTEGER. */
/* Add pre-prepended byte to length of data in INTEGER. */
len++;
}
/* Encode length - passing NULL for output will not encode. */
@ -4935,8 +4936,8 @@ int DecodeObjectId(const byte* in, word32 inSz, word16* out, word32* outSz)
/* Decode the header of a BER/DER encoded OBJECT ID.
*
* @param [in] input Buffer holding DER/BER encoded data.
* @param [in, out] inOutIdx On in, starting index of hedear.
* On out, end of parsed hedear.
* @param [in, out] inOutIdx On in, starting index of header.
* On out, end of parsed header.
* @param [out] len Number of bytes in the ASN.1 data.
* @param [in] maxIdx Length of data in buffer.
* @return 0 on success.
@ -5195,7 +5196,7 @@ static const ASNItem algoIdASN[] = {
* NULL tag is skipped if present.
*
* @param [in] input Buffer holding BER encoded data.
* @param [in, out] inOutIdx On in, start of algorithm identfier.
* @param [in, out] inOutIdx On in, start of algorithm identifier.
* On out, start of ASN.1 item after algorithm id.
* @param [out] oid Id of OID in algorithm identifier data.
* @param [in] oidType Type of OID to expect.
@ -5325,7 +5326,7 @@ static const ASNItem rsaKeyASN[] = {
*
* PKCS #1: RFC 8017, A.1.2 - RSAPrivateKey
*
* Compiling with WOLFSSL_RSA_PUBLIC_ONLY will result in only the public fuields
* Compiling with WOLFSSL_RSA_PUBLIC_ONLY will result in only the public fields
* being extracted.
*
* @param [in] input Buffer holding BER encoded data.
@ -5502,7 +5503,7 @@ static const ASNItem pkcs8KeyASN[] = {
* @param [in, out] inOutIdx On in, start of PKCS #8 encoding.
* On out, start of encoded key.
* @param [in] sz Size of data in buffer.
* @param [out] algId Key's algorithm id fronm PKCS #8 header.
* @param [out] algId Key's algorithm id from PKCS #8 header.
* @return Length of key data on success.
* @return ASN_PARSE_E when BER encoded data does not match ASN.1 items or
* is invalid.
@ -6375,7 +6376,8 @@ int wc_GetKeyOID(byte* key, word32 keySz, const byte** curveOID, word32* oidSz,
#endif /* HAVE_ECC && !NO_ASN_CRYPT */
#if defined(HAVE_ED25519) && defined(HAVE_ED25519_KEY_IMPORT) && !defined(NO_ASN_CRYPT)
if (*algoID == 0) {
ed25519_key *ed25519 = (ed25519_key *)XMALLOC(sizeof *ed25519, heap, DYNAMIC_TYPE_TMP_BUFFER);
ed25519_key *ed25519 = (ed25519_key *)XMALLOC(sizeof *ed25519, heap,
DYNAMIC_TYPE_TMP_BUFFER);
if (ed25519 == NULL)
return MEMORY_E;
@ -6397,7 +6399,8 @@ int wc_GetKeyOID(byte* key, word32 keySz, const byte** curveOID, word32* oidSz,
#endif /* HAVE_ED25519 && HAVE_ED25519_KEY_IMPORT && !NO_ASN_CRYPT */
#if defined(HAVE_ED448) && defined(HAVE_ED448_KEY_IMPORT) && !defined(NO_ASN_CRYPT)
if (*algoID == 0) {
ed448_key *ed448 = (ed448_key *)XMALLOC(sizeof *ed448, heap, DYNAMIC_TYPE_TMP_BUFFER);
ed448_key *ed448 = (ed448_key *)XMALLOC(sizeof *ed448, heap,
DYNAMIC_TYPE_TMP_BUFFER);
if (ed448 == NULL)
return MEMORY_E;
@ -7214,7 +7217,7 @@ static const ASNItem p8EncPbes1ASN[] = {
* @param [in] saltSz Length of salt in bytes.
* @param [in] itt Number of iterations to use in KDF.
* @param [in] rng Random number generator to use to generate salt.
* @param [in] heap Dynamic memory alloctor hint.
* @param [in] heap Dynamic memory allocator hint.
* @return The size of encrypted data on success
* @return LENGTH_ONLY_E when out is NULL and able to encode.
* @return ASN_PARSE_E when the salt size is too large.
@ -7991,7 +7994,7 @@ static const ASNItem dhKeyPkcs8ASN[] = {
* On out, end of DH key data.
* @param [in, out] key DH key object.
* @param [in] inSz Size of data in bytes.
* @return 0 on suceess.
* @return 0 on success.
* @return BAD_FUNC_ARG when input, inOutIDx or key is NULL.
* @return MEMORY_E when dynamic memory allocation fails.
* @return ASN_PARSE_E when BER encoded data does not match ASN.1 items or
@ -10152,6 +10155,7 @@ int CalcHashId(const byte* data, word32 len, byte* hash)
return ret;
}
#ifndef NO_CERTS
/* Get the hash of the id using the SHA-1 or SHA-256.
*
* If the id is not the length of the hash, then hash it.
@ -10176,6 +10180,7 @@ static int GetHashId(const byte* id, int length, byte* hash)
return ret;
}
#endif /* !NO_CERTS */
#ifdef WOLFSSL_ASN_TEMPLATE
/* Id for street address - not used. */
@ -12173,8 +12178,7 @@ int DecodeToKey(DecodedCert* cert, int verify)
#endif /* WOLFSSL_ASN_TEMPLATE */
}
#ifndef NO_CERTS
#ifndef WOLFSSL_ASN_TEMPLATE
#if !defined(NO_CERTS) && !defined(WOLFSSL_ASN_TEMPLATE)
static int GetSignature(DecodedCert* cert)
{
int length;
@ -12194,14 +12198,15 @@ static int GetSignature(DecodedCert* cert)
return 0;
}
#endif /* !NO_CERTS && !WOLFSSL_ASN_TEMPLATE */
#ifndef WOLFSSL_ASN_TEMPLATE
static word32 SetOctetString8Bit(word32 len, byte* output)
{
output[0] = ASN_OCTET_STRING;
output[1] = (byte)len;
return 2;
}
static word32 SetDigest(const byte* digest, word32 digSz, byte* output)
{
word32 idx = SetOctetString8Bit(digSz, output);
@ -12210,7 +12215,7 @@ static word32 SetDigest(const byte* digest, word32 digSz, byte* output)
return idx + digSz;
}
#endif
#endif /* NO_CERTS */
/* Encode a length for DER.
*
@ -18752,7 +18757,8 @@ int PemToDer(const unsigned char* buff, long longSz, int type,
word32 algId = 0;
word32 idx;
#if defined(WOLFSSL_ENCRYPTED_KEYS)
#if defined(WOLFSSL_ENCRYPTED_KEYS) && !defined(NO_DES3) && !defined(NO_WOLFSSL_SKIP_TRAILING_PAD)
#if defined(WOLFSSL_ENCRYPTED_KEYS) && !defined(NO_DES3) && \
!defined(NO_WOLFSSL_SKIP_TRAILING_PAD)
int padVal = 0;
#endif
#endif

View File

@ -3551,16 +3551,13 @@ int wc_RsaPSS_CheckPadding(const byte* in, word32 inSz, byte* sig,
* NULL is passed in to in or sig or inSz is not the same as the hash
* algorithm length and 0 on success.
*/
int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inSz, byte* sig,
int wc_RsaPSS_CheckPadding_ex2(const byte* in, word32 inSz, byte* sig,
word32 sigSz, enum wc_HashType hashType,
int saltLen, int bits)
int saltLen, int bits, void* heap)
{
int ret = 0;
#ifndef WOLFSSL_PSS_LONG_SALT
byte sigCheck[WC_MAX_DIGEST_SIZE*2 + RSA_PSS_PAD_SZ];
#else
byte *sigCheck = NULL;
#endif
byte sigCheckBuf[WC_MAX_DIGEST_SIZE*2 + RSA_PSS_PAD_SZ];
byte *sigCheck = sigCheckBuf;
(void)bits;
@ -3609,8 +3606,9 @@ int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inSz, byte* sig,
}
#ifdef WOLFSSL_PSS_LONG_SALT
if (ret == 0) {
sigCheck = (byte*)XMALLOC(RSA_PSS_PAD_SZ + inSz + saltLen, NULL,
/* if long salt is larger then default maximum buffer then allocate a buffer */
if (ret == 0 && sizeof(sigCheckBuf) < (RSA_PSS_PAD_SZ + inSz + saltLen)) {
sigCheck = (byte*)XMALLOC(RSA_PSS_PAD_SZ + inSz + saltLen, heap,
DYNAMIC_TYPE_RSA_BUFFER);
if (sigCheck == NULL) {
ret = MEMORY_E;
@ -3634,12 +3632,22 @@ int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inSz, byte* sig,
}
#ifdef WOLFSSL_PSS_LONG_SALT
if (sigCheck != NULL) {
XFREE(sigCheck, NULL, DYNAMIC_TYPE_RSA_BUFFER);
if (sigCheck != NULL && sigCheck != sigCheckBuf) {
XFREE(sigCheck, heap, DYNAMIC_TYPE_RSA_BUFFER);
}
#else
(void)heap;
#endif
return ret;
}
int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inSz, byte* sig,
word32 sigSz, enum wc_HashType hashType,
int saltLen, int bits)
{
return wc_RsaPSS_CheckPadding_ex2(in, inSz, sig, sigSz, hashType, saltLen,
bits, NULL);
}
/* Verify the message signed with RSA-PSS.

View File

@ -13551,8 +13551,8 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key)
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz,
hash[j], -1);
#else
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz,
hash[j], -1, wc_RsaEncryptSize(key)*8);
ret = wc_RsaPSS_CheckPadding_ex2(digest, digestSz, plain, plainSz,
hash[j], -1, wc_RsaEncryptSize(key)*8, HEAP_HINT);
#endif
if (ret != 0)
ERROR_OUT(-7733, exit_rsa_pss);
@ -13627,8 +13627,8 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key)
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, sig, plainSz,
hash[0], 0);
#else
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, sig, plainSz,
hash[0], 0, 0);
ret = wc_RsaPSS_CheckPadding_ex2(digest, digestSz, sig, plainSz,
hash[0], 0, 0, HEAP_HINT);
#endif
}
} while (ret == WC_PENDING_E);
@ -13657,8 +13657,8 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key)
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
0);
#else
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
0, 0);
ret = wc_RsaPSS_CheckPadding_ex2(digest, digestSz, plain, plainSz, hash[0],
0, 0, HEAP_HINT);
#endif
if (ret != 0)
ERROR_OUT(-7739, exit_rsa_pss);
@ -13736,8 +13736,8 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key)
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
len);
#else
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
len, 0);
ret = wc_RsaPSS_CheckPadding_ex2(digest, digestSz, plain, plainSz, hash[0],
len, 0, HEAP_HINT);
#endif
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-7744, exit_rsa_pss);
@ -13751,8 +13751,8 @@ static int rsa_pss_test(WC_RNG* rng, RsaKey* key)
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
len);
#else
ret = wc_RsaPSS_CheckPadding_ex(digest, digestSz, plain, plainSz, hash[0],
len, 0);
ret = wc_RsaPSS_CheckPadding_ex2(digest, digestSz, plain, plainSz, hash[0],
len, 0, HEAP_HINT);
#endif
if (ret != PSS_SALTLEN_E)
ERROR_OUT(-7745, exit_rsa_pss);

View File

@ -276,6 +276,10 @@ WOLFSSL_API int wc_RsaPSS_CheckPadding_ex(const byte* in, word32 inLen,
byte* sig, word32 sigSz,
enum wc_HashType hashType,
int saltLen, int bits);
WOLFSSL_API int wc_RsaPSS_CheckPadding_ex2(const byte* in, word32 inLen,
byte* sig, word32 sigSz,
enum wc_HashType hashType,
int saltLen, int bits, void* heap);
WOLFSSL_API int wc_RsaPSS_VerifyCheckInline(byte* in, word32 inLen, byte** out,
const byte* digest, word32 digentLen,
enum wc_HashType hash, int mgf,