Merge pull request #5399 from SparkiDev/sp_exptmod_reduce

SP int: exptmod ensure base is less than modulus
This commit is contained in:
David Garske
2022-07-27 15:43:16 -07:00
committed by GitHub

View File

@ -7125,7 +7125,7 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
#if (defined(WOLFSSL_SMALL_STACK) || defined(SP_ALLOC)) && \ #if (defined(WOLFSSL_SMALL_STACK) || defined(SP_ALLOC)) && \
!defined(WOLFSSL_SP_NO_MALLOC) !defined(WOLFSSL_SP_NO_MALLOC)
int cnt = 4; int cnt = 4;
if ((rem != NULL) && (rem != d)) { if ((rem != NULL) && (rem != d) && (rem->size > a->used)) {
cnt--; cnt--;
} }
if ((r != NULL) && (r != d)) { if ((r != NULL) && (r != d)) {
@ -7144,8 +7144,9 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
#if (defined(WOLFSSL_SMALL_STACK) || defined(SP_ALLOC)) && \ #if (defined(WOLFSSL_SMALL_STACK) || defined(SP_ALLOC)) && \
!defined(WOLFSSL_SP_NO_MALLOC) !defined(WOLFSSL_SP_NO_MALLOC)
i = 2; i = 2;
sa = ((rem != NULL) && (rem != d)) ? rem : td[i++]; sa = ((rem != NULL) && (rem != d) && (rem->size > a->used)) ? rem :
tr = ((r != NULL) && (r != d)) ? r : td[i]; td[i++];
tr = ((r != NULL) && (r != d)) ? r : td[i];
#else #else
sa = td[2]; sa = td[2];
tr = td[3]; tr = td[3];
@ -7155,10 +7156,10 @@ int sp_div(sp_int* a, sp_int* d, sp_int* r, sp_int* rem)
sp_init_size(trial, a->used + 1); sp_init_size(trial, a->used + 1);
#if (defined(WOLFSSL_SMALL_STACK) || defined(SP_ALLOC)) && \ #if (defined(WOLFSSL_SMALL_STACK) || defined(SP_ALLOC)) && \
!defined(WOLFSSL_SP_NO_MALLOC) !defined(WOLFSSL_SP_NO_MALLOC)
if ((rem == NULL) || (rem == d)) { if (sa != rem) {
sp_init_size(sa, a->used + 1); sp_init_size(sa, a->used + 1);
} }
if ((r == NULL) || (r == d)) { if (tr != r) {
sp_init_size(tr, a->used - d->used + 2); sp_init_size(tr, a->used - d->used + 2);
} }
#else #else
@ -11309,31 +11310,47 @@ int sp_exptmod_ex(sp_int* b, sp_int* e, int digits, sp_int* m, sp_int* r)
} }
#endif #endif
if (err != MP_OKAY) { /* Check for invalid modulus. */
} if ((err == MP_OKAY) && sp_iszero(m)) {
/* Handle special cases. */
else if (sp_iszero(m)) {
err = MP_VAL; err = MP_VAL;
} }
#ifdef WOLFSSL_SP_INT_NEGATIVE #ifdef WOLFSSL_SP_INT_NEGATIVE
else if ((e->sign == MP_NEG) || (m->sign == MP_NEG)) { /* Check for unsupported negative values of exponent and modulus. */
if ((err == MP_OKAY) && ((e->sign == MP_NEG) || (m->sign == MP_NEG))) {
err = MP_VAL; err = MP_VAL;
} }
#endif #endif
else if (sp_isone(m)) {
/* Check for degenerate cases. */
if ((err == MP_OKAY) && sp_isone(m)) {
sp_set(r, 0); sp_set(r, 0);
done = 1; done = 1;
} }
else if (sp_iszero(e)) { if ((!done) && (err == MP_OKAY) && sp_iszero(e)) {
sp_set(r, 1); sp_set(r, 1);
done = 1; done = 1;
} }
else if (sp_iszero(b)) {
/* Check whether base needs to be reduced. */
if ((!done) && (err == MP_OKAY) && (_sp_cmp_abs(b, m) != MP_LT)) {
if ((r == e) || (r == m)) {
err = MP_VAL;
}
if (err == MP_OKAY) {
err = sp_mod(b, m, r);
}
if (err == MP_OKAY) {
b = r;
}
}
/* Check for degenerate case of base. */
if ((!done) && (err == MP_OKAY) && sp_iszero(b)) {
sp_set(r, 0); sp_set(r, 0);
done = 1; done = 1;
} }
/* Ensure SP integers have space for intermediate values. */ /* Ensure SP integers have space for intermediate values. */
else if (m->used * 2 >= r->size) { if ((!done) && (err == MP_OKAY) && (m->used * 2 >= r->size)) {
err = MP_VAL; err = MP_VAL;
} }