]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
api: in constructor function calls, always put the returned object pointer first...
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4 This file is part of systemd.
5
6 Copyright 2013 David Strauss
7
8 systemd is free software; you can redistribute it and/or modify it
9 under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or
11 (at your option) any later version.
12
13 systemd is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 Lesser General Public License for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51 Set *listen;
52 Set *connections;
53 } Context;
54
55 typedef struct Connection {
56 Context *context;
57
58 int server_fd, client_fd;
59 int server_to_client_buffer[2]; /* a pipe */
60 int client_to_server_buffer[2]; /* a pipe */
61
62 size_t server_to_client_buffer_full, client_to_server_buffer_full;
63 size_t server_to_client_buffer_size, client_to_server_buffer_size;
64
65 sd_event_source *server_event_source, *client_event_source;
66 } Connection;
67
68 static const char *arg_remote_host = NULL;
69
70 static void connection_free(Connection *c) {
71 assert(c);
72
73 if (c->context)
74 set_remove(c->context->connections, c);
75
76 sd_event_source_unref(c->server_event_source);
77 sd_event_source_unref(c->client_event_source);
78
79 if (c->server_fd >= 0)
80 close_nointr_nofail(c->server_fd);
81 if (c->client_fd >= 0)
82 close_nointr_nofail(c->client_fd);
83
84 close_pipe(c->server_to_client_buffer);
85 close_pipe(c->client_to_server_buffer);
86
87 free(c);
88 }
89
90 static void context_free(Context *context) {
91 sd_event_source *es;
92 Connection *c;
93
94 assert(context);
95
96 while ((es = set_steal_first(context->listen)))
97 sd_event_source_unref(es);
98
99 while ((c = set_first(context->connections)))
100 connection_free(c);
101
102 set_free(context->listen);
103 set_free(context->connections);
104 }
105
106 static int get_remote_sockaddr(union sockaddr_union *sa, socklen_t *salen) {
107 int r;
108
109 assert(sa);
110 assert(salen);
111
112 if (path_is_absolute(arg_remote_host)) {
113 sa->un.sun_family = AF_UNIX;
114 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
115 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
116
117 *salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa->un.sun_path);
118
119 } else if (arg_remote_host[0] == '@') {
120 sa->un.sun_family = AF_UNIX;
121 sa->un.sun_path[0] = 0;
122 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
123 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
124
125 *salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
126
127 } else {
128 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
129 const char *node, *service;
130
131 struct addrinfo hints = {
132 .ai_family = AF_UNSPEC,
133 .ai_socktype = SOCK_STREAM,
134 .ai_flags = AI_ADDRCONFIG
135 };
136
137 service = strrchr(arg_remote_host, ':');
138 if (service) {
139 node = strndupa(arg_remote_host, service - arg_remote_host);
140 service ++;
141 } else {
142 node = arg_remote_host;
143 service = "80";
144 }
145
146 log_debug("Looking up address info for %s:%s", node, service);
147 r = getaddrinfo(node, service, &hints, &result);
148 if (r != 0) {
149 log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
150 return -EHOSTUNREACH;
151 }
152
153 assert(result);
154 if (result->ai_addrlen > sizeof(union sockaddr_union)) {
155 log_error("Address too long.");
156 return -E2BIG;
157 }
158
159 memcpy(sa, result->ai_addr, result->ai_addrlen);
160 *salen = result->ai_addrlen;
161 }
162
163 return 0;
164 }
165
166 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
167 int r;
168
169 assert(c);
170 assert(buffer);
171 assert(sz);
172
173 if (buffer[0] >= 0)
174 return 0;
175
176 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
177 if (r < 0) {
178 log_error("Failed to allocate pipe buffer: %m");
179 return -errno;
180 }
181
182 fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
183
184 r = fcntl(buffer[0], F_GETPIPE_SZ);
185 if (r < 0) {
186 log_error("Failed to get pipe buffer size: %m");
187 return -errno;
188 }
189
190 assert(r > 0);
191 *sz = r;
192
193 return 0;
194 }
195
196 static int connection_shovel(
197 Connection *c,
198 int *from, int buffer[2], int *to,
199 size_t *full, size_t *sz,
200 sd_event_source **from_source, sd_event_source **to_source) {
201
202 bool shoveled;
203
204 assert(c);
205 assert(from);
206 assert(buffer);
207 assert(buffer[0] >= 0);
208 assert(buffer[1] >= 0);
209 assert(to);
210 assert(full);
211 assert(sz);
212 assert(from_source);
213 assert(to_source);
214
215 do {
216 ssize_t z;
217
218 shoveled = false;
219
220 if (*full < *sz && *from >= 0 && *to >= 0) {
221 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
222 if (z > 0) {
223 *full += z;
224 shoveled = true;
225 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
226 *from_source = sd_event_source_unref(*from_source);
227 close_nointr_nofail(*from);
228 *from = -1;
229 } else if (errno != EAGAIN && errno != EINTR) {
230 log_error("Failed to splice: %m");
231 return -errno;
232 }
233 }
234
235 if (*full > 0 && *to >= 0) {
236 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
237 if (z > 0) {
238 *full -= z;
239 shoveled = true;
240 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
241 *to_source = sd_event_source_unref(*to_source);
242 close_nointr_nofail(*to);
243 *to = -1;
244 } else if (errno != EAGAIN && errno != EINTR) {
245 log_error("Failed to splice: %m");
246 return -errno;
247 }
248 }
249 } while (shoveled);
250
251 return 0;
252 }
253
254 static int connection_enable_event_sources(Connection *c, sd_event *event);
255
256 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
257 Connection *c = userdata;
258 int r;
259
260 assert(s);
261 assert(fd >= 0);
262 assert(c);
263
264 r = connection_shovel(c,
265 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
266 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
267 &c->server_event_source, &c->client_event_source);
268 if (r < 0)
269 goto quit;
270
271 r = connection_shovel(c,
272 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
273 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
274 &c->client_event_source, &c->server_event_source);
275 if (r < 0)
276 goto quit;
277
278 /* EOF on both sides? */
279 if (c->server_fd == -1 && c->client_fd == -1)
280 goto quit;
281
282 /* Server closed, and all data written to client? */
283 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
284 goto quit;
285
286 /* Client closed, and all data written to server? */
287 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
288 goto quit;
289
290 r = connection_enable_event_sources(c, sd_event_source_get_event(s));
291 if (r < 0)
292 goto quit;
293
294 return 1;
295
296 quit:
297 connection_free(c);
298 return 0; /* ignore errors, continue serving */
299 }
300
301 static int connection_enable_event_sources(Connection *c, sd_event *event) {
302 uint32_t a = 0, b = 0;
303 int r;
304
305 assert(c);
306 assert(event);
307
308 if (c->server_to_client_buffer_full > 0)
309 b |= EPOLLOUT;
310 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
311 a |= EPOLLIN;
312
313 if (c->client_to_server_buffer_full > 0)
314 a |= EPOLLOUT;
315 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
316 b |= EPOLLIN;
317
318 if (c->server_event_source)
319 r = sd_event_source_set_io_events(c->server_event_source, a);
320 else if (c->server_fd >= 0)
321 r = sd_event_add_io(event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
322 else
323 r = 0;
324
325 if (r < 0) {
326 log_error("Failed to set up server event source: %s", strerror(-r));
327 return r;
328 }
329
330 if (c->client_event_source)
331 r = sd_event_source_set_io_events(c->client_event_source, b);
332 else if (c->client_fd >= 0)
333 r = sd_event_add_io(event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
334 else
335 r = 0;
336
337 if (r < 0) {
338 log_error("Failed to set up client event source: %s", strerror(-r));
339 return r;
340 }
341
342 return 0;
343 }
344
345 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
346 Connection *c = userdata;
347 socklen_t solen;
348 int error, r;
349
350 assert(s);
351 assert(fd >= 0);
352 assert(c);
353
354 solen = sizeof(error);
355 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
356 if (r < 0) {
357 log_error("Failed to issue SO_ERROR: %m");
358 goto fail;
359 }
360
361 if (error != 0) {
362 log_error("Failed to connect to remote host: %s", strerror(error));
363 goto fail;
364 }
365
366 c->client_event_source = sd_event_source_unref(c->client_event_source);
367
368 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
369 if (r < 0)
370 goto fail;
371
372 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
373 if (r < 0)
374 goto fail;
375
376 r = connection_enable_event_sources(c, sd_event_source_get_event(s));
377 if (r < 0)
378 goto fail;
379
380 return 0;
381
382 fail:
383 connection_free(c);
384 return 0; /* ignore errors, continue serving */
385 }
386
387 static int add_connection_socket(Context *context, sd_event *event, int fd) {
388 union sockaddr_union sa = {};
389 socklen_t salen;
390 Connection *c;
391 int r;
392
393 assert(context);
394 assert(event);
395 assert(fd >= 0);
396
397 if (set_size(context->connections) > CONNECTIONS_MAX) {
398 log_warning("Hit connection limit, refusing connection.");
399 close_nointr_nofail(fd);
400 return 0;
401 }
402
403 r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
404 if (r < 0)
405 return log_oom();
406
407 c = new0(Connection, 1);
408 if (!c)
409 return log_oom();
410
411 c->context = context;
412 c->server_fd = fd;
413 c->client_fd = -1;
414 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
415 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
416
417 r = set_put(context->connections, c);
418 if (r < 0) {
419 free(c);
420 return log_oom();
421 }
422
423 r = get_remote_sockaddr(&sa, &salen);
424 if (r < 0)
425 goto fail;
426
427 c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
428 if (c->client_fd < 0) {
429 log_error("Failed to get remote socket: %m");
430 goto fail;
431 }
432
433 r = connect(c->client_fd, &sa.sa, salen);
434 if (r < 0) {
435 if (errno == EINPROGRESS) {
436 r = sd_event_add_io(event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
437 if (r < 0) {
438 log_error("Failed to add connection socket: %s", strerror(-r));
439 goto fail;
440 }
441
442 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
443 if (r < 0) {
444 log_error("Failed to enable oneshot event source: %s", strerror(-r));
445 goto fail;
446 }
447 } else {
448 log_error("Failed to connect to remote host: %m");
449 goto fail;
450 }
451 } else {
452 r = connection_enable_event_sources(c, event);
453 if (r < 0)
454 goto fail;
455 }
456
457 return 0;
458
459 fail:
460 connection_free(c);
461 return 0; /* ignore non-OOM errors, continue serving */
462 }
463
464 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
465 _cleanup_free_ char *peer = NULL;
466 Context *context = userdata;
467 int nfd = -1, r;
468
469 assert(s);
470 assert(fd >= 0);
471 assert(revents & EPOLLIN);
472 assert(context);
473
474 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
475 if (nfd < 0) {
476 if (errno != -EAGAIN)
477 log_warning("Failed to accept() socket: %m");
478 } else {
479 getpeername_pretty(nfd, &peer);
480 log_debug("New connection from %s", strna(peer));
481
482 r = add_connection_socket(context, sd_event_source_get_event(s), nfd);
483 if (r < 0) {
484 log_error("Failed to accept connection, ignoring: %s", strerror(-r));
485 close_nointr_nofail(fd);
486 }
487 }
488
489 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
490 if (r < 0) {
491 log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r));
492 sd_event_exit(sd_event_source_get_event(s), r);
493 return r;
494 }
495
496 return 1;
497 }
498
499 static int add_listen_socket(Context *context, sd_event *event, int fd) {
500 sd_event_source *source;
501 int r;
502
503 assert(context);
504 assert(event);
505 assert(fd >= 0);
506
507 r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
508 if (r < 0) {
509 log_oom();
510 return r;
511 }
512
513 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
514 if (r < 0) {
515 log_error("Failed to determine socket type: %s", strerror(-r));
516 return r;
517 }
518 if (r == 0) {
519 log_error("Passed in socket is not a stream socket.");
520 return -EINVAL;
521 }
522
523 r = fd_nonblock(fd, true);
524 if (r < 0) {
525 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
526 return r;
527 }
528
529 r = sd_event_add_io(event, &source, fd, EPOLLIN, accept_cb, context);
530 if (r < 0) {
531 log_error("Failed to add event source: %s", strerror(-r));
532 return r;
533 }
534
535 r = set_put(context->listen, source);
536 if (r < 0) {
537 log_error("Failed to add source to set: %s", strerror(-r));
538 sd_event_source_unref(source);
539 return r;
540 }
541
542 /* Set the watcher to oneshot in case other processes are also
543 * watching to accept(). */
544 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
545 if (r < 0) {
546 log_error("Failed to enable oneshot mode: %s", strerror(-r));
547 return r;
548 }
549
550 return 0;
551 }
552
553 static int help(void) {
554
555 printf("%s [HOST:PORT]\n"
556 "%s [SOCKET]\n\n"
557 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
558 " -h --help Show this help\n"
559 " --version Show package version\n",
560 program_invocation_short_name,
561 program_invocation_short_name);
562
563 return 0;
564 }
565
566 static int parse_argv(int argc, char *argv[]) {
567
568 enum {
569 ARG_VERSION = 0x100,
570 ARG_IGNORE_ENV
571 };
572
573 static const struct option options[] = {
574 { "help", no_argument, NULL, 'h' },
575 { "version", no_argument, NULL, ARG_VERSION },
576 {}
577 };
578
579 int c;
580
581 assert(argc >= 0);
582 assert(argv);
583
584 while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
585
586 switch (c) {
587
588 case 'h':
589 return help();
590
591 case ARG_VERSION:
592 puts(PACKAGE_STRING);
593 puts(SYSTEMD_FEATURES);
594 return 0;
595
596 case '?':
597 return -EINVAL;
598
599 default:
600 assert_not_reached("Unhandled option");
601 }
602 }
603
604 if (optind >= argc) {
605 log_error("Not enough parameters.");
606 return -EINVAL;
607 }
608
609 if (argc != optind+1) {
610 log_error("Too many parameters.");
611 return -EINVAL;
612 }
613
614 arg_remote_host = argv[optind];
615 return 1;
616 }
617
618 int main(int argc, char *argv[]) {
619 _cleanup_event_unref_ sd_event *event = NULL;
620 Context context = {};
621 int r, n, fd;
622
623 log_parse_environment();
624 log_open();
625
626 r = parse_argv(argc, argv);
627 if (r <= 0)
628 goto finish;
629
630 r = sd_event_default(&event);
631 if (r < 0) {
632 log_error("Failed to allocate event loop: %s", strerror(-r));
633 goto finish;
634 }
635
636 sd_event_set_watchdog(event, true);
637
638 n = sd_listen_fds(1);
639 if (n < 0) {
640 log_error("Failed to receive sockets from parent.");
641 r = n;
642 goto finish;
643 } else if (n == 0) {
644 log_error("Didn't get any sockets passed in.");
645 r = -EINVAL;
646 goto finish;
647 }
648
649 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
650 r = add_listen_socket(&context, event, fd);
651 if (r < 0)
652 goto finish;
653 }
654
655 r = sd_event_loop(event);
656 if (r < 0) {
657 log_error("Failed to run event loop: %s", strerror(-r));
658 goto finish;
659 }
660
661 finish:
662 context_free(&context);
663
664 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
665 }