Support RSA OAEP with no malloc

This commit is contained in:
Tesfa Mael
2022-05-03 22:57:47 -07:00
parent e722c15be8
commit 97f54e8e0a

View File

@ -853,11 +853,15 @@ int wc_CheckRsaKey(RsaKey* key)
static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz, static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
byte* out, word32 outSz, void* heap) byte* out, word32 outSz, void* heap)
{ {
byte* tmp; #if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
byte* tmp = NULL;
#else
byte tmp[RSA_MAX_SIZE/8] = {0};
#endif
/* needs to be large enough for seed size plus counter(4) */ /* needs to be large enough for seed size plus counter(4) */
byte tmpA[WC_MAX_DIGEST_SIZE + 4]; byte tmpA[WC_MAX_DIGEST_SIZE + 4]= {0};
byte tmpF; /* 1 if dynamic memory needs freed */ byte tmpF = 0; /* 1 if dynamic memory needs freed */
word32 tmpSz; word32 tmpSz = 0;
int hLen; int hLen;
int ret; int ret;
word32 counter; word32 counter;
@ -881,18 +885,22 @@ static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
/* find largest amount of memory needed which will be the max of /* find largest amount of memory needed which will be the max of
* hLen and (seedSz + 4) since tmp is used to store the hash digest */ * hLen and (seedSz + 4) since tmp is used to store the hash digest */
tmpSz = ((seedSz + 4) > (word32)hLen)? seedSz + 4: (word32)hLen; tmpSz = ((seedSz + 4) > (word32)hLen)? seedSz + 4: (word32)hLen;
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
tmp = (byte*)XMALLOC(tmpSz, heap, DYNAMIC_TYPE_RSA_BUFFER); tmp = (byte*)XMALLOC(tmpSz, heap, DYNAMIC_TYPE_RSA_BUFFER);
if (tmp == NULL) { if (tmp == NULL) {
return MEMORY_E; return MEMORY_E;
} }
tmpF = 1; /* make sure to free memory when done */ tmpF = 1; /* make sure to free memory when done */
#endif
} }
else { else {
/* use array on the stack */ /* use array on the stack */
#ifndef WOLFSSL_SMALL_STACK_CACHE #ifndef WOLFSSL_SMALL_STACK_CACHE
tmpSz = sizeof(tmpA); tmpSz = sizeof(tmpA);
#endif #endif
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
tmp = tmpA; tmp = tmpA;
#endif
tmpF = 0; /* no need to free memory at end */ tmpF = 0; /* no need to free memory at end */
} }
@ -935,9 +943,11 @@ static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
#endif #endif
if (ret != 0) { if (ret != 0) {
/* check for if dynamic memory was needed, then free */ /* check for if dynamic memory was needed, then free */
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
if (tmpF) { if (tmpF) {
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER); XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
} }
#endif
return ret; return ret;
} }
@ -1038,15 +1048,15 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
int i; int i;
word32 idx; word32 idx;
byte* dbMask;
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
byte* dbMask = NULL;
byte* lHash = NULL; byte* lHash = NULL;
byte* seed = NULL; byte* seed = NULL;
#else #else
/* must be large enough to contain largest hash */ /* must be large enough to contain largest hash */
byte lHash[WC_MAX_DIGEST_SIZE]; byte lHash[WC_MAX_DIGEST_SIZE] = {0};
byte seed[ WC_MAX_DIGEST_SIZE]; byte seed[WC_MAX_DIGEST_SIZE]= {0};
byte dbMask[RSA_MAX_SIZE/8 + RSA_PSS_PAD_SZ] = {0};
#endif #endif
/* no label is allowed, but catch if no label provided and length > 0 */ /* no label is allowed, but catch if no label provided and length > 0 */
@ -1143,21 +1153,21 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
return ret; return ret;
} }
#ifdef WOLFSSL_SMALL_STACK
/* create maskedDB from dbMask */ /* create maskedDB from dbMask */
dbMask = (byte*)XMALLOC(pkcsBlockLen - hLen - 1, heap, DYNAMIC_TYPE_RSA); dbMask = (byte*)XMALLOC(pkcsBlockLen - hLen - 1, heap, DYNAMIC_TYPE_RSA);
if (dbMask == NULL) { if (dbMask == NULL) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(lHash, heap, DYNAMIC_TYPE_RSA_BUFFER); XFREE(lHash, heap, DYNAMIC_TYPE_RSA_BUFFER);
XFREE(seed, heap, DYNAMIC_TYPE_RSA_BUFFER); XFREE(seed, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif
return MEMORY_E; return MEMORY_E;
} }
#endif
XMEMSET(dbMask, 0, pkcsBlockLen - hLen - 1); /* help static analyzer */ XMEMSET(dbMask, 0, pkcsBlockLen - hLen - 1); /* help static analyzer */
ret = RsaMGF(mgf, seed, hLen, dbMask, pkcsBlockLen - hLen - 1, heap); ret = RsaMGF(mgf, seed, hLen, dbMask, pkcsBlockLen - hLen - 1, heap);
if (ret != 0) { if (ret != 0) {
XFREE(dbMask, heap, DYNAMIC_TYPE_RSA);
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
XFREE(dbMask, heap, DYNAMIC_TYPE_RSA);
XFREE(lHash, heap, DYNAMIC_TYPE_RSA_BUFFER); XFREE(lHash, heap, DYNAMIC_TYPE_RSA_BUFFER);
XFREE(seed, heap, DYNAMIC_TYPE_RSA_BUFFER); XFREE(seed, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif #endif
@ -1170,8 +1180,9 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
pkcsBlock[idx] = dbMask[i++] ^ pkcsBlock[idx]; pkcsBlock[idx] = dbMask[i++] ^ pkcsBlock[idx];
idx++; idx++;
} }
#ifdef WOLFSSL_SMALL_STACK
XFREE(dbMask, heap, DYNAMIC_TYPE_RSA); XFREE(dbMask, heap, DYNAMIC_TYPE_RSA);
#endif
/* create maskedSeed from seedMask */ /* create maskedSeed from seedMask */
idx = 0; idx = 0;
@ -1513,9 +1524,13 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
int hLen; int hLen;
int ret; int ret;
byte h[WC_MAX_DIGEST_SIZE]; /* max digest size */ byte h[WC_MAX_DIGEST_SIZE]; /* max digest size */
byte* tmp;
word32 idx; word32 idx;
#ifdef WOLFSSL_SMALL_STACK
byte* tmp = NULL;
#else
byte tmp[RSA_MAX_SIZE/8 + RSA_PSS_PAD_SZ] = {0};
#endif
/* no label is allowed, but catch if no label provided and length > 0 */ /* no label is allowed, but catch if no label provided and length > 0 */
if (optLabel == NULL && labelLen > 0) { if (optLabel == NULL && labelLen > 0) {
return BUFFER_E; return BUFFER_E;
@ -1526,16 +1541,20 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
return BAD_FUNC_ARG; return BAD_FUNC_ARG;
} }
#ifdef WOLFSSL_SMALL_STACK
tmp = (byte*)XMALLOC(pkcsBlockLen, heap, DYNAMIC_TYPE_RSA_BUFFER); tmp = (byte*)XMALLOC(pkcsBlockLen, heap, DYNAMIC_TYPE_RSA_BUFFER);
if (tmp == NULL) { if (tmp == NULL) {
return MEMORY_E; return MEMORY_E;
} }
#endif
XMEMSET(tmp, 0, pkcsBlockLen); XMEMSET(tmp, 0, pkcsBlockLen);
/* find seedMask value */ /* find seedMask value */
if ((ret = RsaMGF(mgf, (byte*)(pkcsBlock + (hLen + 1)), if ((ret = RsaMGF(mgf, (byte*)(pkcsBlock + (hLen + 1)),
pkcsBlockLen - hLen - 1, tmp, hLen, heap)) != 0) { pkcsBlockLen - hLen - 1, tmp, hLen, heap)) != 0) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER); XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif
return ret; return ret;
} }
@ -1547,7 +1566,9 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
/* get dbMask value */ /* get dbMask value */
if ((ret = RsaMGF(mgf, tmp, hLen, tmp + hLen, if ((ret = RsaMGF(mgf, tmp, hLen, tmp + hLen,
pkcsBlockLen - hLen - 1, heap)) != 0) { pkcsBlockLen - hLen - 1, heap)) != 0) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(tmp, NULL, DYNAMIC_TYPE_RSA_BUFFER); XFREE(tmp, NULL, DYNAMIC_TYPE_RSA_BUFFER);
#endif
return ret; return ret;
} }
@ -1556,8 +1577,10 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
pkcsBlock[hLen + 1 + idx] = pkcsBlock[hLen + 1 + idx] ^ tmp[idx + hLen]; pkcsBlock[hLen + 1 + idx] = pkcsBlock[hLen + 1 + idx] ^ tmp[idx + hLen];
} }
#ifdef WOLFSSL_SMALL_STACK
/* done with use of tmp buffer */ /* done with use of tmp buffer */
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER); XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif
/* advance idx to index of PS and msg separator, account for PS size of 0*/ /* advance idx to index of PS and msg separator, account for PS size of 0*/
idx = hLen + 1 + hLen; idx = hLen + 1 + hLen;