sp_int: Check size of numbers for overflow

This commit is contained in:
Sean Parkinson
2019-12-11 10:57:09 +10:00
parent dffb59ea52
commit adc14f7552
2 changed files with 91 additions and 50 deletions

View File

@@ -149,12 +149,19 @@ int sp_unsigned_bin_size(sp_int* a)
* a SP integer. * a SP integer.
* in Array of bytes. * in Array of bytes.
* inSz Number of data bytes in array. * inSz Number of data bytes in array.
* returns MP_OKAY always. * returns BAD_FUNC_ARG when the number is too big to fit in an SP and
MP_OKAY otherwise.
*/ */
int sp_read_unsigned_bin(sp_int* a, const byte* in, word32 inSz) int sp_read_unsigned_bin(sp_int* a, const byte* in, int inSz)
{ {
int err = MP_OKAY;
int i, j = 0, s = 0; int i, j = 0, s = 0;
if (inSz > SP_INT_DIGITS * (int)sizeof(a->dp[0])) {
err = MP_VAL;
}
if (err == MP_OKAY) {
a->dp[0] = 0; a->dp[0] = 0;
for (i = inSz-1; i >= 0; i--) { for (i = inSz-1; i >= 0; i--) {
a->dp[j] |= ((sp_int_digit)in[i]) << s; a->dp[j] |= ((sp_int_digit)in[i]) << s;
@@ -174,14 +181,12 @@ int sp_read_unsigned_bin(sp_int* a, const byte* in, word32 inSz)
} }
a->used = j + 1; a->used = j + 1;
if (a->dp[j] == 0) sp_clamp(a);
a->used--;
for (j++; j < a->size; j++) for (j++; j < a->size; j++)
a->dp[j] = 0; a->dp[j] = 0;
sp_clamp(a); }
return MP_OKAY; return err;
} }
#ifdef HAVE_ECC #ifdef HAVE_ECC
@@ -201,8 +206,9 @@ int sp_read_radix(sp_int* a, const char* in, int radix)
int i, j = 0, k = 0; int i, j = 0, k = 0;
char ch; char ch;
if ((radix != 16) || (*in == '-')) if ((radix != 16) || (*in == '-')) {
err = BAD_FUNC_ARG; err = BAD_FUNC_ARG;
}
if (err == MP_OKAY) { if (err == MP_OKAY) {
a->dp[0] = 0; a->dp[0] = 0;
@@ -221,7 +227,11 @@ int sp_read_radix(sp_int* a, const char* in, int radix)
a->dp[k] |= ((sp_int_digit)ch) << j; a->dp[k] |= ((sp_int_digit)ch) << j;
j += 4; j += 4;
if (j == DIGIT_BIT && k < SP_INT_DIGITS) if (k >= SP_INT_DIGITS - 1) {
err = MP_VAL;
break;
}
if (j == DIGIT_BIT)
a->dp[++k] = 0; a->dp[++k] = 0;
j &= DIGIT_BIT - 1; j &= DIGIT_BIT - 1;
} }
@@ -1082,12 +1092,17 @@ int sp_mul(sp_int* a, sp_int* b, sp_int* r)
sp_int tr[1]; sp_int tr[1];
#endif #endif
if (a->used + b->used > SP_INT_DIGITS)
err = MP_VAL;
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
if (err == MP_OKAY) {
t = (sp_int*)XMALLOC(sizeof(sp_int) * 2, NULL, DYNAMIC_TYPE_BIGINT); t = (sp_int*)XMALLOC(sizeof(sp_int) * 2, NULL, DYNAMIC_TYPE_BIGINT);
if (t == NULL) if (t == NULL)
err = MP_MEM; err = MP_MEM;
else else
tr = &t[1]; tr = &t[1];
}
#endif #endif
if (err == MP_OKAY) { if (err == MP_OKAY) {
@@ -1114,13 +1129,17 @@ int sp_mul(sp_int* a, sp_int* b, sp_int* r)
* a SP integer to square. * a SP integer to square.
* m SP integer modulus. * m SP integer modulus.
* r SP integer result. * r SP integer result.
* returns MP_VAL when m is 0, MP_MEM when dynamic memory allocation fails and * returns MP_VAL when m is 0, MP_MEM when dynamic memory allocation fails,
* MP_OKAY otherwise. * BAD_FUNC_ARG when a is to big and MP_OKAY otherwise.
*/ */
static int sp_sqrmod(sp_int* a, sp_int* m, sp_int* r) static int sp_sqrmod(sp_int* a, sp_int* m, sp_int* r)
{ {
int err; int err = MP_OKAY;
if (a->used * 2 > SP_INT_DIGITS)
err = MP_VAL;
if (err == MP_OKAY)
err = sp_mul(a, a, r); err = sp_mul(a, a, r);
if (err == MP_OKAY) if (err == MP_OKAY)
err = sp_mod(r, m, r); err = sp_mod(r, m, r);
@@ -1147,11 +1166,16 @@ int sp_mulmod(sp_int* a, sp_int* b, sp_int* m, sp_int* r)
sp_int t[1]; sp_int t[1];
#endif #endif
if (a->used + b->used > SP_INT_DIGITS)
err = MP_VAL;
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
if (err == MP_OKAY) {
t = (sp_int*)XMALLOC(sizeof(sp_int), NULL, DYNAMIC_TYPE_BIGINT); t = (sp_int*)XMALLOC(sizeof(sp_int), NULL, DYNAMIC_TYPE_BIGINT);
if (t == NULL) { if (t == NULL) {
err = MP_MEM; err = MP_MEM;
} }
}
#endif #endif
if (err == MP_OKAY) { if (err == MP_OKAY) {
err = sp_mul(a, b, t); err = sp_mul(a, b, t);
@@ -1364,7 +1388,9 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r)
*/ */
err = sp_invmod(m, a, r); err = sp_invmod(m, a, r);
if (err == MP_OKAY) { if (err == MP_OKAY) {
sp_mul(r, m, r); err = sp_mul(r, m, r);
}
if (err == MP_OKAY) {
sp_sub_d(r, 1, r); sp_sub_d(r, 1, r);
sp_div(r, a, r, NULL); sp_div(r, a, r, NULL);
sp_sub(m, r, r); sp_sub(m, r, r);
@@ -1509,6 +1535,9 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
sp_set(r, 0); sp_set(r, 0);
done = 1; done = 1;
} }
else if (m->used * 2 > SP_INT_DIGITS) {
err = BAD_FUNC_ARG;
}
if (!done && (err == MP_OKAY)) { if (!done && (err == MP_OKAY)) {
#ifndef WOLFSSL_SP_NO_2048 #ifndef WOLFSSL_SP_NO_2048
@@ -1517,8 +1546,8 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
err = sp_ModExp_1024(b, e, m, r); err = sp_ModExp_1024(b, e, m, r);
done = 1; done = 1;
} }
else if ((mBits == 1024) && sp_isodd(m) && (bBits <= 1024) && else if ((mBits == 2048) && sp_isodd(m) && (bBits <= 2048) &&
(eBits <= 1024)) { (eBits <= 2048)) {
err = sp_ModExp_2048(b, e, m, r); err = sp_ModExp_2048(b, e, m, r);
done = 1; done = 1;
} }
@@ -1549,7 +1578,7 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
if (!done && (err == MP_OKAY)) { if (!done && (err == MP_OKAY)) {
t = (sp_int*)XMALLOC(sizeof(sp_int) * 2, NULL, DYNAMIC_TYPE_BIGINT); t = (sp_int*)XMALLOC(sizeof(sp_int), NULL, DYNAMIC_TYPE_BIGINT);
if (t == NULL) { if (t == NULL) {
err = MP_MEM; err = MP_MEM;
} }
@@ -1557,16 +1586,28 @@ int sp_exptmod(sp_int* b, sp_int* e, sp_int* m, sp_int* r)
#endif #endif
if (!done && (err == MP_OKAY)) { if (!done && (err == MP_OKAY)) {
sp_init(t); sp_init(t);
sp_copy(b, t);
if (sp_cmp(b, m) != MP_LT) {
err = sp_mod(b, m, t);
if (err == MP_OKAY && sp_iszero(t)) {
sp_set(r, 0);
done = 1;
}
}
else {
sp_copy(b, t);
}
if (!done && (err == MP_OKAY)) {
for (i = eBits-2; err == MP_OKAY && i >= 0; i--) { for (i = eBits-2; err == MP_OKAY && i >= 0; i--) {
err = sp_sqrmod(t, m, t); err = sp_sqrmod(t, m, t);
if (err == MP_OKAY && if (err == MP_OKAY && (e->dp[i / SP_WORD_SIZE] >>
(e->dp[i / SP_WORD_SIZE] >> (i % SP_WORD_SIZE)) & 1) { (i % SP_WORD_SIZE)) & 1) {
err = sp_mulmod(t, b, m, t); err = sp_mulmod(t, b, m, t);
} }
} }
} }
}
if (!done && (err == MP_OKAY)) { if (!done && (err == MP_OKAY)) {
sp_copy(t, r); sp_copy(t, r);
} }

View File

@@ -161,7 +161,7 @@ MP_API int sp_init_multi(sp_int* a, sp_int* b, sp_int* c, sp_int* d,
sp_int* e, sp_int* f); sp_int* e, sp_int* f);
MP_API void sp_clear(sp_int* a); MP_API void sp_clear(sp_int* a);
MP_API int sp_unsigned_bin_size(sp_int* a); MP_API int sp_unsigned_bin_size(sp_int* a);
MP_API int sp_read_unsigned_bin(sp_int* a, const byte* in, word32 inSz); MP_API int sp_read_unsigned_bin(sp_int* a, const byte* in, int inSz);
MP_API int sp_read_radix(sp_int* a, const char* in, int radix); MP_API int sp_read_radix(sp_int* a, const char* in, int radix);
MP_API int sp_cmp(sp_int* a, sp_int* b); MP_API int sp_cmp(sp_int* a, sp_int* b);
MP_API int sp_count_bits(sp_int* a); MP_API int sp_count_bits(sp_int* a);