]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
event: make sure we keep a reference to all events we dispatch while we do so.
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4 This file is part of systemd.
5
6 Copyright 2013 David Strauss
7
8 systemd is free software; you can redistribute it and/or modify it
9 under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or
11 (at your option) any later version.
12
13 systemd is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 Lesser General Public License for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51 Set *listen;
52 Set *connections;
53 } Context;
54
55 typedef struct Connection {
56 int server_fd, client_fd;
57 int server_to_client_buffer[2]; /* a pipe */
58 int client_to_server_buffer[2]; /* a pipe */
59
60 size_t server_to_client_buffer_full, client_to_server_buffer_full;
61 size_t server_to_client_buffer_size, client_to_server_buffer_size;
62
63 sd_event_source *server_event_source, *client_event_source;
64 } Connection;
65
66 union sockaddr_any {
67 struct sockaddr sa;
68 struct sockaddr_un un;
69 struct sockaddr_in in;
70 struct sockaddr_in6 in6;
71 struct sockaddr_storage storage;
72 };
73
74 static const char *arg_remote_host = NULL;
75
76 static void connection_free(Connection *c) {
77 assert(c);
78
79 sd_event_source_unref(c->server_event_source);
80 sd_event_source_unref(c->client_event_source);
81
82 if (c->server_fd >= 0)
83 close_nointr_nofail(c->server_fd);
84 if (c->client_fd >= 0)
85 close_nointr_nofail(c->client_fd);
86
87 close_pipe(c->server_to_client_buffer);
88 close_pipe(c->client_to_server_buffer);
89
90 free(c);
91 }
92
93 static void context_free(Context *context) {
94 sd_event_source *es;
95 Connection *c;
96
97 assert(context);
98
99 while ((es = set_steal_first(context->listen)))
100 sd_event_source_unref(es);
101
102 while ((c = set_steal_first(context->connections)))
103 connection_free(c);
104
105 set_free(context->listen);
106 set_free(context->connections);
107 }
108
109 static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
110 int r;
111
112 assert(sa);
113 assert(salen);
114
115 if (path_is_absolute(arg_remote_host)) {
116 sa->un.sun_family = AF_UNIX;
117 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
118 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
119
120 *salen = offsetof(union sockaddr_any, un.sun_path) + strlen(sa->un.sun_path);
121
122 } else if (arg_remote_host[0] == '@') {
123 sa->un.sun_family = AF_UNIX;
124 sa->un.sun_path[0] = 0;
125 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
126 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
127
128 *salen = offsetof(union sockaddr_any, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
129
130 } else {
131 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
132 const char *node, *service;
133
134 struct addrinfo hints = {
135 .ai_family = AF_UNSPEC,
136 .ai_socktype = SOCK_STREAM,
137 .ai_flags = AI_ADDRCONFIG
138 };
139
140 service = strrchr(arg_remote_host, ':');
141 if (service) {
142 node = strndupa(arg_remote_host, service - arg_remote_host);
143 service ++;
144 } else {
145 node = arg_remote_host;
146 service = "80";
147 }
148
149 log_debug("Looking up address info for %s:%s", node, service);
150 r = getaddrinfo(node, service, &hints, &result);
151 if (r != 0) {
152 log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
153 return -EHOSTUNREACH;
154 }
155
156 assert(result);
157 if (result->ai_addrlen > sizeof(union sockaddr_any)) {
158 log_error("Address too long.");
159 return -E2BIG;
160 }
161
162 memcpy(sa, result->ai_addr, result->ai_addrlen);
163 *salen = result->ai_addrlen;
164 }
165
166 return 0;
167 }
168
169 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
170 int r;
171
172 assert(c);
173 assert(buffer);
174 assert(sz);
175
176 if (buffer[0] >= 0)
177 return 0;
178
179 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
180 if (r < 0) {
181 log_error("Failed to allocate pipe buffer: %m");
182 return -errno;
183 }
184
185 fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
186
187 r = fcntl(buffer[0], F_GETPIPE_SZ);
188 if (r < 0) {
189 log_error("Failed to get pipe buffer size: %m");
190 return -errno;
191 }
192
193 assert(r > 0);
194 *sz = r;
195
196 return 0;
197 }
198
199 static int connection_shovel(
200 Connection *c,
201 int *from, int buffer[2], int *to,
202 size_t *full, size_t *sz,
203 sd_event_source **from_source, sd_event_source **to_source) {
204
205 bool shoveled;
206
207 assert(c);
208 assert(from);
209 assert(buffer);
210 assert(buffer[0] >= 0);
211 assert(buffer[1] >= 0);
212 assert(to);
213 assert(full);
214 assert(sz);
215 assert(from_source);
216 assert(to_source);
217
218 do {
219 ssize_t z;
220
221 shoveled = false;
222
223 if (*full < *sz && *from >= 0 && *to >= 0) {
224 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
225 if (z > 0) {
226 *full += z;
227 shoveled = true;
228 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
229 *from_source = sd_event_source_unref(*from_source);
230 close_nointr_nofail(*from);
231 *from = -1;
232 } else if (errno != EAGAIN && errno != EINTR) {
233 log_error("Failed to splice: %m");
234 return -errno;
235 }
236 }
237
238 if (*full > 0 && *to >= 0) {
239 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
240 if (z > 0) {
241 *full -= z;
242 shoveled = true;
243 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
244 *to_source = sd_event_source_unref(*to_source);
245 close_nointr_nofail(*to);
246 *to = -1;
247 } else if (errno != EAGAIN && errno != EINTR) {
248 log_error("Failed to splice: %m");
249 return -errno;
250 }
251 }
252 } while (shoveled);
253
254 return 0;
255 }
256
257 static int connection_enable_event_sources(Connection *c, sd_event *event);
258
259 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
260 Connection *c = userdata;
261 int r;
262
263 assert(s);
264 assert(fd >= 0);
265 assert(c);
266
267 r = connection_shovel(c,
268 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
269 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
270 &c->server_event_source, &c->client_event_source);
271 if (r < 0)
272 goto quit;
273
274 r = connection_shovel(c,
275 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
276 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
277 &c->client_event_source, &c->server_event_source);
278 if (r < 0)
279 goto quit;
280
281 /* EOF on both sides? */
282 if (c->server_fd == -1 && c->client_fd == -1)
283 goto quit;
284
285 /* Server closed, and all data written to client? */
286 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
287 goto quit;
288
289 /* Client closed, and all data written to server? */
290 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
291 goto quit;
292
293 r = connection_enable_event_sources(c, sd_event_get(s));
294 if (r < 0)
295 goto quit;
296
297 return 1;
298
299 quit:
300 connection_free(c);
301 return 0; /* ignore errors, continue serving */
302 }
303
304 static int connection_enable_event_sources(Connection *c, sd_event *event) {
305 uint32_t a = 0, b = 0;
306 int r;
307
308 assert(c);
309 assert(event);
310
311 if (c->server_to_client_buffer_full > 0)
312 b |= EPOLLOUT;
313 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
314 a |= EPOLLIN;
315
316 if (c->client_to_server_buffer_full > 0)
317 a |= EPOLLOUT;
318 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
319 b |= EPOLLIN;
320
321 if (c->server_event_source)
322 r = sd_event_source_set_io_events(c->server_event_source, a);
323 else if (c->server_fd >= 0)
324 r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
325 else
326 r = 0;
327
328 if (r < 0) {
329 log_error("Failed to set up server event source: %s", strerror(-r));
330 return r;
331 }
332
333 if (c->client_event_source)
334 r = sd_event_source_set_io_events(c->client_event_source, b);
335 else if (c->client_fd >= 0)
336 r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
337 else
338 r = 0;
339
340 if (r < 0) {
341 log_error("Failed to set up client event source: %s", strerror(-r));
342 return r;
343 }
344
345 return 0;
346 }
347
348 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
349 Connection *c = userdata;
350 socklen_t solen;
351 int error, r;
352
353 assert(s);
354 assert(fd >= 0);
355 assert(c);
356
357 solen = sizeof(error);
358 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
359 if (r < 0) {
360 log_error("Failed to issue SO_ERROR: %m");
361 goto fail;
362 }
363
364 if (error != 0) {
365 log_error("Failed to connect to remote host: %s", strerror(error));
366 goto fail;
367 }
368
369 c->client_event_source = sd_event_source_unref(c->client_event_source);
370
371 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
372 if (r < 0)
373 goto fail;
374
375 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
376 if (r < 0)
377 goto fail;
378
379 r = connection_enable_event_sources(c, sd_event_get(s));
380 if (r < 0)
381 goto fail;
382
383 return 0;
384
385 fail:
386 connection_free(c);
387 return 0; /* ignore errors, continue serving */
388 }
389
390 static int add_connection_socket(Context *context, sd_event *event, int fd) {
391 union sockaddr_any sa = {};
392 socklen_t salen;
393 Connection *c;
394 int r;
395
396 assert(context);
397 assert(event);
398 assert(fd >= 0);
399
400 if (set_size(context->connections) > CONNECTIONS_MAX) {
401 log_warning("Hit connection limit, refusing connection.");
402 close_nointr_nofail(fd);
403 return 0;
404 }
405
406 r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
407 if (r < 0)
408 return log_oom();
409
410 c = new0(Connection, 1);
411 if (!c)
412 return log_oom();
413
414 c->server_fd = fd;
415 c->client_fd = -1;
416 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
417 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
418
419 r = get_remote_sockaddr(&sa, &salen);
420 if (r < 0)
421 goto fail;
422
423 c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
424 if (c->client_fd < 0) {
425 log_error("Failed to get remote socket: %m");
426 goto fail;
427 }
428
429 r = connect(c->client_fd, &sa.sa, salen);
430 if (r < 0) {
431 if (errno == EINPROGRESS) {
432 r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
433 if (r < 0) {
434 log_error("Failed to add connection socket: %s", strerror(-r));
435 goto fail;
436 }
437
438 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
439 if (r < 0) {
440 log_error("Failed to enable oneshot event source: %s", strerror(-r));
441 goto fail;
442 }
443 } else {
444 log_error("Failed to connect to remote host: %m");
445 goto fail;
446 }
447 } else {
448 r = connection_enable_event_sources(c, event);
449 if (r < 0)
450 goto fail;
451 }
452
453 return 0;
454
455 fail:
456 connection_free(c);
457 return 0; /* ignore non-OOM errors, continue serving */
458 }
459
460 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
461 Context *context = userdata;
462 int nfd = -1, r;
463
464 assert(s);
465 assert(fd >= 0);
466 assert(revents & EPOLLIN);
467 assert(context);
468
469 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
470 if (nfd >= 0) {
471 _cleanup_free_ char *peer = NULL;
472
473 getpeername_pretty(nfd, &peer);
474 log_debug("New connection from %s", strna(peer));
475
476 r = add_connection_socket(context, sd_event_get(s), nfd);
477 if (r < 0) {
478 close_nointr_nofail(fd);
479 return r;
480 }
481
482 } else if (errno != -EAGAIN)
483 log_warning("Failed to accept() socket: %m");
484
485 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
486 if (r < 0) {
487 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
488 return r;
489 }
490
491 return 1;
492 }
493
494 static int add_listen_socket(Context *context, sd_event *event, int fd) {
495 sd_event_source *source;
496 int r;
497
498 assert(context);
499 assert(event);
500 assert(fd >= 0);
501
502 log_info("Listening on %i", fd);
503
504 r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
505 if (r < 0) {
506 log_oom();
507 return r;
508 }
509
510 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
511 if (r < 0) {
512 log_error("Failed to determine socket type: %s", strerror(-r));
513 return r;
514 }
515 if (r == 0) {
516 log_error("Passed in socket is not a stream socket.");
517 return -EINVAL;
518 }
519
520 r = fd_nonblock(fd, true);
521 if (r < 0) {
522 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
523 return r;
524 }
525
526 r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
527 if (r < 0) {
528 log_error("Failed to add event source: %s", strerror(-r));
529 return r;
530 }
531
532 r = set_put(context->listen, source);
533 if (r < 0) {
534 log_error("Failed to add source to set: %s", strerror(-r));
535 sd_event_source_unref(source);
536 return r;
537 }
538
539 /* Set the watcher to oneshot in case other processes are also
540 * watching to accept(). */
541 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
542 if (r < 0) {
543 log_error("Failed to enable oneshot mode: %s", strerror(-r));
544 return r;
545 }
546
547 return 0;
548 }
549
550 static int help(void) {
551
552 printf("%s [HOST:PORT]\n"
553 "%s [SOCKET]\n\n"
554 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
555 " -h --help Show this help\n"
556 " --version Show package version\n",
557 program_invocation_short_name,
558 program_invocation_short_name);
559
560 return 0;
561 }
562
563 static int parse_argv(int argc, char *argv[]) {
564
565 enum {
566 ARG_VERSION = 0x100,
567 ARG_IGNORE_ENV
568 };
569
570 static const struct option options[] = {
571 { "help", no_argument, NULL, 'h' },
572 { "version", no_argument, NULL, ARG_VERSION },
573 {}
574 };
575
576 int c;
577
578 assert(argc >= 0);
579 assert(argv);
580
581 while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
582
583 switch (c) {
584
585 case 'h':
586 return help();
587
588 case ARG_VERSION:
589 puts(PACKAGE_STRING);
590 puts(SYSTEMD_FEATURES);
591 return 0;
592
593 case '?':
594 return -EINVAL;
595
596 default:
597 assert_not_reached("Unhandled option");
598 }
599 }
600
601 if (optind >= argc) {
602 log_error("Not enough parameters.");
603 return -EINVAL;
604 }
605
606 if (argc != optind+1) {
607 log_error("Too many parameters.");
608 return -EINVAL;
609 }
610
611 arg_remote_host = argv[optind];
612 return 1;
613 }
614
615 int main(int argc, char *argv[]) {
616 _cleanup_event_unref_ sd_event *event = NULL;
617 Context context = {};
618 int r, n, fd;
619
620 log_parse_environment();
621 log_open();
622
623 r = parse_argv(argc, argv);
624 if (r <= 0)
625 goto finish;
626
627 r = sd_event_new(&event);
628 if (r < 0) {
629 log_error("Failed to allocate event loop: %s", strerror(-r));
630 goto finish;
631 }
632
633 n = sd_listen_fds(1);
634 if (n < 0) {
635 log_error("Failed to receive sockets from parent.");
636 r = n;
637 goto finish;
638 } else if (n == 0) {
639 log_error("Didn't get any sockets passed in.");
640 r = -EINVAL;
641 goto finish;
642 }
643
644 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
645 r = add_listen_socket(&context, event, fd);
646 if (r < 0)
647 goto finish;
648 }
649
650 r = sd_event_loop(event);
651 if (r < 0) {
652 log_error("Failed to run event loop: %s", strerror(-r));
653 goto finish;
654 }
655
656 finish:
657 context_free(&context);
658
659 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
660 }