]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
sd-netlink: introduce sd_nfnl_{send,call}_batch() 23828/head
authorYu Watanabe <watanabe.yu+github@gmail.com>
Sat, 25 Jun 2022 18:52:59 +0000 (03:52 +0900)
committerYu Watanabe <watanabe.yu+github@gmail.com>
Sat, 23 Jul 2022 15:16:21 +0000 (00:16 +0900)
This also introduces sd_nfnl_message_new() which can be also used for
non-nftables subsystems.

src/libsystemd/sd-netlink/netlink-internal.h
src/libsystemd/sd-netlink/netlink-message-nfnl.c
src/shared/firewall-util-nft.c

index 22df5c86fd63463c69b4e6594f6d5a50a8e93a65..e3bb80e496ca6f0f4df0220a7486c5472dd97399 100644 (file)
@@ -176,8 +176,24 @@ int netlink_add_match_internal(
 /* nfnl */
 /* TODO: to be exported later */
 int sd_nfnl_socket_open(sd_netlink **ret);
-int sd_nfnl_message_batch_begin(sd_netlink *nfnl, sd_netlink_message **ret);
-int sd_nfnl_message_batch_end(sd_netlink *nfnl, sd_netlink_message **ret);
+int sd_nfnl_send_batch(
+                sd_netlink *nfnl,
+                sd_netlink_message **messages,
+                size_t msgcount,
+                uint32_t **ret_serials);
+int sd_nfnl_call_batch(
+                sd_netlink *nfnl,
+                sd_netlink_message **messages,
+                size_t n_messages,
+                uint64_t usec,
+                sd_netlink_message ***ret_messages);
+int sd_nfnl_message_new(
+                sd_netlink *nfnl,
+                sd_netlink_message **ret,
+                int nfproto,
+                uint16_t subsys,
+                uint16_t msg_type,
+                uint16_t flags);
 int sd_nfnl_nft_message_new_table(sd_netlink *nfnl, sd_netlink_message **ret,
                                   int nfproto, const char *table);
 int sd_nfnl_nft_message_new_basechain(sd_netlink *nfnl, sd_netlink_message **ret,
index 28f6c7e3304e376869a91e9f865f1494b73bf2f1..582f623efe70fe66d8659801fac064177f8df3e5 100644 (file)
@@ -7,8 +7,10 @@
 
 #include "sd-netlink.h"
 
+#include "io-util.h"
 #include "netlink-internal.h"
 #include "netlink-types.h"
+#include "netlink-util.h"
 
 static bool nfproto_is_valid(int nfproto) {
         return IN_SET(nfproto,
@@ -22,7 +24,7 @@ static bool nfproto_is_valid(int nfproto) {
                       NFPROTO_DECNET);
 }
 
-static int nft_message_new(sd_netlink *nfnl, sd_netlink_message **ret, int nfproto, uint16_t msg_type, uint16_t flags) {
+int sd_nfnl_message_new(sd_netlink *nfnl, sd_netlink_message **ret, int nfproto, uint16_t subsys, uint16_t msg_type, uint16_t flags) {
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
@@ -31,7 +33,7 @@ static int nft_message_new(sd_netlink *nfnl, sd_netlink_message **ret, int nfpro
         assert_return(nfproto_is_valid(nfproto), -EINVAL);
         assert_return(NFNL_MSG_TYPE(msg_type) == msg_type, -EINVAL);
 
-        r = message_new(nfnl, &m, NFNL_SUBSYS_NFTABLES << 8 | msg_type);
+        r = message_new(nfnl, &m, subsys << 8 | msg_type);
         if (r < 0)
                 return r;
 
@@ -40,14 +42,40 @@ static int nft_message_new(sd_netlink *nfnl, sd_netlink_message **ret, int nfpro
         *(struct nfgenmsg*) NLMSG_DATA(m->hdr) = (struct nfgenmsg) {
                 .nfgen_family = nfproto,
                 .version = NFNETLINK_V0,
-                .res_id = nfnl->serial,
         };
 
         *ret = TAKE_PTR(m);
         return 0;
 }
 
-static int nfnl_message_batch(sd_netlink *nfnl, sd_netlink_message **ret, uint16_t msg_type) {
+static int nfnl_message_set_res_id(sd_netlink_message *m, uint16_t res_id) {
+        struct nfgenmsg *nfgen;
+
+        assert(m);
+        assert(m->hdr);
+
+        nfgen = NLMSG_DATA(m->hdr);
+        nfgen->res_id = htobe16(res_id);
+
+        return 0;
+}
+
+static int nfnl_message_get_subsys(sd_netlink_message *m, uint16_t *ret) {
+        uint16_t t;
+        int r;
+
+        assert(m);
+        assert(ret);
+
+        r = sd_netlink_message_get_type(m, &t);
+        if (r < 0)
+                return r;
+
+        *ret = NFNL_SUBSYS_ID(t);
+        return 0;
+}
+
+static int nfnl_message_new_batch(sd_netlink *nfnl, sd_netlink_message **ret, uint16_t subsys, uint16_t msg_type) {
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
@@ -55,26 +83,136 @@ static int nfnl_message_batch(sd_netlink *nfnl, sd_netlink_message **ret, uint16
         assert_return(ret, -EINVAL);
         assert_return(NFNL_MSG_TYPE(msg_type) == msg_type, -EINVAL);
 
-        r = message_new(nfnl, &m, NFNL_SUBSYS_NONE << 8 | msg_type);
+        r = sd_nfnl_message_new(nfnl, &m, NFPROTO_UNSPEC, NFNL_SUBSYS_NONE, msg_type, 0);
         if (r < 0)
                 return r;
 
-        *(struct nfgenmsg*) NLMSG_DATA(m->hdr) = (struct nfgenmsg) {
-                .nfgen_family = NFPROTO_UNSPEC,
-                .version = NFNETLINK_V0,
-                .res_id = NFNL_SUBSYS_NFTABLES,
-        };
+        r = nfnl_message_set_res_id(m, subsys);
+        if (r < 0)
+                return r;
 
         *ret = TAKE_PTR(m);
         return 0;
 }
 
-int sd_nfnl_message_batch_begin(sd_netlink *nfnl, sd_netlink_message **ret) {
-        return nfnl_message_batch(nfnl, ret, NFNL_MSG_BATCH_BEGIN);
+int sd_nfnl_send_batch(
+                sd_netlink *nfnl,
+                sd_netlink_message **messages,
+                size_t n_messages,
+                uint32_t **ret_serials) {
+
+        /* iovs refs batch_begin and batch_end, hence, free iovs first, then free batch_begin and batch_end. */
+        _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *batch_begin = NULL, *batch_end = NULL;
+        _cleanup_free_ struct iovec *iovs = NULL;
+        _cleanup_free_ uint32_t *serials = NULL;
+        uint16_t subsys;
+        ssize_t k;
+        size_t c = 0;
+        int r;
+
+        assert_return(nfnl, -EINVAL);
+        assert_return(!netlink_pid_changed(nfnl), -ECHILD);
+        assert_return(messages, -EINVAL);
+        assert_return(n_messages > 0, -EINVAL);
+
+        iovs = new(struct iovec, n_messages + 2);
+        if (!iovs)
+                return -ENOMEM;
+
+        if (ret_serials) {
+                serials = new(uint32_t, n_messages);
+                if (!serials)
+                        return -ENOMEM;
+        }
+
+        r = nfnl_message_get_subsys(messages[0], &subsys);
+        if (r < 0)
+                return r;
+
+        r = nfnl_message_new_batch(nfnl, &batch_begin, subsys, NFNL_MSG_BATCH_BEGIN);
+        if (r < 0)
+                return r;
+
+        netlink_seal_message(nfnl, batch_begin);
+        iovs[c++] = IOVEC_MAKE(batch_begin->hdr, batch_begin->hdr->nlmsg_len);
+
+        for (size_t i = 0; i < n_messages; i++) {
+                uint16_t s;
+
+                r = nfnl_message_get_subsys(messages[i], &s);
+                if (r < 0)
+                        return r;
+
+                if (s != subsys)
+                        return -EINVAL;
+
+                netlink_seal_message(nfnl, messages[i]);
+                if (serials)
+                        serials[i] = message_get_serial(messages[i]);
+
+                /* It seems that the kernel accepts an arbitrary number. Let's set the serial of the
+                 * first message. */
+                nfnl_message_set_res_id(messages[i], message_get_serial(batch_begin));
+
+                iovs[c++] = IOVEC_MAKE(messages[i]->hdr, messages[i]->hdr->nlmsg_len);
+        }
+
+        r = nfnl_message_new_batch(nfnl, &batch_end, subsys, NFNL_MSG_BATCH_END);
+        if (r < 0)
+                return r;
+
+        netlink_seal_message(nfnl, batch_end);
+        iovs[c++] = IOVEC_MAKE(batch_end->hdr, batch_end->hdr->nlmsg_len);
+
+        assert(c == n_messages + 2);
+        k = writev(nfnl->fd, iovs, n_messages + 2);
+        if (k < 0)
+                return -errno;
+
+        if (ret_serials)
+                *ret_serials = TAKE_PTR(serials);
+
+        return 0;
 }
 
-int sd_nfnl_message_batch_end(sd_netlink *nfnl, sd_netlink_message **ret) {
-        return nfnl_message_batch(nfnl, ret, NFNL_MSG_BATCH_END);
+int sd_nfnl_call_batch(
+                sd_netlink *nfnl,
+                sd_netlink_message **messages,
+                size_t n_messages,
+                uint64_t usec,
+                sd_netlink_message ***ret_messages) {
+
+        _cleanup_free_ sd_netlink_message **replies = NULL;
+        _cleanup_free_ uint32_t *serials = NULL;
+        int k, r;
+
+        assert_return(nfnl, -EINVAL);
+        assert_return(!netlink_pid_changed(nfnl), -ECHILD);
+        assert_return(messages, -EINVAL);
+        assert_return(n_messages > 0, -EINVAL);
+
+        if (ret_messages) {
+                replies = new0(sd_netlink_message*, n_messages);
+                if (!replies)
+                        return -ENOMEM;
+        }
+
+        r = sd_nfnl_send_batch(nfnl, messages, n_messages, &serials);
+        if (r < 0)
+                return r;
+
+        for (size_t i = 0; i < n_messages; i++) {
+                k = sd_netlink_read(nfnl, serials[i], usec, ret_messages ? replies + i : NULL);
+                if (k < 0 && r >= 0)
+                        r = k;
+        }
+        if (r < 0)
+                return r;
+
+        if (ret_messages)
+                *ret_messages = TAKE_PTR(replies);
+
+        return 0;
 }
 
 int sd_nfnl_nft_message_new_basechain(
@@ -90,7 +228,7 @@ int sd_nfnl_nft_message_new_basechain(
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
-        r = nft_message_new(nfnl, &m, nfproto, NFT_MSG_NEWCHAIN, NLM_F_CREATE);
+        r = sd_nfnl_message_new(nfnl, &m, nfproto, NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWCHAIN, NLM_F_CREATE);
         if (r < 0)
                 return r;
 
@@ -135,7 +273,7 @@ int sd_nfnl_nft_message_new_table(
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
-        r = nft_message_new(nfnl, &m, nfproto, NFT_MSG_NEWTABLE, NLM_F_CREATE | NLM_F_EXCL);
+        r = sd_nfnl_message_new(nfnl, &m, nfproto, NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE, NLM_F_CREATE | NLM_F_EXCL);
         if (r < 0)
                 return r;
 
@@ -157,7 +295,7 @@ int sd_nfnl_nft_message_new_rule(
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
-        r = nft_message_new(nfnl, &m, nfproto, NFT_MSG_NEWRULE, NLM_F_CREATE);
+        r = sd_nfnl_message_new(nfnl, &m, nfproto, NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWRULE, NLM_F_CREATE);
         if (r < 0)
                 return r;
 
@@ -185,7 +323,7 @@ int sd_nfnl_nft_message_new_set(
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
-        r = nft_message_new(nfnl, &m, nfproto, NFT_MSG_NEWSET, NLM_F_CREATE);
+        r = sd_nfnl_message_new(nfnl, &m, nfproto, NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWSET, NLM_F_CREATE);
         if (r < 0)
                 return r;
 
@@ -221,9 +359,9 @@ int sd_nfnl_nft_message_new_setelems(
         int r;
 
         if (add)
-                r = nft_message_new(nfnl, &m, nfproto, NFT_MSG_NEWSETELEM, NLM_F_CREATE);
+                r = sd_nfnl_message_new(nfnl, &m, nfproto, NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWSETELEM, NLM_F_CREATE);
         else
-                r = nft_message_new(nfnl, &m, nfproto, NFT_MSG_DELSETELEM, 0);
+                r = sd_nfnl_message_new(nfnl, &m, nfproto, NFNL_SUBSYS_NFTABLES, NFT_MSG_DELSETELEM, 0);
         if (r < 0)
                 return r;
 
index b4ef1664997ce370596f271c367f7f237270ce1a..b5f0d1bab75afe184d873436878efe1a6d9187a5 100644 (file)
@@ -45,35 +45,6 @@ static sd_netlink_message **netlink_message_unref_many(sd_netlink_message **m) {
 
 DEFINE_TRIVIAL_CLEANUP_FUNC(sd_netlink_message**, netlink_message_unref_many);
 
-static int nfnl_netlink_sendv(
-                sd_netlink *nfnl,
-                sd_netlink_message *messages[static 1],
-                size_t msgcount) {
-
-        _cleanup_free_ uint32_t *serial = NULL;
-        int r;
-
-        assert(nfnl);
-        assert(messages);
-        assert(msgcount > 0);
-
-        r = sd_netlink_sendv(nfnl, messages, msgcount, &serial);
-        if (r < 0)
-                return r;
-
-        r = 0;
-        for (size_t i = 1; i < msgcount - 1; i++) {
-                int tmp;
-
-                /* If message is an error, this returns embedded errno */
-                tmp = sd_netlink_read(nfnl, serial[i], NFNL_DEFAULT_TIMEOUT_USECS, NULL);
-                if (tmp < 0 && r == 0)
-                        r = tmp;
-        }
-
-        return r;
-}
-
 static int nfnl_open_expr_container(sd_netlink_message *m, const char *name) {
         int r;
 
@@ -742,7 +713,7 @@ static uint32_t concat_types2(enum nft_key_types a, enum nft_key_types b) {
 }
 
 static int fw_nftables_init_family(sd_netlink *nfnl, int family) {
-        sd_netlink_message *messages[12] = {};
+        sd_netlink_message *messages[10] = {};
         _unused_ _cleanup_(netlink_message_unref_manyp) sd_netlink_message **unref = messages;
         size_t msgcnt = 0, ip_type_size;
         uint32_t set_id = 0;
@@ -751,10 +722,6 @@ static int fw_nftables_init_family(sd_netlink *nfnl, int family) {
         assert(nfnl);
         assert(IN_SET(family, AF_INET, AF_INET6));
 
-        r = sd_nfnl_message_batch_begin(nfnl, &messages[msgcnt++]);
-        if (r < 0)
-                return r;
-
         /* Set F_EXCL so table add fails if the table already exists. */
         r = sd_nfnl_nft_message_new_table(nfnl, &messages[msgcnt++], family, NFT_SYSTEMD_TABLE_NAME);
         if (r < 0)
@@ -816,12 +783,8 @@ static int fw_nftables_init_family(sd_netlink *nfnl, int family) {
         if (r < 0)
                 return r;
 
-        r = sd_nfnl_message_batch_end(nfnl, &messages[msgcnt++]);
-        if (r < 0)
-                return r;
-
         assert(msgcnt < ELEMENTSOF(messages));
-        r = nfnl_netlink_sendv(nfnl, messages, msgcnt);
+        r = sd_nfnl_call_batch(nfnl, messages, msgcnt, NFNL_DEFAULT_TIMEOUT_USECS, NULL);
         if (r < 0 && r != -EEXIST)
                 return r;
 
@@ -935,9 +898,7 @@ static int fw_nftables_add_masquerade_internal(
                 const union in_addr_union *source,
                 unsigned int source_prefixlen) {
 
-        sd_netlink_message *messages[4] = {};
-        _unused_ _cleanup_(netlink_message_unref_manyp) sd_netlink_message **unref = messages;
-        size_t msgcnt = 0;
+        _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         int r;
 
         assert(nfnl);
@@ -949,31 +910,18 @@ static int fw_nftables_add_masquerade_internal(
         if (af == AF_INET6 && source_prefixlen < 8)
                 return -EINVAL;
 
-        r = sd_nfnl_message_batch_begin(nfnl, &messages[msgcnt++]);
-        if (r < 0)
-                return r;
-
-        r = sd_nfnl_nft_message_new_setelems(nfnl, &messages[msgcnt++], add, af, NFT_SYSTEMD_TABLE_NAME, NFT_SYSTEMD_MASQ_SET_NAME);
+        r = sd_nfnl_nft_message_new_setelems(nfnl, &m, add, af, NFT_SYSTEMD_TABLE_NAME, NFT_SYSTEMD_MASQ_SET_NAME);
         if (r < 0)
                 return r;
 
         if (af == AF_INET)
-                 r = nft_message_append_setelem_iprange(messages[msgcnt-1], source, source_prefixlen);
+                 r = nft_message_append_setelem_iprange(m, source, source_prefixlen);
         else
-                 r = nft_message_append_setelem_ip6range(messages[msgcnt-1], source, source_prefixlen);
-        if (r < 0)
-                return r;
-
-        r = sd_nfnl_message_batch_end(nfnl, &messages[msgcnt++]);
-        if (r < 0)
-                return r;
-
-        assert(msgcnt < ELEMENTSOF(messages));
-        r = nfnl_netlink_sendv(nfnl, messages, msgcnt);
+                 r = nft_message_append_setelem_ip6range(m, source, source_prefixlen);
         if (r < 0)
                 return r;
 
-        return 0;
+        return sd_nfnl_call_batch(nfnl, &m, 1, NFNL_DEFAULT_TIMEOUT_USECS, NULL);
 }
 
 int fw_nftables_add_masquerade(
@@ -1030,7 +978,7 @@ static int fw_nftables_add_local_dnat_internal(
                 uint16_t remote_port,
                 const union in_addr_union *previous_remote) {
 
-        sd_netlink_message *messages[5] = {};
+        sd_netlink_message *messages[3] = {};
         _unused_ _cleanup_(netlink_message_unref_manyp) sd_netlink_message **unref = messages;
         static bool ipv6_supported = true;
         uint32_t data[5], key[2], dlen;
@@ -1068,10 +1016,6 @@ static int fw_nftables_add_local_dnat_internal(
                 data[4] = htobe16(remote_port);
         }
 
-        r = sd_nfnl_message_batch_begin(nfnl, &messages[msgcnt++]);
-        if (r < 0)
-                return r;
-
         /* If a previous remote is set, remove its entry */
         if (add && previous_remote && !in_addr_equal(af, previous_remote, remote)) {
                 if (af == AF_INET)
@@ -1096,12 +1040,8 @@ static int fw_nftables_add_local_dnat_internal(
         if (r < 0)
                 return r;
 
-        r = sd_nfnl_message_batch_end(nfnl, &messages[msgcnt++]);
-        if (r < 0)
-                return r;
-
         assert(msgcnt < ELEMENTSOF(messages));
-        r = nfnl_netlink_sendv(nfnl, messages, msgcnt);
+        r = sd_nfnl_call_batch(nfnl, messages, msgcnt, NFNL_DEFAULT_TIMEOUT_USECS, NULL);
         if (r == -EOVERFLOW && af == AF_INET6) {
                 /* The current implementation of DNAT in systemd requires kernel's
                  * fdb9c405e35bdc6e305b9b4e20ebc141ed14fc81 (v5.8), and the older kernel returns