ECC uses CT vers of addmod, submod and div_2_mod

The TFM implementations of mp_submod_ct, mp_addmod_ct,
mp_div_2_mod_t are more resilient to side-channels.
This commit is contained in:
Sean Parkinson
2020-05-19 16:30:11 +10:00
parent 4f30e37094
commit 9ef9671886
5 changed files with 175 additions and 125 deletions

View File

@ -1728,46 +1728,22 @@ int ecc_projective_add_point(ecc_point* P, ecc_point* Q, ecc_point* R,
/* Y = Y - T1 */
if (err == MP_OKAY)
err = mp_sub(y, t1, y);
if (err == MP_OKAY) {
if (mp_isneg(y))
err = mp_add(y, modulus, y);
}
err = mp_submod_ct(y, t1, modulus, y);
/* T1 = 2T1 */
if (err == MP_OKAY)
err = mp_add(t1, t1, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t1, t1, modulus, t1);
/* T1 = Y + T1 */
if (err == MP_OKAY)
err = mp_add(t1, y, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t1, y, modulus, t1);
/* X = X - T2 */
if (err == MP_OKAY)
err = mp_sub(x, t2, x);
if (err == MP_OKAY) {
if (mp_isneg(x))
err = mp_add(x, modulus, x);
}
err = mp_submod_ct(x, t2, modulus, x);
/* T2 = 2T2 */
if (err == MP_OKAY)
err = mp_add(t2, t2, t2);
if (err == MP_OKAY) {
if (mp_cmp(t2, modulus) != MP_LT)
err = mp_sub(t2, modulus, t2);
}
err = mp_addmod_ct(t2, t2, modulus, t2);
/* T2 = X + T2 */
if (err == MP_OKAY)
err = mp_add(t2, x, t2);
if (err == MP_OKAY) {
if (mp_cmp(t2, modulus) != MP_LT)
err = mp_sub(t2, modulus, t2);
}
err = mp_addmod_ct(t2, x, modulus, t2);
if (err == MP_OKAY) {
if (!mp_iszero(Q->z)) {
@ -1816,25 +1792,13 @@ int ecc_projective_add_point(ecc_point* P, ecc_point* Q, ecc_point* R,
/* X = X - T2 */
if (err == MP_OKAY)
err = mp_sub(x, t2, x);
if (err == MP_OKAY) {
if (mp_isneg(x))
err = mp_add(x, modulus, x);
}
err = mp_submod_ct(x, t2, modulus, x);
/* T2 = T2 - X */
if (err == MP_OKAY)
err = mp_sub(t2, x, t2);
if (err == MP_OKAY) {
if (mp_isneg(t2))
err = mp_add(t2, modulus, t2);
}
err = mp_submod_ct(t2, x, modulus, t2);
/* T2 = T2 - X */
if (err == MP_OKAY)
err = mp_sub(t2, x, t2);
if (err == MP_OKAY) {
if (mp_isneg(t2))
err = mp_add(t2, modulus, t2);
}
err = mp_submod_ct(t2, x, modulus, t2);
/* T2 = T2 * Y */
if (err == MP_OKAY)
err = mp_mul(t2, y, t2);
@ -1843,18 +1807,10 @@ int ecc_projective_add_point(ecc_point* P, ecc_point* Q, ecc_point* R,
/* Y = T2 - T1 */
if (err == MP_OKAY)
err = mp_sub(t2, t1, y);
if (err == MP_OKAY) {
if (mp_isneg(y))
err = mp_add(y, modulus, y);
}
err = mp_submod_ct(t2, t1, modulus, y);
/* Y = Y/2 */
if (err == MP_OKAY) {
if (mp_isodd(y) == MP_YES)
err = mp_add(y, modulus, y);
}
if (err == MP_OKAY)
err = mp_div_2(y, y);
err = mp_div_2_mod_ct(y, modulus, y);
#ifdef ALT_ECC_SIZE
if (err == MP_OKAY)
@ -2071,11 +2027,7 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
/* Z = 2Z */
if (err == MP_OKAY)
err = mp_add(z, z, z);
if (err == MP_OKAY) {
if (mp_cmp(z, modulus) != MP_LT)
err = mp_sub(z, modulus, z);
}
err = mp_addmod_ct(z, z, modulus, z);
/* Determine if curve "a" should be used in calc */
#ifdef WOLFSSL_CUSTOM_CURVES
@ -2101,25 +2053,13 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
err = mp_montgomery_reduce(t2, modulus, mp);
/* T1 = T2 + T1 */
if (err == MP_OKAY)
err = mp_add(t1, t2, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t1, t2, modulus, t1);
/* T1 = T2 + T1 */
if (err == MP_OKAY)
err = mp_add(t1, t2, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t1, t2, modulus, t1);
/* T1 = T2 + T1 */
if (err == MP_OKAY)
err = mp_add(t1, t2, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t1, t2, modulus, t1);
}
else
#endif /* WOLFSSL_CUSTOM_CURVES */
@ -2129,18 +2069,10 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
/* T2 = X - T1 */
if (err == MP_OKAY)
err = mp_sub(x, t1, t2);
if (err == MP_OKAY) {
if (mp_isneg(t2))
err = mp_add(t2, modulus, t2);
}
err = mp_submod_ct(x, t1, modulus, t2);
/* T1 = X + T1 */
if (err == MP_OKAY)
err = mp_add(t1, x, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t1, x, modulus, t1);
/* T2 = T1 * T2 */
if (err == MP_OKAY)
err = mp_mul(t1, t2, t2);
@ -2149,27 +2081,15 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
/* T1 = 2T2 */
if (err == MP_OKAY)
err = mp_add(t2, t2, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t2, t2, modulus, t1);
/* T1 = T1 + T2 */
if (err == MP_OKAY)
err = mp_add(t1, t2, t1);
if (err == MP_OKAY) {
if (mp_cmp(t1, modulus) != MP_LT)
err = mp_sub(t1, modulus, t1);
}
err = mp_addmod_ct(t1, t2, modulus, t1);
}
/* Y = 2Y */
if (err == MP_OKAY)
err = mp_add(y, y, y);
if (err == MP_OKAY) {
if (mp_cmp(y, modulus) != MP_LT)
err = mp_sub(y, modulus, y);
}
err = mp_addmod_ct(y, y, modulus, y);
/* Y = Y * Y */
if (err == MP_OKAY)
err = mp_sqr(y, y);
@ -2183,12 +2103,8 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
err = mp_montgomery_reduce(t2, modulus, mp);
/* T2 = T2/2 */
if (err == MP_OKAY) {
if (mp_isodd(t2) == MP_YES)
err = mp_add(t2, modulus, t2);
}
if (err == MP_OKAY)
err = mp_div_2(t2, t2);
err = mp_div_2_mod_ct(t2, modulus, t2);
/* Y = Y * X */
if (err == MP_OKAY)
@ -2204,26 +2120,14 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
/* X = X - Y */
if (err == MP_OKAY)
err = mp_sub(x, y, x);
if (err == MP_OKAY) {
if (mp_isneg(x))
err = mp_add(x, modulus, x);
}
err = mp_submod_ct(x, y, modulus, x);
/* X = X - Y */
if (err == MP_OKAY)
err = mp_sub(x, y, x);
if (err == MP_OKAY) {
if (mp_isneg(x))
err = mp_add(x, modulus, x);
}
err = mp_submod_ct(x, y, modulus, x);
/* Y = Y - X */
if (err == MP_OKAY)
err = mp_sub(y, x, y);
if (err == MP_OKAY) {
if (mp_isneg(y))
err = mp_add(y, modulus, y);
}
err = mp_submod_ct(y, x, modulus, y);
/* Y = Y * T1 */
if (err == MP_OKAY)
err = mp_mul(y, t1, y);
@ -2232,11 +2136,7 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
/* Y = Y - T2 */
if (err == MP_OKAY)
err = mp_sub(y, t2, y);
if (err == MP_OKAY) {
if (mp_isneg(y))
err = mp_add(y, modulus, y);
}
err = mp_submod_ct(y, t2, modulus, y);
#ifdef ALT_ECC_SIZE
if (err == MP_OKAY)

View File

@ -1577,6 +1577,24 @@ int mp_div_2(mp_int * a, mp_int * b)
return MP_OKAY;
}
/* c = a / 2 (mod b) - constant time (a < b and positive) */
int mp_div_2_mod_ct(mp_int *a, mp_int *b, mp_int *c)
{
int res;
if (mp_isodd(a)) {
res = mp_add(a, b, c);
if (res == MP_OKAY) {
res = mp_div_2(c, c);
}
}
else {
res = mp_div_2(a, c);
}
return res;
}
/* high level addition (handles signs) */
int mp_add (mp_int * a, mp_int * b, mp_int * c)
@ -2994,6 +3012,32 @@ int mp_addmod(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
return res;
}
/* d = a - b (mod c) - a < c and b < c and positive */
int mp_submod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
{
int res;
res = mp_sub(a, b, d);
if (res == MP_OKAY && mp_isneg(d)) {
res = mp_add(d, c, d);
}
return res;
}
/* d = a + b (mod c) - a < c and b < c and positive */
int mp_addmod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
{
int res;
res = mp_add(a, b, d);
if (res == MP_OKAY && mp_cmp(d, c) != MP_LT) {
res = mp_sub(d, c, d);
}
return res;
}
/* computes b = a*a */
int mp_sqr (mp_int * a, mp_int * b)
{

View File

@ -878,6 +878,31 @@ void fp_div_2(fp_int * a, fp_int * b)
fp_clamp (b);
}
/* c = a / 2 (mod b) - constant time (a < b and positive) */
int fp_div_2_mod_ct(fp_int *a, fp_int *b, fp_int *c)
{
fp_word w = 0;
fp_digit mask;
int i;
mask = 0 - (a->dp[0] & 1);
for (i = 0; i < b->used; i++) {
fp_digit mask_a = 0 - (i < a->used);
w += b->dp[i] & mask;
w += a->dp[i] & mask_a;
c->dp[i] = (fp_digit)w;
w >>= DIGIT_BIT;
}
c->dp[i] = (fp_digit)w;
c->used = i + 1;
c->sign = FP_ZPOS;
fp_clamp(c);
fp_div_2(c, c);
return FP_OKAY;
}
/* c = a / 2**b */
void fp_div_2d(fp_int *a, int b, fp_int *c, fp_int *d)
{
@ -1546,6 +1571,54 @@ int fp_addmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
return err;
}
/* d = a - b (mod c) - constant time (a < c and b < c and positive) */
int fp_submod_ct(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
{
fp_word w = 0;
fp_digit mask;
int i;
mask = 0 - (fp_cmp_mag(a, b) == FP_LT);
for (i = 0; i < c->used; i++) {
fp_digit mask_a = 0 - (i < a->used);
w += c->dp[i] & mask;
w += a->dp[i] & mask_a;
d->dp[i] = (fp_digit)w;
w >>= DIGIT_BIT;
}
d->dp[i] = (fp_digit)w;
d->used = i + 1;
d->sign = FP_ZPOS;
fp_clamp(d);
fp_sub(d, b, d);
return FP_OKAY;
}
/* d = a + b (mod c) - constant time (|a| < c and |b| < c and positive) */
int fp_addmod_ct(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
{
fp_word w = 0;
fp_digit mask;
int i;
fp_add(a, b, d);
mask = 0 - (fp_cmp_mag(d, c) != FP_LT);
for (i = 0; i < c->used; i++) {
w += c->dp[i] & mask;
w = d->dp[i] - w;
d->dp[i] = (fp_digit)w;
w = (w >> DIGIT_BIT)&1;
}
d->dp[i] = 0;
d->used = i;
d->sign = a->sign;
fp_clamp(d);
return FP_OKAY;
}
#ifdef TFM_TIMING_RESISTANT
#ifdef WC_RSA_NONBLOCK
@ -4007,6 +4080,18 @@ int mp_addmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
return fp_addmod(a, b, c, d);
}
/* d = a - b (mod c) - constant time (a < c and b < c) */
int mp_submod_ct(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
{
return fp_submod_ct(a, b, c, d);
}
/* d = a + b (mod c) - constant time (a < c and b < c) */
int mp_addmod_ct(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
{
return fp_addmod_ct(a, b, c, d);
}
/* c = a mod b, 0 <= c < b */
#if defined(FREESCALE_LTC_TFM)
int wolfcrypt_mp_mod (mp_int * a, mp_int * b, mp_int * c)
@ -5234,6 +5319,12 @@ int mp_div_2(fp_int * a, fp_int * b)
return MP_OKAY;
}
/* c = a / 2 (mod b) - constant time (a < b and positive) */
int mp_div_2_mod_ct(mp_int *a, mp_int *b, mp_int *c)
{
return fp_div_2_mod_ct(a, b, c);
}
int mp_init_copy(fp_int * a, fp_int * b)
{

View File

@ -318,6 +318,7 @@ MP_API int mp_is_bit_set (mp_int * a, mp_digit b);
MP_API int mp_mod (mp_int * a, mp_int * b, mp_int * c);
MP_API int mp_div(mp_int * a, mp_int * b, mp_int * c, mp_int * d);
MP_API int mp_div_2(mp_int * a, mp_int * b);
MP_API int mp_div_2_mod_ct (mp_int* a, mp_int* b, mp_int* c);
MP_API int mp_add (mp_int * a, mp_int * b, mp_int * c);
int s_mp_add (mp_int * a, mp_int * b, mp_int * c);
int s_mp_sub (mp_int * a, mp_int * b, mp_int * c);
@ -355,6 +356,8 @@ MP_API int mp_sqr (mp_int * a, mp_int * b);
MP_API int mp_mulmod (mp_int * a, mp_int * b, mp_int * c, mp_int * d);
MP_API int mp_submod (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_addmod (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_submod_ct (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_addmod_ct (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_mul_d (mp_int * a, mp_digit b, mp_int * c);
MP_API int mp_2expt (mp_int * a, int b);
MP_API int mp_set_bit (mp_int * a, int b);

View File

@ -474,6 +474,9 @@ int fp_mul_2d(fp_int *a, int b, fp_int *c);
void fp_2expt (fp_int *a, int b);
int fp_mul_2(fp_int *a, fp_int *c);
void fp_div_2(fp_int *a, fp_int *c);
/* c = a / 2 (mod b) - constant time (a < b and positive) */
int fp_div_2_mod_ct(fp_int *a, fp_int *b, fp_int *c);
/* Counts the number of lsbs which are zero before the first zero bit */
int fp_cnt_lsb(fp_int *a);
@ -530,6 +533,12 @@ int fp_submod(fp_int *a, fp_int *b, fp_int *c, fp_int *d);
/* d = a + b (mod c) */
int fp_addmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d);
/* d = a - b (mod c) - constant time (a < c and b < c) */
int fp_submod_ct(fp_int *a, fp_int *b, fp_int *c, fp_int *d);
/* d = a + b (mod c) - constant time (a < c and b < c) */
int fp_addmod_ct(fp_int *a, fp_int *b, fp_int *c, fp_int *d);
/* c = a * a (mod b) */
int fp_sqrmod(fp_int *a, fp_int *b, fp_int *c);
@ -743,6 +752,8 @@ MP_API int mp_mul_d (mp_int * a, mp_digit b, mp_int * c);
MP_API int mp_mulmod (mp_int * a, mp_int * b, mp_int * c, mp_int * d);
MP_API int mp_submod (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_addmod (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_submod_ct (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_addmod_ct (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_mod(mp_int *a, mp_int *b, mp_int *c);
MP_API int mp_invmod(mp_int *a, mp_int *b, mp_int *c);
MP_API int mp_invmod_mont_ct(mp_int *a, mp_int *b, mp_int *c, fp_digit mp);
@ -793,6 +804,7 @@ MP_API int mp_radix_size (mp_int * a, int radix, int *size);
MP_API int mp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp);
MP_API int mp_montgomery_setup(fp_int *a, fp_digit *rho);
MP_API int mp_div_2(fp_int * a, fp_int * b);
MP_API int mp_div_2_mod_ct(mp_int *a, mp_int *b, mp_int *c);
MP_API int mp_init_copy(fp_int * a, fp_int * b);
#endif