Merge pull request #5581 from SparkiDev/sp_int_size_fix

SP int: mp_init_size() fix
This commit is contained in:
David Garske
2022-09-16 08:29:06 -07:00
committed by GitHub
3 changed files with 101 additions and 43 deletions

View File

@ -4353,6 +4353,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp);
static void _sp_zero(sp_int* a)
{
sp_int_minimal* am = (sp_int_minimal *)a;
am->used = 0;
am->dp[0] = 0;
#ifdef WOLFSSL_SP_INT_NEGATIVE
@ -4371,20 +4372,20 @@ static void _sp_zero(sp_int* a)
*/
int sp_init_size(sp_int* a, int size)
{
sp_int_minimal* am = (sp_int_minimal *)a;
int err = MP_OKAY;
if (a == NULL) {
if ((a == NULL) || ((size <= 0) || (size > SP_INT_DIGITS))) {
err = MP_VAL;
}
if (err == MP_OKAY) {
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init(&am->raw);
#endif
_sp_zero(a);
}
if (err == MP_OKAY) {
volatile sp_int_minimal* am = (sp_int_minimal *)a;
#ifdef HAVE_WOLF_BIGINT
wc_bigint_init((struct WC_BIGINT*)&am->raw);
#endif
_sp_zero(a);
am->size = size;
}
@ -4602,13 +4603,18 @@ int sp_copy(const sp_int* a, sp_int* r)
err = MP_VAL;
}
else if (a != r) {
XMEMCPY(r->dp, a->dp, a->used * sizeof(sp_int_digit));
if (a->used == 0)
r->dp[0] = 0;
r->used = a->used;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
#endif
if (a->used > r->size) {
err = MP_VAL;
}
else {
XMEMCPY(r->dp, a->dp, a->used * sizeof(sp_int_digit));
if (a->used == 0)
r->dp[0] = 0;
r->used = a->used;
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = a->sign;
#endif
}
}
return err;
@ -5385,12 +5391,23 @@ int sp_add_d(sp_int* a, sp_int_digit d, sp_int* r)
if ((a == NULL) || (r == NULL)) {
err = MP_VAL;
}
else
{
#ifndef WOLFSSL_SP_INT_NEGATIVE
#ifndef WOLFSSL_SP_INT_NEGATIVE
if ((err == MP_OKAY) && (a->used + 1 > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
/* Positive only so just use internal function. */
err = _sp_add_d(a, d, r);
#else
}
#else
if ((err == MP_OKAY) && (a->sign == MP_ZPOS) && (a->used + 1 > r->size)) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (a->sign == MP_NEG) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
if (a->sign == MP_ZPOS) {
/* Positive so use interal function. */
r->sign = MP_ZPOS;
@ -5409,8 +5426,8 @@ int sp_add_d(sp_int* a, sp_int_digit d, sp_int* r)
/* Result is a digit equal to or greater than zero. */
r->used = ((r->dp[0] == 0) ? 0 : 1);
}
#endif
}
#endif
return err;
}
@ -5434,11 +5451,22 @@ int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
if ((a == NULL) || (r == NULL)) {
err = MP_VAL;
}
else {
#ifndef WOLFSSL_SP_INT_NEGATIVE
if ((err == MP_OKAY) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
/* Positive only so just use internal function. */
_sp_sub_d(a, d, r);
}
#else
if ((err == MP_OKAY) && (a->sign == MP_NEG) && (a->used + 1 > r->size)) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (a->sign == MP_ZPOS) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
if (a->sign == MP_NEG) {
/* Subtracting from negative use interal add. */
r->sign = MP_NEG;
@ -5457,8 +5485,8 @@ int sp_sub_d(sp_int* a, sp_int_digit d, sp_int* r)
/* Result is a digit equal to or greater than zero. */
r->used = 1;
}
#endif
}
#endif
return err;
}
@ -5878,6 +5906,10 @@ int sp_div_d(sp_int* a, sp_int_digit d, sp_int* r, sp_int_digit* rem)
err = MP_VAL;
}
if ((err == MP_OKAY) && (r != NULL) && (a->used > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
#if !defined(WOLFSSL_SP_SMALL)
if (d == 3) {
@ -6135,6 +6167,10 @@ int sp_div_2(sp_int* a, sp_int* r)
if ((a == NULL) || (r == NULL)) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (a->used > r->size)) {
err = MP_VAL;
}
#endif
if (err == MP_OKAY) {
@ -6394,7 +6430,10 @@ int sp_sub(sp_int* a, sp_int* b, sp_int* r)
if ((a == NULL) || (b == NULL) || (r == NULL)) {
err = MP_VAL;
}
else {
if ((err == MP_OKAY) && ((a->used >= r->size) || (b->used >= r->size))) {
err = MP_VAL;
}
if (err == MP_OKAY) {
#ifndef WOLFSSL_SP_INT_NEGATIVE
err = _sp_sub_off(a, b, r, 0);
#else
@ -6970,13 +7009,18 @@ void sp_rshd(sp_int* a, int c)
* @param [in] n Number of bits to shift.
* @param [out] r SP integer to store result in.
*/
void sp_rshb(sp_int* a, int n, sp_int* r)
int sp_rshb(sp_int* a, int n, sp_int* r)
{
int err = MP_OKAY;
int i = n >> SP_WORD_SHIFT;
if (i >= a->used) {
_sp_zero(r);
}
/* Change callers when more error cases returned. */
else if (a->used - i > r->size) {
err = MP_VAL;
}
else {
int j;
@ -7002,6 +7046,8 @@ void sp_rshb(sp_int* a, int n, sp_int* r)
}
#endif
}
return err;
}
#endif /* WOLFSSL_SP_MATH_ALL || !NO_DH || HAVE_ECC ||
* (!NO_RSA && !WOLFSSL_RSA_VERIFY_ONLY) || WOLFSSL_HAVE_SP_DH */
@ -7343,7 +7389,7 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
#endif /* WOLFSSL_SP_INT_NEGATIVE */
/* Move result back down if moved up for divisor value. */
if (s != SP_WORD_SIZE) {
sp_rshb(sa, s, sa);
(void)sp_rshb(sa, s, sa);
}
sp_copy(sa, rem);
sp_clamp(rem);
@ -11953,9 +11999,11 @@ int sp_div_2d(sp_int* a, int e, sp_int* r, sp_int* rem)
/* Copy a in to remainder. */
err = sp_copy(a, rem);
}
/* Shift a down by into result. */
sp_rshb(a, e, r);
if (rem != NULL) {
if (err == MP_OKAY) {
/* Shift a down by into result. */
err = sp_rshb(a, e, r);
}
if ((err == MP_OKAY) && (rem != NULL)) {
/* Set used and mask off top digit of remainder. */
rem->used = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT;
e &= SP_WORD_MASK;
@ -11987,13 +12035,16 @@ int sp_div_2d(sp_int* a, int e, sp_int* r, sp_int* rem)
int sp_mod_2d(sp_int* a, int e, sp_int* r)
{
int err = MP_OKAY;
int digits = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT;
if ((a == NULL) || (r == NULL)) {
err = MP_VAL;
}
if ((err == MP_OKAY) && (digits > r->size)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
int digits = (e + SP_WORD_SIZE - 1) >> SP_WORD_SHIFT;
if (a != r) {
XMEMCPY(r->dp, a->dp, digits * sizeof(sp_int_digit));
r->used = a->used;
@ -14633,7 +14684,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp)
}
sp_clamp(a);
sp_rshb(a, bits, a);
(void)sp_rshb(a, bits, a);
if (_sp_cmp_abs(a, m) != MP_LT) {
_sp_sub_off(a, m, a, 0);
@ -14918,7 +14969,7 @@ static int _sp_mont_red(sp_int* a, sp_int* m, sp_int_digit mp)
}
sp_clamp(a);
sp_rshb(a, bits, a);
(void)sp_rshb(a, bits, a);
if (_sp_cmp_abs(a, m) != MP_LT) {
sp_sub(a, m, a);
@ -15031,7 +15082,9 @@ int sp_mont_norm(sp_int* norm, sp_int* m)
bits = SP_WORD_SIZE;
}
_sp_zero(norm);
sp_set_bit(norm, bits);
err = sp_set_bit(norm, bits);
}
if (err == MP_OKAY) {
err = sp_sub(norm, m, norm);
}
if ((err == MP_OKAY) && (bits == SP_WORD_SIZE)) {
@ -15687,14 +15740,11 @@ int sp_radix_size(sp_int* a, int radix, int* size)
ALLOC_SP_INT(t, a->used + 1, err, NULL);
if (err == MP_OKAY) {
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_SP_NO_MALLOC)
t->size = a->used + 1;
#endif /* WOLFSSL_SMALL_STACK && !WOLFSSL_SP_NO_MALLOC */
err = sp_copy(a, t);
}
if (err == MP_OKAY) {
for (i = 0; !sp_iszero(t); i++) {
sp_div_d(t, 10, t, &d);
}
@ -15745,6 +15795,7 @@ int sp_rand_prime(sp_int* r, int len, WC_RNG* rng, void* heap)
#ifdef WOLFSSL_SP_MATH_ALL
int bits = 0;
#endif /* WOLFSSL_SP_MATH_ALL */
int digits = 0;
(void)heap;
@ -15760,6 +15811,13 @@ int sp_rand_prime(sp_int* r, int len, WC_RNG* rng, void* heap)
len = -len;
}
digits = (len + SP_WORD_SIZEOF - 1) / SP_WORD_SIZEOF;
if (r->size < digits) {
err = MP_VAL;
}
}
if (err == MP_OKAY) {
#ifndef WOLFSSL_SP_MATH_ALL
/* For minimal maths, support only what's in SP and needed for DH. */
#if defined(WOLFSSL_HAVE_SP_DH) && defined(WOLFSSL_KEY_GEN)
@ -15781,7 +15839,7 @@ int sp_rand_prime(sp_int* r, int len, WC_RNG* rng, void* heap)
#ifdef WOLFSSL_SP_INT_NEGATIVE
r->sign = MP_ZPOS;
#endif /* WOLFSSL_SP_INT_NEGATIVE */
r->used = (len + SP_WORD_SIZEOF - 1) / SP_WORD_SIZEOF;
r->used = digits;
#ifdef WOLFSSL_SP_MATH_ALL
bits = (len * 8) & SP_WORD_MASK;
#endif /* WOLFSSL_SP_MATH_ALL */
@ -15875,7 +15933,7 @@ static int sp_prime_miller_rabin_ex(sp_int* a, sp_int* b, int* result,
s = sp_cnt_lsb(r);
/* now divide n - 1 by 2**s */
sp_rshb(r, s, r);
(void)sp_rshb(r, s, r);
/* compute y = b**r mod a */
err = sp_exptmod(b, r, a, y);

View File

@ -147,7 +147,7 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
{
int ret = 0;
int cnt = digits * sizeof(mp_digit);
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH)
#ifdef USE_INTEGER_HEAP_MATH
int i;
#endif
@ -158,14 +158,14 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
ret = BAD_FUNC_ARG;
}
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH)
#ifdef USE_INTEGER_HEAP_MATH
/* allocate space for digits */
if (ret == MP_OKAY) {
ret = mp_set_bit(a, digits * DIGIT_BIT - 1);
}
#else
#if defined(WOLFSSL_SP_MATH) || defined(WOLFSSL_SP_MATH_ALL)
if ((ret == MP_OKAY) && (digits > SP_INT_DIGITS))
if ((ret == MP_OKAY) && (digits > a->size))
#else
if ((ret == MP_OKAY) && (digits > FP_SIZE))
#endif
@ -181,7 +181,7 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
ret = wc_RNG_GenerateBlock(rng, (byte*)a->dp, cnt);
}
if (ret == MP_OKAY) {
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH)
#ifdef USE_INTEGER_HEAP_MATH
/* Mask down each digit to only bits used */
for (i = 0; i < a->used; i++) {
a->dp[i] &= MP_MASK;
@ -190,7 +190,7 @@ int mp_rand(mp_int* a, int digits, WC_RNG* rng)
/* ensure top digit is not zero */
while ((ret == MP_OKAY) && (a->dp[a->used - 1] == 0)) {
ret = get_rand_digit(rng, &a->dp[a->used - 1]);
#if !defined(USE_FAST_MATH) && !defined(WOLFSSL_SP_MATH)
#ifdef USE_INTEGER_HEAP_MATH
a->dp[a->used - 1] &= MP_MASK;
#endif
}

View File

@ -877,7 +877,7 @@ MP_API int sp_addmod_ct (sp_int* a, sp_int* b, sp_int* c, sp_int* d);
MP_API int sp_lshd(sp_int* a, int s);
MP_API void sp_rshd(sp_int* a, int c);
MP_API void sp_rshb(sp_int* a, int n, sp_int* r);
MP_API int sp_rshb(sp_int* a, int n, sp_int* r);
#ifdef WOLFSSL_SP_MATH_ALL
MP_API int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem);