From 7fd6913b732d6226ff50f86a7a9269213465b611 Mon Sep 17 00:00:00 2001 From: Chip Hogg Date: Thu, 7 Jul 2022 19:19:08 +0000 Subject: [PATCH] Replace `exp` with `0` everywhere, and remove it This lets us remove a ton of special-casing throughout the codebase, and just generally makes things a lot simpler. We also remove the ability to take rational powers of `ratio`, including `sqrt` and `cbrt` helpers, because these are intrinsically ill-defined. Fixes #369. --- example/custom_systems.cpp | 5 +- src/core/include/units/bits/ratio_maths.h | 112 ++------------------- src/core/include/units/chrono.h | 7 +- src/core/include/units/magnitude.h | 8 +- src/core/include/units/ratio.h | 117 ++++------------------ test/unit_test/runtime/magnitude_test.cpp | 7 -- test/unit_test/static/ratio_test.cpp | 42 -------- 7 files changed, 28 insertions(+), 270 deletions(-) diff --git a/example/custom_systems.cpp b/example/custom_systems.cpp index 4604be9e..90a63dae 100644 --- a/example/custom_systems.cpp +++ b/example/custom_systems.cpp @@ -94,10 +94,7 @@ void unknown_dimensions() std::cout << si_fps_area << "\n"; } -std::ostream& operator<<(std::ostream& os, const ratio& r) -{ - return os << "ratio{" << r.num << ", " << r.den << ", " << r.exp << "}"; -} +std::ostream& operator<<(std::ostream& os, const ratio& r) { return os << "ratio{" << r.num << ", " << r.den << "}"; } template std::ostream& operator<<(std::ostream& os, const U& u) diff --git a/src/core/include/units/bits/ratio_maths.h b/src/core/include/units/bits/ratio_maths.h index d9a6494a..a6cbea4c 100644 --- a/src/core/include/units/bits/ratio_maths.h +++ b/src/core/include/units/bits/ratio_maths.h @@ -39,108 +39,20 @@ template return v < 0 ? -v : v; } -// the following functions enable gcd and related computations on ratios -// with exponents. They avoid overflow. Further information here: -// https://github.com/mpusz/units/issues/62#issuecomment-588152833 - -// Computes (a * b) mod m relies on unsigned integer arithmetic, should not -// overflow -[[nodiscard]] constexpr std::uint64_t mulmod(std::uint64_t a, std::uint64_t b, std::uint64_t m) -{ - std::uint64_t res = 0; - - if (b >= m) { - if (m > UINT64_MAX / 2u) { - b -= m; - } else { - b %= m; - } - } - - while (a != 0) { - if (a & 1) { - if (b >= m - res) { - res -= m; - } - res += b; - } - a >>= 1; - - std::uint64_t temp_b = b; - if (b >= m - b) { - temp_b -= m; - } - b += temp_b; - } - - return res; -} - -// Calculates (a ^ e) mod m , should not overflow. -[[nodiscard]] constexpr std::uint64_t modpow(std::uint64_t a, std::uint64_t e, std::uint64_t m) -{ - a %= m; - std::uint64_t result = 1; - - while (e > 0) { - if (e & 1) { - result = mulmod(result, a, m); - } - a = mulmod(a, a, m); - e >>= 1; - } - return result; -} - -// gcd(a * 10 ^ e, b), should not overflow -[[nodiscard]] constexpr std::intmax_t gcdpow(std::intmax_t a, std::intmax_t e, std::intmax_t b) noexcept -{ - assert(a > 0); - assert(e >= 0); - assert(b > 0); - - // gcd(i, j) = gcd(j, i mod j) for j != 0 Euclid; - // - // gcd(a 10^e, b) = gcd(b, a 10^e mod b) - // - // (a 10^e) mod b -> [ (a mod b) (10^e mod b) ] mod b - - return std::gcd( - b, static_cast(mulmod(static_cast(a % b), - modpow(10, static_cast(e), static_cast(b)), - static_cast(b)))); -} - -constexpr void cwap(std::intmax_t& lhs, std::intmax_t& rhs) -{ - std::intmax_t tmp = lhs; - lhs = rhs; - rhs = tmp; -} - -// Computes the rational gcd of n1/d1 x 10^e1 and n2/d2 x 10^e2 -[[nodiscard]] constexpr auto gcd_frac(std::intmax_t n1, std::intmax_t d1, std::intmax_t e1, std::intmax_t n2, - std::intmax_t d2, std::intmax_t e2) noexcept +// Computes the rational gcd of n1/d1 and n2/d2 +[[nodiscard]] constexpr auto gcd_frac(std::intmax_t n1, std::intmax_t d1, std::intmax_t n2, std::intmax_t d2) noexcept { // Short cut for equal ratios - if (n1 == n2 && d1 == d2 && e1 == e2) { - return std::array{n1, d1, e1}; + if (n1 == n2 && d1 == d2) { + return std::array{n1, d1}; } - if (e2 > e1) { - detail::cwap(n1, n2); - detail::cwap(d1, d2); - detail::cwap(e1, e2); - } - - std::intmax_t exp = e2; // minimum - // gcd(a/b,c/d) = gcd(a⋅d, c⋅b) / b⋅d assert(std::numeric_limits::max() / n1 > d2); assert(std::numeric_limits::max() / n2 > d1); - std::intmax_t num = detail::gcdpow(n1 * d2, e1 - e2, n2 * d1); + std::intmax_t num = std::gcd(n1 * d2, n2 * d1); assert(std::numeric_limits::max() / d1 > d2); @@ -148,29 +60,19 @@ constexpr void cwap(std::intmax_t& lhs, std::intmax_t& rhs) std::intmax_t gcd = std::gcd(num, den); - return std::array{num / gcd, den / gcd, exp}; + return std::array{num / gcd, den / gcd}; } -constexpr void normalize(std::intmax_t& num, std::intmax_t& den, std::intmax_t& exp) +constexpr void normalize(std::intmax_t& num, std::intmax_t& den) { if (num == 0) { den = 1; - exp = 0; return; } std::intmax_t gcd = std::gcd(num, den); num = num * (den < 0 ? -1 : 1) / gcd; den = detail::abs(den) / gcd; - - while (num % 10 == 0) { - num /= 10; - ++exp; - } - while (den % 10 == 0) { - den /= 10; - --exp; - } } [[nodiscard]] constexpr std::intmax_t safe_multiply(std::intmax_t lhs, std::intmax_t rhs) diff --git a/src/core/include/units/chrono.h b/src/core/include/units/chrono.h index 91b4ab6b..d7c9fc91 100644 --- a/src/core/include/units/chrono.h +++ b/src/core/include/units/chrono.h @@ -75,12 +75,7 @@ constexpr std::intmax_t pow_10(std::intmax_t v) template constexpr auto to_std_ratio_impl() { - if constexpr (R.exp == 0) - return std::ratio{}; - else if constexpr (R.exp > 0) - return std::ratio{}; - else - return std::ratio{}; + return std::ratio{}; } } // namespace detail diff --git a/src/core/include/units/magnitude.h b/src/core/include/units/magnitude.h index 1257e132..47bb1aee 100644 --- a/src/core/include/units/magnitude.h +++ b/src/core/include/units/magnitude.h @@ -189,9 +189,6 @@ constexpr widen_t compute_base_power(BasePower auto bp) if (bp.power.den != 1) { throw std::invalid_argument{"Rational powers not yet supported"}; } - if (bp.power.exp < 0) { - throw std::invalid_argument{"Unsupported exp value"}; - } if (bp.power.num < 0) { if constexpr (std::is_integral_v) { @@ -344,7 +341,7 @@ inline constexpr bool is_base_power_pack_valid = all_base_powers_valid & constexpr bool is_rational(BasePower auto bp) { - return std::is_integral_v && (bp.power.den == 1) && (bp.power.exp >= 0); + return std::is_integral_v && (bp.power.den == 1); } constexpr bool is_integral(BasePower auto bp) { return is_rational(bp) && bp.power.num > 0; } @@ -652,8 +649,7 @@ template requires(R.num > 0) constexpr Magnitude auto as_magnitude() { - return pow(detail::prime_factorization_v<10>) * detail::prime_factorization_v / - detail::prime_factorization_v; + return detail::prime_factorization_v / detail::prime_factorization_v; } namespace detail { diff --git a/src/core/include/units/ratio.h b/src/core/include/units/ratio.h index b57ce06d..a469d4c2 100644 --- a/src/core/include/units/ratio.h +++ b/src/core/include/units/ratio.h @@ -43,42 +43,28 @@ constexpr ratio inverse(const ratio& r); /** * @brief Provides compile-time rational arithmetic support. * - * This class is really similar to @c std::ratio but gets an additional `Exp` - * template parameter that defines the exponent of the ratio. Another important - * difference is the fact that the objects of that class are used as class NTTPs - * rather then a type template parameter kind. + * This class is really similar to @c std::ratio. An important difference is the fact that the objects of that class + * are used as class NTTPs rather then a type template parameter kind. */ struct ratio { std::intmax_t num; std::intmax_t den; - std::intmax_t exp; - constexpr explicit(false) ratio(std::intmax_t n, std::intmax_t d = 1, std::intmax_t e = 0) : num(n), den(d), exp(e) + constexpr explicit(false) ratio(std::intmax_t n, std::intmax_t d = 1) : num(n), den(d) { gsl_Expects(den != 0); - detail::normalize(num, den, exp); + detail::normalize(num, den); } [[nodiscard]] friend constexpr bool operator==(const ratio&, const ratio&) = default; [[nodiscard]] friend constexpr auto operator<=>(const ratio& lhs, const ratio& rhs) { return (lhs - rhs).num <=> 0; } - [[nodiscard]] friend constexpr ratio operator-(const ratio& r) { return ratio(-r.num, r.den, r.exp); } + [[nodiscard]] friend constexpr ratio operator-(const ratio& r) { return ratio(-r.num, r.den); } [[nodiscard]] friend constexpr ratio operator+(ratio lhs, ratio rhs) { - // First, get the inputs into a common exponent. - const auto common_exp = std::min(lhs.exp, rhs.exp); - auto commonify = [common_exp](ratio& r) { - while (r.exp > common_exp) { - r.num *= 10; - --r.exp; - } - }; - commonify(lhs); - commonify(rhs); - - return ratio{lhs.num * rhs.den + lhs.den * rhs.num, lhs.den * rhs.den, common_exp}; + return ratio{lhs.num * rhs.den + lhs.den * rhs.num, lhs.den * rhs.den}; } [[nodiscard]] friend constexpr ratio operator-(const ratio& lhs, const ratio& rhs) { return lhs + (-rhs); } @@ -88,96 +74,31 @@ struct ratio { const std::intmax_t gcd1 = std::gcd(lhs.num, rhs.den); const std::intmax_t gcd2 = std::gcd(rhs.num, lhs.den); return ratio(detail::safe_multiply(lhs.num / gcd1, rhs.num / gcd2), - detail::safe_multiply(lhs.den / gcd2, rhs.den / gcd1), lhs.exp + rhs.exp); + detail::safe_multiply(lhs.den / gcd2, rhs.den / gcd1)); } [[nodiscard]] friend constexpr ratio operator/(const ratio& lhs, const ratio& rhs) { return lhs * inverse(rhs); } - [[nodiscard]] friend constexpr std::intmax_t numerator(const ratio& r) - { - std::intmax_t true_num = r.num; - for (auto i = r.exp; i > 0; --i) { - true_num *= 10; - } - return true_num; - } + [[nodiscard]] friend constexpr std::intmax_t numerator(const ratio& r) { return r.num; } - [[nodiscard]] friend constexpr std::intmax_t denominator(const ratio& r) - { - std::intmax_t true_den = r.den; - for (auto i = r.exp; i < 0; ++i) { - true_den *= 10; - } - return true_den; - } + [[nodiscard]] friend constexpr std::intmax_t denominator(const ratio& r) { return r.den; } }; -[[nodiscard]] constexpr ratio inverse(const ratio& r) { return ratio(r.den, r.num, -r.exp); } +[[nodiscard]] constexpr ratio inverse(const ratio& r) { return ratio(r.den, r.num); } -[[nodiscard]] constexpr bool is_integral(const ratio& r) -{ - if (r.exp < 0) { - return false; - } else { - return detail::gcdpow(r.num, r.exp, r.den) == r.den; - } -} +[[nodiscard]] constexpr bool is_integral(const ratio& r) { return r.num % r.den == 0; } -namespace detail { - -[[nodiscard]] constexpr auto make_exp_align(const ratio& r, std::intmax_t alignment) -{ - gsl_Expects(alignment > 0); - const std::intmax_t rem = r.exp % alignment; - - if (rem == 0) { // already aligned - return std::array{r.num, r.den, r.exp}; - } - - if (r.exp > 0) { // remainder is positive - return std::array{r.num * ipow10(rem), r.den, r.exp - rem}; - } - - // remainder is negative - return std::array{r.num, r.den * ipow10(-rem), r.exp - rem}; -} - -template - requires gt_zero -[[nodiscard]] constexpr ratio root(const ratio& r) -{ - if constexpr (N == 1) { - return r; - } else { - if (r.num == 0) { - return ratio(0); - } - - const auto aligned = make_exp_align(r, N); - return ratio(iroot(aligned[0]), iroot(aligned[1]), aligned[2] / N); - } -} - -} // namespace detail - -template - requires detail::non_zero +template [[nodiscard]] constexpr ratio pow(const ratio& r) { if constexpr (Num == 0) { return ratio(1); - } else if constexpr (Num == Den) { + } else if constexpr (Num == 1) { return r; } else { - // simplify factors first and compute power for positive exponent - constexpr std::intmax_t gcd = std::gcd(Num, Den); - constexpr std::intmax_t num = detail::abs(Num / gcd); - constexpr std::intmax_t den = detail::abs(Den / gcd); + const ratio result = detail::pow_impl(r); - // integer root loses precision so do pow first - const ratio result = detail::root(detail::pow_impl(r)); - - if constexpr (Num * Den < 0) { // account for negative exponent + if constexpr (Num < 0) { // account for negative exponent return inverse(result); } else { return result; @@ -185,15 +106,11 @@ template } } -[[nodiscard]] constexpr ratio sqrt(const ratio& r) { return pow<1, 2>(r); } - -[[nodiscard]] constexpr ratio cbrt(const ratio& r) { return pow<1, 3>(r); } - // common_ratio [[nodiscard]] constexpr ratio common_ratio(const ratio& r1, const ratio& r2) { - const auto res = detail::gcd_frac(r1.num, r1.den, r1.exp, r2.num, r2.den, r2.exp); - return ratio(res[0], res[1], res[2]); + const auto res = detail::gcd_frac(r1.num, r1.den, r2.num, r2.den); + return ratio(res[0], res[1]); } } // namespace units diff --git a/test/unit_test/runtime/magnitude_test.cpp b/test/unit_test/runtime/magnitude_test.cpp index 9fa0aa98..8b80a60d 100644 --- a/test/unit_test/runtime/magnitude_test.cpp +++ b/test/unit_test/runtime/magnitude_test.cpp @@ -154,13 +154,6 @@ TEST_CASE("make_ratio performs prime factorization correctly") SECTION("Supports fractions") { CHECK(as_magnitude() == magnitude{}); } - SECTION("Supports nonzero exp") - { - constexpr ratio r{3, 1, 2}; - REQUIRE(r.exp == 2); - CHECK(as_magnitude() == as_magnitude<300>()); - } - SECTION("Can handle prime factor which would be large enough to overflow int") { // This was taken from a case which failed when we used `int` for our base to store prime numbers. diff --git a/test/unit_test/static/ratio_test.cpp b/test/unit_test/static/ratio_test.cpp index ad966950..87d36443 100644 --- a/test/unit_test/static/ratio_test.cpp +++ b/test/unit_test/static/ratio_test.cpp @@ -28,11 +28,6 @@ using namespace units; static_assert(ratio(2, 4) == ratio(1, 2)); -// basic exponents tests -static_assert(ratio(2, 40, 1) == ratio(1, 20, 1)); -static_assert(ratio(20, 4, -1) == ratio(10, 2, -1)); -static_assert(ratio(200, 5) == ratio(20'000, 50, -1)); - static_assert(ratio(1) * ratio(3, 8) == ratio(3, 8)); static_assert(ratio(3, 8) * ratio(1) == ratio(3, 8)); static_assert(ratio(4) * ratio(1, 8) == ratio(1, 2)); @@ -45,21 +40,12 @@ static_assert(-ratio(3, 8) == ratio(-3, 8)); // ratio addition static_assert(ratio(1, 2) + ratio(1, 3) == ratio(5, 6)); -static_assert(ratio(1, 3, 2) + ratio(11, 6) == ratio(211, 6)); // 100/3 + 11/6 - -// multiply with exponents -static_assert(ratio(1, 8, 2) * ratio(2, 1, 4) == ratio(1, 4, 6)); -static_assert(ratio(1, 2, -4) * ratio(8, 1, 3) == ratio(4, 1, -1)); static_assert(ratio(4) / ratio(2) == ratio(2)); static_assert(ratio(2) / ratio(8) == ratio(1, 4)); static_assert(ratio(1, 8) / ratio(2) == ratio(1, 16)); static_assert(ratio(6) / ratio(3) == ratio(2)); -// divide with exponents -static_assert(ratio(1, 8, -6) / ratio(2, 1, -8) == ratio(1, 16, 2)); -static_assert(ratio(6, 1, 4) / ratio(3) == ratio(2, 1, 4)); - static_assert(pow<0>(ratio(2)) == ratio(1)); static_assert(pow<1>(ratio(2)) == ratio(2)); static_assert(pow<2>(ratio(2)) == ratio(4)); @@ -69,27 +55,6 @@ static_assert(pow<1>(ratio(1, 2)) == ratio(1, 2)); static_assert(pow<2>(ratio(1, 2)) == ratio(1, 4)); static_assert(pow<3>(ratio(1, 2)) == ratio(1, 8)); -// pow with exponents -static_assert(pow<2>(ratio(1, 2, 3)) == ratio(1, 4, 6)); -static_assert(pow<4, 2>(ratio(1, 2, 3)) == ratio(1, 4, 6)); -static_assert(pow<3>(ratio(1, 2, -6)) == ratio(1, 8, -18)); - -static_assert(sqrt(ratio(9)) == ratio(3)); -static_assert(cbrt(ratio(27)) == ratio(3)); -static_assert(sqrt(ratio(4)) == ratio(2)); -static_assert(cbrt(ratio(8)) == ratio(2)); -static_assert(sqrt(ratio(1)) == ratio(1)); -static_assert(cbrt(ratio(1)) == ratio(1)); -static_assert(sqrt(ratio(0)) == ratio(0)); -static_assert(cbrt(ratio(0)) == ratio(0)); -static_assert(sqrt(ratio(1, 4)) == ratio(1, 2)); -static_assert(cbrt(ratio(1, 8)) == ratio(1, 2)); - -// sqrt with exponents -static_assert(sqrt(ratio(9, 1, 2)) == ratio(3, 1, 1)); -static_assert(cbrt(ratio(27, 1, 3)) == ratio(3, 1, 1)); -static_assert(cbrt(ratio(27, 1, 2)) == ratio(13, 1, 0)); - // common_ratio static_assert(common_ratio(ratio(1), ratio(1000)) == ratio(1)); static_assert(common_ratio(ratio(1000), ratio(1)) == ratio(1)); @@ -98,20 +63,13 @@ static_assert(common_ratio(ratio(1, 1000), ratio(1)) == ratio(1, 1000)); static_assert(common_ratio(ratio(100, 1), ratio(10, 1)) == ratio(10, 1)); static_assert(common_ratio(ratio(100, 1), ratio(1, 10)) == ratio(1, 10)); -// common ratio with exponents -static_assert(common_ratio(ratio(1), ratio(1, 1, 3)) == ratio(1)); -static_assert(common_ratio(ratio(10, 1, -1), ratio(1, 1, -3)) == ratio(1, 1, -3)); - // numerator and denominator static_assert(numerator(ratio(3, 4)) == 3); -static_assert(numerator(ratio(3, 7, 2)) == 300); static_assert(denominator(ratio(3, 4)) == 4); -static_assert(denominator(ratio(3, 7, -2)) == 700); // comparison static_assert((ratio(3, 4) <=> ratio(6, 8)) == (0 <=> 0)); static_assert((ratio(3, 4) <=> ratio(-3, 4)) == (0 <=> -1)); static_assert((ratio(-3, 4) <=> ratio(3, -4)) == (0 <=> 0)); -static_assert((ratio(1, 1, 1) <=> ratio(10)) == (0 <=> 0)); } // namespace