Merge pull request #4789 from SparkiDev/sp_invmod_oob

SP int: sp_modinv fixes for sizes
This commit is contained in:
David Garske
2022-01-24 09:08:08 -08:00
committed by GitHub
2 changed files with 4498 additions and 4466 deletions

View File

@@ -6722,8 +6722,8 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r)
#endif #endif
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_SP_NO_MALLOC) #if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_SP_NO_MALLOC)
t = (sp_int_digit*)XMALLOC(sizeof(sp_int_digit) * (a->used + b->used), t = (sp_int_digit*)XMALLOC(sizeof(sp_int_digit) * (a->used + b->used), NULL,
NULL, DYNAMIC_TYPE_BIGINT); DYNAMIC_TYPE_BIGINT);
if (t == NULL) { if (t == NULL) {
err = MP_MEM; err = MP_MEM;
} }
@@ -6800,8 +6800,8 @@ int sp_mod(sp_int* a, sp_int* m, sp_int* r)
#endif #endif
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_SP_NO_MALLOC) #if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_SP_NO_MALLOC)
t = (sp_int_digit*)XMALLOC(sizeof(sp_int_digit) * (a->used + b->used), t = (sp_int_digit*)XMALLOC(sizeof(sp_int_digit) * (a->used + b->used), NULL,
NULL, DYNAMIC_TYPE_BIGINT); DYNAMIC_TYPE_BIGINT);
if (t == NULL) { if (t == NULL) {
err = MP_MEM; err = MP_MEM;
} }
@@ -9637,14 +9637,17 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
sp_int* u = NULL; sp_int* u = NULL;
sp_int* v = NULL; sp_int* v = NULL;
sp_int* b = NULL; sp_int* b = NULL;
sp_int* c = NULL;
sp_int* mm; sp_int* mm;
int evenMod = 0; int evenMod = 0;
DECL_SP_INT_ARRAY(t, (m == NULL) ? 1 : (m->used + 1), 4); DECL_SP_INT_ARRAY(t, (m == NULL) ? 1 : (m->used + 1), 3);
DECL_SP_INT(c, (m == NULL) ? 1 : (2 * m->used + 1));
if ((a == NULL) || (m == NULL) || (r == NULL) || (r == m)) { if ((a == NULL) || (m == NULL) || (r == NULL) || (r == m)) {
err = MP_VAL; err = MP_VAL;
} }
if ((err == MP_OKAY) && (m->used * 2 > r->size)) {
err = MP_VAL;
}
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
if ((err == MP_OKAY) && (m->sign == MP_NEG)) { if ((err == MP_OKAY) && (m->sign == MP_NEG)) {
@@ -9652,12 +9655,13 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
} }
#endif #endif
ALLOC_SP_INT_ARRAY(t, m->used + 1, 4, err, NULL); ALLOC_SP_INT_ARRAY(t, m->used + 1, 3, err, NULL);
ALLOC_SP_INT(c, 2 * m->used + 1, err, NULL);
if (err == MP_OKAY) { if (err == MP_OKAY) {
u = t[0]; u = t[0];
v = t[1]; v = t[1];
b = t[2]; b = t[2];
c = t[3]; /* c allocated separately and larger for even mod case. */
if (_sp_cmp_abs(a, m) != MP_LT) { if (_sp_cmp_abs(a, m) != MP_LT) {
err = sp_mod(a, m, r); err = sp_mod(a, m, r);
@@ -9690,9 +9694,9 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
} }
else { else {
sp_init_size(u, m->used + 1); sp_init_size(u, m->used + 1);
sp_init_size(v, 2*m->used + 1); sp_init_size(v, m->used + 1);
sp_init_size(b, m->used + 1); sp_init_size(b, m->used + 1);
sp_init_size(c, m->used + 1); sp_init_size(c, 2 * m->used + 1);
if (sp_iseven(m)) { if (sp_iseven(m)) {
/* a^-1 mod m = m + ((1 - m*(m^-1 % a)) / a) */ /* a^-1 mod m = m + ((1 - m*(m^-1 % a)) / a) */
@@ -9752,13 +9756,13 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
* a^-1 mod m = m + ((1 - m*c) / a) * a^-1 mod m = m + ((1 - m*c) / a)
* => a^-1 mod m = m - ((m*c - 1) / a) * => a^-1 mod m = m - ((m*c - 1) / a)
*/ */
err = sp_mul(c, m, v); err = sp_mul(c, m, c);
if (err == MP_OKAY) { if (err == MP_OKAY) {
_sp_sub_d(v, 1, v); _sp_sub_d(c, 1, c);
err = sp_div(v, a, v, NULL); err = sp_div(c, a, c, NULL);
} }
if (err == MP_OKAY) { if (err == MP_OKAY) {
sp_sub(m, v, r); sp_sub(m, c, r);
} }
} }
else { else {
@@ -9766,6 +9770,7 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
} }
} }
FREE_SP_INT(c, NULL);
FREE_SP_INT_ARRAY(t, NULL); FREE_SP_INT_ARRAY(t, NULL);
return err; return err;
} }
@@ -11542,6 +11547,10 @@ int sp_mul_2d(sp_int* a, int e, sp_int* r)
sp_int_digit o = 0; sp_int_digit o = 0;
sp_int_digit t[4]; sp_int_digit t[4];
#if defined(WOLFSSL_SP_ARM_THUMB) && SP_WORD_SIZE == 32
to = 0;
#endif
SP_ASM_SQR(h, l, a->dp[0]); SP_ASM_SQR(h, l, a->dp[0]);
t[0] = h; t[0] = h;
h = 0; h = 0;
@@ -11603,6 +11612,10 @@ int sp_mul_2d(sp_int* a, int e, sp_int* r)
sp_int_digit to; sp_int_digit to;
sp_int_digit t[6]; sp_int_digit t[6];
#if defined(WOLFSSL_SP_ARM_THUMB) && SP_WORD_SIZE == 32
to = 0;
#endif
SP_ASM_SQR(h, l, a->dp[0]); SP_ASM_SQR(h, l, a->dp[0]);
t[0] = h; t[0] = h;
h = 0; h = 0;

View File

@@ -36889,11 +36889,30 @@ static int mp_test_invmod(mp_int* a, mp_int* m, mp_int* r)
if (ret != MP_OKAY) if (ret != MP_OKAY)
return -13177; return -13177;
#if defined(WOLFSSL_SP_MATH) || defined(WOLFSSL_SP_MATH_ALL)
/* Maximum 'a' */
mp_set(a, 0);
mp_set_bit(a, (r->size / 2)* SP_WORD_SIZE - 1);
mp_sub_d(a, 1, a);
/* Modulus too big. */
mp_set(m, 0);
mp_set_bit(m, (r->size / 2) * SP_WORD_SIZE);
ret = mp_invmod(a, m, r);
if (ret != MP_VAL)
return -13178;
/* Maximum modulus - even. */
mp_set(m, 0);
mp_set_bit(m, (r->size / 2) * SP_WORD_SIZE - 1);
ret = mp_invmod(a, m, r);
if (ret != MP_OKAY)
return -13179;
#endif
#if !defined(WOLFSSL_SP_MATH) || defined(WOLFSSL_SP_INT_NEGATIVE) #if !defined(WOLFSSL_SP_MATH) || defined(WOLFSSL_SP_INT_NEGATIVE)
mp_read_radix(a, "-3", 16); mp_read_radix(a, "-3", 16);
ret = mp_invmod(a, m, r); ret = mp_invmod(a, m, r);
if (ret != MP_OKAY) if (ret != MP_OKAY)
return -13178; return -13180;
#endif #endif
#if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC) #if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)
@@ -36901,28 +36920,28 @@ static int mp_test_invmod(mp_int* a, mp_int* m, mp_int* r)
mp_set(m, 3); mp_set(m, 3);
ret = mp_invmod_mont_ct(a, m, r, 1); ret = mp_invmod_mont_ct(a, m, r, 1);
if (ret != MP_VAL) if (ret != MP_VAL)
return -13179; return -13190;
mp_set(a, 1); mp_set(a, 1);
mp_set(m, 0); mp_set(m, 0);
ret = mp_invmod_mont_ct(a, m, r, 1); ret = mp_invmod_mont_ct(a, m, r, 1);
if (ret != MP_VAL) if (ret != MP_VAL)
return -13180; return -13191;
mp_set(a, 1); mp_set(a, 1);
mp_set(m, 1); mp_set(m, 1);
ret = mp_invmod_mont_ct(a, m, r, 1); ret = mp_invmod_mont_ct(a, m, r, 1);
if (ret != MP_VAL) if (ret != MP_VAL)
return -13181; return -13192;
mp_set(a, 1); mp_set(a, 1);
mp_set(m, 2); mp_set(m, 2);
ret = mp_invmod_mont_ct(a, m, r, 1); ret = mp_invmod_mont_ct(a, m, r, 1);
if (ret != MP_VAL) if (ret != MP_VAL)
return -13182; return -13193;
mp_set(a, 1); mp_set(a, 1);
mp_set(m, 3); mp_set(m, 3);
ret = mp_invmod_mont_ct(a, m, r, 1); ret = mp_invmod_mont_ct(a, m, r, 1);
if (ret != MP_OKAY) if (ret != MP_OKAY)
return -13183; return -13194;
#endif #endif
return 0; return 0;