Fixes from review

This commit is contained in:
Sean Parkinson
2017-05-22 09:43:55 +10:00
parent 4390f4c711
commit 6c6069bed8

View File

@ -586,7 +586,7 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
*/ */
static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock, static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock,
word32 pkcsBlockLen, WC_RNG* rng, enum wc_HashType hType, int mgf, word32 pkcsBlockLen, WC_RNG* rng, enum wc_HashType hType, int mgf,
void* heap) int bits, void* heap)
{ {
int ret; int ret;
int hLen, i; int hLen, i;
@ -617,8 +617,7 @@ static int RsaPad_PSS(const byte* input, word32 inputLen, byte* pkcsBlock,
ret = RsaMGF(mgf, h, hLen, pkcsBlock, pkcsBlockLen - hLen - 1, heap); ret = RsaMGF(mgf, h, hLen, pkcsBlock, pkcsBlockLen - hLen - 1, heap);
if (ret != 0) if (ret != 0)
return ret; return ret;
/* TODO: use the number of bits in the prime */ pkcsBlock[0] &= (1 << ((bits - 1) & 0x7)) - 1;
pkcsBlock[0] &= (1 << 7) - 1;
m = pkcsBlock + pkcsBlockLen - 1 - hLen - hLen - 1; m = pkcsBlock + pkcsBlockLen - 1 - hLen - hLen - 1;
*(m++) ^= 0x01; *(m++) ^= 0x01;
@ -681,7 +680,7 @@ static int RsaPad(const byte* input, word32 inputLen, byte* pkcsBlock,
/* helper function to direct which padding is used */ /* helper function to direct which padding is used */
static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock, static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock,
word32 pkcsBlockLen, byte padValue, WC_RNG* rng, int padType, word32 pkcsBlockLen, byte padValue, WC_RNG* rng, int padType,
enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen, enum wc_HashType hType, int mgf, byte* optLabel, word32 labelLen, int bits,
void* heap) void* heap)
{ {
int ret; int ret;
@ -705,8 +704,8 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock,
#ifdef WC_RSA_PSS #ifdef WC_RSA_PSS
case WC_RSA_PSS_PAD: case WC_RSA_PSS_PAD:
WOLFSSL_MSG("wolfSSL Using RSA PSS padding"); WOLFSSL_MSG("wolfSSL Using RSA PSS padding");
ret = RsaPad_PSS(input, inputLen, pkcsBlock, pkcsBlockLen, ret = RsaPad_PSS(input, inputLen, pkcsBlock, pkcsBlockLen, rng,
rng, hType, mgf, heap); hType, mgf, bits, heap);
break; break;
#endif #endif
@ -720,6 +719,7 @@ static int wc_RsaPad_ex(const byte* input, word32 inputLen, byte* pkcsBlock,
(void)mgf; (void)mgf;
(void)optLabel; (void)optLabel;
(void)labelLen; (void)labelLen;
(void)bits;
(void)heap; (void)heap;
return ret; return ret;
@ -816,7 +816,7 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
#ifdef WC_RSA_PSS #ifdef WC_RSA_PSS
static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen, static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen,
byte **output, enum wc_HashType hType, int mgf, byte **output, enum wc_HashType hType, int mgf,
void* heap) int bits, void* heap)
{ {
int ret; int ret;
byte* tmp; byte* tmp;
@ -840,8 +840,7 @@ static int RsaUnPad_PSS(byte *pkcsBlock, unsigned int pkcsBlockLen,
return ret; return ret;
} }
/* TODO: use the number of bits in the prime */ tmp[0] &= (1 << ((bits - 1) & 0x7)) - 1;
tmp[0] &= (1 << 7) - 1;
for (i = 0; i < (int)(pkcsBlockLen - 1 - hLen - hLen - 1); i++) { for (i = 0; i < (int)(pkcsBlockLen - 1 - hLen - hLen - 1); i++) {
if (tmp[i] != pkcsBlock[i]) { if (tmp[i] != pkcsBlock[i]) {
XFREE(tmp, heap, DYNAMIC_TYPE_TMP_BUFFER); XFREE(tmp, heap, DYNAMIC_TYPE_TMP_BUFFER);
@ -915,7 +914,8 @@ static int RsaUnPad(const byte *pkcsBlock, unsigned int pkcsBlockLen,
/* helper function to direct unpadding */ /* helper function to direct unpadding */
static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out, static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out,
byte padValue, int padType, enum wc_HashType hType, byte padValue, int padType, enum wc_HashType hType,
int mgf, byte* optLabel, word32 labelLen, void* heap) int mgf, byte* optLabel, word32 labelLen, int bits,
void* heap)
{ {
int ret; int ret;
@ -937,7 +937,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out,
case WC_RSA_PSS_PAD: case WC_RSA_PSS_PAD:
WOLFSSL_MSG("wolfSSL Using RSA PSS un-padding"); WOLFSSL_MSG("wolfSSL Using RSA PSS un-padding");
ret = RsaUnPad_PSS((byte*)pkcsBlock, pkcsBlockLen, out, hType, mgf, ret = RsaUnPad_PSS((byte*)pkcsBlock, pkcsBlockLen, out, hType, mgf,
heap); bits, heap);
break; break;
#endif #endif
@ -951,6 +951,7 @@ static int wc_RsaUnPad_ex(byte* pkcsBlock, word32 pkcsBlockLen, byte** out,
(void)mgf; (void)mgf;
(void)optLabel; (void)optLabel;
(void)labelLen; (void)labelLen;
(void)bits;
(void)heap; (void)heap;
return ret; return ret;
@ -1285,7 +1286,8 @@ static int RsaPublicEncryptEx(const byte* in, word32 inLen, byte* out,
#endif #endif
ret = wc_RsaPad_ex(in, inLen, out, sz, pad_value, rng, pad_type, hash, ret = wc_RsaPad_ex(in, inLen, out, sz, pad_value, rng, pad_type, hash,
mgf, label, labelSz, key->heap); mgf, label, labelSz, mp_count_bits(&key->n),
key->heap);
if (ret < 0) { if (ret < 0) {
break; break;
} }
@ -1418,7 +1420,8 @@ static int RsaPrivateDecryptEx(byte* in, word32 inLen, byte* out,
{ {
byte* pad = NULL; byte* pad = NULL;
ret = wc_RsaUnPad_ex(key->data, key->dataLen, &pad, pad_value, pad_type, ret = wc_RsaUnPad_ex(key->data, key->dataLen, &pad, pad_value, pad_type,
hash, mgf, label, labelSz, key->heap); hash, mgf, label, labelSz, mp_count_bits(&key->n),
key->heap);
if (ret > 0 && ret <= (int)outLen && pad != NULL) { if (ret > 0 && ret <= (int)outLen && pad != NULL) {
/* only copy output if not inline */ /* only copy output if not inline */
if (outPtr == NULL) { if (outPtr == NULL) {
@ -1601,7 +1604,9 @@ int wc_RsaPSS_CheckPadding(const byte* in, word32 inSz, byte* sig,
ret = BAD_FUNC_ARG; ret = BAD_FUNC_ARG;
else { else {
XMEMCPY(sig + RSA_PSS_PAD_SZ, in, inSz); XMEMCPY(sig + RSA_PSS_PAD_SZ, in, inSz);
wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz * 2, sig, inSz); ret = wc_Hash(hashType, sig, RSA_PSS_PAD_SZ + inSz * 2, sig, inSz);
if (ret != 0)
return ret;
if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz * 2, inSz) != 0) if (XMEMCMP(sig, sig + RSA_PSS_PAD_SZ + inSz * 2, inSz) != 0)
ret = BAD_PADDING_E; ret = BAD_PADDING_E;
else else