]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
shared: extract `socket_forward_new()` helper from socket-proxyd 41168/head
authorMichael Vogt <michael@amutable.com>
Wed, 18 Mar 2026 10:38:48 +0000 (11:38 +0100)
committerMichael Vogt <michael@amutable.com>
Fri, 20 Mar 2026 09:13:34 +0000 (10:13 +0100)
This commit extracts the socket forwarding code from the existing
socket-proxyd into a new shared helper that will be used by the
varlinkctl protocol upgrade support code and is used as is in
the socket-proxyd.c.

It tries to keep the changes as small as possible, its mostly
renaming like:
* connection_create_pipes -> socket_forward_create_pipes
* connection_shovel -> socket_forward_shovel
* connection_enable_event_sources -> socket_forward_enable_event_sources
* traffic_cb -> socket_forward_traffic_cb

and a new socket_forward_new() that creates/starts the forwarding.

All log_error_errno() got downgraded to log_debug_errno().

mkosi/mkosi.sanitizers/mkosi.postinst
src/shared/meson.build
src/shared/socket-forward.c [new file with mode: 0644]
src/shared/socket-forward.h [new file with mode: 0644]
src/socket-proxy/socket-proxyd.c

index 229a5368b92f4ec9eb4c60a6fa97e66486f8dcb2..72356005e93371cdc50f1c85ab2a57e4a10c8aaf 100755 (executable)
@@ -42,8 +42,8 @@ fi
 
 wrap=(
     /usr/lib/polkit-1/polkitd
-    /usr/libexec/polkit-1/polkitd
     /usr/lib/systemd/tests/testdata/TEST-74-AUX-UTILS.units/proxy-echo.py
+    /usr/libexec/polkit-1/polkitd
     agetty
     btrfs
     capsh
index bbc03079993242e3eb7a4f1c3164aec74e60068b..e8a86b11b0659c806c6b4a5228116292d17bca98 100644 (file)
@@ -181,6 +181,7 @@ shared_sources = files(
         'smack-util.c',
         'smbios11.c',
         'snapshot-util.c',
+        'socket-forward.c',
         'socket-label.c',
         'socket-netlink.c',
         'specifier.c',
diff --git a/src/shared/socket-forward.c b/src/shared/socket-forward.c
new file mode 100644 (file)
index 0000000..2601b25
--- /dev/null
@@ -0,0 +1,256 @@
+/* SPDX-License-Identifier: LGPL-2.1-or-later */
+
+#include <fcntl.h>
+#include <unistd.h>
+
+#include "sd-event.h"
+
+#include "alloc-util.h"
+#include "errno-util.h"
+#include "fd-util.h"
+#include "log.h"
+#include "socket-forward.h"
+
+#define SOCKET_FORWARD_BUFFER_SIZE (256 * 1024)
+
+struct SocketForward {
+        sd_event *event;
+
+        int server_fd, client_fd;
+
+        int server_to_client_buffer[2]; /* a pipe */
+        int client_to_server_buffer[2]; /* a pipe */
+
+        size_t server_to_client_buffer_full, client_to_server_buffer_full;
+        size_t server_to_client_buffer_size, client_to_server_buffer_size;
+
+        sd_event_source *server_event_source, *client_event_source;
+
+        socket_forward_done_t on_done;
+        void *userdata;
+};
+
+SocketForward* socket_forward_free(SocketForward *sf) {
+        if (!sf)
+                return NULL;
+
+        sd_event_source_unref(sf->server_event_source);
+        sd_event_source_unref(sf->client_event_source);
+
+        safe_close(sf->server_fd);
+        safe_close(sf->client_fd);
+
+        safe_close_pair(sf->server_to_client_buffer);
+        safe_close_pair(sf->client_to_server_buffer);
+
+        sd_event_unref(sf->event);
+
+        return mfree(sf);
+}
+
+static int socket_forward_create_pipes(int buffer[static 2], size_t *ret_size) {
+        int r;
+
+        assert(buffer);
+        assert(ret_size);
+
+        if (buffer[0] >= 0)
+                return 0;
+
+        r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
+        if (r < 0)
+                return log_debug_errno(errno, "Failed to allocate pipe buffer: %m");
+
+        (void) fcntl(buffer[0], F_SETPIPE_SZ, SOCKET_FORWARD_BUFFER_SIZE);
+
+        r = fcntl(buffer[0], F_GETPIPE_SZ);
+        if (r < 0)
+                return log_debug_errno(errno, "Failed to get pipe buffer size: %m");
+
+        assert(r > 0);
+        *ret_size = r;
+
+        return 0;
+}
+
+static int socket_forward_shovel(
+                int *from, int buffer[2], int *to,
+                size_t *full, size_t *sz,
+                sd_event_source **from_source, sd_event_source **to_source) {
+
+        bool shoveled;
+
+        assert(from);
+        assert(buffer);
+        assert(buffer[0] >= 0);
+        assert(buffer[1] >= 0);
+        assert(to);
+        assert(full);
+        assert(sz);
+        assert(from_source);
+        assert(to_source);
+
+        do {
+                ssize_t z;
+
+                shoveled = false;
+
+                if (*full < *sz && *from >= 0 && *to >= 0) {
+                        z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
+                        if (z > 0) {
+                                *full += z;
+                                shoveled = true;
+                        } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
+                                *from_source = sd_event_source_unref(*from_source);
+                                *from = safe_close(*from);
+                        } else if (!ERRNO_IS_TRANSIENT(errno))
+                                return log_debug_errno(errno, "Failed to splice: %m");
+                }
+
+                if (*full > 0 && *to >= 0) {
+                        z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
+                        if (z > 0) {
+                                *full -= z;
+                                shoveled = true;
+                        } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
+                                *to_source = sd_event_source_unref(*to_source);
+                                *to = safe_close(*to);
+                        } else if (!ERRNO_IS_TRANSIENT(errno))
+                                return log_debug_errno(errno, "Failed to splice: %m");
+                }
+        } while (shoveled);
+
+        return 0;
+}
+
+static int socket_forward_enable_event_sources(SocketForward *sf);
+
+static int socket_forward_traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
+        SocketForward *sf = ASSERT_PTR(userdata);
+        int r;
+
+        assert(s);
+        assert(fd >= 0);
+
+        r = socket_forward_shovel(
+                        &sf->server_fd, sf->server_to_client_buffer, &sf->client_fd,
+                        &sf->server_to_client_buffer_full, &sf->server_to_client_buffer_size,
+                        &sf->server_event_source, &sf->client_event_source);
+        if (r < 0)
+                goto quit;
+
+        r = socket_forward_shovel(
+                        &sf->client_fd, sf->client_to_server_buffer, &sf->server_fd,
+                        &sf->client_to_server_buffer_full, &sf->client_to_server_buffer_size,
+                        &sf->client_event_source, &sf->server_event_source);
+        if (r < 0)
+                goto quit;
+
+        /* EOF on both sides? */
+        if (sf->server_fd < 0 && sf->client_fd < 0)
+                goto quit;
+
+        /* Server closed, and all data written to client? */
+        if (sf->server_fd < 0 && sf->server_to_client_buffer_full <= 0)
+                goto quit;
+
+        /* Client closed, and all data written to server? */
+        if (sf->client_fd < 0 && sf->client_to_server_buffer_full <= 0)
+                goto quit;
+
+        r = socket_forward_enable_event_sources(sf);
+        if (r < 0)
+                goto quit;
+
+        return 1;
+
+quit:
+        return sf->on_done(sf, r, sf->userdata);
+}
+
+static int socket_forward_enable_event_sources(SocketForward *sf) {
+        uint32_t a = 0, b = 0;
+        int r;
+
+        assert(sf);
+
+        if (sf->server_to_client_buffer_full > 0)
+                b |= EPOLLOUT;
+        if (sf->server_to_client_buffer_full < sf->server_to_client_buffer_size)
+                a |= EPOLLIN;
+
+        if (sf->client_to_server_buffer_full > 0)
+                a |= EPOLLOUT;
+        if (sf->client_to_server_buffer_full < sf->client_to_server_buffer_size)
+                b |= EPOLLIN;
+
+        if (sf->server_event_source)
+                r = sd_event_source_set_io_events(sf->server_event_source, a);
+        else if (sf->server_fd >= 0)
+                r = sd_event_add_io(sf->event, &sf->server_event_source, sf->server_fd, a, socket_forward_traffic_cb, sf);
+        else
+                r = 0;
+        if (r < 0)
+                return log_debug_errno(r, "Failed to set up server event source: %m");
+
+        if (sf->client_event_source)
+                r = sd_event_source_set_io_events(sf->client_event_source, b);
+        else if (sf->client_fd >= 0)
+                r = sd_event_add_io(sf->event, &sf->client_event_source, sf->client_fd, b, socket_forward_traffic_cb, sf);
+        else
+                r = 0;
+        if (r < 0)
+                return log_debug_errno(r, "Failed to set up client event source: %m");
+
+        return 0;
+}
+
+int socket_forward_new(
+                sd_event *event,
+                int server_fd,
+                int client_fd,
+                socket_forward_done_t on_done,
+                void *userdata,
+                SocketForward **ret) {
+
+        _cleanup_(socket_forward_freep) SocketForward *sf = NULL;
+        int r;
+
+        assert(event);
+        assert(server_fd >= 0);
+        assert(client_fd >= 0);
+        assert(on_done);
+        assert(ret);
+
+        sf = new(SocketForward, 1);
+        if (!sf) {
+                safe_close(server_fd);
+                safe_close(client_fd);
+                return log_oom_debug();
+        }
+
+        *sf = (SocketForward) {
+                .event = sd_event_ref(event),
+                .server_fd = server_fd,
+                .client_fd = client_fd,
+                .server_to_client_buffer = EBADF_PAIR,
+                .client_to_server_buffer = EBADF_PAIR,
+                .on_done = on_done,
+                .userdata = userdata,
+        };
+
+        r = socket_forward_create_pipes(sf->server_to_client_buffer, &sf->server_to_client_buffer_size);
+        if (r < 0)
+                return r;
+
+        r = socket_forward_create_pipes(sf->client_to_server_buffer, &sf->client_to_server_buffer_size);
+        if (r < 0)
+                return r;
+
+        r = socket_forward_enable_event_sources(sf);
+        if (r < 0)
+                return r;
+
+        *ret = TAKE_PTR(sf);
+        return 0;
+}
diff --git a/src/shared/socket-forward.h b/src/shared/socket-forward.h
new file mode 100644 (file)
index 0000000..a2d34da
--- /dev/null
@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: LGPL-2.1-or-later */
+#pragma once
+
+#include "shared-forward.h"
+
+/* Bidirectional socket forwarder using splice().
+ *
+ * Forwards data between two bidirectional sockets ("server" and "client") via kernel pipe buffers,
+ * avoiding userspace copies.
+ *
+ * When forwarding completes (both directions reach EOF or error), the completion callback is invoked.
+ *
+ * The SocketForward takes ownership of both fds - they are closed when the SocketForward is freed
+ * (or earlier, during normal forwarding when EOF/disconnect is detected). */
+
+typedef struct SocketForward SocketForward;
+
+typedef int (*socket_forward_done_t)(SocketForward *sf, int error, void *userdata);
+
+int socket_forward_new(
+                sd_event *event,
+                int server_fd,
+                int client_fd,
+                socket_forward_done_t on_done,
+                void *userdata,
+                SocketForward **ret);
+
+SocketForward* socket_forward_free(SocketForward *sf);
+DEFINE_TRIVIAL_CLEANUP_FUNC(SocketForward*, socket_forward_free);
index 71172326da125603bbcc2d77201733663b122bab..e1eec1dd41c829818f6867514da81d4b5604c1ec 100644 (file)
@@ -1,6 +1,5 @@
 /* SPDX-License-Identifier: LGPL-2.1-or-later */
 
-#include <fcntl.h>
 #include <getopt.h>
 #include <netdb.h>
 #include <stdio.h>
 #include "pretty-print.h"
 #include "resolve-private.h"
 #include "set.h"
+#include "socket-forward.h"
 #include "socket-util.h"
 #include "string-util.h"
 #include "time-util.h"
 
-#define BUFFER_SIZE (256 * 1024)
-
 static unsigned arg_connections_max = 256;
 static const char *arg_remote_host = NULL;
 static usec_t arg_exit_idle_time = USEC_INFINITY;
@@ -45,13 +43,10 @@ typedef struct Connection {
         Context *context;
 
         int server_fd, client_fd;
-        int server_to_client_buffer[2]; /* a pipe */
-        int client_to_server_buffer[2]; /* a pipe */
 
-        size_t server_to_client_buffer_full, client_to_server_buffer_full;
-        size_t server_to_client_buffer_size, client_to_server_buffer_size;
+        sd_event_source *connect_event_source;
 
-        sd_event_source *server_event_source, *client_event_source;
+        SocketForward *forward;
 
         sd_resolve_query *resolve_query;
 } Connection;
@@ -63,15 +58,12 @@ static Connection* connection_free(Connection *c) {
         if (c->context)
                 set_remove(c->context->connections, c);
 
-        sd_event_source_unref(c->server_event_source);
-        sd_event_source_unref(c->client_event_source);
+        sd_event_source_unref(c->connect_event_source);
+        socket_forward_free(c->forward);
 
         safe_close(c->server_fd);
         safe_close(c->client_fd);
 
-        safe_close_pair(c->server_to_client_buffer);
-        safe_close_pair(c->client_to_server_buffer);
-
         sd_resolve_query_unref(c->resolve_query);
 
         return mfree(c);
@@ -134,185 +126,29 @@ static void connection_release(Connection *c) {
         context_reset_timer(context);
 }
 
-static int connection_create_pipes(Connection *c, int buffer[static 2], size_t *sz) {
-        int r;
-
-        assert(c);
-        assert(buffer);
-        assert(sz);
-
-        if (buffer[0] >= 0)
-                return 0;
-
-        r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
-        if (r < 0)
-                return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
-
-        (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
-
-        r = fcntl(buffer[0], F_GETPIPE_SZ);
-        if (r < 0)
-                return log_error_errno(errno, "Failed to get pipe buffer size: %m");
-
-        assert(r > 0);
-        *sz = r;
-
-        return 0;
-}
-
-static int connection_shovel(
-                Connection *c,
-                int *from, int buffer[2], int *to,
-                size_t *full, size_t *sz,
-                sd_event_source **from_source, sd_event_source **to_source) {
-
-        bool shoveled;
-
-        assert(c);
-        assert(from);
-        assert(buffer);
-        assert(buffer[0] >= 0);
-        assert(buffer[1] >= 0);
-        assert(to);
-        assert(full);
-        assert(sz);
-        assert(from_source);
-        assert(to_source);
-
-        do {
-                ssize_t z;
-
-                shoveled = false;
-
-                if (*full < *sz && *from >= 0 && *to >= 0) {
-                        z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
-                        if (z > 0) {
-                                *full += z;
-                                shoveled = true;
-                        } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
-                                *from_source = sd_event_source_unref(*from_source);
-                                *from = safe_close(*from);
-                        } else if (!ERRNO_IS_TRANSIENT(errno))
-                                return log_error_errno(errno, "Failed to splice: %m");
-                }
-
-                if (*full > 0 && *to >= 0) {
-                        z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
-                        if (z > 0) {
-                                *full -= z;
-                                shoveled = true;
-                        } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
-                                *to_source = sd_event_source_unref(*to_source);
-                                *to = safe_close(*to);
-                        } else if (!ERRNO_IS_TRANSIENT(errno))
-                                return log_error_errno(errno, "Failed to splice: %m");
-                }
-        } while (shoveled);
-
-        return 0;
-}
-
-static int connection_enable_event_sources(Connection *c);
-
-static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
+static int connection_forward_done(SocketForward *sf, int error, void *userdata) {
         Connection *c = ASSERT_PTR(userdata);
-        int r;
-
-        assert(s);
-        assert(fd >= 0);
 
-        r = connection_shovel(c,
-                              &c->server_fd, c->server_to_client_buffer, &c->client_fd,
-                              &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
-                              &c->server_event_source, &c->client_event_source);
-        if (r < 0)
-                goto quit;
-
-        r = connection_shovel(c,
-                              &c->client_fd, c->client_to_server_buffer, &c->server_fd,
-                              &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
-                              &c->client_event_source, &c->server_event_source);
-        if (r < 0)
-                goto quit;
-
-        /* EOF on both sides? */
-        if (c->server_fd < 0 && c->client_fd < 0)
-                goto quit;
-
-        /* Server closed, and all data written to client? */
-        if (c->server_fd < 0 && c->server_to_client_buffer_full <= 0)
-                goto quit;
-
-        /* Client closed, and all data written to server? */
-        if (c->client_fd < 0 && c->client_to_server_buffer_full <= 0)
-                goto quit;
-
-        r = connection_enable_event_sources(c);
-        if (r < 0)
-                goto quit;
+        if (error < 0)
+                log_error_errno(error, "Forwarding failed: %m");
 
-        return 1;
-
-quit:
         connection_release(c);
         return 0; /* ignore errors, continue serving */
 }
 
-static int connection_enable_event_sources(Connection *c) {
-        uint32_t a = 0, b = 0;
-        int r;
-
-        assert(c);
-
-        if (c->server_to_client_buffer_full > 0)
-                b |= EPOLLOUT;
-        if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
-                a |= EPOLLIN;
-
-        if (c->client_to_server_buffer_full > 0)
-                a |= EPOLLOUT;
-        if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
-                b |= EPOLLIN;
-
-        if (c->server_event_source)
-                r = sd_event_source_set_io_events(c->server_event_source, a);
-        else if (c->server_fd >= 0)
-                r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
-        else
-                r = 0;
-
-        if (r < 0)
-                return log_error_errno(r, "Failed to set up server event source: %m");
-
-        if (c->client_event_source)
-                r = sd_event_source_set_io_events(c->client_event_source, b);
-        else if (c->client_fd >= 0)
-                r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
-        else
-                r = 0;
-
-        if (r < 0)
-                return log_error_errno(r, "Failed to set up client event source: %m");
-
-        return 0;
-}
-
 static int connection_complete(Connection *c) {
         int r;
 
         assert(c);
 
-        r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
-        if (r < 0)
-                return r;
-
-        r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
+        r = socket_forward_new(
+                        c->context->event,
+                        TAKE_FD(c->server_fd),
+                        TAKE_FD(c->client_fd),
+                        connection_forward_done, c,
+                        &c->forward);
         if (r < 0)
-                return r;
-
-        r = connection_enable_event_sources(c);
-        if (r < 0)
-                return r;
+                return log_error_errno(r, "Failed to set up forwarding: %m");
 
         return 0;
 }
@@ -336,7 +172,7 @@ static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userda
                 goto fail;
         }
 
-        c->client_event_source = sd_event_source_unref(c->client_event_source);
+        c->connect_event_source = sd_event_source_unref(c->connect_event_source);
 
         if (connection_complete(c) < 0)
                 goto fail;
@@ -364,11 +200,11 @@ static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen)
                 if (errno != EINPROGRESS)
                         return log_error_errno(errno, "Failed to connect to remote host: %m");
 
-                r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
+                r = sd_event_add_io(c->context->event, &c->connect_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
                 if (r < 0)
                         return log_error_errno(r, "Failed to add connection socket: %m");
 
-                r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
+                r = sd_event_source_set_enabled(c->connect_event_source, SD_EVENT_ONESHOT);
                 if (r < 0)
                         return log_error_errno(r, "Failed to enable oneshot event source: %m");
 
@@ -472,8 +308,6 @@ static int context_add_connection(Context *context, int fd) {
         *c = (Connection) {
                 .server_fd = TAKE_FD(nfd),
                 .client_fd = -EBADF,
-                .server_to_client_buffer = EBADF_PAIR,
-                .client_to_server_buffer = EBADF_PAIR,
         };
 
         r = set_ensure_put(&context->connections, &connection_hash_ops, c);