]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
99d4b62139b8ebc5307dc03a3fc5fd0464f31c41
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /***
2 This file is part of systemd.
3
4 Copyright 2013 David Strauss
5
6 systemd is free software; you can redistribute it and/or modify it
7 under the terms of the GNU Lesser General Public License as published by
8 the Free Software Foundation; either version 2.1 of the License, or
9 (at your option) any later version.
10
11 systemd is distributed in the hope that it will be useful, but
12 WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 Lesser General Public License for more details.
15
16 You should have received a copy of the GNU Lesser General Public License
17 along with systemd; If not, see <http://www.gnu.org/licenses/>.
18 ***/
19
20 #include <errno.h>
21 #include <fcntl.h>
22 #include <getopt.h>
23 #include <netdb.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27 #include <sys/socket.h>
28 #include <sys/un.h>
29 #include <unistd.h>
30
31 #include "sd-daemon.h"
32 #include "sd-event.h"
33 #include "sd-resolve.h"
34
35 #include "alloc-util.h"
36 #include "fd-util.h"
37 #include "log.h"
38 #include "path-util.h"
39 #include "set.h"
40 #include "socket-util.h"
41 #include "string-util.h"
42 #include "util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 static const char *arg_remote_host = NULL;
48
49 typedef struct Context {
50 sd_event *event;
51 sd_resolve *resolve;
52
53 Set *listen;
54 Set *connections;
55 } Context;
56
57 typedef struct Connection {
58 Context *context;
59
60 int server_fd, client_fd;
61 int server_to_client_buffer[2]; /* a pipe */
62 int client_to_server_buffer[2]; /* a pipe */
63
64 size_t server_to_client_buffer_full, client_to_server_buffer_full;
65 size_t server_to_client_buffer_size, client_to_server_buffer_size;
66
67 sd_event_source *server_event_source, *client_event_source;
68
69 sd_resolve_query *resolve_query;
70 } Connection;
71
72 static void connection_free(Connection *c) {
73 assert(c);
74
75 if (c->context)
76 set_remove(c->context->connections, c);
77
78 sd_event_source_unref(c->server_event_source);
79 sd_event_source_unref(c->client_event_source);
80
81 safe_close(c->server_fd);
82 safe_close(c->client_fd);
83
84 safe_close_pair(c->server_to_client_buffer);
85 safe_close_pair(c->client_to_server_buffer);
86
87 sd_resolve_query_unref(c->resolve_query);
88
89 free(c);
90 }
91
92 static void context_free(Context *context) {
93 sd_event_source *es;
94 Connection *c;
95
96 assert(context);
97
98 while ((es = set_steal_first(context->listen)))
99 sd_event_source_unref(es);
100
101 while ((c = set_first(context->connections)))
102 connection_free(c);
103
104 set_free(context->listen);
105 set_free(context->connections);
106
107 sd_event_unref(context->event);
108 sd_resolve_unref(context->resolve);
109 }
110
111 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
112 int r;
113
114 assert(c);
115 assert(buffer);
116 assert(sz);
117
118 if (buffer[0] >= 0)
119 return 0;
120
121 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
122 if (r < 0)
123 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
124
125 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
126
127 r = fcntl(buffer[0], F_GETPIPE_SZ);
128 if (r < 0)
129 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
130
131 assert(r > 0);
132 *sz = r;
133
134 return 0;
135 }
136
137 static int connection_shovel(
138 Connection *c,
139 int *from, int buffer[2], int *to,
140 size_t *full, size_t *sz,
141 sd_event_source **from_source, sd_event_source **to_source) {
142
143 bool shoveled;
144
145 assert(c);
146 assert(from);
147 assert(buffer);
148 assert(buffer[0] >= 0);
149 assert(buffer[1] >= 0);
150 assert(to);
151 assert(full);
152 assert(sz);
153 assert(from_source);
154 assert(to_source);
155
156 do {
157 ssize_t z;
158
159 shoveled = false;
160
161 if (*full < *sz && *from >= 0 && *to >= 0) {
162 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
163 if (z > 0) {
164 *full += z;
165 shoveled = true;
166 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
167 *from_source = sd_event_source_unref(*from_source);
168 *from = safe_close(*from);
169 } else if (errno != EAGAIN && errno != EINTR)
170 return log_error_errno(errno, "Failed to splice: %m");
171 }
172
173 if (*full > 0 && *to >= 0) {
174 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
175 if (z > 0) {
176 *full -= z;
177 shoveled = true;
178 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
179 *to_source = sd_event_source_unref(*to_source);
180 *to = safe_close(*to);
181 } else if (errno != EAGAIN && errno != EINTR)
182 return log_error_errno(errno, "Failed to splice: %m");
183 }
184 } while (shoveled);
185
186 return 0;
187 }
188
189 static int connection_enable_event_sources(Connection *c);
190
191 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
192 Connection *c = userdata;
193 int r;
194
195 assert(s);
196 assert(fd >= 0);
197 assert(c);
198
199 r = connection_shovel(c,
200 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
201 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
202 &c->server_event_source, &c->client_event_source);
203 if (r < 0)
204 goto quit;
205
206 r = connection_shovel(c,
207 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
208 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
209 &c->client_event_source, &c->server_event_source);
210 if (r < 0)
211 goto quit;
212
213 /* EOF on both sides? */
214 if (c->server_fd == -1 && c->client_fd == -1)
215 goto quit;
216
217 /* Server closed, and all data written to client? */
218 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
219 goto quit;
220
221 /* Client closed, and all data written to server? */
222 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
223 goto quit;
224
225 r = connection_enable_event_sources(c);
226 if (r < 0)
227 goto quit;
228
229 return 1;
230
231 quit:
232 connection_free(c);
233 return 0; /* ignore errors, continue serving */
234 }
235
236 static int connection_enable_event_sources(Connection *c) {
237 uint32_t a = 0, b = 0;
238 int r;
239
240 assert(c);
241
242 if (c->server_to_client_buffer_full > 0)
243 b |= EPOLLOUT;
244 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
245 a |= EPOLLIN;
246
247 if (c->client_to_server_buffer_full > 0)
248 a |= EPOLLOUT;
249 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
250 b |= EPOLLIN;
251
252 if (c->server_event_source)
253 r = sd_event_source_set_io_events(c->server_event_source, a);
254 else if (c->server_fd >= 0)
255 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
256 else
257 r = 0;
258
259 if (r < 0)
260 return log_error_errno(r, "Failed to set up server event source: %m");
261
262 if (c->client_event_source)
263 r = sd_event_source_set_io_events(c->client_event_source, b);
264 else if (c->client_fd >= 0)
265 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
266 else
267 r = 0;
268
269 if (r < 0)
270 return log_error_errno(r, "Failed to set up client event source: %m");
271
272 return 0;
273 }
274
275 static int connection_complete(Connection *c) {
276 int r;
277
278 assert(c);
279
280 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
281 if (r < 0)
282 goto fail;
283
284 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
285 if (r < 0)
286 goto fail;
287
288 r = connection_enable_event_sources(c);
289 if (r < 0)
290 goto fail;
291
292 return 0;
293
294 fail:
295 connection_free(c);
296 return 0; /* ignore errors, continue serving */
297 }
298
299 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
300 Connection *c = userdata;
301 socklen_t solen;
302 int error, r;
303
304 assert(s);
305 assert(fd >= 0);
306 assert(c);
307
308 solen = sizeof(error);
309 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
310 if (r < 0) {
311 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
312 goto fail;
313 }
314
315 if (error != 0) {
316 log_error_errno(error, "Failed to connect to remote host: %m");
317 goto fail;
318 }
319
320 c->client_event_source = sd_event_source_unref(c->client_event_source);
321
322 return connection_complete(c);
323
324 fail:
325 connection_free(c);
326 return 0; /* ignore errors, continue serving */
327 }
328
329 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
330 int r;
331
332 assert(c);
333 assert(sa);
334 assert(salen);
335
336 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
337 if (c->client_fd < 0) {
338 log_error_errno(errno, "Failed to get remote socket: %m");
339 goto fail;
340 }
341
342 r = connect(c->client_fd, sa, salen);
343 if (r < 0) {
344 if (errno == EINPROGRESS) {
345 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
346 if (r < 0) {
347 log_error_errno(r, "Failed to add connection socket: %m");
348 goto fail;
349 }
350
351 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
352 if (r < 0) {
353 log_error_errno(r, "Failed to enable oneshot event source: %m");
354 goto fail;
355 }
356 } else {
357 log_error_errno(errno, "Failed to connect to remote host: %m");
358 goto fail;
359 }
360 } else {
361 r = connection_complete(c);
362 if (r < 0)
363 goto fail;
364 }
365
366 return 0;
367
368 fail:
369 connection_free(c);
370 return 0; /* ignore errors, continue serving */
371 }
372
373 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
374 Connection *c = userdata;
375
376 assert(q);
377 assert(c);
378
379 if (ret != 0) {
380 log_error("Failed to resolve host: %s", gai_strerror(ret));
381 goto fail;
382 }
383
384 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
385
386 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
387
388 fail:
389 connection_free(c);
390 return 0; /* ignore errors, continue serving */
391 }
392
393 static int resolve_remote(Connection *c) {
394
395 static const struct addrinfo hints = {
396 .ai_family = AF_UNSPEC,
397 .ai_socktype = SOCK_STREAM,
398 .ai_flags = AI_ADDRCONFIG
399 };
400
401 union sockaddr_union sa = {};
402 const char *node, *service;
403 socklen_t salen;
404 int r;
405
406 if (path_is_absolute(arg_remote_host)) {
407 sa.un.sun_family = AF_UNIX;
408 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
409 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
410
411 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
412
413 return connection_start(c, &sa.sa, salen);
414 }
415
416 if (arg_remote_host[0] == '@') {
417 sa.un.sun_family = AF_UNIX;
418 sa.un.sun_path[0] = 0;
419 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
420 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
421
422 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
423
424 return connection_start(c, &sa.sa, salen);
425 }
426
427 service = strrchr(arg_remote_host, ':');
428 if (service) {
429 node = strndupa(arg_remote_host, service - arg_remote_host);
430 service ++;
431 } else {
432 node = arg_remote_host;
433 service = "80";
434 }
435
436 log_debug("Looking up address info for %s:%s", node, service);
437 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
438 if (r < 0) {
439 log_error_errno(r, "Failed to resolve remote host: %m");
440 goto fail;
441 }
442
443 return 0;
444
445 fail:
446 connection_free(c);
447 return 0; /* ignore errors, continue serving */
448 }
449
450 static int add_connection_socket(Context *context, int fd) {
451 Connection *c;
452 int r;
453
454 assert(context);
455 assert(fd >= 0);
456
457 if (set_size(context->connections) > CONNECTIONS_MAX) {
458 log_warning("Hit connection limit, refusing connection.");
459 safe_close(fd);
460 return 0;
461 }
462
463 r = set_ensure_allocated(&context->connections, NULL);
464 if (r < 0) {
465 log_oom();
466 return 0;
467 }
468
469 c = new0(Connection, 1);
470 if (!c) {
471 log_oom();
472 return 0;
473 }
474
475 c->context = context;
476 c->server_fd = fd;
477 c->client_fd = -1;
478 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
479 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
480
481 r = set_put(context->connections, c);
482 if (r < 0) {
483 free(c);
484 log_oom();
485 return 0;
486 }
487
488 return resolve_remote(c);
489 }
490
491 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
492 _cleanup_free_ char *peer = NULL;
493 Context *context = userdata;
494 int nfd = -1, r;
495
496 assert(s);
497 assert(fd >= 0);
498 assert(revents & EPOLLIN);
499 assert(context);
500
501 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
502 if (nfd < 0) {
503 if (errno != -EAGAIN)
504 log_warning_errno(errno, "Failed to accept() socket: %m");
505 } else {
506 getpeername_pretty(nfd, true, &peer);
507 log_debug("New connection from %s", strna(peer));
508
509 r = add_connection_socket(context, nfd);
510 if (r < 0) {
511 log_error_errno(r, "Failed to accept connection, ignoring: %m");
512 safe_close(fd);
513 }
514 }
515
516 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
517 if (r < 0) {
518 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
519 sd_event_exit(context->event, r);
520 return r;
521 }
522
523 return 1;
524 }
525
526 static int add_listen_socket(Context *context, int fd) {
527 sd_event_source *source;
528 int r;
529
530 assert(context);
531 assert(fd >= 0);
532
533 r = set_ensure_allocated(&context->listen, NULL);
534 if (r < 0) {
535 log_oom();
536 return r;
537 }
538
539 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
540 if (r < 0)
541 return log_error_errno(r, "Failed to determine socket type: %m");
542 if (r == 0) {
543 log_error("Passed in socket is not a stream socket.");
544 return -EINVAL;
545 }
546
547 r = fd_nonblock(fd, true);
548 if (r < 0)
549 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
550
551 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
552 if (r < 0)
553 return log_error_errno(r, "Failed to add event source: %m");
554
555 r = set_put(context->listen, source);
556 if (r < 0) {
557 log_error_errno(r, "Failed to add source to set: %m");
558 sd_event_source_unref(source);
559 return r;
560 }
561
562 /* Set the watcher to oneshot in case other processes are also
563 * watching to accept(). */
564 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
565 if (r < 0)
566 return log_error_errno(r, "Failed to enable oneshot mode: %m");
567
568 return 0;
569 }
570
571 static void help(void) {
572 printf("%1$s [HOST:PORT]\n"
573 "%1$s [SOCKET]\n\n"
574 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
575 " -h --help Show this help\n"
576 " --version Show package version\n",
577 program_invocation_short_name);
578 }
579
580 static int parse_argv(int argc, char *argv[]) {
581
582 enum {
583 ARG_VERSION = 0x100,
584 ARG_IGNORE_ENV
585 };
586
587 static const struct option options[] = {
588 { "help", no_argument, NULL, 'h' },
589 { "version", no_argument, NULL, ARG_VERSION },
590 {}
591 };
592
593 int c;
594
595 assert(argc >= 0);
596 assert(argv);
597
598 while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
599
600 switch (c) {
601
602 case 'h':
603 help();
604 return 0;
605
606 case ARG_VERSION:
607 return version();
608
609 case '?':
610 return -EINVAL;
611
612 default:
613 assert_not_reached("Unhandled option");
614 }
615
616 if (optind >= argc) {
617 log_error("Not enough parameters.");
618 return -EINVAL;
619 }
620
621 if (argc != optind+1) {
622 log_error("Too many parameters.");
623 return -EINVAL;
624 }
625
626 arg_remote_host = argv[optind];
627 return 1;
628 }
629
630 int main(int argc, char *argv[]) {
631 Context context = {};
632 int r, n, fd;
633
634 log_parse_environment();
635 log_open();
636
637 r = parse_argv(argc, argv);
638 if (r <= 0)
639 goto finish;
640
641 r = sd_event_default(&context.event);
642 if (r < 0) {
643 log_error_errno(r, "Failed to allocate event loop: %m");
644 goto finish;
645 }
646
647 r = sd_resolve_default(&context.resolve);
648 if (r < 0) {
649 log_error_errno(r, "Failed to allocate resolver: %m");
650 goto finish;
651 }
652
653 r = sd_resolve_attach_event(context.resolve, context.event, 0);
654 if (r < 0) {
655 log_error_errno(r, "Failed to attach resolver: %m");
656 goto finish;
657 }
658
659 sd_event_set_watchdog(context.event, true);
660
661 n = sd_listen_fds(1);
662 if (n < 0) {
663 log_error("Failed to receive sockets from parent.");
664 r = n;
665 goto finish;
666 } else if (n == 0) {
667 log_error("Didn't get any sockets passed in.");
668 r = -EINVAL;
669 goto finish;
670 }
671
672 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
673 r = add_listen_socket(&context, fd);
674 if (r < 0)
675 goto finish;
676 }
677
678 r = sd_event_loop(context.event);
679 if (r < 0) {
680 log_error_errno(r, "Failed to run event loop: %m");
681 goto finish;
682 }
683
684 finish:
685 context_free(&context);
686
687 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
688 }