Merge pull request #5105 from tmael/rsa_oaep_nomalloc

Support RSA OAEP with nomalloc
This commit is contained in:
David Garske
2022-06-02 08:45:01 -07:00
committed by GitHub
2 changed files with 58 additions and 23 deletions

View File

@ -239,7 +239,7 @@ enum {
static void wc_RsaCleanup(RsaKey* key)
{
#ifndef WOLFSSL_RSA_VERIFY_INLINE
#if !defined(WOLFSSL_RSA_VERIFY_INLINE) && !defined(WOLFSSL_NO_MALLOC)
if (key && key->data) {
/* make sure any allocated memory is free'd */
if (key->dataIsAlloc) {
@ -273,7 +273,7 @@ int wc_InitRsaKey_ex(RsaKey* key, void* heap, int devId)
key->type = RSA_TYPE_UNKNOWN;
key->state = RSA_STATE_NONE;
key->heap = heap;
#ifndef WOLFSSL_RSA_VERIFY_INLINE
#if !defined(WOLFSSL_RSA_VERIFY_INLINE) && !defined(WOLFSSL_NO_MALLOC)
key->dataIsAlloc = 0;
key->data = NULL;
#endif
@ -853,11 +853,15 @@ int wc_CheckRsaKey(RsaKey* key)
static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
byte* out, word32 outSz, void* heap)
{
byte* tmp;
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
byte* tmp = NULL;
byte tmpF = 0; /* 1 if dynamic memory needs freed */
#else
byte tmp[RSA_MAX_SIZE/8];
#endif
/* needs to be large enough for seed size plus counter(4) */
byte tmpA[WC_MAX_DIGEST_SIZE + 4];
byte tmpF; /* 1 if dynamic memory needs freed */
word32 tmpSz;
word32 tmpSz = 0;
int hLen;
int ret;
word32 counter;
@ -871,6 +875,7 @@ static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
(void)heap;
XMEMSET(tmpA, 0, sizeof(tmpA));
/* check error return of wc_HashGetDigestSize */
if (hLen < 0) {
return hLen;
@ -881,19 +886,26 @@ static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
/* find largest amount of memory needed which will be the max of
* hLen and (seedSz + 4) since tmp is used to store the hash digest */
tmpSz = ((seedSz + 4) > (word32)hLen)? seedSz + 4: (word32)hLen;
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
tmp = (byte*)XMALLOC(tmpSz, heap, DYNAMIC_TYPE_RSA_BUFFER);
if (tmp == NULL) {
return MEMORY_E;
}
tmpF = 1; /* make sure to free memory when done */
#else
if (tmpSz > RSA_MAX_SIZE/8)
return BAD_FUNC_ARG;
#endif
}
else {
/* use array on the stack */
#ifndef WOLFSSL_SMALL_STACK_CACHE
tmpSz = sizeof(tmpA);
#endif
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
tmp = tmpA;
tmpF = 0; /* no need to free memory at end */
#endif
}
#ifdef WOLFSSL_SMALL_STACK_CACHE
@ -935,9 +947,11 @@ static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
#endif
if (ret != 0) {
/* check for if dynamic memory was needed, then free */
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
if (tmpF) {
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
}
#endif
return ret;
}
@ -946,11 +960,12 @@ static int RsaMGF1(enum wc_HashType hType, byte* seed, word32 seedSz,
}
counter++;
} while (idx < outSz);
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
/* check for if dynamic memory was needed, then free */
if (tmpF) {
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
}
#endif
#ifdef WOLFSSL_SMALL_STACK_CACHE
wc_HashFree(hash, hType);
XFREE(hash, heap, DYNAMIC_TYPE_DIGEST);
@ -1038,15 +1053,15 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
int i;
word32 idx;
byte* dbMask;
#ifdef WOLFSSL_SMALL_STACK
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
byte* dbMask = NULL;
byte* lHash = NULL;
byte* seed = NULL;
#else
byte dbMask[RSA_MAX_SIZE/8 + RSA_PSS_PAD_SZ];
/* must be large enough to contain largest hash */
byte lHash[WC_MAX_DIGEST_SIZE];
byte seed[ WC_MAX_DIGEST_SIZE];
byte seed[WC_MAX_DIGEST_SIZE];
#endif
/* no label is allowed, but catch if no label provided and length > 0 */
@ -1060,7 +1075,7 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
return hLen;
}
#ifdef WOLFSSL_SMALL_STACK
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
lHash = (byte*)XMALLOC(hLen, heap, DYNAMIC_TYPE_RSA_BUFFER);
if (lHash == NULL) {
return MEMORY_E;
@ -1143,21 +1158,21 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
return ret;
}
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
/* create maskedDB from dbMask */
dbMask = (byte*)XMALLOC(pkcsBlockLen - hLen - 1, heap, DYNAMIC_TYPE_RSA);
if (dbMask == NULL) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(lHash, heap, DYNAMIC_TYPE_RSA_BUFFER);
XFREE(seed, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif
return MEMORY_E;
}
#endif
XMEMSET(dbMask, 0, pkcsBlockLen - hLen - 1); /* help static analyzer */
ret = RsaMGF(mgf, seed, hLen, dbMask, pkcsBlockLen - hLen - 1, heap);
if (ret != 0) {
XFREE(dbMask, heap, DYNAMIC_TYPE_RSA);
#ifdef WOLFSSL_SMALL_STACK
XFREE(dbMask, heap, DYNAMIC_TYPE_RSA);
XFREE(lHash, heap, DYNAMIC_TYPE_RSA_BUFFER);
XFREE(seed, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif
@ -1170,8 +1185,9 @@ static int RsaPad_OAEP(const byte* input, word32 inputLen, byte* pkcsBlock,
pkcsBlock[idx] = dbMask[i++] ^ pkcsBlock[idx];
idx++;
}
#ifdef WOLFSSL_SMALL_STACK
XFREE(dbMask, heap, DYNAMIC_TYPE_RSA);
#endif
/* create maskedSeed from seedMask */
idx = 0;
@ -1513,11 +1529,16 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
int hLen;
int ret;
byte h[WC_MAX_DIGEST_SIZE]; /* max digest size */
byte* tmp;
word32 idx;
word32 i;
word32 inc;
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
byte* tmp = NULL;
#else
byte tmp[RSA_MAX_SIZE/8 + RSA_PSS_PAD_SZ];
XMEMSET(tmp, 0, RSA_MAX_SIZE/8 + RSA_PSS_PAD_SZ);
#endif
/* no label is allowed, but catch if no label provided and length > 0 */
if (optLabel == NULL && labelLen > 0) {
return BUFFER_E;
@ -1528,16 +1549,20 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
return BAD_FUNC_ARG;
}
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
tmp = (byte*)XMALLOC(pkcsBlockLen, heap, DYNAMIC_TYPE_RSA_BUFFER);
if (tmp == NULL) {
return MEMORY_E;
}
#endif
XMEMSET(tmp, 0, pkcsBlockLen);
/* find seedMask value */
if ((ret = RsaMGF(mgf, (byte*)(pkcsBlock + (hLen + 1)),
pkcsBlockLen - hLen - 1, tmp, hLen, heap)) != 0) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif
return ret;
}
@ -1549,7 +1574,9 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
/* get dbMask value */
if ((ret = RsaMGF(mgf, tmp, hLen, tmp + hLen,
pkcsBlockLen - hLen - 1, heap)) != 0) {
#ifdef WOLFSSL_SMALL_STACK
XFREE(tmp, NULL, DYNAMIC_TYPE_RSA_BUFFER);
#endif
return ret;
}
@ -1558,8 +1585,10 @@ static int RsaUnPad_OAEP(byte *pkcsBlock, unsigned int pkcsBlockLen,
pkcsBlock[hLen + 1 + idx] = pkcsBlock[hLen + 1 + idx] ^ tmp[idx + hLen];
}
#ifdef WOLFSSL_SMALL_STACK
/* done with use of tmp buffer */
XFREE(tmp, heap, DYNAMIC_TYPE_RSA_BUFFER);
#endif
/* advance idx to index of PS and msg separator, account for PS size of 0*/
idx = hLen + 1 + hLen;
@ -3269,7 +3298,8 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
#endif /* WOLFSSL_CRYPTOCELL */
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE)
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE) && \
!defined(WOLFSSL_NO_MALLOC)
/* verify the tmp ptr is NULL, otherwise indicates bad state */
if (key->data != NULL) {
ret = BAD_STATE_E;
@ -3297,7 +3327,8 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
FALL_THROUGH;
case RSA_STATE_DECRYPT_EXPTMOD:
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE)
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE) && \
!defined(WOLFSSL_NO_MALLOC)
ret = wc_RsaFunction_ex(key->data, inLen, key->data, &key->dataLen,
rsa_type, key, rng,
pad_type != WC_RSA_OAEP_PAD);
@ -3316,7 +3347,8 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
FALL_THROUGH;
case RSA_STATE_DECRYPT_UNPAD:
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE)
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE) && \
!defined(WOLFSSL_NO_MALLOC)
ret = wc_RsaUnPad_ex(key->data, key->dataLen, &pad, pad_value, pad_type,
hash, mgf, label, labelSz, saltLen,
mp_count_bits(&key->n), key->heap);
@ -3328,13 +3360,15 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out,
if (rsa_type == RSA_PUBLIC_DECRYPT && ret > (int)outLen)
ret = RSA_BUFFER_E;
else if (ret >= 0 && pad != NULL) {
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE)
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE) && \
!defined(WOLFSSL_NO_MALLOC)
signed char c;
#endif
/* only copy output if not inline */
if (outPtr == NULL) {
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE)
#if !defined(WOLFSSL_RSA_VERIFY_ONLY) && !defined(WOLFSSL_RSA_VERIFY_INLINE) && \
!defined(WOLFSSL_NO_MALLOC)
if (rsa_type == RSA_PRIVATE_DECRYPT) {
word32 i = 0;
word32 j;

View File

@ -205,7 +205,8 @@ struct RsaKey {
char label[RSA_MAX_LABEL_LEN];
int labelLen;
#endif
#if defined(WOLFSSL_ASYNC_CRYPT) || !defined(WOLFSSL_RSA_VERIFY_INLINE)
#if defined(WOLFSSL_ASYNC_CRYPT) || !defined(WOLFSSL_RSA_VERIFY_INLINE) && \
!defined(WOLFSSL_NO_MALLOC)
byte dataIsAlloc;
#endif
#ifdef WC_RSA_NONBLOCK