Merge pull request #4151 from SparkiDev/sp_math_all_base10

SP math all: fix read radix 10
This commit is contained in:
David Garske
2021-06-25 09:37:05 -07:00
committed by GitHub
2 changed files with 95 additions and 56 deletions

View File

@ -2178,6 +2178,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp);
static void _sp_zero(sp_int* a)
{
a->used = 0;
a->dp[0] = 0;
#ifdef WOLFSSL_SP_INT_NEGATIVE
a->sign = MP_ZPOS;
#endif
@ -3236,9 +3237,12 @@ int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
* @param [in] n Number (SP digit) to multiply by.
* @param [out] r SP integer result.
* @param [in] o Number of digits to move result up by.
* @return MP_OKAY on success.
* @return MP_VAL when result is too large for sp_int.
*/
static void _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
static int _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
{
int err = MP_OKAY;
int i;
sp_int_word t = 0;
@ -3257,9 +3261,18 @@ static void _sp_mul_d(sp_int* a, sp_int_digit n, sp_int* r, int o)
t >>= SP_WORD_SIZE;
}
r->dp[o++] = (sp_int_digit)t;
if (t > 0) {
if (o == r->size) {
err = MP_VAL;
}
else {
r->dp[o++] = (sp_int_digit)t;
}
}
r->used = o;
sp_clamp(r);
return err;
}
#endif /* (WOLFSSL_SP_MATH_ALL && !WOLFSSL_RSA_VERIFY_ONLY) ||
* WOLFSSL_SP_SMALL || (WOLFSSL_KEY_GEN && !NO_RSA) */
@ -3287,7 +3300,7 @@ int sp_mul_d(sp_int* a, sp_int_digit d, sp_int* r)
}
if (err == MP_OKAY) {
_sp_mul_d(a, d, r, 0);
err = _sp_mul_d(a, d, r, 0);
#ifdef WOLFSSL_SP_INT_NEGATIVE
if (d == 0) {
r->sign = MP_ZPOS;
@ -3705,6 +3718,9 @@ int sp_div_2_mod_ct(sp_int* a, sp_int* m, sp_int* r)
if ((a == NULL) || (m == NULL) || (r == NULL)) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (r->size < m->used + 1)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
sp_int_word w = 0;
@ -4208,37 +4224,44 @@ static int sp_cmp_mag_ct(sp_int* a, sp_int* b, int len)
*/
int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
{
int err = MP_OKAY;
sp_int_word w = 0;
sp_int_digit mask;
int i;
if (0) {
sp_print(a, "a");
sp_print(b, "b");
sp_print(m, "m");
if ((r->size < m->used + 1) || (m->used == m->size)) {
err = MP_VAL;
}
_sp_add_off(a, b, r, 0);
mask = 0 - (sp_cmp_mag_ct(r, m, m->used + 1) != MP_LT);
for (i = 0; i < m->used; i++) {
sp_int_digit mask_r = 0 - (i < r->used);
w += m->dp[i] & mask;
w = (r->dp[i] & mask_r) - w;
r->dp[i] = (sp_int_digit)w;
w = (w >> DIGIT_BIT) & 1;
}
r->dp[i] = 0;
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
if (err == MP_OKAY) {
if (0) {
sp_print(a, "a");
sp_print(b, "b");
sp_print(m, "m");
}
if (0) {
sp_print(r, "rma");
_sp_add_off(a, b, r, 0);
mask = 0 - (sp_cmp_mag_ct(r, m, m->used + 1) != MP_LT);
for (i = 0; i < m->used; i++) {
sp_int_digit mask_r = 0 - (i < r->used);
w += m->dp[i] & mask;
w = (r->dp[i] & mask_r) - w;
r->dp[i] = (sp_int_digit)w;
w = (w >> DIGIT_BIT) & 1;
}
r->dp[i] = 0;
r->used = i;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
if (0) {
sp_print(r, "rma");
}
}
return MP_OKAY;
return err;
}
#endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */
@ -4259,39 +4282,46 @@ int sp_addmod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
*/
int sp_submod_ct(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
{
int err = MP_OKAY;
sp_int_word w = 0;
sp_int_digit mask;
int i;
if (0) {
sp_print(a, "a");
sp_print(b, "b");
sp_print(m, "m");
if (r->size < m->used + 1) {
err = MP_VAL;
}
mask = 0 - (sp_cmp_mag_ct(a, b, m->used + 1) == MP_LT);
for (i = 0; i < m->used + 1; i++) {
sp_int_digit mask_a = 0 - (i < a->used);
sp_int_digit mask_m = 0 - (i < m->used);
if (err == MP_OKAY) {
if (0) {
sp_print(a, "a");
sp_print(b, "b");
sp_print(m, "m");
}
w += m->dp[i] & mask_m & mask;
w += a->dp[i] & mask_a;
r->dp[i] = (sp_int_digit)w;
w >>= DIGIT_BIT;
}
r->dp[i] = (sp_int_digit)w;
r->used = i + 1;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
_sp_sub_off(r, b, r, 0);
mask = 0 - (sp_cmp_mag_ct(a, b, m->used) == MP_LT);
for (i = 0; i < m->used; i++) {
sp_int_digit mask_a = 0 - (i < a->used);
sp_int_digit mask_m = 0 - (i < m->used);
if (0) {
sp_print(r, "rms");
w += m->dp[i] & mask_m & mask;
w += a->dp[i] & mask_a;
r->dp[i] = (sp_int_digit)w;
w >>= DIGIT_BIT;
}
r->dp[i] = (sp_int_digit)w;
r->used = i + 1;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
sp_clamp(r);
_sp_sub_off(r, b, r, 0);
if (0) {
sp_print(r, "rms");
}
}
return MP_OKAY;
return err;
}
#endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */
@ -4628,7 +4658,10 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
#ifdef WOLFSSL_SP_SMALL
do {
_sp_mul_d(d, t, trial, i - d->used);
err = _sp_mul_d(d, t, trial, i - d->used);
if (err != MP_OKAY) {
break;
}
c = _sp_cmp_abs(trial, sa);
if (c == MP_GT) {
t--;
@ -4636,6 +4669,10 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
}
while (c == MP_GT);
if (err != MP_OKAY) {
break;
}
_sp_sub_off(sa, trial, sa, 0);
tr->dp[i - d->used] += t;
if (tr->dp[i - d->used] < t) {
@ -4676,7 +4713,7 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
}
sa->used = i + 1;
if (rem != NULL) {
if ((err == MP_OKAY) && (rem != NULL)) {
#ifdef WOLFSSL_SP_INT_NEGATIVE
sa->sign = (sa->used == 0) ? MP_ZPOS : aSign;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
@ -4691,7 +4728,7 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
}
#endif
}
if (r != NULL) {
if ((err == MP_OKAY) && (r != NULL)) {
sp_copy(tr, r);
sp_clamp(r);
#ifdef WOLFSSL_SP_INT_NEGATIVE
@ -12467,11 +12504,10 @@ static int _sp_read_radix_10(sp_int* a, const char* in)
err = MP_VAL;
break;
}
if (a->used + 1 > a->size) {
err = MP_VAL;
err = _sp_mul_d(a, 10, a, 0);
if (err != MP_OKAY) {
break;
}
_sp_mul_d(a, 10, a, 0);
(void)_sp_add_d(a, ch, a);
}
#ifdef WOLFSSL_SP_INT_NEGATIVE

View File

@ -648,9 +648,12 @@ typedef struct sp_ecc_ctx {
#define CheckFastMathSettings() (SP_WORD_SIZE == CheckRunTimeFastMath())
/* The number of bytes to a sp_int with 'cnt' digits. */
/* The number of bytes to a sp_int with 'cnt' digits.
* Must have at least one digit.
*/
#define MP_INT_SIZEOF(cnt) \
(sizeof(sp_int) - (SP_INT_DIGITS - (cnt)) * sizeof(sp_int_digit))
(sizeof(sp_int) - (SP_INT_DIGITS - (((cnt) == 0) ? 1 : (cnt))) * \
sizeof(sp_int_digit))
/* The address of the next sp_int after one with 'cnt' digits. */
#define MP_INT_NEXT(t, cnt) \
(sp_int*)(((byte*)(t)) + MP_INT_SIZEOF(cnt))