]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
socket-util: add send/receive helpers for FD array
authorLuca Boccassi <bluca@debian.org>
Fri, 7 Jul 2023 23:18:01 +0000 (00:18 +0100)
committerLuca Boccassi <bluca@debian.org>
Sat, 29 Jul 2023 10:25:05 +0000 (11:25 +0100)
src/basic/socket-util.c
src/basic/socket-util.h
src/test/test-socket-util.c

index 6951c12c9b72f8ceeb477af691bd7579992b4974..beb64d8e6c7bf2edce4e30e3d21e0277a1a5126a 100644 (file)
 #  define IDN_FLAGS 0
 #endif
 
+/* From the kernel's include/net/scm.h */
+#ifndef SCM_MAX_FD
+#  define SCM_MAX_FD 253
+#endif
+
 static const char* const socket_address_type_table[] = {
         [SOCK_STREAM] =    "Stream",
         [SOCK_DGRAM] =     "Datagram",
@@ -951,6 +956,53 @@ int getpeergroups(int fd, gid_t **ret) {
         return (int) n;
 }
 
+ssize_t send_many_fds_iov_sa(
+                int transport_fd,
+                int *fds_array, size_t n_fds_array,
+                const struct iovec *iov, size_t iovlen,
+                const struct sockaddr *sa, socklen_t len,
+                int flags) {
+
+        _cleanup_free_ struct cmsghdr *cmsg = NULL;
+        struct msghdr mh = {
+                .msg_name = (struct sockaddr*) sa,
+                .msg_namelen = len,
+                .msg_iov = (struct iovec *)iov,
+                .msg_iovlen = iovlen,
+        };
+        ssize_t k;
+
+        assert(transport_fd >= 0);
+        assert(fds_array || n_fds_array == 0);
+
+        /* The kernel will reject sending more than SCM_MAX_FD FDs at once */
+        if (n_fds_array > SCM_MAX_FD)
+                return -E2BIG;
+
+        /* We need either an FD array or data to send. If there's nothing, return an error. */
+        if (n_fds_array == 0 && !iov)
+                return -EINVAL;
+
+        if (n_fds_array > 0) {
+                mh.msg_controllen = CMSG_SPACE(sizeof(int) * n_fds_array);
+                mh.msg_control = cmsg = malloc(mh.msg_controllen);
+                if (!cmsg)
+                        return -ENOMEM;
+
+                *cmsg = (struct cmsghdr) {
+                        .cmsg_len = CMSG_LEN(sizeof(int) * n_fds_array),
+                        .cmsg_level = SOL_SOCKET,
+                        .cmsg_type = SCM_RIGHTS,
+                };
+                memcpy(CMSG_DATA(cmsg), fds_array, sizeof(int) * n_fds_array);
+        }
+        k = sendmsg(transport_fd, &mh, MSG_NOSIGNAL | flags);
+        if (k < 0)
+                return (ssize_t) -errno;
+
+        return k;
+}
+
 ssize_t send_one_fd_iov_sa(
                 int transport_fd,
                 int fd,
@@ -1006,6 +1058,78 @@ int send_one_fd_sa(
         return (int) send_one_fd_iov_sa(transport_fd, fd, NULL, 0, sa, len, flags);
 }
 
+ssize_t receive_many_fds_iov(
+                int transport_fd,
+                struct iovec *iov, size_t iovlen,
+                int **ret_fds_array, size_t *ret_n_fds_array,
+                int flags) {
+
+        CMSG_BUFFER_TYPE(CMSG_SPACE(sizeof(int) * SCM_MAX_FD)) control;
+        struct msghdr mh = {
+                .msg_control = &control,
+                .msg_controllen = sizeof(control),
+                .msg_iov = iov,
+                .msg_iovlen = iovlen,
+        };
+        _cleanup_free_ int *fds_array = NULL;
+        size_t n_fds_array = 0;
+        struct cmsghdr *cmsg;
+        ssize_t k;
+
+        assert(transport_fd >= 0);
+        assert(ret_fds_array);
+        assert(ret_n_fds_array);
+
+        /*
+         * Receive many FDs via @transport_fd. We don't care for the transport-type. We retrieve all the FDs
+         * at once. This is best used in combination with send_many_fds().
+         */
+
+        k = recvmsg_safe(transport_fd, &mh, MSG_CMSG_CLOEXEC | flags);
+        if (k < 0)
+                return k;
+
+        CMSG_FOREACH(cmsg, &mh)
+                if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+                        size_t n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+
+                        fds_array = GREEDY_REALLOC(fds_array, n_fds_array + n);
+                        if (!fds_array) {
+                                cmsg_close_all(&mh);
+                                return -ENOMEM;
+                        }
+
+                        memcpy(fds_array + n_fds_array, CMSG_TYPED_DATA(cmsg, int), sizeof(int) * n);
+                        n_fds_array += n;
+                }
+
+        if (n_fds_array == 0) {
+                cmsg_close_all(&mh);
+
+                /* If didn't receive an FD or any data, return an error. */
+                if (k == 0)
+                        return -EIO;
+        }
+
+        *ret_fds_array = TAKE_PTR(fds_array);
+        *ret_n_fds_array = n_fds_array;
+
+        return k;
+}
+
+int receive_many_fds(int transport_fd, int **ret_fds_array, size_t *ret_n_fds_array, int flags) {
+        ssize_t k;
+
+        k = receive_many_fds_iov(transport_fd, NULL, 0, ret_fds_array, ret_n_fds_array, flags);
+        if (k == 0)
+                return 0;
+
+        /* k must be negative, since receive_many_fds_iov() only returns a positive value if data was received
+         * through the iov. */
+        assert(k < 0);
+        return (int) k;
+}
+
 ssize_t receive_one_fd_iov(
                 int transport_fd,
                 struct iovec *iov, size_t iovlen,
index 9c4c95bd3a1a0b9591b93e0436f6c7a9848f2df0..9a11df834d113ebd26a76a88458dc62d52eef185 100644 (file)
@@ -153,6 +153,28 @@ int getpeercred(int fd, struct ucred *ucred);
 int getpeersec(int fd, char **ret);
 int getpeergroups(int fd, gid_t **ret);
 
+ssize_t send_many_fds_iov_sa(
+                int transport_fd,
+                int *fds_array, size_t n_fds_array,
+                const struct iovec *iov, size_t iovlen,
+                const struct sockaddr *sa, socklen_t len,
+                int flags);
+static inline ssize_t send_many_fds_iov(
+                int transport_fd,
+                int *fds_array, size_t n_fds_array,
+                const struct iovec *iov, size_t iovlen,
+                int flags) {
+
+        return send_many_fds_iov_sa(transport_fd, fds_array, n_fds_array, iov, iovlen, NULL, 0, flags);
+}
+static inline int send_many_fds(
+                int transport_fd,
+                int *fds_array,
+                size_t n_fds_array,
+                int flags) {
+
+        return send_many_fds_iov_sa(transport_fd, fds_array, n_fds_array, NULL, 0, NULL, 0, flags);
+}
 ssize_t send_one_fd_iov_sa(
                 int transport_fd,
                 int fd,
@@ -167,6 +189,8 @@ int send_one_fd_sa(int transport_fd,
 #define send_one_fd(transport_fd, fd, flags) send_one_fd_iov_sa(transport_fd, fd, NULL, 0, NULL, 0, flags)
 ssize_t receive_one_fd_iov(int transport_fd, struct iovec *iov, size_t iovlen, int flags, int *ret_fd);
 int receive_one_fd(int transport_fd, int flags);
+ssize_t receive_many_fds_iov(int transport_fd, struct iovec *iov, size_t iovlen, int **ret_fds_array, size_t *ret_n_fds_array, int flags);
+int receive_many_fds(int transport_fd, int **ret_fds_array, size_t *ret_n_fds_array, int flags);
 
 ssize_t next_datagram_size_fd(int fd);
 
index 0259cbf3bb67c905c6565411d6b00e8ee0aeed13..2c5d31e5f8545fab514eff28ef62091955d614f6 100644 (file)
@@ -310,6 +310,71 @@ TEST(passfd_contents_read) {
         assert_se(streq(buf, file_contents));
 }
 
+TEST(pass_many_fds_contents_read) {
+        _cleanup_close_pair_ int pair[2];
+        static const char file_contents[][STRLEN("test contents in the fileX") + 1] = {
+                "test contents in the file0",
+                "test contents in the file1",
+                "test contents in the file2"
+        };
+        static const char wire_contents[] = "test contents on the wire";
+        int r;
+
+        assert_se(socketpair(AF_UNIX, SOCK_DGRAM, 0, pair) >= 0);
+
+        r = safe_fork("(passfd_contents_read)", FORK_DEATHSIG|FORK_LOG|FORK_WAIT, NULL);
+        assert_se(r >= 0);
+
+        if (r == 0) {
+                /* Child */
+                struct iovec iov = IOVEC_MAKE_STRING(wire_contents);
+                char tmpfile[][STRLEN("/tmp/test-socket-util-passfd-contents-read-XXXXXX") + 1] = {
+                        "/tmp/test-socket-util-passfd-contents-read-XXXXXX",
+                        "/tmp/test-socket-util-passfd-contents-read-XXXXXX",
+                        "/tmp/test-socket-util-passfd-contents-read-XXXXXX"
+                };
+                int tmpfds[3] = { -EBADF, -EBADF, -EBADF };
+
+                pair[0] = safe_close(pair[0]);
+
+                for (size_t i = 0; i < 3; ++i) {
+                        assert_se(write_tmpfile(tmpfile[i], file_contents[i]) == 0);
+                        tmpfds[i] = open(tmpfile[i], O_RDONLY);
+                        assert_se(tmpfds[i] >= 0);
+                        assert_se(unlink(tmpfile[i]) == 0);
+                }
+
+                assert_se(send_many_fds_iov(pair[1], tmpfds, 3, &iov, 1, MSG_DONTWAIT) > 0);
+                close_many(tmpfds, 3);
+                _exit(EXIT_SUCCESS);
+        }
+
+        /* Parent */
+        char buf[64];
+        struct iovec iov = IOVEC_MAKE(buf, sizeof(buf)-1);
+        _cleanup_free_ int *fds = NULL;
+        size_t n_fds = 0;
+        ssize_t k;
+
+        pair[1] = safe_close(pair[1]);
+
+        k = receive_many_fds_iov(pair[0], &iov, 1, &fds, &n_fds, MSG_DONTWAIT);
+        assert_se(k > 0);
+        buf[k] = 0;
+        assert_se(streq(buf, wire_contents));
+
+        assert_se(n_fds == 3);
+
+        for (size_t i = 0; i < 3; ++i) {
+                assert_se(fds[i] >= 0);
+                r = read(fds[i], buf, sizeof(buf)-1);
+                assert_se(r >= 0);
+                buf[r] = 0;
+                assert_se(streq(buf, file_contents[i]));
+                safe_close(fds[i]);
+        }
+}
+
 TEST(receive_nopassfd) {
         _cleanup_close_pair_ int pair[2];
         static const char wire_contents[] = "no fd passed here";