]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
Merge pull request #7388 from keszybz/doc-tweak
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /***
2 This file is part of systemd.
3
4 Copyright 2013 David Strauss
5
6 systemd is free software; you can redistribute it and/or modify it
7 under the terms of the GNU Lesser General Public License as published by
8 the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version.
10
11 systemd is distributed in the hope that it will be useful, but
12 WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15
16 You should have received a copy of the GNU Lesser General Public License
17 along with systemd; If not, see <http://www.gnu.org/licenses/>.
18 ***/
19
20 #include <errno.h>
21 #include <fcntl.h>
22 #include <getopt.h>
23 #include <netdb.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27 #include <sys/socket.h>
28 #include <sys/un.h>
29 #include <unistd.h>
30
31 #include "sd-daemon.h"
32 #include "sd-event.h"
33 #include "sd-resolve.h"
34
35 #include "alloc-util.h"
36 #include "fd-util.h"
37 #include "log.h"
38 #include "path-util.h"
39 #include "set.h"
40 #include "socket-util.h"
41 #include "string-util.h"
42 #include "parse-util.h"
43 #include "util.h"
44
45 #define BUFFER_SIZE (256 * 1024)
46 static unsigned arg_connections_max = 256;
47
48 static const char *arg_remote_host = NULL;
49
50 typedef struct Context {
51 sd_event *event;
52 sd_resolve *resolve;
53
54 Set *listen;
55 Set *connections;
56 } Context;
57
58 typedef struct Connection {
59 Context *context;
60
61 int server_fd, client_fd;
62 int server_to_client_buffer[2]; /* a pipe */
63 int client_to_server_buffer[2]; /* a pipe */
64
65 size_t server_to_client_buffer_full, client_to_server_buffer_full;
66 size_t server_to_client_buffer_size, client_to_server_buffer_size;
67
68 sd_event_source *server_event_source, *client_event_source;
69
70 sd_resolve_query *resolve_query;
71 } Connection;
72
73 static void connection_free(Connection *c) {
74 assert(c);
75
76 if (c->context)
77 set_remove(c->context->connections, c);
78
79 sd_event_source_unref(c->server_event_source);
80 sd_event_source_unref(c->client_event_source);
81
82 safe_close(c->server_fd);
83 safe_close(c->client_fd);
84
85 safe_close_pair(c->server_to_client_buffer);
86 safe_close_pair(c->client_to_server_buffer);
87
88 sd_resolve_query_unref(c->resolve_query);
89
90 free(c);
91 }
92
93 static void context_free(Context *context) {
94 sd_event_source *es;
95 Connection *c;
96
97 assert(context);
98
99 while ((es = set_steal_first(context->listen)))
100 sd_event_source_unref(es);
101
102 while ((c = set_first(context->connections)))
103 connection_free(c);
104
105 set_free(context->listen);
106 set_free(context->connections);
107
108 sd_event_unref(context->event);
109 sd_resolve_unref(context->resolve);
110 }
111
112 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
113 int r;
114
115 assert(c);
116 assert(buffer);
117 assert(sz);
118
119 if (buffer[0] >= 0)
120 return 0;
121
122 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
123 if (r < 0)
124 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
125
126 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
127
128 r = fcntl(buffer[0], F_GETPIPE_SZ);
129 if (r < 0)
130 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
131
132 assert(r > 0);
133 *sz = r;
134
135 return 0;
136 }
137
138 static int connection_shovel(
139 Connection *c,
140 int *from, int buffer[2], int *to,
141 size_t *full, size_t *sz,
142 sd_event_source **from_source, sd_event_source **to_source) {
143
144 bool shoveled;
145
146 assert(c);
147 assert(from);
148 assert(buffer);
149 assert(buffer[0] >= 0);
150 assert(buffer[1] >= 0);
151 assert(to);
152 assert(full);
153 assert(sz);
154 assert(from_source);
155 assert(to_source);
156
157 do {
158 ssize_t z;
159
160 shoveled = false;
161
162 if (*full < *sz && *from >= 0 && *to >= 0) {
163 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
164 if (z > 0) {
165 *full += z;
166 shoveled = true;
167 } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
168 *from_source = sd_event_source_unref(*from_source);
169 *from = safe_close(*from);
170 } else if (!IN_SET(errno, EAGAIN, EINTR))
171 return log_error_errno(errno, "Failed to splice: %m");
172 }
173
174 if (*full > 0 && *to >= 0) {
175 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
176 if (z > 0) {
177 *full -= z;
178 shoveled = true;
179 } else if (z == 0 || IN_SET(errno, EPIPE, ECONNRESET)) {
180 *to_source = sd_event_source_unref(*to_source);
181 *to = safe_close(*to);
182 } else if (!IN_SET(errno, EAGAIN, EINTR))
183 return log_error_errno(errno, "Failed to splice: %m");
184 }
185 } while (shoveled);
186
187 return 0;
188 }
189
190 static int connection_enable_event_sources(Connection *c);
191
192 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
193 Connection *c = userdata;
194 int r;
195
196 assert(s);
197 assert(fd >= 0);
198 assert(c);
199
200 r = connection_shovel(c,
201 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
202 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
203 &c->server_event_source, &c->client_event_source);
204 if (r < 0)
205 goto quit;
206
207 r = connection_shovel(c,
208 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
209 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
210 &c->client_event_source, &c->server_event_source);
211 if (r < 0)
212 goto quit;
213
214 /* EOF on both sides? */
215 if (c->server_fd == -1 && c->client_fd == -1)
216 goto quit;
217
218 /* Server closed, and all data written to client? */
219 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
220 goto quit;
221
222 /* Client closed, and all data written to server? */
223 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
224 goto quit;
225
226 r = connection_enable_event_sources(c);
227 if (r < 0)
228 goto quit;
229
230 return 1;
231
232 quit:
233 connection_free(c);
234 return 0; /* ignore errors, continue serving */
235 }
236
237 static int connection_enable_event_sources(Connection *c) {
238 uint32_t a = 0, b = 0;
239 int r;
240
241 assert(c);
242
243 if (c->server_to_client_buffer_full > 0)
244 b |= EPOLLOUT;
245 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
246 a |= EPOLLIN;
247
248 if (c->client_to_server_buffer_full > 0)
249 a |= EPOLLOUT;
250 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
251 b |= EPOLLIN;
252
253 if (c->server_event_source)
254 r = sd_event_source_set_io_events(c->server_event_source, a);
255 else if (c->server_fd >= 0)
256 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
257 else
258 r = 0;
259
260 if (r < 0)
261 return log_error_errno(r, "Failed to set up server event source: %m");
262
263 if (c->client_event_source)
264 r = sd_event_source_set_io_events(c->client_event_source, b);
265 else if (c->client_fd >= 0)
266 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
267 else
268 r = 0;
269
270 if (r < 0)
271 return log_error_errno(r, "Failed to set up client event source: %m");
272
273 return 0;
274 }
275
276 static int connection_complete(Connection *c) {
277 int r;
278
279 assert(c);
280
281 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
282 if (r < 0)
283 goto fail;
284
285 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
286 if (r < 0)
287 goto fail;
288
289 r = connection_enable_event_sources(c);
290 if (r < 0)
291 goto fail;
292
293 return 0;
294
295 fail:
296 connection_free(c);
297 return 0; /* ignore errors, continue serving */
298 }
299
300 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
301 Connection *c = userdata;
302 socklen_t solen;
303 int error, r;
304
305 assert(s);
306 assert(fd >= 0);
307 assert(c);
308
309 solen = sizeof(error);
310 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
311 if (r < 0) {
312 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
313 goto fail;
314 }
315
316 if (error != 0) {
317 log_error_errno(error, "Failed to connect to remote host: %m");
318 goto fail;
319 }
320
321 c->client_event_source = sd_event_source_unref(c->client_event_source);
322
323 return connection_complete(c);
324
325 fail:
326 connection_free(c);
327 return 0; /* ignore errors, continue serving */
328 }
329
330 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
331 int r;
332
333 assert(c);
334 assert(sa);
335 assert(salen);
336
337 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
338 if (c->client_fd < 0) {
339 log_error_errno(errno, "Failed to get remote socket: %m");
340 goto fail;
341 }
342
343 r = connect(c->client_fd, sa, salen);
344 if (r < 0) {
345 if (errno == EINPROGRESS) {
346 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
347 if (r < 0) {
348 log_error_errno(r, "Failed to add connection socket: %m");
349 goto fail;
350 }
351
352 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
353 if (r < 0) {
354 log_error_errno(r, "Failed to enable oneshot event source: %m");
355 goto fail;
356 }
357 } else {
358 log_error_errno(errno, "Failed to connect to remote host: %m");
359 goto fail;
360 }
361 } else {
362 r = connection_complete(c);
363 if (r < 0)
364 goto fail;
365 }
366
367 return 0;
368
369 fail:
370 connection_free(c);
371 return 0; /* ignore errors, continue serving */
372 }
373
374 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
375 Connection *c = userdata;
376
377 assert(q);
378 assert(c);
379
380 if (ret != 0) {
381 log_error("Failed to resolve host: %s", gai_strerror(ret));
382 goto fail;
383 }
384
385 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
386
387 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
388
389 fail:
390 connection_free(c);
391 return 0; /* ignore errors, continue serving */
392 }
393
394 static int resolve_remote(Connection *c) {
395
396 static const struct addrinfo hints = {
397 .ai_family = AF_UNSPEC,
398 .ai_socktype = SOCK_STREAM,
399 .ai_flags = AI_ADDRCONFIG
400 };
401
402 union sockaddr_union sa = {};
403 const char *node, *service;
404 int r;
405
406 if (path_is_absolute(arg_remote_host)) {
407 sa.un.sun_family = AF_UNIX;
408 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path));
409 return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
410 }
411
412 if (arg_remote_host[0] == '@') {
413 sa.un.sun_family = AF_UNIX;
414 sa.un.sun_path[0] = 0;
415 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-1);
416 return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un));
417 }
418
419 service = strrchr(arg_remote_host, ':');
420 if (service) {
421 node = strndupa(arg_remote_host, service - arg_remote_host);
422 service++;
423 } else {
424 node = arg_remote_host;
425 service = "80";
426 }
427
428 log_debug("Looking up address info for %s:%s", node, service);
429 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
430 if (r < 0) {
431 log_error_errno(r, "Failed to resolve remote host: %m");
432 goto fail;
433 }
434
435 return 0;
436
437 fail:
438 connection_free(c);
439 return 0; /* ignore errors, continue serving */
440 }
441
442 static int add_connection_socket(Context *context, int fd) {
443 Connection *c;
444 int r;
445
446 assert(context);
447 assert(fd >= 0);
448
449 if (set_size(context->connections) > arg_connections_max) {
450 log_warning("Hit connection limit, refusing connection.");
451 safe_close(fd);
452 return 0;
453 }
454
455 r = set_ensure_allocated(&context->connections, NULL);
456 if (r < 0) {
457 log_oom();
458 return 0;
459 }
460
461 c = new0(Connection, 1);
462 if (!c) {
463 log_oom();
464 return 0;
465 }
466
467 c->context = context;
468 c->server_fd = fd;
469 c->client_fd = -1;
470 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
471 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
472
473 r = set_put(context->connections, c);
474 if (r < 0) {
475 free(c);
476 log_oom();
477 return 0;
478 }
479
480 return resolve_remote(c);
481 }
482
483 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
484 _cleanup_free_ char *peer = NULL;
485 Context *context = userdata;
486 int nfd = -1, r;
487
488 assert(s);
489 assert(fd >= 0);
490 assert(revents & EPOLLIN);
491 assert(context);
492
493 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
494 if (nfd < 0) {
495 if (errno != -EAGAIN)
496 log_warning_errno(errno, "Failed to accept() socket: %m");
497 } else {
498 getpeername_pretty(nfd, true, &peer);
499 log_debug("New connection from %s", strna(peer));
500
501 r = add_connection_socket(context, nfd);
502 if (r < 0) {
503 log_error_errno(r, "Failed to accept connection, ignoring: %m");
504 safe_close(fd);
505 }
506 }
507
508 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
509 if (r < 0) {
510 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
511 sd_event_exit(context->event, r);
512 return r;
513 }
514
515 return 1;
516 }
517
518 static int add_listen_socket(Context *context, int fd) {
519 sd_event_source *source;
520 int r;
521
522 assert(context);
523 assert(fd >= 0);
524
525 r = set_ensure_allocated(&context->listen, NULL);
526 if (r < 0) {
527 log_oom();
528 return r;
529 }
530
531 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
532 if (r < 0)
533 return log_error_errno(r, "Failed to determine socket type: %m");
534 if (r == 0) {
535 log_error("Passed in socket is not a stream socket.");
536 return -EINVAL;
537 }
538
539 r = fd_nonblock(fd, true);
540 if (r < 0)
541 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
542
543 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
544 if (r < 0)
545 return log_error_errno(r, "Failed to add event source: %m");
546
547 r = set_put(context->listen, source);
548 if (r < 0) {
549 log_error_errno(r, "Failed to add source to set: %m");
550 sd_event_source_unref(source);
551 return r;
552 }
553
554 /* Set the watcher to oneshot in case other processes are also
555 * watching to accept(). */
556 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
557 if (r < 0)
558 return log_error_errno(r, "Failed to enable oneshot mode: %m");
559
560 return 0;
561 }
562
563 static void help(void) {
564 printf("%1$s [HOST:PORT]\n"
565 "%1$s [SOCKET]\n\n"
566 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
567 " -c --connections-max= Set the maximum number of connections to be accepted\n"
568 " -h --help Show this help\n"
569 " --version Show package version\n",
570 program_invocation_short_name);
571 }
572
573 static int parse_argv(int argc, char *argv[]) {
574
575 enum {
576 ARG_VERSION = 0x100,
577 ARG_IGNORE_ENV
578 };
579
580 static const struct option options[] = {
581 { "connections-max", required_argument, NULL, 'c' },
582 { "help", no_argument, NULL, 'h' },
583 { "version", no_argument, NULL, ARG_VERSION },
584 {}
585 };
586
587 int c, r;
588
589 assert(argc >= 0);
590 assert(argv);
591
592 while ((c = getopt_long(argc, argv, "c:h", options, NULL)) >= 0)
593
594 switch (c) {
595
596 case 'h':
597 help();
598 return 0;
599
600 case 'c':
601 r = safe_atou(optarg, &arg_connections_max);
602 if (r < 0) {
603 log_error("Failed to parse --connections-max= argument: %s", optarg);
604 return r;
605 }
606
607 if (arg_connections_max < 1) {
608 log_error("Connection limit is too low.");
609 return -EINVAL;
610 }
611
612 break;
613
614 case ARG_VERSION:
615 return version();
616
617 case '?':
618 return -EINVAL;
619
620 default:
621 assert_not_reached("Unhandled option");
622 }
623
624 if (optind >= argc) {
625 log_error("Not enough parameters.");
626 return -EINVAL;
627 }
628
629 if (argc != optind+1) {
630 log_error("Too many parameters.");
631 return -EINVAL;
632 }
633
634 arg_remote_host = argv[optind];
635 return 1;
636 }
637
638 int main(int argc, char *argv[]) {
639 Context context = {};
640 int r, n, fd;
641
642 log_parse_environment();
643 log_open();
644
645 r = parse_argv(argc, argv);
646 if (r <= 0)
647 goto finish;
648
649 r = sd_event_default(&context.event);
650 if (r < 0) {
651 log_error_errno(r, "Failed to allocate event loop: %m");
652 goto finish;
653 }
654
655 r = sd_resolve_default(&context.resolve);
656 if (r < 0) {
657 log_error_errno(r, "Failed to allocate resolver: %m");
658 goto finish;
659 }
660
661 r = sd_resolve_attach_event(context.resolve, context.event, 0);
662 if (r < 0) {
663 log_error_errno(r, "Failed to attach resolver: %m");
664 goto finish;
665 }
666
667 sd_event_set_watchdog(context.event, true);
668
669 n = sd_listen_fds(1);
670 if (n < 0) {
671 log_error("Failed to receive sockets from parent.");
672 r = n;
673 goto finish;
674 } else if (n == 0) {
675 log_error("Didn't get any sockets passed in.");
676 r = -EINVAL;
677 goto finish;
678 }
679
680 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
681 r = add_listen_socket(&context, fd);
682 if (r < 0)
683 goto finish;
684 }
685
686 r = sd_event_loop(context.event);
687 if (r < 0) {
688 log_error_errno(r, "Failed to run event loop: %m");
689 goto finish;
690 }
691
692 finish:
693 context_free(&context);
694
695 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
696 }