diff --git a/src/core/include/mp-units/ext/prime.h b/src/core/include/mp-units/ext/prime.h index bb48956b..c734c3d3 100644 --- a/src/core/include/mp-units/ext/prime.h +++ b/src/core/include/mp-units/ext/prime.h @@ -28,6 +28,7 @@ #include #ifndef MP_UNITS_IN_MODULE_INTERFACE +#include #ifdef MP_UNITS_IMPORT_STD import std; #else @@ -47,8 +48,12 @@ namespace mp_units::detail { // Precondition: (a < n). // Precondition: (b < n). // Precondition: (n > 0). -[[nodiscard]] consteval uint64_t add_mod(uint64_t a, uint64_t b, uint64_t n) +[[nodiscard]] consteval std::uint64_t add_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(b < n); + MP_UNITS_EXPECTS_DEBUG(n > 0u); + if (a >= n - b) { return a - (n - b); } else { @@ -61,8 +66,12 @@ namespace mp_units::detail { // Precondition: (a < n). // Precondition: (b < n). // Precondition: (n > 0). -[[nodiscard]] consteval uint64_t sub_mod(uint64_t a, uint64_t b, uint64_t n) +[[nodiscard]] consteval std::uint64_t sub_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(b < n); + MP_UNITS_EXPECTS_DEBUG(n > 0u); + if (a >= b) { return a - b; } else { @@ -75,14 +84,18 @@ namespace mp_units::detail { // Precondition: (a < n). // Precondition: (b < n). // Precondition: (n > 0). -[[nodiscard]] consteval uint64_t mul_mod(uint64_t a, uint64_t b, uint64_t n) +[[nodiscard]] consteval std::uint64_t mul_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n) { - if (b == 0u || a < std::numeric_limits::max() / b) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(b < n); + MP_UNITS_EXPECTS_DEBUG(n > 0u); + + if (b == 0u || a < std::numeric_limits::max() / b) { return (a * b) % n; } - const uint64_t batch_size = n / a; - const uint64_t num_batches = b / batch_size; + const std::uint64_t batch_size = n / a; + const std::uint64_t num_batches = b / batch_size; return add_mod( // Transform into "negative space" to make the first parameter as small as possible; @@ -99,15 +112,18 @@ namespace mp_units::detail { // // Precondition: (a < n). // Precondition: (n % 2 == 1). -[[nodiscard]] consteval uint64_t half_mod_odd(uint64_t a, uint64_t n) +[[nodiscard]] consteval std::uint64_t half_mod_odd(std::uint64_t a, std::uint64_t n) { + MP_UNITS_EXPECTS_DEBUG(a < n); + MP_UNITS_EXPECTS_DEBUG(n % 2 == 1); + return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u)); } // (base ^ exp) % n. -[[nodiscard]] consteval uint64_t pow_mod(uint64_t base, uint64_t exp, uint64_t n) +[[nodiscard]] consteval std::uint64_t pow_mod(std::uint64_t base, std::uint64_t exp, std::uint64_t n) { - uint64_t result = 1u; + std::uint64_t result = 1u; base %= n; while (exp > 0u) {