diff --git a/include/sdbus-c++/IObject.h b/include/sdbus-c++/IObject.h index d79e116..21953bc 100644 --- a/include/sdbus-c++/IObject.h +++ b/include/sdbus-c++/IObject.h @@ -403,7 +403,7 @@ namespace sdbus { * * @throws sdbus::Error in case of failure */ - [[nodiscard]] virtual Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) = 0; + [[nodiscard]] virtual Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) const = 0; /*! * @brief Emits signal for this object path @@ -419,7 +419,7 @@ namespace sdbus { protected: // Internal API for efficiency reasons used by high-level API helper classes friend SignalEmitter; - [[nodiscard]] virtual Signal createSignal(const char* interfaceName, const char* signalName) = 0; + [[nodiscard]] virtual Signal createSignal(const char* interfaceName, const char* signalName) const = 0; }; // Out-of-line member definitions diff --git a/include/sdbus-c++/IProxy.h b/include/sdbus-c++/IProxy.h index 029127f..9d2be23 100644 --- a/include/sdbus-c++/IProxy.h +++ b/include/sdbus-c++/IProxy.h @@ -364,7 +364,7 @@ namespace sdbus { * * @throws sdbus::Error in case of failure */ - [[nodiscard]] virtual MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) = 0; + [[nodiscard]] virtual MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) const = 0; /*! * @brief Calls method on the remote D-Bus object @@ -652,7 +652,7 @@ namespace sdbus { friend AsyncMethodInvoker; friend SignalSubscriber; - [[nodiscard]] virtual MethodCall createMethodCall(const char* interfaceName, const char* methodName) = 0; + [[nodiscard]] virtual MethodCall createMethodCall(const char* interfaceName, const char* methodName) const = 0; virtual void registerSignalHandler( const char* interfaceName , const char* signalName , signal_handler signalHandler ) = 0; diff --git a/include/sdbus-c++/Message.h b/include/sdbus-c++/Message.h index 4cf5446..40814c6 100644 --- a/include/sdbus-c++/Message.h +++ b/include/sdbus-c++/Message.h @@ -58,7 +58,7 @@ namespace sdbus { class UnixFd; class MethodReply; namespace internal { - class ISdBus; + class IConnection; } } @@ -245,15 +245,15 @@ namespace sdbus { protected: Message() = default; - explicit Message(internal::ISdBus* sdbus) noexcept; - Message(void *msg, internal::ISdBus* sdbus) noexcept; - Message(void *msg, internal::ISdBus* sdbus, adopt_message_t) noexcept; + explicit Message(internal::IConnection* connection) noexcept; + Message(void *msg, internal::IConnection* connection) noexcept; + Message(void *msg, internal::IConnection* connection, adopt_message_t) noexcept; friend Factory; protected: void* msg_{}; - internal::ISdBus* sdbus_{}; + internal::IConnection* connection_{}; mutable bool ok_{true}; }; @@ -275,7 +275,7 @@ namespace sdbus { bool doesntExpectReply() const; protected: - MethodCall(void *msg, internal::ISdBus* sdbus, adopt_message_t) noexcept; + MethodCall(void *msg, internal::IConnection* connection, adopt_message_t) noexcept; private: MethodReply sendWithReply(uint64_t timeout = 0) const; diff --git a/src/Connection.cpp b/src/Connection.cpp index ff1d325..8a5f0cf 100644 --- a/src/Connection.cpp +++ b/src/Connection.cpp @@ -179,16 +179,6 @@ Connection::PollData Connection::getEventLoopPollData() const return {pollData.fd, pollData.events, timeout, eventFd_.fd}; } -const ISdBus& Connection::getSdBusInterface() const -{ - return *sdbus_.get(); -} - -ISdBus& Connection::getSdBusInterface() -{ - return *sdbus_.get(); -} - void Connection::addObjectManager(const ObjectPath& objectPath) { auto r = sdbus_->sd_bus_add_object_manager(bus_.get(), nullptr, objectPath.c_str()); @@ -488,7 +478,7 @@ PlainMessage Connection::createPlainMessage() const SDBUS_THROW_ERROR_IF(r < 0, "Failed to create a plain message", -r); - return Message::Factory::create(sdbusMsg, sdbus_.get(), adopt_message); + return Message::Factory::create(sdbusMsg, const_cast(this), adopt_message); } MethodCall Connection::createMethodCall( const ServiceName& destination @@ -515,7 +505,7 @@ MethodCall Connection::createMethodCall( const char* destination SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method call", -r); - return Message::Factory::create(sdbusMsg, sdbus_.get(), adopt_message); + return Message::Factory::create(sdbusMsg, const_cast(this), adopt_message); } Signal Connection::createSignal( const ObjectPath& objectPath @@ -535,34 +525,7 @@ Signal Connection::createSignal( const char* objectPath SDBUS_THROW_ERROR_IF(r < 0, "Failed to create signal", -r); - return Message::Factory::create(sdbusMsg, sdbus_.get(), adopt_message); -} - -MethodReply Connection::callMethod(const MethodCall& message, uint64_t timeout) -{ - // If the call expects reply, this call will block the bus connection from - // serving other messages until the reply arrives or the call times out. - auto reply = message.send(timeout); - - // Wake up event loop to process messages that may have arrived in the meantime... - wakeUpEventLoopIfMessagesInQueue(); - - return reply; -} - -Slot Connection::callMethod(const MethodCall& message, void* callback, void* userData, uint64_t timeout, return_slot_t) -{ - // TODO: Think of ways of optimizing these three locking/unlocking of sdbus mutex (merge into one call?) - auto timeoutBefore = getEventLoopPollData().timeout; - auto slot = message.send(callback, userData, timeout, return_slot); - auto timeoutAfter = getEventLoopPollData().timeout; - - // An event loop may wait in poll with timeout `t1', while in another thread an async call is made with - // timeout `t2'. If `t2' < `t1', then we have to wake up the event loop thread to update its poll timeout. - if (timeoutAfter < timeoutBefore) - notifyEventLoopToWakeUpFromPoll(); - - return slot; + return Message::Factory::create(sdbusMsg, const_cast(this), adopt_message); } void Connection::emitPropertiesChangedSignal( const ObjectPath& objectPath @@ -648,6 +611,100 @@ Slot Connection::registerSignalHandler( const char* sender return {slot, [this](void *slot){ sdbus_->sd_bus_slot_unref((sd_bus_slot*)slot); }}; } +sd_bus_message* Connection::incrementMessageRefCount(sd_bus_message* sdbusMsg) +{ + return sdbus_->sd_bus_message_ref(sdbusMsg); +} + +sd_bus_message* Connection::decrementMessageRefCount(sd_bus_message* sdbusMsg) +{ + return sdbus_->sd_bus_message_unref(sdbusMsg); +} + +int Connection::querySenderCredentials(sd_bus_message* sdbusMsg, uint64_t mask, sd_bus_creds **creds) +{ + return sdbus_->sd_bus_query_sender_creds(sdbusMsg, mask, creds); +} + +sd_bus_creds* Connection::incrementCredsRefCount(sd_bus_creds* creds) +{ + return sdbus_->sd_bus_creds_ref(creds); +} + +sd_bus_creds* Connection::decrementCredsRefCount(sd_bus_creds* creds) +{ + return sdbus_->sd_bus_creds_unref(creds); +} + +sd_bus_message* Connection::callMethod(sd_bus_message* sdbusMsg, uint64_t timeout) +{ + sd_bus_error sdbusError = SD_BUS_ERROR_NULL; + SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; + + // This call will block the bus connection from serving other messages + // until the reply arrives or the call times out. + sd_bus_message* sdbusReply{}; + auto r = sdbus_->sd_bus_call(nullptr, sdbusMsg, timeout, &sdbusError, &sdbusReply); + + if (sd_bus_error_is_set(&sdbusError)) + throw Error(Error::Name{sdbusError.name}, sdbusError.message); + + SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method", -r); + + // Wake up event loop to process messages that may have arrived in the meantime... + wakeUpEventLoopIfMessagesInQueue(); + + return sdbusReply; +} + +Slot Connection::callMethodAsync(sd_bus_message* sdbusMsg, sd_bus_message_handler_t callback, void* userData, uint64_t timeout, return_slot_t) +{ + sd_bus_slot *slot{}; + + // TODO: Think of ways of optimizing these three locking/unlocking of sdbus mutex (merge into one call?) + auto timeoutBefore = getEventLoopPollData().timeout; + auto r = sdbus_->sd_bus_call_async(nullptr, &slot, sdbusMsg, (sd_bus_message_handler_t)callback, userData, timeout); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method asynchronously", -r); + auto timeoutAfter = getEventLoopPollData().timeout; + + // An event loop may wait in poll with timeout `t1', while in another thread an async call is made with + // timeout `t2'. If `t2' < `t1', then we have to wake up the event loop thread to update its poll timeout. + if (timeoutAfter < timeoutBefore) + notifyEventLoopToWakeUpFromPoll(); + + return {slot, [this](void *slot){ sdbus_->sd_bus_slot_unref((sd_bus_slot*)slot); }}; +} + +void Connection::sendMessage(sd_bus_message* sdbusMsg) +{ + auto r = sdbus_->sd_bus_send(nullptr, sdbusMsg, nullptr); + + SDBUS_THROW_ERROR_IF(r < 0, "Failed to send D-Bus message", -r); +} + +sd_bus_message* Connection::createMethodReply(sd_bus_message* sdbusMsg) +{ + sd_bus_message* sdbusReply{}; + + auto r = sdbus_->sd_bus_message_new_method_return(sdbusMsg, &sdbusReply); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method reply", -r); + + return sdbusReply; +} + +sd_bus_message* Connection::createErrorReplyMessage(sd_bus_message* sdbusMsg, const Error& error) +{ + sd_bus_error sdbusError = SD_BUS_ERROR_NULL; + SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; + sd_bus_error_set(&sdbusError, error.getName().c_str(), error.getMessage().c_str()); + + sd_bus_message* sdbusErrorReply{}; + auto r = sdbus_->sd_bus_message_new_method_error(sdbusMsg, &sdbusErrorReply, &sdbusError); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method error reply", -r); + + return sdbusErrorReply; +} + Connection::BusPtr Connection::openBus(const BusFactory& busFactory) { sd_bus* bus{}; @@ -785,7 +842,7 @@ Message Connection::getCurrentlyProcessedMessage() const { auto* sdbusMsg = sdbus_->sd_bus_get_current_message(bus_.get()); - return Message::Factory::create(sdbusMsg, sdbus_.get()); + return Message::Factory::create(sdbusMsg, const_cast(this)); } template @@ -804,7 +861,7 @@ int Connection::sdbus_match_callback(sd_bus_message *sdbusMessage, void *userDat assert(matchInfo != nullptr); assert(matchInfo->callback); - auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection.getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection); auto ok = invokeHandlerAndCatchErrors([&](){ matchInfo->callback(std::move(message)); }, retError); @@ -817,7 +874,7 @@ int Connection::sdbus_match_install_callback(sd_bus_message *sdbusMessage, void assert(matchInfo != nullptr); assert(matchInfo->installCallback); - auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection.getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, &matchInfo->connection); auto ok = invokeHandlerAndCatchErrors([&](){ matchInfo->installCallback(std::move(message)); }, retError); diff --git a/src/Connection.h b/src/Connection.h index 0bd8b99..b68986f 100644 --- a/src/Connection.h +++ b/src/Connection.h @@ -120,9 +120,6 @@ namespace sdbus::internal { void detachSdEventLoop() override; sd_event *getSdEventLoop() override; - [[nodiscard]] const ISdBus& getSdBusInterface() const override; - [[nodiscard]] ISdBus& getSdBusInterface() override; - Slot addObjectVTable( const ObjectPath& objectPath , const InterfaceName& interfaceName , const sd_bus_vtable* vtable @@ -145,9 +142,6 @@ namespace sdbus::internal { , const char* interfaceName , const char* signalName ) const override; - MethodReply callMethod(const MethodCall& message, uint64_t timeout) override; - Slot callMethod(const MethodCall& message, void* callback, void* userData, uint64_t timeout, return_slot_t) override; - void emitPropertiesChangedSignal( const ObjectPath& objectPath , const InterfaceName& interfaceName , const std::vector& propNames ) override; @@ -169,6 +163,20 @@ namespace sdbus::internal { , void* userData , return_slot_t ) override; + sd_bus_message* incrementMessageRefCount(sd_bus_message* sdbusMsg) override; + sd_bus_message* decrementMessageRefCount(sd_bus_message* sdbusMsg) override; + + int querySenderCredentials(sd_bus_message* sdbusMsg, uint64_t mask, sd_bus_creds **creds) override; + sd_bus_creds* incrementCredsRefCount(sd_bus_creds* creds) override; + sd_bus_creds* decrementCredsRefCount(sd_bus_creds* creds) override; + + sd_bus_message* callMethod(sd_bus_message* sdbusMsg, uint64_t timeout) override; + Slot callMethodAsync(sd_bus_message* sdbusMsg, sd_bus_message_handler_t callback, void* userData, uint64_t timeout, return_slot_t) override; + void sendMessage(sd_bus_message* sdbusMsg) override; + + sd_bus_message* createMethodReply(sd_bus_message* sdbusMsg) override; + sd_bus_message* createErrorReplyMessage(sd_bus_message* sdbusMsg, const Error& error) override; + private: using BusFactory = std::function; using BusPtr = std::unique_ptr>; diff --git a/src/IConnection.h b/src/IConnection.h index 412b17c..7e104a8 100644 --- a/src/IConnection.h +++ b/src/IConnection.h @@ -51,6 +51,7 @@ namespace sdbus { using MethodName = MemberName; using SignalName = MemberName; using PropertyName = MemberName; + class Error; namespace internal { class ISdBus; } @@ -64,9 +65,6 @@ namespace sdbus::internal { public: ~IConnection() override = default; - [[nodiscard]] virtual const ISdBus& getSdBusInterface() const = 0; - [[nodiscard]] virtual ISdBus& getSdBusInterface() = 0; - [[nodiscard]] virtual Slot addObjectVTable( const ObjectPath& objectPath , const InterfaceName& interfaceName , const sd_bus_vtable* vtable @@ -89,13 +87,6 @@ namespace sdbus::internal { , const char* interfaceName , const char* signalName ) const = 0; - virtual MethodReply callMethod(const MethodCall& message, uint64_t timeout) = 0; - [[nodiscard]] virtual Slot callMethod( const MethodCall& message - , void* callback - , void* userData - , uint64_t timeout - , return_slot_t ) = 0; - virtual void emitPropertiesChangedSignal( const ObjectPath& objectPath , const InterfaceName& interfaceName , const std::vector& propNames ) = 0; @@ -116,6 +107,25 @@ namespace sdbus::internal { , sd_bus_message_handler_t callback , void* userData , return_slot_t ) = 0; + + virtual sd_bus_message* incrementMessageRefCount(sd_bus_message* sdbusMsg) = 0; + virtual sd_bus_message* decrementMessageRefCount(sd_bus_message* sdbusMsg) = 0; + + // TODO: Refactor to higher level (Creds class will ownership handling and getters) + virtual int querySenderCredentials(sd_bus_message* sdbusMsg, uint64_t mask, sd_bus_creds **creds) = 0; + virtual sd_bus_creds* incrementCredsRefCount(sd_bus_creds* creds) = 0; + virtual sd_bus_creds* decrementCredsRefCount(sd_bus_creds* creds) = 0; + + virtual sd_bus_message* callMethod(sd_bus_message* sdbusMsg, uint64_t timeout) = 0; + [[nodiscard]] virtual Slot callMethodAsync( sd_bus_message* sdbusMsg + , sd_bus_message_handler_t callback + , void* userData + , uint64_t timeout + , return_slot_t ) = 0; + virtual void sendMessage(sd_bus_message* sdbusMsg) = 0; + + virtual sd_bus_message* createMethodReply(sd_bus_message* sdbusMsg) = 0; + virtual sd_bus_message* createErrorReplyMessage(sd_bus_message* sdbusMsg, const Error& error) = 0; }; [[nodiscard]] std::unique_ptr createPseudoConnection(); diff --git a/src/ISdBus.h b/src/ISdBus.h index 95eac4c..61d0ed9 100644 --- a/src/ISdBus.h +++ b/src/ISdBus.h @@ -98,6 +98,7 @@ namespace sdbus::internal { virtual int sd_bus_message_set_destination(sd_bus_message *m, const char *destination) = 0; virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) = 0; + virtual sd_bus_creds* sd_bus_creds_ref(sd_bus_creds *c) = 0; virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) = 0; virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) = 0; diff --git a/src/Message.cpp b/src/Message.cpp index 77eb0f4..63ae3fa 100644 --- a/src/Message.cpp +++ b/src/Message.cpp @@ -30,7 +30,6 @@ #include "sdbus-c++/Types.h" #include "IConnection.h" -#include "ISdBus.h" #include "MessageUtils.h" #include "ScopeGuard.h" @@ -40,27 +39,27 @@ namespace sdbus { -Message::Message(internal::ISdBus* sdbus) noexcept - : sdbus_(sdbus) +Message::Message(internal::IConnection* connection) noexcept + : connection_(connection) { - assert(sdbus_ != nullptr); + assert(connection_ != nullptr); } -Message::Message(void *msg, internal::ISdBus* sdbus) noexcept +Message::Message(void *msg, internal::IConnection* connection) noexcept : msg_(msg) - , sdbus_(sdbus) + , connection_(connection) { assert(msg_ != nullptr); - assert(sdbus_ != nullptr); - sdbus_->sd_bus_message_ref((sd_bus_message*)msg_); + assert(connection_ != nullptr); + connection_->incrementMessageRefCount((sd_bus_message*)msg_); } -Message::Message(void *msg, internal::ISdBus* sdbus, adopt_message_t) noexcept +Message::Message(void *msg, internal::IConnection* connection, adopt_message_t) noexcept : msg_(msg) - , sdbus_(sdbus) + , connection_(connection) { assert(msg_ != nullptr); - assert(sdbus_ != nullptr); + assert(connection_ != nullptr); } Message::Message(const Message& other) noexcept @@ -71,13 +70,13 @@ Message::Message(const Message& other) noexcept Message& Message::operator=(const Message& other) noexcept { if (msg_) - sdbus_->sd_bus_message_unref((sd_bus_message*)msg_); + connection_->decrementMessageRefCount((sd_bus_message*)msg_); msg_ = other.msg_; - sdbus_ = other.sdbus_; + connection_ = other.connection_; ok_ = other.ok_; - sdbus_->sd_bus_message_ref((sd_bus_message*)msg_); + connection_->incrementMessageRefCount((sd_bus_message*)msg_); return *this; } @@ -90,12 +89,12 @@ Message::Message(Message&& other) noexcept Message& Message::operator=(Message&& other) noexcept { if (msg_) - sdbus_->sd_bus_message_unref((sd_bus_message*)msg_); + connection_->decrementMessageRefCount((sd_bus_message*)msg_); msg_ = other.msg_; other.msg_ = nullptr; - sdbus_ = other.sdbus_; - other.sdbus_ = nullptr; + connection_ = other.connection_; + other.connection_ = nullptr; ok_ = other.ok_; other.ok_ = true; @@ -105,13 +104,17 @@ Message& Message::operator=(Message&& other) noexcept Message::~Message() { if (msg_) - sdbus_->sd_bus_message_unref((sd_bus_message*)msg_); + connection_->decrementMessageRefCount((sd_bus_message*)msg_); } Message& Message::operator<<(bool item) { int intItem = item; + // Direct sd-bus method, bypassing SdBus mutex, are called here, since Message serialization/deserialization, + // as well as getter/setter methods are not thread safe by design. Additionally, they are called frequently, + // so some overhead is spared. What is thread-safe in Message class is Message constructors, copy/move operations + // and the destructor, so the Message instance can be passed from one thread to another safely. auto r = sd_bus_message_append_basic((sd_bus_message*)msg_, SD_BUS_TYPE_BOOLEAN, &intItem); SDBUS_THROW_ERROR_IF(r < 0, "Failed to serialize a bool value", -r); @@ -648,7 +651,7 @@ std::pair Message::peekType() const bool Message::isValid() const { - return msg_ != nullptr && sdbus_ != nullptr; + return msg_ != nullptr && connection_ != nullptr; } bool Message::isEmpty() const @@ -661,17 +664,20 @@ bool Message::isAtEnd(bool complete) const return sd_bus_message_at_end((sd_bus_message*)msg_, complete) > 0; } +// TODO: Create a RAII ownership class for creds with copy&move semantics, doing ref()/unref() under the hood. +// Create a method Message::querySenderCreds() that will return an object of this class by value, through IConnection and SdBus mutex. +// The class will expose methods like getPid(), getUid(), etc. that will directly call sd_bus_creds_* functions, no need for mutex here. pid_t Message::getCredsPid() const { uint64_t mask = SD_BUS_CREDS_PID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); pid_t pid = 0; - r = sdbus_->sd_bus_creds_get_pid(creds, &pid); + r = sd_bus_creds_get_pid(creds, &pid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred pid", -r); return pid; } @@ -680,12 +686,12 @@ uid_t Message::getCredsUid() const { uint64_t mask = SD_BUS_CREDS_UID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); uid_t uid = (uid_t)-1; - r = sdbus_->sd_bus_creds_get_uid(creds, &uid); + r = sd_bus_creds_get_uid(creds, &uid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred uid", -r); return uid; } @@ -694,12 +700,12 @@ uid_t Message::getCredsEuid() const { uint64_t mask = SD_BUS_CREDS_EUID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); uid_t euid = (uid_t)-1; - r = sdbus_->sd_bus_creds_get_euid(creds, &euid); + r = sd_bus_creds_get_euid(creds, &euid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred euid", -r); return euid; } @@ -708,12 +714,12 @@ gid_t Message::getCredsGid() const { uint64_t mask = SD_BUS_CREDS_GID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); gid_t gid = (gid_t)-1; - r = sdbus_->sd_bus_creds_get_gid(creds, &gid); + r = sd_bus_creds_get_gid(creds, &gid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred gid", -r); return gid; } @@ -722,12 +728,12 @@ gid_t Message::getCredsEgid() const { uint64_t mask = SD_BUS_CREDS_EGID | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); gid_t egid = (gid_t)-1; - r = sdbus_->sd_bus_creds_get_egid(creds, &egid); + r = sd_bus_creds_get_egid(creds, &egid); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred egid", -r); return egid; } @@ -736,12 +742,12 @@ std::vector Message::getCredsSupplementaryGids() const { uint64_t mask = SD_BUS_CREDS_SUPPLEMENTARY_GIDS | SD_BUS_CREDS_AUGMENT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); const gid_t *cGids = nullptr; - r = sdbus_->sd_bus_creds_get_supplementary_gids(creds, &cGids); + r = sd_bus_creds_get_supplementary_gids(creds, &cGids); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred supplementary gids", -r); std::vector gids{}; @@ -758,21 +764,21 @@ std::string Message::getSELinuxContext() const { uint64_t mask = SD_BUS_CREDS_AUGMENT | SD_BUS_CREDS_SELINUX_CONTEXT; sd_bus_creds *creds = nullptr; - SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; - int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SCOPE_EXIT{ connection_->decrementCredsRefCount(creds); }; + int r = connection_->querySenderCredentials((sd_bus_message*)msg_, mask, &creds); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); const char *cLabel = nullptr; - r = sdbus_->sd_bus_creds_get_selinux_context(creds, &cLabel); + r = sd_bus_creds_get_selinux_context(creds, &cLabel); SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred selinux context", -r); return cLabel; } MethodCall::MethodCall( void *msg - , internal::ISdBus *sdbus + , internal::IConnection *connection , adopt_message_t) noexcept - : Message(msg, sdbus, adopt_message) + : Message(msg, connection, adopt_message) { } @@ -799,70 +805,45 @@ MethodReply MethodCall::send(uint64_t timeout) const MethodReply MethodCall::sendWithReply(uint64_t timeout) const { - sd_bus_error sdbusError = SD_BUS_ERROR_NULL; - SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; + auto* sdbusReply = connection_->callMethod((sd_bus_message*)msg_, timeout); - sd_bus_message* sdbusReply{}; - auto r = sdbus_->sd_bus_call(nullptr, (sd_bus_message*)msg_, timeout, &sdbusError, &sdbusReply); - - if (sd_bus_error_is_set(&sdbusError)) - throw Error(Error::Name{sdbusError.name}, sdbusError.message); - - SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method", -r); - - return Factory::create(sdbusReply, sdbus_, adopt_message); + return Factory::create(sdbusReply, connection_, adopt_message); } MethodReply MethodCall::sendWithNoReply() const { - auto r = sdbus_->sd_bus_send(nullptr, (sd_bus_message*)msg_, nullptr); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method with no reply", -r); + connection_->sendMessage((sd_bus_message*)msg_); return Factory::create(); // No reply } Slot MethodCall::send(void* callback, void* userData, uint64_t timeout, return_slot_t) const { - sd_bus_slot* slot; - - auto r = sdbus_->sd_bus_call_async(nullptr, &slot, (sd_bus_message*)msg_, (sd_bus_message_handler_t)callback, userData, timeout); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to call method asynchronously", -r); - - return {slot, [sdbus_ = sdbus_](void *slot){ sdbus_->sd_bus_slot_unref((sd_bus_slot*)slot); }}; + return connection_->callMethodAsync((sd_bus_message*)msg_, (sd_bus_message_handler_t)callback, userData, timeout, return_slot); } MethodReply MethodCall::createReply() const { - sd_bus_message* sdbusReply{}; - auto r = sdbus_->sd_bus_message_new_method_return((sd_bus_message*)msg_, &sdbusReply); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method reply", -r); + auto* sdbusReply = connection_->createMethodReply((sd_bus_message*)msg_); - return Factory::create(sdbusReply, sdbus_, adopt_message); + return Factory::create(sdbusReply, connection_, adopt_message); } MethodReply MethodCall::createErrorReply(const Error& error) const { - sd_bus_error sdbusError = SD_BUS_ERROR_NULL; - SCOPE_EXIT{ sd_bus_error_free(&sdbusError); }; - sd_bus_error_set(&sdbusError, error.getName().c_str(), error.getMessage().c_str()); + sd_bus_message* sdbusErrorReply = connection_->createErrorReplyMessage((sd_bus_message*)msg_, error); - sd_bus_message* sdbusErrorReply{}; - auto r = sdbus_->sd_bus_message_new_method_error((sd_bus_message*)msg_, &sdbusErrorReply, &sdbusError); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to create method error reply", -r); - - return Factory::create(sdbusErrorReply, sdbus_, adopt_message); + return Factory::create(sdbusErrorReply, connection_, adopt_message); } void MethodReply::send() const { - auto r = sdbus_->sd_bus_send(nullptr, (sd_bus_message*)msg_, nullptr); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to send reply", -r); + connection_->sendMessage((sd_bus_message*)msg_); } void Signal::send() const { - auto r = sdbus_->sd_bus_send(nullptr, (sd_bus_message*)msg_, nullptr); - SDBUS_THROW_ERROR_IF(r < 0, "Failed to emit signal", -r); + connection_->sendMessage((sd_bus_message*)msg_); } void Signal::setDestination(const std::string& destination) @@ -872,7 +853,7 @@ void Signal::setDestination(const std::string& destination) void Signal::setDestination(const char* destination) { - auto r = sdbus_->sd_bus_message_set_destination((sd_bus_message*)msg_, destination); + auto r = sd_bus_message_set_destination((sd_bus_message*)msg_, destination); SDBUS_THROW_ERROR_IF(r < 0, "Failed to set signal destination", -r); } @@ -929,7 +910,7 @@ PlainMessage createPlainMessage() // This is a bit of a hack, but it enables use to work with D-Bus message locally without // the need of D-Bus daemon. This is especially useful in unit tests of both sdbus-c++ and client code. // Additionally, it's light-weight and fast solution. - auto& connection = getPseudoConnectionInstance(); + const auto& connection = getPseudoConnectionInstance(); return connection.createPlainMessage(); } diff --git a/src/MessageUtils.h b/src/MessageUtils.h index 8dd8af5..e6db992 100644 --- a/src/MessageUtils.h +++ b/src/MessageUtils.h @@ -47,15 +47,15 @@ namespace sdbus } template - static _Msg create(void *msg, internal::ISdBus* sdbus) + static _Msg create(void *msg, internal::IConnection* connection) { - return _Msg{msg, sdbus}; + return _Msg{msg, connection}; } template - static _Msg create(void *msg, internal::ISdBus* sdbus, adopt_message_t) + static _Msg create(void *msg, internal::IConnection* connection, adopt_message_t) { - return _Msg{msg, sdbus, adopt_message}; + return _Msg{msg, connection, adopt_message}; } }; } diff --git a/src/Object.cpp b/src/Object.cpp index 63bb9c7..f0fce0f 100644 --- a/src/Object.cpp +++ b/src/Object.cpp @@ -83,12 +83,12 @@ void Object::unregister() objectManagerSlot_.reset(); } -sdbus::Signal Object::createSignal(const InterfaceName& interfaceName, const SignalName& signalName) +Signal Object::createSignal(const InterfaceName& interfaceName, const SignalName& signalName) const { return connection_.createSignal(objectPath_, interfaceName, signalName); } -sdbus::Signal Object::createSignal(const char* interfaceName, const char* signalName) +Signal Object::createSignal(const char* interfaceName, const char* signalName) const { return connection_.createSignal(objectPath_.c_str(), interfaceName, signalName); } @@ -326,7 +326,7 @@ int Object::sdbus_method_callback(sd_bus_message *sdbusMessage, void *userData, assert(vtable != nullptr); assert(vtable->object != nullptr); - auto message = Message::Factory::create(sdbusMessage, &vtable->object->connection_.getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, &vtable->object->connection_); const auto* methodItem = findMethod(*vtable, message.getMemberName()); assert(methodItem != nullptr); @@ -359,7 +359,7 @@ int Object::sdbus_property_get_callback( sd_bus */*bus*/ return 1; } - auto reply = Message::Factory::create(sdbusReply, &vtable->object->connection_.getSdBusInterface()); + auto reply = Message::Factory::create(sdbusReply, &vtable->object->connection_); auto ok = invokeHandlerAndCatchErrors([&](){ propertyItem->getCallback(reply); }, retError); @@ -382,7 +382,7 @@ int Object::sdbus_property_set_callback( sd_bus */*bus*/ assert(propertyItem != nullptr); assert(propertyItem->setCallback); - auto value = Message::Factory::create(sdbusValue, &vtable->object->connection_.getSdBusInterface()); + auto value = Message::Factory::create(sdbusValue, &vtable->object->connection_); auto ok = invokeHandlerAndCatchErrors([&](){ propertyItem->setCallback(std::move(value)); }, retError); diff --git a/src/Object.h b/src/Object.h index e7d2727..2bf5eb7 100644 --- a/src/Object.h +++ b/src/Object.h @@ -53,8 +53,8 @@ namespace sdbus::internal { Slot addVTable(InterfaceName interfaceName, std::vector vtable, return_slot_t) override; void unregister() override; - sdbus::Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) override; - sdbus::Signal createSignal(const char* interfaceName, const char* signalName) override; + Signal createSignal(const InterfaceName& interfaceName, const SignalName& signalName) const override; + Signal createSignal(const char* interfaceName, const char* signalName) const override; void emitSignal(const sdbus::Signal& message) override; void emitPropertiesChangedSignal(const InterfaceName& interfaceName, const std::vector& propNames) override; void emitPropertiesChangedSignal(const char* interfaceName, const std::vector& propNames) override; diff --git a/src/Proxy.cpp b/src/Proxy.cpp index 0cab73f..6f67d7e 100644 --- a/src/Proxy.cpp +++ b/src/Proxy.cpp @@ -84,12 +84,12 @@ Proxy::Proxy( std::unique_ptr&& connection // This proxy is meant to be created, used for simple synchronous D-Bus call(s) and then dismissed. } -MethodCall Proxy::createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) +MethodCall Proxy::createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) const { return connection_->createMethodCall(destination_, objectPath_, interfaceName, methodName); } -MethodCall Proxy::createMethodCall(const char* interfaceName, const char* methodName) +MethodCall Proxy::createMethodCall(const char* interfaceName, const char* methodName) const { return connection_->createMethodCall(destination_.c_str(), objectPath_.c_str(), interfaceName, methodName); } @@ -103,7 +103,7 @@ MethodReply Proxy::callMethod(const MethodCall& message, uint64_t timeout) { SDBUS_THROW_ERROR_IF(!message.isValid(), "Invalid method call message provided", EINVAL); - return connection_->callMethod(message, timeout); + return message.send(timeout); } PendingAsyncCall Proxy::callMethodAsync(const MethodCall& message, async_reply_handler asyncReplyCallback) @@ -124,11 +124,7 @@ PendingAsyncCall Proxy::callMethodAsync(const MethodCall& message, async_reply_h , .proxy = *this , .floating = false }); - asyncCallInfo->slot = connection_->callMethod( message - , (void*)&Proxy::sdbus_async_reply_handler - , asyncCallInfo.get() - , timeout - , return_slot ); + asyncCallInfo->slot = message.send((void*)&Proxy::sdbus_async_reply_handler, asyncCallInfo.get(), timeout, return_slot); auto asyncCallInfoWeakPtr = std::weak_ptr{asyncCallInfo}; @@ -145,11 +141,7 @@ Slot Proxy::callMethodAsync(const MethodCall& message, async_reply_handler async , .proxy = *this , .floating = true }); - asyncCallInfo->slot = connection_->callMethod( message - , (void*)&Proxy::sdbus_async_reply_handler - , asyncCallInfo.get() - , timeout - , return_slot ); + asyncCallInfo->slot = message.send((void*)&Proxy::sdbus_async_reply_handler, asyncCallInfo.get(), timeout, return_slot); return {asyncCallInfo.release(), [](void *ptr){ delete static_cast(ptr); }}; } @@ -259,7 +251,7 @@ int Proxy::sdbus_async_reply_handler(sd_bus_message *sdbusMessage, void *userDat proxy.floatingAsyncCallSlots_.erase(asyncCallInfo); }; - auto message = Message::Factory::create(sdbusMessage, &proxy.connection_->getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, proxy.connection_.get()); auto ok = invokeHandlerAndCatchErrors([&] { @@ -284,8 +276,7 @@ int Proxy::sdbus_signal_handler(sd_bus_message *sdbusMessage, void *userData, sd assert(signalInfo != nullptr); assert(signalInfo->callback); - // TODO: Hide Message factory invocation under Connection API (tell, don't ask principle), then we can remove getSdBusInterface() - auto message = Message::Factory::create(sdbusMessage, &signalInfo->proxy.connection_->getSdBusInterface()); + auto message = Message::Factory::create(sdbusMessage, signalInfo->proxy.connection_.get()); auto ok = invokeHandlerAndCatchErrors([&](){ signalInfo->callback(std::move(message)); }, retError); diff --git a/src/Proxy.h b/src/Proxy.h index 4ee2047..bd42dcf 100644 --- a/src/Proxy.h +++ b/src/Proxy.h @@ -56,8 +56,8 @@ namespace sdbus::internal { , ObjectPath objectPath , dont_run_event_loop_thread_t ); - MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) override; - MethodCall createMethodCall(const char* interfaceName, const char* methodName) override; + MethodCall createMethodCall(const InterfaceName& interfaceName, const MethodName& methodName) const override; + MethodCall createMethodCall(const char* interfaceName, const char* methodName) const override; MethodReply callMethod(const MethodCall& message) override; MethodReply callMethod(const MethodCall& message, uint64_t timeout) override; PendingAsyncCall callMethodAsync(const MethodCall& message, async_reply_handler asyncReplyCallback) override; diff --git a/src/SdBus.cpp b/src/SdBus.cpp index 7828400..f42760b 100644 --- a/src/SdBus.cpp +++ b/src/SdBus.cpp @@ -454,6 +454,13 @@ int SdBus::sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_cr return ::sd_bus_query_sender_creds(m, mask, c); } +sd_bus_creds* SdBus::sd_bus_creds_ref(sd_bus_creds *c) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_ref(c); +} + sd_bus_creds* SdBus::sd_bus_creds_unref(sd_bus_creds *c) { std::lock_guard lock(sdbusMutex_); diff --git a/src/SdBus.h b/src/SdBus.h index ebe903d..87c459e 100644 --- a/src/SdBus.h +++ b/src/SdBus.h @@ -90,6 +90,7 @@ public: virtual int sd_bus_message_set_destination(sd_bus_message *m, const char *destination) override; virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) override; + virtual sd_bus_creds* sd_bus_creds_ref(sd_bus_creds *c) override; virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) override; virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) override; diff --git a/tests/unittests/mocks/SdBusMock.h b/tests/unittests/mocks/SdBusMock.h index 8216705..f20f4df 100644 --- a/tests/unittests/mocks/SdBusMock.h +++ b/tests/unittests/mocks/SdBusMock.h @@ -89,6 +89,7 @@ public: MOCK_METHOD2(sd_bus_message_set_destination, int(sd_bus_message *m, const char *destination)); MOCK_METHOD3(sd_bus_query_sender_creds, int(sd_bus_message *, uint64_t, sd_bus_creds **)); + MOCK_METHOD1(sd_bus_creds_ref, sd_bus_creds*(sd_bus_creds *)); MOCK_METHOD1(sd_bus_creds_unref, sd_bus_creds*(sd_bus_creds *)); MOCK_METHOD2(sd_bus_creds_get_pid, int(sd_bus_creds *, pid_t *));