Add helpers for modular arithmetic

The prime-testing techniques we will use (Miller-Rabin, Strong Lucas)
all make heavy usage of modular arithmetic.  Therefore, we lay those
foundations here, adding utilities to perform the basic arithmetic
operations robustly.

Since these are internal-only helper functions, we don't bother checking
the preconditions, although we state them clearly in the contract
comment for each utility.  After C++26, we could add contracts for
these.

Helps #509.
This commit is contained in:
Chip Hogg
2024-11-11 12:55:41 -05:00
parent 727a898141
commit 8a7483f7ea
2 changed files with 106 additions and 0 deletions

View File

@@ -42,6 +42,86 @@ import std;
namespace mp_units::detail {
// (a + b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval uint64_t add_mod(uint64_t a, uint64_t b, uint64_t n)
{
if (a >= n - b) {
return a - (n - b);
} else {
return a + b;
}
}
// (a - b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval uint64_t sub_mod(uint64_t a, uint64_t b, uint64_t n)
{
if (a >= b) {
return a - b;
} else {
return n - (b - a);
}
}
// (a * b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t n)
{
if (b == 0u || a < std::numeric_limits<uint64_t>::max() / b) {
return (a * b) % n;
}
const uint64_t batch_size = n / a;
const uint64_t num_batches = b / batch_size;
return add_mod(
// Transform into "negative space" to make the first parameter as small as possible;
// then, transform back.
n - mul_mod(n % a, num_batches, n),
// Handle the leftover product (which is guaranteed to fit in the integer type).
(a * (b % batch_size)) % n,
n);
}
// (a / 2) % n.
//
// Precondition: (a < n).
// Precondition: (n % 2 == 1).
[[nodiscard]] consteval uint64_t half_mod_odd(uint64_t a, uint64_t n)
{
return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u));
}
// (base ^ exp) % n.
[[nodiscard]] consteval uint64_t pow_mod(uint64_t base, uint64_t exp, uint64_t n)
{
uint64_t result = 1u;
base %= n;
while (exp > 0u) {
if (exp % 2u == 1u) {
result = mul_mod(result, base, n);
}
exp /= 2u;
base = mul_mod(base, base, n);
}
return result;
}
[[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n)
{
for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) {

View File

@@ -33,6 +33,8 @@ using namespace mp_units::detail;
namespace {
inline constexpr auto MAX_U64 = std::numeric_limits<std::uint64_t>::max();
template<std::size_t BasisSize, std::size_t... Is>
constexpr bool check_primes(std::index_sequence<Is...>)
{
@@ -78,4 +80,28 @@ static_assert(!wheel_factorizer<3>::is_prime(0));
static_assert(!wheel_factorizer<3>::is_prime(1));
static_assert(wheel_factorizer<3>::is_prime(2));
// Modular arithmetic.
static_assert(add_mod(1u, 2u, 5u) == 3u);
static_assert(add_mod(4u, 4u, 5u) == 3u);
static_assert(add_mod(MAX_U64 - 1u, MAX_U64 - 2u, MAX_U64) == MAX_U64 - 3u);
static_assert(sub_mod(2u, 1u, 5u) == 1u);
static_assert(sub_mod(1u, 2u, 5u) == 4u);
static_assert(sub_mod(MAX_U64 - 2u, MAX_U64 - 1u, MAX_U64) == MAX_U64 - 1u);
static_assert(sub_mod(1u, MAX_U64 - 1u, MAX_U64) == 2u);
static_assert(mul_mod(6u, 7u, 10u) == 2u);
static_assert(mul_mod(13u, 11u, 50u) == 43u);
static_assert(mul_mod(MAX_U64 / 2u, 10u, MAX_U64) == MAX_U64 - 5u);
static_assert(half_mod_odd(0u, 11u) == 0u);
static_assert(half_mod_odd(10u, 11u) == 5u);
static_assert(half_mod_odd(1u, 11u) == 6u);
static_assert(half_mod_odd(9u, 11u) == 10u);
static_assert(half_mod_odd(MAX_U64 - 1u, MAX_U64) == (MAX_U64 - 1u) / 2u);
static_assert(half_mod_odd(MAX_U64 - 2u, MAX_U64) == MAX_U64 - 1u);
static_assert(pow_mod(5u, 8u, 9u) == ((5u * 5u * 5u * 5u) * (5u * 5u * 5u * 5u)) % 9u);
static_assert(pow_mod(2u, 64u, MAX_U64) == 1u);
} // namespace