Merge pull request #2859 from SparkiDev/tfm_ec_invmod_ct

Constant time EC map to affine for private operations
This commit is contained in:
David Garske
2020-03-23 19:16:45 -07:00
committed by GitHub
4 changed files with 124 additions and 8 deletions

View File

@ -2285,9 +2285,10 @@ int ecc_projective_dbl_point(ecc_point *P, ecc_point *R, mp_int* a,
P [in/out] The point to map P [in/out] The point to map
modulus The modulus of the field the ECC curve is in modulus The modulus of the field the ECC curve is in
mp The "b" value from montgomery_setup() mp The "b" value from montgomery_setup()
ct Operation should be constant time.
return MP_OKAY on success return MP_OKAY on success
*/ */
int ecc_map(ecc_point* P, mp_int* modulus, mp_digit mp) int ecc_map_ex(ecc_point* P, mp_int* modulus, mp_digit mp, int ct)
{ {
#ifndef WOLFSSL_SP_MATH #ifndef WOLFSSL_SP_MATH
#ifdef WOLFSSL_SMALL_STACK #ifdef WOLFSSL_SMALL_STACK
@ -2307,6 +2308,8 @@ int ecc_map(ecc_point* P, mp_int* modulus, mp_digit mp)
mp_int *x, *y, *z; mp_int *x, *y, *z;
int err; int err;
(void)ct;
if (P == NULL || modulus == NULL) if (P == NULL || modulus == NULL)
return ECC_BAD_ARG_E; return ECC_BAD_ARG_E;
@ -2402,12 +2405,23 @@ int ecc_map(ecc_point* P, mp_int* modulus, mp_digit mp)
z = P->z; z = P->z;
#endif #endif
/* first map z back to normal */
err = mp_montgomery_reduce(z, modulus, mp);
/* get 1/z */ /* get 1/z */
if (err == MP_OKAY) if (err == MP_OKAY) {
err = mp_invmod(z, modulus, t1); #if defined(ECC_TIMING_RESISTANT) && defined(USE_FAST_MATH)
if (ct) {
err = mp_invmod_mont_ct(z, modulus, t1, mp);
if (err == MP_OKAY)
err = mp_montgomery_reduce(t1, modulus, mp);
}
else
#endif
{
/* first map z back to normal */
err = mp_montgomery_reduce(z, modulus, mp);
if (err == MP_OKAY)
err = mp_invmod(z, modulus, t1);
}
}
/* get 1/z^2 and 1/z^3 */ /* get 1/z^2 and 1/z^3 */
if (err == MP_OKAY) if (err == MP_OKAY)
@ -2484,6 +2498,10 @@ done:
#endif #endif
} }
int ecc_map(ecc_point* P, mp_int* modulus, mp_digit mp)
{
return ecc_map_ex(P, modulus, mp, 0);
}
#endif /* !WOLFSSL_SP_MATH || WOLFSSL_PUBLIC_ECC_ADD_DBL */ #endif /* !WOLFSSL_SP_MATH || WOLFSSL_PUBLIC_ECC_ADD_DBL */
#if !defined(FREESCALE_LTC_ECC) && !defined(WOLFSSL_STM32_PKA) #if !defined(FREESCALE_LTC_ECC) && !defined(WOLFSSL_STM32_PKA)
@ -3639,6 +3657,8 @@ static int wc_ecc_shared_secret_gen_sync(ecc_key* private_key, ecc_point* point,
} }
#else #else
{ {
mp_digit mp = 0;
/* make new point */ /* make new point */
result = wc_ecc_new_point_h(private_key->heap); result = wc_ecc_new_point_h(private_key->heap);
if (result == NULL) { if (result == NULL) {
@ -3649,8 +3669,16 @@ static int wc_ecc_shared_secret_gen_sync(ecc_key* private_key, ecc_point* point,
return MEMORY_E; return MEMORY_E;
} }
err = wc_ecc_mulmod_ex(k, point, result, curve->Af, curve->prime, 1, /* Map in a separate call as this should be constant time */
err = wc_ecc_mulmod_ex(k, point, result, curve->Af, curve->prime, 0,
private_key->heap); private_key->heap);
if (err == MP_OKAY) {
err = mp_montgomery_setup(curve->prime, &mp);
}
if (err == MP_OKAY) {
/* Use constant time map if compiled in */
err = ecc_map_ex(result, curve->prime, mp, 1);
}
if (err == MP_OKAY) { if (err == MP_OKAY) {
x = mp_unsigned_bin_size(curve->prime); x = mp_unsigned_bin_size(curve->prime);
if (*outlen < x || (int)x < mp_unsigned_bin_size(result->x)) { if (*outlen < x || (int)x < mp_unsigned_bin_size(result->x)) {
@ -4008,6 +4036,8 @@ static int wc_ecc_make_pub_ex(ecc_key* key, ecc_curve_spec* curveIn,
err = WC_KEY_SIZE_E; err = WC_KEY_SIZE_E;
#else #else
{ {
mp_digit mp;
if (err == MP_OKAY) { if (err == MP_OKAY) {
base = wc_ecc_new_point_h(key->heap); base = wc_ecc_new_point_h(key->heap);
if (base == NULL) if (base == NULL)
@ -4023,12 +4053,20 @@ static int wc_ecc_make_pub_ex(ecc_key* key, ecc_curve_spec* curveIn,
/* make the public key */ /* make the public key */
if (err == MP_OKAY) { if (err == MP_OKAY) {
/* Map in a separate call as this should be constant time */
err = wc_ecc_mulmod_ex(&key->k, base, pub, curve->Af, curve->prime, err = wc_ecc_mulmod_ex(&key->k, base, pub, curve->Af, curve->prime,
1, key->heap); 0, key->heap);
if (err == MP_MEM) { if (err == MP_MEM) {
err = MEMORY_E; err = MEMORY_E;
} }
} }
if (err == MP_OKAY) {
err = mp_montgomery_setup(curve->prime, &mp);
}
if (err == MP_OKAY) {
/* Use constant time map if compiled in */
err = ecc_map_ex(pub, curve->prime, mp, 1);
}
wc_ecc_del_point_h(base, key->heap); wc_ecc_del_point_h(base, key->heap);
} }

View File

@ -1160,6 +1160,75 @@ top:
return FP_OKAY; return FP_OKAY;
} }
#define CT_INV_MOD_PRE_CNT 8
/* modulus (b) must be greater than 2 and a prime */
int fp_invmod_mont_ct(fp_int *a, fp_int *b, fp_int *c, fp_digit mp)
{
int i, j;
#ifndef WOLFSSL_SMALL_STACK
fp_int t[1], e[1];
fp_int pre[CT_INV_MOD_PRE_CNT];
#else
fp_int* t;
fp_int* e;
fp_int* pre;
#endif
#ifdef WOLFSSL_SMALL_STACK
t = (fp_int*)XMALLOC(sizeof(fp_int) * (2 + CT_INV_MOD_PRE_CNT), NULL,
DYNAMIC_TYPE_BIGINT);
if (t == NULL)
return FP_MEM;
e = t + 1;
pre = t + 2;
#endif
fp_init(t);
fp_init(e);
fp_init(&pre[0]);
fp_copy(a, &pre[0]);
for (i = 1; i < CT_INV_MOD_PRE_CNT; i++) {
fp_init(&pre[i]);
fp_sqr(&pre[i-1], &pre[i]);
fp_montgomery_reduce(&pre[i], b, mp);
fp_mul(&pre[i], a, &pre[i]);
fp_montgomery_reduce(&pre[i], b, mp);
}
fp_sub_d(b, 2, e);
/* Highest bit is always set. */
for (i = fp_count_bits(e)-2, j = 1; i >= 0; i--, j++) {
if (!fp_is_bit_set(e, i) || j == CT_INV_MOD_PRE_CNT)
break;
}
fp_copy(&pre[j-1], t);
for (j = 0; i >= 0; i--) {
int set = fp_is_bit_set(e, i);
if ((j == CT_INV_MOD_PRE_CNT) || (!set && j > 0)) {
fp_mul(t, &pre[j-1], t);
fp_montgomery_reduce(t, b, mp);
j = 0;
}
fp_sqr(t, t);
fp_montgomery_reduce(t, b, mp);
j += set;
}
if (j > 0) {
fp_mul(t, &pre[j-1], c);
fp_montgomery_reduce(c, b, mp);
}
else
fp_copy(t, c);
#ifdef WOLFSSL_SMALL_STACK
XFREE(t, NULL, DYNAMIC_TYPE_BIGINT);
#endif
return FP_OKAY;
}
/* d = a * b (mod c) */ /* d = a * b (mod c) */
int fp_mulmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d) int fp_mulmod(fp_int *a, fp_int *b, fp_int *c, fp_int *d)
{ {
@ -3545,6 +3614,12 @@ int mp_invmod (mp_int * a, mp_int * b, mp_int * c)
return fp_invmod(a, b, c); return fp_invmod(a, b, c);
} }
/* hac 14.61, pp608 */
int mp_invmod_mont_ct (mp_int * a, mp_int * b, mp_int * c, mp_digit mp)
{
return fp_invmod_mont_ct(a, b, c, mp);
}
/* this is a shell function that calls either the normal or Montgomery /* this is a shell function that calls either the normal or Montgomery
* exptmod functions. Originally the call to the montgomery code was * exptmod functions. Originally the call to the montgomery code was
* embedded in the normal function but that wasted a lot of stack space * embedded in the normal function but that wasted a lot of stack space

View File

@ -439,6 +439,7 @@ ECC_API int ecc_mul2add(ecc_point* A, mp_int* kA,
ecc_point* C, mp_int* a, mp_int* modulus, void* heap); ecc_point* C, mp_int* a, mp_int* modulus, void* heap);
ECC_API int ecc_map(ecc_point*, mp_int*, mp_digit); ECC_API int ecc_map(ecc_point*, mp_int*, mp_digit);
ECC_API int ecc_map_ex(ecc_point*, mp_int*, mp_digit, int ct);
ECC_API int ecc_projective_add_point(ecc_point* P, ecc_point* Q, ecc_point* R, ECC_API int ecc_projective_add_point(ecc_point* P, ecc_point* Q, ecc_point* R,
mp_int* a, mp_int* modulus, mp_digit mp); mp_int* a, mp_int* modulus, mp_digit mp);
ECC_API int ecc_projective_dbl_point(ecc_point* P, ecc_point* R, mp_int* a, ECC_API int ecc_projective_dbl_point(ecc_point* P, ecc_point* R, mp_int* a,

View File

@ -535,6 +535,7 @@ int fp_sqrmod(fp_int *a, fp_int *b, fp_int *c);
/* c = 1/a (mod b) */ /* c = 1/a (mod b) */
int fp_invmod(fp_int *a, fp_int *b, fp_int *c); int fp_invmod(fp_int *a, fp_int *b, fp_int *c);
int fp_invmod_mont_ct(fp_int *a, fp_int *b, fp_int *c, fp_digit mp);
/* c = (a, b) */ /* c = (a, b) */
/*int fp_gcd(fp_int *a, fp_int *b, fp_int *c);*/ /*int fp_gcd(fp_int *a, fp_int *b, fp_int *c);*/
@ -743,6 +744,7 @@ MP_API int mp_submod (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_addmod (mp_int* a, mp_int* b, mp_int* c, mp_int* d); MP_API int mp_addmod (mp_int* a, mp_int* b, mp_int* c, mp_int* d);
MP_API int mp_mod(mp_int *a, mp_int *b, mp_int *c); MP_API int mp_mod(mp_int *a, mp_int *b, mp_int *c);
MP_API int mp_invmod(mp_int *a, mp_int *b, mp_int *c); MP_API int mp_invmod(mp_int *a, mp_int *b, mp_int *c);
MP_API int mp_invmod_mont_ct(mp_int *a, mp_int *b, mp_int *c, fp_digit mp);
MP_API int mp_exptmod (mp_int * g, mp_int * x, mp_int * p, mp_int * y); MP_API int mp_exptmod (mp_int * g, mp_int * x, mp_int * p, mp_int * y);
MP_API int mp_exptmod_ex (mp_int * g, mp_int * x, int minDigits, mp_int * p, MP_API int mp_exptmod_ex (mp_int * g, mp_int * x, int minDigits, mp_int * p,
mp_int * y); mp_int * y);