Merge pull request #4943 from SparkiDev/sp_arm64_perf_1

SP ASM performance improvements
This commit is contained in:
David Garske
2022-03-14 18:40:51 -07:00
committed by GitHub
11 changed files with 39443 additions and 34794 deletions
+3 -3
View File
@@ -5964,7 +5964,7 @@ do
ENABLED_SP_FF_3072=yes ENABLED_SP_FF_3072=yes
ENABLED_SP_ECC=yes ENABLED_SP_ECC=yes
ENABLED_SP_EC_256=yes ENABLED_SP_EC_256=yes
if test "$host_cpu" = "x86_64"; then if test "$host_cpu" = "x86_64" || test "$host_cpu" = "aarch64"; then
ENABLED_SP_FF_4096=yes ENABLED_SP_FF_4096=yes
ENABLED_SP_EC_384=yes ENABLED_SP_EC_384=yes
ENABLED_SP_EC_521=yes ENABLED_SP_EC_521=yes
@@ -5979,7 +5979,7 @@ do
ENABLED_SP_FF_3072=yes ENABLED_SP_FF_3072=yes
ENABLED_SP_ECC=yes ENABLED_SP_ECC=yes
ENABLED_SP_EC_256=yes ENABLED_SP_EC_256=yes
if test "$host_cpu" = "x86_64"; then if test "$host_cpu" = "x86_64" || test "$host_cpu" = "aarch64"; then
ENABLED_SP_FF_4096=yes ENABLED_SP_FF_4096=yes
ENABLED_SP_EC_384=yes ENABLED_SP_EC_384=yes
ENABLED_SP_EC_521=yes ENABLED_SP_EC_521=yes
@@ -5994,7 +5994,7 @@ do
ENABLED_SP_FF_3072=yes ENABLED_SP_FF_3072=yes
ENABLED_SP_ECC=yes ENABLED_SP_ECC=yes
ENABLED_SP_EC_256=yes ENABLED_SP_EC_256=yes
if test "$host_cpu" = "x86_64"; then if test "$host_cpu" = "x86_64" || test "$host_cpu" = "aarch64"; then
ENABLED_SP_FF_4096=yes ENABLED_SP_FF_4096=yes
ENABLED_SP_EC_384=yes ENABLED_SP_EC_384=yes
ENABLED_SP_EC_521=yes ENABLED_SP_EC_521=yes
+1
View File
@@ -6172,6 +6172,7 @@ int wc_ecc_sign_hash_ex(const byte* in, word32 inlen, WC_RNG* rng,
#endif #endif
} }
#endif #endif
(void)sign_k;
} }
#else #else
(void)inlen; (void)inlen;
+2567 -1713
View File
File diff suppressed because it is too large Load Diff
+11186 -9086
View File
File diff suppressed because it is too large Load Diff
+8191 -7534
View File
File diff suppressed because it is too large Load Diff
+825 -755
View File
File diff suppressed because it is too large Load Diff
+504 -424
View File
File diff suppressed because it is too large Load Diff
+1515 -769
View File
File diff suppressed because it is too large Load Diff
+214 -73
View File
@@ -48,19 +48,19 @@
#include <wolfssl/wolfcrypt/sp.h> #include <wolfssl/wolfcrypt/sp.h>
#ifdef WOLFSSL_SP_X86_64_ASM #ifdef WOLFSSL_SP_X86_64_ASM
#define SP_PRINT_NUM(var, name, total, words, bits) \ #define SP_PRINT_NUM(var, name, total, words, bits) \
do { \ do { \
int ii; \ int ii; \
fprintf(stderr, name "=0x"); \ fprintf(stderr, name "=0x"); \
for (ii = words - 1; ii >= 0; ii--) \ for (ii = ((bits + 63) / 64) - 1; ii >= 0; ii--) \
fprintf(stderr, SP_PRINT_FMT, (var)[ii]); \ fprintf(stderr, SP_PRINT_FMT, (var)[ii]); \
fprintf(stderr, "\n"); \ fprintf(stderr, "\n"); \
} while (0) } while (0)
#define SP_PRINT_VAL(var, name) \ #define SP_PRINT_VAL(var, name) \
fprintf(stderr, name "=0x" SP_PRINT_FMT "\n", var) fprintf(stderr, name "=0x" SP_PRINT_FMT "\n", var)
#define SP_PRINT_INT(var, name) \ #define SP_PRINT_INT(var, name) \
fprintf(stderr, name "=%d\n", var) fprintf(stderr, name "=%d\n", var)
#if defined(WOLFSSL_HAVE_SP_RSA) || defined(WOLFSSL_HAVE_SP_DH) #if defined(WOLFSSL_HAVE_SP_RSA) || defined(WOLFSSL_HAVE_SP_DH)
@@ -212,19 +212,19 @@ static void sp_2048_to_bin_32(sp_digit* r, byte* a)
#define sp_2048_norm_32(a) #define sp_2048_norm_32(a)
extern void sp_2048_mul_16(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_2048_mul_16(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_2048_sqr_16(sp_digit* r, const sp_digit* a);
extern void sp_2048_mul_avx2_16(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_2048_mul_avx2_16(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_2048_sqr_avx2_16(sp_digit* r, const sp_digit* a);
extern sp_digit sp_2048_add_16(sp_digit* r, const sp_digit* a, const sp_digit* b); extern sp_digit sp_2048_add_16(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern sp_digit sp_2048_sub_in_place_32(sp_digit* a, const sp_digit* b); extern sp_digit sp_2048_sub_in_place_32(sp_digit* a, const sp_digit* b);
extern sp_digit sp_2048_add_32(sp_digit* r, const sp_digit* a, const sp_digit* b); extern sp_digit sp_2048_add_32(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_2048_mul_32(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_2048_mul_32(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_2048_mul_avx2_32(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_2048_sqr_16(sp_digit* r, const sp_digit* a);
extern void sp_2048_sqr_avx2_16(sp_digit* r, const sp_digit* a);
extern sp_digit sp_2048_dbl_16(sp_digit* r, const sp_digit* a); extern sp_digit sp_2048_dbl_16(sp_digit* r, const sp_digit* a);
extern void sp_2048_sqr_32(sp_digit* r, const sp_digit* a); extern void sp_2048_sqr_32(sp_digit* r, const sp_digit* a);
extern void sp_2048_mul_avx2_32(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_2048_sqr_avx2_32(sp_digit* r, const sp_digit* a); extern void sp_2048_sqr_avx2_32(sp_digit* r, const sp_digit* a);
#if (defined(WOLFSSL_HAVE_SP_RSA) && !defined(WOLFSSL_RSA_PUBLIC_ONLY)) || defined(WOLFSSL_HAVE_SP_DH) #if (defined(WOLFSSL_HAVE_SP_RSA) && !defined(WOLFSSL_RSA_PUBLIC_ONLY)) || defined(WOLFSSL_HAVE_SP_DH)
@@ -281,7 +281,7 @@ extern void sp_2048_mont_reduce_16(sp_digit* a, const sp_digit* m, sp_digit mp);
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_mul_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_mul_16(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_2048_mul_16(r, a, b); sp_2048_mul_16(r, a, b);
@@ -295,7 +295,7 @@ static void sp_2048_mont_mul_16(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_sqr_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_sqr_16(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_2048_sqr_16(r, a); sp_2048_sqr_16(r, a);
@@ -413,7 +413,7 @@ static WC_INLINE int sp_2048_div_16(const sp_digit* a, const sp_digit* d, sp_dig
else else
#endif #endif
sp_2048_cond_sub_16(&t1[16], &t1[16], d, (sp_digit)0 - r1); sp_2048_cond_sub_16(&t1[16], &t1[16], d, (sp_digit)0 - r1);
for (i=15; i>=0; i--) { for (i = 15; i >= 0; i--) {
sp_digit hi = t1[16 + i] - (t1[16 + i] == div); sp_digit hi = t1[16 + i] - (t1[16 + i] == div);
r1 = div_2048_word_16(hi, t1[16 + i - 1], div); r1 = div_2048_word_16(hi, t1[16 + i - 1], div);
@@ -658,7 +658,7 @@ extern void sp_2048_mont_reduce_avx2_16(sp_digit* a, const sp_digit* m, sp_digit
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_mul_avx2_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_mul_avx2_16(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_2048_mul_avx2_16(r, a, b); sp_2048_mul_avx2_16(r, a, b);
@@ -674,7 +674,7 @@ static void sp_2048_mont_mul_avx2_16(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_sqr_avx2_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_sqr_avx2_16(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_2048_sqr_avx2_16(r, a); sp_2048_sqr_avx2_16(r, a);
@@ -906,7 +906,7 @@ extern void sp_2048_mont_reduce_32(sp_digit* a, const sp_digit* m, sp_digit mp);
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_mul_32(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_mul_32(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_2048_mul_32(r, a, b); sp_2048_mul_32(r, a, b);
@@ -920,7 +920,7 @@ static void sp_2048_mont_mul_32(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_sqr_32(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_sqr_32(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_2048_sqr_32(r, a); sp_2048_sqr_32(r, a);
@@ -1006,9 +1006,13 @@ static WC_INLINE int sp_2048_div_32_cond(const sp_digit* a, const sp_digit* d, s
if (t1[i + 32] >= d[i]) { if (t1[i + 32] >= d[i]) {
sp_2048_sub_in_place_32(&t1[32], d); sp_2048_sub_in_place_32(&t1[32], d);
} }
for (i=31; i>=0; i--) { for (i = 31; i >= 0; i--) {
sp_digit hi = t1[32 + i] - (t1[32 + i] == div); if (t1[32 + i] == div) {
r1 = div_2048_word_32(hi, t1[32 + i - 1], div); r1 = SP_DIGIT_MAX;
}
else {
r1 = div_2048_word_32(t1[32 + i], t1[32 + i - 1], div);
}
#ifdef HAVE_INTEL_AVX2 #ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags))
@@ -1120,7 +1124,7 @@ static WC_INLINE int sp_2048_div_32(const sp_digit* a, const sp_digit* d, sp_dig
else else
#endif #endif
sp_2048_cond_sub_32(&t1[32], &t1[32], d, (sp_digit)0 - r1); sp_2048_cond_sub_32(&t1[32], &t1[32], d, (sp_digit)0 - r1);
for (i=31; i>=0; i--) { for (i = 31; i >= 0; i--) {
sp_digit hi = t1[32 + i] - (t1[32 + i] == div); sp_digit hi = t1[32 + i] - (t1[32 + i] == div);
r1 = div_2048_word_32(hi, t1[32 + i - 1], div); r1 = div_2048_word_32(hi, t1[32 + i - 1], div);
@@ -1350,7 +1354,7 @@ extern void sp_2048_mont_reduce_avx2_32(sp_digit* a, const sp_digit* m, sp_digit
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_mul_avx2_32(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_mul_avx2_32(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_2048_mul_avx2_32(r, a, b); sp_2048_mul_avx2_32(r, a, b);
@@ -1366,7 +1370,7 @@ static void sp_2048_mont_mul_avx2_32(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_2048_mont_sqr_avx2_32(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_2048_mont_sqr_avx2_32(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_2048_sqr_avx2_32(r, a); sp_2048_sqr_avx2_32(r, a);
@@ -1622,7 +1626,50 @@ int sp_RsaPublic_2048(const byte* in, word32 inLen, const mp_int* em,
if (err == MP_OKAY) { if (err == MP_OKAY) {
sp_2048_from_mp(m, 32, mm); sp_2048_from_mp(m, 32, mm);
if (e == 0x3) { if (e == 0x10001) {
int i;
sp_digit mp;
sp_2048_mont_setup(m, &mp);
/* Convert to Montgomery form. */
XMEMSET(a, 0, sizeof(sp_digit) * 32);
err = sp_2048_mod_32_cond(r, a, m);
/* Montgomery form: r = a.R mod m */
if (err == MP_OKAY) {
/* r = a ^ 0x10000 => r = a squared 16 times */
#ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) {
for (i = 15; i >= 0; i--) {
sp_2048_mont_sqr_avx2_32(r, r, m, mp);
}
/* mont_red(r.R.R) = (r.R.R / R) mod m = r.R mod m
* mont_red(r.R * a) = (r.R.a / R) mod m = r.a mod m
*/
sp_2048_mont_mul_avx2_32(r, r, ah, m, mp);
}
else
#endif
{
for (i = 15; i >= 0; i--) {
sp_2048_mont_sqr_32(r, r, m, mp);
}
/* mont_red(r.R.R) = (r.R.R / R) mod m = r.R mod m
* mont_red(r.R * a) = (r.R.a / R) mod m = r.a mod m
*/
sp_2048_mont_mul_32(r, r, ah, m, mp);
}
for (i = 31; i > 0; i--) {
if (r[i] != m[i])
break;
}
if (r[i] >= m[i])
sp_2048_sub_in_place_32(r, m);
}
}
else if (e == 0x3) {
#ifdef HAVE_INTEL_AVX2 #ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) { if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) {
if (err == MP_OKAY) { if (err == MP_OKAY) {
@@ -2751,30 +2798,30 @@ static void sp_3072_to_bin_48(sp_digit* r, byte* a)
#define sp_3072_norm_48(a) #define sp_3072_norm_48(a)
extern void sp_3072_mul_12(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_3072_mul_12(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_sqr_12(sp_digit* r, const sp_digit* a);
extern void sp_3072_mul_avx2_12(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_3072_mul_avx2_12(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_sqr_avx2_12(sp_digit* r, const sp_digit* a);
extern sp_digit sp_3072_add_12(sp_digit* r, const sp_digit* a, const sp_digit* b); extern sp_digit sp_3072_add_12(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern sp_digit sp_3072_sub_in_place_24(sp_digit* a, const sp_digit* b); extern sp_digit sp_3072_sub_in_place_24(sp_digit* a, const sp_digit* b);
extern sp_digit sp_3072_add_24(sp_digit* r, const sp_digit* a, const sp_digit* b); extern sp_digit sp_3072_add_24(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_mul_24(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_3072_mul_24(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern sp_digit sp_3072_dbl_12(sp_digit* r, const sp_digit* a);
extern void sp_3072_sqr_24(sp_digit* r, const sp_digit* a);
extern void sp_3072_mul_avx2_24(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_3072_mul_avx2_24(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_sqr_avx2_24(sp_digit* r, const sp_digit* a);
extern sp_digit sp_3072_sub_in_place_48(sp_digit* a, const sp_digit* b); extern sp_digit sp_3072_sub_in_place_48(sp_digit* a, const sp_digit* b);
extern sp_digit sp_3072_add_48(sp_digit* r, const sp_digit* a, const sp_digit* b); extern sp_digit sp_3072_add_48(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_mul_48(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_3072_mul_48(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_mul_avx2_48(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_sqr_12(sp_digit* r, const sp_digit* a);
extern void sp_3072_sqr_avx2_12(sp_digit* r, const sp_digit* a);
extern sp_digit sp_3072_dbl_12(sp_digit* r, const sp_digit* a);
extern void sp_3072_sqr_24(sp_digit* r, const sp_digit* a);
extern void sp_3072_sqr_avx2_24(sp_digit* r, const sp_digit* a);
extern sp_digit sp_3072_dbl_24(sp_digit* r, const sp_digit* a); extern sp_digit sp_3072_dbl_24(sp_digit* r, const sp_digit* a);
extern void sp_3072_sqr_48(sp_digit* r, const sp_digit* a); extern void sp_3072_sqr_48(sp_digit* r, const sp_digit* a);
extern void sp_3072_mul_avx2_48(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_3072_sqr_avx2_48(sp_digit* r, const sp_digit* a); extern void sp_3072_sqr_avx2_48(sp_digit* r, const sp_digit* a);
#if (defined(WOLFSSL_HAVE_SP_RSA) && !defined(WOLFSSL_RSA_PUBLIC_ONLY)) || defined(WOLFSSL_HAVE_SP_DH) #if (defined(WOLFSSL_HAVE_SP_RSA) && !defined(WOLFSSL_RSA_PUBLIC_ONLY)) || defined(WOLFSSL_HAVE_SP_DH)
@@ -2830,7 +2877,7 @@ extern void sp_3072_mont_reduce_24(sp_digit* a, const sp_digit* m, sp_digit mp);
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_mul_24(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_mul_24(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_3072_mul_24(r, a, b); sp_3072_mul_24(r, a, b);
@@ -2844,7 +2891,7 @@ static void sp_3072_mont_mul_24(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_sqr_24(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_sqr_24(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_3072_sqr_24(r, a); sp_3072_sqr_24(r, a);
@@ -2962,7 +3009,7 @@ static WC_INLINE int sp_3072_div_24(const sp_digit* a, const sp_digit* d, sp_dig
else else
#endif #endif
sp_3072_cond_sub_24(&t1[24], &t1[24], d, (sp_digit)0 - r1); sp_3072_cond_sub_24(&t1[24], &t1[24], d, (sp_digit)0 - r1);
for (i=23; i>=0; i--) { for (i = 23; i >= 0; i--) {
sp_digit hi = t1[24 + i] - (t1[24 + i] == div); sp_digit hi = t1[24 + i] - (t1[24 + i] == div);
r1 = div_3072_word_24(hi, t1[24 + i - 1], div); r1 = div_3072_word_24(hi, t1[24 + i - 1], div);
@@ -3207,7 +3254,7 @@ extern void sp_3072_mont_reduce_avx2_24(sp_digit* a, const sp_digit* m, sp_digit
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_mul_avx2_24(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_mul_avx2_24(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_3072_mul_avx2_24(r, a, b); sp_3072_mul_avx2_24(r, a, b);
@@ -3223,7 +3270,7 @@ static void sp_3072_mont_mul_avx2_24(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_sqr_avx2_24(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_sqr_avx2_24(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_3072_sqr_avx2_24(r, a); sp_3072_sqr_avx2_24(r, a);
@@ -3455,7 +3502,7 @@ extern void sp_3072_mont_reduce_48(sp_digit* a, const sp_digit* m, sp_digit mp);
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_mul_48(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_mul_48(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_3072_mul_48(r, a, b); sp_3072_mul_48(r, a, b);
@@ -3469,7 +3516,7 @@ static void sp_3072_mont_mul_48(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_sqr_48(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_sqr_48(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_3072_sqr_48(r, a); sp_3072_sqr_48(r, a);
@@ -3555,9 +3602,13 @@ static WC_INLINE int sp_3072_div_48_cond(const sp_digit* a, const sp_digit* d, s
if (t1[i + 48] >= d[i]) { if (t1[i + 48] >= d[i]) {
sp_3072_sub_in_place_48(&t1[48], d); sp_3072_sub_in_place_48(&t1[48], d);
} }
for (i=47; i>=0; i--) { for (i = 47; i >= 0; i--) {
sp_digit hi = t1[48 + i] - (t1[48 + i] == div); if (t1[48 + i] == div) {
r1 = div_3072_word_48(hi, t1[48 + i - 1], div); r1 = SP_DIGIT_MAX;
}
else {
r1 = div_3072_word_48(t1[48 + i], t1[48 + i - 1], div);
}
#ifdef HAVE_INTEL_AVX2 #ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags))
@@ -3669,7 +3720,7 @@ static WC_INLINE int sp_3072_div_48(const sp_digit* a, const sp_digit* d, sp_dig
else else
#endif #endif
sp_3072_cond_sub_48(&t1[48], &t1[48], d, (sp_digit)0 - r1); sp_3072_cond_sub_48(&t1[48], &t1[48], d, (sp_digit)0 - r1);
for (i=47; i>=0; i--) { for (i = 47; i >= 0; i--) {
sp_digit hi = t1[48 + i] - (t1[48 + i] == div); sp_digit hi = t1[48 + i] - (t1[48 + i] == div);
r1 = div_3072_word_48(hi, t1[48 + i - 1], div); r1 = div_3072_word_48(hi, t1[48 + i - 1], div);
@@ -3899,7 +3950,7 @@ extern void sp_3072_mont_reduce_avx2_48(sp_digit* a, const sp_digit* m, sp_digit
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_mul_avx2_48(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_mul_avx2_48(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_3072_mul_avx2_48(r, a, b); sp_3072_mul_avx2_48(r, a, b);
@@ -3915,7 +3966,7 @@ static void sp_3072_mont_mul_avx2_48(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_3072_mont_sqr_avx2_48(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_3072_mont_sqr_avx2_48(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_3072_sqr_avx2_48(r, a); sp_3072_sqr_avx2_48(r, a);
@@ -4171,7 +4222,50 @@ int sp_RsaPublic_3072(const byte* in, word32 inLen, const mp_int* em,
if (err == MP_OKAY) { if (err == MP_OKAY) {
sp_3072_from_mp(m, 48, mm); sp_3072_from_mp(m, 48, mm);
if (e == 0x3) { if (e == 0x10001) {
int i;
sp_digit mp;
sp_3072_mont_setup(m, &mp);
/* Convert to Montgomery form. */
XMEMSET(a, 0, sizeof(sp_digit) * 48);
err = sp_3072_mod_48_cond(r, a, m);
/* Montgomery form: r = a.R mod m */
if (err == MP_OKAY) {
/* r = a ^ 0x10000 => r = a squared 16 times */
#ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) {
for (i = 15; i >= 0; i--) {
sp_3072_mont_sqr_avx2_48(r, r, m, mp);
}
/* mont_red(r.R.R) = (r.R.R / R) mod m = r.R mod m
* mont_red(r.R * a) = (r.R.a / R) mod m = r.a mod m
*/
sp_3072_mont_mul_avx2_48(r, r, ah, m, mp);
}
else
#endif
{
for (i = 15; i >= 0; i--) {
sp_3072_mont_sqr_48(r, r, m, mp);
}
/* mont_red(r.R.R) = (r.R.R / R) mod m = r.R mod m
* mont_red(r.R * a) = (r.R.a / R) mod m = r.a mod m
*/
sp_3072_mont_mul_48(r, r, ah, m, mp);
}
for (i = 47; i > 0; i--) {
if (r[i] != m[i])
break;
}
if (r[i] >= m[i])
sp_3072_sub_in_place_48(r, m);
}
}
else if (e == 0x3) {
#ifdef HAVE_INTEL_AVX2 #ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) { if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) {
if (err == MP_OKAY) { if (err == MP_OKAY) {
@@ -5303,11 +5397,11 @@ extern sp_digit sp_4096_sub_in_place_64(sp_digit* a, const sp_digit* b);
extern sp_digit sp_4096_add_64(sp_digit* r, const sp_digit* a, const sp_digit* b); extern sp_digit sp_4096_add_64(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_4096_mul_64(sp_digit* r, const sp_digit* a, const sp_digit* b); extern void sp_4096_mul_64(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_4096_mul_avx2_64(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern sp_digit sp_2048_dbl_32(sp_digit* r, const sp_digit* a); extern sp_digit sp_2048_dbl_32(sp_digit* r, const sp_digit* a);
extern void sp_4096_sqr_64(sp_digit* r, const sp_digit* a); extern void sp_4096_sqr_64(sp_digit* r, const sp_digit* a);
extern void sp_4096_mul_avx2_64(sp_digit* r, const sp_digit* a, const sp_digit* b);
extern void sp_4096_sqr_avx2_64(sp_digit* r, const sp_digit* a); extern void sp_4096_sqr_avx2_64(sp_digit* r, const sp_digit* a);
/* Caclulate the bottom digit of -1/a mod 2^n. /* Caclulate the bottom digit of -1/a mod 2^n.
@@ -5361,7 +5455,7 @@ extern void sp_4096_mont_reduce_64(sp_digit* a, const sp_digit* m, sp_digit mp);
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_4096_mont_mul_64(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_4096_mont_mul_64(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_4096_mul_64(r, a, b); sp_4096_mul_64(r, a, b);
@@ -5375,7 +5469,7 @@ static void sp_4096_mont_mul_64(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_4096_mont_sqr_64(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_4096_mont_sqr_64(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_4096_sqr_64(r, a); sp_4096_sqr_64(r, a);
@@ -5461,9 +5555,13 @@ static WC_INLINE int sp_4096_div_64_cond(const sp_digit* a, const sp_digit* d, s
if (t1[i + 64] >= d[i]) { if (t1[i + 64] >= d[i]) {
sp_4096_sub_in_place_64(&t1[64], d); sp_4096_sub_in_place_64(&t1[64], d);
} }
for (i=63; i>=0; i--) { for (i = 63; i >= 0; i--) {
sp_digit hi = t1[64 + i] - (t1[64 + i] == div); if (t1[64 + i] == div) {
r1 = div_4096_word_64(hi, t1[64 + i - 1], div); r1 = SP_DIGIT_MAX;
}
else {
r1 = div_4096_word_64(t1[64 + i], t1[64 + i - 1], div);
}
#ifdef HAVE_INTEL_AVX2 #ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags))
@@ -5575,7 +5673,7 @@ static WC_INLINE int sp_4096_div_64(const sp_digit* a, const sp_digit* d, sp_dig
else else
#endif #endif
sp_4096_cond_sub_64(&t1[64], &t1[64], d, (sp_digit)0 - r1); sp_4096_cond_sub_64(&t1[64], &t1[64], d, (sp_digit)0 - r1);
for (i=63; i>=0; i--) { for (i = 63; i >= 0; i--) {
sp_digit hi = t1[64 + i] - (t1[64 + i] == div); sp_digit hi = t1[64 + i] - (t1[64 + i] == div);
r1 = div_4096_word_64(hi, t1[64 + i - 1], div); r1 = div_4096_word_64(hi, t1[64 + i - 1], div);
@@ -5805,7 +5903,7 @@ extern void sp_4096_mont_reduce_avx2_64(sp_digit* a, const sp_digit* m, sp_digit
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_4096_mont_mul_avx2_64(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_4096_mont_mul_avx2_64(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_4096_mul_avx2_64(r, a, b); sp_4096_mul_avx2_64(r, a, b);
@@ -5821,7 +5919,7 @@ static void sp_4096_mont_mul_avx2_64(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_4096_mont_sqr_avx2_64(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_4096_mont_sqr_avx2_64(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_4096_sqr_avx2_64(r, a); sp_4096_sqr_avx2_64(r, a);
@@ -6077,7 +6175,50 @@ int sp_RsaPublic_4096(const byte* in, word32 inLen, const mp_int* em,
if (err == MP_OKAY) { if (err == MP_OKAY) {
sp_4096_from_mp(m, 64, mm); sp_4096_from_mp(m, 64, mm);
if (e == 0x3) { if (e == 0x10001) {
int i;
sp_digit mp;
sp_4096_mont_setup(m, &mp);
/* Convert to Montgomery form. */
XMEMSET(a, 0, sizeof(sp_digit) * 64);
err = sp_4096_mod_64_cond(r, a, m);
/* Montgomery form: r = a.R mod m */
if (err == MP_OKAY) {
/* r = a ^ 0x10000 => r = a squared 16 times */
#ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) {
for (i = 15; i >= 0; i--) {
sp_4096_mont_sqr_avx2_64(r, r, m, mp);
}
/* mont_red(r.R.R) = (r.R.R / R) mod m = r.R mod m
* mont_red(r.R * a) = (r.R.a / R) mod m = r.a mod m
*/
sp_4096_mont_mul_avx2_64(r, r, ah, m, mp);
}
else
#endif
{
for (i = 15; i >= 0; i--) {
sp_4096_mont_sqr_64(r, r, m, mp);
}
/* mont_red(r.R.R) = (r.R.R / R) mod m = r.R mod m
* mont_red(r.R * a) = (r.R.a / R) mod m = r.a mod m
*/
sp_4096_mont_mul_64(r, r, ah, m, mp);
}
for (i = 63; i > 0; i--) {
if (r[i] != m[i])
break;
}
if (r[i] >= m[i])
sp_4096_sub_in_place_64(r, m);
}
}
else if (e == 0x3) {
#ifdef HAVE_INTEL_AVX2 #ifdef HAVE_INTEL_AVX2
if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) { if (IS_INTEL_BMI2(cpuid_flags) && IS_INTEL_ADX(cpuid_flags)) {
if (err == MP_OKAY) { if (err == MP_OKAY) {
@@ -23667,7 +23808,7 @@ static WC_INLINE int sp_256_div_4(const sp_digit* a, const sp_digit* d, sp_digit
else else
#endif #endif
sp_256_cond_sub_4(&t1[4], &t1[4], d, (sp_digit)0 - r1); sp_256_cond_sub_4(&t1[4], &t1[4], d, (sp_digit)0 - r1);
for (i=3; i>=0; i--) { for (i = 3; i >= 0; i--) {
sp_digit hi = t1[4 + i] - (t1[4 + i] == div); sp_digit hi = t1[4 + i] - (t1[4 + i] == div);
r1 = div_256_word_4(hi, t1[4 + i - 1], div); r1 = div_256_word_4(hi, t1[4 + i - 1], div);
@@ -25984,7 +26125,7 @@ extern void sp_384_mont_reduce_order_6(sp_digit* a, const sp_digit* m, sp_digit
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_384_mont_mul_6(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_384_mont_mul_6(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_384_mul_6(r, a, b); sp_384_mul_6(r, a, b);
@@ -25998,7 +26139,7 @@ static void sp_384_mont_mul_6(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_384_mont_sqr_6(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_384_mont_sqr_6(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_384_sqr_6(r, a); sp_384_sqr_6(r, a);
@@ -27218,7 +27359,7 @@ extern void sp_384_mont_reduce_order_avx2_6(sp_digit* a, const sp_digit* m, sp_d
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_384_mont_mul_avx2_6(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_384_mont_mul_avx2_6(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_384_mul_avx2_6(r, a, b); sp_384_mul_avx2_6(r, a, b);
@@ -27234,7 +27375,7 @@ static void sp_384_mont_mul_avx2_6(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_384_mont_sqr_avx2_6(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_384_mont_sqr_avx2_6(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_384_sqr_avx2_6(r, a); sp_384_sqr_avx2_6(r, a);
@@ -48285,7 +48426,7 @@ static WC_INLINE int sp_384_div_6(const sp_digit* a, const sp_digit* d, sp_digit
else else
#endif #endif
sp_384_cond_sub_6(&t1[6], &t1[6], d, (sp_digit)0 - r1); sp_384_cond_sub_6(&t1[6], &t1[6], d, (sp_digit)0 - r1);
for (i=5; i>=0; i--) { for (i = 5; i >= 0; i--) {
sp_digit hi = t1[6 + i] - (t1[6 + i] == div); sp_digit hi = t1[6 + i] - (t1[6 + i] == div);
r1 = div_384_word_6(hi, t1[6 + i - 1], div); r1 = div_384_word_6(hi, t1[6 + i - 1], div);
@@ -88947,7 +89088,7 @@ static WC_INLINE int sp_521_div_9(const sp_digit* a, const sp_digit* d, sp_digit
sp_521_lshift_9(sd, d, 55); sp_521_lshift_9(sd, d, 55);
sp_521_lshift_18(t1, t1, 55); sp_521_lshift_18(t1, t1, 55);
for (i=8; i>=0; i--) { for (i = 8; i >= 0; i--) {
sp_digit hi = t1[9 + i] - (t1[9 + i] == div); sp_digit hi = t1[9 + i] - (t1[9 + i] == div);
r1 = div_521_word_9(hi, t1[9 + i - 1], div); r1 = div_521_word_9(hi, t1[9 + i - 1], div);
@@ -91079,7 +91220,7 @@ static WC_INLINE int sp_1024_div_16(const sp_digit* a, const sp_digit* d, sp_dig
else else
#endif #endif
sp_1024_cond_sub_16(&t1[16], &t1[16], d, (sp_digit)0 - r1); sp_1024_cond_sub_16(&t1[16], &t1[16], d, (sp_digit)0 - r1);
for (i=15; i>=0; i--) { for (i = 15; i >= 0; i--) {
sp_digit hi = t1[16 + i] - (t1[16 + i] == div); sp_digit hi = t1[16 + i] - (t1[16 + i] == div);
r1 = div_1024_word_16(hi, t1[16 + i - 1], div); r1 = div_1024_word_16(hi, t1[16 + i - 1], div);
@@ -91414,7 +91555,7 @@ extern void sp_1024_mont_reduce_16(sp_digit* a, const sp_digit* m, sp_digit mp);
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_1024_mont_mul_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_1024_mont_mul_16(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_1024_mul_16(r, a, b); sp_1024_mul_16(r, a, b);
@@ -91428,7 +91569,7 @@ static void sp_1024_mont_mul_16(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_1024_mont_sqr_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_1024_mont_sqr_16(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_1024_sqr_16(r, a); sp_1024_sqr_16(r, a);
@@ -92538,7 +92679,7 @@ extern void sp_1024_mont_reduce_avx2_16(sp_digit* a, const sp_digit* m, sp_digit
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_1024_mont_mul_avx2_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_1024_mont_mul_avx2_16(sp_digit* r, const sp_digit* a,
const sp_digit* b, const sp_digit* m, sp_digit mp) const sp_digit* b, const sp_digit* m, sp_digit mp)
{ {
sp_1024_mul_avx2_16(r, a, b); sp_1024_mul_avx2_16(r, a, b);
@@ -92554,7 +92695,7 @@ static void sp_1024_mont_mul_avx2_16(sp_digit* r, const sp_digit* a,
* m Modulus (prime). * m Modulus (prime).
* mp Montgomery mulitplier. * mp Montgomery mulitplier.
*/ */
static void sp_1024_mont_sqr_avx2_16(sp_digit* r, const sp_digit* a, SP_NOINLINE static void sp_1024_mont_sqr_avx2_16(sp_digit* r, const sp_digit* a,
const sp_digit* m, sp_digit mp) const sp_digit* m, sp_digit mp)
{ {
sp_1024_sqr_avx2_16(r, a); sp_1024_sqr_avx2_16(r, a);
+7321 -7321
View File
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff