Merge pull request #5581 from SparkiDev/sp_int_size_fix

SP int: mp_init_size() fix
This commit is contained in:
David Garske
2022-09-16 08:29:06 -07:00
committed by GitHub
3 changed files with 101 additions and 43 deletions

View File

@ -4353,6 +4353,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp);
static void _sp_zero(sp_int* a) static void _sp_zero(sp_int* a)
{ {
sp_int_minimal* am = (sp_int_minimal *)a; sp_int_minimal* am = (sp_int_minimal *)a;
am->used = 0; am->used = 0;
am->dp[0] = 0; am->dp[0] = 0;
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
@ -4371,20 +4372,20 @@ static void _sp_zero(sp_int* a)
*/ */
int sp_init_size(sp_int* a, int size) int sp_init_size(sp_int* a, int size)
{ {
sp_int_minimal* am = (sp_int_minimal *)a;
int err = MP_OKAY; int err = MP_OKAY;
if (a == NULL) { if ((a == NULL) || ((size <= 0) || (size > SP_INT_DIGITS))) {
err = MP_VAL; err = MP_VAL;
} }
if (err == MP_OKAY) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&am->raw);
#endif
_sp_zero(a);
}
if (err == MP_OKAY) { if (err == MP_OKAY) {
volatile sp_int_minimal* am = (sp_int_minimal *)a;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init((struct WC_BIGINT*)&am->raw);
#endif
_sp_zero(a);
am->size = size; am->size = size;
} }
@ -4602,13 +4603,18 @@ int sp_copy(const sp_int* a, sp_int* r)
err = MP_VAL; err = MP_VAL;
} }
else if (a != r) { else if (a != r) {
XMEMCPY(r->dp, a->dp, a->used * sizeof(sp_int_digit)); if (a->used > r->size) {
if (a->used == 0) err = MP_VAL;
r->dp[0] = 0; }
r->used = a->used; else {
#ifdef WOLFSSL_SP_INT_NEGATIVE XMEMCPY(r->dp, a->dp, a->used * sizeof(sp_int_digit));
r->sign = a->sign; if (a->used == 0)
#endif r->dp[0] = 0;
r->used = a->used;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
#endif
}
} }
return err; return err;
@ -5385,12 +5391,23 @@ int sp_add_d(sp_int* a, sp_int_digit d, sp_int* r)
if ((a == NULL) || (r == NULL)) { if ((a == NULL) || (r == NULL)) {
err = MP_VAL; err = MP_VAL;
} }
else
{ #ifndef WOLFSSL_SP_INT_NEGATIVE
#ifndef WOLFSSL_SP_INT_NEGATIVE if ((err == MP_OKAY) && (a->used + 1 > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
/* Positive only so just use internal function. */ /* Positive only so just use internal function. */
err = _sp_add_d(a, d, r); err = _sp_add_d(a, d, r);
#else }
#else
if ((err == MP_OKAY) && (a->sign == MP_ZPOS) && (a->used + 1 > r->size)) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (a->sign == MP_NEG) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
if (a->sign == MP_ZPOS) { if (a->sign == MP_ZPOS) {
/* Positive so use interal function. */ /* Positive so use interal function. */
r->sign = MP_ZPOS; r->sign = MP_ZPOS;
@ -5409,8 +5426,8 @@ int sp_add_d(sp_int* a, sp_int_digit d, sp_int* r)
/* Result is a digit equal to or greater than zero. */ /* Result is a digit equal to or greater than zero. */
r->used = ((r->dp[0] == 0) ? 0 : 1); r->used = ((r->dp[0] == 0) ? 0 : 1);
} }
#endif
} }
#endif
return err; return err;
} }
@ -5434,11 +5451,22 @@ int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
if ((a == NULL) || (r == NULL)) { if ((a == NULL) || (r == NULL)) {
err = MP_VAL; err = MP_VAL;
} }
else {
#ifndef WOLFSSL_SP_INT_NEGATIVE #ifndef WOLFSSL_SP_INT_NEGATIVE
if ((err == MP_OKAY) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
/* Positive only so just use internal function. */ /* Positive only so just use internal function. */
_sp_sub_d(a, d, r); _sp_sub_d(a, d, r);
}
#else #else
if ((err == MP_OKAY) && (a->sign == MP_NEG) && (a->used + 1 > r->size)) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (a->sign == MP_ZPOS) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
if (a->sign == MP_NEG) { if (a->sign == MP_NEG) {
/* Subtracting from negative use interal add. */ /* Subtracting from negative use interal add. */
r->sign = MP_NEG; r->sign = MP_NEG;
@ -5457,8 +5485,8 @@ int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
/* Result is a digit equal to or greater than zero. */ /* Result is a digit equal to or greater than zero. */
r->used = 1; r->used = 1;
} }
#endif
} }
#endif
return err; return err;
} }
@ -5878,6 +5906,10 @@ int sp_div_d(sp_int* a, sp_int_digit d, sp_int* r, sp_int_digit* rem)
err = MP_VAL; err = MP_VAL;
} }
if ((err == MP_OKAY) && (r != NULL) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) { if (err == MP_OKAY) {
#if !defined(WOLFSSL_SP_SMALL) #if !defined(WOLFSSL_SP_SMALL)
if (d == 3) { if (d == 3) {
@ -6135,6 +6167,10 @@ int sp_div_2(sp_int* a, sp_int* r)
if ((a == NULL) || (r == NULL)) { if ((a == NULL) || (r == NULL)) {
err = MP_VAL; err = MP_VAL;
} }
if ((err == MP_OKAY) && (a->used > r->size)) {
err = MP_VAL;
}
#endif #endif
if (err == MP_OKAY) { if (err == MP_OKAY) {
@ -6394,7 +6430,10 @@ int sp_sub(sp_int* a, sp_int* b, sp_int* r)
if ((a == NULL) || (b == NULL) || (r == NULL)) { if ((a == NULL) || (b == NULL) || (r == NULL)) {
err = MP_VAL; err = MP_VAL;
} }
else { if ((err == MP_OKAY) && ((a->used >= r->size) || (b->used >= r->size))) {
err = MP_VAL;
}
if (err == MP_OKAY) {
#ifndef WOLFSSL_SP_INT_NEGATIVE #ifndef WOLFSSL_SP_INT_NEGATIVE
err = _sp_sub_off(a, b, r, 0); err = _sp_sub_off(a, b, r, 0);
#else #else
@ -6970,13 +7009,18 @@ void sp_rshd(sp_int* a, int c)
* @param [in] n Number of bits to shift. * @param [in] n Number of bits to shift.
* @param [out] r SP integer to store result in. * @param [out] r SP integer to store result in.
*/ */
void sp_rshb(sp_int* a, int n, sp_int* r) int sp_rshb(sp_int* a, int n, sp_int* r)
{ {
int err = MP_OKAY;
int i = n >> SP_WORD_SHIFT; int i = n >> SP_WORD_SHIFT;
if (i >= a->used) { if (i >= a->used) {
_sp_zero(r); _sp_zero(r);
} }
/* Change callers when more error cases returned. */
else if (a->used - i > r->size) {
err = MP_VAL;
}
else { else {
int j; int j;
@ -7002,6 +7046,8 @@ void sp_rshb(sp_int* a, int n, sp_int* r)
} }
#endif #endif
} }
return err;
} }
#endif /* WOLFSSL_SP_MATH_ALL || !NO_DH || HAVE_ECC || #endif /* WOLFSSL_SP_MATH_ALL || !NO_DH || HAVE_ECC ||
* (!NO_RSA && !WOLFSSL_RSA_VERIFY_ONLY) || WOLFSSL_HAVE_SP_DH */ * (!NO_RSA && !WOLFSSL_RSA_VERIFY_ONLY) || WOLFSSL_HAVE_SP_DH */
@ -7343,7 +7389,7 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
#endif /* WOLFSSL_SP_INT_NEGATIVE */ #endif /* WOLFSSL_SP_INT_NEGATIVE */
/* Move result back down if moved up for divisor value. */ /* Move result back down if moved up for divisor value. */
if (s != SP_WORD_SIZE) { if (s != SP_WORD_SIZE) {
sp_rshb(sa, s, sa); (void)sp_rshb(sa, s, sa);
} }
sp_copy(sa, rem); sp_copy(sa, rem);
sp_clamp(rem); sp_clamp(rem);
@ -11953,9 +11999,11 @@ int sp_div_2d(sp_int* a, int e, sp_int* r, sp_int* rem)
/* Copy a in to remainder. */ /* Copy a in to remainder. */
err = sp_copy(a, rem); err = sp_copy(a, rem);
} }
/* Shift a down by into result. */ if (err == MP_OKAY) {
sp_rshb(a, e, r); /* Shift a down by into result. */
if (rem != NULL) { err = sp_rshb(a, e, r);
}
if ((err == MP_OKAY) && (rem != NULL)) {
/* Set used and mask off top digit of remainder. */ /* Set used and mask off top digit of remainder. */
rem->used = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT; rem->used = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT;
e &= SP_WORD_MASK; e &= SP_WORD_MASK;
@ -11987,13 +12035,16 @@ int sp_div_2d(sp_int* a, int e, sp_int* r, sp_int* rem)
int sp_mod_2d(sp_int* a, int e, sp_int* r) int sp_mod_2d(sp_int* a, int e, sp_int* r)
{ {
int err = MP_OKAY; int err = MP_OKAY;
int digits = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT;
if ((a == NULL) || (r == NULL)) { if ((a == NULL) || (r == NULL)) {
err = MP_VAL; err = MP_VAL;
} }
if ((err == MP_OKAY) && (digits > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) { if (err == MP_OKAY) {
int digits = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT;
if (a != r) { if (a != r) {
XMEMCPY(r->dp, a->dp, digits * sizeof(sp_int_digit)); XMEMCPY(r->dp, a->dp, digits * sizeof(sp_int_digit));
r->used = a->used; r->used = a->used;
@ -14633,7 +14684,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp)
} }
sp_clamp(a); sp_clamp(a);
sp_rshb(a, bits, a); (void)sp_rshb(a, bits, a);
if (_sp_cmp_abs(a, m) != MP_LT) { if (_sp_cmp_abs(a, m) != MP_LT) {
_sp_sub_off(a, m, a, 0); _sp_sub_off(a, m, a, 0);
@ -14918,7 +14969,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp)
} }
sp_clamp(a); sp_clamp(a);
sp_rshb(a, bits, a); (void)sp_rshb(a, bits, a);
if (_sp_cmp_abs(a, m) != MP_LT) { if (_sp_cmp_abs(a, m) != MP_LT) {
sp_sub(a, m, a); sp_sub(a, m, a);
@ -15031,7 +15082,9 @@ int sp_mont_norm(sp_int* norm, sp_int* m)
bits = SP_WORD_SIZE; bits = SP_WORD_SIZE;
} }
_sp_zero(norm); _sp_zero(norm);
sp_set_bit(norm, bits); err = sp_set_bit(norm, bits);
}
if (err == MP_OKAY) {
err = sp_sub(norm, m, norm); err = sp_sub(norm, m, norm);
} }
if ((err == MP_OKAY) && (bits == SP_WORD_SIZE)) { if ((err == MP_OKAY) && (bits == SP_WORD_SIZE)) {
@ -15687,14 +15740,11 @@ int sp_radix_size(sp_int* a, int radix, int* size)
ALLOC_SP_INT(t, a->used + 1, err, NULL); ALLOC_SP_INT(t, a->used + 1, err, NULL);
if (err == MP_OKAY) { if (err == MP_OKAY) {
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_SP_NO_MALLOC)
t->size = a->used + 1; t->size = a->used + 1;
#endif /* WOLFSSL_SMALL_STACK && !WOLFSSL_SP_NO_MALLOC */
err = sp_copy(a, t); err = sp_copy(a, t);
} }
if (err == MP_OKAY) { if (err == MP_OKAY) {
for (i = 0; !sp_iszero(t); i++) { for (i = 0; !sp_iszero(t); i++) {
sp_div_d(t, 10, t, &d); sp_div_d(t, 10, t, &d);
} }
@ -15745,6 +15795,7 @@ int sp_rand_prime(sp_int* r, int len, WC_RNG* rng, void* heap)
#ifdef WOLFSSL_SP_MATH_ALL #ifdef WOLFSSL_SP_MATH_ALL
int bits = 0; int bits = 0;
#endif /* WOLFSSL_SP_MATH_ALL */ #endif /* WOLFSSL_SP_MATH_ALL */
int digits = 0;
(void)heap; (void)heap;
@ -15760,6 +15811,13 @@ int sp_rand_prime(sp_int* r, int len, WC_RNG* rng, void* heap)
len = -len; len = -len;
} }
digits = (len + SP_WORD_SIZEOF - 1) / SP_WORD_SIZEOF;
if (r->size < digits) {
err = MP_VAL;
}
}
if (err == MP_OKAY) {
#ifndef WOLFSSL_SP_MATH_ALL #ifndef WOLFSSL_SP_MATH_ALL
/* For minimal maths, support only what's in SP and needed for DH. */ /* For minimal maths, support only what's in SP and needed for DH. */
#if defined(WOLFSSL_HAVE_SP_DH) && defined(WOLFSSL_KEY_GEN) #if defined(WOLFSSL_HAVE_SP_DH) && defined(WOLFSSL_KEY_GEN)
@ -15781,7 +15839,7 @@ int sp_rand_prime(sp_int* r, int len, WC_RNG* rng, void* heap)
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS; r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */ #endif /* WOLFSSL_SP_INT_NEGATIVE */
r->used = (len + SP_WORD_SIZEOF - 1) / SP_WORD_SIZEOF; r->used = digits;
#ifdef WOLFSSL_SP_MATH_ALL #ifdef WOLFSSL_SP_MATH_ALL
bits = (len * 8) & SP_WORD_MASK; bits = (len * 8) & SP_WORD_MASK;
#endif /* WOLFSSL_SP_MATH_ALL */ #endif /* WOLFSSL_SP_MATH_ALL */
@ -15875,7 +15933,7 @@ static int sp_prime_miller_rabin_ex(sp_int* a, sp_int* b, int* result,
s = sp_cnt_lsb(r); s = sp_cnt_lsb(r);
/* now divide n - 1 by 2**s */ /* now divide n - 1 by 2**s */
sp_rshb(r, s, r); (void)sp_rshb(r, s, r);
/* compute y = b**r mod a */ /* compute y = b**r mod a */
err = sp_exptmod(b, r, a, y); err = sp_exptmod(b, r, a, y);

View File

@ -147,7 +147,7 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
{ {
int ret = 0; int ret = 0;
int cnt = digits * sizeof(mp_digit); int cnt = digits * sizeof(mp_digit);
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH) #ifdef USE_INTEGER_HEAP_MATH
int i; int i;
#endif #endif
@ -158,14 +158,14 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
ret = BAD_FUNC_ARG; ret = BAD_FUNC_ARG;
} }
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH) #ifdef USE_INTEGER_HEAP_MATH
/* allocate space for digits */ /* allocate space for digits */
if (ret == MP_OKAY) { if (ret == MP_OKAY) {
ret = mp_set_bit(a, digits * DIGIT_BIT - 1); ret = mp_set_bit(a, digits * DIGIT_BIT - 1);
} }
#else #else
#if defined(WOLFSSL_SP_MATH) || defined(WOLFSSL_SP_MATH_ALL) #if defined(WOLFSSL_SP_MATH) || defined(WOLFSSL_SP_MATH_ALL)
if ((ret == MP_OKAY) && (digits > SP_INT_DIGITS)) if ((ret == MP_OKAY) && (digits > a->size))
#else #else
if ((ret == MP_OKAY) && (digits > FP_SIZE)) if ((ret == MP_OKAY) && (digits > FP_SIZE))
#endif #endif
@ -181,7 +181,7 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
ret = wc_RNG_GenerateBlock(rng, (byte*)a->dp, cnt); ret = wc_RNG_GenerateBlock(rng, (byte*)a->dp, cnt);
} }
if (ret == MP_OKAY) { if (ret == MP_OKAY) {
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH) #ifdef USE_INTEGER_HEAP_MATH
/* Mask down each digit to only bits used */ /* Mask down each digit to only bits used */
for (i = 0; i < a->used; i++) { for (i = 0; i < a->used; i++) {
a->dp[i] &= MP_MASK; a->dp[i] &= MP_MASK;
@ -190,7 +190,7 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
/* ensure top digit is not zero */ /* ensure top digit is not zero */
while ((ret == MP_OKAY) && (a->dp[a->used - 1] == 0)) { while ((ret == MP_OKAY) && (a->dp[a->used - 1] == 0)) {
ret = get_rand_digit(rng, &a->dp[a->used - 1]); ret = get_rand_digit(rng, &a->dp[a->used - 1]);
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH) #ifdef USE_INTEGER_HEAP_MATH
a->dp[a->used - 1] &= MP_MASK; a->dp[a->used - 1] &= MP_MASK;
#endif #endif
} }

View File

@ -877,7 +877,7 @@ MP_API int sp_addmod_ct (sp_int* a, sp_int* b, sp_int* c, sp_int* d);
MP_API int sp_lshd(sp_int* a, int s); MP_API int sp_lshd(sp_int* a, int s);
MP_API void sp_rshd(sp_int* a, int c); MP_API void sp_rshd(sp_int* a, int c);
MP_API void sp_rshb(sp_int* a, int n, sp_int* r); MP_API int sp_rshb(sp_int* a, int n, sp_int* r);
#ifdef WOLFSSL_SP_MATH_ALL #ifdef WOLFSSL_SP_MATH_ALL
MP_API int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem); MP_API int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem);