diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index 723061ee5..fdf6a45c4 100755 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -586,7 +586,7 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock, */ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock, word32 pkcsBlockLen, WC_RNG* rng, enum wc_HashType hType, int mgf, - void* heap) + int bits, void* heap) { int ret; int hLen, i; @@ -617,8 +617,7 @@ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock, ret = RsaMGF(mgf, h, hLen, pkcsBlock, pkcsBlockLen - hLen - 1, heap); if (ret != 0) return ret; - /* TODO: use the number of bits in the prime */ - pkcsBlock[0] &= (1 << 7) - 1; + pkcsBlock[0] &= (1 << ((bits - 1) & 0x7)) - 1; m = pkcsBlock + pkcsBlockLen - 1 - hLen - hLen - 1; *(m++) ^= 0x01; @@ -681,7 +680,7 @@ static int RsaPad(const byte* input, word32 inputLen, byte* pkcsBlock, /* helper function to direct which padding is used */ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock, word32 pkcsBlockLen, byte padValue, WC_RNG* rng, int padType, - enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen, + enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen, int bits, void* heap) { int ret; @@ -705,8 +704,8 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock, #ifdef WC_RSA_PSS case WC_RSA_PSS_PAD: WOLFSSL_MSG("wolfSSL Using RSA PSS padding"); - ret = RsaPad_PSS(input, inputLen, pkcsBlock, pkcsBlockLen, - rng, hType, mgf, heap); + ret = RsaPad_PSS(input, inputLen, pkcsBlock, pkcsBlockLen, rng, + hType, mgf, bits, heap); break; #endif @@ -720,6 +719,7 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock, (void)mgf; (void)optLabel; (void)labelLen; + (void)bits; (void)heap; return ret; @@ -816,7 +816,7 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen, #ifdef WC_RSA_PSS static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, byte **output, enum wc_HashType hType, int mgf, - void* heap) + int bits, void* heap) { int ret; byte* tmp; @@ -840,8 +840,7 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, return ret; } - /* TODO: use the number of bits in the prime */ - tmp[0] &= (1 << 7) - 1; + tmp[0] &= (1 << ((bits - 1) & 0x7)) - 1; for (i = 0; i < (int)(pkcsBlockLen - 1 - hLen - hLen - 1); i++) { if (tmp[i] != pkcsBlock[i]) { XFREE(tmp, heap, DYNAMIC_TYPE_TMP_BUFFER); @@ -915,7 +914,8 @@ static int RsaUnPad(const byte *pkcsBlock, unsigned int pkcsBlockLen, /* helper function to direct unpadding */ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, byte padValue, int padType, enum wc_HashType hType, - int mgf, byte* optLabel, word32 labelLen, void* heap) + int mgf, byte* optLabel, word32 labelLen, int bits, + void* heap) { int ret; @@ -937,7 +937,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, case WC_RSA_PSS_PAD: WOLFSSL_MSG("wolfSSL Using RSA PSS un-padding"); ret = RsaUnPad_PSS((byte*)pkcsBlock, pkcsBlockLen, out, hType, mgf, - heap); + bits, heap); break; #endif @@ -951,6 +951,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, (void)mgf; (void)optLabel; (void)labelLen; + (void)bits; (void)heap; return ret; @@ -1285,7 +1286,8 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out, #endif ret = wc_RsaPad_ex(in, inLen, out, sz, pad_value, rng, pad_type, hash, - mgf, label, labelSz, key->heap); + mgf, label, labelSz, mp_count_bits(&key->n), + key->heap); if (ret < 0) { break; } @@ -1418,7 +1420,8 @@ static int RsaPrivateDecryptEx(byte* in, word32 inLen, byte* out, { byte* pad = NULL; ret = wc_RsaUnPad_ex(key->data, key->dataLen, &pad, pad_value, pad_type, - hash, mgf, label, labelSz, key->heap); + hash, mgf, label, labelSz, mp_count_bits(&key->n), + key->heap); if (ret > 0 && ret <= (int)outLen && pad != NULL) { /* only copy output if not inline */ if (outPtr == NULL) { @@ -1601,7 +1604,9 @@ int wc_RsaPSS_CheckPadding(const byte* in, word32 inSz, byte* sig, ret = BAD_FUNC_ARG; else { XMEMCPY(sig + RSA_PSS_PAD_SZ, in, inSz); - wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz * 2, sig, inSz); + ret = wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz * 2, sig, inSz); + if (ret != 0) + return ret; if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz * 2, inSz) != 0) ret = BAD_PADDING_E; else