From 107c6a1a97eebd7ac37e1ba5bcf935a541a2c849 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanislav=20Angelovi=C4=8D?= Date: Wed, 2 Oct 2024 21:22:02 +0200 Subject: [PATCH] refactor: invert dependency between Message and Connection (#457) This reorganizes the layers of abstraction in the sense how `Message` depends on `Connection` and vice versa. Now, `Message` has a link to the `Connection`. This replaces the shortcut link to the low-level `SdBus` interface that the `Message` kept. The interactions from `Message` now go through `Connection` which forwards them to `SdBus`. `Connection` is now a sole owner of the low-level `SdBus` interface. This allows for future changes around `SdBus` (e.g. a change from virtual functions back to non-virtual functions) without affecting the rest of the library. `Proxy`s and `Object`s can now send messages directly without having to go through `Connection`. The `Connection` no more depends on `Message` business-logic methods; it serves only as a factory for messages. The flow for creating messages: `Proxy`/`Object` -> `Connection` -> `SdBus` The flow for sending messages: (`Proxy`/`Object` ->) `Message` -> `Connection` -> `SdBus` This also better reflects how dependencies are managed in the underlying sd-bus library. Additionally, `getSdBusInterface()` methods are removed which was anyway planned, and improves the design by "Tell, don't ask" principle. This refactoring is the necessary enabler for other upcoming improvements (regarding sending long messages, or creds refactoring, for example). --- include/sdbus-c++/IObject.h | 4 +- include/sdbus-c++/IProxy.h | 4 +- include/sdbus-c++/Message.h | 12 +-- src/Connection.cpp | 143 +++++++++++++++++++++--------- src/Connection.h | 20 +++-- src/IConnection.h | 30 ++++--- src/ISdBus.h | 1 + src/Message.cpp | 139 +++++++++++++---------------- src/MessageUtils.h | 8 +- src/Object.cpp | 10 +-- src/Object.h | 4 +- src/Proxy.cpp | 23 ++--- src/Proxy.h | 4 +- src/SdBus.cpp | 7 ++ src/SdBus.h | 1 + tests/unittests/mocks/SdBusMock.h | 1 + 16 files changed, 234 insertions(+), 177 deletions(-) diff --git a/include/sdbus-c++/IObject.h b/include/sdbus-c++/IObject.h index d79e116..21953bc 100644 --- a/include/sdbus-c++/IObject.h +++ b/include/sdbus-c++/IObject.h @@ -403,7 +403,7 @@ namespace sdbus { * * @throws sdbus::Error in case of failure */ - [[nodiscard]] virtual Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) = 0; + [[nodiscard]] virtual Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) const = 0; /*! * @brief Emits signal for this object path @@ -419,7 +419,7 @@ namespace sdbus { protected: // Internal API for efficiency reasons used by high-level API helper classes friend SignalEmitter; - [[nodiscard]] virtual Signal createSignal(const char* interfaceName, const char* signalName) = 0; + [[nodiscard]] virtual Signal createSignal(const char* interfaceName, const char* signalName) const = 0; }; // Out-of-line member definitions diff --git a/include/sdbus-c++/IProxy.h b/include/sdbus-c++/IProxy.h index 029127f..9d2be23 100644 --- a/include/sdbus-c++/IProxy.h +++ b/include/sdbus-c++/IProxy.h @@ -364,7 +364,7 @@ namespace sdbus { * * @throws sdbus::Error in case of failure */ - [[nodiscard]] virtual MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) = 0; + [[nodiscard]] virtual MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) const = 0; /*! * @brief Calls method on the remote D-Bus object @@ -652,7 +652,7 @@ namespace sdbus { friend AsyncMethodInvoker; friend SignalSubscriber; - [[nodiscard]] virtual MethodCall createMethodCall(const char* interfaceName, const char* methodName) = 0; + [[nodiscard]] virtual MethodCall createMethodCall(const char* interfaceName, const char* methodName) const = 0; virtual void registerSignalHandler( const char* interfaceName , const char* signalName , signal_handler signalHandler ) = 0; diff --git a/include/sdbus-c++/Message.h b/include/sdbus-c++/Message.h index 4cf5446..40814c6 100644 --- a/include/sdbus-c++/Message.h +++ b/include/sdbus-c++/Message.h @@ -58,7 +58,7 @@ namespace sdbus { class UnixFd; class MethodReply; namespace internal { - class ISdBus; + class IConnection; } } @@ -245,15 +245,15 @@ namespace sdbus { protected: Message() = default; - explicit Message(internal::ISdBus* sdbus) noexcept; - Message(void *msg, internal::ISdBus* sdbus) noexcept; - Message(void *msg, internal::ISdBus* sdbus, adopt_message_t) noexcept; + explicit Message(internal::IConnection* connection) noexcept; + Message(void *msg, internal::IConnection* connection) noexcept; + Message(void *msg, internal::IConnection* connection, adopt_message_t) noexcept; friend Factory; protected: void* msg_{}; - internal::ISdBus* sdbus_{}; + internal::IConnection* connection_{}; mutable bool ok_{true}; }; @@ -275,7 +275,7 @@ namespace sdbus { bool doesntExpectReply() const; protected: - MethodCall(void *msg, internal::ISdBus* sdbus, adopt_message_t) noexcept; + MethodCall(void *msg, internal::IConnection* connection, adopt_message_t) noexcept; private: MethodReply sendWithReply(uint64_t timeout = 0) const; diff --git a/src/Connection.cpp b/src/Connection.cpp index ff1d325..8a5f0cf 100644 --- a/src/Connection.cpp +++ b/src/Connection.cpp @@ -179,16 +179,6 @@ Connection::PollData Connection::getEventLoopPollData() const return {pollData.fd, pollData.events, timeout, eventFd_.fd}; } -const ISdBus& Connection::getSdBusInterface() const -{ - return *sdbus_.get(); -} - -ISdBus& Connection::getSdBusInterface() -{ - return *sdbus_.get(); -} - void Connection::addObjectManager(const ObjectPath& objectPath) { auto r = sdbus_->sd_bus_add_object_manager(bus_.get(), nullptr, objectPath.c_str()); @@ -488,7 +478,7 @@ PlainMessage Connection::createPlainMessage() const SDBUS_THROW_ERROR_IF(r < 0, "Failed to create a plain message", -r); - return Message::Factory::create(sdbusMsg, sdbus_.get(), adopt_message); + return Message::Factory::create(sdbusMsg, const_cast(this), adopt_message); } MethodCall Connection::createMethodCall( const ServiceName& destination @@ -515,7 +505,7 @@ MethodCall Connection::createMethodCall( const char* destination SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method call", -r); - return Message::Factory::create(sdbusMsg, sdbus_.get(), adopt_message); + return Message::Factory::create(sdbusMsg, const_cast(this), adopt_message); } Signal Connection::createSignal( const ObjectPath& objectPath @@ -535,34 +525,7 @@ Signal Connection::createSignal( const char* objectPath SDBUS_THROW_ERROR_IF(r < 0, "Failed to create signal", -r); - return Message::Factory::create(sdbusMsg, sdbus_.get(), adopt_message); -} - -MethodReply Connection::callMethod(const MethodCall& message, uint64_t timeout) -{ - // If the call expects reply, this call will block the bus connection from - // serving other messages until the reply arrives or the call times out. - auto reply = message.send(timeout); - - // Wake up event loop to process messages that may have arrived in the meantime... - wakeUpEventLoopIfMessagesInQueue(); - - return reply; -} - -Slot Connection::callMethod(const MethodCall& message, void* callback, void* userData, uint64_t timeout, return_slot_t) -{ - // TODO: Think of ways of optimizing these three locking/unlocking of sdbus mutex (merge into one call?) - auto timeoutBefore = getEventLoopPollData().timeout; - auto slot = message.send(callback, userData, timeout, return_slot); - auto timeoutAfter = getEventLoopPollData().timeout; - - // An event loop may wait in poll with timeout `t1', while in another thread an async call is made with - // timeout `t2'. If `t2' < `t1', then we have to wake up the event loop thread to update its poll timeout. - if (timeoutAfter < timeoutBefore) - notifyEventLoopToWakeUpFromPoll(); - - return slot; + return Message::Factory::create(sdbusMsg, const_cast(this), adopt_message); } void Connection::emitPropertiesChangedSignal( const ObjectPath& objectPath @@ -648,6 +611,100 @@ Slot Connection::registerSignalHandler( const char* sender return {slot, [this](void *slot){ sdbus_->sd_bus_slot_unref((sd_bus_slot*)slot); }}; } +sd_bus_message* Connection::incrementMessageRefCount(sd_bus_message* sdbusMsg) +{ + return sdbus_->sd_bus_message_ref(sdbusMsg); +} + +sd_bus_message* Connection::decrementMessageRefCount(sd_bus_message* sdbusMsg) +{ + return sdbus_->sd_bus_message_unref(sdbusMsg); +} + +int Connection::querySenderCredentials(sd_bus_message* sdbusMsg, uint64_t mask, sd_bus_creds **creds) +{ + return sdbus_->sd_bus_query_sender_creds(sdbusMsg, mask, creds); +} + +sd_bus_creds* Connection::incrementCredsRefCount(sd_bus_creds* creds) +{ + return sdbus_->sd_bus_creds_ref(creds); +} + +sd_bus_creds* Connection::decrementCredsRefCount(sd_bus_creds* creds) +{ + return sdbus_->sd_bus_creds_unref(creds); +} + +sd_bus_message* Connection::callMethod(sd_bus_message* sdbusMsg, uint64_t timeout) +{ + sd_bus_error sdbusError = SD_BUS_ERROR_NULL; + SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; + + // This call will block the bus connection from serving other messages + // until the reply arrives or the call times out. + sd_bus_message* sdbusReply{}; + auto r = sdbus_->sd_bus_call(nullptr, sdbusMsg, timeout, &sdbusError, &sdbusReply); + + if (sd_bus_error_is_set(&sdbusError)) + throw Error(Error::Name{sdbusError.name}, sdbusError.message); + + SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method", -r); + + // Wake up event loop to process messages that may have arrived in the meantime... + wakeUpEventLoopIfMessagesInQueue(); + + return sdbusReply; +} + +Slot Connection::callMethodAsync(sd_bus_message* sdbusMsg, sd_bus_message_handler_t callback, void* userData, uint64_t timeout, return_slot_t) +{ + sd_bus_slot *slot{}; + + // TODO: Think of ways of optimizing these three locking/unlocking of sdbus mutex (merge into one call?) + auto timeoutBefore = getEventLoopPollData().timeout; + auto r = sdbus_->sd_bus_call_async(nullptr, &slot, sdbusMsg, (sd_bus_message_handler_t)callback, userData, timeout); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method asynchronously", -r); + auto timeoutAfter = getEventLoopPollData().timeout; + + // An event loop may wait in poll with timeout `t1', while in another thread an async call is made with + // timeout `t2'. If `t2' < `t1', then we have to wake up the event loop thread to update its poll timeout. + if (timeoutAfter < timeoutBefore) + notifyEventLoopToWakeUpFromPoll(); + + return {slot, [this](void *slot){ sdbus_->sd_bus_slot_unref((sd_bus_slot*)slot); }}; +} + +void Connection::sendMessage(sd_bus_message* sdbusMsg) +{ + auto r = sdbus_->sd_bus_send(nullptr, sdbusMsg, nullptr); + + SDBUS_THROW_ERROR_IF(r < 0, "Failed to send D-Bus message", -r); +} + +sd_bus_message* Connection::createMethodReply(sd_bus_message* sdbusMsg) +{ + sd_bus_message* sdbusReply{}; + + auto r = sdbus_->sd_bus_message_new_method_return(sdbusMsg, &sdbusReply); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method reply", -r); + + return sdbusReply; +} + +sd_bus_message* Connection::createErrorReplyMessage(sd_bus_message* sdbusMsg, const Error& error) +{ + sd_bus_error sdbusError = SD_BUS_ERROR_NULL; + SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; + sd_bus_error_set(&sdbusError, error.getName().c_str(), error.getMessage().c_str()); + + sd_bus_message* sdbusErrorReply{}; + auto r = sdbus_->sd_bus_message_new_method_error(sdbusMsg, &sdbusErrorReply, &sdbusError); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method error reply", -r); + + return sdbusErrorReply; +} + Connection::BusPtr Connection::openBus(const BusFactory& busFactory) { sd_bus* bus{}; @@ -785,7 +842,7 @@ Message Connection::getCurrentlyProcessedMessage() const { auto* sdbusMsg = sdbus_->sd_bus_get_current_message(bus_.get()); - return Message::Factory::create(sdbusMsg, sdbus_.get()); + return Message::Factory::create(sdbusMsg, const_cast(this)); } template @@ -804,7 +861,7 @@ int Connection::sdbus_match_callback(sd_bus_message *sdbusMessage, void *userDat assert(matchInfo != nullptr); assert(matchInfo->callback); - auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection.getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection); auto ok = invokeHandlerAndCatchErrors([&](){ matchInfo->callback(std::move(message)); }, retError); @@ -817,7 +874,7 @@ int Connection::sdbus_match_install_callback(sd_bus_message *sdbusMessage, void assert(matchInfo != nullptr); assert(matchInfo->installCallback); - auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection.getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection); auto ok = invokeHandlerAndCatchErrors([&](){ matchInfo->installCallback(std::move(message)); }, retError); diff --git a/src/Connection.h b/src/Connection.h index 0bd8b99..b68986f 100644 --- a/src/Connection.h +++ b/src/Connection.h @@ -120,9 +120,6 @@ namespace sdbus::internal { void detachSdEventLoop() override; sd_event *getSdEventLoop() override; - [[nodiscard]] const ISdBus& getSdBusInterface() const override; - [[nodiscard]] ISdBus& getSdBusInterface() override; - Slot addObjectVTable( const ObjectPath& objectPath , const InterfaceName& interfaceName , const sd_bus_vtable* vtable @@ -145,9 +142,6 @@ namespace sdbus::internal { , const char* interfaceName , const char* signalName ) const override; - MethodReply callMethod(const MethodCall& message, uint64_t timeout) override; - Slot callMethod(const MethodCall& message, void* callback, void* userData, uint64_t timeout, return_slot_t) override; - void emitPropertiesChangedSignal( const ObjectPath& objectPath , const InterfaceName& interfaceName , const std::vector& propNames ) override; @@ -169,6 +163,20 @@ namespace sdbus::internal { , void* userData , return_slot_t ) override; + sd_bus_message* incrementMessageRefCount(sd_bus_message* sdbusMsg) override; + sd_bus_message* decrementMessageRefCount(sd_bus_message* sdbusMsg) override; + + int querySenderCredentials(sd_bus_message* sdbusMsg, uint64_t mask, sd_bus_creds **creds) override; + sd_bus_creds* incrementCredsRefCount(sd_bus_creds* creds) override; + sd_bus_creds* decrementCredsRefCount(sd_bus_creds* creds) override; + + sd_bus_message* callMethod(sd_bus_message* sdbusMsg, uint64_t timeout) override; + Slot callMethodAsync(sd_bus_message* sdbusMsg, sd_bus_message_handler_t callback, void* userData, uint64_t timeout, return_slot_t) override; + void sendMessage(sd_bus_message* sdbusMsg) override; + + sd_bus_message* createMethodReply(sd_bus_message* sdbusMsg) override; + sd_bus_message* createErrorReplyMessage(sd_bus_message* sdbusMsg, const Error& error) override; + private: using BusFactory = std::function; using BusPtr = std::unique_ptr>; diff --git a/src/IConnection.h b/src/IConnection.h index 412b17c..7e104a8 100644 --- a/src/IConnection.h +++ b/src/IConnection.h @@ -51,6 +51,7 @@ namespace sdbus { using MethodName = MemberName; using SignalName = MemberName; using PropertyName = MemberName; + class Error; namespace internal { class ISdBus; } @@ -64,9 +65,6 @@ namespace sdbus::internal { public: ~IConnection() override = default; - [[nodiscard]] virtual const ISdBus& getSdBusInterface() const = 0; - [[nodiscard]] virtual ISdBus& getSdBusInterface() = 0; - [[nodiscard]] virtual Slot addObjectVTable( const ObjectPath& objectPath , const InterfaceName& interfaceName , const sd_bus_vtable* vtable @@ -89,13 +87,6 @@ namespace sdbus::internal { , const char* interfaceName , const char* signalName ) const = 0; - virtual MethodReply callMethod(const MethodCall& message, uint64_t timeout) = 0; - [[nodiscard]] virtual Slot callMethod( const MethodCall& message - , void* callback - , void* userData - , uint64_t timeout - , return_slot_t ) = 0; - virtual void emitPropertiesChangedSignal( const ObjectPath& objectPath , const InterfaceName& interfaceName , const std::vector& propNames ) = 0; @@ -116,6 +107,25 @@ namespace sdbus::internal { , sd_bus_message_handler_t callback , void* userData , return_slot_t ) = 0; + + virtual sd_bus_message* incrementMessageRefCount(sd_bus_message* sdbusMsg) = 0; + virtual sd_bus_message* decrementMessageRefCount(sd_bus_message* sdbusMsg) = 0; + + // TODO: Refactor to higher level (Creds class will ownership handling and getters) + virtual int querySenderCredentials(sd_bus_message* sdbusMsg, uint64_t mask, sd_bus_creds **creds) = 0; + virtual sd_bus_creds* incrementCredsRefCount(sd_bus_creds* creds) = 0; + virtual sd_bus_creds* decrementCredsRefCount(sd_bus_creds* creds) = 0; + + virtual sd_bus_message* callMethod(sd_bus_message* sdbusMsg, uint64_t timeout) = 0; + [[nodiscard]] virtual Slot callMethodAsync( sd_bus_message* sdbusMsg + , sd_bus_message_handler_t callback + , void* userData + , uint64_t timeout + , return_slot_t ) = 0; + virtual void sendMessage(sd_bus_message* sdbusMsg) = 0; + + virtual sd_bus_message* createMethodReply(sd_bus_message* sdbusMsg) = 0; + virtual sd_bus_message* createErrorReplyMessage(sd_bus_message* sdbusMsg, const Error& error) = 0; }; [[nodiscard]] std::unique_ptr createPseudoConnection(); diff --git a/src/ISdBus.h b/src/ISdBus.h index 95eac4c..61d0ed9 100644 --- a/src/ISdBus.h +++ b/src/ISdBus.h @@ -98,6 +98,7 @@ namespace sdbus::internal { virtual int sd_bus_message_set_destination(sd_bus_message *m, const char *destination) = 0; virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) = 0; + virtual sd_bus_creds* sd_bus_creds_ref(sd_bus_creds *c) = 0; virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) = 0; virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) = 0; diff --git a/src/Message.cpp b/src/Message.cpp index 77eb0f4..63ae3fa 100644 --- a/src/Message.cpp +++ b/src/Message.cpp @@ -30,7 +30,6 @@ #include "sdbus-c++/Types.h" #include "IConnection.h" -#include "ISdBus.h" #include "MessageUtils.h" #include "ScopeGuard.h" @@ -40,27 +39,27 @@ namespace sdbus { -Message::Message(internal::ISdBus* sdbus) noexcept - : sdbus_(sdbus) +Message::Message(internal::IConnection* connection) noexcept + : connection_(connection) { - assert(sdbus_ != nullptr); + assert(connection_ != nullptr); } -Message::Message(void *msg, internal::ISdBus* sdbus) noexcept +Message::Message(void *msg, internal::IConnection* connection) noexcept : msg_(msg) - , sdbus_(sdbus) + , connection_(connection) { assert(msg_ != nullptr); - assert(sdbus_ != nullptr); - sdbus_->sd_bus_message_ref((sd_bus_message*)msg_); + assert(connection_ != nullptr); + connection_->incrementMessageRefCount((sd_bus_message*)msg_); } -Message::Message(void *msg, internal::ISdBus* sdbus, adopt_message_t) noexcept +Message::Message(void *msg, internal::IConnection* connection, adopt_message_t) noexcept : msg_(msg) - , sdbus_(sdbus) + , connection_(connection) { assert(msg_ != nullptr); - assert(sdbus_ != nullptr); + assert(connection_ != nullptr); } Message::Message(const Message& other) noexcept @@ -71,13 +70,13 @@ Message::Message(const Message& other) noexcept Message& Message::operator=(const Message& other) noexcept { if (msg_) - sdbus_->sd_bus_message_unref((sd_bus_message*)msg_); + connection_->decrementMessageRefCount((sd_bus_message*)msg_); msg_ = other.msg_; - sdbus_ = other.sdbus_; + connection_ = other.connection_; ok_ = other.ok_; - sdbus_->sd_bus_message_ref((sd_bus_message*)msg_); + connection_->incrementMessageRefCount((sd_bus_message*)msg_); return *this; } @@ -90,12 +89,12 @@ Message::Message(Message&& other) noexcept Message& Message::operator=(Message&& other) noexcept { if (msg_) - sdbus_->sd_bus_message_unref((sd_bus_message*)msg_); + connection_->decrementMessageRefCount((sd_bus_message*)msg_); msg_ = other.msg_; other.msg_ = nullptr; - sdbus_ = other.sdbus_; - other.sdbus_ = nullptr; + connection_ = other.connection_; + other.connection_ = nullptr; ok_ = other.ok_; other.ok_ = true; @@ -105,13 +104,17 @@ Message& Message::operator=(Message&& other) noexcept Message::~Message() { if (msg_) - sdbus_->sd_bus_message_unref((sd_bus_message*)msg_); + connection_->decrementMessageRefCount((sd_bus_message*)msg_); } Message& Message::operator<<(bool item) { int intItem = item; + // Direct sd-bus method, bypassing SdBus mutex, are called here, since Message serialization/deserialization, + // as well as getter/setter methods are not thread safe by design. Additionally, they are called frequently, + // so some overhead is spared. What is thread-safe in Message class is Message constructors, copy/move operations + // and the destructor, so the Message instance can be passed from one thread to another safely. auto r = sd_bus_message_append_basic((sd_bus_message*)msg_, SD_BUS_TYPE_BOOLEAN, &intItem); SDBUS_THROW_ERROR_IF(r < 0, "Failed to serialize a bool value", -r); @@ -648,7 +651,7 @@ std::pair Message::peekType() const bool Message::isValid() const { - return msg_ != nullptr && sdbus_ != nullptr; + return msg_ != nullptr && connection_ != nullptr; } bool Message::isEmpty() const @@ -661,17 +664,20 @@ bool Message::isAtEnd(bool complete) const return sd_bus_message_at_end((sd_bus_message*)msg_, complete) > 0; } +// TODO: Create a RAII ownership class for creds with copy&move semantics, doing ref()/unref() under the hood. +// Create a method Message::querySenderCreds() that will return an object of this class by value, through IConnection and SdBus mutex. +// The class will expose methods like getPid(), getUid(), etc. that will directly call sd_bus_creds_* functions, no need for mutex here. pid_t Message::getCredsPid() const { uint64_t mask = SD_BUS_CREDS_PID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); pid_t pid = 0; - r = sdbus_->sd_bus_creds_get_pid(creds, &pid); + r = sd_bus_creds_get_pid(creds, &pid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred pid", -r); return pid; } @@ -680,12 +686,12 @@ uid_t Message::getCredsUid() const { uint64_t mask = SD_BUS_CREDS_UID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); uid_t uid = (uid_t)-1; - r = sdbus_->sd_bus_creds_get_uid(creds, &uid); + r = sd_bus_creds_get_uid(creds, &uid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred uid", -r); return uid; } @@ -694,12 +700,12 @@ uid_t Message::getCredsEuid() const { uint64_t mask = SD_BUS_CREDS_EUID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); uid_t euid = (uid_t)-1; - r = sdbus_->sd_bus_creds_get_euid(creds, &euid); + r = sd_bus_creds_get_euid(creds, &euid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred euid", -r); return euid; } @@ -708,12 +714,12 @@ gid_t Message::getCredsGid() const { uint64_t mask = SD_BUS_CREDS_GID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); gid_t gid = (gid_t)-1; - r = sdbus_->sd_bus_creds_get_gid(creds, &gid); + r = sd_bus_creds_get_gid(creds, &gid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred gid", -r); return gid; } @@ -722,12 +728,12 @@ gid_t Message::getCredsEgid() const { uint64_t mask = SD_BUS_CREDS_EGID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); gid_t egid = (gid_t)-1; - r = sdbus_->sd_bus_creds_get_egid(creds, &egid); + r = sd_bus_creds_get_egid(creds, &egid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred egid", -r); return egid; } @@ -736,12 +742,12 @@ std::vector Message::getCredsSupplementaryGids() const { uint64_t mask = SD_BUS_CREDS_SUPPLEMENTARY_GIDS | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); const gid_t *cGids = nullptr; - r = sdbus_->sd_bus_creds_get_supplementary_gids(creds, &cGids); + r = sd_bus_creds_get_supplementary_gids(creds, &cGids); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred supplementary gids", -r); std::vector gids{}; @@ -758,21 +764,21 @@ std::string Message::getSELinuxContext() const { uint64_t mask = SD_BUS_CREDS_AUGMENT | SD_BUS_CREDS_SELINUX_CONTEXT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); const char *cLabel = nullptr; - r = sdbus_->sd_bus_creds_get_selinux_context(creds, &cLabel); + r = sd_bus_creds_get_selinux_context(creds, &cLabel); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred selinux context", -r); return cLabel; } MethodCall::MethodCall( void *msg - , internal::ISdBus *sdbus + , internal::IConnection *connection , adopt_message_t) noexcept - : Message(msg, sdbus, adopt_message) + : Message(msg, connection, adopt_message) { } @@ -799,70 +805,45 @@ MethodReply MethodCall::send(uint64_t timeout) const MethodReply MethodCall::sendWithReply(uint64_t timeout) const { - sd_bus_error sdbusError = SD_BUS_ERROR_NULL; - SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; + auto* sdbusReply = connection_->callMethod((sd_bus_message*)msg_, timeout); - sd_bus_message* sdbusReply{}; - auto r = sdbus_->sd_bus_call(nullptr, (sd_bus_message*)msg_, timeout, &sdbusError, &sdbusReply); - - if (sd_bus_error_is_set(&sdbusError)) - throw Error(Error::Name{sdbusError.name}, sdbusError.message); - - SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method", -r); - - return Factory::create(sdbusReply, sdbus_, adopt_message); + return Factory::create(sdbusReply, connection_, adopt_message); } MethodReply MethodCall::sendWithNoReply() const { - auto r = sdbus_->sd_bus_send(nullptr, (sd_bus_message*)msg_, nullptr); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method with no reply", -r); + connection_->sendMessage((sd_bus_message*)msg_); return Factory::create(); // No reply } Slot MethodCall::send(void* callback, void* userData, uint64_t timeout, return_slot_t) const { - sd_bus_slot* slot; - - auto r = sdbus_->sd_bus_call_async(nullptr, &slot, (sd_bus_message*)msg_, (sd_bus_message_handler_t)callback, userData, timeout); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method asynchronously", -r); - - return {slot, [sdbus_ = sdbus_](void *slot){ sdbus_->sd_bus_slot_unref((sd_bus_slot*)slot); }}; + return connection_->callMethodAsync((sd_bus_message*)msg_, (sd_bus_message_handler_t)callback, userData, timeout, return_slot); } MethodReply MethodCall::createReply() const { - sd_bus_message* sdbusReply{}; - auto r = sdbus_->sd_bus_message_new_method_return((sd_bus_message*)msg_, &sdbusReply); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method reply", -r); + auto* sdbusReply = connection_->createMethodReply((sd_bus_message*)msg_); - return Factory::create(sdbusReply, sdbus_, adopt_message); + return Factory::create(sdbusReply, connection_, adopt_message); } MethodReply MethodCall::createErrorReply(const Error& error) const { - sd_bus_error sdbusError = SD_BUS_ERROR_NULL; - SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; - sd_bus_error_set(&sdbusError, error.getName().c_str(), error.getMessage().c_str()); + sd_bus_message* sdbusErrorReply = connection_->createErrorReplyMessage((sd_bus_message*)msg_, error); - sd_bus_message* sdbusErrorReply{}; - auto r = sdbus_->sd_bus_message_new_method_error((sd_bus_message*)msg_, &sdbusErrorReply, &sdbusError); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method error reply", -r); - - return Factory::create(sdbusErrorReply, sdbus_, adopt_message); + return Factory::create(sdbusErrorReply, connection_, adopt_message); } void MethodReply::send() const { - auto r = sdbus_->sd_bus_send(nullptr, (sd_bus_message*)msg_, nullptr); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to send reply", -r); + connection_->sendMessage((sd_bus_message*)msg_); } void Signal::send() const { - auto r = sdbus_->sd_bus_send(nullptr, (sd_bus_message*)msg_, nullptr); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to emit signal", -r); + connection_->sendMessage((sd_bus_message*)msg_); } void Signal::setDestination(const std::string& destination) @@ -872,7 +853,7 @@ void Signal::setDestination(const std::string& destination) void Signal::setDestination(const char* destination) { - auto r = sdbus_->sd_bus_message_set_destination((sd_bus_message*)msg_, destination); + auto r = sd_bus_message_set_destination((sd_bus_message*)msg_, destination); SDBUS_THROW_ERROR_IF(r < 0, "Failed to set signal destination", -r); } @@ -929,7 +910,7 @@ PlainMessage createPlainMessage() // This is a bit of a hack, but it enables use to work with D-Bus message locally without // the need of D-Bus daemon. This is especially useful in unit tests of both sdbus-c++ and client code. // Additionally, it's light-weight and fast solution. - auto& connection = getPseudoConnectionInstance(); + const auto& connection = getPseudoConnectionInstance(); return connection.createPlainMessage(); } diff --git a/src/MessageUtils.h b/src/MessageUtils.h index 8dd8af5..e6db992 100644 --- a/src/MessageUtils.h +++ b/src/MessageUtils.h @@ -47,15 +47,15 @@ namespace sdbus } template - static _Msg create(void *msg, internal::ISdBus* sdbus) + static _Msg create(void *msg, internal::IConnection* connection) { - return _Msg{msg, sdbus}; + return _Msg{msg, connection}; } template - static _Msg create(void *msg, internal::ISdBus* sdbus, adopt_message_t) + static _Msg create(void *msg, internal::IConnection* connection, adopt_message_t) { - return _Msg{msg, sdbus, adopt_message}; + return _Msg{msg, connection, adopt_message}; } }; } diff --git a/src/Object.cpp b/src/Object.cpp index 63bb9c7..f0fce0f 100644 --- a/src/Object.cpp +++ b/src/Object.cpp @@ -83,12 +83,12 @@ void Object::unregister() objectManagerSlot_.reset(); } -sdbus::Signal Object::createSignal(const InterfaceName& interfaceName, const SignalName& signalName) +Signal Object::createSignal(const InterfaceName& interfaceName, const SignalName& signalName) const { return connection_.createSignal(objectPath_, interfaceName, signalName); } -sdbus::Signal Object::createSignal(const char* interfaceName, const char* signalName) +Signal Object::createSignal(const char* interfaceName, const char* signalName) const { return connection_.createSignal(objectPath_.c_str(), interfaceName, signalName); } @@ -326,7 +326,7 @@ int Object::sdbus_method_callback(sd_bus_message *sdbusMessage, void *userData, assert(vtable != nullptr); assert(vtable->object != nullptr); - auto message = Message::Factory::create(sdbusMessage, &vtable->object->connection_.getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, &vtable->object->connection_); const auto* methodItem = findMethod(*vtable, message.getMemberName()); assert(methodItem != nullptr); @@ -359,7 +359,7 @@ int Object::sdbus_property_get_callback( sd_bus */*bus*/ return 1; } - auto reply = Message::Factory::create(sdbusReply, &vtable->object->connection_.getSdBusInterface()); + auto reply = Message::Factory::create(sdbusReply, &vtable->object->connection_); auto ok = invokeHandlerAndCatchErrors([&](){ propertyItem->getCallback(reply); }, retError); @@ -382,7 +382,7 @@ int Object::sdbus_property_set_callback( sd_bus */*bus*/ assert(propertyItem != nullptr); assert(propertyItem->setCallback); - auto value = Message::Factory::create(sdbusValue, &vtable->object->connection_.getSdBusInterface()); + auto value = Message::Factory::create(sdbusValue, &vtable->object->connection_); auto ok = invokeHandlerAndCatchErrors([&](){ propertyItem->setCallback(std::move(value)); }, retError); diff --git a/src/Object.h b/src/Object.h index e7d2727..2bf5eb7 100644 --- a/src/Object.h +++ b/src/Object.h @@ -53,8 +53,8 @@ namespace sdbus::internal { Slot addVTable(InterfaceName interfaceName, std::vector vtable, return_slot_t) override; void unregister() override; - sdbus::Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) override; - sdbus::Signal createSignal(const char* interfaceName, const char* signalName) override; + Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) const override; + Signal createSignal(const char* interfaceName, const char* signalName) const override; void emitSignal(const sdbus::Signal& message) override; void emitPropertiesChangedSignal(const InterfaceName& interfaceName, const std::vector& propNames) override; void emitPropertiesChangedSignal(const char* interfaceName, const std::vector& propNames) override; diff --git a/src/Proxy.cpp b/src/Proxy.cpp index 0cab73f..6f67d7e 100644 --- a/src/Proxy.cpp +++ b/src/Proxy.cpp @@ -84,12 +84,12 @@ Proxy::Proxy( std::unique_ptr&& connection // This proxy is meant to be created, used for simple synchronous D-Bus call(s) and then dismissed. } -MethodCall Proxy::createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) +MethodCall Proxy::createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) const { return connection_->createMethodCall(destination_, objectPath_, interfaceName, methodName); } -MethodCall Proxy::createMethodCall(const char* interfaceName, const char* methodName) +MethodCall Proxy::createMethodCall(const char* interfaceName, const char* methodName) const { return connection_->createMethodCall(destination_.c_str(), objectPath_.c_str(), interfaceName, methodName); } @@ -103,7 +103,7 @@ MethodReply Proxy::callMethod(const MethodCall& message, uint64_t timeout) { SDBUS_THROW_ERROR_IF(!message.isValid(), "Invalid method call message provided", EINVAL); - return connection_->callMethod(message, timeout); + return message.send(timeout); } PendingAsyncCall Proxy::callMethodAsync(const MethodCall& message, async_reply_handler asyncReplyCallback) @@ -124,11 +124,7 @@ PendingAsyncCall Proxy::callMethodAsync(const MethodCall& message, async_reply_h , .proxy = *this , .floating = false }); - asyncCallInfo->slot = connection_->callMethod( message - , (void*)&Proxy::sdbus_async_reply_handler - , asyncCallInfo.get() - , timeout - , return_slot ); + asyncCallInfo->slot = message.send((void*)&Proxy::sdbus_async_reply_handler, asyncCallInfo.get(), timeout, return_slot); auto asyncCallInfoWeakPtr = std::weak_ptr{asyncCallInfo}; @@ -145,11 +141,7 @@ Slot Proxy::callMethodAsync(const MethodCall& message, async_reply_handler async , .proxy = *this , .floating = true }); - asyncCallInfo->slot = connection_->callMethod( message - , (void*)&Proxy::sdbus_async_reply_handler - , asyncCallInfo.get() - , timeout - , return_slot ); + asyncCallInfo->slot = message.send((void*)&Proxy::sdbus_async_reply_handler, asyncCallInfo.get(), timeout, return_slot); return {asyncCallInfo.release(), [](void *ptr){ delete static_cast(ptr); }}; } @@ -259,7 +251,7 @@ int Proxy::sdbus_async_reply_handler(sd_bus_message *sdbusMessage, void *userDat proxy.floatingAsyncCallSlots_.erase(asyncCallInfo); }; - auto message = Message::Factory::create(sdbusMessage, &proxy.connection_->getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, proxy.connection_.get()); auto ok = invokeHandlerAndCatchErrors([&] { @@ -284,8 +276,7 @@ int Proxy::sdbus_signal_handler(sd_bus_message *sdbusMessage, void *userData, sd assert(signalInfo != nullptr); assert(signalInfo->callback); - // TODO: Hide Message factory invocation under Connection API (tell, don't ask principle), then we can remove getSdBusInterface() - auto message = Message::Factory::create(sdbusMessage, &signalInfo->proxy.connection_->getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, signalInfo->proxy.connection_.get()); auto ok = invokeHandlerAndCatchErrors([&](){ signalInfo->callback(std::move(message)); }, retError); diff --git a/src/Proxy.h b/src/Proxy.h index 4ee2047..bd42dcf 100644 --- a/src/Proxy.h +++ b/src/Proxy.h @@ -56,8 +56,8 @@ namespace sdbus::internal { , ObjectPath objectPath , dont_run_event_loop_thread_t ); - MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) override; - MethodCall createMethodCall(const char* interfaceName, const char* methodName) override; + MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) const override; + MethodCall createMethodCall(const char* interfaceName, const char* methodName) const override; MethodReply callMethod(const MethodCall& message) override; MethodReply callMethod(const MethodCall& message, uint64_t timeout) override; PendingAsyncCall callMethodAsync(const MethodCall& message, async_reply_handler asyncReplyCallback) override; diff --git a/src/SdBus.cpp b/src/SdBus.cpp index 7828400..f42760b 100644 --- a/src/SdBus.cpp +++ b/src/SdBus.cpp @@ -454,6 +454,13 @@ int SdBus::sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_cr return ::sd_bus_query_sender_creds(m, mask, c); } +sd_bus_creds* SdBus::sd_bus_creds_ref(sd_bus_creds *c) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_ref(c); +} + sd_bus_creds* SdBus::sd_bus_creds_unref(sd_bus_creds *c) { std::lock_guard lock(sdbusMutex_); diff --git a/src/SdBus.h b/src/SdBus.h index ebe903d..87c459e 100644 --- a/src/SdBus.h +++ b/src/SdBus.h @@ -90,6 +90,7 @@ public: virtual int sd_bus_message_set_destination(sd_bus_message *m, const char *destination) override; virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) override; + virtual sd_bus_creds* sd_bus_creds_ref(sd_bus_creds *c) override; virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) override; virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) override; diff --git a/tests/unittests/mocks/SdBusMock.h b/tests/unittests/mocks/SdBusMock.h index 8216705..f20f4df 100644 --- a/tests/unittests/mocks/SdBusMock.h +++ b/tests/unittests/mocks/SdBusMock.h @@ -89,6 +89,7 @@ public: MOCK_METHOD2(sd_bus_message_set_destination, int(sd_bus_message *m, const char *destination)); MOCK_METHOD3(sd_bus_query_sender_creds, int(sd_bus_message *, uint64_t, sd_bus_creds **)); + MOCK_METHOD1(sd_bus_creds_ref, sd_bus_creds*(sd_bus_creds *)); MOCK_METHOD1(sd_bus_creds_unref, sd_bus_creds*(sd_bus_creds *)); MOCK_METHOD2(sd_bus_creds_get_pid, int(sd_bus_creds *, pid_t *));