diff --git a/include/boost/type_index/runtime_cast.hpp b/include/boost/type_index/runtime_cast.hpp index e1e8ba5..333753d 100644 --- a/include/boost/type_index/runtime_cast.hpp +++ b/include/boost/type_index/runtime_cast.hpp @@ -17,7 +17,9 @@ /// or undesirable at a global level. #include +#include #include +#include #ifdef BOOST_HAS_PRAGMA_ONCE # pragma once @@ -27,80 +29,112 @@ namespace boost { namespace typeindex { namespace detail { -template -struct find_type; - -template -struct find_type { - template - void* operator()(T* p, type_index const& idx) const BOOST_NOEXCEPT { - if(idx == boost::typeindex::type_id()) - return nullptr; - return nullptr; - } -}; - template struct find_type { - template - Current* check_current(T* p, type_index const& idx) const BOOST_NOEXCEPT{ - if(idx == boost::typeindex::type_id()) - return nullptr; - return nullptr; - } + template + Current* check_current(T* p, type_index const& idx) const BOOST_NOEXCEPT { + if(idx == boost::typeindex::type_id()) + return p; + return nullptr; + } - template - void* check_bases(T* p, type_index const& idx) const BOOST_NOEXCEPT { - if(void* result = p->FirstBase::boost_type_index_find_instance_(idx)) - return result; - return check_bases(p, idx); - } + template + void* check_bases(T* p, type_index const& idx) const BOOST_NOEXCEPT { + return nullptr; + } - template - void* operator()(T* p, type_index const& idx) const BOOST_NOEXCEPT { - if(Current* current = check_current(p, idx)) - return p; - return check_bases(p, idx); - } + template + void* check_bases(T* p, type_index const& idx) const BOOST_NOEXCEPT { + if(void* result = p->FirstBase::boost_type_index_find_instance_(idx)) + return result; + return check_bases(p, idx); + } + + template + void* operator()(T* p, type_index const& idx) const BOOST_NOEXCEPT { + if(Current* current = check_current(p, idx)) + return p; + return check_bases(p, idx); + } }; template -T* runtime_cast_impl(U* u) { - return static_cast( - u->boost_type_index_find_instance_(boost::typeindex::type_id()) - ); +T* runtime_cast_impl(U* u, std::true_type) { + return u; } template -T const* runtime_cast_impl(U const* u) { - return static_cast( - const_cast(u)->boost_type_index_find_instance_(boost::typeindex::type_id()) - ); +T const* runtime_cast_impl(U const* u, std::true_type) { + return u; +} + +template +T* runtime_cast_impl(U* u, std::false_type) { + return static_cast( + u->boost_type_index_find_instance_(boost::typeindex::type_id()) + ); +} + +template +T const* runtime_cast_impl(U const* u, std::false_type) { + return static_cast( + const_cast(u)->boost_type_index_find_instance_(boost::typeindex::type_id()) + ); } } // namespace detail -#define BOOST_TYPE_INDEX_REGISTER_CLASS_RTTI \ - virtual void* boost_type_index_find_instance_(boost::typeindex::type_index const& idx) BOOST_NOEXCEPT { \ - return boost::typeindex::detail::find_type::type>()(this, idx); \ - } +#define BOOST_TYPE_INDEX_REGISTER_CLASS_RTTI \ + virtual void* boost_type_index_find_instance_(boost::typeindex::type_index const& idx) BOOST_NOEXCEPT { \ + if(idx == boost::typeindex::type_id::type>()) \ + return this; \ + return nullptr; \ + } -#define BOOST_TYPE_INDEX_REGISTER_CLASS_RTTI_BASES(...) \ - virtual void* boost_type_index_find_instance_(boost::typeindex::type_index const& idx) BOOST_NOEXCEPT { \ - if(auto ret = boost::typeindex::detail::find_type::type>()(this, idx)) \ - return ret; \ - return boost::typeindex::detail::find_type::type, __VA_ARGS__>()(this, idx);\ - } +#define BOOST_TYPE_INDEX_REGISTER_CLASS_RTTI_BASES(...) \ + virtual void* boost_type_index_find_instance_(boost::typeindex::type_index const& idx) BOOST_NOEXCEPT { \ + return boost::typeindex::detail::find_type::type, __VA_ARGS__>()(this, idx);\ + } - template - T* runtime_cast(U* u) { - return detail::runtime_cast_impl(u); - } + template + T runtime_cast(U* u) BOOST_NOEXCEPT { + typedef typename std::remove_pointer::type impl_type; + return detail::runtime_cast_impl(u, std::is_same()); + } - template - T const* runtime_cast(U const* u) { - return detail::runtime_cast_impl(u); - } + template + T runtime_cast(U const* u) BOOST_NOEXCEPT { + typedef typename std::remove_pointer::type impl_type; + return detail::runtime_cast_impl(u, std::is_same()); + } + + template + T runtime_cast(U& u) { + typedef typename std::remove_reference::type impl_type; + impl_type* value = detail::runtime_cast_impl(&u, std::is_same()); + if(!value) + boost::throw_exception(std::bad_cast()); + return *value; + } + + template + T runtime_cast(U const& u) { + typedef typename std::remove_reference::type impl_type; + impl_type* value = detail::runtime_cast_impl(&u, std::is_same()); + if(!value) + boost::throw_exception(std::bad_cast()); + return *value; + } + + template + T* runtime_pointer_cast(U* u) BOOST_NOEXCEPT { + return detail::runtime_cast_impl(u, std::is_same()); + } + + template + T const* runtime_pointer_cast(U const* u) BOOST_NOEXCEPT { + return detail::runtime_cast_impl(u, std::is_same()); + } }} // namespace boost::typeindex diff --git a/test/runtime_cast_test.cpp b/test/runtime_cast_test.cpp index 77aadb0..970566d 100644 --- a/test/runtime_cast_test.cpp +++ b/test/runtime_cast_test.cpp @@ -59,13 +59,30 @@ struct multiple_virtual_derived : baseV1, baseV2 { IMPLEMENT_CLASS(multiple_virtual_derived) }; +struct unrelated { + BOOST_TYPE_INDEX_REGISTER_CLASS_RTTI +}; + +struct unrelated_with_base : base { + BOOST_TYPE_INDEX_REGISTER_CLASS_RTTI_BASES(base) +}; + +struct unrelatedV1 : virtual base { + BOOST_TYPE_INDEX_REGISTER_CLASS_RTTI_BASES(base) +}; + void no_base() { using namespace boost::typeindex; base b; - base* b2 = runtime_cast(&b); + base* b2 = runtime_pointer_cast(&b); BOOST_TEST_NE(b2, (base*)nullptr); BOOST_TEST_EQ(b2->name, "base"); + + BOOST_TEST_EQ(runtime_pointer_cast(&b), (unrelated*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(&b), (single_derived*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(&b), (unrelatedV1*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(&b), (unrelated_with_base*)nullptr); } void single_base() @@ -73,9 +90,13 @@ void single_base() using namespace boost::typeindex; single_derived d; base* b = &d; - single_derived* d2 = runtime_cast(b); + single_derived* d2 = runtime_pointer_cast(b); BOOST_TEST_NE(d2, (single_derived*)nullptr); BOOST_TEST_EQ(d2->name, "single_derived"); + + BOOST_TEST_EQ(runtime_pointer_cast(&d), (unrelated*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(b), (unrelated*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(b), (unrelated_with_base*)nullptr); } void multiple_base() @@ -83,13 +104,17 @@ void multiple_base() using namespace boost::typeindex; multiple_derived d; base1* b1 = &d; - multiple_derived* d2 = runtime_cast(b1); + multiple_derived* d2 = runtime_pointer_cast(b1); BOOST_TEST_NE(d2, (multiple_derived*)nullptr); BOOST_TEST_EQ(d2->name, "multiple_derived"); - base2* b2 = runtime_cast(b1); + base2* b2 = runtime_pointer_cast(b1); BOOST_TEST_NE(b2, (base2*)nullptr); BOOST_TEST_EQ(b2->name, "base2"); + + BOOST_TEST_EQ(runtime_pointer_cast(&d), (unrelated*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(b1), (unrelated*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(b1), (unrelated_with_base*)nullptr); } void virtual_base() @@ -97,9 +122,9 @@ void virtual_base() using namespace boost::typeindex; multiple_virtual_derived d; base* b = &d; - multiple_virtual_derived* d2 = runtime_cast(b); - baseV1* bv1 = runtime_cast(b); - baseV2* bv2 = runtime_cast(b); + multiple_virtual_derived* d2 = runtime_pointer_cast(b); + baseV1* bv1 = runtime_pointer_cast(b); + baseV2* bv2 = runtime_pointer_cast(b); BOOST_TEST_NE(d2, (multiple_virtual_derived*)nullptr); BOOST_TEST_EQ(d2->name, "multiple_virtual_derived"); @@ -109,13 +134,77 @@ void virtual_base() BOOST_TEST_NE(bv2, (baseV2*)nullptr); BOOST_TEST_EQ(bv2->name, "baseV2"); + + BOOST_TEST_EQ(runtime_pointer_cast(b), (unrelated*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(&d), (unrelated*)nullptr); + BOOST_TEST_EQ(runtime_pointer_cast(b), (unrelated_with_base*)nullptr); +} + +void pointer_interface() +{ + using namespace boost::typeindex; + single_derived d; + base* b = &d; + single_derived* d2 = runtime_cast(b); + BOOST_TEST_NE(d2, (single_derived*)nullptr); + BOOST_TEST_EQ(d2->name, "single_derived"); + BOOST_TEST_EQ(runtime_pointer_cast(b), (unrelated*)nullptr); +} + +void reference_interface() +{ + using namespace boost::typeindex; + single_derived d; + base& b = d; + single_derived& d2 = runtime_cast(b); + BOOST_TEST_EQ(d2.name, "single_derived"); + + try { + unrelated& u = runtime_cast(b); + (void)u; + BOOST_TEST(!"should throw bad_cast"); + } + catch(...) { + } +} + +void const_pointer_interface() +{ + using namespace boost::typeindex; + const single_derived d; + base const* b = &d; + single_derived const* d2 = runtime_cast(b); + BOOST_TEST_NE(d2, (single_derived*)nullptr); + BOOST_TEST_EQ(d2->name, "single_derived"); + BOOST_TEST_EQ(runtime_pointer_cast(b), (unrelated*)nullptr); +} + +void const_reference_interface() +{ + using namespace boost::typeindex; + const single_derived d; + base const& b = d; + single_derived const& d2 = runtime_cast(b); + BOOST_TEST_EQ(d2.name, "single_derived"); + + try { + unrelated const& u = runtime_cast(b); + (void)u; + BOOST_TEST(!"should throw bad_cast"); + } + catch(...) { + } } int main() { no_base(); single_derived(); - multiple_base(); - virtual_base(); + multiple_base(); + virtual_base(); + pointer_interface(); + reference_interface(); + const_pointer_interface(); + const_reference_interface(); return boost::report_errors(); }