Merge pull request #2612 from SparkiDev/sp_div_small_a

sp_div improved to handle when a has less digits than d
This commit is contained in:
toddouska
2019-12-05 16:14:05 -08:00
committed by GitHub

View File

@ -179,6 +179,7 @@ int sp_read_unsigned_bin(sp_int* a, const byte* in, word32 inSz)
for (j++; j < a->size; j++) for (j++; j < a->size; j++)
a->dp[j] = 0; a->dp[j] = 0;
sp_clamp(a);
return MP_OKAY; return MP_OKAY;
} }
@ -234,6 +235,7 @@ int sp_read_radix(sp_int* a, const char* in, int radix)
for (k++; k < a->size; k++) for (k++; k < a->size; k++)
a->dp[k] = 0; a->dp[k] = 0;
} }
sp_clamp(a);
return err; return err;
} }
@ -402,13 +404,13 @@ int sp_copy(sp_int* a, sp_int* r)
/* creates "a" then copies b into it */ /* creates "a" then copies b into it */
int sp_init_copy (sp_int * a, sp_int * b) int sp_init_copy (sp_int * a, sp_int * b)
{ {
int res; int err;
if ((res = sp_init(a)) == MP_OKAY) { if ((err = sp_init(a)) == MP_OKAY) {
if((res = sp_copy (b, a)) != MP_OKAY) { if((err = sp_copy (b, a)) != MP_OKAY) {
sp_clear(a); sp_clear(a);
} }
} }
return res; return err;
} }
#endif #endif
@ -420,8 +422,14 @@ int sp_init_copy (sp_int * a, sp_int * b)
*/ */
int sp_set(sp_int* a, sp_int_digit d) int sp_set(sp_int* a, sp_int_digit d)
{ {
a->dp[0] = d; if (d == 0) {
a->used = 1; a->dp[0] = d;
a->used = 0;
}
else {
a->dp[0] = d;
a->used = 1;
}
return MP_OKAY; return MP_OKAY;
} }
@ -479,6 +487,7 @@ int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
} }
for (++i; i < a->used; i++) for (++i; i < a->used; i++)
r->dp[i] = a->dp[i]; r->dp[i] = a->dp[i];
sp_clamp(r);
return MP_OKAY; return MP_OKAY;
} }
@ -635,6 +644,8 @@ static void _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem) static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
{ {
int err = MP_OKAY; int err = MP_OKAY;
int ret;
int done = 0;
int i; int i;
int s; int s;
sp_int_word w = 0; sp_int_word w = 0;
@ -655,8 +666,38 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
if (sp_iszero(d)) if (sp_iszero(d))
err = MP_VAL; err = MP_VAL;
ret = sp_cmp(a, d);
if (ret == MP_LT) {
if (rem != NULL) {
sp_copy(a, rem);
}
if (r != NULL) {
sp_set(r, 0);
}
done = 1;
}
else if (ret == MP_EQ) {
if (rem != NULL) {
sp_set(rem, 0);
}
if (r != NULL) {
sp_set(r, 1);
}
done = 1;
}
else if (sp_count_bits(a) == sp_count_bits(d)) {
/* a is greater than d but same bit length */
if (rem != NULL) {
sp_sub(a, d, rem);
}
if (r != NULL) {
sp_set(r, 1);
}
done = 1;
}
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
if (err == MP_OKAY) { if (!done && err == MP_OKAY) {
sa = (sp_int*)XMALLOC(sizeof(sp_int) * 4, NULL, DYNAMIC_TYPE_BIGINT); sa = (sp_int*)XMALLOC(sizeof(sp_int) * 4, NULL, DYNAMIC_TYPE_BIGINT);
if (sa == NULL) if (sa == NULL)
err = MP_MEM; err = MP_MEM;
@ -668,7 +709,7 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
} }
#endif #endif
if (err == MP_OKAY) { if (!done && err == MP_OKAY) {
sp_init(sa); sp_init(sa);
sp_init(sd); sp_init(sd);
sp_init(tr); sp_init(tr);
@ -685,26 +726,30 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
} }
tr->used = sa->used - d->used; tr->used = sa->used - d->used;
if ((sa->dp[sa->used-1] >> (SP_WORD_SIZE - 1)) == 1) { sp_clear(tr);
_sp_mul_d(d, 1, trial, sa->used - d->used); tr->used = sa->used - d->used;
if (sp_cmp(sa, trial) != MP_LT) {
tr->used++;
sp_sub(sa, trial, sa);
tr->dp[sa->used - d->used] = 1;
}
}
dt = d->dp[d->used-1]; dt = d->dp[d->used-1];
for (i = sa->used - 1; i >= d->used; i--) { for (i = sa->used - 1; i >= d->used; i--) {
w = ((sp_int_word)sa->dp[i] << SP_WORD_SIZE) | sa->dp[i-1]; w = ((sp_int_word)sa->dp[i] << SP_WORD_SIZE) | sa->dp[i-1];
t = (sp_int_digit)(w / dt); w /= dt;
_sp_mul_d(d, t, trial, i - d->used); if (w > (sp_int_digit)-1) {
while (sp_cmp(trial, sa) == MP_GT) { t = (sp_int_digit)-1;
t--; }
_sp_mul_d(d, t, trial, i - d->used); else {
t = (sp_int_digit)w;
}
if (t > 0) {
_sp_mul_d(d, t, trial, i - d->used);
while (sp_cmp(trial, sa) == MP_GT) {
t--;
_sp_mul_d(d, t, trial, i - d->used);
}
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (w > (sp_int_digit)-1) {
i++;
}
} }
sp_sub(sa, trial, sa);
tr->dp[i - d->used] = t;
} }
sp_clamp(tr); sp_clamp(tr);
@ -800,6 +845,7 @@ int sp_lshd(sp_int* a, int s)
XMEMMOVE(a->dp + s, a->dp, a->used * sizeof(sp_int_digit)); XMEMMOVE(a->dp + s, a->dp, a->used * sizeof(sp_int_digit));
a->used += s; a->used += s;
XMEMSET(a->dp, 0, s * sizeof(sp_int_digit)); XMEMSET(a->dp, 0, s * sizeof(sp_int_digit));
sp_clamp(a);
return MP_OKAY; return MP_OKAY;
} }
@ -1268,8 +1314,9 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
sp_int u[1], v[1], t[1], b[1], c[1]; sp_int u[1], v[1], t[1], b[1], c[1];
#endif #endif
if (sp_iszero(a) || sp_iszero(m)) if (sp_iszero(a) || sp_iszero(m)) {
err = MP_VAL; err = MP_VAL;
}
else if (sp_iseven(m)) { else if (sp_iseven(m)) {
/* a^-1 mod m = m + (1 - m*(m^-1 % a)) / a /* a^-1 mod m = m + (1 - m*(m^-1 % a)) / a
* = m - (m*(m^-1 % a) - 1) / a * = m - (m*(m^-1 % a) - 1) / a
@ -1882,7 +1929,7 @@ int sp_prime_is_prime_ex(sp_int* a, int t, int* result, WC_RNG* rng)
} }
#ifndef NO_DH #ifndef NO_DH
int sp_exch (sp_int* a, sp_int* b) int sp_exch(sp_int* a, sp_int* b)
{ {
int err = MP_OKAY; int err = MP_OKAY;
#ifndef WOLFSSL_SMALL_STACK #ifndef WOLFSSL_SMALL_STACK