feat: add API to get message credentials (#151)

* sdbus-cpp: Add API to get message credentials

Signed-off-by: Alexander Livenets <a.livenets@gmail.com>

* fix: add <sys/types.h> include for gid_t and other types

Co-authored-by: Stanislav Angelovič <angelovic.s@gmail.com>
This commit is contained in:
alivenets
2021-03-12 14:14:23 +01:00
committed by GitHub
parent 3f74512f8e
commit 5e03e78451
6 changed files with 213 additions and 0 deletions

View File

@ -37,6 +37,7 @@
#include <cstdint>
#include <cassert>
#include <functional>
#include <sys/types.h>
// Forward declarations
namespace sdbus {
@ -138,6 +139,14 @@ namespace sdbus {
void seal();
void rewind(bool complete);
pid_t getCredsPid() const;
uid_t getCredsUid() const;
uid_t getCredsEuid() const;
gid_t getCredsGid() const;
gid_t getCredsEgid() const;
std::vector<gid_t> getCredsSupplementaryGids() const;
std::string getSELinuxContext() const;
class Factory;
protected:

View File

@ -82,6 +82,17 @@ namespace sdbus::internal {
virtual int sd_bus_flush(sd_bus *bus) = 0;
virtual sd_bus *sd_bus_flush_close_unref(sd_bus *bus) = 0;
virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) = 0;
virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) = 0;
virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) = 0;
virtual int sd_bus_creds_get_uid(sd_bus_creds *c, uid_t *uid) = 0;
virtual int sd_bus_creds_get_euid(sd_bus_creds *c, uid_t *uid) = 0;
virtual int sd_bus_creds_get_gid(sd_bus_creds *c, gid_t *gid) = 0;
virtual int sd_bus_creds_get_egid(sd_bus_creds *c, gid_t *egid) = 0;
virtual int sd_bus_creds_get_supplementary_gids(sd_bus_creds *c, const gid_t **gids) = 0;
virtual int sd_bus_creds_get_selinux_context(sd_bus_creds *c, const char **label) = 0;
};
}

View File

@ -616,6 +616,113 @@ bool Message::isEmpty() const
return sd_bus_message_is_empty((sd_bus_message*)msg_) != 0;
}
pid_t Message::getCredsPid() const
{
uint64_t mask = SD_BUS_CREDS_PID | SD_BUS_CREDS_AUGMENT;
sd_bus_creds *creds = nullptr;
SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); };
int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r);
pid_t pid = 0;
r = sdbus_->sd_bus_creds_get_pid(creds, &pid);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred pid", -r);
return pid;
}
uid_t Message::getCredsUid() const
{
uint64_t mask = SD_BUS_CREDS_UID | SD_BUS_CREDS_AUGMENT;
sd_bus_creds *creds = nullptr;
SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); };
int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r);
uid_t uid = (uid_t)-1;
r = sdbus_->sd_bus_creds_get_uid(creds, &uid);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred uid", -r);
return uid;
}
uid_t Message::getCredsEuid() const
{
uint64_t mask = SD_BUS_CREDS_EUID | SD_BUS_CREDS_AUGMENT;
sd_bus_creds *creds = nullptr;
SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); };
int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r);
uid_t euid = (uid_t)-1;
r = sdbus_->sd_bus_creds_get_euid(creds, &euid);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred euid", -r);
return euid;
}
gid_t Message::getCredsGid() const
{
uint64_t mask = SD_BUS_CREDS_GID | SD_BUS_CREDS_AUGMENT;
sd_bus_creds *creds = nullptr;
SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); };
int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r);
gid_t gid = (gid_t)-1;
r = sdbus_->sd_bus_creds_get_gid(creds, &gid);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred gid", -r);
return gid;
}
gid_t Message::getCredsEgid() const
{
uint64_t mask = SD_BUS_CREDS_EGID | SD_BUS_CREDS_AUGMENT;
sd_bus_creds *creds = nullptr;
SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); };
int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r);
gid_t egid = (gid_t)-1;
r = sdbus_->sd_bus_creds_get_egid(creds, &egid);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred egid", -r);
return egid;
}
std::vector<gid_t> Message::getCredsSupplementaryGids() const
{
uint64_t mask = SD_BUS_CREDS_SUPPLEMENTARY_GIDS | SD_BUS_CREDS_AUGMENT;
sd_bus_creds *creds = nullptr;
SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); };
int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r);
const gid_t *cGids = nullptr;
r = sdbus_->sd_bus_creds_get_supplementary_gids(creds, &cGids);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred supplementary gids", -r);
std::vector<gid_t> gids{};
if (cGids != nullptr)
{
for (int i = 0; i < r; i++)
gids.push_back(cGids[i]);
}
return gids;
}
std::string Message::getSELinuxContext() const
{
uint64_t mask = SD_BUS_CREDS_AUGMENT | SD_BUS_CREDS_SELINUX_CONTEXT;
sd_bus_creds *creds = nullptr;
SCOPE_EXIT{ sdbus_->sd_bus_creds_unref(creds); };
int r = sdbus_->sd_bus_query_sender_creds((sd_bus_message*)msg_, mask, &creds);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus creds", -r);
const char *cLabel = nullptr;
r = sdbus_->sd_bus_creds_get_selinux_context(creds, &cLabel);
SDBUS_THROW_ERROR_IF(r < 0, "Failed to get bus cred selinux context", -r);
return cLabel;
}
void MethodCall::dontExpectReply()
{
auto r = sd_bus_message_set_expect_reply((sd_bus_message*)msg_, 0);
@ -717,4 +824,5 @@ PlainMessage createPlainMessage()
return connection->createPlainMessage();
}
}

View File

@ -260,4 +260,67 @@ sd_bus* SdBus::sd_bus_flush_close_unref(sd_bus *bus)
return ::sd_bus_flush_close_unref(bus);
}
int SdBus::sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_query_sender_creds(m, mask, c);
}
sd_bus_creds* SdBus::sd_bus_creds_unref(sd_bus_creds *c)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_unref(c);
}
int SdBus::sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_get_pid(c, pid);
}
int SdBus::sd_bus_creds_get_uid(sd_bus_creds *c, uid_t *uid)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_get_uid(c, uid);
}
int SdBus::sd_bus_creds_get_euid(sd_bus_creds *c, uid_t *euid)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_get_euid(c, euid);
}
int SdBus::sd_bus_creds_get_gid(sd_bus_creds *c, gid_t *gid)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_get_gid(c, gid);
}
int SdBus::sd_bus_creds_get_egid(sd_bus_creds *c, uid_t *egid)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_get_egid(c, egid);
}
int SdBus::sd_bus_creds_get_supplementary_gids(sd_bus_creds *c, const gid_t **gids)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_get_supplementary_gids(c, gids);
}
int SdBus::sd_bus_creds_get_selinux_context(sd_bus_creds *c, const char **label)
{
std::lock_guard lock(sdbusMutex_);
return ::sd_bus_creds_get_selinux_context(c, label);
}
}

View File

@ -75,6 +75,17 @@ public:
virtual int sd_bus_flush(sd_bus *bus) override;
virtual sd_bus *sd_bus_flush_close_unref(sd_bus *bus) override;
virtual int sd_bus_query_sender_creds(sd_bus_message *m, uint64_t mask, sd_bus_creds **c) override;
virtual sd_bus_creds* sd_bus_creds_unref(sd_bus_creds *c) override;
virtual int sd_bus_creds_get_pid(sd_bus_creds *c, pid_t *pid) override;
virtual int sd_bus_creds_get_uid(sd_bus_creds *c, uid_t *uid) override;
virtual int sd_bus_creds_get_euid(sd_bus_creds *c, uid_t *euid) override;
virtual int sd_bus_creds_get_gid(sd_bus_creds *c, gid_t *gid) override;
virtual int sd_bus_creds_get_egid(sd_bus_creds *c, gid_t *egid) override;
virtual int sd_bus_creds_get_supplementary_gids(sd_bus_creds *c, const gid_t **gids) override;
virtual int sd_bus_creds_get_selinux_context(sd_bus_creds *c, const char **label) override;
private:
std::recursive_mutex sdbusMutex_;
};

View File

@ -73,6 +73,17 @@ public:
MOCK_METHOD1(sd_bus_flush, int(sd_bus *bus));
MOCK_METHOD1(sd_bus_flush_close_unref, sd_bus *(sd_bus *bus));
MOCK_METHOD3(sd_bus_query_sender_creds, int(sd_bus_message *, uint64_t, sd_bus_creds **));
MOCK_METHOD1(sd_bus_creds_unref, sd_bus_creds*(sd_bus_creds *));
MOCK_METHOD2(sd_bus_creds_get_pid, int(sd_bus_creds *, pid_t *));
MOCK_METHOD2(sd_bus_creds_get_uid, int(sd_bus_creds *, uid_t *));
MOCK_METHOD2(sd_bus_creds_get_euid, int(sd_bus_creds *, uid_t *));
MOCK_METHOD2(sd_bus_creds_get_gid, int(sd_bus_creds *, gid_t *));
MOCK_METHOD2(sd_bus_creds_get_egid, int(sd_bus_creds *, gid_t *));
MOCK_METHOD2(sd_bus_creds_get_supplementary_gids, int(sd_bus_creds *, const gid_t **));
MOCK_METHOD2(sd_bus_creds_get_selinux_context, int(sd_bus_creds *, const char **));
};
#endif //SDBUS_CXX_SDBUS_MOCK_H