diff --git a/src/core/include/units/math.h b/src/core/include/units/math.h index 61160e1d..d07bcb25 100644 --- a/src/core/include/units/math.h +++ b/src/core/include/units/math.h @@ -164,8 +164,9 @@ template }) { const auto handle_signed_results = [&](const T& res) { - if (res > q) + if (res > q) { return res - T::one(); + } return res; }; if constexpr(treat_as_floating_point) { @@ -193,13 +194,9 @@ template * @tparam q Quantity being the base of the operation * @return Quantity The rounded quantity with unit type of quantity To */ -template +template D, typename U, std::same_as Rep> [[nodiscard]] constexpr quantity floor(const quantity& q) noexcept - requires std::same_as && - std::same_as && - requires { - ::units::floor(q); - } + requires requires { ::units::floor(q); } { return ::units::floor(q); } @@ -221,8 +218,9 @@ template }) { const auto handle_signed_results = [&](const T& res) { - if (res < q) + if (res < q) { return res + T::one(); + } return res; }; if constexpr(treat_as_floating_point) { @@ -250,15 +248,69 @@ template * @tparam q Quantity being the base of the operation * @return Quantity The rounded quantity with unit type of quantity To */ -template +template D, typename U, std::same_as Rep> [[nodiscard]] constexpr quantity ceil(const quantity& q) noexcept - requires std::same_as && - std::same_as && - requires { - ::units::ceil(q); - } + requires requires { ::units::ceil(q); } { return ::units::ceil(q); } +/** + * @brief Computes the nearest quantity with integer representation and unit type To to q + * + * Rounding halfway cases away from zero, regardless of the current rounding mode. + * + * @tparam q Quantity being the base of the operation + * @return Quantity The rounded quantity with unit type To + */ +template +[[nodiscard]] constexpr quantity round(const quantity& q) noexcept + requires ((!treat_as_floating_point) || + requires { round(q.number()); } || + requires { std::round(q.number()); }) && + (std::same_as || requires { + ::units::floor(q); + quantity::one(); + }) +{ + if constexpr(std::is_same_v) { + if constexpr(treat_as_floating_point) { + using std::round; + return quantity(round(q.number())); + } + else { + return q; + } + } + else { + const auto res_low = units::floor(q); + const auto res_high = res_low + decltype(res_low)::one(); + const auto diff0 = q - res_low; + const auto diff1 = res_high - q; + if (diff0 == diff1) { + if (static_cast(res_low.number()) & 1) { + return res_high; + } + return res_low; + } + if (diff0 < diff1) { + return res_low; + } + return res_high; + } +} + +/** + * @brief Overload of @c ::units::round() using the unit type of To + * + * @tparam q Quantity being the base of the operation + * @return Quantity The rounded quantity with unit type of quantity To + */ +template D, typename U, std::same_as Rep> +[[nodiscard]] constexpr quantity round(const quantity& q) noexcept + requires requires { ::units::round(q); } +{ + return ::units::round(q); +} + } // namespace units diff --git a/test/unit_test/runtime/math_test.cpp b/test/unit_test/runtime/math_test.cpp index a3281026..ee56fefc 100644 --- a/test/unit_test/runtime/math_test.cpp +++ b/test/unit_test/runtime/math_test.cpp @@ -199,6 +199,77 @@ TEST_CASE("ceil functions", "[ceil]") } } +TEST_CASE("round functions", "[round]") +{ + SECTION ("round 1 second with target unit second should be 1 second") { + REQUIRE(round(1_q_s) == 1_q_s); + } + SECTION ("round 1000 milliseconds with target unit second should be 1 second") { + REQUIRE(round(1000_q_ms) == 1_q_s); + } + SECTION ("round 1001 milliseconds with target unit second should be 1 second") { + REQUIRE(round(1001_q_ms) == 1_q_s); + } + SECTION ("round 1499 milliseconds with target unit second should be 1 second") { + REQUIRE(round(1499_q_ms) == 1_q_s); + } + SECTION ("round 1500 milliseconds with target unit second should be 2 seconds") { + REQUIRE(round(1500_q_ms) == 2_q_s); + } + SECTION ("round 1999 milliseconds with target unit second should be 2 seconds") { + REQUIRE(round(1999_q_ms) == 2_q_s); + } + SECTION ("round -1000 milliseconds with target unit second should be -1 second") { + REQUIRE(round(-1000_q_ms) == -1_q_s); + } + SECTION ("round -1001 milliseconds with target unit second should be -1 second") { + REQUIRE(round(-1001_q_ms) == -1_q_s); + } + SECTION ("round -1499 milliseconds with target unit second should be -1 second") { + REQUIRE(round(-1499_q_ms) == -1_q_s); + } + SECTION ("round -1500 milliseconds with target unit second should be -2 seconds") { + REQUIRE(round(-1500_q_ms) == -2_q_s); + } + SECTION ("round -1999 milliseconds with target unit second should be -2 seconds") { + REQUIRE(round(-1999_q_ms) == -2_q_s); + } + SECTION ("round 1000. milliseconds with target unit second should be 1 second") { + REQUIRE(round(1000._q_ms) == 1_q_s); + } + SECTION ("round 1001. milliseconds with target unit second should be 1 second") { + REQUIRE(round(1001._q_ms) == 1_q_s); + } + SECTION ("round 1499. milliseconds with target unit second should be 1 second") { + REQUIRE(round(1499._q_ms) == 1_q_s); + } + SECTION ("round 1500. milliseconds with target unit second should be 2 seconds") { + REQUIRE(round(1500._q_ms) == 2_q_s); + } + SECTION ("round 1999. milliseconds with target unit second should be 2 seconds") { + REQUIRE(round(1999._q_ms) == 2_q_s); + } + SECTION ("round -1000. milliseconds with target unit second should be -1 second") { + REQUIRE(round(-1000._q_ms) == -1_q_s); + } + SECTION ("round -1001. milliseconds with target unit second should be -1 second") { + REQUIRE(round(-1001._q_ms) == -1_q_s); + } + SECTION ("round -1499. milliseconds with target unit second should be -1 second") { + REQUIRE(round(-1499._q_ms) == -1_q_s); + } + SECTION ("round -1500. milliseconds with target unit second should be -2 seconds") { + REQUIRE(round(-1500._q_ms) == -2_q_s); + } + SECTION ("round -1999. milliseconds with target unit second should be -2 seconds") { + REQUIRE(round(-1999._q_ms) == -2_q_s); + } + SECTION ("round 1 second with target quantity with unit type second should be 1 second") { + using showtime = si::time; + REQUIRE(round(showtime::one()) == showtime::one()); + } +} + TEMPLATE_TEST_CASE_SIG("pow() implementation exponentiates values to power N", "[math][pow][exp]", (std::intmax_t N, N), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25) { diff --git a/test/unit_test/static/math_test.cpp b/test/unit_test/static/math_test.cpp index 627ba60c..f2a5b683 100644 --- a/test/unit_test/static/math_test.cpp +++ b/test/unit_test/static/math_test.cpp @@ -57,7 +57,7 @@ static_assert(compare(4_q_m2)), decltype(sqrt(2_q_m))>); static_assert(compare(4_q_km2)), decltype(sqrt(2_q_km))>); static_assert(compare(4_q_ft2)), decltype(sqrt(2_q_ft))>); -#if __cpp_lib_constexpr_cmath // TODO remove once std::floor is constexpr for all compilers +#if __cpp_lib_constexpr_cmath // TODO remove once std::floor, std::ceil, and std::round is constexpr for all compilers // floor // integral types static_assert(compare(1_q_s)), decltype(1_q_s)>); @@ -103,6 +103,41 @@ static_assert(ceil(-999._q_ms) == 0_q_s); // ceil with quantity static_assert(compare>(1_q_s)), decltype(1_q_s)>); + +// round +// integral types +static_assert(compare(1_q_s)), decltype(1_q_s)>); + +static_assert(compare(1000_q_ms)), decltype(1_q_s)>); +static_assert(compare(1001_q_ms)), decltype(1_q_s)>); +static_assert(compare(1499_q_ms)), decltype(1_q_s)>); +static_assert(compare(1500_q_ms)), decltype(2_q_s)>); +static_assert(compare(1999_q_ms)), decltype(2_q_s)>); + +static_assert(compare(-1000_q_ms)), decltype(-1_q_s)>); +static_assert(compare(-1001_q_ms)), decltype(-1_q_s)>); +static_assert(compare(-1499_q_ms)), decltype(-1_q_s)>); +static_assert(compare(-1500_q_ms)), decltype(-2_q_s)>); +static_assert(compare(-1999_q_ms)), decltype(-2_q_s)>); + +// floating-point +static_assert(round(1.3_q_s) == 1_q_s); +static_assert(round(-1.3_q_s) == -1_q_s); + +static_assert(compare(1000._q_ms)), decltype(1_q_s)>); +static_assert(compare(1001._q_ms)), decltype(1_q_s)>); +static_assert(compare(1499._q_ms)), decltype(1_q_s)>); +static_assert(compare(1500._q_ms)), decltype(2_q_s)>); +static_assert(compare(1999._q_ms)), decltype(2_q_s)>); + +static_assert(compare(-1000._q_ms)), decltype(-1_q_s)>); +static_assert(compare(-1001._q_ms)), decltype(-1_q_s)>); +static_assert(compare(-1499._q_ms)), decltype(-1_q_s)>); +static_assert(compare(-1500._q_ms)), decltype(-2_q_s)>); +static_assert(compare(-1999._q_ms)), decltype(-2_q_s)>); + +// round with quantity +static_assert(compare>(1_q_s)), decltype(1_q_s)>); #endif } // namespace