Performance improvement for fast math mp_clear to use fp_zero (memset) instead of fp_clear(ForceZero). Added new mp_forcezero function for clearing/free'ing sensitive private key data. Changed ECC and RSA to use mp_forcezero to explicitly handle private key clearing.

This commit is contained in:
David Garske
2016-05-12 11:50:34 -07:00
parent 440956f8d4
commit 2ad9d41641
7 changed files with 66 additions and 68 deletions

View File

@@ -56,8 +56,8 @@ void wc_InitDhKey(DhKey* key)
(void)key; (void)key;
/* TomsFastMath doesn't use memory allocation */ /* TomsFastMath doesn't use memory allocation */
#ifndef USE_FAST_MATH #ifndef USE_FAST_MATH
key->p.dp = 0; key->p.dp = NULL;
key->g.dp = 0; key->g.dp = NULL;
#endif #endif
} }
@@ -65,11 +65,8 @@ void wc_InitDhKey(DhKey* key)
void wc_FreeDhKey(DhKey* key) void wc_FreeDhKey(DhKey* key)
{ {
(void)key; (void)key;
/* TomsFastMath doesn't use memory allocation */
#ifndef USE_FAST_MATH
mp_clear(&key->p); mp_clear(&key->p);
mp_clear(&key->g); mp_clear(&key->g);
#endif
} }

View File

@@ -1572,7 +1572,7 @@ static int wc_ecc_make_key_ex(WC_RNG* rng, ecc_key* key, const ecc_set_type* dp)
mp_clear(key->pubkey.x); mp_clear(key->pubkey.x);
mp_clear(key->pubkey.y); mp_clear(key->pubkey.y);
mp_clear(key->pubkey.z); mp_clear(key->pubkey.z);
mp_clear(&key->k); mp_forcezero(&key->k);
} }
wc_ecc_del_point(base); wc_ecc_del_point(base);
if (po_init) { if (po_init) {
@@ -1803,7 +1803,7 @@ void wc_ecc_free(ecc_key* key)
mp_clear(key->pubkey.x); mp_clear(key->pubkey.x);
mp_clear(key->pubkey.y); mp_clear(key->pubkey.y);
mp_clear(key->pubkey.z); mp_clear(key->pubkey.z);
mp_clear(&key->k); mp_forcezero(&key->k);
} }

View File

@@ -34,6 +34,13 @@
/* in case user set USE_FAST_MATH there */ /* in case user set USE_FAST_MATH there */
#include <wolfssl/wolfcrypt/settings.h> #include <wolfssl/wolfcrypt/settings.h>
#ifdef NO_INLINE
#include <wolfssl/wolfcrypt/misc.h>
#else
#define WOLFSSL_MISC_INCLUDED
#include <wolfcrypt/src/misc.c>
#endif
#ifndef NO_BIG_INT #ifndef NO_BIG_INT
#ifndef USE_FAST_MATH #ifndef USE_FAST_MATH
@@ -157,8 +164,7 @@ int mp_init (mp_int * a)
/* clear one (frees) */ /* clear one (frees) */
void void mp_clear (mp_int * a)
mp_clear (mp_int * a)
{ {
int i; int i;
@@ -182,6 +188,29 @@ mp_clear (mp_int * a)
} }
} }
void mp_forcezero(mp_int * a)
{
if (a == NULL)
return;
/* only do anything if a hasn't been freed previously */
if (a->dp != NULL) {
/* force zero the used digits */
ForceZero(a->dp, a->used * sizeof(mp_digit));
/* free ram */
XFREE(a->dp, 0, DYNAMIC_TYPE_BIGINT);
/* reset members to make debugging easier */
a->dp = NULL;
a->alloc = a->used = 0;
a->sign = MP_ZPOS;
}
a->sign = MP_ZPOS;
a->used = 0;
}
/* get the size for an unsigned equivalent */ /* get the size for an unsigned equivalent */
int mp_unsigned_bin_size (mp_int * a) int mp_unsigned_bin_size (mp_int * a)
@@ -192,8 +221,7 @@ int mp_unsigned_bin_size (mp_int * a)
/* returns the number of bits in an int */ /* returns the number of bits in an int */
int int mp_count_bits (mp_int * a)
mp_count_bits (mp_int * a)
{ {
int r; int r;
mp_digit q; mp_digit q;
@@ -427,6 +455,9 @@ void mp_zero (mp_int * a)
int n; int n;
mp_digit *tmp; mp_digit *tmp;
if (a == NULL)
return;
a->sign = MP_ZPOS; a->sign = MP_ZPOS;
a->used = 0; a->used = 0;
@@ -444,8 +475,7 @@ void mp_zero (mp_int * a)
* Typically very fast. Also fixes the sign if there * Typically very fast. Also fixes the sign if there
* are no more leading digits * are no more leading digits
*/ */
void void mp_clamp (mp_int * a)
mp_clamp (mp_int * a)
{ {
/* decrease used while the most significant digit is /* decrease used while the most significant digit is
* zero. * zero.
@@ -464,8 +494,7 @@ mp_clamp (mp_int * a)
/* swap the elements of two integers, for cases where you can't simply swap the /* swap the elements of two integers, for cases where you can't simply swap the
* mp_int pointers around * mp_int pointers around
*/ */
void void mp_exch (mp_int * a, mp_int * b)
mp_exch (mp_int * a, mp_int * b)
{ {
mp_int t; mp_int t;
@@ -560,8 +589,7 @@ void mp_rshd (mp_int * a, int b)
/* calc a value mod 2**b */ /* calc a value mod 2**b */
int int mp_mod_2d (mp_int * a, int b, mp_int * c)
mp_mod_2d (mp_int * a, int b, mp_int * c)
{ {
int x, res; int x, res;
@@ -838,8 +866,7 @@ int mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y)
* *
* Simple function copies the input and fixes the sign to positive * Simple function copies the input and fixes the sign to positive
*/ */
int int mp_abs (mp_int * a, mp_int * b)
mp_abs (mp_int * a, mp_int * b)
{ {
int res; int res;
@@ -1224,8 +1251,7 @@ int mp_cmp_mag (mp_int * a, mp_int * b)
/* compare two ints (signed)*/ /* compare two ints (signed)*/
int int mp_cmp (mp_int * a, mp_int * b)
mp_cmp (mp_int * a, mp_int * b)
{ {
/* compare based on sign */ /* compare based on sign */
if (a->sign != b->sign) { if (a->sign != b->sign) {
@@ -1288,8 +1314,7 @@ int mp_is_bit_set (mp_int *a, mp_digit b)
} }
/* c = a mod b, 0 <= c < b */ /* c = a mod b, 0 <= c < b */
int int mp_mod (mp_int * a, mp_int * b, mp_int * c)
mp_mod (mp_int * a, mp_int * b, mp_int * c)
{ {
mp_int t; mp_int t;
int res; int res;
@@ -1468,8 +1493,7 @@ int mp_add (mp_int * a, mp_int * b, mp_int * c)
/* low level addition, based on HAC pp.594, Algorithm 14.7 */ /* low level addition, based on HAC pp.594, Algorithm 14.7 */
int int s_mp_add (mp_int * a, mp_int * b, mp_int * c)
s_mp_add (mp_int * a, mp_int * b, mp_int * c)
{ {
mp_int *x; mp_int *x;
int olduse, res, min, max; int olduse, res, min, max;
@@ -1557,8 +1581,7 @@ s_mp_add (mp_int * a, mp_int * b, mp_int * c)
/* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */ /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
int int s_mp_sub (mp_int * a, mp_int * b, mp_int * c)
s_mp_sub (mp_int * a, mp_int * b, mp_int * c)
{ {
int olduse, res, min, max; int olduse, res, min, max;
@@ -1625,8 +1648,7 @@ s_mp_sub (mp_int * a, mp_int * b, mp_int * c)
/* high level subtraction (handles signs) */ /* high level subtraction (handles signs) */
int int mp_sub (mp_int * a, mp_int * b, mp_int * c)
mp_sub (mp_int * a, mp_int * b, mp_int * c)
{ {
int sa, sb, res; int sa, sb, res;
@@ -2068,8 +2090,7 @@ LBL_M:
/* setups the montgomery reduction stuff */ /* setups the montgomery reduction stuff */
int int mp_montgomery_setup (mp_int * n, mp_digit * rho)
mp_montgomery_setup (mp_int * n, mp_digit * rho)
{ {
mp_digit x, b; mp_digit x, b;
@@ -2274,8 +2295,7 @@ int fast_mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho)
/* computes xR**-1 == x (mod N) via Montgomery Reduction */ /* computes xR**-1 == x (mod N) via Montgomery Reduction */
int int mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho)
mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho)
{ {
int ix, res, digs; int ix, res, digs;
mp_digit mu; mp_digit mu;
@@ -2396,8 +2416,7 @@ void mp_dr_setup(mp_int *a, mp_digit *d)
* *
* Input x must be in the range 0 <= x <= (n-1)**2 * Input x must be in the range 0 <= x <= (n-1)**2
*/ */
int int mp_dr_reduce (mp_int * x, mp_int * n, mp_digit k)
mp_dr_reduce (mp_int * x, mp_int * n, mp_digit k)
{ {
int err, i, m; int err, i, m;
mp_word r; mp_word r;
@@ -2524,8 +2543,7 @@ int mp_reduce_2k_setup(mp_int *a, mp_digit *d)
/* set the b bit of a */ /* set the b bit of a */
int int mp_set_bit (mp_int * a, int b)
mp_set_bit (mp_int * a, int b)
{ {
int i = b / DIGIT_BIT, res; int i = b / DIGIT_BIT, res;
@@ -2549,8 +2567,7 @@ mp_set_bit (mp_int * a, int b)
* *
* Simple algorithm which zeros the int, set the required bit * Simple algorithm which zeros the int, set the required bit
*/ */
int int mp_2expt (mp_int * a, int b)
mp_2expt (mp_int * a, int b)
{ {
/* zero a as per default */ /* zero a as per default */
mp_zero (a); mp_zero (a);
@@ -2559,8 +2576,7 @@ mp_2expt (mp_int * a, int b)
} }
/* multiply by a digit */ /* multiply by a digit */
int int mp_mul_d (mp_int * a, mp_digit b, mp_int * c)
mp_mul_d (mp_int * a, mp_digit b, mp_int * c)
{ {
mp_digit u, *tmpa, *tmpc; mp_digit u, *tmpa, *tmpc;
mp_word r; mp_word r;
@@ -2638,8 +2654,7 @@ int mp_mulmod (mp_int * a, mp_int * b, mp_int * c, mp_int * d)
/* computes b = a*a */ /* computes b = a*a */
int int mp_sqr (mp_int * a, mp_int * b)
mp_sqr (mp_int * a, mp_int * b)
{ {
int res; int res;
@@ -2760,8 +2775,7 @@ int mp_mul_2(mp_int * a, mp_int * b)
/* divide by three (based on routine from MPI and the GMP manual) */ /* divide by three (based on routine from MPI and the GMP manual) */
int int mp_div_3 (mp_int * a, mp_int *c, mp_digit * d)
mp_div_3 (mp_int * a, mp_int *c, mp_digit * d)
{ {
mp_int q; mp_int q;
mp_word w, t; mp_word w, t;
@@ -3632,8 +3646,7 @@ ERR:
/* multiplies |a| * |b| and does not compute the lower digs digits /* multiplies |a| * |b| and does not compute the lower digs digits
* [meant to get the higher part of the product] * [meant to get the higher part of the product]
*/ */
int int s_mp_mul_high_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
s_mp_mul_high_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
{ {
mp_int t; mp_int t;
int res, pa, pb, ix, iy; int res, pa, pb, ix, iy;

View File

@@ -206,30 +206,16 @@ int wc_FreeRsaKey(RsaKey* key)
return FreeCaviumRsaKey(key); return FreeCaviumRsaKey(key);
#endif #endif
/* TomsFastMath doesn't use memory allocation */
#ifndef USE_FAST_MATH
if (key->type == RSA_PRIVATE) { if (key->type == RSA_PRIVATE) {
mp_clear(&key->u); mp_forcezero(&key->u);
mp_clear(&key->dQ); mp_forcezero(&key->dQ);
mp_clear(&key->dP); mp_forcezero(&key->dP);
mp_clear(&key->q); mp_forcezero(&key->q);
mp_clear(&key->p); mp_forcezero(&key->p);
mp_clear(&key->d); mp_forcezero(&key->d);
} }
mp_clear(&key->e); mp_clear(&key->e);
mp_clear(&key->n); mp_clear(&key->n);
#else
/* still clear private key memory information when free'd */
if (key->type == RSA_PRIVATE) {
mp_clear(&key->u);
mp_clear(&key->dQ);
mp_clear(&key->u);
mp_clear(&key->dP);
mp_clear(&key->q);
mp_clear(&key->p);
mp_clear(&key->d);
}
#endif
return 0; return 0;
} }

View File

@@ -2096,7 +2096,7 @@ void fp_clear(fp_int *a)
/* clear one (frees) */ /* clear one (frees) */
void mp_clear (mp_int * a) void mp_clear (mp_int * a)
{ {
fp_clear(a); fp_zero(a);
} }
/* handle up to 6 inits */ /* handle up to 6 inits */

View File

@@ -230,6 +230,7 @@ extern const char *mp_s_rmap;
/* 6 functions needed by Rsa */ /* 6 functions needed by Rsa */
int mp_init (mp_int * a); int mp_init (mp_int * a);
void mp_clear (mp_int * a); void mp_clear (mp_int * a);
void mp_forcezero(mp_int * a);
int mp_unsigned_bin_size(mp_int * a); int mp_unsigned_bin_size(mp_int * a);
int mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c); int mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c);
int mp_to_unsigned_bin (mp_int * a, unsigned char *b); int mp_to_unsigned_bin (mp_int * a, unsigned char *b);

View File

@@ -611,6 +611,7 @@ void fp_sqr_comba64(fp_int *a, fp_int *b);
#define mp_iseven(a) fp_iseven(a) #define mp_iseven(a) fp_iseven(a)
int mp_init (mp_int * a); int mp_init (mp_int * a);
void mp_clear (mp_int * a); void mp_clear (mp_int * a);
#define mp_forcezero(a) fp_clear(a)
int mp_init_multi(mp_int* a, mp_int* b, mp_int* c, mp_int* d, mp_int* e, mp_int* f); int mp_init_multi(mp_int* a, mp_int* b, mp_int* c, mp_int* d, mp_int* e, mp_int* f);
int mp_add (mp_int * a, mp_int * b, mp_int * c); int mp_add (mp_int * a, mp_int * b, mp_int * c);