]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
socket-proxy: use hash_ops with destructor for managing Connection
authorYu Watanabe <watanabe.yu+github@gmail.com>
Sat, 12 Apr 2025 16:16:00 +0000 (01:16 +0900)
committerYu Watanabe <watanabe.yu+github@gmail.com>
Sat, 12 Apr 2025 19:28:35 +0000 (04:28 +0900)
This also renames context_clear() -> context_done(), to follow our
recent coding style.

src/socket-proxy/socket-proxyd.c

index 40bdfe070e04fa4335bffd445a973bec8aeaec45..bf4264a94cc475906d7732c5e6340523405bcdd1 100644 (file)
@@ -80,6 +80,24 @@ static Connection* connection_free(Connection *c) {
         return mfree(c);
 }
 
+DEFINE_TRIVIAL_CLEANUP_FUNC(Connection*, connection_free);
+
+DEFINE_PRIVATE_HASH_OPS_WITH_VALUE_DESTRUCTOR(
+                connection_hash_ops,
+                void, trivial_hash_func, trivial_compare_func,
+                Connection, connection_free);
+
+static void context_done(Context *context) {
+        assert(context);
+
+        set_free_with_destructor(context->listen, sd_event_source_unref);
+        set_free(context->connections);
+
+        sd_event_unref(context->event);
+        sd_resolve_unref(context->resolve);
+        sd_event_source_unref(context->idle_time);
+}
+
 static int idle_time_cb(sd_event_source *s, uint64_t usec, void *userdata) {
         Context *c = userdata;
         int r;
@@ -119,17 +137,6 @@ static void connection_release(Connection *c) {
         context_reset_timer(c->context);
 }
 
-static void context_clear(Context *context) {
-        assert(context);
-
-        set_free_with_destructor(context->listen, sd_event_source_unref);
-        set_free_with_destructor(context->connections, connection_free);
-
-        sd_event_unref(context->event);
-        sd_resolve_unref(context->resolve);
-        sd_event_source_unref(context->idle_time);
-}
-
 static int connection_create_pipes(Connection *c, int buffer[static 2], size_t *sz) {
         int r;
 
@@ -456,70 +463,62 @@ fail:
         return 0; /* ignore errors, continue serving */
 }
 
-static int add_connection_socket(Context *context, int fd) {
-        Connection *c;
+static int context_add_connection(Context *context, int fd) {
         int r;
 
         assert(context);
-        assert(fd >= 0);
 
-        if (set_size(context->connections) > arg_connections_max) {
-                log_warning("Hit connection limit, refusing connection.");
-                safe_close(fd);
-                return 0;
+        _cleanup_close_ int nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
+        if (nfd < 0) {
+                if (!ERRNO_IS_ACCEPT_AGAIN(errno))
+                        log_warning_errno(errno, "Failed to accept() socket, ignoring: %m");
+
+                return -errno;
         }
 
+        if (DEBUG_LOGGING) {
+                _cleanup_free_ char *peer = NULL;
+                (void) getpeername_pretty(nfd, true, &peer);
+                log_debug("New connection from %s", strna(peer));
+        }
+
+        if (set_size(context->connections) > arg_connections_max)
+                return log_warning_errno(SYNTHETIC_ERRNO(EBUSY), "Hit connection limit, refusing connection.");
+
         r = sd_event_source_set_enabled(context->idle_time, SD_EVENT_OFF);
         if (r < 0)
                 log_warning_errno(r, "Unable to disable idle timer, continuing: %m");
 
-        c = new(Connection, 1);
-        if (!c) {
-                log_oom();
-                return 0;
-        }
+        _cleanup_(connection_freep) Connection *c = new(Connection, 1);
+        if (!c)
+                return log_oom();
 
         *c = (Connection) {
-               .context = context,
-               .server_fd = fd,
-               .client_fd = -EBADF,
-               .server_to_client_buffer = EBADF_PAIR,
-               .client_to_server_buffer = EBADF_PAIR,
+                .server_fd = TAKE_FD(nfd),
+                .client_fd = -EBADF,
+                .server_to_client_buffer = EBADF_PAIR,
+                .client_to_server_buffer = EBADF_PAIR,
         };
 
-        r = set_ensure_put(&context->connections, NULL, c);
-        if (r < 0) {
-                free(c);
-                log_oom();
-                return 0;
-        }
+        r = set_ensure_put(&context->connections, &connection_hash_ops, c);
+        if (r < 0)
+                return log_oom();
+
+        c->context = context;
 
-        return resolve_remote(c);
+        return resolve_remote(TAKE_PTR(c));
 }
 
 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
-        _cleanup_free_ char *peer = NULL;
         Context *context = ASSERT_PTR(userdata);
-        int nfd = -EBADF, r;
+        int r;
 
         assert(s);
         assert(fd >= 0);
         assert(revents & EPOLLIN);
 
-        nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
-        if (nfd < 0) {
-                if (!ERRNO_IS_ACCEPT_AGAIN(errno))
-                        log_warning_errno(errno, "Failed to accept() socket: %m");
-        } else {
-                (void) getpeername_pretty(nfd, true, &peer);
-                log_debug("New connection from %s", strna(peer));
-
-                r = add_connection_socket(context, nfd);
-                if (r < 0) {
-                        log_warning_errno(r, "Failed to accept connection, ignoring: %m");
-                        safe_close(nfd);
-                }
-        }
+        if (context_add_connection(context, fd) < 0)
+                context_reset_timer(context);
 
         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
         if (r < 0)
@@ -669,7 +668,7 @@ static int parse_argv(int argc, char *argv[]) {
 }
 
 static int run(int argc, char *argv[]) {
-        _cleanup_(context_clear) Context context = {};
+        _cleanup_(context_done) Context context = {};
         _unused_ _cleanup_(notify_on_cleanup) const char *notify_stop = NULL;
         int r, n, fd;