Merge pull request #5633 from SparkiDev/sp_mod_fix

SP int all: sp_mod
This commit is contained in:
JacobBarthelmeh
2022-09-27 10:23:03 -06:00
committed by GitHub

View File

@ -4362,6 +4362,23 @@ static void _sp_zero(sp_int* a)
}
/* Initialize the multi-precision number to be zero with a given max size.
*
* @param [out] a SP integer.
* @param [in] size Number of words to say are available.
*/
static void _sp_init_size(sp_int* a, int size)
{
volatile sp_int_minimal* am = (sp_int_minimal *)a;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init((struct WC_BIGINT*)&am->raw);
#endif
_sp_zero((sp_int*)am);
am->size = size;
}
/* Initialize the multi-precision number to be zero with a given max size.
*
* @param [out] a SP integer.
@ -4379,14 +4396,7 @@ int sp_init_size(sp_int* a, int size)
}
if (err == MP_OKAY) {
volatile sp_int_minimal* am = (sp_int_minimal *)a;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init((struct WC_BIGINT*)&am->raw);
#endif
_sp_zero(a);
am->size = size;
_sp_init_size(a, size);
}
return err;
@ -4401,7 +4411,16 @@ int sp_init_size(sp_int* a, int size)
*/
int sp_init(sp_int* a)
{
return sp_init_size(a, SP_INT_DIGITS);
int err = MP_OKAY;
if (a == NULL) {
err = MP_VAL;
}
else {
_sp_init_size(a, SP_INT_DIGITS);
}
return err;
}
#if !defined(WOLFSSL_RSA_PUBLIC_ONLY) || !defined(NO_DH) || defined(HAVE_ECC)
@ -4420,70 +4439,22 @@ int sp_init_multi(sp_int* n1, sp_int* n2, sp_int* n3, sp_int* n4, sp_int* n5,
sp_int* n6)
{
if (n1 != NULL) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n1->raw);
#endif
_sp_zero(n1);
n1->dp[0] = 0;
n1->size = SP_INT_DIGITS;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n1->raw);
#endif
_sp_init_size(n1, SP_INT_DIGITS);
}
if (n2 != NULL) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n2->raw);
#endif
_sp_zero(n2);
n2->dp[0] = 0;
n2->size = SP_INT_DIGITS;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n2->raw);
#endif
_sp_init_size(n2, SP_INT_DIGITS);
}
if (n3 != NULL) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n3->raw);
#endif
_sp_zero(n3);
n3->dp[0] = 0;
n3->size = SP_INT_DIGITS;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n3->raw);
#endif
_sp_init_size(n3, SP_INT_DIGITS);
}
if (n4 != NULL) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n4->raw);
#endif
_sp_zero(n4);
n4->dp[0] = 0;
n4->size = SP_INT_DIGITS;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n4->raw);
#endif
_sp_init_size(n4, SP_INT_DIGITS);
}
if (n5 != NULL) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n5->raw);
#endif
_sp_zero(n5);
n5->dp[0] = 0;
n5->size = SP_INT_DIGITS;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n5->raw);
#endif
_sp_init_size(n5, SP_INT_DIGITS);
}
if (n6 != NULL) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n6->raw);
#endif
_sp_zero(n6);
n6->dp[0] = 0;
n6->size = SP_INT_DIGITS;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&n6->raw);
#endif
_sp_init_size(n6, SP_INT_DIGITS);
}
return MP_OKAY;
@ -7354,19 +7325,19 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
sd = td[0];
trial = td[1];
sp_init_size(sd, d->used + 1);
sp_init_size(trial, a->used + 1);
_sp_init_size(sd, d->used + 1);
_sp_init_size(trial, a->used + 1);
#if (defined(WOLFSSL_SMALL_STACK) || defined(SP_ALLOC)) && \
!defined(WOLFSSL_SP_NO_MALLOC)
if (sa != rem) {
sp_init_size(sa, a->used + 1);
_sp_init_size(sa, a->used + 1);
}
if (tr != r) {
sp_init_size(tr, a->used - d->used + 2);
_sp_init_size(tr, a->used - d->used + 2);
}
#else
sp_init_size(sa, a->used + 1);
sp_init_size(tr, a->used - d->used + 2);
_sp_init_size(sa, a->used + 1);
_sp_init_size(tr, a->used - d->used + 2);
#endif
/* Move divisor to top of word. Adjust dividend as well. */
@ -7457,6 +7428,9 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r)
if ((a == NULL) || (m == NULL) || (r == NULL)) {
err = MP_VAL;
}
else if (a->used >= SP_INT_DIGITS) {
err = MP_VAL;
}
#ifndef WOLFSSL_SP_INT_NEGATIVE
if (err == MP_OKAY) {
@ -7465,7 +7439,7 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r)
#else
ALLOC_SP_INT(t, a->used + 1, err, NULL);
if (err == MP_OKAY) {
sp_init_size(t, a->used + 1);
_sp_init_size(t, a->used + 1);
err = sp_div(a, m, NULL, t);
}
if (err == MP_OKAY) {
@ -10732,13 +10706,13 @@ int sp_invmod_mont_ct(sp_int* a, sp_int* m, sp_int* r, sp_int_digit mp)
if (err == MP_OKAY) {
t = pre[CT_INV_MOD_PRE_CNT + 0];
e = pre[CT_INV_MOD_PRE_CNT + 1];
sp_init_size(t, m->used * 2 + 1);
sp_init_size(e, m->used * 2 + 1);
_sp_init_size(t, m->used * 2 + 1);
_sp_init_size(e, m->used * 2 + 1);
sp_init_size(pre[0], m->used * 2 + 1);
_sp_init_size(pre[0], m->used * 2 + 1);
err = sp_copy(a, pre[0]);
for (i = 1; (err == MP_OKAY) && (i < CT_INV_MOD_PRE_CNT); i++) {
sp_init_size(pre[i], m->used * 2 + 1);
_sp_init_size(pre[i], m->used * 2 + 1);
err = sp_sqr(pre[i-1], pre[i]);
if (err == MP_OKAY) {
err = _sp_mont_red(pre[i], m, mp);
@ -10839,10 +10813,10 @@ static int _sp_exptmod_ex(sp_int* b, sp_int* e, int bits, sp_int* m, sp_int* r)
ALLOC_SP_INT_ARRAY(t, 2 * m->used + 1, 3, err, NULL);
#endif
if (err == MP_OKAY) {
sp_init_size(t[0], 2 * m->used + 1);
sp_init_size(t[1], 2 * m->used + 1);
_sp_init_size(t[0], 2 * m->used + 1);
_sp_init_size(t[1], 2 * m->used + 1);
#ifndef WC_NO_CACHE_RESISTANT
sp_init_size(t[2], 2 * m->used + 1);
_sp_init_size(t[2], 2 * m->used + 1);
#endif
/* Ensure base is less than exponent. */
@ -10940,10 +10914,10 @@ static int _sp_exptmod_mont_ex(sp_int* b, sp_int* e, int bits, sp_int* m,
ALLOC_SP_INT_ARRAY(t, m->used * 2 + 1, 4, err, NULL);
if (err == MP_OKAY) {
sp_init_size(t[0], m->used * 2 + 1);
sp_init_size(t[1], m->used * 2 + 1);
sp_init_size(t[2], m->used * 2 + 1);
sp_init_size(t[3], m->used * 2 + 1);
_sp_init_size(t[0], m->used * 2 + 1);
_sp_init_size(t[1], m->used * 2 + 1);
_sp_init_size(t[2], m->used * 2 + 1);
_sp_init_size(t[3], m->used * 2 + 1);
/* Ensure base is less than exponent. */
if (_sp_cmp_abs(b, m) != MP_LT) {
@ -11076,9 +11050,9 @@ static int _sp_exptmod_mont_ex(sp_int* b, sp_int* e, int bits, sp_int* m,
tr = t[preCnt];
for (i = 0; i < preCnt; i++) {
sp_init_size(t[i], m->used * 2 + 1);
_sp_init_size(t[i], m->used * 2 + 1);
}
sp_init_size(tr, m->used * 2 + 1);
_sp_init_size(tr, m->used * 2 + 1);
/* Ensure base is less than exponent. */
if (_sp_cmp_abs(b, m) != MP_LT) {
@ -11240,8 +11214,8 @@ static int _sp_exptmod_base_2(sp_int* e, int digits, sp_int* m, sp_int* r)
t = d[0];
tr = d[1];
sp_init_size(t, m->used * 2 + 1);
sp_init_size(tr, m->used * 2 + 1);
_sp_init_size(t, m->used * 2 + 1);
_sp_init_size(tr, m->used * 2 + 1);
if (m->used > 1) {
err = sp_mont_setup(m, &mp);
@ -11624,10 +11598,10 @@ static int _sp_exptmod_nct(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
bm = t[preCnt + 1];
for (i = 0; i < preCnt; i++) {
sp_init_size(t[i], m->used * 2 + 1);
_sp_init_size(t[i], m->used * 2 + 1);
}
sp_init_size(tr, m->used * 2 + 1);
sp_init_size(bm, m->used * 2 + 1);
_sp_init_size(tr, m->used * 2 + 1);
_sp_init_size(bm, m->used * 2 + 1);
/* Ensure base is less than exponent. */
if (_sp_cmp_abs(b, m) != MP_LT) {
@ -11839,8 +11813,8 @@ static int _sp_exptmod_nct(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
ALLOC_SP_INT_ARRAY(t, m->used * 2 + 1, 2, err, NULL);
if (err == MP_OKAY) {
sp_init_size(t[0], m->used * 2 + 1);
sp_init_size(t[1], m->used * 2 + 1);
_sp_init_size(t[0], m->used * 2 + 1);
_sp_init_size(t[1], m->used * 2 + 1);
/* Ensure base is less than exponent. */
if (_sp_cmp_abs(b, m) != MP_LT) {
@ -16014,9 +15988,9 @@ static int sp_prime_miller_rabin(sp_int* a, sp_int* b, int* result)
r = t[2];
/* Only 'y' needs to be twice as big. */
sp_init_size(n1, a->used * 2 + 1);
sp_init_size(y, a->used * 2 + 1);
sp_init_size(r, a->used * 2 + 1);
_sp_init_size(n1, a->used * 2 + 1);
_sp_init_size(y, a->used * 2 + 1);
_sp_init_size(r, a->used * 2 + 1);
err = sp_prime_miller_rabin_ex(a, b, result, n1, y, r);
@ -16153,7 +16127,7 @@ int sp_prime_is_prime(sp_int* a, int t, int* result)
ALLOC_SP_INT(b, 1, err, NULL);
if (err == MP_OKAY) {
/* now do 't' miller rabins */
sp_init_size(b, 1);
_sp_init_size(b, 1);
for (i = 0; i < t; i++) {
sp_set(b, sp_primes[i]);
err = sp_prime_miller_rabin(a, b, result);
@ -16259,11 +16233,11 @@ int sp_prime_is_prime_ex(sp_int* a, int t, int* result, WC_RNG* rng)
r = d[1];
/* Only 'y' needs to be twice as big. */
sp_init_size(b , a->used + 1);
sp_init_size(c , a->used + 1);
sp_init_size(n1, a->used + 1);
sp_init_size(y , a->used * 2 + 1);
sp_init_size(r , a->used * 2 + 1);
_sp_init_size(b , a->used + 1);
_sp_init_size(c , a->used + 1);
_sp_init_size(n1, a->used + 1);
_sp_init_size(y , a->used * 2 + 1);
_sp_init_size(r , a->used * 2 + 1);
_sp_sub_d(a, 2, c);
@ -16484,8 +16458,8 @@ int sp_lcm(sp_int* a, sp_int* b, sp_int* r)
ALLOC_SP_INT_ARRAY(t, used, 2, err, NULL);
if (err == MP_OKAY) {
sp_init_size(t[0], used);
sp_init_size(t[1], used);
_sp_init_size(t[0], used);
_sp_init_size(t[1], used);
SAVE_VECTOR_REGISTERS(err = _svr_ret;);