sp_int: Check size of numbers for overflow

This commit is contained in:
Sean Parkinson
2019-12-11 10:57:09 +10:00
parent dffb59ea52
commit adc14f7552
2 changed files with 91 additions and 50 deletions

View File

@@ -149,12 +149,19 @@ int sp_unsigned_bin_size(sp_int* a)
* a SP integer.
* in Array of bytes.
* inSz Number of data bytes in array.
* returns MP_OKAY always.
* returns BAD_FUNC_ARG when the number is too big to fit in an SP and
MP_OKAY otherwise.
*/
int sp_read_unsigned_bin(sp_int* a, const byte* in, word32 inSz)
int sp_read_unsigned_bin(sp_int* a, const byte* in, int inSz)
{
int err = MP_OKAY;
int i, j = 0, s = 0;
if (inSz > SP_INT_DIGITS * (int)sizeof(a->dp[0])) {
err = MP_VAL;
}
if (err == MP_OKAY) {
a->dp[0] = 0;
for (i = inSz-1; i >= 0; i--) {
a->dp[j] |= ((sp_int_digit)in[i]) << s;
@@ -174,14 +181,12 @@ int sp_read_unsigned_bin(sp_int* a, const byte* in, word32 inSz)
}
a->used = j + 1;
if (a->dp[j] == 0)
a->used--;
sp_clamp(a);
for (j++; j < a->size; j++)
a->dp[j] = 0;
sp_clamp(a);
}
return MP_OKAY;
return err;
}
#ifdef HAVE_ECC
@@ -201,8 +206,9 @@ int sp_read_radix(sp_int* a, const char* in, int radix)
int i, j = 0, k = 0;
char ch;
if ((radix != 16) || (*in == '-'))
if ((radix != 16) || (*in == '-')) {
err = BAD_FUNC_ARG;
}
if (err == MP_OKAY) {
a->dp[0] = 0;
@@ -221,7 +227,11 @@ int sp_read_radix(sp_int* a, const char* in, int radix)
a->dp[k] |= ((sp_int_digit)ch) << j;
j += 4;
if (j == DIGIT_BIT && k < SP_INT_DIGITS)
if (k >= SP_INT_DIGITS - 1) {
err = MP_VAL;
break;
}
if (j == DIGIT_BIT)
a->dp[++k] = 0;
j &= DIGIT_BIT - 1;
}
@@ -1082,12 +1092,17 @@ int sp_mul(sp_int* a, sp_int* b, sp_int* r)
sp_int tr[1];
#endif
if (a->used + b->used > SP_INT_DIGITS)
err = MP_VAL;
#ifdef WOLFSSL_SMALL_STACK
if (err == MP_OKAY) {
t = (sp_int*)XMALLOC(sizeof(sp_int) * 2, NULL, DYNAMIC_TYPE_BIGINT);
if (t == NULL)
err = MP_MEM;
else
tr = &t[1];
}
#endif
if (err == MP_OKAY) {
@@ -1114,13 +1129,17 @@ int sp_mul(sp_int* a, sp_int* b, sp_int* r)
* a SP integer to square.
* m SP integer modulus.
* r SP integer result.
* returns MP_VAL when m is 0, MP_MEM when dynamic memory allocation fails and
* MP_OKAY otherwise.
* returns MP_VAL when m is 0, MP_MEM when dynamic memory allocation fails,
* BAD_FUNC_ARG when a is to big and MP_OKAY otherwise.
*/
static int sp_sqrmod(sp_int* a, sp_int* m, sp_int* r)
{
int err;
int err = MP_OKAY;
if (a->used * 2 > SP_INT_DIGITS)
err = MP_VAL;
if (err == MP_OKAY)
err = sp_mul(a, a, r);
if (err == MP_OKAY)
err = sp_mod(r, m, r);
@@ -1147,11 +1166,16 @@ int sp_mulmod(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
sp_int t[1];
#endif
if (a->used + b->used > SP_INT_DIGITS)
err = MP_VAL;
#ifdef WOLFSSL_SMALL_STACK
if (err == MP_OKAY) {
t = (sp_int*)XMALLOC(sizeof(sp_int), NULL, DYNAMIC_TYPE_BIGINT);
if (t == NULL) {
err = MP_MEM;
}
}
#endif
if (err == MP_OKAY) {
err = sp_mul(a, b, t);
@@ -1364,7 +1388,9 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
*/
err = sp_invmod(m, a, r);
if (err == MP_OKAY) {
sp_mul(r, m, r);
err = sp_mul(r, m, r);
}
if (err == MP_OKAY) {
sp_sub_d(r, 1, r);
sp_div(r, a, r, NULL);
sp_sub(m, r, r);
@@ -1509,6 +1535,9 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
sp_set(r, 0);
done = 1;
}
else if (m->used * 2 > SP_INT_DIGITS) {
err = BAD_FUNC_ARG;
}
if (!done && (err == MP_OKAY)) {
#ifndef WOLFSSL_SP_NO_2048
@@ -1517,8 +1546,8 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
err = sp_ModExp_1024(b, e, m, r);
done = 1;
}
else if ((mBits == 1024) && sp_isodd(m) && (bBits <= 1024) &&
(eBits <= 1024)) {
else if ((mBits == 2048) && sp_isodd(m) && (bBits <= 2048) &&
(eBits <= 2048)) {
err = sp_ModExp_2048(b, e, m, r);
done = 1;
}
@@ -1549,7 +1578,7 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
#ifdef WOLFSSL_SMALL_STACK
if (!done && (err == MP_OKAY)) {
t = (sp_int*)XMALLOC(sizeof(sp_int) * 2, NULL, DYNAMIC_TYPE_BIGINT);
t = (sp_int*)XMALLOC(sizeof(sp_int), NULL, DYNAMIC_TYPE_BIGINT);
if (t == NULL) {
err = MP_MEM;
}
@@ -1557,16 +1586,28 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
#endif
if (!done && (err == MP_OKAY)) {
sp_init(t);
sp_copy(b, t);
if (sp_cmp(b, m) != MP_LT) {
err = sp_mod(b, m, t);
if (err == MP_OKAY && sp_iszero(t)) {
sp_set(r, 0);
done = 1;
}
}
else {
sp_copy(b, t);
}
if (!done && (err == MP_OKAY)) {
for (i = eBits-2; err == MP_OKAY && i >= 0; i--) {
err = sp_sqrmod(t, m, t);
if (err == MP_OKAY &&
(e->dp[i / SP_WORD_SIZE] >> (i % SP_WORD_SIZE)) & 1) {
if (err == MP_OKAY && (e->dp[i / SP_WORD_SIZE] >>
(i % SP_WORD_SIZE)) & 1) {
err = sp_mulmod(t, b, m, t);
}
}
}
}
if (!done && (err == MP_OKAY)) {
sp_copy(t, r);
}

View File

@@ -161,7 +161,7 @@ MP_API int sp_init_multi(sp_int* a, sp_int* b, sp_int* c, sp_int* d,
sp_int* e, sp_int* f);
MP_API void sp_clear(sp_int* a);
MP_API int sp_unsigned_bin_size(sp_int* a);
MP_API int sp_read_unsigned_bin(sp_int* a, const byte* in, word32 inSz);
MP_API int sp_read_unsigned_bin(sp_int* a, const byte* in, int inSz);
MP_API int sp_read_radix(sp_int* a, const char* in, int radix);
MP_API int sp_cmp(sp_int* a, sp_int* b);
MP_API int sp_count_bits(sp_int* a);