]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
Add SPDX license identifiers to source files under the LGPL
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2 /***
3 This file is part of systemd.
4
5 Copyright 2013 David Strauss
6
7 systemd is free software; you can redistribute it and/or modify it
8 under the terms of the GNU Lesser General Public License as published by
9 the Free Software Foundation; either version 2.1 of the License, or
10 (at your option) any later version.
11
12 systemd is distributed in the hope that it will be useful, but
13 WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 Lesser General Public License for more details.
16
17 You should have received a copy of the GNU Lesser General Public License
18 along with systemd; If not, see <http://www.gnu.org/licenses/>.
19 ***/
20
21 #include <errno.h>
22 #include <fcntl.h>
23 #include <getopt.h>
24 #include <netdb.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/socket.h>
29 #include <sys/un.h>
30 #include <unistd.h>
31
32 #include "sd-daemon.h"
33 #include "sd-event.h"
34 #include "sd-resolve.h"
35
36 #include "alloc-util.h"
37 #include "fd-util.h"
38 #include "log.h"
39 #include "path-util.h"
40 #include "set.h"
41 #include "socket-util.h"
42 #include "string-util.h"
43 #include "parse-util.h"
44 #include "util.h"
45
46 #define BUFFER_SIZE (256 * 1024)
47 static unsigned arg_connections_max = 256;
48
49 static const char *arg_remote_host = NULL;
50
51 typedef struct Context {
52 sd_event *event;
53 sd_resolve *resolve;
54
55 Set *listen;
56 Set *connections;
57 } Context;
58
59 typedef struct Connection {
60 Context *context;
61
62 int server_fd, client_fd;
63 int server_to_client_buffer[2]; /* a pipe */
64 int client_to_server_buffer[2]; /* a pipe */
65
66 size_t server_to_client_buffer_full, client_to_server_buffer_full;
67 size_t server_to_client_buffer_size, client_to_server_buffer_size;
68
69 sd_event_source *server_event_source, *client_event_source;
70
71 sd_resolve_query *resolve_query;
72 } Connection;
73
74 static void connection_free(Connection *c) {
75 assert(c);
76
77 if (c->context)
78 set_remove(c->context->connections, c);
79
80 sd_event_source_unref(c->server_event_source);
81 sd_event_source_unref(c->client_event_source);
82
83 safe_close(c->server_fd);
84 safe_close(c->client_fd);
85
86 safe_close_pair(c->server_to_client_buffer);
87 safe_close_pair(c->client_to_server_buffer);
88
89 sd_resolve_query_unref(c->resolve_query);
90
91 free(c);
92 }
93
94 static void context_free(Context *context) {
95 sd_event_source *es;
96 Connection *c;
97
98 assert(context);
99
100 while ((es = set_steal_first(context->listen)))
101 sd_event_source_unref(es);
102
103 while ((c = set_first(context->connections)))
104 connection_free(c);
105
106 set_free(context->listen);
107 set_free(context->connections);
108
109 sd_event_unref(context->event);
110 sd_resolve_unref(context->resolve);
111 }
112
113 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
114 int r;
115
116 assert(c);
117 assert(buffer);
118 assert(sz);
119
120 if (buffer[0] >= 0)
121 return 0;
122
123 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
124 if (r < 0)
125 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
126
127 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
128
129 r = fcntl(buffer[0], F_GETPIPE_SZ);
130 if (r < 0)
131 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
132
133 assert(r > 0);
134 *sz = r;
135
136 return 0;
137 }
138
139 static int connection_shovel(
140 Connection *c,
141 int *from, int buffer[2], int *to,
142 size_t *full, size_t *sz,
143 sd_event_source **from_source, sd_event_source **to_source) {
144
145 bool shoveled;
146
147 assert(c);
148 assert(from);
149 assert(buffer);
150 assert(buffer[0] >= 0);
151 assert(buffer[1] >= 0);
152 assert(to);
153 assert(full);
154 assert(sz);
155 assert(from_source);
156 assert(to_source);
157
158 do {
159 ssize_t z;
160
161 shoveled = false;
162
163 if (*full < *sz && *from >= 0 && *to >= 0) {
164 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
165 if (z > 0) {
166 *full += z;
167 shoveled = true;
168 } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
169 *from_source = sd_event_source_unref(*from_source);
170 *from = safe_close(*from);
171 } else if (!IN_SET(errno, EAGAIN, EINTR))
172 return log_error_errno(errno, "Failed to splice: %m");
173 }
174
175 if (*full > 0 && *to >= 0) {
176 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
177 if (z > 0) {
178 *full -= z;
179 shoveled = true;
180 } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
181 *to_source = sd_event_source_unref(*to_source);
182 *to = safe_close(*to);
183 } else if (!IN_SET(errno, EAGAIN, EINTR))
184 return log_error_errno(errno, "Failed to splice: %m");
185 }
186 } while (shoveled);
187
188 return 0;
189 }
190
191 static int connection_enable_event_sources(Connection *c);
192
193 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
194 Connection *c = userdata;
195 int r;
196
197 assert(s);
198 assert(fd >= 0);
199 assert(c);
200
201 r = connection_shovel(c,
202 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
203 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
204 &c->server_event_source, &c->client_event_source);
205 if (r < 0)
206 goto quit;
207
208 r = connection_shovel(c,
209 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
210 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
211 &c->client_event_source, &c->server_event_source);
212 if (r < 0)
213 goto quit;
214
215 /* EOF on both sides? */
216 if (c->server_fd == -1 && c->client_fd == -1)
217 goto quit;
218
219 /* Server closed, and all data written to client? */
220 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
221 goto quit;
222
223 /* Client closed, and all data written to server? */
224 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
225 goto quit;
226
227 r = connection_enable_event_sources(c);
228 if (r < 0)
229 goto quit;
230
231 return 1;
232
233 quit:
234 connection_free(c);
235 return 0; /* ignore errors, continue serving */
236 }
237
238 static int connection_enable_event_sources(Connection *c) {
239 uint32_t a = 0, b = 0;
240 int r;
241
242 assert(c);
243
244 if (c->server_to_client_buffer_full > 0)
245 b |= EPOLLOUT;
246 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
247 a |= EPOLLIN;
248
249 if (c->client_to_server_buffer_full > 0)
250 a |= EPOLLOUT;
251 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
252 b |= EPOLLIN;
253
254 if (c->server_event_source)
255 r = sd_event_source_set_io_events(c->server_event_source, a);
256 else if (c->server_fd >= 0)
257 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
258 else
259 r = 0;
260
261 if (r < 0)
262 return log_error_errno(r, "Failed to set up server event source: %m");
263
264 if (c->client_event_source)
265 r = sd_event_source_set_io_events(c->client_event_source, b);
266 else if (c->client_fd >= 0)
267 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
268 else
269 r = 0;
270
271 if (r < 0)
272 return log_error_errno(r, "Failed to set up client event source: %m");
273
274 return 0;
275 }
276
277 static int connection_complete(Connection *c) {
278 int r;
279
280 assert(c);
281
282 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
283 if (r < 0)
284 goto fail;
285
286 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
287 if (r < 0)
288 goto fail;
289
290 r = connection_enable_event_sources(c);
291 if (r < 0)
292 goto fail;
293
294 return 0;
295
296 fail:
297 connection_free(c);
298 return 0; /* ignore errors, continue serving */
299 }
300
301 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
302 Connection *c = userdata;
303 socklen_t solen;
304 int error, r;
305
306 assert(s);
307 assert(fd >= 0);
308 assert(c);
309
310 solen = sizeof(error);
311 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
312 if (r < 0) {
313 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
314 goto fail;
315 }
316
317 if (error != 0) {
318 log_error_errno(error, "Failed to connect to remote host: %m");
319 goto fail;
320 }
321
322 c->client_event_source = sd_event_source_unref(c->client_event_source);
323
324 return connection_complete(c);
325
326 fail:
327 connection_free(c);
328 return 0; /* ignore errors, continue serving */
329 }
330
331 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
332 int r;
333
334 assert(c);
335 assert(sa);
336 assert(salen);
337
338 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
339 if (c->client_fd < 0) {
340 log_error_errno(errno, "Failed to get remote socket: %m");
341 goto fail;
342 }
343
344 r = connect(c->client_fd, sa, salen);
345 if (r < 0) {
346 if (errno == EINPROGRESS) {
347 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
348 if (r < 0) {
349 log_error_errno(r, "Failed to add connection socket: %m");
350 goto fail;
351 }
352
353 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
354 if (r < 0) {
355 log_error_errno(r, "Failed to enable oneshot event source: %m");
356 goto fail;
357 }
358 } else {
359 log_error_errno(errno, "Failed to connect to remote host: %m");
360 goto fail;
361 }
362 } else {
363 r = connection_complete(c);
364 if (r < 0)
365 goto fail;
366 }
367
368 return 0;
369
370 fail:
371 connection_free(c);
372 return 0; /* ignore errors, continue serving */
373 }
374
375 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
376 Connection *c = userdata;
377
378 assert(q);
379 assert(c);
380
381 if (ret != 0) {
382 log_error("Failed to resolve host: %s", gai_strerror(ret));
383 goto fail;
384 }
385
386 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
387
388 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
389
390 fail:
391 connection_free(c);
392 return 0; /* ignore errors, continue serving */
393 }
394
395 static int resolve_remote(Connection *c) {
396
397 static const struct addrinfo hints = {
398 .ai_family = AF_UNSPEC,
399 .ai_socktype = SOCK_STREAM,
400 .ai_flags = AI_ADDRCONFIG
401 };
402
403 union sockaddr_union sa = {};
404 const char *node, *service;
405 int r;
406
407 if (path_is_absolute(arg_remote_host)) {
408 sa.un.sun_family = AF_UNIX;
409 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path));
410 return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
411 }
412
413 if (arg_remote_host[0] == '@') {
414 sa.un.sun_family = AF_UNIX;
415 sa.un.sun_path[0] = 0;
416 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-1);
417 return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
418 }
419
420 service = strrchr(arg_remote_host, ':');
421 if (service) {
422 node = strndupa(arg_remote_host, service - arg_remote_host);
423 service++;
424 } else {
425 node = arg_remote_host;
426 service = "80";
427 }
428
429 log_debug("Looking up address info for %s:%s", node, service);
430 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
431 if (r < 0) {
432 log_error_errno(r, "Failed to resolve remote host: %m");
433 goto fail;
434 }
435
436 return 0;
437
438 fail:
439 connection_free(c);
440 return 0; /* ignore errors, continue serving */
441 }
442
443 static int add_connection_socket(Context *context, int fd) {
444 Connection *c;
445 int r;
446
447 assert(context);
448 assert(fd >= 0);
449
450 if (set_size(context->connections) > arg_connections_max) {
451 log_warning("Hit connection limit, refusing connection.");
452 safe_close(fd);
453 return 0;
454 }
455
456 r = set_ensure_allocated(&context->connections, NULL);
457 if (r < 0) {
458 log_oom();
459 return 0;
460 }
461
462 c = new0(Connection, 1);
463 if (!c) {
464 log_oom();
465 return 0;
466 }
467
468 c->context = context;
469 c->server_fd = fd;
470 c->client_fd = -1;
471 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
472 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
473
474 r = set_put(context->connections, c);
475 if (r < 0) {
476 free(c);
477 log_oom();
478 return 0;
479 }
480
481 return resolve_remote(c);
482 }
483
484 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
485 _cleanup_free_ char *peer = NULL;
486 Context *context = userdata;
487 int nfd = -1, r;
488
489 assert(s);
490 assert(fd >= 0);
491 assert(revents & EPOLLIN);
492 assert(context);
493
494 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
495 if (nfd < 0) {
496 if (errno != -EAGAIN)
497 log_warning_errno(errno, "Failed to accept() socket: %m");
498 } else {
499 getpeername_pretty(nfd, true, &peer);
500 log_debug("New connection from %s", strna(peer));
501
502 r = add_connection_socket(context, nfd);
503 if (r < 0) {
504 log_error_errno(r, "Failed to accept connection, ignoring: %m");
505 safe_close(fd);
506 }
507 }
508
509 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
510 if (r < 0) {
511 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
512 sd_event_exit(context->event, r);
513 return r;
514 }
515
516 return 1;
517 }
518
519 static int add_listen_socket(Context *context, int fd) {
520 sd_event_source *source;
521 int r;
522
523 assert(context);
524 assert(fd >= 0);
525
526 r = set_ensure_allocated(&context->listen, NULL);
527 if (r < 0) {
528 log_oom();
529 return r;
530 }
531
532 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
533 if (r < 0)
534 return log_error_errno(r, "Failed to determine socket type: %m");
535 if (r == 0) {
536 log_error("Passed in socket is not a stream socket.");
537 return -EINVAL;
538 }
539
540 r = fd_nonblock(fd, true);
541 if (r < 0)
542 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
543
544 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
545 if (r < 0)
546 return log_error_errno(r, "Failed to add event source: %m");
547
548 r = set_put(context->listen, source);
549 if (r < 0) {
550 log_error_errno(r, "Failed to add source to set: %m");
551 sd_event_source_unref(source);
552 return r;
553 }
554
555 /* Set the watcher to oneshot in case other processes are also
556 * watching to accept(). */
557 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
558 if (r < 0)
559 return log_error_errno(r, "Failed to enable oneshot mode: %m");
560
561 return 0;
562 }
563
564 static void help(void) {
565 printf("%1$s [HOST:PORT]\n"
566 "%1$s [SOCKET]\n\n"
567 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
568 " -c --connections-max= Set the maximum number of connections to be accepted\n"
569 " -h --help Show this help\n"
570 " --version Show package version\n",
571 program_invocation_short_name);
572 }
573
574 static int parse_argv(int argc, char *argv[]) {
575
576 enum {
577 ARG_VERSION = 0x100,
578 ARG_IGNORE_ENV
579 };
580
581 static const struct option options[] = {
582 { "connections-max", required_argument, NULL, 'c' },
583 { "help", no_argument, NULL, 'h' },
584 { "version", no_argument, NULL, ARG_VERSION },
585 {}
586 };
587
588 int c, r;
589
590 assert(argc >= 0);
591 assert(argv);
592
593 while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
594
595 switch (c) {
596
597 case 'h':
598 help();
599 return 0;
600
601 case 'c':
602 r = safe_atou(optarg, &arg_connections_max);
603 if (r < 0) {
604 log_error("Failed to parse --connections-max= argument: %s", optarg);
605 return r;
606 }
607
608 if (arg_connections_max < 1) {
609 log_error("Connection limit is too low.");
610 return -EINVAL;
611 }
612
613 break;
614
615 case ARG_VERSION:
616 return version();
617
618 case '?':
619 return -EINVAL;
620
621 default:
622 assert_not_reached("Unhandled option");
623 }
624
625 if (optind >= argc) {
626 log_error("Not enough parameters.");
627 return -EINVAL;
628 }
629
630 if (argc != optind+1) {
631 log_error("Too many parameters.");
632 return -EINVAL;
633 }
634
635 arg_remote_host = argv[optind];
636 return 1;
637 }
638
639 int main(int argc, char *argv[]) {
640 Context context = {};
641 int r, n, fd;
642
643 log_parse_environment();
644 log_open();
645
646 r = parse_argv(argc, argv);
647 if (r <= 0)
648 goto finish;
649
650 r = sd_event_default(&context.event);
651 if (r < 0) {
652 log_error_errno(r, "Failed to allocate event loop: %m");
653 goto finish;
654 }
655
656 r = sd_resolve_default(&context.resolve);
657 if (r < 0) {
658 log_error_errno(r, "Failed to allocate resolver: %m");
659 goto finish;
660 }
661
662 r = sd_resolve_attach_event(context.resolve, context.event, 0);
663 if (r < 0) {
664 log_error_errno(r, "Failed to attach resolver: %m");
665 goto finish;
666 }
667
668 sd_event_set_watchdog(context.event, true);
669
670 n = sd_listen_fds(1);
671 if (n < 0) {
672 log_error("Failed to receive sockets from parent.");
673 r = n;
674 goto finish;
675 } else if (n == 0) {
676 log_error("Didn't get any sockets passed in.");
677 r = -EINVAL;
678 goto finish;
679 }
680
681 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
682 r = add_listen_socket(&context, fd);
683 if (r < 0)
684 goto finish;
685 }
686
687 r = sd_event_loop(context.event);
688 if (r < 0) {
689 log_error_errno(r, "Failed to run event loop: %m");
690 goto finish;
691 }
692
693 finish:
694 context_free(&context);
695
696 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
697 }