]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
Merge pull request #12013 from yuwata/fix-switchroot-11997
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2
3 #include <errno.h>
4 #include <fcntl.h>
5 #include <getopt.h>
6 #include <netdb.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <string.h>
10 #include <sys/socket.h>
11 #include <sys/un.h>
12 #include <unistd.h>
13
14 #include "sd-daemon.h"
15 #include "sd-event.h"
16 #include "sd-resolve.h"
17
18 #include "alloc-util.h"
19 #include "errno-util.h"
20 #include "fd-util.h"
21 #include "log.h"
22 #include "main-func.h"
23 #include "parse-util.h"
24 #include "path-util.h"
25 #include "pretty-print.h"
26 #include "resolve-private.h"
27 #include "set.h"
28 #include "socket-util.h"
29 #include "string-util.h"
30 #include "util.h"
31
32 #define BUFFER_SIZE (256 * 1024)
33
34 static unsigned arg_connections_max = 256;
35 static const char *arg_remote_host = NULL;
36
37 typedef struct Context {
38 sd_event *event;
39 sd_resolve *resolve;
40
41 Set *listen;
42 Set *connections;
43 } Context;
44
45 typedef struct Connection {
46 Context *context;
47
48 int server_fd, client_fd;
49 int server_to_client_buffer[2]; /* a pipe */
50 int client_to_server_buffer[2]; /* a pipe */
51
52 size_t server_to_client_buffer_full, client_to_server_buffer_full;
53 size_t server_to_client_buffer_size, client_to_server_buffer_size;
54
55 sd_event_source *server_event_source, *client_event_source;
56
57 sd_resolve_query *resolve_query;
58 } Connection;
59
60 static void connection_free(Connection *c) {
61 assert(c);
62
63 if (c->context)
64 set_remove(c->context->connections, c);
65
66 sd_event_source_unref(c->server_event_source);
67 sd_event_source_unref(c->client_event_source);
68
69 safe_close(c->server_fd);
70 safe_close(c->client_fd);
71
72 safe_close_pair(c->server_to_client_buffer);
73 safe_close_pair(c->client_to_server_buffer);
74
75 sd_resolve_query_unref(c->resolve_query);
76
77 free(c);
78 }
79
80 static void context_clear(Context *context) {
81 assert(context);
82
83 set_free_with_destructor(context->listen, sd_event_source_unref);
84 set_free_with_destructor(context->connections, connection_free);
85
86 sd_event_unref(context->event);
87 sd_resolve_unref(context->resolve);
88 }
89
90 static int connection_create_pipes(Connection *c, int buffer[static 2], size_t *sz) {
91 int r;
92
93 assert(c);
94 assert(buffer);
95 assert(sz);
96
97 if (buffer[0] >= 0)
98 return 0;
99
100 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
101 if (r < 0)
102 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
103
104 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
105
106 r = fcntl(buffer[0], F_GETPIPE_SZ);
107 if (r < 0)
108 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
109
110 assert(r > 0);
111 *sz = r;
112
113 return 0;
114 }
115
116 static int connection_shovel(
117 Connection *c,
118 int *from, int buffer[2], int *to,
119 size_t *full, size_t *sz,
120 sd_event_source **from_source, sd_event_source **to_source) {
121
122 bool shoveled;
123
124 assert(c);
125 assert(from);
126 assert(buffer);
127 assert(buffer[0] >= 0);
128 assert(buffer[1] >= 0);
129 assert(to);
130 assert(full);
131 assert(sz);
132 assert(from_source);
133 assert(to_source);
134
135 do {
136 ssize_t z;
137
138 shoveled = false;
139
140 if (*full < *sz && *from >= 0 && *to >= 0) {
141 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
142 if (z > 0) {
143 *full += z;
144 shoveled = true;
145 } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
146 *from_source = sd_event_source_unref(*from_source);
147 *from = safe_close(*from);
148 } else if (!IN_SET(errno, EAGAIN, EINTR))
149 return log_error_errno(errno, "Failed to splice: %m");
150 }
151
152 if (*full > 0 && *to >= 0) {
153 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
154 if (z > 0) {
155 *full -= z;
156 shoveled = true;
157 } else if (z == 0 || ERRNO_IS_DISCONNECT(errno)) {
158 *to_source = sd_event_source_unref(*to_source);
159 *to = safe_close(*to);
160 } else if (!IN_SET(errno, EAGAIN, EINTR))
161 return log_error_errno(errno, "Failed to splice: %m");
162 }
163 } while (shoveled);
164
165 return 0;
166 }
167
168 static int connection_enable_event_sources(Connection *c);
169
170 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
171 Connection *c = userdata;
172 int r;
173
174 assert(s);
175 assert(fd >= 0);
176 assert(c);
177
178 r = connection_shovel(c,
179 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
180 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
181 &c->server_event_source, &c->client_event_source);
182 if (r < 0)
183 goto quit;
184
185 r = connection_shovel(c,
186 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
187 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
188 &c->client_event_source, &c->server_event_source);
189 if (r < 0)
190 goto quit;
191
192 /* EOF on both sides? */
193 if (c->server_fd == -1 && c->client_fd == -1)
194 goto quit;
195
196 /* Server closed, and all data written to client? */
197 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
198 goto quit;
199
200 /* Client closed, and all data written to server? */
201 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
202 goto quit;
203
204 r = connection_enable_event_sources(c);
205 if (r < 0)
206 goto quit;
207
208 return 1;
209
210 quit:
211 connection_free(c);
212 return 0; /* ignore errors, continue serving */
213 }
214
215 static int connection_enable_event_sources(Connection *c) {
216 uint32_t a = 0, b = 0;
217 int r;
218
219 assert(c);
220
221 if (c->server_to_client_buffer_full > 0)
222 b |= EPOLLOUT;
223 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
224 a |= EPOLLIN;
225
226 if (c->client_to_server_buffer_full > 0)
227 a |= EPOLLOUT;
228 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
229 b |= EPOLLIN;
230
231 if (c->server_event_source)
232 r = sd_event_source_set_io_events(c->server_event_source, a);
233 else if (c->server_fd >= 0)
234 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
235 else
236 r = 0;
237
238 if (r < 0)
239 return log_error_errno(r, "Failed to set up server event source: %m");
240
241 if (c->client_event_source)
242 r = sd_event_source_set_io_events(c->client_event_source, b);
243 else if (c->client_fd >= 0)
244 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
245 else
246 r = 0;
247
248 if (r < 0)
249 return log_error_errno(r, "Failed to set up client event source: %m");
250
251 return 0;
252 }
253
254 static int connection_complete(Connection *c) {
255 int r;
256
257 assert(c);
258
259 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
260 if (r < 0)
261 goto fail;
262
263 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
264 if (r < 0)
265 goto fail;
266
267 r = connection_enable_event_sources(c);
268 if (r < 0)
269 goto fail;
270
271 return 0;
272
273 fail:
274 connection_free(c);
275 return 0; /* ignore errors, continue serving */
276 }
277
278 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
279 Connection *c = userdata;
280 socklen_t solen;
281 int error, r;
282
283 assert(s);
284 assert(fd >= 0);
285 assert(c);
286
287 solen = sizeof(error);
288 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
289 if (r < 0) {
290 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
291 goto fail;
292 }
293
294 if (error != 0) {
295 log_error_errno(error, "Failed to connect to remote host: %m");
296 goto fail;
297 }
298
299 c->client_event_source = sd_event_source_unref(c->client_event_source);
300
301 return connection_complete(c);
302
303 fail:
304 connection_free(c);
305 return 0; /* ignore errors, continue serving */
306 }
307
308 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
309 int r;
310
311 assert(c);
312 assert(sa);
313 assert(salen);
314
315 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
316 if (c->client_fd < 0) {
317 log_error_errno(errno, "Failed to get remote socket: %m");
318 goto fail;
319 }
320
321 r = connect(c->client_fd, sa, salen);
322 if (r < 0) {
323 if (errno == EINPROGRESS) {
324 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
325 if (r < 0) {
326 log_error_errno(r, "Failed to add connection socket: %m");
327 goto fail;
328 }
329
330 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
331 if (r < 0) {
332 log_error_errno(r, "Failed to enable oneshot event source: %m");
333 goto fail;
334 }
335 } else {
336 log_error_errno(errno, "Failed to connect to remote host: %m");
337 goto fail;
338 }
339 } else {
340 r = connection_complete(c);
341 if (r < 0)
342 goto fail;
343 }
344
345 return 0;
346
347 fail:
348 connection_free(c);
349 return 0; /* ignore errors, continue serving */
350 }
351
352 static int resolve_handler(sd_resolve_query *q, int ret, const struct addrinfo *ai, Connection *c) {
353 assert(q);
354 assert(c);
355
356 if (ret != 0) {
357 log_error("Failed to resolve host: %s", gai_strerror(ret));
358 goto fail;
359 }
360
361 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
362
363 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
364
365 fail:
366 connection_free(c);
367 return 0; /* ignore errors, continue serving */
368 }
369
370 static int resolve_remote(Connection *c) {
371
372 static const struct addrinfo hints = {
373 .ai_family = AF_UNSPEC,
374 .ai_socktype = SOCK_STREAM,
375 .ai_flags = AI_ADDRCONFIG
376 };
377
378 union sockaddr_union sa = {};
379 const char *node, *service;
380 int r;
381
382 if (IN_SET(arg_remote_host[0], '/', '@')) {
383 int salen;
384
385 salen = sockaddr_un_set_path(&sa.un, arg_remote_host);
386 if (salen < 0) {
387 log_error_errno(salen, "Specified address doesn't fit in an AF_UNIX address, refusing: %m");
388 goto fail;
389 }
390
391 return connection_start(c, &sa.sa, salen);
392 }
393
394 service = strrchr(arg_remote_host, ':');
395 if (service) {
396 node = strndupa(arg_remote_host, service - arg_remote_host);
397 service++;
398 } else {
399 node = arg_remote_host;
400 service = "80";
401 }
402
403 log_debug("Looking up address info for %s:%s", node, service);
404 r = resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_handler, NULL, c);
405 if (r < 0) {
406 log_error_errno(r, "Failed to resolve remote host: %m");
407 goto fail;
408 }
409
410 return 0;
411
412 fail:
413 connection_free(c);
414 return 0; /* ignore errors, continue serving */
415 }
416
417 static int add_connection_socket(Context *context, int fd) {
418 Connection *c;
419 int r;
420
421 assert(context);
422 assert(fd >= 0);
423
424 if (set_size(context->connections) > arg_connections_max) {
425 log_warning("Hit connection limit, refusing connection.");
426 safe_close(fd);
427 return 0;
428 }
429
430 r = set_ensure_allocated(&context->connections, NULL);
431 if (r < 0) {
432 log_oom();
433 return 0;
434 }
435
436 c = new0(Connection, 1);
437 if (!c) {
438 log_oom();
439 return 0;
440 }
441
442 c->context = context;
443 c->server_fd = fd;
444 c->client_fd = -1;
445 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
446 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
447
448 r = set_put(context->connections, c);
449 if (r < 0) {
450 free(c);
451 log_oom();
452 return 0;
453 }
454
455 return resolve_remote(c);
456 }
457
458 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
459 _cleanup_free_ char *peer = NULL;
460 Context *context = userdata;
461 int nfd = -1, r;
462
463 assert(s);
464 assert(fd >= 0);
465 assert(revents & EPOLLIN);
466 assert(context);
467
468 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
469 if (nfd < 0) {
470 if (errno != -EAGAIN)
471 log_warning_errno(errno, "Failed to accept() socket: %m");
472 } else {
473 getpeername_pretty(nfd, true, &peer);
474 log_debug("New connection from %s", strna(peer));
475
476 r = add_connection_socket(context, nfd);
477 if (r < 0) {
478 log_error_errno(r, "Failed to accept connection, ignoring: %m");
479 safe_close(fd);
480 }
481 }
482
483 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
484 if (r < 0) {
485 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
486 sd_event_exit(context->event, r);
487 return r;
488 }
489
490 return 1;
491 }
492
493 static int add_listen_socket(Context *context, int fd) {
494 sd_event_source *source;
495 int r;
496
497 assert(context);
498 assert(fd >= 0);
499
500 r = set_ensure_allocated(&context->listen, NULL);
501 if (r < 0) {
502 log_oom();
503 return r;
504 }
505
506 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
507 if (r < 0)
508 return log_error_errno(r, "Failed to determine socket type: %m");
509 if (r == 0)
510 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
511 "Passed in socket is not a stream socket.");
512
513 r = fd_nonblock(fd, true);
514 if (r < 0)
515 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
516
517 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
518 if (r < 0)
519 return log_error_errno(r, "Failed to add event source: %m");
520
521 r = set_put(context->listen, source);
522 if (r < 0) {
523 log_error_errno(r, "Failed to add source to set: %m");
524 sd_event_source_unref(source);
525 return r;
526 }
527
528 /* Set the watcher to oneshot in case other processes are also
529 * watching to accept(). */
530 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
531 if (r < 0)
532 return log_error_errno(r, "Failed to enable oneshot mode: %m");
533
534 return 0;
535 }
536
537 static int help(void) {
538 _cleanup_free_ char *link = NULL;
539 int r;
540
541 r = terminal_urlify_man("systemd-socket-proxyd", "8", &link);
542 if (r < 0)
543 return log_oom();
544
545 printf("%1$s [HOST:PORT]\n"
546 "%1$s [SOCKET]\n\n"
547 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
548 " -c --connections-max= Set the maximum number of connections to be accepted\n"
549 " -h --help Show this help\n"
550 " --version Show package version\n"
551 "\nSee the %2$s for details.\n"
552 , program_invocation_short_name
553 , link
554 );
555
556 return 0;
557 }
558
559 static int parse_argv(int argc, char *argv[]) {
560
561 enum {
562 ARG_VERSION = 0x100,
563 ARG_IGNORE_ENV
564 };
565
566 static const struct option options[] = {
567 { "connections-max", required_argument, NULL, 'c' },
568 { "help", no_argument, NULL, 'h' },
569 { "version", no_argument, NULL, ARG_VERSION },
570 {}
571 };
572
573 int c, r;
574
575 assert(argc >= 0);
576 assert(argv);
577
578 while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
579
580 switch (c) {
581
582 case 'h':
583 return help();
584
585 case ARG_VERSION:
586 return version();
587
588 case 'c':
589 r = safe_atou(optarg, &arg_connections_max);
590 if (r < 0) {
591 log_error("Failed to parse --connections-max= argument: %s", optarg);
592 return r;
593 }
594
595 if (arg_connections_max < 1)
596 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
597 "Connection limit is too low.");
598
599 break;
600
601 case '?':
602 return -EINVAL;
603
604 default:
605 assert_not_reached("Unhandled option");
606 }
607
608 if (optind >= argc)
609 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
610 "Not enough parameters.");
611
612 if (argc != optind+1)
613 return log_error_errno(SYNTHETIC_ERRNO(EINVAL),
614 "Too many parameters.");
615
616 arg_remote_host = argv[optind];
617 return 1;
618 }
619
620 static int run(int argc, char *argv[]) {
621 _cleanup_(context_clear) Context context = {};
622 int r, n, fd;
623
624 log_parse_environment();
625 log_open();
626
627 r = parse_argv(argc, argv);
628 if (r <= 0)
629 return r;
630
631 r = sd_event_default(&context.event);
632 if (r < 0)
633 return log_error_errno(r, "Failed to allocate event loop: %m");
634
635 r = sd_resolve_default(&context.resolve);
636 if (r < 0)
637 return log_error_errno(r, "Failed to allocate resolver: %m");
638
639 r = sd_resolve_attach_event(context.resolve, context.event, 0);
640 if (r < 0)
641 return log_error_errno(r, "Failed to attach resolver: %m");
642
643 sd_event_set_watchdog(context.event, true);
644
645 r = sd_listen_fds(1);
646 if (r < 0)
647 return log_error_errno(r, "Failed to receive sockets from parent.");
648 if (r == 0)
649 return log_error_errno(SYNTHETIC_ERRNO(EINVAL), "Didn't get any sockets passed in.");
650
651 n = r;
652
653 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
654 r = add_listen_socket(&context, fd);
655 if (r < 0)
656 return r;
657 }
658
659 r = sd_event_loop(context.event);
660 if (r < 0)
661 return log_error_errno(r, "Failed to run event loop: %m");
662
663 return 0;
664 }
665
666 DEFINE_MAIN_FUNCTION(run);