]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
Merge pull request #7493 from keszybz/revert-revert
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2 /***
3 This file is part of systemd.
4
5 Copyright 2013 David Strauss
6
7 systemd is free software; you can redistribute it and/or modify it
8 under the terms of the GNU Lesser General Public License as published by
9 the Free Software Foundation; either version 2.1 of the License, or
10 (at your option) any later version.
11
12 systemd is distributed in the hope that it will be useful, but
13 WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 Lesser General Public License for more details.
16
17 You should have received a copy of the GNU Lesser General Public License
18 along with systemd; If not, see <http://www.gnu.org/licenses/>.
19 ***/
20
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <getopt.h>
24 #include <netdb.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/socket.h>
29 #include <sys/un.h>
30 #include <unistd.h>
31
32 #include "sd-daemon.h"
33 #include "sd-event.h"
34 #include "sd-resolve.h"
35
36 #include "alloc-util.h"
37 #include "fd-util.h"
38 #include "log.h"
39 #include "path-util.h"
40 #include "set.h"
41 #include "socket-util.h"
42 #include "string-util.h"
43 #include "parse-util.h"
44 #include "util.h"
45
46 #define BUFFER_SIZE (256 * 1024)
47 static unsigned arg_connections_max = 256;
48
49 static const char *arg_remote_host = NULL;
50
51 typedef struct Context {
52 sd_event *event;
53 sd_resolve *resolve;
54
55 Set *listen;
56 Set *connections;
57 } Context;
58
59 typedef struct Connection {
60 Context *context;
61
62 int server_fd, client_fd;
63 int server_to_client_buffer[2]; /* a pipe */
64 int client_to_server_buffer[2]; /* a pipe */
65
66 size_t server_to_client_buffer_full, client_to_server_buffer_full;
67 size_t server_to_client_buffer_size, client_to_server_buffer_size;
68
69 sd_event_source *server_event_source, *client_event_source;
70
71 sd_resolve_query *resolve_query;
72 } Connection;
73
74 static void connection_free(Connection *c) {
75 assert(c);
76
77 if (c->context)
78 set_remove(c->context->connections, c);
79
80 sd_event_source_unref(c->server_event_source);
81 sd_event_source_unref(c->client_event_source);
82
83 safe_close(c->server_fd);
84 safe_close(c->client_fd);
85
86 safe_close_pair(c->server_to_client_buffer);
87 safe_close_pair(c->client_to_server_buffer);
88
89 sd_resolve_query_unref(c->resolve_query);
90
91 free(c);
92 }
93
94 static void context_free(Context *context) {
95 assert(context);
96
97 set_free_with_destructor(context->listen, sd_event_source_unref);
98 set_free_with_destructor(context->connections, connection_free);
99
100 sd_event_unref(context->event);
101 sd_resolve_unref(context->resolve);
102 }
103
104 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
105 int r;
106
107 assert(c);
108 assert(buffer);
109 assert(sz);
110
111 if (buffer[0] >= 0)
112 return 0;
113
114 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
115 if (r < 0)
116 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
117
118 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
119
120 r = fcntl(buffer[0], F_GETPIPE_SZ);
121 if (r < 0)
122 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
123
124 assert(r > 0);
125 *sz = r;
126
127 return 0;
128 }
129
130 static int connection_shovel(
131 Connection *c,
132 int *from, int buffer[2], int *to,
133 size_t *full, size_t *sz,
134 sd_event_source **from_source, sd_event_source **to_source) {
135
136 bool shoveled;
137
138 assert(c);
139 assert(from);
140 assert(buffer);
141 assert(buffer[0] >= 0);
142 assert(buffer[1] >= 0);
143 assert(to);
144 assert(full);
145 assert(sz);
146 assert(from_source);
147 assert(to_source);
148
149 do {
150 ssize_t z;
151
152 shoveled = false;
153
154 if (*full < *sz && *from >= 0 && *to >= 0) {
155 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
156 if (z > 0) {
157 *full += z;
158 shoveled = true;
159 } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
160 *from_source = sd_event_source_unref(*from_source);
161 *from = safe_close(*from);
162 } else if (!IN_SET(errno, EAGAIN, EINTR))
163 return log_error_errno(errno, "Failed to splice: %m");
164 }
165
166 if (*full > 0 && *to >= 0) {
167 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
168 if (z > 0) {
169 *full -= z;
170 shoveled = true;
171 } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
172 *to_source = sd_event_source_unref(*to_source);
173 *to = safe_close(*to);
174 } else if (!IN_SET(errno, EAGAIN, EINTR))
175 return log_error_errno(errno, "Failed to splice: %m");
176 }
177 } while (shoveled);
178
179 return 0;
180 }
181
182 static int connection_enable_event_sources(Connection *c);
183
184 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
185 Connection *c = userdata;
186 int r;
187
188 assert(s);
189 assert(fd >= 0);
190 assert(c);
191
192 r = connection_shovel(c,
193 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
194 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
195 &c->server_event_source, &c->client_event_source);
196 if (r < 0)
197 goto quit;
198
199 r = connection_shovel(c,
200 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
201 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
202 &c->client_event_source, &c->server_event_source);
203 if (r < 0)
204 goto quit;
205
206 /* EOF on both sides? */
207 if (c->server_fd == -1 && c->client_fd == -1)
208 goto quit;
209
210 /* Server closed, and all data written to client? */
211 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
212 goto quit;
213
214 /* Client closed, and all data written to server? */
215 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
216 goto quit;
217
218 r = connection_enable_event_sources(c);
219 if (r < 0)
220 goto quit;
221
222 return 1;
223
224 quit:
225 connection_free(c);
226 return 0; /* ignore errors, continue serving */
227 }
228
229 static int connection_enable_event_sources(Connection *c) {
230 uint32_t a = 0, b = 0;
231 int r;
232
233 assert(c);
234
235 if (c->server_to_client_buffer_full > 0)
236 b |= EPOLLOUT;
237 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
238 a |= EPOLLIN;
239
240 if (c->client_to_server_buffer_full > 0)
241 a |= EPOLLOUT;
242 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
243 b |= EPOLLIN;
244
245 if (c->server_event_source)
246 r = sd_event_source_set_io_events(c->server_event_source, a);
247 else if (c->server_fd >= 0)
248 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
249 else
250 r = 0;
251
252 if (r < 0)
253 return log_error_errno(r, "Failed to set up server event source: %m");
254
255 if (c->client_event_source)
256 r = sd_event_source_set_io_events(c->client_event_source, b);
257 else if (c->client_fd >= 0)
258 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
259 else
260 r = 0;
261
262 if (r < 0)
263 return log_error_errno(r, "Failed to set up client event source: %m");
264
265 return 0;
266 }
267
268 static int connection_complete(Connection *c) {
269 int r;
270
271 assert(c);
272
273 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
274 if (r < 0)
275 goto fail;
276
277 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
278 if (r < 0)
279 goto fail;
280
281 r = connection_enable_event_sources(c);
282 if (r < 0)
283 goto fail;
284
285 return 0;
286
287 fail:
288 connection_free(c);
289 return 0; /* ignore errors, continue serving */
290 }
291
292 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
293 Connection *c = userdata;
294 socklen_t solen;
295 int error, r;
296
297 assert(s);
298 assert(fd >= 0);
299 assert(c);
300
301 solen = sizeof(error);
302 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
303 if (r < 0) {
304 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
305 goto fail;
306 }
307
308 if (error != 0) {
309 log_error_errno(error, "Failed to connect to remote host: %m");
310 goto fail;
311 }
312
313 c->client_event_source = sd_event_source_unref(c->client_event_source);
314
315 return connection_complete(c);
316
317 fail:
318 connection_free(c);
319 return 0; /* ignore errors, continue serving */
320 }
321
322 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
323 int r;
324
325 assert(c);
326 assert(sa);
327 assert(salen);
328
329 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
330 if (c->client_fd < 0) {
331 log_error_errno(errno, "Failed to get remote socket: %m");
332 goto fail;
333 }
334
335 r = connect(c->client_fd, sa, salen);
336 if (r < 0) {
337 if (errno == EINPROGRESS) {
338 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
339 if (r < 0) {
340 log_error_errno(r, "Failed to add connection socket: %m");
341 goto fail;
342 }
343
344 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
345 if (r < 0) {
346 log_error_errno(r, "Failed to enable oneshot event source: %m");
347 goto fail;
348 }
349 } else {
350 log_error_errno(errno, "Failed to connect to remote host: %m");
351 goto fail;
352 }
353 } else {
354 r = connection_complete(c);
355 if (r < 0)
356 goto fail;
357 }
358
359 return 0;
360
361 fail:
362 connection_free(c);
363 return 0; /* ignore errors, continue serving */
364 }
365
366 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
367 Connection *c = userdata;
368
369 assert(q);
370 assert(c);
371
372 if (ret != 0) {
373 log_error("Failed to resolve host: %s", gai_strerror(ret));
374 goto fail;
375 }
376
377 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
378
379 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
380
381 fail:
382 connection_free(c);
383 return 0; /* ignore errors, continue serving */
384 }
385
386 static int resolve_remote(Connection *c) {
387
388 static const struct addrinfo hints = {
389 .ai_family = AF_UNSPEC,
390 .ai_socktype = SOCK_STREAM,
391 .ai_flags = AI_ADDRCONFIG
392 };
393
394 union sockaddr_union sa = {};
395 const char *node, *service;
396 int r;
397
398 if (path_is_absolute(arg_remote_host)) {
399 sa.un.sun_family = AF_UNIX;
400 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path));
401 return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
402 }
403
404 if (arg_remote_host[0] == '@') {
405 sa.un.sun_family = AF_UNIX;
406 sa.un.sun_path[0] = 0;
407 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-1);
408 return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
409 }
410
411 service = strrchr(arg_remote_host, ':');
412 if (service) {
413 node = strndupa(arg_remote_host, service - arg_remote_host);
414 service++;
415 } else {
416 node = arg_remote_host;
417 service = "80";
418 }
419
420 log_debug("Looking up address info for %s:%s", node, service);
421 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
422 if (r < 0) {
423 log_error_errno(r, "Failed to resolve remote host: %m");
424 goto fail;
425 }
426
427 return 0;
428
429 fail:
430 connection_free(c);
431 return 0; /* ignore errors, continue serving */
432 }
433
434 static int add_connection_socket(Context *context, int fd) {
435 Connection *c;
436 int r;
437
438 assert(context);
439 assert(fd >= 0);
440
441 if (set_size(context->connections) > arg_connections_max) {
442 log_warning("Hit connection limit, refusing connection.");
443 safe_close(fd);
444 return 0;
445 }
446
447 r = set_ensure_allocated(&context->connections, NULL);
448 if (r < 0) {
449 log_oom();
450 return 0;
451 }
452
453 c = new0(Connection, 1);
454 if (!c) {
455 log_oom();
456 return 0;
457 }
458
459 c->context = context;
460 c->server_fd = fd;
461 c->client_fd = -1;
462 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
463 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
464
465 r = set_put(context->connections, c);
466 if (r < 0) {
467 free(c);
468 log_oom();
469 return 0;
470 }
471
472 return resolve_remote(c);
473 }
474
475 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
476 _cleanup_free_ char *peer = NULL;
477 Context *context = userdata;
478 int nfd = -1, r;
479
480 assert(s);
481 assert(fd >= 0);
482 assert(revents & EPOLLIN);
483 assert(context);
484
485 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
486 if (nfd < 0) {
487 if (errno != -EAGAIN)
488 log_warning_errno(errno, "Failed to accept() socket: %m");
489 } else {
490 getpeername_pretty(nfd, true, &peer);
491 log_debug("New connection from %s", strna(peer));
492
493 r = add_connection_socket(context, nfd);
494 if (r < 0) {
495 log_error_errno(r, "Failed to accept connection, ignoring: %m");
496 safe_close(fd);
497 }
498 }
499
500 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
501 if (r < 0) {
502 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
503 sd_event_exit(context->event, r);
504 return r;
505 }
506
507 return 1;
508 }
509
510 static int add_listen_socket(Context *context, int fd) {
511 sd_event_source *source;
512 int r;
513
514 assert(context);
515 assert(fd >= 0);
516
517 r = set_ensure_allocated(&context->listen, NULL);
518 if (r < 0) {
519 log_oom();
520 return r;
521 }
522
523 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
524 if (r < 0)
525 return log_error_errno(r, "Failed to determine socket type: %m");
526 if (r == 0) {
527 log_error("Passed in socket is not a stream socket.");
528 return -EINVAL;
529 }
530
531 r = fd_nonblock(fd, true);
532 if (r < 0)
533 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
534
535 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
536 if (r < 0)
537 return log_error_errno(r, "Failed to add event source: %m");
538
539 r = set_put(context->listen, source);
540 if (r < 0) {
541 log_error_errno(r, "Failed to add source to set: %m");
542 sd_event_source_unref(source);
543 return r;
544 }
545
546 /* Set the watcher to oneshot in case other processes are also
547 * watching to accept(). */
548 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
549 if (r < 0)
550 return log_error_errno(r, "Failed to enable oneshot mode: %m");
551
552 return 0;
553 }
554
555 static void help(void) {
556 printf("%1$s [HOST:PORT]\n"
557 "%1$s [SOCKET]\n\n"
558 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
559 " -c --connections-max= Set the maximum number of connections to be accepted\n"
560 " -h --help Show this help\n"
561 " --version Show package version\n",
562 program_invocation_short_name);
563 }
564
565 static int parse_argv(int argc, char *argv[]) {
566
567 enum {
568 ARG_VERSION = 0x100,
569 ARG_IGNORE_ENV
570 };
571
572 static const struct option options[] = {
573 { "connections-max", required_argument, NULL, 'c' },
574 { "help", no_argument, NULL, 'h' },
575 { "version", no_argument, NULL, ARG_VERSION },
576 {}
577 };
578
579 int c, r;
580
581 assert(argc >= 0);
582 assert(argv);
583
584 while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
585
586 switch (c) {
587
588 case 'h':
589 help();
590 return 0;
591
592 case 'c':
593 r = safe_atou(optarg, &arg_connections_max);
594 if (r < 0) {
595 log_error("Failed to parse --connections-max= argument: %s", optarg);
596 return r;
597 }
598
599 if (arg_connections_max < 1) {
600 log_error("Connection limit is too low.");
601 return -EINVAL;
602 }
603
604 break;
605
606 case ARG_VERSION:
607 return version();
608
609 case '?':
610 return -EINVAL;
611
612 default:
613 assert_not_reached("Unhandled option");
614 }
615
616 if (optind >= argc) {
617 log_error("Not enough parameters.");
618 return -EINVAL;
619 }
620
621 if (argc != optind+1) {
622 log_error("Too many parameters.");
623 return -EINVAL;
624 }
625
626 arg_remote_host = argv[optind];
627 return 1;
628 }
629
630 int main(int argc, char *argv[]) {
631 Context context = {};
632 int r, n, fd;
633
634 log_parse_environment();
635 log_open();
636
637 r = parse_argv(argc, argv);
638 if (r <= 0)
639 goto finish;
640
641 r = sd_event_default(&context.event);
642 if (r < 0) {
643 log_error_errno(r, "Failed to allocate event loop: %m");
644 goto finish;
645 }
646
647 r = sd_resolve_default(&context.resolve);
648 if (r < 0) {
649 log_error_errno(r, "Failed to allocate resolver: %m");
650 goto finish;
651 }
652
653 r = sd_resolve_attach_event(context.resolve, context.event, 0);
654 if (r < 0) {
655 log_error_errno(r, "Failed to attach resolver: %m");
656 goto finish;
657 }
658
659 sd_event_set_watchdog(context.event, true);
660
661 n = sd_listen_fds(1);
662 if (n < 0) {
663 log_error("Failed to receive sockets from parent.");
664 r = n;
665 goto finish;
666 } else if (n == 0) {
667 log_error("Didn't get any sockets passed in.");
668 r = -EINVAL;
669 goto finish;
670 }
671
672 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
673 r = add_listen_socket(&context, fd);
674 if (r < 0)
675 goto finish;
676 }
677
678 r = sd_event_loop(context.event);
679 if (r < 0) {
680 log_error_errno(r, "Failed to run event loop: %m");
681 goto finish;
682 }
683
684 finish:
685 context_free(&context);
686
687 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
688 }