Merge pull request #634 from chiphogg/chiphogg/mod#509

Add helpers for modular arithmetic
This commit is contained in:
Mateusz Pusz
2024-11-13 20:32:28 +01:00
committed by GitHub
2 changed files with 122 additions and 0 deletions

View File

@@ -28,6 +28,7 @@
#include <mp-units/ext/algorithm.h>
#ifndef MP_UNITS_IN_MODULE_INTERFACE
#include <mp-units/ext/contracts.h>
#ifdef MP_UNITS_IMPORT_STD
import std;
#else
@@ -42,6 +43,101 @@ import std;
namespace mp_units::detail {
// (a + b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval std::uint64_t add_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(b < n);
MP_UNITS_EXPECTS_DEBUG(n > 0u);
if (a >= n - b) {
return a - (n - b);
} else {
return a + b;
}
}
// (a - b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval std::uint64_t sub_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(b < n);
MP_UNITS_EXPECTS_DEBUG(n > 0u);
if (a >= b) {
return a - b;
} else {
return n - (b - a);
}
}
// (a * b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval std::uint64_t mul_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(b < n);
MP_UNITS_EXPECTS_DEBUG(n > 0u);
if (b == 0u || a < std::numeric_limits<std::uint64_t>::max() / b) {
return (a * b) % n;
}
const std::uint64_t batch_size = n / a;
const std::uint64_t num_batches = b / batch_size;
return add_mod(
// Transform into "negative space" to make the first parameter as small as possible;
// then, transform back.
n - mul_mod(n % a, num_batches, n),
// Handle the leftover product (which is guaranteed to fit in the integer type).
(a * (b % batch_size)) % n,
n);
}
// (a / 2) % n.
//
// Precondition: (a < n).
// Precondition: (n % 2 == 1).
[[nodiscard]] consteval std::uint64_t half_mod_odd(std::uint64_t a, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(n % 2 == 1);
return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u));
}
// (base ^ exp) % n.
[[nodiscard]] consteval std::uint64_t pow_mod(std::uint64_t base, std::uint64_t exp, std::uint64_t n)
{
std::uint64_t result = 1u;
base %= n;
while (exp > 0u) {
if (exp % 2u == 1u) {
result = mul_mod(result, base, n);
}
exp /= 2u;
base = mul_mod(base, base, n);
}
return result;
}
[[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n)
{
for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) {

View File

@@ -33,6 +33,8 @@ using namespace mp_units::detail;
namespace {
inline constexpr auto MAX_U64 = std::numeric_limits<std::uint64_t>::max();
template<std::size_t BasisSize, std::size_t... Is>
constexpr bool check_primes(std::index_sequence<Is...>)
{
@@ -78,4 +80,28 @@ static_assert(!wheel_factorizer<3>::is_prime(0));
static_assert(!wheel_factorizer<3>::is_prime(1));
static_assert(wheel_factorizer<3>::is_prime(2));
// Modular arithmetic.
static_assert(add_mod(1u, 2u, 5u) == 3u);
static_assert(add_mod(4u, 4u, 5u) == 3u);
static_assert(add_mod(MAX_U64 - 1u, MAX_U64 - 2u, MAX_U64) == MAX_U64 - 3u);
static_assert(sub_mod(2u, 1u, 5u) == 1u);
static_assert(sub_mod(1u, 2u, 5u) == 4u);
static_assert(sub_mod(MAX_U64 - 2u, MAX_U64 - 1u, MAX_U64) == MAX_U64 - 1u);
static_assert(sub_mod(1u, MAX_U64 - 1u, MAX_U64) == 2u);
static_assert(mul_mod(6u, 7u, 10u) == 2u);
static_assert(mul_mod(13u, 11u, 50u) == 43u);
static_assert(mul_mod(MAX_U64 / 2u, 10u, MAX_U64) == MAX_U64 - 5u);
static_assert(half_mod_odd(0u, 11u) == 0u);
static_assert(half_mod_odd(10u, 11u) == 5u);
static_assert(half_mod_odd(1u, 11u) == 6u);
static_assert(half_mod_odd(9u, 11u) == 10u);
static_assert(half_mod_odd(MAX_U64 - 1u, MAX_U64) == (MAX_U64 - 1u) / 2u);
static_assert(half_mod_odd(MAX_U64 - 2u, MAX_U64) == MAX_U64 - 1u);
static_assert(pow_mod(5u, 8u, 9u) == ((5u * 5u * 5u * 5u) * (5u * 5u * 5u * 5u)) % 9u);
static_assert(pow_mod(2u, 64u, MAX_U64) == 1u);
} // namespace