diff --git a/wolfcrypt/src/rsa.c b/wolfcrypt/src/rsa.c index 0e9db2fd7..535f10c4d 100644 --- a/wolfcrypt/src/rsa.c +++ b/wolfcrypt/src/rsa.c @@ -2495,6 +2495,7 @@ static int RsaFunctionPrivate(mp_int* tmp, RsaKey* key, WC_RNG* rng) { int ret = 0; #if defined(WC_RSA_BLINDING) && !defined(WC_NO_RNG) + mp_digit mp; DECL_MP_INT_SIZE_DYN(rnd, mp_bitsused(&key->n), RSA_MAX_SIZE); DECL_MP_INT_SIZE_DYN(rndi, mp_bitsused(&key->n), RSA_MAX_SIZE); #endif /* WC_RSA_BLINDING && !WC_NO_RNG */ @@ -2627,9 +2628,31 @@ static int RsaFunctionPrivate(mp_int* tmp, RsaKey* key, WC_RNG* rng) #endif /* RSA_LOW_MEM */ #if defined(WC_RSA_BLINDING) && !defined(WC_NO_RNG) - /* unblind */ - if (ret == 0 && mp_mulmod(tmp, rndi, &key->n, tmp) != MP_OKAY) + /* Multiply result (tmp) by bliding invertor (rndi). + * Use Montogemery form to make operation more constant time. + */ + if ((ret == 0) && (mp_montgomery_setup(&key->n, &mp) != MP_OKAY)) { ret = MP_MULMOD_E; + } + if ((ret == 0) && (mp_montgomery_calc_normalization(rnd, &key->n) != + MP_OKAY)) { + ret = MP_MULMOD_E; + } + /* Convert blinding invert to Montogmery form. */ + if ((ret == 0) && (mp_mul(rndi, rnd, rndi) != MP_OKAY)) { + ret = MP_MULMOD_E; + } + if ((ret == 0) && (mp_mod(rndi, &key->n, rndi) != MP_OKAY)) { + ret = MP_MULMOD_E; + } + /* Multiply result by blinding invert. */ + if ((ret == 0) && (mp_mul(tmp, rndi, tmp) != MP_OKAY)) { + ret = MP_MULMOD_E; + } + /* Reduce result. */ + if ((ret == 0) && (mp_montgomery_reduce_ct(tmp, &key->n, mp) != MP_OKAY)) { + ret = MP_MULMOD_E; + } mp_forcezero(rndi); mp_forcezero(rnd); @@ -3520,8 +3543,9 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out, mgf, label, labelSz, saltLen, mp_count_bits(&key->n), key->heap); #endif - if (rsa_type == RSA_PUBLIC_DECRYPT && ret > (int)outLen) + if (rsa_type == RSA_PUBLIC_DECRYPT && ret > (int)outLen) { ret = RSA_BUFFER_E; + } else if (ret >= 0 && pad != NULL) { /* only copy output if not inline */ if (outPtr == NULL) { @@ -3547,8 +3571,9 @@ static int RsaPrivateDecryptEx(const byte* in, word32 inLen, byte* out, XMEMCPY(out, pad, (size_t)ret); } } - else + else { *outPtr = pad; + } #if !defined(WOLFSSL_RSA_VERIFY_ONLY) ret = ctMaskSelInt(ctMaskLTE(ret, (int)outLen), ret, RSA_BUFFER_E); diff --git a/wolfcrypt/src/sp_int.c b/wolfcrypt/src/sp_int.c index 06c01ab00..1af72f84e 100644 --- a/wolfcrypt/src/sp_int.c +++ b/wolfcrypt/src/sp_int.c @@ -4770,7 +4770,7 @@ WOLFSSL_LOCAL int sp_ModExp_4096(sp_int* base, sp_int* exp, sp_int* mod, #if defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \ defined(OPENSSL_ALL) -static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp); +static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp, int ct); #endif #if defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \ defined(WOLFCRYPT_HAVE_ECCSI) || defined(WOLFCRYPT_HAVE_SAKKE) || \ @@ -7673,6 +7673,28 @@ int sp_submod(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r) } #endif /* WOLFSSL_SP_MATH_ALL */ +#if (defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)) || \ + (defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \ + defined(WOLFCRYPT_HAVE_ECCSI) || defined(WOLFCRYPT_HAVE_SAKKE) || \ + defined(OPENSSL_ALL)) +/* Constant time clamping/ + * + * @param [in, out] a SP integer to clamp. + */ +static void sp_clamp_ct(sp_int* a) +{ + int i; + unsigned int used = a->used; + unsigned int mask = (unsigned int)-1; + + for (i = a->used-1; i >= 0; i--) { + used -= ((unsigned int)(a->dp[i] == 0)) & mask; + mask &= (unsigned int)0 - (a->dp[i] == 0); + } + a->used = used; +} +#endif + #if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC) /* Add two value and reduce: r = (a + b) % m * @@ -7826,7 +7848,7 @@ int sp_addmod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r) r->sign = MP_ZPOS; #endif /* WOLFSSL_SP_INT_NEGATIVE */ /* Remove leading zeros. */ - sp_clamp(r); + sp_clamp_ct(r); #if 0 sp_print(r, "rma"); @@ -7837,8 +7859,121 @@ int sp_addmod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r) } #endif /* WOLFSSL_SP_MATH_ALL && HAVE_ECC */ +#if (defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC)) || \ + (defined(WOLFSSL_SP_MATH_ALL) || defined(WOLFSSL_HAVE_SP_DH) || \ + defined(WOLFCRYPT_HAVE_ECCSI) || defined(WOLFCRYPT_HAVE_SAKKE) || \ + defined(OPENSSL_ALL)) +/* Sub b from a modulo m: r = (a - b) % m + * + * Result is always positive. + * + * Assumes a, b, m and r are not NULL. + * m and r must not be the same pointer. + * + * @param [in] a SP integer to subtract from + * @param [in] b SP integer to subtract. + * @param [in] m SP integer that is the modulus. + * @param [out] r SP integer to hold result. + * + * @return MP_OKAY on success. + */ +static void _sp_submod_ct(const sp_int* a, const sp_int* b, const sp_int* m, + unsigned int max, sp_int* r) +{ +#ifndef SQR_MUL_ASM + sp_int_sword w; +#else + sp_int_digit l; + sp_int_digit h; + sp_int_digit t; +#endif + sp_int_digit mask; + sp_int_digit mask_a = (sp_int_digit)-1; + sp_int_digit mask_b = (sp_int_digit)-1; + unsigned int i; + + /* In constant time, subtract b from a putting result in r. */ +#ifndef SQR_MUL_ASM + w = 0; +#else + l = 0; + h = 0; +#endif + for (i = 0; i < max; i++) { + /* Values past 'used' are not initialized. */ + mask_a += (i == a->used); + mask_b += (i == b->used); + + #ifndef SQR_MUL_ASM + /* Add a to and subtract b from current value. */ + w += a->dp[i] & mask_a; + w -= b->dp[i] & mask_b; + /* Store low digit in result. */ + r->dp[i] = (sp_int_digit)w; + /* Move high digit down. */ + w >>= DIGIT_BIT; + #else + /* Add a and subtract b from current value. */ + t = a->dp[i] & mask_a; + SP_ASM_ADDC_REG(l, h, t); + t = b->dp[i] & mask_b; + SP_ASM_SUBB_REG(l, h, t); + /* Store low digit in result. */ + r->dp[i] = l; + /* Move high digit down. */ + l = h; + /* High digit is 0 when positive or -1 on negative. */ + h = (sp_int_digit)0 - (l >> (SP_WORD_SIZE - 1)); + #endif + } + /* When w is negative then we need to add modulus to make result + * positive. */ +#ifndef SQR_MUL_ASM + mask = (sp_int_digit)0 - (w < 0); +#else + mask = h; +#endif + + /* Constant time, conditionally, add modulus to difference. */ +#ifndef SQR_MUL_ASM + w = 0; +#else + l = 0; +#endif + for (i = 0; i < m->used; i++) { + #ifndef SQR_MUL_ASM + /* Add result and conditionally modulus to current value. */ + w += r->dp[i]; + w += m->dp[i] & mask; + /* Store low digit in result. */ + r->dp[i] = (sp_int_digit)w; + /* Move high digit down. */ + w >>= DIGIT_BIT; + #else + h = 0; + /* Add result and conditionally modulus to current value. */ + SP_ASM_ADDC(l, h, r->dp[i]); + t = m->dp[i] & mask; + SP_ASM_ADDC_REG(l, h, t); + /* Store low digit in result. */ + r->dp[i] = l; + /* Move high digit down. */ + l = h; + #endif + } + /* Result will always have digits equal to or less than those in + * modulus. */ + r->used = i; +#ifdef WOLFSSL_SP_INT_NEGATIVE + r->sign = MP_ZPOS; +#endif /* WOLFSSL_SP_INT_NEGATIVE */ + /* Remove leading zeros. */ + sp_clamp_ct(r); +} +#endif + #if defined(WOLFSSL_SP_MATH_ALL) && defined(HAVE_ECC) -/* Sub b from a and reduce: r = (a - b) % m +/* Sub b from a modulo m: r = (a - b) % m * Result is always positive. * * r = a - b (mod m) - constant time (a < m and b < m, a, b and m are positive) @@ -7856,17 +7991,6 @@ int sp_addmod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r) int sp_submod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r) { int err = MP_OKAY; -#ifndef SQR_MUL_ASM - sp_int_sword w; -#else - sp_int_digit l; - sp_int_digit h; - sp_int_digit t; -#endif - sp_int_digit mask; - sp_int_digit mask_a = (sp_int_digit)-1; - sp_int_digit mask_b = (sp_int_digit)-1; - unsigned int i; /* Check result is as big as modulus plus one digit. */ if (m->used > r->size) { @@ -7884,82 +8008,7 @@ int sp_submod_ct(const sp_int* a, const sp_int* b, const sp_int* m, sp_int* r) sp_print(m, "m"); #endif - /* In constant time, subtract b from a putting result in r. */ - #ifndef SQR_MUL_ASM - w = 0; - #else - l = 0; - h = 0; - #endif - for (i = 0; i < m->used; i++) { - /* Values past 'used' are not initialized. */ - mask_a += (i == a->used); - mask_b += (i == b->used); - - #ifndef SQR_MUL_ASM - /* Add a to and subtract b from current value. */ - w += a->dp[i] & mask_a; - w -= b->dp[i] & mask_b; - /* Store low digit in result. */ - r->dp[i] = (sp_int_digit)w; - /* Move high digit down. */ - w >>= DIGIT_BIT; - #else - /* Add a and subtract b from current value. */ - t = a->dp[i] & mask_a; - SP_ASM_ADDC_REG(l, h, t); - t = b->dp[i] & mask_b; - SP_ASM_SUBB_REG(l, h, t); - /* Store low digit in result. */ - r->dp[i] = l; - /* Move high digit down. */ - l = h; - /* High digit is 0 when positive or -1 on negative. */ - h = (sp_int_digit)0 - (l >> (SP_WORD_SIZE - 1)); - #endif - } - /* When w is negative then we need to add modulus to make result - * positive. */ - #ifndef SQR_MUL_ASM - mask = (sp_int_digit)0 - (w < 0); - #else - mask = h; - #endif - /* Constant time, conditionally, add modulus to difference. */ - #ifndef SQR_MUL_ASM - w = 0; - #else - l = 0; - #endif - for (i = 0; i < m->used; i++) { - #ifndef SQR_MUL_ASM - /* Add result and conditionally modulus to current value. */ - w += r->dp[i]; - w += m->dp[i] & mask; - /* Store low digit in result. */ - r->dp[i] = (sp_int_digit)w; - /* Move high digit down. */ - w >>= DIGIT_BIT; - #else - h = 0; - /* Add result and conditionally modulus to current value. */ - SP_ASM_ADDC(l, h, r->dp[i]); - t = m->dp[i] & mask; - SP_ASM_ADDC_REG(l, h, t); - /* Store low digit in result. */ - r->dp[i] = l; - /* Move high digit down. */ - l = h; - #endif - } - /* Result will always have digits equal to or less than those in - * modulus. */ - r->used = i; - #ifdef WOLFSSL_SP_INT_NEGATIVE - r->sign = MP_ZPOS; - #endif /* WOLFSSL_SP_INT_NEGATIVE */ - /* Remove leading zeros. */ - sp_clamp(r); + _sp_submod_ct(a, b, m, m->used, r); #if 0 sp_print(r, "rms"); @@ -12377,14 +12426,14 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r, _sp_init_size(pre[i], m->used * 2 + 1); err = sp_sqr(pre[i-1], pre[i]); if (err == MP_OKAY) { - err = _sp_mont_red(pre[i], m, mp); + err = _sp_mont_red(pre[i], m, mp, 0); } /* ..10 -> ..11 */ if (err == MP_OKAY) { err = sp_mul(pre[i], a, pre[i]); } if (err == MP_OKAY) { - err = _sp_mont_red(pre[i], m, mp); + err = _sp_mont_red(pre[i], m, mp, 0); } } } @@ -12438,7 +12487,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r, /* 6.4.2.1. t = (t ^ 2) mod m */ err = sp_sqr(t, t); if (err == MP_OKAY) { - err = _sp_mont_red(t, m, mp); + err = _sp_mont_red(t, m, mp, 0); } } /* 6.4.3. s = 1 - bit */ @@ -12449,7 +12498,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r, err = sp_mul(t, pre[j-1], t); } if (err == MP_OKAY) { - err = _sp_mont_red(t, m, mp); + err = _sp_mont_red(t, m, mp, 0); } /* 6.4.5. j = 0 * Reset number of 1 bits seen. @@ -12465,7 +12514,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r, /* 7.1. t = (t ^ 2) mod m */ err = sp_sqr(t, t); if (err == MP_OKAY) { - err = _sp_mont_red(t, m, mp); + err = _sp_mont_red(t, m, mp, 0); } } } @@ -12474,7 +12523,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r, if (j > 0) { err = sp_mul(t, pre[j-1], r); if (err == MP_OKAY) { - err = _sp_mont_red(r, m, mp); + err = _sp_mont_red(r, m, mp, 0); } } /* 9. Else r = t */ @@ -12887,7 +12936,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, t[3]); err = sp_sqr(t[3], t[3]); if (err == MP_OKAY) { - err = _sp_mont_red(t[3], m, mp); + err = _sp_mont_red(t[3], m, mp, 0); } _sp_copy(t[3], (sp_int*)(((size_t)t[0] & sp_off_on_addr[s^1]) + @@ -12907,7 +12956,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, t[3]); err = sp_mul(t[3], t[2], t[3]); if (err == MP_OKAY) { - err = _sp_mont_red(t[3], m, mp); + err = _sp_mont_red(t[3], m, mp, 0); } _sp_copy(t[3], (sp_int*)(((size_t)t[0] & sp_off_on_addr[j^1]) + @@ -12916,7 +12965,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, } if (err == MP_OKAY) { /* 7. t[1] = FromMont(t[1]) */ - err = _sp_mont_red(t[1], m, mp); + err = _sp_mont_red(t[1], m, mp, 0); /* Reduction implementation returns number to range: 0..m-1. */ } } @@ -13017,7 +13066,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, /* 4.2. t[2] = t[0] * t[1] */ err = sp_mul(t[0], t[1], t[2]); if (err == MP_OKAY) { - err = _sp_mont_red(t[2], m, mp); + err = _sp_mont_red(t[2], m, mp, 0); } /* 4.3. t[3] = t[y] ^ 2 */ if (err == MP_OKAY) { @@ -13027,7 +13076,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, err = sp_sqr(t[3], t[3]); } if (err == MP_OKAY) { - err = _sp_mont_red(t[3], m, mp); + err = _sp_mont_red(t[3], m, mp, 0); } /* 4.4. t[y] = t[3], t[y^1] = t[2] */ if (err == MP_OKAY) { @@ -13037,7 +13086,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, if (err == MP_OKAY) { /* 5. t[0] = FromMont(t[0]) */ - err = _sp_mont_red(t[0], m, mp); + err = _sp_mont_red(t[0], m, mp, 0); /* Reduction implementation returns number to range: 0..m-1. */ } } @@ -13189,7 +13238,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, } /* Montgomery reduce square or multiplication result. */ if (err == MP_OKAY) { - err = _sp_mont_red(t[i], m, mp); + err = _sp_mont_red(t[i], m, mp, 0); } } @@ -13250,7 +13299,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, for (j = 0; (j < winBits) && (err == MP_OKAY); j++) { err = sp_sqr(tr, tr); if (err == MP_OKAY) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); } } @@ -13259,14 +13308,14 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits, err = sp_mul(tr, t[y], tr); } if (err == MP_OKAY) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); } } } if (err == MP_OKAY) { /* 7. tr = FromMont(tr) */ - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); /* Reduction implementation returns number to range: 0..m-1. */ } } @@ -13475,7 +13524,7 @@ static int _sp_exptmod_base_2(const sp_int* e, int digits, const sp_int* m, err = sp_sqr(tr, tr); if (err == MP_OKAY) { if (useMont) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); } else { err = sp_mod(tr, m, tr); @@ -13501,7 +13550,7 @@ static int _sp_exptmod_base_2(const sp_int* e, int digits, const sp_int* m, /* 7. if Words(m) > 1 then tr = FromMont(tr) */ if ((err == MP_OKAY) && useMont) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); /* Reduction implementation returns number to range: 0..m-1. */ } if (err == MP_OKAY) { @@ -13880,7 +13929,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, for (i = 1; (i < winBits) && (err == MP_OKAY); i++) { err = sp_sqr(t[0], t[0]); if (err == MP_OKAY) { - err = _sp_mont_red(t[0], m, mp); + err = _sp_mont_red(t[0], m, mp, 0); } } /* For each table entry after first. */ @@ -13888,7 +13937,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, /* Multiply previous entry by the base in Mont form into table. */ err = sp_mul(t[i-1], bm, t[i]); if (err == MP_OKAY) { - err = _sp_mont_red(t[i], m, mp); + err = _sp_mont_red(t[i], m, mp, 0); } } @@ -13972,7 +14021,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, for (; (err == MP_OKAY) && (sqrs > 0); sqrs--) { err = sp_sqr(tr, tr); if (err == MP_OKAY) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); } } @@ -14013,7 +14062,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, err = sp_mul(tr, t[y], tr); } if (err == MP_OKAY) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); } } @@ -14027,7 +14076,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, /* 5.1. Montogmery square result */ err = sp_sqr(tr, tr); if (err == MP_OKAY) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); } /* 5.2. If exponent bit set */ if ((err == MP_OKAY) && ((n >> c) & 1)) { @@ -14036,7 +14085,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, */ err = sp_mul(tr, bm, tr); if (err == MP_OKAY) { - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); } } } @@ -14045,7 +14094,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, if (err == MP_OKAY) { /* 6. Convert result back from Montgomery form. */ - err = _sp_mont_red(tr, m, mp); + err = _sp_mont_red(tr, m, mp, 0); /* Reduction implementation returns number to range: 0..m-1. */ } } @@ -14141,7 +14190,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, /* 3.1. Montgomery square result. */ err = sp_sqr(t[0], t[0]); if (err == MP_OKAY) { - err = _sp_mont_red(t[0], m, mp); + err = _sp_mont_red(t[0], m, mp, 0); } if (err == MP_OKAY) { /* Get bit and index i. */ @@ -14151,14 +14200,14 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m, /* 3.2.1. Montgomery multiply result by Mont of base. */ err = sp_mul(t[0], t[1], t[0]); if (err == MP_OKAY) { - err = _sp_mont_red(t[0], m, mp); + err = _sp_mont_red(t[0], m, mp, 0); } } } } if (err == MP_OKAY) { /* 4. Convert from Montgomery form. */ - err = _sp_mont_red(t[0], m, mp); + err = _sp_mont_red(t[0], m, mp, 0); /* Reduction implementation returns number of range 0..m-1. */ } } @@ -16995,10 +17044,11 @@ int sp_sqrmod(const sp_int* a, const sp_int* m, sp_int* r) * @param [in,out] a SP integer to Montgomery reduce. * @param [in] m SP integer that is the modulus. * @param [in] mp SP integer digit that is the bottom digit of inv(-m). + * @param [in] ct Indicates operation must be constant time. * * @return MP_OKAY on success. */ -static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) +static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp, int ct) { #if !defined(SQR_MUL_ASM) unsigned int i; @@ -17015,8 +17065,15 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) bits = sp_count_bits(m); /* Adding numbers into m->used * 2 digits - zero out unused digits. */ - for (i = a->used; i < m->used * 2; i++) { - a->dp[i] = 0; + if (!ct) { + for (i = a->used; i < m->used * 2; i++) { + a->dp[i] = 0; + } + } + else { + for (i = 0; i < m->used * 2; i++) { + a->dp[i] &= (sp_int_digit)(sp_int_sdigit)ctMaskIntGTE(a->used-1, i); + } } /* Special case when modulus is 1 digit or less. */ @@ -17087,15 +17144,28 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) a->used = m->used * 2 + 1; } - /* Remove leading zeros. */ - sp_clamp(a); - /* 3. a >>= NumBits(m) */ - (void)sp_rshb(a, bits, a); - - /* 4. a = a mod m */ - if (_sp_cmp_abs(a, m) != MP_LT) { - _sp_sub_off(a, m, a, 0); + if (!ct) { + /* Remove leading zeros. */ + sp_clamp(a); + /* 3. a >>= NumBits(m) */ + (void)sp_rshb(a, bits, a); + /* 4. a = a mod m */ + if (_sp_cmp_abs(a, m) != MP_LT) { + _sp_sub_off(a, m, a, 0); + } } + else { + /* 3. a >>= NumBits(m) */ + (void)sp_rshb(a, bits, a); + /* Constant time clamping. */ + sp_clamp_ct(a); + + /* 4. a = a mod m + * Always subtract but at a too high offset if a is less than m. + */ + _sp_submod_ct(a, m, m, m->used + 1, a); + } + #if 0 sp_print(a, "rr"); @@ -17118,8 +17188,15 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) bits = sp_count_bits(m); mask = ((sp_int_digit)1 << (bits & (SP_WORD_SIZE - 1))) - 1; - for (i = a->used; i < m->used * 2; i++) { - a->dp[i] = 0; + if (!ct) { + for (i = a->used; i < m->used * 2; i++) { + a->dp[i] = 0; + } + } + else { + for (i = 0; i < m->used * 2; i++) { + a->dp[i] &= (sp_int_digit)(sp_int_sdigit)ctMaskIntGTE(a->used-1, i); + } } if (m->used <= 1) { @@ -17398,13 +17475,21 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) a->used = m->used * 2 + 1; } - /* Remove leading zeros. */ - sp_clamp(a); - (void)sp_rshb(a, bits, a); + if (!ct) { + /* Remove leading zeros. */ + sp_clamp(a); + (void)sp_rshb(a, bits, a); + /* a = a mod m */ + if (_sp_cmp_abs(a, m) != MP_LT) { + _sp_sub_off(a, m, a, 0); + } + } + else { + (void)sp_rshb(a, bits, a); + /* Constant time clamping. */ + sp_clamp_ct(a); - /* a = a mod m */ - if (_sp_cmp_abs(a, m) != MP_LT) { - _sp_sub_off(a, m, a, 0); + _sp_submod_ct(a, m, m, m->used + 1, a); } #if 0 @@ -17422,11 +17507,12 @@ static int _sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) * @param [in,out] a SP integer to Montgomery reduce. * @param [in] m SP integer that is the modulus. * @param [in] mp SP integer digit that is the bottom digit of inv(-m). + * @param [in] ct Indicates operation must be constant time. * * @return MP_OKAY on success. * @return MP_VAL when a or m is NULL or m is zero. */ -int sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) +int sp_mont_red_ex(sp_int* a, const sp_int* m, sp_int_digit mp, int ct) { int err; @@ -17440,7 +17526,7 @@ int sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp) } else { /* Perform Montogomery Reduction. */ - err = _sp_mont_red(a, m, mp); + err = _sp_mont_red(a, m, mp, ct); } return err; diff --git a/wolfcrypt/src/tfm.c b/wolfcrypt/src/tfm.c index c2937d2b1..6e3d2fe8b 100644 --- a/wolfcrypt/src/tfm.c +++ b/wolfcrypt/src/tfm.c @@ -6049,15 +6049,8 @@ int mp_read_radix(mp_int *a, const char *str, int radix) #endif /* !defined(NO_DSA) || defined(HAVE_ECC) */ -#ifdef HAVE_ECC +#if defined(HAVE_ECC) || (!defined(NO_RSA) && defined(WC_RSA_BLINDING)) -/* fast math conversion */ -int mp_sqr(fp_int *A, fp_int *B) -{ - return fp_sqr(A, B); -} - -/* fast math conversion */ int mp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp) { return fp_montgomery_reduce(a, m, mp); @@ -6075,6 +6068,17 @@ int mp_montgomery_setup(fp_int *a, fp_digit *rho) return fp_montgomery_setup(a, rho); } +#endif /* HAVE_ECC || (!NO_RSA && WC_RSA_BLINDING) */ + +#ifdef HAVE_ECC + +/* fast math conversion */ +int mp_sqr(fp_int *A, fp_int *B) +{ + return fp_sqr(A, B); +} + +/* fast math conversion */ int mp_div_2(fp_int * a, fp_int * b) { fp_div_2(a, b); diff --git a/wolfssl/wolfcrypt/integer.h b/wolfssl/wolfcrypt/integer.h index a4e742d77..75dc61438 100644 --- a/wolfssl/wolfcrypt/integer.h +++ b/wolfssl/wolfcrypt/integer.h @@ -366,6 +366,7 @@ MP_API int mp_montgomery_setup (mp_int * n, mp_digit * rho); int fast_mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho); MP_API int mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho); #define mp_montgomery_reduce_ex(x, n, rho, ct) mp_montgomery_reduce (x, n, rho) +#define mp_montgomery_reduce_ct(x, n, rho) mp_montgomery_reduce (x, n, rho) MP_API void mp_dr_setup(mp_int *a, mp_digit *d); MP_API int mp_dr_reduce (mp_int * x, mp_int * n, mp_digit k); MP_API int mp_reduce_2k(mp_int *a, mp_int *n, mp_digit d); diff --git a/wolfssl/wolfcrypt/sp_int.h b/wolfssl/wolfcrypt/sp_int.h index 833a6cab6..f0b478945 100644 --- a/wolfssl/wolfcrypt/sp_int.h +++ b/wolfssl/wolfcrypt/sp_int.h @@ -1037,7 +1037,8 @@ MP_API int sp_mul_2d(const sp_int* a, int e, sp_int* r); MP_API int sp_sqr(const sp_int* a, sp_int* r); MP_API int sp_sqrmod(const sp_int* a, const sp_int* m, sp_int* r); -MP_API int sp_mont_red(sp_int* a, const sp_int* m, sp_int_digit mp); +MP_API int sp_mont_red_ex(sp_int* a, const sp_int* m, sp_int_digit mp, int ct); +#define sp_mont_red(a, m, mp) sp_mont_red_ex(a, m, mp, 0) MP_API int sp_mont_setup(const sp_int* m, sp_int_digit* rho); MP_API int sp_mont_norm(sp_int* norm, const sp_int* m); @@ -1085,7 +1086,8 @@ WOLFSSL_LOCAL void sp_memzero_check(sp_int* sp); #define mp_div_3(a, r, rem) sp_div_d(a, 3, r, rem) #define mp_rshb(A,x) sp_rshb(A,x,A) #define mp_is_bit_set(a,b) sp_is_bit_set(a,(unsigned int)(b)) -#define mp_montgomery_reduce sp_mont_red +#define mp_montgomery_reduce(a, m, mp) sp_mont_red_ex(a, m, mp, 0) +#define mp_montgomery_reduce_ct(a, m, mp) sp_mont_red_ex(a, m, mp, 1) #define mp_montgomery_setup sp_mont_setup #define mp_montgomery_calc_normalization sp_mont_norm diff --git a/wolfssl/wolfcrypt/tfm.h b/wolfssl/wolfcrypt/tfm.h index 7e850d3cb..8099a1700 100644 --- a/wolfssl/wolfcrypt/tfm.h +++ b/wolfssl/wolfcrypt/tfm.h @@ -871,12 +871,13 @@ MP_API int mp_radix_size (mp_int * a, int radix, int *size); MP_API int mp_read_radix(mp_int* a, const char* str, int radix); #endif +#define mp_montgomery_reduce_ct(a, m, mp) \ + mp_montgomery_reduce_ex(a, m, mp, 1) +MP_API int mp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp); +MP_API int mp_montgomery_reduce_ex(fp_int *a, fp_int *m, fp_digit mp, int ct); +MP_API int mp_montgomery_setup(fp_int *a, fp_digit *rho); #ifdef HAVE_ECC MP_API int mp_sqr(fp_int *a, fp_int *b); - MP_API int mp_montgomery_reduce(fp_int *a, fp_int *m, fp_digit mp); - MP_API int mp_montgomery_reduce_ex(fp_int *a, fp_int *m, fp_digit mp, - int ct); - MP_API int mp_montgomery_setup(fp_int *a, fp_digit *rho); MP_API int mp_div_2(fp_int * a, fp_int * b); MP_API int mp_div_2_mod_ct(mp_int *a, mp_int *b, mp_int *c); #endif