Merge pull request #4151 from SparkiDev/sp_math_all_base10

SP math all: fix read radix 10
This commit is contained in:
David Garske
2021-06-25 09:37:05 -07:00
committed by GitHub
2 changed files with 95 additions and 56 deletions

View File

@ -2178,6 +2178,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp);
static void _sp_zero(sp_int* a) static void _sp_zero(sp_int* a)
{ {
a->used = 0; a->used = 0;
a->dp[0] = 0;
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
a->sign = MP_ZPOS; a->sign = MP_ZPOS;
#endif #endif
@ -3236,9 +3237,12 @@ int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
* @param [in] n Number (SP digit) to multiply by. * @param [in] n Number (SP digit) to multiply by.
* @param [out] r SP integer result. * @param [out] r SP integer result.
* @param [in] o Number of digits to move result up by. * @param [in] o Number of digits to move result up by.
* @return MP_OKAY on success.
* @return MP_VAL when result is too large for sp_int.
*/ */
static void _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o) static int _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
{ {
int err = MP_OKAY;
int i; int i;
sp_int_word t = 0; sp_int_word t = 0;
@ -3257,9 +3261,18 @@ static void _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
t >>= SP_WORD_SIZE; t >>= SP_WORD_SIZE;
} }
r->dp[o++] = (sp_int_digit)t; if (t > 0) {
if (o == r->size) {
err = MP_VAL;
}
else {
r->dp[o++] = (sp_int_digit)t;
}
}
r->used = o; r->used = o;
sp_clamp(r); sp_clamp(r);
return err;
} }
#endif /* (WOLFSSL_SP_MATH_ALL && !WOLFSSL_RSA_VERIFY_ONLY) || #endif /* (WOLFSSL_SP_MATH_ALL && !WOLFSSL_RSA_VERIFY_ONLY) ||
* WOLFSSL_SP_SMALL || (WOLFSSL_KEY_GEN && !NO_RSA) */ * WOLFSSL_SP_SMALL || (WOLFSSL_KEY_GEN && !NO_RSA) */
@ -3287,7 +3300,7 @@ int sp_mul_d(sp_int* a, sp_int_digit d, sp_int* r)
} }
if (err == MP_OKAY) { if (err == MP_OKAY) {
_sp_mul_d(a, d, r, 0); err = _sp_mul_d(a, d, r, 0);
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
if (d == 0) { if (d == 0) {
r->sign = MP_ZPOS; r->sign = MP_ZPOS;
@ -3705,6 +3718,9 @@ int sp_div_2_mod_ct(sp_int* a, sp_int* m, sp_int* r)
if ((a == NULL) || (m == NULL) || (r == NULL)) { if ((a == NULL) || (m == NULL) || (r == NULL)) {
err = MP_VAL; err = MP_VAL;
} }
if ((err == MP_OKAY) && (r->size < m->used + 1)) {
err = MP_VAL;
}
if (err == MP_OKAY) { if (err == MP_OKAY) {
sp_int_word w = 0; sp_int_word w = 0;
@ -4208,37 +4224,44 @@ static int sp_cmp_mag_ct(sp_int* a, sp_int* b, int len)
*/ */
int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r) int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
{ {
int err = MP_OKAY;
sp_int_word w = 0; sp_int_word w = 0;
sp_int_digit mask; sp_int_digit mask;
int i; int i;
if (0) { if ((r->size < m->used + 1) || (m->used == m->size)) {
sp_print(a, "a"); err = MP_VAL;
sp_print(b, "b");
sp_print(m, "m");
} }
_sp_add_off(a, b, r, 0); if (err == MP_OKAY) {
mask = 0 - (sp_cmp_mag_ct(r, m, m->used + 1) != MP_LT); if (0) {
for (i = 0; i < m->used; i++) { sp_print(a, "a");
sp_int_digit mask_r = 0 - (i < r->used); sp_print(b, "b");
w += m->dp[i] & mask; sp_print(m, "m");
w = (r->dp[i] & mask_r) - w; }
r->dp[i] = (sp_int_digit)w;
w = (w >> DIGIT_BIT) & 1;
}
r->dp[i] = 0;
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
if (0) { _sp_add_off(a, b, r, 0);
sp_print(r, "rma"); mask = 0 - (sp_cmp_mag_ct(r, m, m->used + 1) != MP_LT);
for (i = 0; i < m->used; i++) {
sp_int_digit mask_r = 0 - (i < r->used);
w += m->dp[i] & mask;
w = (r->dp[i] & mask_r) - w;
r->dp[i] = (sp_int_digit)w;
w = (w >> DIGIT_BIT) & 1;
}
r->dp[i] = 0;
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
if (0) {
sp_print(r, "rma");
}
} }
return MP_OKAY; return err;
} }
#endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */ #endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */
@ -4259,39 +4282,46 @@ int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
*/ */
int sp_submod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r) int sp_submod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
{ {
int err = MP_OKAY;
sp_int_word w = 0; sp_int_word w = 0;
sp_int_digit mask; sp_int_digit mask;
int i; int i;
if (0) { if (r->size < m->used + 1) {
sp_print(a, "a"); err = MP_VAL;
sp_print(b, "b");
sp_print(m, "m");
} }
mask = 0 - (sp_cmp_mag_ct(a, b, m->used + 1) == MP_LT); if (err == MP_OKAY) {
for (i = 0; i < m->used + 1; i++) { if (0) {
sp_int_digit mask_a = 0 - (i < a->used); sp_print(a, "a");
sp_int_digit mask_m = 0 - (i < m->used); sp_print(b, "b");
sp_print(m, "m");
}
w += m->dp[i] & mask_m & mask; mask = 0 - (sp_cmp_mag_ct(a, b, m->used) == MP_LT);
w += a->dp[i] & mask_a; for (i = 0; i < m->used; i++) {
r->dp[i] = (sp_int_digit)w; sp_int_digit mask_a = 0 - (i < a->used);
w >>= DIGIT_BIT; sp_int_digit mask_m = 0 - (i < m->used);
}
r->dp[i] = (sp_int_digit)w;
r->used = i + 1;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
_sp_sub_off(r, b, r, 0);
if (0) { w += m->dp[i] & mask_m & mask;
sp_print(r, "rms"); w += a->dp[i] & mask_a;
r->dp[i] = (sp_int_digit)w;
w >>= DIGIT_BIT;
}
r->dp[i] = (sp_int_digit)w;
r->used = i + 1;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
_sp_sub_off(r, b, r, 0);
if (0) {
sp_print(r, "rms");
}
} }
return MP_OKAY; return err;
} }
#endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */ #endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */
@ -4628,7 +4658,10 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
#ifdef WOLFSSL_SP_SMALL #ifdef WOLFSSL_SP_SMALL
do { do {
_sp_mul_d(d, t, trial, i - d->used); err = _sp_mul_d(d, t, trial, i - d->used);
if (err != MP_OKAY) {
break;
}
c = _sp_cmp_abs(trial, sa); c = _sp_cmp_abs(trial, sa);
if (c == MP_GT) { if (c == MP_GT) {
t--; t--;
@ -4636,6 +4669,10 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
} }
while (c == MP_GT); while (c == MP_GT);
if (err != MP_OKAY) {
break;
}
_sp_sub_off(sa, trial, sa, 0); _sp_sub_off(sa, trial, sa, 0);
tr->dp[i - d->used] += t; tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t) { if (tr->dp[i - d->used] < t) {
@ -4676,7 +4713,7 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
} }
sa->used = i + 1; sa->used = i + 1;
if (rem != NULL) { if ((err == MP_OKAY) && (rem != NULL)) {
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
sa->sign = (sa->used == 0) ? MP_ZPOS : aSign; sa->sign = (sa->used == 0) ? MP_ZPOS : aSign;
#endif /* WOLFSSL_SP_INT_NEGATIVE */ #endif /* WOLFSSL_SP_INT_NEGATIVE */
@ -4691,7 +4728,7 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
} }
#endif #endif
} }
if (r != NULL) { if ((err == MP_OKAY) && (r != NULL)) {
sp_copy(tr, r); sp_copy(tr, r);
sp_clamp(r); sp_clamp(r);
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
@ -12467,11 +12504,10 @@ static int _sp_read_radix_10(sp_int* a, const char* in)
err = MP_VAL; err = MP_VAL;
break; break;
} }
if (a->used + 1 > a->size) { err = _sp_mul_d(a, 10, a, 0);
err = MP_VAL; if (err != MP_OKAY) {
break; break;
} }
_sp_mul_d(a, 10, a, 0);
(void)_sp_add_d(a, ch, a); (void)_sp_add_d(a, ch, a);
} }
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE

View File

@ -648,9 +648,12 @@ typedef struct sp_ecc_ctx {
#define CheckFastMathSettings() (SP_WORD_SIZE == CheckRunTimeFastMath()) #define CheckFastMathSettings() (SP_WORD_SIZE == CheckRunTimeFastMath())
/* The number of bytes to a sp_int with 'cnt' digits. */ /* The number of bytes to a sp_int with 'cnt' digits.
* Must have at least one digit.
*/
#define MP_INT_SIZEOF(cnt) \ #define MP_INT_SIZEOF(cnt) \
(sizeof(sp_int) - (SP_INT_DIGITS - (cnt)) * sizeof(sp_int_digit)) (sizeof(sp_int) - (SP_INT_DIGITS - (((cnt) == 0) ? 1 : (cnt))) * \
sizeof(sp_int_digit))
/* The address of the next sp_int after one with 'cnt' digits. */ /* The address of the next sp_int after one with 'cnt' digits. */
#define MP_INT_NEXT(t, cnt) \ #define MP_INT_NEXT(t, cnt) \
(sp_int*)(((byte*)(t)) + MP_INT_SIZEOF(cnt)) (sp_int*)(((byte*)(t)) + MP_INT_SIZEOF(cnt))