diff --git a/src/core/include/mp-units/ext/prime.h b/src/core/include/mp-units/ext/prime.h index 51050f41..3f127eb4 100644 --- a/src/core/include/mp-units/ext/prime.h +++ b/src/core/include/mp-units/ext/prime.h @@ -183,6 +183,83 @@ struct NumberDecomposition { return false; } +// The Jacobi symbol, notated as `(a/n)`, is defined for odd positive `n` and any integer `a`, taking values +// in the set `{-1, 0, 1}`. Besides being a completely multiplicative function (so that, for example, both +// (a*b/n) = (a/n) * (b/n), and (a/n*m) = (a/n) * (a/m)), it obeys the following symmetry rules, which enable +// its calculation: +// +// 1. (a/1) = 1, and (1/n) = 1, for all a and n. +// +// 2. (a/n) = 0 whenever a and n have a nontrivial common factor. +// +// 3. (a/n) = (b/n) whenever (a % n) = (b % n). +// +// 4. (2a/n) = (a/n) if n % 8 = 1 or 7, and -(a/n) if n % 8 = 3 or 5. +// +// 5. (a/n) = (n/a) * x if a and n are both odd, positive, and coprime. Here, x is 1 if either (a % 4) = 1 +// or (n % 4) = 1, and -1 otherwise. +// +// 6. (-1/n) = 1 if n % 4 = 1, and -1 if n % 4 = 3. +[[nodiscard]] consteval int jacobi_symbol(int64_t raw_a, uint64_t n) +{ + // Rule 1: n=1 case. + if (n == 1u) { + return 1; + } + + // Starting conditions: transform `a` to strictly non-negative values, setting `result` to the sign that we + // pick up (if any) from following these rules (i.e., rules 3 and 6). + int result = ((raw_a >= 0) || (n % 4u == 1u)) ? 1 : -1; + auto a = static_cast(raw_a < 0 ? -raw_a : raw_a) % n; + + while (a != 0u) { + // Rule 4. + const int sign_for_even = (n % 8u == 1u || n % 8u == 7u) ? 1 : -1; + while (a % 2u == 0u) { + a /= 2u; + result *= sign_for_even; + } + + // Rule 1: a=1 case. + if (a == 1u) { + return result; + } + + // Rule 2. + if (std::gcd(a, n) != 1u) { + return 0; + } + + // Note that at this point, we know that `a` and `n` are coprime, and are both odd and positive. + // Therefore, we meet the preconditions for rule 5 (the "flip-and-reduce" rule). + result *= (n % 4u == 1u || a % 4u == 1u) ? 1 : -1; + const uint64_t new_a = n % a; + n = a; + a = new_a; + } + + return 0; +} + +[[nodiscard]] consteval bool is_perfect_square(uint64_t n) +{ + if (n < 2u) { + return true; + } + + uint64_t prev = n / 2u; + while (true) { + const uint64_t curr = (prev + n / prev) / 2u; + if (curr * curr == n) { + return true; + } + if (curr >= prev) { + return false; + } + prev = curr; + } +} + [[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n) { for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) { diff --git a/test/static/prime_test.cpp b/test/static/prime_test.cpp index bfd7aae5..e669ab4e 100644 --- a/test/static/prime_test.cpp +++ b/test/static/prime_test.cpp @@ -122,4 +122,47 @@ static_assert(miller_rabin_probable_prime(2u, 9'007'199'254'740'881u), "Large kn static_assert(miller_rabin_probable_prime(2u, 18'446'744'073'709'551'557u), "Largest 64-bit prime"); +// Jacobi symbols --- a building block for the Strong Lucas probable prime test, needed for Baillie-PSW. +static_assert(jacobi_symbol(1, 1u) == 1, "Jacobi symbol always 1 when 'numerator' is 1"); +static_assert(jacobi_symbol(1, 3u) == 1, "Jacobi symbol always 1 when 'numerator' is 1"); +static_assert(jacobi_symbol(1, 5u) == 1, "Jacobi symbol always 1 when 'numerator' is 1"); +static_assert(jacobi_symbol(1, 987654321u) == 1, "Jacobi symbol always 1 when 'numerator' is 1"); + +static_assert(jacobi_symbol(3, 1u) == 1, "Jacobi symbol always 1 when 'denominator' is 1"); +static_assert(jacobi_symbol(5, 1u) == 1, "Jacobi symbol always 1 when 'denominator' is 1"); +static_assert(jacobi_symbol(-1234567890, 1u) == 1, "Jacobi symbol always 1 when 'denominator' is 1"); + +static_assert(jacobi_symbol(10, 5u) == 0, "Jacobi symbol always 0 when there's a common factor"); +static_assert(jacobi_symbol(25, 15u) == 0, "Jacobi symbol always 0 when there's a common factor"); +static_assert(jacobi_symbol(-24, 9u) == 0, "Jacobi symbol always 0 when there's a common factor"); + +static_assert(jacobi_symbol(14, 9u) == +jacobi_symbol(7, 9u), + "Divide numerator by 2: positive when (denom % 8) in {1, 7}"); +static_assert(jacobi_symbol(14, 15u) == +jacobi_symbol(7, 15u), + "Divide numerator by 2: positive when (denom % 8) in {1, 7}"); +static_assert(jacobi_symbol(14, 11u) == -jacobi_symbol(7, 11u), + "Divide numerator by 2: negative when (denom % 8) in {3, 5}"); +static_assert(jacobi_symbol(14, 13u) == -jacobi_symbol(7, 13u), + "Divide numerator by 2: negative when (denom % 8) in {3, 5}"); + +static_assert(jacobi_symbol(19, 9u) == +jacobi_symbol(9, 19u), "Flip is identity when (n % 4) = 1"); +static_assert(jacobi_symbol(17, 7u) == +jacobi_symbol(7, 17u), "Flip is identity when (a % 4) = 1"); +static_assert(jacobi_symbol(19, 7u) == -jacobi_symbol(9, 7u), "Flip changes sign when (n % 4) = 3 and (a % 4) = 3"); + +static_assert(jacobi_symbol(1001, 9907u) == -1, "Example from Wikipedia page"); +static_assert(jacobi_symbol(19, 45u) == 1, "Example from Wikipedia page"); +static_assert(jacobi_symbol(8, 21u) == -1, "Example from Wikipedia page"); +static_assert(jacobi_symbol(5, 21u) == 1, "Example from Wikipedia page"); + +// Tests for perfect square finder +static_assert(is_perfect_square(0u)); +static_assert(is_perfect_square(1u)); +static_assert(!is_perfect_square(2u)); +static_assert(is_perfect_square(4u)); + +constexpr uint64_t BIG_SQUARE = [](auto x) { return x * x; }((uint64_t{1u} << 32) - 1u); +static_assert(!is_perfect_square(BIG_SQUARE - 1u)); +static_assert(is_perfect_square(BIG_SQUARE)); +static_assert(!is_perfect_square(BIG_SQUARE + 1u)); + } // namespace