diff --git a/include/boost/unordered/concurrent_flat_map.hpp b/include/boost/unordered/concurrent_flat_map.hpp index df3718dd..45e65284 100644 --- a/include/boost/unordered/concurrent_flat_map.hpp +++ b/include/boost/unordered/concurrent_flat_map.hpp @@ -140,6 +140,10 @@ namespace boost { class concurrent_flat_map { private: + template + friend class concurrent_flat_map; + using type_policy = detail::concurrent_map_types; detail::foa::concurrent_table table_; @@ -703,6 +707,19 @@ namespace boost { void clear() noexcept { table_.clear(); } + template + size_type merge(concurrent_flat_map& x) + { + BOOST_ASSERT(get_allocator() == x.get_allocator()); + return table_.merge(x.table_); + } + + template + size_type merge(concurrent_flat_map&& x) + { + return merge(x); + } + /// Hash Policy /// void rehash(size_type n) { table_.rehash(n); } diff --git a/include/boost/unordered/detail/foa/concurrent_table.hpp b/include/boost/unordered/detail/foa/concurrent_table.hpp index a19d28b9..f933e774 100644 --- a/include/boost/unordered/detail/foa/concurrent_table.hpp +++ b/include/boost/unordered/detail/foa/concurrent_table.hpp @@ -362,6 +362,10 @@ class concurrent_table: using super::N; using prober=typename super::prober; + template< + typename TypePolicy2,typename Hash2,typename Pred2,typename Allocator2> + friend class concurrent_table; + public: using key_type=typename super::key_type; using init_type=typename super::init_type; @@ -683,14 +687,19 @@ public: // TODO: should we accept different allocator too? template - void merge(concurrent_table& x) + size_type merge(concurrent_table& x) { - auto lck=exclusive_access(*this,x); - x.super::for_all_elements( /* super::for_all_elements -> unprotected */ + using merge_table_type=concurrent_table; + using super2=typename merge_table_type::super; + + auto lck=exclusive_access(*this,x); + size_type s=unprotected_size(); + static_cast(x).for_all_elements( /* super::for_all_elements -> unprotected */ [&,this](group_type* pg,unsigned int n,element_type* p){ - erase_on_exit e{x,pg,n,p}; + erase_on_exit e{x,pg,n,p}; if(!unprotected_emplace(type_policy::move(*p)))e.rollback(); }); + return size_type{unprotected_size()-s}; } template @@ -799,6 +808,14 @@ private: return {x.mutexes,y.mutexes}; } + template + inline exclusive_bilock_guard exclusive_access( + const concurrent_table& x, + const concurrent_table& y) + { + return {x.mutexes,y.mutexes}; + } + /* Tag-dispatched shared/exclusive group access */ using group_shared=std::false_type; @@ -835,21 +852,29 @@ private: >::type cast_for(group_exclusive,value_type& x){return x;} + template struct erase_on_exit { + using table_group_type=typename Table::group_type; + using table_element_type=typename Table::element_type; + using table_super_type=typename Table::super; + erase_on_exit( - concurrent_table& x_, - group_type* pg_,unsigned int pos_,element_type* p_): + Table& x_,table_group_type* pg_,unsigned int pos_,table_element_type* p_): x{x_},pg{pg_},pos{pos_},p{p_}{} - ~erase_on_exit(){if(!rollback_)x.super::erase(pg,pos,p);} + ~erase_on_exit() + { + if(!rollback_) + static_cast(x).erase(pg,pos,p); + } void rollback(){rollback_=true;} - concurrent_table &x; - group_type *pg; - unsigned int pos; - element_type *p; - bool rollback_=false; + Table &x; + table_group_type *pg; + unsigned int pos; + table_element_type *p; + bool rollback_=false; }; template diff --git a/test/Jamfile.v2 b/test/Jamfile.v2 index 209e73c4..57be13da 100644 --- a/test/Jamfile.v2 +++ b/test/Jamfile.v2 @@ -186,6 +186,7 @@ local CFOA_TESTS = assign_tests clear_tests swap_tests + merge_tests ; for local test in $(CFOA_TESTS) diff --git a/test/cfoa/helpers.hpp b/test/cfoa/helpers.hpp index aa5a86a4..59a35be7 100644 --- a/test/cfoa/helpers.hpp +++ b/test/cfoa/helpers.hpp @@ -258,6 +258,16 @@ std::size_t hash_value(raii const& r) noexcept return hasher(r.x_); } +namespace std { + template <> struct hash + { + std::size_t operator()(raii const& r) const noexcept + { + return hash_value(r); + } + }; +} // namespace std + template auto make_random_values(std::size_t count, F f) -> std::vector { diff --git a/test/cfoa/merge_tests.cpp b/test/cfoa/merge_tests.cpp new file mode 100644 index 00000000..9b6bbb0e --- /dev/null +++ b/test/cfoa/merge_tests.cpp @@ -0,0 +1,193 @@ +// Copyright (C) 2023 Christian Mazakas +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +#include "helpers.hpp" + +#include + +test::seed_t initialize_seed{402031699}; + +using test::default_generator; +using test::limited_range; +using test::sequential; + +using hasher = stateful_hash; +using key_equal = stateful_key_equal; +using allocator_type = stateful_allocator >; + +using map_type = boost::unordered::concurrent_flat_map; + +using map_value_type = typename map_type::value_type; + +struct +{ + template + std::size_t operator()(X1& x1, X2& x2) const noexcept + { + return x1.merge(x2); + } +} lvalue_merge; + +struct +{ + template + std::size_t operator()(X1& x1, X2& x2) const noexcept + { + return x1.merge(std::move(x2)); + } +} rvalue_merge; + +namespace { + template + void merge_tests(F merger, G gen, test::random_generator rg) + { + auto values = make_random_values(1024 * 8, [&] { return gen(rg); }); + + auto ref_map = + boost::unordered_flat_map(values.begin(), values.end()); + + { + raii::reset_counts(); + + map_type x(values.size(), hasher(1), key_equal(2), allocator_type(3)); + + auto const old_cc = +raii::copy_constructor; + + std::atomic expected_copies{0}; + std::atomic num_merged{0}; + + thread_runner(values, [&x, &expected_copies, &num_merged, merger]( + boost::span s) { + using map2_type = boost::unordered::concurrent_flat_map, std::equal_to, allocator_type>; + + map2_type y(s.begin(), s.end(), s.size(), allocator_type(3)); + expected_copies += 2 * y.size(); + + BOOST_TEST(x.get_allocator() == y.get_allocator()); + num_merged += merger(x, y); + }); + + BOOST_TEST_EQ(raii::copy_constructor, old_cc + expected_copies); + BOOST_TEST_EQ(raii::move_constructor, 2 * ref_map.size()); + BOOST_TEST_EQ(+num_merged, ref_map.size()); + + test_fuzzy_matches_reference(x, ref_map, rg); + } + check_raii_counts(); + } + + template + void insert_and_merge_tests(G gen, test::random_generator rg) + { + using map2_type = boost::unordered::concurrent_flat_map, std::equal_to, allocator_type>; + + auto vals1 = make_random_values(1024 * 8, [&] { return gen(rg); }); + auto vals2 = make_random_values(1024 * 4, [&] { return gen(rg); }); + + auto ref_map = boost::unordered_flat_map(); + ref_map.insert(vals1.begin(), vals1.end()); + ref_map.insert(vals2.begin(), vals2.end()); + + { + raii::reset_counts(); + + map_type x1(2 * vals1.size(), hasher(1), key_equal(2), allocator_type(3)); + + map2_type x2(2 * vals1.size(), allocator_type(3)); + + std::thread t1, t2, t3; + boost::latch l(2); + + std::mutex m; + std::condition_variable cv; + std::atomic_bool done1{false}, done2{false}; + std::atomic num_merges{0}; + + auto const old_mc = +raii::move_constructor; + BOOST_TEST_EQ(old_mc, 0u); + + t1 = std::thread([&x1, &vals1, &l, &done1, &cv] { + l.arrive_and_wait(); + + for (std::size_t idx = 0; idx < vals1.size(); ++idx) { + auto const& val = vals1[idx]; + x1.insert(val); + if (idx % 100 == 0) { + cv.notify_all(); + std::this_thread::yield(); + } + } + + done1 = true; + }); + + t2 = std::thread([&x2, &vals2, &l, &done2] { + l.arrive_and_wait(); + + for (std::size_t idx = 0; idx < vals2.size(); ++idx) { + auto const& val = vals2[idx]; + x2.insert(val); + if (idx % 100 == 0) { + std::this_thread::yield(); + } + } + + done2 = true; + }); + + t3 = std::thread([&x1, &x2, &m, &cv, &done1, &done2, &num_merges] { + while (x1.empty() && x2.empty()) { + } + + do { + { + std::unique_lock lk(m); + cv.wait(lk, [] { return true; }); + } + num_merges += x1.merge(x2); + std::this_thread::yield(); + num_merges += x2.merge(x1); + + } while (!done1 || !done2); + + BOOST_TEST(done1); + BOOST_TEST(done2); + }); + + t1.join(); + t2.join(); + t3.join(); + + if (num_merges > 0) { + // num merges is 0 most commonly in the cast of the limited_range + // generator as both maps will contains keys from 0 to 99 + BOOST_TEST_EQ(+raii::move_constructor, 2 * num_merges); + } + + x1.merge(x2); + test_fuzzy_matches_reference(x1, ref_map, rg); + } + + check_raii_counts(); + } + +} // namespace + +// clang-format off +UNORDERED_TEST( + merge_tests, + ((lvalue_merge)(rvalue_merge)) + ((value_type_generator)) + ((default_generator)(sequential)(limited_range))) + +UNORDERED_TEST( + insert_and_merge_tests, + ((value_type_generator)) + ((default_generator)(sequential)(limited_range))) +// clang-format on + +RUN_TESTS()