Fix sp_invmod to handle more input values

This commit is contained in:
Sean Parkinson
2019-11-27 09:56:33 +10:00
parent 245a2b7012
commit 204045223f

View File

@ -475,18 +475,22 @@ int sp_grow(sp_int* a, int l)
int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
{
int i = 0;
sp_int_digit t;
r->used = a->used;
r->dp[0] = a->dp[0] - d;
if (r->dp[0] > a->dp[0]) {
t = a->dp[0] - d;
if (t > a->dp[0]) {
for (++i; i < a->used; i++) {
r->dp[i] = a->dp[i] - 1;
if (r->dp[i] != (sp_int_digit)-1)
break;
}
}
for (++i; i < a->used; i++)
r->dp[i] = a->dp[i];
r->dp[0] = t;
if (r != a) {
for (++i; i < a->used; i++)
r->dp[i] = a->dp[i];
}
sp_clamp(r);
return MP_OKAY;
@ -578,7 +582,7 @@ int sp_sub(sp_int* a, sp_int* b, sp_int* r)
}
for (; i < a->used; i++) {
r->dp[i] = a->dp[i] - c;
c = (a->dp[i] == 0) && (r->dp[i] == (sp_int_digit)-1);
c &= (r->dp[i] == (sp_int_digit)-1);
}
r->used = i;
sp_clamp(r);
@ -725,9 +729,9 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
d = sd;
}
tr->used = sa->used - d->used;
tr->used = sa->used - d->used + 1;
sp_clear(tr);
tr->used = sa->used - d->used;
tr->used = sa->used - d->used + 1;
dt = d->dp[d->used-1];
for (i = sa->used - 1; i >= d->used; i--) {
w = ((sp_int_word)sa->dp[i] << SP_WORD_SIZE) | sa->dp[i-1];
@ -746,6 +750,8 @@ static int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
}
sp_sub(sa, trial, sa);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t)
tr->dp[i + 1 - d->used]++;
if (w > (sp_int_digit)-1) {
i++;
}
@ -809,6 +815,9 @@ int sp_add_d(sp_int* a, sp_int_digit d, sp_int* r)
int i = 0;
r->used = a->used;
if (a->used == 0) {
r->used = 1;
}
r->dp[0] = a->dp[0] + d;
if (r->dp[i] < a->dp[i]) {
for (; i < a->used; i++) {
@ -1294,9 +1303,10 @@ static int sp_div_2(sp_int* a, sp_int* r)
}
/* Divides a by 2 and stores in r: r = a >> 1
/* Calculates the multiplicative inverse in the field.
*
* a SP integer to divide.
* a SP integer to invert.
* m SP integer that is the modulus of the field.
* r SP integer result.
* returns MP_VAL when a or m is 0, MP_MEM when dynamic memory allocation fails
* and MP_OKAY otherwise.
@ -1307,16 +1317,47 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
#ifdef WOLFSSL_SMALL_STACK
sp_int* u = NULL;
sp_int* v;
sp_int* t;
sp_int* b;
sp_int* c;
#else
sp_int u[1], v[1], t[1], b[1], c[1];
sp_int u[1], v[1], b[1], c[1];
#endif
if (sp_iszero(a) || sp_iszero(m)) {
#ifdef WOLFSSL_SMALL_STACK
u = (sp_int*)XMALLOC(sizeof(sp_int) * 4, NULL, DYNAMIC_TYPE_BIGINT);
if (u == NULL) {
err = MP_MEM;
}
else {
v = &u[1];
b = &u[2];
c = &u[3];
}
#endif
sp_init(v);
if ((err == MP_OKAY) && (sp_cmp(a, m) != MP_LT)) {
err = sp_mod(a, m, v);
a = v;
}
/* 0 != n*m + 1 (+ve m), r*a mod 0 is always 0 (never 1) */
if ((err == MP_OKAY) && (sp_iszero(a) || sp_iszero(m))) {
err = MP_VAL;
}
/* r*2*x != n*2*y + 1 */
if ((err == MP_OKAY) && sp_iseven(a) && sp_iseven(m)) {
err = MP_VAL;
}
/* 1*1 = 0*m + 1 */
if ((err == MP_OKAY) && sp_isone(a)) {
sp_set(r, 1);
}
else if (err != MP_OKAY) {
}
else if (sp_iseven(m)) {
/* a^-1 mod m = m + (1 - m*(m^-1 % a)) / a
* = m - (m*(m^-1 % a) - 1) / a
@ -1330,22 +1371,8 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
}
}
else {
#ifdef WOLFSSL_SMALL_STACK
u = (sp_int*)XMALLOC(sizeof(sp_int) * 5, NULL, DYNAMIC_TYPE_BIGINT);
if (u == NULL)
err = MP_MEM;
else {
v = &u[1];
t = &u[2];
b = &u[3];
c = &u[4];
}
#endif
if (err == MP_OKAY) {
sp_init(u);
sp_init(v);
sp_init(t);
sp_init(b);
sp_init(c);
@ -1354,39 +1381,49 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
sp_zero(b);
sp_set(c, 1);
while (!sp_isone(v)) {
while (!sp_isone(v) && !sp_iszero(u)) {
if (sp_iseven(u)) {
sp_div_2(u, u);
if (sp_isodd(b))
if (sp_isodd(b)) {
sp_add(b, m, b);
}
sp_div_2(b, b);
}
else if (sp_iseven(v)) {
sp_div_2(v, v);
if (sp_isodd(c))
if (sp_isodd(c)) {
sp_add(c, m, c);
}
sp_div_2(c, c);
}
else if (sp_cmp(u, v) != MP_LT) {
sp_sub(u, v, u);
if (sp_cmp(b, c) == MP_LT)
if (sp_cmp(b, c) == MP_LT) {
sp_add(b, m, b);
}
sp_sub(b, c, b);
}
else {
sp_sub(v, u, v);
if (sp_cmp(c, b) == MP_LT)
if (sp_cmp(c, b) == MP_LT) {
sp_add(c, m, c);
}
sp_sub(c, b, c);
}
}
sp_copy(c, r);
if (sp_iszero(u)) {
err = MP_VAL;
}
else {
sp_copy(c, r);
}
}
}
#ifdef WOLFSSL_SMALL_STACK
if (u != NULL)
if (u != NULL) {
XFREE(u, NULL, DYNAMIC_TYPE_BIGINT);
}
#endif
return err;