]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/libsystemd/sd-netlink/netlink-socket.c
Add SPDX license identifiers to source files under the LGPL
[thirdparty/systemd.git] / src / libsystemd / sd-netlink / netlink-socket.c
index 84ff7c38c925e7b5dca7fb0571cd53d32512dcba..22be94382a79c9b3a39e0753f20154520598ab9d 100644 (file)
@@ -1,5 +1,4 @@
-/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
-
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
 #include <stdbool.h>
 #include <unistd.h>
 
-#include "util.h"
-#include "socket-util.h"
-#include "formats-util.h"
-#include "refcnt.h"
-#include "missing.h"
-
 #include "sd-netlink.h"
-#include "netlink-util.h"
+
+#include "alloc-util.h"
+#include "format-util.h"
+#include "missing.h"
 #include "netlink-internal.h"
 #include "netlink-types.h"
+#include "netlink-util.h"
+#include "refcnt.h"
+#include "socket-util.h"
+#include "util.h"
 
 int socket_open(int family) {
         int fd;
@@ -44,6 +44,65 @@ int socket_open(int family) {
         return fd;
 }
 
+static int broadcast_groups_get(sd_netlink *nl) {
+        _cleanup_free_ uint32_t *groups = NULL;
+        socklen_t len = 0, old_len;
+        unsigned i, j;
+        int r;
+
+        assert(nl);
+        assert(nl->fd >= 0);
+
+        r = getsockopt(nl->fd, SOL_NETLINK, NETLINK_LIST_MEMBERSHIPS, NULL, &len);
+        if (r < 0) {
+                if (errno == ENOPROTOOPT) {
+                        nl->broadcast_group_dont_leave = true;
+                        return 0;
+                } else
+                        return -errno;
+        }
+
+        if (len == 0)
+                return 0;
+
+        groups = new0(uint32_t, len);
+        if (!groups)
+                return -ENOMEM;
+
+        old_len = len;
+
+        r = getsockopt(nl->fd, SOL_NETLINK, NETLINK_LIST_MEMBERSHIPS, groups, &len);
+        if (r < 0)
+                return -errno;
+
+        if (old_len != len)
+                return -EIO;
+
+        r = hashmap_ensure_allocated(&nl->broadcast_group_refs, NULL);
+        if (r < 0)
+                return r;
+
+        for (i = 0; i < len; i++) {
+                for (j = 0; j < sizeof(uint32_t) * 8; j++) {
+                        uint32_t offset;
+                        unsigned group;
+
+                        offset = 1U << j;
+
+                        if (!(groups[i] & offset))
+                                continue;
+
+                        group = i * sizeof(uint32_t) * 8 + j + 1;
+
+                        r = hashmap_put(nl->broadcast_group_refs, UINT_TO_PTR(group), UINT_TO_PTR(1));
+                        if (r < 0)
+                                return r;
+                }
+        }
+
+        return 0;
+}
+
 int socket_bind(sd_netlink *nl) {
         socklen_t addrlen;
         int r, one = 1;
@@ -63,11 +122,32 @@ int socket_bind(sd_netlink *nl) {
         if (r < 0)
                 return -errno;
 
+        r = broadcast_groups_get(nl);
+        if (r < 0)
+                return r;
+
         return 0;
 }
 
+static unsigned broadcast_group_get_ref(sd_netlink *nl, unsigned group) {
+        assert(nl);
+
+        return PTR_TO_UINT(hashmap_get(nl->broadcast_group_refs, UINT_TO_PTR(group)));
+}
 
-int socket_join_broadcast_group(sd_netlink *nl, unsigned group) {
+static int broadcast_group_set_ref(sd_netlink *nl, unsigned group, unsigned n_ref) {
+        int r;
+
+        assert(nl);
+
+        r = hashmap_replace(nl->broadcast_group_refs, UINT_TO_PTR(group), UINT_TO_PTR(n_ref));
+        if (r < 0)
+                return r;
+
+        return 0;
+}
+
+static int broadcast_group_join(sd_netlink *nl, unsigned group) {
         int r;
 
         assert(nl);
@@ -81,6 +161,79 @@ int socket_join_broadcast_group(sd_netlink *nl, unsigned group) {
         return 0;
 }
 
+int socket_broadcast_group_ref(sd_netlink *nl, unsigned group) {
+        unsigned n_ref;
+        int r;
+
+        assert(nl);
+
+        n_ref = broadcast_group_get_ref(nl, group);
+
+        n_ref++;
+
+        r = hashmap_ensure_allocated(&nl->broadcast_group_refs, NULL);
+        if (r < 0)
+                return r;
+
+        r = broadcast_group_set_ref(nl, group, n_ref);
+        if (r < 0)
+                return r;
+
+        if (n_ref > 1)
+                /* not yet in the group */
+                return 0;
+
+        r = broadcast_group_join(nl, group);
+        if (r < 0)
+                return r;
+
+        return 0;
+}
+
+static int broadcast_group_leave(sd_netlink *nl, unsigned group) {
+        int r;
+
+        assert(nl);
+        assert(nl->fd >= 0);
+        assert(group > 0);
+
+        if (nl->broadcast_group_dont_leave)
+                return 0;
+
+        r = setsockopt(nl->fd, SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, &group, sizeof(group));
+        if (r < 0)
+                return -errno;
+
+        return 0;
+}
+
+int socket_broadcast_group_unref(sd_netlink *nl, unsigned group) {
+        unsigned n_ref;
+        int r;
+
+        assert(nl);
+
+        n_ref = broadcast_group_get_ref(nl, group);
+
+        assert(n_ref > 0);
+
+        n_ref--;
+
+        r = broadcast_group_set_ref(nl, group, n_ref);
+        if (r < 0)
+                return r;
+
+        if (n_ref > 0)
+                /* still refs left */
+                return 0;
+
+        r = broadcast_group_leave(nl, group);
+        if (r < 0)
+                return r;
+
+        return 0;
+}
+
 /* returns the number of bytes sent, or a negative error code */
 int socket_write_message(sd_netlink *nl, sd_netlink_message *m) {
         union {
@@ -129,7 +282,7 @@ static int socket_recv_message(int fd, struct iovec *iov, uint32_t *_group, bool
                 else if (errno == EAGAIN)
                         log_debug("rtnl: no data in socket");
 
-                return (errno == EAGAIN || errno == EINTR) ? 0 : -errno;
+                return IN_SET(errno, EAGAIN, EINTR) ? 0 : -errno;
         }
 
         if (sender.nl.nl_pid != 0) {
@@ -140,7 +293,7 @@ static int socket_recv_message(int fd, struct iovec *iov, uint32_t *_group, bool
                         /* drop the message */
                         r = recvmsg(fd, &msg, 0);
                         if (r < 0)
-                                return (errno == EAGAIN || errno == EINTR) ? 0 : -errno;
+                                return IN_SET(errno, EAGAIN, EINTR) ? 0 : -errno;
                 }
 
                 return 0;
@@ -169,7 +322,7 @@ static int socket_recv_message(int fd, struct iovec *iov, uint32_t *_group, bool
  * On failure, a negative error code is returned.
  */
 int socket_read_message(sd_netlink *rtnl) {
-        _cleanup_netlink_message_unref_ sd_netlink_message *first = NULL;
+        _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *first = NULL;
         struct iovec iov = {};
         uint32_t group = 0;
         bool multi_part = false, done = false;
@@ -222,7 +375,7 @@ int socket_read_message(sd_netlink *rtnl) {
         }
 
         for (new_msg = rtnl->rbuffer; NLMSG_OK(new_msg, len) && !done; new_msg = NLMSG_NEXT(new_msg, len)) {
-                _cleanup_netlink_message_unref_ sd_netlink_message *m = NULL;
+                _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
                 const NLType *nl_type;
 
                 if (!group && new_msg->nlmsg_pid != rtnl->sockaddr.nl.nl_pid)
@@ -292,14 +445,14 @@ int socket_read_message(sd_netlink *rtnl) {
                 if (r < 0)
                         return r;
 
-                rtnl->rqueue[rtnl->rqueue_size ++] = first;
+                rtnl->rqueue[rtnl->rqueue_size++] = first;
                 first = NULL;
 
                 if (multi_part && (i < rtnl->rqueue_partial_size)) {
                         /* remove the message form the partial read queue */
                         memmove(rtnl->rqueue_partial + i,rtnl->rqueue_partial + i + 1,
                                 sizeof(sd_netlink_message*) * (rtnl->rqueue_partial_size - i - 1));
-                        rtnl->rqueue_partial_size --;
+                        rtnl->rqueue_partial_size--;
                 }
 
                 return 1;
@@ -313,7 +466,7 @@ int socket_read_message(sd_netlink *rtnl) {
                         if (r < 0)
                                 return r;
 
-                        rtnl->rqueue_partial[rtnl->rqueue_partial_size ++] = first;
+                        rtnl->rqueue_partial[rtnl->rqueue_partial_size++] = first;
                 }
                 first = NULL;