Merge pull request #5829 from SparkiDev/sp_invmod_fixes

SP int: fix error checks when modulus even
This commit is contained in:
David Garske
2022-12-01 15:05:22 -08:00
committed by GitHub
2 changed files with 77 additions and 61 deletions

View File

@@ -11433,8 +11433,8 @@ static int _sp_invmod_div(sp_int* a, sp_int* m, sp_int* x, sp_int* y, sp_int* b,
sp_int* c, sp_int* inv) sp_int* c, sp_int* inv)
{ {
int err = MP_OKAY; int err = MP_OKAY;
sp_int* d; sp_int* d = NULL;
sp_int* r; sp_int* r = NULL;
sp_int* s; sp_int* s;
#ifndef WOLFSSL_SP_INT_NEGATIVE #ifndef WOLFSSL_SP_INT_NEGATIVE
int bneg = 0; int bneg = 0;
@@ -11459,6 +11459,7 @@ static int _sp_invmod_div(sp_int* a, sp_int* m, sp_int* x, sp_int* y, sp_int* b,
while ((err == MP_OKAY) && (!sp_isone(x)) && (!sp_iszero(x))) { while ((err == MP_OKAY) && (!sp_isone(x)) && (!sp_iszero(x))) {
/* 2.1. d = x / y, r = x mod y */ /* 2.1. d = x / y, r = x mod y */
err = sp_div(x, y, d, r); err = sp_div(x, y, d, r);
if (err == MP_OKAY) {
/* 2.2. c -= d * b */ /* 2.2. c -= d * b */
if (sp_isone(d)) { if (sp_isone(d)) {
/* c -= 1 * b */ /* c -= 1 * b */
@@ -11477,6 +11478,7 @@ static int _sp_invmod_div(sp_int* a, sp_int* m, sp_int* x, sp_int* y, sp_int* b,
/* 2.4. s = b, b = c, c = s */ /* 2.4. s = b, b = c, c = s */
s = b; b = c; c = s; s = b; b = c; c = s;
} }
}
/* 3. If y != 0 then NO_INVERSE */ /* 3. If y != 0 then NO_INVERSE */
if ((err == MP_OKAY) && (!sp_iszero(y))) { if ((err == MP_OKAY) && (!sp_iszero(y))) {
err = MP_VAL; err = MP_VAL;
@@ -11494,6 +11496,7 @@ static int _sp_invmod_div(sp_int* a, sp_int* m, sp_int* x, sp_int* y, sp_int* b,
while ((err == MP_OKAY) && (!sp_isone(x)) && (!sp_iszero(x))) { while ((err == MP_OKAY) && (!sp_isone(x)) && (!sp_iszero(x))) {
/* 2.1. d = x / y, r = x mod y */ /* 2.1. d = x / y, r = x mod y */
err = sp_div(x, y, d, r); err = sp_div(x, y, d, r);
if (err == MP_OKAY) {
if (sp_isone(d)) { if (sp_isone(d)) {
/* c -= 1 * b */ /* c -= 1 * b */
if ((bneg ^ cneg) == 1) { if ((bneg ^ cneg) == 1) {
@@ -11520,7 +11523,8 @@ static int _sp_invmod_div(sp_int* a, sp_int* m, sp_int* x, sp_int* y, sp_int* b,
_sp_add_off(c, d, c, 0); _sp_add_off(c, d, c, 0);
} }
else if (_sp_cmp_abs(c, d) == MP_LT) { else if (_sp_cmp_abs(c, d) == MP_LT) {
/* |c| < |d| and same sign, reverse subtract and negate. */ /* |c| < |d| and same sign, reverse subtract and negate.
*/
_sp_sub_off(d, c, c, 0); _sp_sub_off(d, c, c, 0);
cneg = !cneg; cneg = !cneg;
} }
@@ -11535,6 +11539,7 @@ static int _sp_invmod_div(sp_int* a, sp_int* m, sp_int* x, sp_int* y, sp_int* b,
s = b; b = c; c = s; s = b; b = c; c = s;
neg = bneg; bneg = cneg; cneg = neg; neg = bneg; bneg = cneg; cneg = neg;
} }
}
/* 3. If y != 0 then NO_INVERSE */ /* 3. If y != 0 then NO_INVERSE */
if ((err == MP_OKAY) && (!sp_iszero(y))) { if ((err == MP_OKAY) && (!sp_iszero(y))) {
err = MP_VAL; err = MP_VAL;
@@ -11654,12 +11659,12 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
sp_mod(m, a, v); sp_mod(m, a, v);
/* v == 0 when a divides m evenly - no inverse. */ /* v == 0 when a divides m evenly - no inverse. */
if (sp_iszero(v)) { if (sp_iszero(v)) {
/* Force u to be the no inverse answer. */ err = MP_VAL;
sp_set(u, 0);
} }
evenMod = 1; evenMod = 1;
} }
if (err == MP_OKAY) {
/* Calculate inverse. */ /* Calculate inverse. */
#if !defined(WOLFSSL_SP_SMALL) && (!defined(NO_RSA) || !defined(NO_DH)) #if !defined(WOLFSSL_SP_SMALL) && (!defined(NO_RSA) || !defined(NO_DH))
if (sp_count_bits(mm) >= 1024) { if (sp_count_bits(mm) >= 1024) {
@@ -11670,6 +11675,7 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
{ {
err = _sp_invmod(ma, mm, u, v, b, c); err = _sp_invmod(ma, mm, u, v, b, c);
} }
}
/* Fixup for even modulus. */ /* Fixup for even modulus. */
if ((err == MP_OKAY) && evenMod) { if ((err == MP_OKAY) && evenMod) {
@@ -11686,7 +11692,7 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
sp_sub(m, c, r); sp_sub(m, c, r);
} }
} }
else { else if (err == MP_OKAY) {
err = sp_copy(c, r); err = sp_copy(c, r);
} }
} }

View File

@@ -43216,6 +43216,16 @@ static int mp_test_invmod(mp_int* a, mp_int* m, mp_int* r)
ret = mp_invmod(a, m, r); ret = mp_invmod(a, m, r);
if (ret != MP_VAL) if (ret != MP_VAL)
return -13172; return -13172;
mp_set(a, 3);
mp_set(m, 6);
ret = mp_invmod(a, m, r);
if (ret != MP_VAL)
return -13181;
mp_set(a, 5*9);
mp_set(m, 6*9);
ret = mp_invmod(a, m, r);
if (ret != MP_VAL)
return -13182;
mp_set(a, 1); mp_set(a, 1);
mp_set(m, 4); mp_set(m, 4);
ret = mp_invmod(a, m, r); ret = mp_invmod(a, m, r);