]> git.ipfire.org Git - thirdparty/systemd.git/blame_incremental - src/socket-proxy/socket-proxyd.c
man/systemd-sysext: list ephemeral/ephemeral-import in the list of options
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
... / ...
CommitLineData
1/* SPDX-License-Identifier: LGPL-2.1-or-later */
2
3#include <fcntl.h>
4#include <getopt.h>
5#include <netdb.h>
6#include <stdio.h>
7#include <unistd.h>
8
9#include "sd-daemon.h"
10#include "sd-event.h"
11#include "sd-resolve.h"
12
13#include "alloc-util.h"
14#include "build.h"
15#include "daemon-util.h"
16#include "errno-util.h"
17#include "event-util.h"
18#include "fd-util.h"
19#include "log.h"
20#include "main-func.h"
21#include "parse-util.h"
22#include "pretty-print.h"
23#include "resolve-private.h"
24#include "set.h"
25#include "socket-util.h"
26#include "string-util.h"
27#include "time-util.h"
28
29#define BUFFER_SIZE (256 * 1024)
30
31static unsigned arg_connections_max = 256;
32static const char *arg_remote_host = NULL;
33static usec_t arg_exit_idle_time = USEC_INFINITY;
34
35typedef struct Context {
36 sd_event *event;
37 sd_resolve *resolve;
38 sd_event_source *idle_time;
39
40 Set *listen;
41 Set *connections;
42} Context;
43
44typedef struct Connection {
45 Context *context;
46
47 int server_fd, client_fd;
48 int server_to_client_buffer[2]; /* a pipe */
49 int client_to_server_buffer[2]; /* a pipe */
50
51 size_t server_to_client_buffer_full, client_to_server_buffer_full;
52 size_t server_to_client_buffer_size, client_to_server_buffer_size;
53
54 sd_event_source *server_event_source, *client_event_source;
55
56 sd_resolve_query *resolve_query;
57} Connection;
58
59static Connection* connection_free(Connection *c) {
60 if (!c)
61 return NULL;
62
63 if (c->context)
64 set_remove(c->context->connections, c);
65
66 sd_event_source_unref(c->server_event_source);
67 sd_event_source_unref(c->client_event_source);
68
69 safe_close(c->server_fd);
70 safe_close(c->client_fd);
71
72 safe_close_pair(c->server_to_client_buffer);
73 safe_close_pair(c->client_to_server_buffer);
74
75 sd_resolve_query_unref(c->resolve_query);
76
77 return mfree(c);
78}
79
80DEFINE_TRIVIAL_CLEANUP_FUNC(Connection*, connection_free);
81
82DEFINE_PRIVATE_HASH_OPS_WITH_VALUE_DESTRUCTOR(
83 connection_hash_ops,
84 void, trivial_hash_func, trivial_compare_func,
85 Connection, connection_free);
86
87static void context_done(Context *context) {
88 assert(context);
89
90 set_free(context->listen);
91 set_free(context->connections);
92
93 sd_event_unref(context->event);
94 sd_resolve_unref(context->resolve);
95 sd_event_source_unref(context->idle_time);
96}
97
98static int idle_time_cb(sd_event_source *s, uint64_t usec, void *userdata) {
99 Context *c = userdata;
100 int r;
101
102 if (!set_isempty(c->connections)) {
103 log_warning("Idle timer fired even though there are connections, ignoring");
104 return 0;
105 }
106
107 r = sd_event_exit(c->event, 0);
108 if (r < 0) {
109 log_warning_errno(r, "Error while stopping event loop, ignoring: %m");
110 return 0;
111 }
112 return 0;
113}
114
115static void context_reset_timer(Context *context) {
116 int r;
117
118 assert(context);
119
120 if (arg_exit_idle_time < USEC_INFINITY && set_isempty(context->connections)) {
121 r = event_reset_time_relative(
122 context->event, &context->idle_time, CLOCK_MONOTONIC,
123 arg_exit_idle_time, 0, idle_time_cb, context,
124 SD_EVENT_PRIORITY_NORMAL, "idle-timer", /* force_reset = */ true);
125 if (r < 0)
126 log_warning_errno(r, "Failed to reset idle timer, ignoring: %m");
127 }
128}
129
130static void connection_release(Connection *c) {
131 Context *context = ASSERT_PTR(ASSERT_PTR(c)->context);
132
133 connection_free(c);
134 context_reset_timer(context);
135}
136
137static int connection_create_pipes(Connection *c, int buffer[static 2], size_t *sz) {
138 int r;
139
140 assert(c);
141 assert(buffer);
142 assert(sz);
143
144 if (buffer[0] >= 0)
145 return 0;
146
147 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
148 if (r < 0)
149 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
150
151 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
152
153 r = fcntl(buffer[0], F_GETPIPE_SZ);
154 if (r < 0)
155 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
156
157 assert(r > 0);
158 *sz = r;
159
160 return 0;
161}
162
163static int connection_shovel(
164 Connection *c,
165 int *from, int buffer[2], int *to,
166 size_t *full, size_t *sz,
167 sd_event_source **from_source, sd_event_source **to_source) {
168
169 bool shoveled;
170
171 assert(c);
172 assert(from);
173 assert(buffer);
174 assert(buffer[0] >= 0);
175 assert(buffer[1] >= 0);
176 assert(to);
177 assert(full);
178 assert(sz);
179 assert(from_source);
180 assert(to_source);
181
182 do {
183 ssize_t z;
184
185 shoveled = false;
186
187 if (*full < *sz && *from >= 0 && *to >= 0) {
188 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
189 if (z > 0) {
190 *full += z;
191 shoveled = true;
192 } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
193 *from_source = sd_event_source_unref(*from_source);
194 *from = safe_close(*from);
195 } else if (!ERRNO_IS_TRANSIENT(errno))
196 return log_error_errno(errno, "Failed to splice: %m");
197 }
198
199 if (*full > 0 && *to >= 0) {
200 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
201 if (z > 0) {
202 *full -= z;
203 shoveled = true;
204 } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
205 *to_source = sd_event_source_unref(*to_source);
206 *to = safe_close(*to);
207 } else if (!ERRNO_IS_TRANSIENT(errno))
208 return log_error_errno(errno, "Failed to splice: %m");
209 }
210 } while (shoveled);
211
212 return 0;
213}
214
215static int connection_enable_event_sources(Connection *c);
216
217static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
218 Connection *c = ASSERT_PTR(userdata);
219 int r;
220
221 assert(s);
222 assert(fd >= 0);
223
224 r = connection_shovel(c,
225 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
226 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
227 &c->server_event_source, &c->client_event_source);
228 if (r < 0)
229 goto quit;
230
231 r = connection_shovel(c,
232 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
233 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
234 &c->client_event_source, &c->server_event_source);
235 if (r < 0)
236 goto quit;
237
238 /* EOF on both sides? */
239 if (c->server_fd < 0 && c->client_fd < 0)
240 goto quit;
241
242 /* Server closed, and all data written to client? */
243 if (c->server_fd < 0 && c->server_to_client_buffer_full <= 0)
244 goto quit;
245
246 /* Client closed, and all data written to server? */
247 if (c->client_fd < 0 && c->client_to_server_buffer_full <= 0)
248 goto quit;
249
250 r = connection_enable_event_sources(c);
251 if (r < 0)
252 goto quit;
253
254 return 1;
255
256quit:
257 connection_release(c);
258 return 0; /* ignore errors, continue serving */
259}
260
261static int connection_enable_event_sources(Connection *c) {
262 uint32_t a = 0, b = 0;
263 int r;
264
265 assert(c);
266
267 if (c->server_to_client_buffer_full > 0)
268 b |= EPOLLOUT;
269 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
270 a |= EPOLLIN;
271
272 if (c->client_to_server_buffer_full > 0)
273 a |= EPOLLOUT;
274 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
275 b |= EPOLLIN;
276
277 if (c->server_event_source)
278 r = sd_event_source_set_io_events(c->server_event_source, a);
279 else if (c->server_fd >= 0)
280 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
281 else
282 r = 0;
283
284 if (r < 0)
285 return log_error_errno(r, "Failed to set up server event source: %m");
286
287 if (c->client_event_source)
288 r = sd_event_source_set_io_events(c->client_event_source, b);
289 else if (c->client_fd >= 0)
290 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
291 else
292 r = 0;
293
294 if (r < 0)
295 return log_error_errno(r, "Failed to set up client event source: %m");
296
297 return 0;
298}
299
300static int connection_complete(Connection *c) {
301 int r;
302
303 assert(c);
304
305 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
306 if (r < 0)
307 return r;
308
309 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
310 if (r < 0)
311 return r;
312
313 r = connection_enable_event_sources(c);
314 if (r < 0)
315 return r;
316
317 return 0;
318}
319
320static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
321 Connection *c = ASSERT_PTR(userdata);
322 socklen_t solen;
323 int error;
324
325 assert(s);
326 assert(fd >= 0);
327
328 solen = sizeof(error);
329 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen) < 0) {
330 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
331 goto fail;
332 }
333
334 if (error != 0) {
335 log_error_errno(error, "Failed to connect to remote host: %m");
336 goto fail;
337 }
338
339 c->client_event_source = sd_event_source_unref(c->client_event_source);
340
341 if (connection_complete(c) < 0)
342 goto fail;
343
344 return 0;
345
346fail:
347 connection_release(c);
348 return 0; /* ignore errors, continue serving */
349}
350
351static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
352 int r;
353
354 assert(c);
355 assert(sa);
356 assert(salen);
357
358 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
359 if (c->client_fd < 0)
360 return log_error_errno(errno, "Failed to get remote socket: %m");
361
362 r = connect(c->client_fd, sa, salen);
363 if (r < 0) {
364 if (errno != EINPROGRESS)
365 return log_error_errno(errno, "Failed to connect to remote host: %m");
366
367 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
368 if (r < 0)
369 return log_error_errno(r, "Failed to add connection socket: %m");
370
371 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
372 if (r < 0)
373 return log_error_errno(r, "Failed to enable oneshot event source: %m");
374
375 return 0;
376 }
377
378 return connection_complete(c);
379}
380
381static int resolve_handler(sd_resolve_query *q, int ret, const struct addrinfo *ai, Connection *c) {
382 assert(q);
383 assert(c);
384
385 if (ret != 0) {
386 log_error("Failed to resolve host: %s", gai_strerror(ret));
387 goto fail;
388 }
389
390 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
391
392 if (connection_start(c, ai->ai_addr, ai->ai_addrlen) < 0)
393 goto fail;
394
395 return 0;
396
397fail:
398 connection_release(c);
399 return 0; /* ignore errors, continue serving */
400}
401
402static int resolve_remote(Connection *c) {
403
404 static const struct addrinfo hints = {
405 .ai_family = AF_UNSPEC,
406 .ai_socktype = SOCK_STREAM,
407 };
408
409 const char *node, *service;
410 int r;
411
412 if (IN_SET(arg_remote_host[0], '/', '@')) {
413 union sockaddr_union sa;
414 int sa_len;
415
416 r = sockaddr_un_set_path(&sa.un, arg_remote_host);
417 if (r < 0)
418 return log_error_errno(r, "Specified address doesn't fit in an AF_UNIX address, refusing: %m");
419 sa_len = r;
420
421 return connection_start(c, &sa.sa, sa_len);
422 }
423
424 service = strrchr(arg_remote_host, ':');
425 if (service) {
426 node = strndupa_safe(arg_remote_host,
427 service - arg_remote_host);
428 service++;
429 } else {
430 node = arg_remote_host;
431 service = "80";
432 }
433
434 log_debug("Looking up address info for %s:%s", node, service);
435 r = resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_handler, NULL, c);
436 if (r < 0)
437 return log_error_errno(r, "Failed to resolve remote host: %m");
438
439 return 0;
440}
441
442static int context_add_connection(Context *context, int fd) {
443 int r;
444
445 assert(context);
446
447 _cleanup_close_ int nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
448 if (nfd < 0) {
449 if (!ERRNO_IS_ACCEPT_AGAIN(errno))
450 log_warning_errno(errno, "Failed to accept() socket, ignoring: %m");
451
452 return -errno;
453 }
454
455 if (DEBUG_LOGGING) {
456 _cleanup_free_ char *peer = NULL;
457 (void) getpeername_pretty(nfd, true, &peer);
458 log_debug("New connection from %s", strna(peer));
459 }
460
461 if (set_size(context->connections) > arg_connections_max)
462 return log_warning_errno(SYNTHETIC_ERRNO(EBUSY), "Hit connection limit, refusing connection.");
463
464 r = sd_event_source_set_enabled(context->idle_time, SD_EVENT_OFF);
465 if (r < 0)
466 log_warning_errno(r, "Unable to disable idle timer, continuing: %m");
467
468 _cleanup_(connection_freep) Connection *c = new(Connection, 1);
469 if (!c)
470 return log_oom();
471
472 *c = (Connection) {
473 .server_fd = TAKE_FD(nfd),
474 .client_fd = -EBADF,
475 .server_to_client_buffer = EBADF_PAIR,
476 .client_to_server_buffer = EBADF_PAIR,
477 };
478
479 r = set_ensure_put(&context->connections, &connection_hash_ops, c);
480 if (r < 0)
481 return log_oom();
482
483 c->context = context;
484
485 r = resolve_remote(c);
486 if (r < 0)
487 return r;
488
489 TAKE_PTR(c);
490 return 0;
491}
492
493static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
494 Context *context = ASSERT_PTR(userdata);
495 int r;
496
497 assert(s);
498 assert(fd >= 0);
499 assert(revents & EPOLLIN);
500
501 if (context_add_connection(context, fd) < 0)
502 context_reset_timer(context);
503
504 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
505 if (r < 0)
506 return log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
507
508 return 1;
509}
510
511static int add_listen_socket(Context *context, int fd) {
512 int r;
513
514 assert(context);
515 assert(fd >= 0);
516
517 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
518 if (r < 0)
519 return log_error_errno(r, "Failed to determine socket type: %m");
520 if (r == 0)
521 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
522 "Passed in socket is not a stream socket.");
523
524 r = fd_nonblock(fd, true);
525 if (r < 0)
526 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
527
528 _cleanup_(sd_event_source_unrefp) sd_event_source *source = NULL;
529 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
530 if (r < 0)
531 return log_error_errno(r, "Failed to add event source: %m");
532
533 r = sd_event_source_set_exit_on_failure(source, true);
534 if (r < 0)
535 return log_error_errno(r, "Failed to enable exit-on-failure logic: %m");
536
537 /* Set the watcher to oneshot in case other processes are also
538 * watching to accept(). */
539 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
540 if (r < 0)
541 return log_error_errno(r, "Failed to enable oneshot mode: %m");
542
543 r = set_ensure_consume(&context->listen, &event_source_hash_ops, TAKE_PTR(source));
544 if (r < 0)
545 return log_error_errno(r, "Failed to add source to set: %m");
546
547 return 0;
548}
549
550static int help(void) {
551 _cleanup_free_ char *link = NULL;
552 _cleanup_free_ char *time_link = NULL;
553 int r;
554
555 r = terminal_urlify_man("systemd-socket-proxyd", "8", &link);
556 if (r < 0)
557 return log_oom();
558 r = terminal_urlify_man("systemd.time", "7", &time_link);
559 if (r < 0)
560 return log_oom();
561
562 printf("%1$s [HOST:PORT]\n"
563 "%1$s [SOCKET]\n\n"
564 "%2$sBidirectionally proxy local sockets to another (possibly remote) socket.%3$s\n\n"
565 " -c --connections-max= Set the maximum number of connections to be accepted\n"
566 " --exit-idle-time= Exit when without a connection for this duration. See\n"
567 " the %4$s for time span format\n"
568 " -h --help Show this help\n"
569 " --version Show package version\n"
570 "\nSee the %5$s for details.\n",
571 program_invocation_short_name,
572 ansi_highlight(),
573 ansi_normal(),
574 time_link,
575 link);
576
577 return 0;
578}
579
580static int parse_argv(int argc, char *argv[]) {
581
582 enum {
583 ARG_VERSION = 0x100,
584 ARG_EXIT_IDLE,
585 ARG_IGNORE_ENV
586 };
587
588 static const struct option options[] = {
589 { "connections-max", required_argument, NULL, 'c' },
590 { "exit-idle-time", required_argument, NULL, ARG_EXIT_IDLE },
591 { "help", no_argument, NULL, 'h' },
592 { "version", no_argument, NULL, ARG_VERSION },
593 {}
594 };
595
596 int c, r;
597
598 assert(argc >= 0);
599 assert(argv);
600
601 while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
602
603 switch (c) {
604
605 case 'h':
606 return help();
607
608 case ARG_VERSION:
609 return version();
610
611 case 'c':
612 r = safe_atou(optarg, &arg_connections_max);
613 if (r < 0) {
614 log_error("Failed to parse --connections-max= argument: %s", optarg);
615 return r;
616 }
617
618 if (arg_connections_max < 1)
619 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
620 "Connection limit is too low.");
621
622 break;
623
624 case ARG_EXIT_IDLE:
625 r = parse_sec(optarg, &arg_exit_idle_time);
626 if (r < 0)
627 return log_error_errno(r, "Failed to parse --exit-idle-time= argument: %s", optarg);
628 break;
629
630 case '?':
631 return -EINVAL;
632
633 default:
634 assert_not_reached();
635 }
636
637 if (optind >= argc)
638 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
639 "Not enough parameters.");
640
641 if (argc != optind+1)
642 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
643 "Too many parameters.");
644
645 arg_remote_host = argv[optind];
646 return 1;
647}
648
649static int run(int argc, char *argv[]) {
650 _cleanup_(context_done) Context context = {};
651 _unused_ _cleanup_(notify_on_cleanup) const char *notify_stop = NULL;
652 int r, n, fd;
653
654 log_setup();
655
656 r = parse_argv(argc, argv);
657 if (r <= 0)
658 return r;
659
660 r = sd_event_default(&context.event);
661 if (r < 0)
662 return log_error_errno(r, "Failed to allocate event loop: %m");
663
664 r = sd_resolve_default(&context.resolve);
665 if (r < 0)
666 return log_error_errno(r, "Failed to allocate resolver: %m");
667
668 r = sd_resolve_attach_event(context.resolve, context.event, 0);
669 if (r < 0)
670 return log_error_errno(r, "Failed to attach resolver: %m");
671
672 sd_event_set_watchdog(context.event, true);
673
674 r = sd_listen_fds(1);
675 if (r < 0)
676 return log_error_errno(r, "Failed to receive sockets from parent.");
677 if (r == 0)
678 return log_error_errno(SYNTHETIC_ERRNO(EINVAL), "Didn't get any sockets passed in.");
679
680 n = r;
681
682 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
683 r = add_listen_socket(&context, fd);
684 if (r < 0)
685 return r;
686 }
687
688 notify_stop = notify_start(NOTIFY_READY_MESSAGE, NOTIFY_STOPPING_MESSAGE);
689 r = sd_event_loop(context.event);
690 if (r < 0)
691 return log_error_errno(r, "Failed to run event loop: %m");
692
693 return 0;
694}
695
696DEFINE_MAIN_FUNCTION(run);