From e9c6a0fef5dab24a5690988604aaf4de77c1c416 Mon Sep 17 00:00:00 2001 From: Christian Mazakas Date: Mon, 24 Apr 2023 13:29:35 -0700 Subject: [PATCH] Add polyfill implementation of `std::latch` --- test/Jamfile.v2 | 1 + test/cfoa/helpers.hpp | 23 +++--- test/cfoa/latch.hpp | 87 +++++++++++++++++++++ test/cfoa/latch_tests.cpp | 154 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 251 insertions(+), 14 deletions(-) create mode 100644 test/cfoa/latch.hpp create mode 100644 test/cfoa/latch_tests.cpp diff --git a/test/Jamfile.v2 b/test/Jamfile.v2 index bb5ea89a..d6038bdf 100644 --- a/test/Jamfile.v2 +++ b/test/Jamfile.v2 @@ -176,6 +176,7 @@ alias foa_tests : ; local CFOA_TESTS = + latch_tests insert_tests erase_tests try_emplace_tests diff --git a/test/cfoa/helpers.hpp b/test/cfoa/helpers.hpp index b5fa1046..7ff48334 100644 --- a/test/cfoa/helpers.hpp +++ b/test/cfoa/helpers.hpp @@ -1,6 +1,12 @@ +// Copyright (C) 2023 Christian Mazakas +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + #ifndef BOOST_UNORDERED_TEST_CFOA_HELPERS_HPP #define BOOST_UNORDERED_TEST_CFOA_HELPERS_HPP +#include "latch.hpp" + #include "../helpers/generators.hpp" #include "../helpers/test.hpp" @@ -320,25 +326,14 @@ std::vector > split( template void thread_runner(std::vector& values, F f) { - std::mutex m; - std::condition_variable cv; - std::size_t c = 0; + boost::latch latch(num_threads); std::vector threads; auto subslices = split(values, num_threads); for (std::size_t i = 0; i < num_threads; ++i) { - threads.emplace_back([&f, &subslices, i, &m, &cv, &c] { - { - std::unique_lock lk(m); - ++c; - if (c == num_threads) { - lk.unlock(); - cv.notify_all(); - } else { - cv.wait(lk, [&] { return c == num_threads; }); - } - } + threads.emplace_back([&f, &subslices, i, &latch] { + latch.arrive_and_wait(); auto s = subslices[i]; f(s); diff --git a/test/cfoa/latch.hpp b/test/cfoa/latch.hpp new file mode 100644 index 00000000..bee42119 --- /dev/null +++ b/test/cfoa/latch.hpp @@ -0,0 +1,87 @@ +// Copyright (C) 2023 Christian Mazakas +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +#ifndef BOOST_UNORDERED_TEST_CFOA_LATCH_HPP +#define BOOST_UNORDERED_TEST_CFOA_LATCH_HPP + +#include + +#include +#include +#include +#include + +namespace boost { + class latch + { + private: + std::ptrdiff_t n_; + mutable std::mutex m_; + mutable std::condition_variable cv_; + + public: + explicit latch(std::ptrdiff_t expected) : n_{expected}, m_{}, cv_{} + { + BOOST_ASSERT(n_ >= 0); + BOOST_ASSERT(n_ <= max()); + } + + latch(latch const&) = delete; + latch& operator=(latch const&) = delete; + + ~latch() = default; + + void count_down(std::ptrdiff_t n = 1) + { + std::unique_lock lk(m_); + count_down_and_notify(lk, n); + } + + bool try_wait() const noexcept + { + std::unique_lock lk(m_); + return is_ready(); + } + + void wait() const + { + std::unique_lock lk(m_); + wait_impl(lk); + } + + void arrive_and_wait(std::ptrdiff_t n = 1) + { + std::unique_lock lk(m_); + bool should_wait = count_down_and_notify(lk, n); + if (should_wait) { + wait_impl(lk); + } + } + + static constexpr std::ptrdiff_t max() noexcept { return INT_MAX; } + + private: + bool is_ready() const { return n_ == 0; } + + bool count_down_and_notify( + std::unique_lock& lk, std::ptrdiff_t n) + { + n_ -= n; + if (n_ == 0) { + lk.unlock(); + cv_.notify_all(); + return false; + } + + return true; + } + + void wait_impl(std::unique_lock& lk) const + { + cv_.wait(lk, [this] { return this->is_ready(); }); + } + }; +} // namespace boost + +#endif // BOOST_UNORDERED_TEST_CFOA_LATCH_HPP diff --git a/test/cfoa/latch_tests.cpp b/test/cfoa/latch_tests.cpp new file mode 100644 index 00000000..8bad302a --- /dev/null +++ b/test/cfoa/latch_tests.cpp @@ -0,0 +1,154 @@ +// Copyright (C) 2023 Christian Mazakas +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +#define BOOST_ENABLE_ASSERT_HANDLER + +#include "latch.hpp" + +#include + +#include +#include + +struct exception +{ +}; + +namespace boost { + void assertion_failed( + char const* expr, char const* function, char const* file, long line) + { + (void)expr; + (void)function; + (void)file; + (void)line; + throw exception{}; + } +} // namespace boost + +namespace { + void test_max() { BOOST_TEST_EQ(boost::latch::max(), INT_MAX); } + + void test_constructor() + { + { + auto const f = [] { + boost::latch l(-1); + (void)l; + }; + BOOST_TEST_THROWS(f(), exception); + } + + { + std::ptrdiff_t n = 0; + + boost::latch l(n); + BOOST_TEST(l.try_wait()); + } + + { + std::ptrdiff_t n = 16; + + boost::latch l(n); + BOOST_TEST_NOT(l.try_wait()); + + l.count_down(16); + BOOST_TEST(l.try_wait()); + } + +#if PTRDIFF_MAX > INT_MAX + { + auto const f = [] { + std::ptrdiff_t n = INT_MAX; + n += 10; + boost::latch l(n); + (void)l; + }; + BOOST_TEST_THROWS(f(), exception); + } +#endif + } + + void test_count_down_and_wait() + { + constexpr std::ptrdiff_t n = 1024; + + boost::latch l(2 * n); + + bool bs[] = {false, false}; + + std::thread t1([&] { + l.wait(); + BOOST_TEST(bs[0]); + BOOST_TEST(bs[1]); + }); + + std::thread t2([&] { + for (int i = 0; i < n; ++i) { + if (i == (n - 1)) { + bs[0] = true; + } else { + BOOST_TEST_NOT(l.try_wait()); + } + + l.count_down(1); + } + }); + + for (int i = 0; i < n; ++i) { + if (i == (n - 1)) { + bs[1] = true; + } else { + BOOST_TEST_NOT(l.try_wait()); + } + + l.count_down(1); + } + + t1.join(); + t2.join(); + + BOOST_TEST(l.try_wait()); + } + + void test_arrive_and_wait() + { + constexpr std::ptrdiff_t n = 16; + + boost::latch l(2 * n); + + int xs[n] = {0}; + + std::vector threads; + for (int i = 0; i < n; ++i) { + threads.emplace_back([&l, &xs, i] { + for (int j = 0; j < n; ++j) { + BOOST_TEST_EQ(xs[j], 0); + } + + l.arrive_and_wait(2); + + xs[i] = 1; + }); + } + + for (auto& t : threads) { + t.join(); + } + + for (int i = 0; i < n; ++i) { + BOOST_TEST_EQ(xs[i], 1); + } + } +} // namespace + +int main() +{ + test_max(); + test_constructor(); + test_count_down_and_wait(); + test_arrive_and_wait(); + + return boost::report_errors(); +}