forked from wolfSSL/wolfssl
Heap math: mp_add/submod_ct make work when c == d
mp_addmod_ct and mp_submod_ct expected c and d to be different pointers. Change code to support this use case. Fix whitespace.
This commit is contained in:
@ -3068,47 +3068,83 @@ int mp_submod(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
|
|||||||
/* d = a + b (mod c) */
|
/* d = a + b (mod c) */
|
||||||
int mp_addmod(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
|
int mp_addmod(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
|
||||||
{
|
{
|
||||||
int res;
|
int res;
|
||||||
mp_int t;
|
mp_int t;
|
||||||
|
|
||||||
if ((res = mp_init (&t)) != MP_OKAY) {
|
if ((res = mp_init (&t)) != MP_OKAY) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
res = mp_add (a, b, &t);
|
res = mp_add (a, b, &t);
|
||||||
if (res == MP_OKAY) {
|
if (res == MP_OKAY) {
|
||||||
res = mp_mod (&t, c, d);
|
res = mp_mod (&t, c, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
mp_clear (&t);
|
mp_clear (&t);
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* d = a - b (mod c) - a < c and b < c and positive */
|
/* d = a - b (mod c) - a < c and b < c and positive */
|
||||||
int mp_submod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
|
int mp_submod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
|
||||||
{
|
{
|
||||||
int res;
|
int res;
|
||||||
|
mp_int t;
|
||||||
|
mp_int* r = d;
|
||||||
|
|
||||||
res = mp_sub(a, b, d);
|
if (c == d) {
|
||||||
if (res == MP_OKAY && mp_isneg(d)) {
|
r = &t;
|
||||||
res = mp_add(d, c, d);
|
|
||||||
|
if ((res = mp_init (r)) != MP_OKAY) {
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
res = mp_sub (a, b, r);
|
||||||
|
if (res == MP_OKAY) {
|
||||||
|
if (mp_isneg (r)) {
|
||||||
|
res = mp_add (r, c, d);
|
||||||
|
} else if (c == d) {
|
||||||
|
res = mp_copy (r, d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c == d) {
|
||||||
|
mp_clear (r);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* d = a + b (mod c) - a < c and b < c and positive */
|
/* d = a + b (mod c) - a < c and b < c and positive */
|
||||||
int mp_addmod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
|
int mp_addmod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
|
||||||
{
|
{
|
||||||
int res;
|
int res;
|
||||||
|
mp_int t;
|
||||||
|
mp_int* r = d;
|
||||||
|
|
||||||
res = mp_add(a, b, d);
|
if (c == d) {
|
||||||
if (res == MP_OKAY && mp_cmp(d, c) != MP_LT) {
|
r = &t;
|
||||||
res = mp_sub(d, c, d);
|
|
||||||
|
if ((res = mp_init (r)) != MP_OKAY) {
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return res;
|
res = mp_add (a, b, r);
|
||||||
|
if (res == MP_OKAY) {
|
||||||
|
if (mp_cmp (r, c) != MP_LT) {
|
||||||
|
res = mp_sub (r, c, d);
|
||||||
|
} else if (c == d) {
|
||||||
|
res = mp_copy (r, d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c == d) {
|
||||||
|
mp_clear (r);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* computes b = a*a */
|
/* computes b = a*a */
|
||||||
|
Reference in New Issue
Block a user