diff --git a/src/core/include/mp-units/ext/prime.h b/src/core/include/mp-units/ext/prime.h index 5b372fe9..bb48956b 100644 --- a/src/core/include/mp-units/ext/prime.h +++ b/src/core/include/mp-units/ext/prime.h @@ -42,6 +42,86 @@ import std; namespace mp_units::detail { +// (a + b) % n. +// +// Precondition: (a < n). +// Precondition: (b < n). +// Precondition: (n > 0). +[[nodiscard]] consteval uint64_t add_mod(uint64_t a, uint64_t b, uint64_t n) +{ + if (a >= n - b) { + return a - (n - b); + } else { + return a + b; + } +} + +// (a - b) % n. +// +// Precondition: (a < n). +// Precondition: (b < n). +// Precondition: (n > 0). +[[nodiscard]] consteval uint64_t sub_mod(uint64_t a, uint64_t b, uint64_t n) +{ + if (a >= b) { + return a - b; + } else { + return n - (b - a); + } +} + +// (a * b) % n. +// +// Precondition: (a < n). +// Precondition: (b < n). +// Precondition: (n > 0). +[[nodiscard]] consteval uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t n) +{ + if (b == 0u || a < std::numeric_limits::max() / b) { + return (a * b) % n; + } + + const uint64_t batch_size = n / a; + const uint64_t num_batches = b / batch_size; + + return add_mod( + // Transform into "negative space" to make the first parameter as small as possible; + // then, transform back. + n - mul_mod(n % a, num_batches, n), + + // Handle the leftover product (which is guaranteed to fit in the integer type). + (a * (b % batch_size)) % n, + + n); +} + +// (a / 2) % n. +// +// Precondition: (a < n). +// Precondition: (n % 2 == 1). +[[nodiscard]] consteval uint64_t half_mod_odd(uint64_t a, uint64_t n) +{ + return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u)); +} + +// (base ^ exp) % n. +[[nodiscard]] consteval uint64_t pow_mod(uint64_t base, uint64_t exp, uint64_t n) +{ + uint64_t result = 1u; + base %= n; + + while (exp > 0u) { + if (exp % 2u == 1u) { + result = mul_mod(result, base, n); + } + + exp /= 2u; + base = mul_mod(base, base, n); + } + + return result; +} + [[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n) { for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) { diff --git a/test/static/prime_test.cpp b/test/static/prime_test.cpp index d3c50f09..bf1fd922 100644 --- a/test/static/prime_test.cpp +++ b/test/static/prime_test.cpp @@ -33,6 +33,8 @@ using namespace mp_units::detail; namespace { +inline constexpr auto MAX_U64 = std::numeric_limits::max(); + template constexpr bool check_primes(std::index_sequence) { @@ -78,4 +80,28 @@ static_assert(!wheel_factorizer<3>::is_prime(0)); static_assert(!wheel_factorizer<3>::is_prime(1)); static_assert(wheel_factorizer<3>::is_prime(2)); +// Modular arithmetic. +static_assert(add_mod(1u, 2u, 5u) == 3u); +static_assert(add_mod(4u, 4u, 5u) == 3u); +static_assert(add_mod(MAX_U64 - 1u, MAX_U64 - 2u, MAX_U64) == MAX_U64 - 3u); + +static_assert(sub_mod(2u, 1u, 5u) == 1u); +static_assert(sub_mod(1u, 2u, 5u) == 4u); +static_assert(sub_mod(MAX_U64 - 2u, MAX_U64 - 1u, MAX_U64) == MAX_U64 - 1u); +static_assert(sub_mod(1u, MAX_U64 - 1u, MAX_U64) == 2u); + +static_assert(mul_mod(6u, 7u, 10u) == 2u); +static_assert(mul_mod(13u, 11u, 50u) == 43u); +static_assert(mul_mod(MAX_U64 / 2u, 10u, MAX_U64) == MAX_U64 - 5u); + +static_assert(half_mod_odd(0u, 11u) == 0u); +static_assert(half_mod_odd(10u, 11u) == 5u); +static_assert(half_mod_odd(1u, 11u) == 6u); +static_assert(half_mod_odd(9u, 11u) == 10u); +static_assert(half_mod_odd(MAX_U64 - 1u, MAX_U64) == (MAX_U64 - 1u) / 2u); +static_assert(half_mod_odd(MAX_U64 - 2u, MAX_U64) == MAX_U64 - 1u); + +static_assert(pow_mod(5u, 8u, 9u) == ((5u * 5u * 5u * 5u) * (5u * 5u * 5u * 5u)) % 9u); +static_assert(pow_mod(2u, 64u, MAX_U64) == 1u); + } // namespace