rsa.c:wc_CheckProbablePrime(): WOLFSSL_SMALL_STACK refactor

This commit is contained in:
Daniel Pouzzner
2020-09-03 23:23:52 -05:00
parent af6bd1d163
commit 4f5bbbdca8

View File

@@ -3952,30 +3952,49 @@ int wc_CheckProbablePrime(const byte* pRaw, word32 pRawSz,
int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng) int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
{ {
#ifndef WC_NO_RNG #ifndef WC_NO_RNG
mp_int p, q, tmp1, tmp2, tmp3; #ifdef WOLFSSL_SMALL_STACK
mp_int *p = (mp_int *)XMALLOC(sizeof *p, key->heap, DYNAMIC_TYPE_RSA),
*q = (mp_int *)XMALLOC(sizeof *q, key->heap, DYNAMIC_TYPE_RSA),
*tmp1 = (mp_int *)XMALLOC(sizeof *tmp1, key->heap, DYNAMIC_TYPE_RSA),
*tmp2 = (mp_int *)XMALLOC(sizeof *tmp2, key->heap, DYNAMIC_TYPE_RSA),
*tmp3 = (mp_int *)XMALLOC(sizeof *tmp3, key->heap, DYNAMIC_TYPE_RSA);
#else
mp_int p_buf, *p = &p_buf,
q_buf, *q = &q_buf,
tmp1_buf, *tmp1 = &tmp1_buf,
tmp2_buf, *tmp2 = &tmp2_buf,
tmp3_buf, *tmp3 = &tmp3_buf;
#endif
int err, i, failCount, primeSz, isPrime = 0; int err, i, failCount, primeSz, isPrime = 0;
byte* buf = NULL; byte* buf = NULL;
if (key == NULL || rng == NULL) if (key == NULL || rng == NULL) {
return BAD_FUNC_ARG; err = BAD_FUNC_ARG;
goto out;
}
if (!RsaSizeCheck(size)) if (!RsaSizeCheck(size)) {
return BAD_FUNC_ARG; err = BAD_FUNC_ARG;
goto out;
}
if (e < 3 || (e & 1) == 0) if (e < 3 || (e & 1) == 0) {
return BAD_FUNC_ARG; err = BAD_FUNC_ARG;
goto out;
}
#if defined(WOLFSSL_CRYPTOCELL) #if defined(WOLFSSL_CRYPTOCELL)
return cc310_RSA_GenerateKeyPair(key, size, e); err = cc310_RSA_GenerateKeyPair(key, size, e);
goto out;
#endif /*WOLFSSL_CRYPTOCELL*/ #endif /*WOLFSSL_CRYPTOCELL*/
#ifdef WOLF_CRYPTO_CB #ifdef WOLF_CRYPTO_CB
if (key->devId != INVALID_DEVID) { if (key->devId != INVALID_DEVID) {
int ret = wc_CryptoCb_MakeRsaKey(key, size, e, rng); err = wc_CryptoCb_MakeRsaKey(key, size, e, rng);
if (ret != CRYPTOCB_UNAVAILABLE) if (err != CRYPTOCB_UNAVAILABLE)
return ret; goto out;
/* fall-through when unavailable */ /* fall-through when unavailable */
} }
#endif #endif
@@ -3986,7 +4005,8 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
#ifdef HAVE_CAVIUM #ifdef HAVE_CAVIUM
/* TODO: Not implemented */ /* TODO: Not implemented */
#elif defined(HAVE_INTEL_QA) #elif defined(HAVE_INTEL_QA)
return IntelQaRsaKeyGen(&key->asyncDev, key, size, e, rng); err = IntelQaRsaKeyGen(&key->asyncDev, key, size, e, rng);
goto out;
#else #else
if (wc_AsyncTestInit(&key->asyncDev, ASYNC_TEST_RSA_MAKE)) { if (wc_AsyncTestInit(&key->asyncDev, ASYNC_TEST_RSA_MAKE)) {
WC_ASYNC_TEST* testDev = &key->asyncDev.test; WC_ASYNC_TEST* testDev = &key->asyncDev.test;
@@ -3994,16 +4014,17 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
testDev->rsaMake.key = key; testDev->rsaMake.key = key;
testDev->rsaMake.size = size; testDev->rsaMake.size = size;
testDev->rsaMake.e = e; testDev->rsaMake.e = e;
return WC_PENDING_E; err = WC_PENDING_E;
goto out;
} }
#endif #endif
} }
#endif #endif
err = mp_init_multi(&p, &q, &tmp1, &tmp2, &tmp3, NULL); err = mp_init_multi(p, q, tmp1, tmp2, tmp3, NULL);
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_set_int(&tmp3, e); err = mp_set_int(tmp3, e);
/* The failCount value comes from NIST FIPS 186-4, section B.3.3, /* The failCount value comes from NIST FIPS 186-4, section B.3.3,
* process steps 4.7 and 5.8. */ * process steps 4.7 and 5.8. */
@@ -4035,11 +4056,11 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
/* make candidate odd */ /* make candidate odd */
buf[primeSz-1] |= 0x01; buf[primeSz-1] |= 0x01;
/* load value */ /* load value */
err = mp_read_unsigned_bin(&p, buf, primeSz); err = mp_read_unsigned_bin(p, buf, primeSz);
} }
if (err == MP_OKAY) if (err == MP_OKAY)
err = _CheckProbablePrime(&p, NULL, &tmp3, size, &isPrime, rng); err = _CheckProbablePrime(p, NULL, tmp3, size, &isPrime, rng);
#ifdef HAVE_FIPS #ifdef HAVE_FIPS
i++; i++;
@@ -4070,11 +4091,11 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
/* make candidate odd */ /* make candidate odd */
buf[primeSz-1] |= 0x01; buf[primeSz-1] |= 0x01;
/* load value */ /* load value */
err = mp_read_unsigned_bin(&q, buf, primeSz); err = mp_read_unsigned_bin(q, buf, primeSz);
} }
if (err == MP_OKAY) if (err == MP_OKAY)
err = _CheckProbablePrime(&p, &q, &tmp3, size, &isPrime, rng); err = _CheckProbablePrime(p, q, tmp3, size, &isPrime, rng);
#ifdef HAVE_FIPS #ifdef HAVE_FIPS
i++; i++;
@@ -4093,12 +4114,12 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
XFREE(buf, key->heap, DYNAMIC_TYPE_RSA); XFREE(buf, key->heap, DYNAMIC_TYPE_RSA);
} }
if (err == MP_OKAY && mp_cmp(&p, &q) < 0) { if (err == MP_OKAY && mp_cmp(p, q) < 0) {
err = mp_copy(&p, &tmp1); err = mp_copy(p, tmp1);
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_copy(&q, &p); err = mp_copy(q, p);
if (err == MP_OKAY) if (err == MP_OKAY)
mp_copy(&tmp1, &q); mp_copy(tmp1, q);
} }
/* Setup RsaKey buffers */ /* Setup RsaKey buffers */
@@ -4109,15 +4130,15 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
/* Software Key Calculation */ /* Software Key Calculation */
if (err == MP_OKAY) /* tmp1 = p-1 */ if (err == MP_OKAY) /* tmp1 = p-1 */
err = mp_sub_d(&p, 1, &tmp1); err = mp_sub_d(p, 1, tmp1);
if (err == MP_OKAY) /* tmp2 = q-1 */ if (err == MP_OKAY) /* tmp2 = q-1 */
err = mp_sub_d(&q, 1, &tmp2); err = mp_sub_d(q, 1, tmp2);
#ifdef WC_RSA_BLINDING #ifdef WC_RSA_BLINDING
if (err == MP_OKAY) /* tmp3 = order of n */ if (err == MP_OKAY) /* tmp3 = order of n */
err = mp_mul(&tmp1, &tmp2, &tmp3); err = mp_mul(tmp1, tmp2, tmp3);
#else #else
if (err == MP_OKAY) /* tmp3 = lcm(p-1, q-1), last loop */ if (err == MP_OKAY) /* tmp3 = lcm(p-1, q-1), last loop */
err = mp_lcm(&tmp1, &tmp2, &tmp3); err = mp_lcm(tmp1, tmp2, tmp3);
#endif #endif
/* make key */ /* make key */
if (err == MP_OKAY) /* key->e = e */ if (err == MP_OKAY) /* key->e = e */
@@ -4126,13 +4147,13 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
/* Blind the inverse operation with a value that is invertable */ /* Blind the inverse operation with a value that is invertable */
if (err == MP_OKAY) { if (err == MP_OKAY) {
do { do {
err = mp_rand(&key->p, get_digit_count(&tmp3), rng); err = mp_rand(&key->p, get_digit_count(tmp3), rng);
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_set_bit(&key->p, 0); err = mp_set_bit(&key->p, 0);
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_set_bit(&key->p, size - 1); err = mp_set_bit(&key->p, size - 1);
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_gcd(&key->p, &tmp3, &key->q); err = mp_gcd(&key->p, tmp3, &key->q);
} }
while ((err == MP_OKAY) && !mp_isone(&key->q)); while ((err == MP_OKAY) && !mp_isone(&key->q));
} }
@@ -4140,33 +4161,33 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
err = mp_mul_d(&key->p, (mp_digit)e, &key->e); err = mp_mul_d(&key->p, (mp_digit)e, &key->e);
#endif #endif
if (err == MP_OKAY) /* key->d = 1/e mod lcm(p-1, q-1) */ if (err == MP_OKAY) /* key->d = 1/e mod lcm(p-1, q-1) */
err = mp_invmod(&key->e, &tmp3, &key->d); err = mp_invmod(&key->e, tmp3, &key->d);
#ifdef WC_RSA_BLINDING #ifdef WC_RSA_BLINDING
/* Take off blinding from d and reset e */ /* Take off blinding from d and reset e */
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_mulmod(&key->d, &key->p, &tmp3, &key->d); err = mp_mulmod(&key->d, &key->p, tmp3, &key->d);
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_set_int(&key->e, (mp_digit)e); err = mp_set_int(&key->e, (mp_digit)e);
#endif #endif
if (err == MP_OKAY) /* key->n = pq */ if (err == MP_OKAY) /* key->n = pq */
err = mp_mul(&p, &q, &key->n); err = mp_mul(p, q, &key->n);
if (err == MP_OKAY) /* key->dP = d mod(p-1) */ if (err == MP_OKAY) /* key->dP = d mod(p-1) */
err = mp_mod(&key->d, &tmp1, &key->dP); err = mp_mod(&key->d, tmp1, &key->dP);
if (err == MP_OKAY) /* key->dQ = d mod(q-1) */ if (err == MP_OKAY) /* key->dQ = d mod(q-1) */
err = mp_mod(&key->d, &tmp2, &key->dQ); err = mp_mod(&key->d, tmp2, &key->dQ);
#ifdef WOLFSSL_MP_INVMOD_CONSTANT_TIME #ifdef WOLFSSL_MP_INVMOD_CONSTANT_TIME
if (err == MP_OKAY) /* key->u = 1/q mod p */ if (err == MP_OKAY) /* key->u = 1/q mod p */
err = mp_invmod(&q, &p, &key->u); err = mp_invmod(q, p, &key->u);
#else #else
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_sub_d(&p, 2, &tmp3); err = mp_sub_d(p, 2, tmp3);
if (err == MP_OKAY) /* key->u = 1/q mod p = q^p-2 mod p */ if (err == MP_OKAY) /* key->u = 1/q mod p = q^p-2 mod p */
err = mp_exptmod(&q, &tmp3 , &p, &key->u); err = mp_exptmod(q, tmp3 , p, &key->u);
#endif #endif
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_copy(&p, &key->p); err = mp_copy(p, &key->p);
if (err == MP_OKAY) if (err == MP_OKAY)
err = mp_copy(&q, &key->q); err = mp_copy(q, &key->q);
#ifdef HAVE_WOLF_BIGINT #ifdef HAVE_WOLF_BIGINT
/* make sure raw unsigned bin version is available */ /* make sure raw unsigned bin version is available */
@@ -4191,11 +4212,11 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
if (err == MP_OKAY) if (err == MP_OKAY)
key->type = RSA_PRIVATE; key->type = RSA_PRIVATE;
mp_clear(&tmp1); mp_clear(tmp1);
mp_clear(&tmp2); mp_clear(tmp2);
mp_clear(&tmp3); mp_clear(tmp3);
mp_clear(&p); mp_clear(p);
mp_clear(&q); mp_clear(q);
#if defined(WOLFSSL_KEY_GEN) && !defined(WOLFSSL_NO_RSA_KEY_CHECK) #if defined(WOLFSSL_KEY_GEN) && !defined(WOLFSSL_NO_RSA_KEY_CHECK)
/* Perform the pair-wise consistency test on the new key. */ /* Perform the pair-wise consistency test on the new key. */
@@ -4205,7 +4226,7 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
if (err != 0) { if (err != 0) {
wc_FreeRsaKey(key); wc_FreeRsaKey(key);
return err; goto out;
} }
#if defined(WOLFSSL_XILINX_CRYPT) || defined(WOLFSSL_CRYPTOCELL) #if defined(WOLFSSL_XILINX_CRYPT) || defined(WOLFSSL_CRYPTOCELL)
@@ -4213,7 +4234,25 @@ int wc_MakeRsaKey(RsaKey* key, int size, long e, WC_RNG* rng)
return BAD_STATE_E; return BAD_STATE_E;
} }
#endif #endif
return 0;
err = 0;
out:
#ifdef WOLFSSL_SMALL_STACK
if (p)
XFREE(p, key->heap, DYNAMIC_TYPE_RSA);
if (q)
XFREE(q, key->heap, DYNAMIC_TYPE_RSA);
if (tmp1)
XFREE(tmp1, key->heap, DYNAMIC_TYPE_RSA);
if (tmp2)
XFREE(tmp2, key->heap, DYNAMIC_TYPE_RSA);
if (tmp3)
XFREE(tmp3, key->heap, DYNAMIC_TYPE_RSA);
#endif
return err;
#else #else
return NOT_COMPILED_IN; return NOT_COMPILED_IN;
#endif #endif