]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/socket-proxy/socket-proxyd.c
tree-wide: drop license boilerplate
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
index ca5931166204a3563644ef33f3e5780feb4e9d74..9fa7359cb75540ce2f9c4f1fb8cedca327d3ca21 100644 (file)
@@ -1,32 +1,17 @@
-/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
-
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
   Copyright 2013 David Strauss
-
-  systemd is free software; you can redistribute it and/or modify it
-  under the terms of the GNU Lesser General Public License as published by
-  the Free Software Foundation; either version 2.1 of the License, or
-  (at your option) any later version.
-
-  systemd is distributed in the hope that it will be useful, but
-  WITHOUT ANY WARRANTY; without even the implied warranty of
-  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-  Lesser General Public License for more details.
-
-  You should have received a copy of the GNU Lesser General Public License
-  along with systemd; If not, see <http://www.gnu.org/licenses/>.
  ***/
 
-#include <arpa/inet.h>
 #include <errno.h>
+#include <fcntl.h>
 #include <getopt.h>
+#include <netdb.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
-#include <netdb.h>
-#include <fcntl.h>
 #include <sys/socket.h>
 #include <sys/un.h>
 #include <unistd.h>
 #include "sd-daemon.h"
 #include "sd-event.h"
 #include "sd-resolve.h"
+
+#include "alloc-util.h"
+#include "fd-util.h"
 #include "log.h"
+#include "path-util.h"
+#include "set.h"
 #include "socket-util.h"
+#include "string-util.h"
+#include "parse-util.h"
 #include "util.h"
-#include "event-util.h"
-#include "build.h"
-#include "set.h"
-#include "path-util.h"
 
 #define BUFFER_SIZE (256 * 1024)
-#define CONNECTIONS_MAX 256
+static unsigned arg_connections_max = 256;
 
 static const char *arg_remote_host = NULL;
 
@@ -91,19 +79,10 @@ static void connection_free(Connection *c) {
 }
 
 static void context_free(Context *context) {
-        sd_event_source *es;
-        Connection *c;
-
         assert(context);
 
-        while ((es = set_steal_first(context->listen)))
-                sd_event_source_unref(es);
-
-        while ((c = set_first(context->connections)))
-                connection_free(c);
-
-        set_free(context->listen);
-        set_free(context->connections);
+        set_free_with_destructor(context->listen, sd_event_source_unref);
+        set_free_with_destructor(context->connections, connection_free);
 
         sd_event_unref(context->event);
         sd_resolve_unref(context->resolve);
@@ -120,18 +99,14 @@ static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
                 return 0;
 
         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
-        if (r < 0) {
-                log_error("Failed to allocate pipe buffer: %m");
-                return -errno;
-        }
+        if (r < 0)
+                return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
 
         (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
 
         r = fcntl(buffer[0], F_GETPIPE_SZ);
-        if (r < 0) {
-                log_error("Failed to get pipe buffer size: %m");
-                return -errno;
-        }
+        if (r < 0)
+                return log_error_errno(errno, "Failed to get pipe buffer size: %m");
 
         assert(r > 0);
         *sz = r;
@@ -168,13 +143,11 @@ static int connection_shovel(
                         if (z > 0) {
                                 *full += z;
                                 shoveled = true;
-                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
+                        } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
                                 *from_source = sd_event_source_unref(*from_source);
                                 *from = safe_close(*from);
-                        } else if (errno != EAGAIN && errno != EINTR) {
-                                log_error("Failed to splice: %m");
-                                return -errno;
-                        }
+                        } else if (!IN_SET(errno, EAGAIN, EINTR))
+                                return log_error_errno(errno, "Failed to splice: %m");
                 }
 
                 if (*full > 0 && *to >= 0) {
@@ -182,13 +155,11 @@ static int connection_shovel(
                         if (z > 0) {
                                 *full -= z;
                                 shoveled = true;
-                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
+                        } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
                                 *to_source = sd_event_source_unref(*to_source);
                                 *to = safe_close(*to);
-                        } else if (errno != EAGAIN && errno != EINTR) {
-                                log_error("Failed to splice: %m");
-                                return -errno;
-                        }
+                        } else if (!IN_SET(errno, EAGAIN, EINTR))
+                                return log_error_errno(errno, "Failed to splice: %m");
                 }
         } while (shoveled);
 
@@ -317,7 +288,7 @@ static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userda
         solen = sizeof(error);
         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
         if (r < 0) {
-                log_error("Failed to issue SO_ERROR: %m");
+                log_error_errno(errno, "Failed to issue SO_ERROR: %m");
                 goto fail;
         }
 
@@ -344,7 +315,7 @@ static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen)
 
         c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
         if (c->client_fd < 0) {
-                log_error("Failed to get remote socket: %m");
+                log_error_errno(errno, "Failed to get remote socket: %m");
                 goto fail;
         }
 
@@ -363,7 +334,7 @@ static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen)
                                 goto fail;
                         }
                 } else {
-                        log_error("Failed to connect to remote host: %m");
+                        log_error_errno(errno, "Failed to connect to remote host: %m");
                         goto fail;
                 }
         } else {
@@ -409,34 +380,25 @@ static int resolve_remote(Connection *c) {
 
         union sockaddr_union sa = {};
         const char *node, *service;
-        socklen_t salen;
         int r;
 
         if (path_is_absolute(arg_remote_host)) {
                 sa.un.sun_family = AF_UNIX;
-                strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
-                sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
-
-                salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
-
-                return connection_start(c, &sa.sa, salen);
+                strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path));
+                return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
         }
 
         if (arg_remote_host[0] == '@') {
                 sa.un.sun_family = AF_UNIX;
                 sa.un.sun_path[0] = 0;
-                strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
-                sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
-
-                salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
-
-                return connection_start(c, &sa.sa, salen);
+                strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-1);
+                return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
         }
 
         service = strrchr(arg_remote_host, ':');
         if (service) {
                 node = strndupa(arg_remote_host, service - arg_remote_host);
-                service ++;
+                service++;
         } else {
                 node = arg_remote_host;
                 service = "80";
@@ -463,7 +425,7 @@ static int add_connection_socket(Context *context, int fd) {
         assert(context);
         assert(fd >= 0);
 
-        if (set_size(context->connections) > CONNECTIONS_MAX) {
+        if (set_size(context->connections) > arg_connections_max) {
                 log_warning("Hit connection limit, refusing connection.");
                 safe_close(fd);
                 return 0;
@@ -510,9 +472,9 @@ static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdat
         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
         if (nfd < 0) {
                 if (errno != -EAGAIN)
-                        log_warning("Failed to accept() socket: %m");
+                        log_warning_errno(errno, "Failed to accept() socket: %m");
         } else {
-                getpeername_pretty(nfd, &peer);
+                getpeername_pretty(nfd, true, &peer);
                 log_debug("New connection from %s", strna(peer));
 
                 r = add_connection_socket(context, nfd);
@@ -581,6 +543,7 @@ static void help(void) {
         printf("%1$s [HOST:PORT]\n"
                "%1$s [SOCKET]\n\n"
                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
+               "  -c --connections-max=  Set the maximum number of connections to be accepted\n"
                "  -h --help              Show this help\n"
                "     --version           Show package version\n",
                program_invocation_short_name);
@@ -594,17 +557,18 @@ static int parse_argv(int argc, char *argv[]) {
         };
 
         static const struct option options[] = {
-                { "help",       no_argument, NULL, 'h'           },
-                { "version",    no_argument, NULL, ARG_VERSION   },
+                { "connections-max", required_argument, NULL, 'c'           },
+                { "help",            no_argument,       NULL, 'h'           },
+                { "version",         no_argument,       NULL, ARG_VERSION   },
                 {}
         };
 
-        int c;
+        int c, r;
 
         assert(argc >= 0);
         assert(argv);
 
-        while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
+        while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
 
                 switch (c) {
 
@@ -612,10 +576,22 @@ static int parse_argv(int argc, char *argv[]) {
                         help();
                         return 0;
 
+                case 'c':
+                        r = safe_atou(optarg, &arg_connections_max);
+                        if (r < 0) {
+                                log_error("Failed to parse --connections-max= argument: %s", optarg);
+                                return r;
+                        }
+
+                        if (arg_connections_max < 1) {
+                                log_error("Connection limit is too low.");
+                                return -EINVAL;
+                        }
+
+                        break;
+
                 case ARG_VERSION:
-                        puts(PACKAGE_STRING);
-                        puts(SYSTEMD_FEATURES);
-                        return 0;
+                        return version();
 
                 case '?':
                         return -EINVAL;