diff --git a/include/sdbus-c++/Message.h b/include/sdbus-c++/Message.h index f252401..133a7bc 100644 --- a/include/sdbus-c++/Message.h +++ b/include/sdbus-c++/Message.h @@ -37,6 +37,7 @@ #include #include #include +#include // Forward declarations namespace sdbus { @@ -138,6 +139,14 @@ namespace sdbus { void seal(); void rewind(bool complete); + pid_t getCredsPid() const; + uid_t getCredsUid() const; + uid_t getCredsEuid() const; + gid_t getCredsGid() const; + gid_t getCredsEgid() const; + std::vector getCredsSupplementaryGids() const; + std::string getSELinuxContext() const; + class Factory; protected: diff --git a/src/ISdBus.h b/src/ISdBus.h index 5044b06..ae40ba1 100644 --- a/src/ISdBus.h +++ b/src/ISdBus.h @@ -82,6 +82,17 @@ namespace sdbus::internal { virtual int sd_bus_flush(sd_bus *bus) = 0; virtual sd_bus *sd_bus_flush_close_unref(sd_bus *bus) = 0; + + virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) = 0; + virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) = 0; + + virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) = 0; + virtual int sd_bus_creds_get_uid(sd_bus_creds *c, uid_t *uid) = 0; + virtual int sd_bus_creds_get_euid(sd_bus_creds *c, uid_t *uid) = 0; + virtual int sd_bus_creds_get_gid(sd_bus_creds *c, gid_t *gid) = 0; + virtual int sd_bus_creds_get_egid(sd_bus_creds *c, gid_t *egid) = 0; + virtual int sd_bus_creds_get_supplementary_gids(sd_bus_creds *c, const gid_t **gids) = 0; + virtual int sd_bus_creds_get_selinux_context(sd_bus_creds *c, const char **label) = 0; }; } diff --git a/src/Message.cpp b/src/Message.cpp index 4c99541..45f8f36 100644 --- a/src/Message.cpp +++ b/src/Message.cpp @@ -616,6 +616,113 @@ bool Message::isEmpty() const return sd_bus_message_is_empty((sd_bus_message*)msg_) != 0; } +pid_t Message::getCredsPid() const +{ + uint64_t mask = SD_BUS_CREDS_PID | SD_BUS_CREDS_AUGMENT; + sd_bus_creds *creds = nullptr; + SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + + int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); + + pid_t pid = 0; + r = sdbus_->sd_bus_creds_get_pid(creds, &pid); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred pid", -r); + return pid; +} + +uid_t Message::getCredsUid() const +{ + uint64_t mask = SD_BUS_CREDS_UID | SD_BUS_CREDS_AUGMENT; + sd_bus_creds *creds = nullptr; + SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); + + uid_t uid = (uid_t)-1; + r = sdbus_->sd_bus_creds_get_uid(creds, &uid); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred uid", -r); + return uid; +} + +uid_t Message::getCredsEuid() const +{ + uint64_t mask = SD_BUS_CREDS_EUID | SD_BUS_CREDS_AUGMENT; + sd_bus_creds *creds = nullptr; + SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); + + uid_t euid = (uid_t)-1; + r = sdbus_->sd_bus_creds_get_euid(creds, &euid); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred euid", -r); + return euid; +} + +gid_t Message::getCredsGid() const +{ + uint64_t mask = SD_BUS_CREDS_GID | SD_BUS_CREDS_AUGMENT; + sd_bus_creds *creds = nullptr; + SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); + + gid_t gid = (gid_t)-1; + r = sdbus_->sd_bus_creds_get_gid(creds, &gid); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred gid", -r); + return gid; +} + +gid_t Message::getCredsEgid() const +{ + uint64_t mask = SD_BUS_CREDS_EGID | SD_BUS_CREDS_AUGMENT; + sd_bus_creds *creds = nullptr; + SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); + + gid_t egid = (gid_t)-1; + r = sdbus_->sd_bus_creds_get_egid(creds, &egid); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred egid", -r); + return egid; +} + +std::vector Message::getCredsSupplementaryGids() const +{ + uint64_t mask = SD_BUS_CREDS_SUPPLEMENTARY_GIDS | SD_BUS_CREDS_AUGMENT; + sd_bus_creds *creds = nullptr; + SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); + + const gid_t *cGids = nullptr; + r = sdbus_->sd_bus_creds_get_supplementary_gids(creds, &cGids); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred supplementary gids", -r); + + std::vector gids{}; + if (cGids != nullptr) + { + for (int i = 0; i < r; i++) + gids.push_back(cGids[i]); + } + + return gids; +} + +std::string Message::getSELinuxContext() const +{ + uint64_t mask = SD_BUS_CREDS_AUGMENT | SD_BUS_CREDS_SELINUX_CONTEXT; + sd_bus_creds *creds = nullptr; + SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); }; + int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r); + + const char *cLabel = nullptr; + r = sdbus_->sd_bus_creds_get_selinux_context(creds, &cLabel); + SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred selinux context", -r); + return cLabel; +} + void MethodCall::dontExpectReply() { auto r = sd_bus_message_set_expect_reply((sd_bus_message*)msg_, 0); @@ -717,4 +824,5 @@ PlainMessage createPlainMessage() return connection->createPlainMessage(); } + } diff --git a/src/SdBus.cpp b/src/SdBus.cpp index 3bc689c..60a622f 100644 --- a/src/SdBus.cpp +++ b/src/SdBus.cpp @@ -260,4 +260,67 @@ sd_bus* SdBus::sd_bus_flush_close_unref(sd_bus *bus) return ::sd_bus_flush_close_unref(bus); } +int SdBus::sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_query_sender_creds(m, mask, c); +} + +sd_bus_creds* SdBus::sd_bus_creds_unref(sd_bus_creds *c) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_unref(c); +} + +int SdBus::sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_get_pid(c, pid); +} + +int SdBus::sd_bus_creds_get_uid(sd_bus_creds *c, uid_t *uid) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_get_uid(c, uid); +} + +int SdBus::sd_bus_creds_get_euid(sd_bus_creds *c, uid_t *euid) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_get_euid(c, euid); +} + +int SdBus::sd_bus_creds_get_gid(sd_bus_creds *c, gid_t *gid) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_get_gid(c, gid); +} + +int SdBus::sd_bus_creds_get_egid(sd_bus_creds *c, uid_t *egid) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_get_egid(c, egid); +} + +int SdBus::sd_bus_creds_get_supplementary_gids(sd_bus_creds *c, const gid_t **gids) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_get_supplementary_gids(c, gids); +} + +int SdBus::sd_bus_creds_get_selinux_context(sd_bus_creds *c, const char **label) +{ + std::lock_guard lock(sdbusMutex_); + + return ::sd_bus_creds_get_selinux_context(c, label); +} + } diff --git a/src/SdBus.h b/src/SdBus.h index ebc0a3c..5ed4161 100644 --- a/src/SdBus.h +++ b/src/SdBus.h @@ -75,6 +75,17 @@ public: virtual int sd_bus_flush(sd_bus *bus) override; virtual sd_bus *sd_bus_flush_close_unref(sd_bus *bus) override; + virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) override; + virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) override; + + virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) override; + virtual int sd_bus_creds_get_uid(sd_bus_creds *c, uid_t *uid) override; + virtual int sd_bus_creds_get_euid(sd_bus_creds *c, uid_t *euid) override; + virtual int sd_bus_creds_get_gid(sd_bus_creds *c, gid_t *gid) override; + virtual int sd_bus_creds_get_egid(sd_bus_creds *c, gid_t *egid) override; + virtual int sd_bus_creds_get_supplementary_gids(sd_bus_creds *c, const gid_t **gids) override; + virtual int sd_bus_creds_get_selinux_context(sd_bus_creds *c, const char **label) override; + private: std::recursive_mutex sdbusMutex_; }; diff --git a/tests/unittests/mocks/SdBusMock.h b/tests/unittests/mocks/SdBusMock.h index 2a81118..6062436 100644 --- a/tests/unittests/mocks/SdBusMock.h +++ b/tests/unittests/mocks/SdBusMock.h @@ -73,6 +73,17 @@ public: MOCK_METHOD1(sd_bus_flush, int(sd_bus *bus)); MOCK_METHOD1(sd_bus_flush_close_unref, sd_bus *(sd_bus *bus)); + + MOCK_METHOD3(sd_bus_query_sender_creds, int(sd_bus_message *, uint64_t, sd_bus_creds **)); + MOCK_METHOD1(sd_bus_creds_unref, sd_bus_creds*(sd_bus_creds *)); + + MOCK_METHOD2(sd_bus_creds_get_pid, int(sd_bus_creds *, pid_t *)); + MOCK_METHOD2(sd_bus_creds_get_uid, int(sd_bus_creds *, uid_t *)); + MOCK_METHOD2(sd_bus_creds_get_euid, int(sd_bus_creds *, uid_t *)); + MOCK_METHOD2(sd_bus_creds_get_gid, int(sd_bus_creds *, gid_t *)); + MOCK_METHOD2(sd_bus_creds_get_egid, int(sd_bus_creds *, gid_t *)); + MOCK_METHOD2(sd_bus_creds_get_supplementary_gids, int(sd_bus_creds *, const gid_t **)); + MOCK_METHOD2(sd_bus_creds_get_selinux_context, int(sd_bus_creds *, const char **)); }; #endif //SDBUS_CXX_SDBUS_MOCK_H