]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/socket-proxy/socket-proxyd.c
tree-wide: drop license boilerplate
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
index 52b4db8875fb2de71dab07cd3153b8b7b5fb28f6..9fa7359cb75540ce2f9c4f1fb8cedca327d3ca21 100644 (file)
@@ -1,20 +1,8 @@
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
   Copyright 2013 David Strauss
-
-  systemd is free software; you can redistribute it and/or modify it
-  under the terms of the GNU Lesser General Public License as published by
-  the Free Software Foundation; either version 2.1 of the License, or
-  (at your option) any later version.
-
-  systemd is distributed in the hope that it will be useful, but
-  WITHOUT ANY WARRANTY; without even the implied warranty of
-  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-  Lesser General Public License for more details.
-
-  You should have received a copy of the GNU Lesser General Public License
-  along with systemd; If not, see <http://www.gnu.org/licenses/>.
  ***/
 
 #include <errno.h>
 #include "set.h"
 #include "socket-util.h"
 #include "string-util.h"
+#include "parse-util.h"
 #include "util.h"
 
 #define BUFFER_SIZE (256 * 1024)
-#define CONNECTIONS_MAX 256
+static unsigned arg_connections_max = 256;
 
 static const char *arg_remote_host = NULL;
 
@@ -90,19 +79,10 @@ static void connection_free(Connection *c) {
 }
 
 static void context_free(Context *context) {
-        sd_event_source *es;
-        Connection *c;
-
         assert(context);
 
-        while ((es = set_steal_first(context->listen)))
-                sd_event_source_unref(es);
-
-        while ((c = set_first(context->connections)))
-                connection_free(c);
-
-        set_free(context->listen);
-        set_free(context->connections);
+        set_free_with_destructor(context->listen, sd_event_source_unref);
+        set_free_with_destructor(context->connections, connection_free);
 
         sd_event_unref(context->event);
         sd_resolve_unref(context->resolve);
@@ -163,10 +143,10 @@ static int connection_shovel(
                         if (z > 0) {
                                 *full += z;
                                 shoveled = true;
-                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
+                        } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
                                 *from_source = sd_event_source_unref(*from_source);
                                 *from = safe_close(*from);
-                        } else if (errno != EAGAIN && errno != EINTR)
+                        } else if (!IN_SET(errno, EAGAIN, EINTR))
                                 return log_error_errno(errno, "Failed to splice: %m");
                 }
 
@@ -175,10 +155,10 @@ static int connection_shovel(
                         if (z > 0) {
                                 *full -= z;
                                 shoveled = true;
-                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
+                        } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
                                 *to_source = sd_event_source_unref(*to_source);
                                 *to = safe_close(*to);
-                        } else if (errno != EAGAIN && errno != EINTR)
+                        } else if (!IN_SET(errno, EAGAIN, EINTR))
                                 return log_error_errno(errno, "Failed to splice: %m");
                 }
         } while (shoveled);
@@ -445,7 +425,7 @@ static int add_connection_socket(Context *context, int fd) {
         assert(context);
         assert(fd >= 0);
 
-        if (set_size(context->connections) > CONNECTIONS_MAX) {
+        if (set_size(context->connections) > arg_connections_max) {
                 log_warning("Hit connection limit, refusing connection.");
                 safe_close(fd);
                 return 0;
@@ -563,6 +543,7 @@ static void help(void) {
         printf("%1$s [HOST:PORT]\n"
                "%1$s [SOCKET]\n\n"
                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
+               "  -c --connections-max=  Set the maximum number of connections to be accepted\n"
                "  -h --help              Show this help\n"
                "     --version           Show package version\n",
                program_invocation_short_name);
@@ -576,17 +557,18 @@ static int parse_argv(int argc, char *argv[]) {
         };
 
         static const struct option options[] = {
-                { "help",       no_argument, NULL, 'h'           },
-                { "version",    no_argument, NULL, ARG_VERSION   },
+                { "connections-max", required_argument, NULL, 'c'           },
+                { "help",            no_argument,       NULL, 'h'           },
+                { "version",         no_argument,       NULL, ARG_VERSION   },
                 {}
         };
 
-        int c;
+        int c, r;
 
         assert(argc >= 0);
         assert(argv);
 
-        while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
+        while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
 
                 switch (c) {
 
@@ -594,6 +576,20 @@ static int parse_argv(int argc, char *argv[]) {
                         help();
                         return 0;
 
+                case 'c':
+                        r = safe_atou(optarg, &arg_connections_max);
+                        if (r < 0) {
+                                log_error("Failed to parse --connections-max= argument: %s", optarg);
+                                return r;
+                        }
+
+                        if (arg_connections_max < 1) {
+                                log_error("Connection limit is too low.");
+                                return -EINVAL;
+                        }
+
+                        break;
+
                 case ARG_VERSION:
                         return version();