diff --git a/wolfcrypt/src/sp_int.c b/wolfcrypt/src/sp_int.c index 4b760631d..8fe3a0a56 100644 --- a/wolfcrypt/src/sp_int.c +++ b/wolfcrypt/src/sp_int.c @@ -4021,7 +4021,11 @@ int sp_submod(sp_int* a, sp_int* b, sp_int* m, sp_int* r) { #ifndef WOLFSSL_SP_INT_NEGATIVE int err = MP_OKAY; - DECL_SP_INT_ARRAY(t, (m == NULL) ? 1 : m->used + 1, 2); + int used = ((a == NULL) || (b == NULL) || (m == NULL)) ? 1 : + ((a->used >= m->used) ? + ((a->used >= b->used) ? (a->used + 1) : (b->used + 1)) : + ((b->used >= m->used)) ? (b->used + 1) : (m->used + 1)); + DECL_SP_INT_ARRAY(t, used, 2); if ((a == NULL) || (b == NULL) || (m == NULL) || (r == NULL)) { err = MP_VAL; @@ -4033,7 +4037,7 @@ int sp_submod(sp_int* a, sp_int* b, sp_int* m, sp_int* r) sp_print(m, "m"); } - ALLOC_SP_INT_ARRAY(t, m->used + 1, 2, err, NULL); + ALLOC_SP_INT_ARRAY(t, used, 2, err, NULL); if (err == MP_OKAY) { if (_sp_cmp(a, m) == MP_GT) { err = sp_mod(a, m, t[0]); @@ -4443,6 +4447,12 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem) if ((err == MP_OKAY) && sp_iszero(d)) { err = MP_VAL; } + if ((err == MP_OKAY) && (r != NULL) && (r->size < a->used - d->used + 2)) { + err = MP_VAL; + } + if ((err == MP_OKAY) && (rem != NULL) && (rem->size < a->used + 1)) { + err = MP_VAL; + } if (0 && (err == MP_OKAY)) { sp_print(a, "a"); @@ -7728,7 +7738,9 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r) sp_int* v; sp_int* b; sp_int* c; - DECL_SP_INT_ARRAY(t, (m == NULL) ? 0 : m->used + 1, 4); + int used = ((m == NULL) || (a == NULL)) ? 1 : + ((m->used >= a->used) ? m->used + 1 : a->used + 1); + DECL_SP_INT_ARRAY(t, used, 4); if ((a == NULL) || (m == NULL) || (r == NULL)) { err = MP_VAL; @@ -7746,7 +7758,7 @@ int sp_invmod(sp_int* a, sp_int* m, sp_int* r) v = t[1]; b = t[2]; c = t[3]; - sp_init_size(v, m->used + 1); + sp_init_size(v, used + 1); if (_sp_cmp_abs(a, m) != MP_LT) { err = sp_mod(a, m, v); @@ -13252,7 +13264,7 @@ int sp_gcd(sp_int* a, sp_int* b, sp_int* r) sp_int* u = NULL; sp_int* v = NULL; sp_int* t = NULL; - int used = (a->used >= b->used) ? a->used : b->used; + int used = (a->used >= b->used) ? a->used + 1 : b->used + 1; DECL_SP_INT_ARRAY(d, used, 3); ALLOC_SP_INT_ARRAY(d, used, 3, err, NULL); @@ -13343,7 +13355,7 @@ int sp_lcm(sp_int* a, sp_int* b, sp_int* r) { int err = MP_OKAY; int used = ((a == NULL) || (b == NULL)) ? 1 : - (a->used >= b->used ? a->used : b->used); + (a->used >= b->used ? a->used + 1: b->used + 1); DECL_SP_INT_ARRAY(t, used, 2); if ((a == NULL) || (b == NULL) || (r == NULL)) {