]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/libsystemd/sd-netlink/sd-netlink.c
Add SPDX license identifiers to source files under the LGPL
[thirdparty/systemd.git] / src / libsystemd / sd-netlink / sd-netlink.c
index 2adc4499b6376d79ce30b9bf24826356e3b7a1ee..924b0c954fdfdc3c81b2810db157ba45b348e130 100644 (file)
@@ -1,5 +1,4 @@
-/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
-
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
 
 #include "sd-netlink.h"
 
+#include "alloc-util.h"
+#include "fd-util.h"
 #include "hashmap.h"
 #include "macro.h"
 #include "missing.h"
 #include "netlink-internal.h"
 #include "netlink-util.h"
+#include "socket-util.h"
 #include "util.h"
 
 static int sd_netlink_new(sd_netlink **ret) {
-        _cleanup_netlink_unref_ sd_netlink *rtnl = NULL;
+        _cleanup_(sd_netlink_unrefp) sd_netlink *rtnl = NULL;
 
         assert_return(ret, -EINVAL);
 
@@ -41,12 +43,9 @@ static int sd_netlink_new(sd_netlink **ret) {
                 return -ENOMEM;
 
         rtnl->n_ref = REFCNT_INIT;
-
         rtnl->fd = -1;
-
         rtnl->sockaddr.nl.nl_family = AF_NETLINK;
-
-        rtnl->original_pid = getpid();
+        rtnl->original_pid = getpid_cached();
 
         LIST_HEAD_INIT(rtnl->match_callbacks);
 
@@ -68,7 +67,7 @@ static int sd_netlink_new(sd_netlink **ret) {
 }
 
 int sd_netlink_new_from_netlink(sd_netlink **ret, int fd) {
-        _cleanup_netlink_unref_ sd_netlink *rtnl = NULL;
+        _cleanup_(sd_netlink_unrefp) sd_netlink *rtnl = NULL;
         socklen_t addrlen;
         int r;
 
@@ -84,6 +83,9 @@ int sd_netlink_new_from_netlink(sd_netlink **ret, int fd) {
         if (r < 0)
                 return -errno;
 
+        if (rtnl->sockaddr.nl.nl_family != AF_NETLINK)
+                return -EINVAL;
+
         rtnl->fd = fd;
 
         *ret = rtnl;
@@ -98,11 +100,11 @@ static bool rtnl_pid_changed(sd_netlink *rtnl) {
         /* We don't support people creating an rtnl connection and
          * keeping it around over a fork(). Let's complain. */
 
-        return rtnl->original_pid != getpid();
+        return rtnl->original_pid != getpid_cached();
 }
 
 int sd_netlink_open_fd(sd_netlink **ret, int fd) {
-        _cleanup_netlink_unref_ sd_netlink *rtnl = NULL;
+        _cleanup_(sd_netlink_unrefp) sd_netlink *rtnl = NULL;
         int r;
 
         assert_return(ret, -EINVAL);
@@ -115,8 +117,10 @@ int sd_netlink_open_fd(sd_netlink **ret, int fd) {
         rtnl->fd = fd;
 
         r = socket_bind(rtnl);
-        if (r < 0)
+        if (r < 0) {
+                rtnl->fd = -1; /* on failure, the caller remains owner of the fd, hence don't close it here */
                 return r;
+        }
 
         *ret = rtnl;
         rtnl = NULL;
@@ -141,7 +145,10 @@ int sd_netlink_open(sd_netlink **ret) {
         return 0;
 }
 
-int sd_netlink_inc_rcvbuf(const sd_netlink *const rtnl, const int size) {
+int sd_netlink_inc_rcvbuf(sd_netlink *rtnl, size_t size) {
+        assert_return(rtnl, -EINVAL);
+        assert_return(!rtnl_pid_changed(rtnl), -ECHILD);
+
         return fd_inc_rcvbuf(rtnl->fd, size);
 }
 
@@ -270,20 +277,24 @@ static int dispatch_rqueue(sd_netlink *rtnl, sd_netlink_message **message) {
         if (rtnl->rqueue_size <= 0) {
                 /* Try to read a new message */
                 r = socket_read_message(rtnl);
+                if (r == -ENOBUFS) { /* FIXME: ignore buffer overruns for now */
+                        log_debug_errno(r, "Got ENOBUFS from netlink socket, ignoring.");
+                        return 1;
+                }
                 if (r <= 0)
                         return r;
         }
 
         /* Dispatch a queued message */
         *message = rtnl->rqueue[0];
-        rtnl->rqueue_size --;
+        rtnl->rqueue_size--;
         memmove(rtnl->rqueue, rtnl->rqueue + 1, sizeof(sd_netlink_message*) * rtnl->rqueue_size);
 
         return 1;
 }
 
 static int process_timeout(sd_netlink *rtnl) {
-        _cleanup_netlink_message_unref_ sd_netlink_message *m = NULL;
+        _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         struct reply_callback *c;
         usec_t n;
         int r;
@@ -373,7 +384,7 @@ static int process_match(sd_netlink *rtnl, sd_netlink_message *m) {
 }
 
 static int process_running(sd_netlink *rtnl, sd_netlink_message **ret) {
-        _cleanup_netlink_message_unref_ sd_netlink_message *m = NULL;
+        _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
         assert(rtnl);
@@ -415,7 +426,7 @@ null_message:
 }
 
 int sd_netlink_process(sd_netlink *rtnl, sd_netlink_message **ret) {
-        RTNL_DONT_DESTROY(rtnl);
+        NETLINK_DONT_DESTROY(rtnl);
         int r;
 
         assert_return(rtnl, -EINVAL);
@@ -620,7 +631,7 @@ int sd_netlink_call(sd_netlink *rtnl,
                         received_serial = rtnl_message_get_serial(rtnl->rqueue[i]);
 
                         if (received_serial == serial) {
-                                _cleanup_netlink_message_unref_ sd_netlink_message *incoming = NULL;
+                                _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *incoming = NULL;
                                 uint16_t type;
 
                                 incoming = rtnl->rqueue[i];
@@ -771,7 +782,7 @@ static int prepare_callback(sd_event_source *s, void *userdata) {
         return 1;
 }
 
-int sd_netlink_attach_event(sd_netlink *rtnl, sd_event *event, int priority) {
+int sd_netlink_attach_event(sd_netlink *rtnl, sd_event *event, int64_t priority) {
         int r;
 
         assert_return(rtnl, -EINVAL);
@@ -884,6 +895,16 @@ int sd_netlink_add_match(sd_netlink *rtnl,
                         if (r < 0)
                                 return r;
                         break;
+                case RTM_NEWRULE:
+                case RTM_DELRULE:
+                        r = socket_broadcast_group_ref(rtnl, RTNLGRP_IPV4_RULE);
+                        if (r < 0)
+                                return r;
+
+                        r = socket_broadcast_group_ref(rtnl, RTNLGRP_IPV6_RULE);
+                        if (r < 0)
+                                return r;
+                        break;
                 default:
                         return -EOPNOTSUPP;
         }