]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/socket-proxy/socket-proxyd.c
tree-wide: make use of new relative time events in sd-event.h
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
index 62b7051073c740ff7ae922cbe76299aa31978e10..50f5b9c9c536a21ffc83508761e3b633b38401be 100644 (file)
@@ -6,8 +6,6 @@
 #include <netdb.h>
 #include <stdio.h>
 #include <stdlib.h>
-#include <string.h>
-#include <sys/socket.h>
 #include <sys/un.h>
 #include <unistd.h>
 
 #include "sd-resolve.h"
 
 #include "alloc-util.h"
+#include "errno-util.h"
 #include "fd-util.h"
 #include "log.h"
+#include "main-func.h"
 #include "parse-util.h"
 #include "path-util.h"
 #include "pretty-print.h"
+#include "resolve-private.h"
 #include "set.h"
 #include "socket-util.h"
 #include "string-util.h"
 #include "util.h"
 
 #define BUFFER_SIZE (256 * 1024)
-static unsigned arg_connections_max = 256;
 
+static unsigned arg_connections_max = 256;
 static const char *arg_remote_host = NULL;
+static usec_t arg_exit_idle_time = USEC_INFINITY;
 
 typedef struct Context {
         sd_event *event;
         sd_resolve *resolve;
+        sd_event_source *idle_time;
 
         Set *listen;
         Set *connections;
@@ -74,7 +77,51 @@ static void connection_free(Connection *c) {
         free(c);
 }
 
-static void context_free(Context *context) {
+static int idle_time_cb(sd_event_source *s, uint64_t usec, void *userdata) {
+        Context *c = userdata;
+        int r;
+
+        if (!set_isempty(c->connections)) {
+                log_warning("Idle timer fired even though there are connections, ignoring");
+                return 0;
+        }
+
+        r = sd_event_exit(c->event, 0);
+        if (r < 0) {
+                log_warning_errno(r, "Error while stopping event loop, ignoring: %m");
+                return 0;
+        }
+        return 0;
+}
+
+static int connection_release(Connection *c) {
+        Context *context = c->context;
+        int r;
+
+        connection_free(c);
+
+        if (arg_exit_idle_time < USEC_INFINITY && set_isempty(context->connections)) {
+                if (context->idle_time) {
+                        r = sd_event_source_set_time_relative(context->idle_time, arg_exit_idle_time);
+                        if (r < 0)
+                                return log_error_errno(r, "Error while setting idle time: %m");
+
+                        r = sd_event_source_set_enabled(context->idle_time, SD_EVENT_ONESHOT);
+                        if (r < 0)
+                                return log_error_errno(r, "Error while enabling idle time: %m");
+                } else {
+                        r = sd_event_add_time_relative(
+                                        context->event, &context->idle_time, CLOCK_MONOTONIC,
+                                        arg_exit_idle_time, 0, idle_time_cb, context);
+                        if (r < 0)
+                                return log_error_errno(r, "Failed to create idle timer: %m");
+                }
+        }
+
+        return 0;
+}
+
+static void context_clear(Context *context) {
         assert(context);
 
         set_free_with_destructor(context->listen, sd_event_source_unref);
@@ -82,9 +129,10 @@ static void context_free(Context *context) {
 
         sd_event_unref(context->event);
         sd_resolve_unref(context->resolve);
+        sd_event_source_unref(context->idle_time);
 }
 
-static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
+static int connection_create_pipes(Connection *c, int buffer[static 2], size_t *sz) {
         int r;
 
         assert(c);
@@ -139,7 +187,7 @@ static int connection_shovel(
                         if (z > 0) {
                                 *full += z;
                                 shoveled = true;
-                        } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
+                        } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
                                 *from_source = sd_event_source_unref(*from_source);
                                 *from = safe_close(*from);
                         } else if (!IN_SET(errno, EAGAIN, EINTR))
@@ -151,7 +199,7 @@ static int connection_shovel(
                         if (z > 0) {
                                 *full -= z;
                                 shoveled = true;
-                        } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
+                        } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
                                 *to_source = sd_event_source_unref(*to_source);
                                 *to = safe_close(*to);
                         } else if (!IN_SET(errno, EAGAIN, EINTR))
@@ -205,7 +253,7 @@ static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userda
         return 1;
 
 quit:
-        connection_free(c);
+        connection_release(c);
         return 0; /* ignore errors, continue serving */
 }
 
@@ -268,7 +316,7 @@ static int connection_complete(Connection *c) {
         return 0;
 
 fail:
-        connection_free(c);
+        connection_release(c);
         return 0; /* ignore errors, continue serving */
 }
 
@@ -298,7 +346,7 @@ static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userda
         return connection_complete(c);
 
 fail:
-        connection_free(c);
+        connection_release(c);
         return 0; /* ignore errors, continue serving */
 }
 
@@ -342,13 +390,11 @@ static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen)
         return 0;
 
 fail:
-        connection_free(c);
+        connection_release(c);
         return 0; /* ignore errors, continue serving */
 }
 
-static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
-        Connection *c = userdata;
-
+static int resolve_handler(sd_resolve_query *q, int ret, const struct addrinfo *ai, Connection *c) {
         assert(q);
         assert(c);
 
@@ -362,7 +408,7 @@ static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, v
         return connection_start(c, ai->ai_addr, ai->ai_addrlen);
 
 fail:
-        connection_free(c);
+        connection_release(c);
         return 0; /* ignore errors, continue serving */
 }
 
@@ -374,20 +420,21 @@ static int resolve_remote(Connection *c) {
                 .ai_flags = AI_ADDRCONFIG
         };
 
-        union sockaddr_union sa = {};
         const char *node, *service;
         int r;
 
         if (IN_SET(arg_remote_host[0], '/', '@')) {
-                int salen;
+                union sockaddr_union sa;
+                int sa_len;
 
-                salen = sockaddr_un_set_path(&sa.un, arg_remote_host);
-                if (salen < 0) {
-                        log_error_errno(salen, "Specified address doesn't fit in an AF_UNIX address, refusing: %m");
+                r = sockaddr_un_set_path(&sa.un, arg_remote_host);
+                if (r < 0) {
+                        log_error_errno(r, "Specified address doesn't fit in an AF_UNIX address, refusing: %m");
                         goto fail;
                 }
+                sa_len = r;
 
-                return connection_start(c, &sa.sa, salen);
+                return connection_start(c, &sa.sa, sa_len);
         }
 
         service = strrchr(arg_remote_host, ':');
@@ -400,7 +447,7 @@ static int resolve_remote(Connection *c) {
         }
 
         log_debug("Looking up address info for %s:%s", node, service);
-        r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
+        r = resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_handler, NULL, c);
         if (r < 0) {
                 log_error_errno(r, "Failed to resolve remote host: %m");
                 goto fail;
@@ -409,7 +456,7 @@ static int resolve_remote(Connection *c) {
         return 0;
 
 fail:
-        connection_free(c);
+        connection_release(c);
         return 0; /* ignore errors, continue serving */
 }
 
@@ -426,25 +473,27 @@ static int add_connection_socket(Context *context, int fd) {
                 return 0;
         }
 
-        r = set_ensure_allocated(&context->connections, NULL);
-        if (r < 0) {
-                log_oom();
-                return 0;
+        if (context->idle_time) {
+                r = sd_event_source_set_enabled(context->idle_time, SD_EVENT_OFF);
+                if (r < 0)
+                        log_warning_errno(r, "Unable to disable idle timer, continuing: %m");
         }
 
-        c = new0(Connection, 1);
+        c = new(Connection, 1);
         if (!c) {
                 log_oom();
                 return 0;
         }
 
-        c->context = context;
-        c->server_fd = fd;
-        c->client_fd = -1;
-        c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
-        c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
+        *c = (Connection) {
+               .context = context,
+               .server_fd = fd,
+               .client_fd = -1,
+               .server_to_client_buffer = {-1, -1},
+               .client_to_server_buffer = {-1, -1},
+        };
 
-        r = set_put(context->connections, c);
+        r = set_ensure_put(&context->connections, NULL, c);
         if (r < 0) {
                 free(c);
                 log_oom();
@@ -466,10 +515,10 @@ static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdat
 
         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
         if (nfd < 0) {
-                if (errno != -EAGAIN)
+                if (!ERRNO_IS_ACCEPT_AGAIN(errno))
                         log_warning_errno(errno, "Failed to accept() socket: %m");
         } else {
-                getpeername_pretty(nfd, true, &peer);
+                (void) getpeername_pretty(nfd, true, &peer);
                 log_debug("New connection from %s", strna(peer));
 
                 r = add_connection_socket(context, nfd);
@@ -496,12 +545,6 @@ static int add_listen_socket(Context *context, int fd) {
         assert(context);
         assert(fd >= 0);
 
-        r = set_ensure_allocated(&context->listen, NULL);
-        if (r < 0) {
-                log_oom();
-                return r;
-        }
-
         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
         if (r < 0)
                 return log_error_errno(r, "Failed to determine socket type: %m");
@@ -517,7 +560,7 @@ static int add_listen_socket(Context *context, int fd) {
         if (r < 0)
                 return log_error_errno(r, "Failed to add event source: %m");
 
-        r = set_put(context->listen, source);
+        r = set_ensure_put(&context->listen, NULL, source);
         if (r < 0) {
                 log_error_errno(r, "Failed to add source to set: %m");
                 sd_event_source_unref(source);
@@ -535,9 +578,13 @@ static int add_listen_socket(Context *context, int fd) {
 
 static int help(void) {
         _cleanup_free_ char *link = NULL;
+        _cleanup_free_ char *time_link = NULL;
         int r;
 
         r = terminal_urlify_man("systemd-socket-proxyd", "8", &link);
+        if (r < 0)
+                return log_oom();
+        r = terminal_urlify_man("systemd.time", "7", &time_link);
         if (r < 0)
                 return log_oom();
 
@@ -545,11 +592,14 @@ static int help(void) {
                "%1$s [SOCKET]\n\n"
                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
                "  -c --connections-max=  Set the maximum number of connections to be accepted\n"
+               "     --exit-idle-time=   Exit when without a connection for this duration. See\n"
+               "                         the %3$s for time span format\n"
                "  -h --help              Show this help\n"
                "     --version           Show package version\n"
                "\nSee the %2$s for details.\n"
                , program_invocation_short_name
                , link
+               , time_link
         );
 
         return 0;
@@ -559,11 +609,13 @@ static int parse_argv(int argc, char *argv[]) {
 
         enum {
                 ARG_VERSION = 0x100,
+                ARG_EXIT_IDLE,
                 ARG_IGNORE_ENV
         };
 
         static const struct option options[] = {
                 { "connections-max", required_argument, NULL, 'c'           },
+                { "exit-idle-time",  required_argument, NULL, ARG_EXIT_IDLE },
                 { "help",            no_argument,       NULL, 'h'           },
                 { "version",         no_argument,       NULL, ARG_VERSION   },
                 {}
@@ -597,6 +649,12 @@ static int parse_argv(int argc, char *argv[]) {
 
                         break;
 
+                case ARG_EXIT_IDLE:
+                        r = parse_sec(optarg, &arg_exit_idle_time);
+                        if (r < 0)
+                                return log_error_errno(r, "Failed to parse --exit-idle-time= argument: %s", optarg);
+                        break;
+
                 case '?':
                         return -EINVAL;
 
@@ -616,8 +674,8 @@ static int parse_argv(int argc, char *argv[]) {
         return 1;
 }
 
-int main(int argc, char *argv[]) {
-        Context context = {};
+static int run(int argc, char *argv[]) {
+        _cleanup_(context_clear) Context context = {};
         int r, n, fd;
 
         log_parse_environment();
@@ -625,53 +683,41 @@ int main(int argc, char *argv[]) {
 
         r = parse_argv(argc, argv);
         if (r <= 0)
-                goto finish;
+                return r;
 
         r = sd_event_default(&context.event);
-        if (r < 0) {
-                log_error_errno(r, "Failed to allocate event loop: %m");
-                goto finish;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Failed to allocate event loop: %m");
 
         r = sd_resolve_default(&context.resolve);
-        if (r < 0) {
-                log_error_errno(r, "Failed to allocate resolver: %m");
-                goto finish;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Failed to allocate resolver: %m");
 
         r = sd_resolve_attach_event(context.resolve, context.event, 0);
-        if (r < 0) {
-                log_error_errno(r, "Failed to attach resolver: %m");
-                goto finish;
-        }
+        if (r < 0)
+                return log_error_errno(r, "Failed to attach resolver: %m");
 
         sd_event_set_watchdog(context.event, true);
 
-        n = sd_listen_fds(1);
-        if (n < 0) {
-                log_error("Failed to receive sockets from parent.");
-                r = n;
-                goto finish;
-        } else if (n == 0) {
-                log_error("Didn't get any sockets passed in.");
-                r = -EINVAL;
-                goto finish;
-        }
+        r = sd_listen_fds(1);
+        if (r < 0)
+                return log_error_errno(r, "Failed to receive sockets from parent.");
+        if (r == 0)
+                return log_error_errno(SYNTHETIC_ERRNO(EINVAL), "Didn't get any sockets passed in.");
+
+        n = r;
 
         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
                 r = add_listen_socket(&context, fd);
                 if (r < 0)
-                        goto finish;
+                        return r;
         }
 
         r = sd_event_loop(context.event);
-        if (r < 0) {
-                log_error_errno(r, "Failed to run event loop: %m");
-                goto finish;
-        }
-
-finish:
-        context_free(&context);
+        if (r < 0)
+                return log_error_errno(r, "Failed to run event loop: %m");
 
-        return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
+        return 0;
 }
+
+DEFINE_MAIN_FUNCTION(run);