diff --git a/src/core/include/mp-units/framework/magnitude.h b/src/core/include/mp-units/framework/magnitude.h index 8fb22eb8..b94404db 100644 --- a/src/core/include/mp-units/framework/magnitude.h +++ b/src/core/include/mp-units/framework/magnitude.h @@ -307,6 +307,119 @@ template return checked_square(int_power(base, exp / 2)); } +template +[[nodiscard]] consteval std::optional checked_int_pow(T base, std::uintmax_t exp) { + T result = T{1}; + while (exp > 0u) { + if (exp % 2u == 1u) { + if (base > std::numeric_limits::max() / result) { + return std::nullopt; + } + result *= base; + } + + exp /= 2u; + + if (base > std::numeric_limits::max() / base) { + return (exp == 0u) + ? std::make_optional(result) + : std::nullopt; + } + base *= base; + } + return result; +} + +template +[[nodiscard]] consteval std::optional root(T x, std::uintmax_t n) { + // The "zeroth root" would be mathematically undefined. + if (n == 0) { + return std::nullopt; + } + + // The "first root" is trivial. + if (n == 1) { + return x; + } + + // We only support nontrivial roots of floating point types. + if (!std::is_floating_point::value) { + return std::nullopt; + } + + // Handle negative numbers: only odd roots are allowed. + if (x < 0) { + if (n % 2 == 0) { + return std::nullopt; + } else { + const auto negative_result = root(-x, n); + if (!negative_result.has_value()) { + return std::nullopt; + } + return static_cast(-negative_result.value()); + } + } + + // Handle special cases of zero and one. + if (x == 0 || x == 1) { + return x; + } + + // Handle numbers bewtween 0 and 1. + if (x < 1) { + const auto inverse_result = root(T{1} / x, n); + if (!inverse_result.has_value()) { + return std::nullopt; + } + return static_cast(T{1} / inverse_result.value()); + } + + // + // At this point, error conditions are finished, and we can proceed with the "core" algorithm. + // + + // Always use `long double` for intermediate computations. We don't ever expect people to be + // calling this at runtime, so we want maximum accuracy. + long double lo = 1.0; + long double hi = static_cast(x); + + // Do a binary search to find the closest value such that `checked_int_pow` recovers the input. + // + // Because we know `n > 1`, and `x > 1`, and x^n is monotonically increasing, we know that + // `checked_int_pow(lo, n) < x < checked_int_pow(hi, n)`. We will preserve this as an + // invariant. + while (lo < hi) { + long double mid = lo + (hi - lo) / 2; + + auto result = checked_int_pow(mid, n); + + if (!result.has_value()) { + return std::nullopt; + } + + // Early return if we get lucky with an exact answer. + if (result.value() == x) { + return static_cast(mid); + } + + // Check for stagnation. + if (mid == lo || mid == hi) { + break; + } + + // Preserve the invariant that `checked_int_pow(lo, n) < x < checked_int_pow(hi, n)`. + if (result.value() < x) { + lo = mid; + } else { + hi = mid; + } + } + + // Pick whichever one gets closer to the target. + const auto lo_diff = x - checked_int_pow(lo, n).value(); + const auto hi_diff = checked_int_pow(hi, n).value() - x; + return static_cast(lo_diff < hi_diff ? lo : hi); +} template [[nodiscard]] consteval widen_t compute_base_power(MagnitudeSpec auto el) @@ -317,9 +430,6 @@ template // Note that since this function should only be called at compile time, the point of these // terminations is to act as "static_assert substitutes", not to actually terminate at runtime. const auto exp = get_exponent(el); - if (exp.den != 1) { - std::abort(); // Rational powers not yet supported - } if (exp.num < 0) { if constexpr (std::is_integral_v) { @@ -329,8 +439,19 @@ template } } - auto power = exp.num; - return int_power(static_cast>(get_base_value(el)), power); + const auto pow_result = checked_int_pow(static_cast>(get_base_value(el)), static_cast(exp.num)); + if (pow_result.has_value()) { + const auto final_result = (exp.den > 1) ? root(pow_result.value(), static_cast(exp.den)) : pow_result; + if (final_result.has_value()) { + return final_result.value(); + } + else { + std::abort(); // Root computation failed. + } + } + else { + std::abort(); // Power computation failed. + } } // A converter for the value member variable of magnitude (below). diff --git a/test/static/quantity_test.cpp b/test/static/quantity_test.cpp index 9b3abbb0..744735df 100644 --- a/test/static/quantity_test.cpp +++ b/test/static/quantity_test.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,24 @@ using namespace mp_units::si::unit_symbols; // quantity class invariants ////////////////////////////// +template +constexpr bool within_4_ulps(T a, T b) { + static_assert(std::is_floating_point_v); + auto walk_ulps = [](T x, int n) { + while (n > 0) { + x = std::nextafter(x, std::numeric_limits::infinity()); + --n; + } + while (n < 0) { + x = std::nextafter(x, -std::numeric_limits::infinity()); + ++n; + } + return x; + }; + + return (walk_ulps(a, -4) <= b) && (b <= walk_ulps(a, 4)); +} + static_assert(sizeof(quantity) == sizeof(double)); static_assert(sizeof(quantity) == sizeof(double)); static_assert(sizeof(quantity) == sizeof(short)); @@ -199,6 +218,16 @@ static_assert(std::convertible_to, quantity, quantity>); static_assert(std::convertible_to, quantity>); +// conversion requiring radical magnitudes +static_assert(within_4_ulps(sqrt((1.0 * m) * (1.0 * km)).numerical_value_in(m), sqrt(1000.0))); + +// Reproducing issue #494 exactly: +constexpr auto val_issue_494 = 8.0 * si::si2019::boltzmann_constant * 1000.0 * K / (std::numbers::pi * 10 * Da); +static_assert( + within_4_ulps( + sqrt(val_issue_494).numerical_value_in(m / s), + sqrt(val_issue_494.numerical_value_in(m * m / s / s)))); + /////////////////////// // obtaining a number