diff --git a/src/core/include/mp-units/math.h b/src/core/include/mp-units/math.h index fb458f9e..352a577f 100644 --- a/src/core/include/mp-units/math.h +++ b/src/core/include/mp-units/math.h @@ -333,32 +333,17 @@ template */ template [[nodiscard]] constexpr quantity(R), Rep> floor(const quantity& q) noexcept - requires((!treat_as_floating_point) || requires(Rep v) { floor(v); } || requires(Rep v) { std::floor(v); }) && - (equivalent(To, get_unit(R)) || requires { - q.force_in(To); - representation_values::one(); - }) + requires requires { q.force_in(To); } && + (treat_as_floating_point && (requires(Rep v) { floor(v); } || requires(Rep v) { std::floor(v); })) || + (!treat_as_floating_point && requires { representation_values::one(); }) { - const auto handle_signed_results = [&](const T& res) { - if (res > q) { - return res - representation_values::one() * T::reference; - } - return res; - }; + const quantity res = q.force_in(To); if constexpr (treat_as_floating_point) { using std::floor; - if constexpr (equivalent(To, get_unit(R))) { - return {static_cast(floor(q.numerical_value_ref_in(q.unit))), detail::clone_reference_with(R)}; - } else { - return handle_signed_results( - quantity{static_cast(floor(q.force_numerical_value_in(To))), detail::clone_reference_with(R)}); - } + return {static_cast(floor(res.numerical_value_ref_in(res.unit))), res.reference}; } else { - if constexpr (equivalent(To, get_unit(R))) { - return q.force_in(To); - } else { - return handle_signed_results(q.force_in(To)); - } + if (res > q) return res - representation_values::one() * res.reference; + return res; } } @@ -370,74 +355,48 @@ template */ template [[nodiscard]] constexpr quantity(R), Rep> ceil(const quantity& q) noexcept - requires((!treat_as_floating_point) || requires(Rep v) { ceil(v); } || requires(Rep v) { std::ceil(v); }) && - (equivalent(To, get_unit(R)) || requires { - q.force_in(To); - representation_values::one(); - }) + requires requires { q.force_in(To); } && + (treat_as_floating_point && (requires(Rep v) { ceil(v); } || requires(Rep v) { std::ceil(v); })) || + (!treat_as_floating_point && requires { representation_values::one(); }) { - const auto handle_signed_results = [&](const T& res) { - if (res < q) { - return res + representation_values::one() * T::reference; - } - return res; - }; + const quantity res = q.force_in(To); if constexpr (treat_as_floating_point) { using std::ceil; - if constexpr (equivalent(To, get_unit(R))) { - return {static_cast(ceil(q.numerical_value_ref_in(q.unit))), detail::clone_reference_with(R)}; - } else { - return handle_signed_results( - quantity{static_cast(ceil(q.force_numerical_value_in(To))), detail::clone_reference_with(R)}); - } + return {static_cast(ceil(res.numerical_value_ref_in(res.unit))), res.reference}; } else { - if constexpr (equivalent(To, get_unit(R))) { - return q.force_in(To); - } else { - return handle_signed_results(q.force_in(To)); - } + if (res < q) return res + representation_values::one() * res.reference; + return res; } } /** - * @brief Computes the nearest quantity with integer representation and unit type To to q + * @brief Computes the nearest quantity with integer representation and unit type `To` to `q` * - * Rounding halfway cases away from zero, regardless of the current rounding mode. + * Returns the value `res` representable in `To` unit that is the closest to `q`. If there are two + * such values, returns the even value (that is, the value `res` such that `res % 2 == 0`). * * @tparam q Quantity being the base of the operation - * @return Quantity The rounded quantity with unit type To + * @return Quantity The quantity rounded to the nearest unit `To`, rounding to even in halfway + * cases. */ template [[nodiscard]] constexpr quantity(R), Rep> round(const quantity& q) noexcept - requires((!treat_as_floating_point) || requires(Rep v) { round(v); } || requires(Rep v) { std::round(v); }) && - (equivalent(To, get_unit(R)) || requires { - ::mp_units::floor(q); - representation_values::one(); - }) + requires requires { + mp_units::floor(q); + representation_values::one(); + } && std::constructible_from { - if constexpr (equivalent(To, get_unit(R))) { - if constexpr (treat_as_floating_point) { - using std::round; - return {static_cast(round(q.numerical_value_ref_in(q.unit))), detail::clone_reference_with(R)}; - } else { - return q.force_in(To); - } - } else { - const auto res_low = mp_units::floor(q); - const auto res_high = res_low + representation_values::one() * res_low.reference; - const auto diff0 = q - res_low; - const auto diff1 = res_high - q; - if (diff0 == diff1) { - if (static_cast(res_low.numerical_value_ref_in(To)) & 1) { - return res_high; - } - return res_low; - } - if (diff0 < diff1) { - return res_low; - } - return res_high; - } + const auto res_low = mp_units::floor(q); + const auto res_high = res_low + representation_values::one() * res_low.reference; + const auto diff0 = q - res_low; + const auto diff1 = res_high - q; + if (diff0 == diff1) { + // TODO How to extend this to custom representation types? + if (static_cast(res_low.numerical_value_ref_in(To)) & 1) return res_high; + return res_low; + } else if (diff0 < diff1) + return res_low; + return res_high; } /** diff --git a/test/static/math_test.cpp b/test/static/math_test.cpp index 48682ee8..a033514a 100644 --- a/test/static/math_test.cpp +++ b/test/static/math_test.cpp @@ -189,11 +189,15 @@ static_assert(compare(round(1001 * ms), 1 * s)); static_assert(compare(round(1499 * ms), 1 * s)); static_assert(compare(round(1500 * ms), 2 * s)); static_assert(compare(round(1999 * ms), 2 * s)); +static_assert(compare(round(2500 * ms), 2 * s)); +static_assert(compare(round(3500 * ms), 4 * s)); static_assert(compare(round(-1000 * ms), -1 * s)); static_assert(compare(round(-1001 * ms), -1 * s)); static_assert(compare(round(-1499 * ms), -1 * s)); static_assert(compare(round(-1500 * ms), -2 * s)); static_assert(compare(round(-1999 * ms), -2 * s)); +static_assert(compare(round(-2500 * ms), -2 * s)); +static_assert(compare(round(-3500 * ms), -4 * s)); static_assert(compare(round(1 * isq::time[s]), 1 * isq::time[s])); static_assert(compare(round(1000 * isq::time[ms]), 1 * isq::time[s])); @@ -201,38 +205,62 @@ static_assert(compare(round(1001 * isq::time[ms]), 1 * isq::time[s]) static_assert(compare(round(1499 * isq::time[ms]), 1 * isq::time[s])); static_assert(compare(round(1500 * isq::time[ms]), 2 * isq::time[s])); static_assert(compare(round(1999 * isq::time[ms]), 2 * isq::time[s])); +static_assert(compare(round(2500 * isq::time[ms]), 2 * isq::time[s])); +static_assert(compare(round(3500 * isq::time[ms]), 4 * isq::time[s])); static_assert(compare(round(-1000 * isq::time[ms]), -1 * isq::time[s])); static_assert(compare(round(-1001 * isq::time[ms]), -1 * isq::time[s])); static_assert(compare(round(-1499 * isq::time[ms]), -1 * isq::time[s])); static_assert(compare(round(-1500 * isq::time[ms]), -2 * isq::time[s])); static_assert(compare(round(-1999 * isq::time[ms]), -2 * isq::time[s])); +static_assert(compare(round(-2500 * isq::time[ms]), -2 * isq::time[s])); +static_assert(compare(round(-3500 * isq::time[ms]), -4 * isq::time[s])); // floating-point static_assert(compare(round(1.3 * s), 1. * s)); +static_assert(compare(round(1.5 * s), 2. * s)); +static_assert(compare(round(2.5 * s), 2. * s)); +static_assert(compare(round(3.5 * s), 4. * s)); static_assert(compare(round(-1.3 * s), -1. * s)); +static_assert(compare(round(-1.5 * s), -2. * s)); +static_assert(compare(round(-2.5 * s), -2. * s)); +static_assert(compare(round(-3.5 * s), -4. * s)); static_assert(compare(round(1000. * ms), 1. * s)); static_assert(compare(round(1001. * ms), 1. * s)); static_assert(compare(round(1499. * ms), 1. * s)); static_assert(compare(round(1500. * ms), 2. * s)); static_assert(compare(round(1999. * ms), 2. * s)); +static_assert(compare(round(2500. * ms), 2. * s)); +static_assert(compare(round(3500. * ms), 4. * s)); static_assert(compare(round(-1000. * ms), -1. * s)); static_assert(compare(round(-1001. * ms), -1. * s)); static_assert(compare(round(-1499. * ms), -1. * s)); static_assert(compare(round(-1500. * ms), -2. * s)); static_assert(compare(round(-1999. * ms), -2. * s)); +static_assert(compare(round(-2500. * ms), -2. * s)); +static_assert(compare(round(-3500. * ms), -4. * s)); static_assert(compare(round(1.3 * isq::time[s]), 1. * isq::time[s])); +static_assert(compare(round(1.5 * isq::time[s]), 2. * isq::time[s])); +static_assert(compare(round(2.5 * isq::time[s]), 2. * isq::time[s])); +static_assert(compare(round(3.5 * isq::time[s]), 4. * isq::time[s])); static_assert(compare(round(-1.3 * isq::time[s]), -1. * isq::time[s])); +static_assert(compare(round(-1.5 * isq::time[s]), -2. * isq::time[s])); +static_assert(compare(round(-2.5 * isq::time[s]), -2. * isq::time[s])); +static_assert(compare(round(-3.5 * isq::time[s]), -4. * isq::time[s])); static_assert(compare(round(1000. * isq::time[ms]), 1. * isq::time[s])); static_assert(compare(round(1001. * isq::time[ms]), 1. * isq::time[s])); static_assert(compare(round(1499. * isq::time[ms]), 1. * isq::time[s])); static_assert(compare(round(1500. * isq::time[ms]), 2. * isq::time[s])); static_assert(compare(round(1999. * isq::time[ms]), 2. * isq::time[s])); +static_assert(compare(round(2500. * isq::time[ms]), 2. * isq::time[s])); +static_assert(compare(round(3500. * isq::time[ms]), 4. * isq::time[s])); static_assert(compare(round(-1000. * isq::time[ms]), -1. * isq::time[s])); static_assert(compare(round(-1001. * isq::time[ms]), -1. * isq::time[s])); static_assert(compare(round(-1499. * isq::time[ms]), -1. * isq::time[s])); static_assert(compare(round(-1500. * isq::time[ms]), -2. * isq::time[s])); static_assert(compare(round(-1999. * isq::time[ms]), -2. * isq::time[s])); +static_assert(compare(round(-2500. * isq::time[ms]), -2. * isq::time[s])); +static_assert(compare(round(-3500. * isq::time[ms]), -4. * isq::time[s])); #endif