Merge pull request #5927 from SparkiDev/sp_math_clz

SP math: use count leading zero instruction
This commit is contained in:
David Garske
2023-01-20 10:33:18 -08:00
committed by GitHub

View File

@@ -473,7 +473,15 @@ This library provides single precision (SP) integer math functions.
"adcq %[c], %[o] \n\t" \
: [l] "+r" (vl), [h] "+r" (vh), [o] "+r" (vo) \
: [a] "r" (va), [b] "r" (vb), [c] "r" (vc) \
: "%rax", "%rdx", "cc" \
: "cc" \
)
/* Index of highest bit set. */
#define SP_ASM_HI_BIT_SET_IDX(va, vi) \
__asm__ __volatile__ ( \
"bsr %[a], %[i] \n\t" \
: [i] "=r" (vi) \
: [a] "r" (va) \
: "cc" \
)
#else
#include <intrin.h>
@@ -609,6 +617,9 @@ This library provides single precision (SP) integer math functions.
_addcarry_u64(c, vo, vc, &vo); \
} \
while (0)
/* Index of highest bit set. */
#define SP_ASM_HI_BIT_SET_IDX(va, vi) \
vi = _BitScanReverse64(va)
#endif
#if !defined(WOLFSSL_SP_DIV_WORD_HALF) && (!defined(_MSC_VER) || \
@@ -809,6 +820,14 @@ static WC_INLINE sp_int_digit sp_div_word(sp_int_digit hi, sp_int_digit lo,
: [a] "r" (va), [b] "r" (vb), [c] "r" (vc) \
: "cc" \
)
/* Index of highest bit set. */
#define SP_ASM_HI_BIT_SET_IDX(va, vi) \
__asm__ __volatile__ ( \
"bsr %[a], %[i] \n\t" \
: [i] "=r" (vi) \
: [a] "r" (va) \
: "cC" \
)
#ifndef WOLFSSL_SP_DIV_WORD_HALF
/* Divide a two digit number by a digit number and return. (hi | lo) / d
@@ -980,6 +999,14 @@ static WC_INLINE sp_int_digit sp_div_word(sp_int_digit hi, sp_int_digit lo,
: [a] "r" (va), [b] "r" (vb), [c] "r" (vc) \
: "cc" \
)
/* Count leading zeros. */
#define SP_ASM_LZCNT(va, vn) \
__asm__ __volatile__ ( \
"clz %[n], %[a] \n\t" \
: [n] "=r" (vn) \
: [a] "r" (va) \
: \
)
#ifndef WOLFSSL_SP_DIV_WORD_HALF
/* Divide a two digit number by a digit number and return. (hi | lo) / d
@@ -1190,6 +1217,16 @@ static WC_INLINE sp_int_digit sp_div_word(sp_int_digit hi, sp_int_digit lo,
: [a] "r" (va), [b] "r" (vb), [c] "r" (vc) \
: "cc" \
)
#if defined(WOLFSSL_SP_ARM_ARCH) && (WOLFSSL_SP_ARM_ARCH < 7)
/* Count leading zeros - instruction only available on ARMv7 and newer. */
#define SP_ASM_LZCNT(va, vn) \
__asm__ __volatile__ ( \
"clz %[n], %[a] \n\t" \
: [n] "=r" (vn) \
: [a] "r" (va) \
: \
)
#endif
#ifndef WOLFSSL_SP_DIV_WORD_HALF
#ifndef WOLFSSL_SP_ARM32_UDIV
@@ -3376,6 +3413,14 @@ static WC_INLINE sp_int_digit sp_div_word(sp_int_digit hi, sp_int_digit lo,
: [a] "r" (va), [b] "r" (vb), [c] "r" (vc) \
: "cc" \
)
/* Count leading zeros. */
#define SP_ASM_LZCNT(va, vn) \
__asm__ __volatile__ ( \
"cntlzd %[n], %[a] \n\t" \
: [n] "=r" (vn) \
: [a] "r" (va) \
: \
)
#define SP_INT_ASM_AVAILABLE
@@ -3523,6 +3568,14 @@ static WC_INLINE sp_int_digit sp_div_word(sp_int_digit hi, sp_int_digit lo,
: [a] "r" (va), [b] "r" (vb), [c] "r" (vc) \
: "cc" \
)
/* Count leading zeros. */
#define SP_ASM_LZCNT(va, vn) \
__asm__ __volatile__ ( \
"cntlzw %[n], %[a] \n\t" \
: [n] "=r" (vn) \
: [a] "r" (va) \
: \
)
#define SP_INT_ASM_AVAILABLE
@@ -5225,6 +5278,35 @@ int sp_count_bits(const sp_int* a)
n = 0;
}
else {
#ifdef SP_ASM_HI_BIT_SET_IDX
sp_int_digit hi;
sp_int_digit d;
/* Get the most significant word. */
d = a->dp[n];
/* Count of bits up to last word. */
n *= SP_WORD_SIZE;
/* Get index of highest set bit. */
SP_ASM_HI_BIT_SET_IDX(d, hi);
/* Add bits up to and including index. */
n += (int)hi + 1;
#elif defined(SP_ASM_LZCNT)
sp_int_digit lz;
sp_int_digit d;
/* Get the most significant word. */
d = a->dp[n];
/* Count of bits up to last word. */
n *= SP_WORD_SIZE;
/* Count number of leading zeros in highest non-zero digit. */
SP_ASM_LZCNT(d, lz);
/* Add non-leading zero bits count. */
n += SP_WORD_SIZE - (int)lz;
#else
sp_int_digit d;
/* Get the most significant word. */
@@ -5249,6 +5331,7 @@ int sp_count_bits(const sp_int* a)
d >>= 1;
}
}
#endif
}
}
@@ -6086,6 +6169,8 @@ static void _sp_div_3(const sp_int* a, sp_int* r, sp_int_digit* rem)
sp_int_digit l = 0;
sp_int_digit tt = 0;
sp_int_digit t = SP_DIV_3_CONST;
sp_int_digit lm = 0;
sp_int_digit hm = 0;
#endif
sp_int_digit tr = 0;
/* Quotient fixup. */
@@ -6111,7 +6196,7 @@ static void _sp_div_3(const sp_int* a, sp_int* r, sp_int_digit* rem)
t = (t >> SP_WORD_SIZE) + (t & SP_MASK);
/* Get top digit after multipling by (2^SP_WORD_SIZE) / 3. */
tt = (t * SP_DIV_3_CONST) >> SP_WORD_SIZE;
/* Subtract trail division. */
/* Subtract trial division. */
tr = (sp_int_digit)(t - (sp_int_word)tt * 3);
#else
/* Sum the digits. */
@@ -6121,11 +6206,11 @@ static void _sp_div_3(const sp_int* a, sp_int* r, sp_int_digit* rem)
/* Sum digits of sum - can get carry. */
SP_ASM_ADDC_REG(l, tt, tr);
/* Multiply digit by (2^SP_WORD_SIZE) / 3. */
SP_ASM_MUL(t, tr, l, t);
SP_ASM_MUL(lm, hm, l, t);
/* Add remainder multiplied by (2^SP_WORD_SIZE) / 3 to top digit. */
tr += tt * SP_DIV_3_CONST;
/* Subtract trail division from digit. */
tr = l - (tr * 3);
hm += tt * SP_DIV_3_CONST;
/* Subtract trial division from digit. */
tr = l - (hm * 3);
#endif
/* tr is 0..5 but need 0..2 */
/* Fix up remainder. */
@@ -6141,14 +6226,14 @@ static void _sp_div_3(const sp_int* a, sp_int* r, sp_int_digit* rem)
t = ((sp_int_word)tr << SP_WORD_SIZE) | a->dp[i];
/* Get top digit after multipling by (2^SP_WORD_SIZE) / 3. */
tt = (t * SP_DIV_3_CONST) >> SP_WORD_SIZE;
/* Subtract trail division. */
/* Subtract trial division. */
tr = (sp_int_digit)(t - (sp_int_word)tt * 3);
#else
/* Multiply digit by (2^SP_WORD_SIZE) / 3. */
SP_ASM_MUL(l, tt, a->dp[i], t);
/* Add remainder multiplied by (2^SP_WORD_SIZE) / 3 to top digit. */
tt += tr * SP_DIV_3_CONST;
/* Subtract trail division from digit. */
/* Subtract trial division from digit. */
tr = a->dp[i] - (tt * 3);
#endif
/* tr is 0..5 but need 0..2 */
@@ -6202,7 +6287,7 @@ static void _sp_div_10(const sp_int* a, sp_int* r, sp_int_digit* rem)
t = ((sp_int_word)tr << SP_WORD_SIZE) | a->dp[i];
/* Get top digit after multipling by (2^SP_WORD_SIZE) / 10. */
tt = (t * SP_DIV_10_CONST) >> SP_WORD_SIZE;
/* Subtract trail division. */
/* Subtract trial division. */
tr = (sp_int_digit)(t - (sp_int_word)tt * 10);
#else
/* Multiply digit by (2^SP_WORD_SIZE) / 10. */
@@ -6210,7 +6295,7 @@ static void _sp_div_10(const sp_int* a, sp_int* r, sp_int_digit* rem)
/* Add remainder multiplied by (2^SP_WORD_SIZE) / 10 to top digit.
*/
tt += tr * SP_DIV_10_CONST;
/* Subtract trail division from digit. */
/* Subtract trial division from digit. */
tr = a->dp[i] - (tt * 10);
#endif
/* tr is 0..99 but need 0..9 */
@@ -6228,7 +6313,7 @@ static void _sp_div_10(const sp_int* a, sp_int* r, sp_int_digit* rem)
t = ((sp_int_word)tr << SP_WORD_SIZE) | a->dp[i];
/* Get top digit after multipling by (2^SP_WORD_SIZE) / 10. */
tt = (t * SP_DIV_10_CONST) >> SP_WORD_SIZE;
/* Subtract trail division. */
/* Subtract trial division. */
tr = (sp_int_digit)(t - (sp_int_word)tt * 10);
#else
/* Multiply digit by (2^SP_WORD_SIZE) / 10. */
@@ -6236,7 +6321,7 @@ static void _sp_div_10(const sp_int* a, sp_int* r, sp_int_digit* rem)
/* Add remainder multiplied by (2^SP_WORD_SIZE) / 10 to top digit.
*/
tt += tr * SP_DIV_10_CONST;
/* Subtract trail division from digit. */
/* Subtract trial division from digit. */
tr = a->dp[i] - (tt * 10);
#endif
/* tr is 0..99 but need 0..9 */
@@ -6292,14 +6377,14 @@ static void _sp_div_small(const sp_int* a, sp_int_digit d, sp_int* r,
t = ((sp_int_word)tr << SP_WORD_SIZE) | a->dp[i];
/* Get top digit after multipling. */
tt = (t * m) >> SP_WORD_SIZE;
/* Subtract trail division. */
/* Subtract trial division. */
tr = (sp_int_digit)(t - tt * d);
#else
/* Multiply digit. */
SP_ASM_MUL(l, tt, a->dp[i], m);
/* Add multiplied remainder to top digit. */
tt += tr * m;
/* Subtract trail division from digit. */
/* Subtract trial division from digit. */
tr = a->dp[i] - (tt * d);
#endif
/* tr < d * d */
@@ -6319,14 +6404,14 @@ static void _sp_div_small(const sp_int* a, sp_int_digit d, sp_int* r,
t = ((sp_int_word)tr << SP_WORD_SIZE) | a->dp[i];
/* Get top digit after multipling. */
tt = (t * m) >> SP_WORD_SIZE;
/* Subtract trail division. */
/* Subtract trial division. */
tr = (sp_int_digit)(t - tt * d);
#else
/* Multiply digit. */
SP_ASM_MUL(l, tt, a->dp[i], m);
/* Add multiplied remainder to top digit. */
tt += tr * m;
/* Subtract trail division from digit. */
/* Subtract trial division from digit. */
tr = a->dp[i] - (tt * d);
#endif
/* tr < d * d */
@@ -7875,7 +7960,7 @@ static int _sp_div(sp_int* a, const sp_int* d, sp_int* r, sp_int* trial)
t = SP_DIGIT_MAX;
}
else {
/* Calculate trail quotient by dividing top word of dividend by top
/* Calculate trial quotient by dividing top word of dividend by top
* digit of divisor.
* Some implementations segfault when quotient > SP_DIGIT_MAX.
* Implementations in assembly, using builtins or using