From 8a7483f7eafb6f7665d2a498d05e69f7b94601fe Mon Sep 17 00:00:00 2001 From: Chip Hogg Date: Mon, 11 Nov 2024 12:55:41 -0500 Subject: [PATCH 1/4] Add helpers for modular arithmetic The prime-testing techniques we will use (Miller-Rabin, Strong Lucas) all make heavy usage of modular arithmetic. Therefore, we lay those foundations here, adding utilities to perform the basic arithmetic operations robustly. Since these are internal-only helper functions, we don't bother checking the preconditions, although we state them clearly in the contract comment for each utility. After C++26, we could add contracts for these. Helps #509. --- src/core/include/mp-units/ext/prime.h | 80 +++++++++++++++++++++++++++ test/static/prime_test.cpp | 26 +++++++++ 2 files changed, 106 insertions(+) diff --git a/src/core/include/mp-units/ext/prime.h b/src/core/include/mp-units/ext/prime.h index 5b372fe9..bb48956b 100644 --- a/src/core/include/mp-units/ext/prime.h +++ b/src/core/include/mp-units/ext/prime.h @@ -42,6 +42,86 @@ import std; namespace mp_units::detail { +// (a + b) % n. +// +// Precondition: (a < n). +// Precondition: (b < n). +// Precondition: (n > 0). +[[nodiscard]] consteval uint64_t add_mod(uint64_t a, uint64_t b, uint64_t n) +{ + if (a >= n - b) { + return a - (n - b); + } else { + return a + b; + } +} + +// (a - b) % n. +// +// Precondition: (a < n). +// Precondition: (b < n). +// Precondition: (n > 0). +[[nodiscard]] consteval uint64_t sub_mod(uint64_t a, uint64_t b, uint64_t n) +{ + if (a >= b) { + return a - b; + } else { + return n - (b - a); + } +} + +// (a * b) % n. +// +// Precondition: (a < n). +// Precondition: (b < n). +// Precondition: (n > 0). +[[nodiscard]] consteval uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t n) +{ + if (b == 0u || a < std::numeric_limits::max() / b) { + return (a * b) % n; + } + + const uint64_t batch_size = n / a; + const uint64_t num_batches = b / batch_size; + + return add_mod( + // Transform into "negative space" to make the first parameter as small as possible; + // then, transform back. + n - mul_mod(n % a, num_batches, n), + + // Handle the leftover product (which is guaranteed to fit in the integer type). + (a * (b % batch_size)) % n, + + n); +} + +// (a / 2) % n. +// +// Precondition: (a < n). +// Precondition: (n % 2 == 1). +[[nodiscard]] consteval uint64_t half_mod_odd(uint64_t a, uint64_t n) +{ + return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u)); +} + +// (base ^ exp) % n. +[[nodiscard]] consteval uint64_t pow_mod(uint64_t base, uint64_t exp, uint64_t n) +{ + uint64_t result = 1u; + base %= n; + + while (exp > 0u) { + if (exp % 2u == 1u) { + result = mul_mod(result, base, n); + } + + exp /= 2u; + base = mul_mod(base, base, n); + } + + return result; +} + [[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n) { for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) { diff --git a/test/static/prime_test.cpp b/test/static/prime_test.cpp index d3c50f09..bf1fd922 100644 --- a/test/static/prime_test.cpp +++ b/test/static/prime_test.cpp @@ -33,6 +33,8 @@ using namespace mp_units::detail; namespace { +inline constexpr auto MAX_U64 = std::numeric_limits::max(); + template constexpr bool check_primes(std::index_sequence) { @@ -78,4 +80,28 @@ static_assert(!wheel_factorizer<3>::is_prime(0)); static_assert(!wheel_factorizer<3>::is_prime(1)); static_assert(wheel_factorizer<3>::is_prime(2)); +// Modular arithmetic. +static_assert(add_mod(1u, 2u, 5u) == 3u); +static_assert(add_mod(4u, 4u, 5u) == 3u); +static_assert(add_mod(MAX_U64 - 1u, MAX_U64 - 2u, MAX_U64) == MAX_U64 - 3u); + +static_assert(sub_mod(2u, 1u, 5u) == 1u); +static_assert(sub_mod(1u, 2u, 5u) == 4u); +static_assert(sub_mod(MAX_U64 - 2u, MAX_U64 - 1u, MAX_U64) == MAX_U64 - 1u); +static_assert(sub_mod(1u, MAX_U64 - 1u, MAX_U64) == 2u); + +static_assert(mul_mod(6u, 7u, 10u) == 2u); +static_assert(mul_mod(13u, 11u, 50u) == 43u); +static_assert(mul_mod(MAX_U64 / 2u, 10u, MAX_U64) == MAX_U64 - 5u); + +static_assert(half_mod_odd(0u, 11u) == 0u); +static_assert(half_mod_odd(10u, 11u) == 5u); +static_assert(half_mod_odd(1u, 11u) == 6u); +static_assert(half_mod_odd(9u, 11u) == 10u); +static_assert(half_mod_odd(MAX_U64 - 1u, MAX_U64) == (MAX_U64 - 1u) / 2u); +static_assert(half_mod_odd(MAX_U64 - 2u, MAX_U64) == MAX_U64 - 1u); + +static_assert(pow_mod(5u, 8u, 9u) == ((5u * 5u * 5u * 5u) * (5u * 5u * 5u * 5u)) % 9u); +static_assert(pow_mod(2u, 64u, MAX_U64) == 1u); + } // namespace From a6d34b40a617ee29e08e2d78fb76de7c2283e0e6 Mon Sep 17 00:00:00 2001 From: Chip Hogg Date: Mon, 11 Nov 2024 13:35:37 -0500 Subject: [PATCH 2/4] Include Apparently some freestanding builds need this? --- src/core/include/mp-units/ext/prime.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core/include/mp-units/ext/prime.h b/src/core/include/mp-units/ext/prime.h index bb48956b..641c8039 100644 --- a/src/core/include/mp-units/ext/prime.h +++ b/src/core/include/mp-units/ext/prime.h @@ -32,6 +32,7 @@ import std; #else #include +#include #include #include #include From 1110e53e38316fcffef0ba2f75005fb757aa920a Mon Sep 17 00:00:00 2001 From: Chip Hogg Date: Mon, 11 Nov 2024 14:34:47 -0500 Subject: [PATCH 3/4] Never mind Apparently, that's not the right approach. --- src/core/include/mp-units/ext/prime.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/core/include/mp-units/ext/prime.h b/src/core/include/mp-units/ext/prime.h index 641c8039..bb48956b 100644 --- a/src/core/include/mp-units/ext/prime.h +++ b/src/core/include/mp-units/ext/prime.h @@ -32,7 +32,6 @@ import std; #else #include -#include #include #include #include From 6c982d4202534cf99131e20820a6f6407c56e52c Mon Sep 17 00:00:00 2001 From: Chip Hogg Date: Mon, 11 Nov 2024 16:14:04 -0500 Subject: [PATCH 4/4] Use `std::` prefix and mpu's EXPECTS_DEBUG macro --- src/core/include/mp-units/ext/prime.h | 34 ++++++++++++++++++++------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/core/include/mp-units/ext/prime.h b/src/core/include/mp-units/ext/prime.h index bb48956b..c734c3d3 100644 --- a/src/core/include/mp-units/ext/prime.h +++ b/src/core/include/mp-units/ext/prime.h @@ -28,6 +28,7 @@ #include #ifndef MP_UNITS_IN_MODULE_INTERFACE +#include #ifdef MP_UNITS_IMPORT_STD import std; #else @@ -47,8 +48,12 @@ namespace mp_units::detail { // Precondition: (a < n). // Precondition: (b < n). // Precondition: (n > 0). -[[nodiscard]] consteval uint64_t add_mod(uint64_t a, uint64_t b, uint64_t n) +[[nodiscard]] consteval std::uint64_t add_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(b < n); + MP_UNITS_EXPECTS_DEBUG(n > 0u); + if (a >= n - b) { return a - (n - b); } else { @@ -61,8 +66,12 @@ namespace mp_units::detail { // Precondition: (a < n). // Precondition: (b < n). // Precondition: (n > 0). -[[nodiscard]] consteval uint64_t sub_mod(uint64_t a, uint64_t b, uint64_t n) +[[nodiscard]] consteval std::uint64_t sub_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(b < n); + MP_UNITS_EXPECTS_DEBUG(n > 0u); + if (a >= b) { return a - b; } else { @@ -75,14 +84,18 @@ namespace mp_units::detail { // Precondition: (a < n). // Precondition: (b < n). // Precondition: (n > 0). -[[nodiscard]] consteval uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t n) +[[nodiscard]] consteval std::uint64_t mul_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n) { - if (b == 0u || a < std::numeric_limits::max() / b) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(b < n); + MP_UNITS_EXPECTS_DEBUG(n > 0u); + + if (b == 0u || a < std::numeric_limits::max() / b) { return (a * b) % n; } - const uint64_t batch_size = n / a; - const uint64_t num_batches = b / batch_size; + const std::uint64_t batch_size = n / a; + const std::uint64_t num_batches = b / batch_size; return add_mod( // Transform into "negative space" to make the first parameter as small as possible; @@ -99,15 +112,18 @@ namespace mp_units::detail { // // Precondition: (a < n). // Precondition: (n % 2 == 1). -[[nodiscard]] consteval uint64_t half_mod_odd(uint64_t a, uint64_t n) +[[nodiscard]] consteval std::uint64_t half_mod_odd(std::uint64_t a, std::uint64_t n) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(n % 2 == 1); + return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u)); } // (base ^ exp) % n. -[[nodiscard]] consteval uint64_t pow_mod(uint64_t base, uint64_t exp, uint64_t n) +[[nodiscard]] consteval std::uint64_t pow_mod(std::uint64_t base, std::uint64_t exp, std::uint64_t n) { - uint64_t result = 1u; + std::uint64_t result = 1u; base %= n; while (exp > 0u) {