]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
treewide: use log_*_errno whenever %m is in the format string
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4 This file is part of systemd.
5
6 Copyright 2013 David Strauss
7
8 systemd is free software; you can redistribute it and/or modify it
9 under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or
11 (at your option) any later version.
12
13 systemd is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 Lesser General Public License for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "sd-resolve.h"
37 #include "log.h"
38 #include "socket-util.h"
39 #include "util.h"
40 #include "event-util.h"
41 #include "build.h"
42 #include "set.h"
43 #include "path-util.h"
44
45 #define BUFFER_SIZE (256 * 1024)
46 #define CONNECTIONS_MAX 256
47
48 static const char *arg_remote_host = NULL;
49
50 typedef struct Context {
51 sd_event *event;
52 sd_resolve *resolve;
53
54 Set *listen;
55 Set *connections;
56 } Context;
57
58 typedef struct Connection {
59 Context *context;
60
61 int server_fd, client_fd;
62 int server_to_client_buffer[2]; /* a pipe */
63 int client_to_server_buffer[2]; /* a pipe */
64
65 size_t server_to_client_buffer_full, client_to_server_buffer_full;
66 size_t server_to_client_buffer_size, client_to_server_buffer_size;
67
68 sd_event_source *server_event_source, *client_event_source;
69
70 sd_resolve_query *resolve_query;
71 } Connection;
72
73 static void connection_free(Connection *c) {
74 assert(c);
75
76 if (c->context)
77 set_remove(c->context->connections, c);
78
79 sd_event_source_unref(c->server_event_source);
80 sd_event_source_unref(c->client_event_source);
81
82 safe_close(c->server_fd);
83 safe_close(c->client_fd);
84
85 safe_close_pair(c->server_to_client_buffer);
86 safe_close_pair(c->client_to_server_buffer);
87
88 sd_resolve_query_unref(c->resolve_query);
89
90 free(c);
91 }
92
93 static void context_free(Context *context) {
94 sd_event_source *es;
95 Connection *c;
96
97 assert(context);
98
99 while ((es = set_steal_first(context->listen)))
100 sd_event_source_unref(es);
101
102 while ((c = set_first(context->connections)))
103 connection_free(c);
104
105 set_free(context->listen);
106 set_free(context->connections);
107
108 sd_event_unref(context->event);
109 sd_resolve_unref(context->resolve);
110 }
111
112 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
113 int r;
114
115 assert(c);
116 assert(buffer);
117 assert(sz);
118
119 if (buffer[0] >= 0)
120 return 0;
121
122 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
123 if (r < 0) {
124 log_error_errno(errno, "Failed to allocate pipe buffer: %m");
125 return -errno;
126 }
127
128 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
129
130 r = fcntl(buffer[0], F_GETPIPE_SZ);
131 if (r < 0) {
132 log_error_errno(errno, "Failed to get pipe buffer size: %m");
133 return -errno;
134 }
135
136 assert(r > 0);
137 *sz = r;
138
139 return 0;
140 }
141
142 static int connection_shovel(
143 Connection *c,
144 int *from, int buffer[2], int *to,
145 size_t *full, size_t *sz,
146 sd_event_source **from_source, sd_event_source **to_source) {
147
148 bool shoveled;
149
150 assert(c);
151 assert(from);
152 assert(buffer);
153 assert(buffer[0] >= 0);
154 assert(buffer[1] >= 0);
155 assert(to);
156 assert(full);
157 assert(sz);
158 assert(from_source);
159 assert(to_source);
160
161 do {
162 ssize_t z;
163
164 shoveled = false;
165
166 if (*full < *sz && *from >= 0 && *to >= 0) {
167 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
168 if (z > 0) {
169 *full += z;
170 shoveled = true;
171 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
172 *from_source = sd_event_source_unref(*from_source);
173 *from = safe_close(*from);
174 } else if (errno != EAGAIN && errno != EINTR) {
175 log_error_errno(errno, "Failed to splice: %m");
176 return -errno;
177 }
178 }
179
180 if (*full > 0 && *to >= 0) {
181 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
182 if (z > 0) {
183 *full -= z;
184 shoveled = true;
185 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
186 *to_source = sd_event_source_unref(*to_source);
187 *to = safe_close(*to);
188 } else if (errno != EAGAIN && errno != EINTR) {
189 log_error_errno(errno, "Failed to splice: %m");
190 return -errno;
191 }
192 }
193 } while (shoveled);
194
195 return 0;
196 }
197
198 static int connection_enable_event_sources(Connection *c);
199
200 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
201 Connection *c = userdata;
202 int r;
203
204 assert(s);
205 assert(fd >= 0);
206 assert(c);
207
208 r = connection_shovel(c,
209 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
210 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
211 &c->server_event_source, &c->client_event_source);
212 if (r < 0)
213 goto quit;
214
215 r = connection_shovel(c,
216 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
217 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
218 &c->client_event_source, &c->server_event_source);
219 if (r < 0)
220 goto quit;
221
222 /* EOF on both sides? */
223 if (c->server_fd == -1 && c->client_fd == -1)
224 goto quit;
225
226 /* Server closed, and all data written to client? */
227 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
228 goto quit;
229
230 /* Client closed, and all data written to server? */
231 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
232 goto quit;
233
234 r = connection_enable_event_sources(c);
235 if (r < 0)
236 goto quit;
237
238 return 1;
239
240 quit:
241 connection_free(c);
242 return 0; /* ignore errors, continue serving */
243 }
244
245 static int connection_enable_event_sources(Connection *c) {
246 uint32_t a = 0, b = 0;
247 int r;
248
249 assert(c);
250
251 if (c->server_to_client_buffer_full > 0)
252 b |= EPOLLOUT;
253 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
254 a |= EPOLLIN;
255
256 if (c->client_to_server_buffer_full > 0)
257 a |= EPOLLOUT;
258 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
259 b |= EPOLLIN;
260
261 if (c->server_event_source)
262 r = sd_event_source_set_io_events(c->server_event_source, a);
263 else if (c->server_fd >= 0)
264 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
265 else
266 r = 0;
267
268 if (r < 0)
269 return log_error_errno(r, "Failed to set up server event source: %m");
270
271 if (c->client_event_source)
272 r = sd_event_source_set_io_events(c->client_event_source, b);
273 else if (c->client_fd >= 0)
274 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
275 else
276 r = 0;
277
278 if (r < 0)
279 return log_error_errno(r, "Failed to set up client event source: %m");
280
281 return 0;
282 }
283
284 static int connection_complete(Connection *c) {
285 int r;
286
287 assert(c);
288
289 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
290 if (r < 0)
291 goto fail;
292
293 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
294 if (r < 0)
295 goto fail;
296
297 r = connection_enable_event_sources(c);
298 if (r < 0)
299 goto fail;
300
301 return 0;
302
303 fail:
304 connection_free(c);
305 return 0; /* ignore errors, continue serving */
306 }
307
308 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
309 Connection *c = userdata;
310 socklen_t solen;
311 int error, r;
312
313 assert(s);
314 assert(fd >= 0);
315 assert(c);
316
317 solen = sizeof(error);
318 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
319 if (r < 0) {
320 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
321 goto fail;
322 }
323
324 if (error != 0) {
325 log_error_errno(error, "Failed to connect to remote host: %m");
326 goto fail;
327 }
328
329 c->client_event_source = sd_event_source_unref(c->client_event_source);
330
331 return connection_complete(c);
332
333 fail:
334 connection_free(c);
335 return 0; /* ignore errors, continue serving */
336 }
337
338 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
339 int r;
340
341 assert(c);
342 assert(sa);
343 assert(salen);
344
345 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
346 if (c->client_fd < 0) {
347 log_error_errno(errno, "Failed to get remote socket: %m");
348 goto fail;
349 }
350
351 r = connect(c->client_fd, sa, salen);
352 if (r < 0) {
353 if (errno == EINPROGRESS) {
354 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
355 if (r < 0) {
356 log_error_errno(r, "Failed to add connection socket: %m");
357 goto fail;
358 }
359
360 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
361 if (r < 0) {
362 log_error_errno(r, "Failed to enable oneshot event source: %m");
363 goto fail;
364 }
365 } else {
366 log_error_errno(errno, "Failed to connect to remote host: %m");
367 goto fail;
368 }
369 } else {
370 r = connection_complete(c);
371 if (r < 0)
372 goto fail;
373 }
374
375 return 0;
376
377 fail:
378 connection_free(c);
379 return 0; /* ignore errors, continue serving */
380 }
381
382 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
383 Connection *c = userdata;
384
385 assert(q);
386 assert(c);
387
388 if (ret != 0) {
389 log_error("Failed to resolve host: %s", gai_strerror(ret));
390 goto fail;
391 }
392
393 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
394
395 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
396
397 fail:
398 connection_free(c);
399 return 0; /* ignore errors, continue serving */
400 }
401
402 static int resolve_remote(Connection *c) {
403
404 static const struct addrinfo hints = {
405 .ai_family = AF_UNSPEC,
406 .ai_socktype = SOCK_STREAM,
407 .ai_flags = AI_ADDRCONFIG
408 };
409
410 union sockaddr_union sa = {};
411 const char *node, *service;
412 socklen_t salen;
413 int r;
414
415 if (path_is_absolute(arg_remote_host)) {
416 sa.un.sun_family = AF_UNIX;
417 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
418 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
419
420 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
421
422 return connection_start(c, &sa.sa, salen);
423 }
424
425 if (arg_remote_host[0] == '@') {
426 sa.un.sun_family = AF_UNIX;
427 sa.un.sun_path[0] = 0;
428 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
429 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
430
431 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
432
433 return connection_start(c, &sa.sa, salen);
434 }
435
436 service = strrchr(arg_remote_host, ':');
437 if (service) {
438 node = strndupa(arg_remote_host, service - arg_remote_host);
439 service ++;
440 } else {
441 node = arg_remote_host;
442 service = "80";
443 }
444
445 log_debug("Looking up address info for %s:%s", node, service);
446 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
447 if (r < 0) {
448 log_error_errno(r, "Failed to resolve remote host: %m");
449 goto fail;
450 }
451
452 return 0;
453
454 fail:
455 connection_free(c);
456 return 0; /* ignore errors, continue serving */
457 }
458
459 static int add_connection_socket(Context *context, int fd) {
460 Connection *c;
461 int r;
462
463 assert(context);
464 assert(fd >= 0);
465
466 if (set_size(context->connections) > CONNECTIONS_MAX) {
467 log_warning("Hit connection limit, refusing connection.");
468 safe_close(fd);
469 return 0;
470 }
471
472 r = set_ensure_allocated(&context->connections, NULL);
473 if (r < 0) {
474 log_oom();
475 return 0;
476 }
477
478 c = new0(Connection, 1);
479 if (!c) {
480 log_oom();
481 return 0;
482 }
483
484 c->context = context;
485 c->server_fd = fd;
486 c->client_fd = -1;
487 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
488 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
489
490 r = set_put(context->connections, c);
491 if (r < 0) {
492 free(c);
493 log_oom();
494 return 0;
495 }
496
497 return resolve_remote(c);
498 }
499
500 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
501 _cleanup_free_ char *peer = NULL;
502 Context *context = userdata;
503 int nfd = -1, r;
504
505 assert(s);
506 assert(fd >= 0);
507 assert(revents & EPOLLIN);
508 assert(context);
509
510 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
511 if (nfd < 0) {
512 if (errno != -EAGAIN)
513 log_warning_errno(errno, "Failed to accept() socket: %m");
514 } else {
515 getpeername_pretty(nfd, &peer);
516 log_debug("New connection from %s", strna(peer));
517
518 r = add_connection_socket(context, nfd);
519 if (r < 0) {
520 log_error_errno(r, "Failed to accept connection, ignoring: %m");
521 safe_close(fd);
522 }
523 }
524
525 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
526 if (r < 0) {
527 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
528 sd_event_exit(context->event, r);
529 return r;
530 }
531
532 return 1;
533 }
534
535 static int add_listen_socket(Context *context, int fd) {
536 sd_event_source *source;
537 int r;
538
539 assert(context);
540 assert(fd >= 0);
541
542 r = set_ensure_allocated(&context->listen, NULL);
543 if (r < 0) {
544 log_oom();
545 return r;
546 }
547
548 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
549 if (r < 0)
550 return log_error_errno(r, "Failed to determine socket type: %m");
551 if (r == 0) {
552 log_error("Passed in socket is not a stream socket.");
553 return -EINVAL;
554 }
555
556 r = fd_nonblock(fd, true);
557 if (r < 0)
558 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
559
560 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
561 if (r < 0)
562 return log_error_errno(r, "Failed to add event source: %m");
563
564 r = set_put(context->listen, source);
565 if (r < 0) {
566 log_error_errno(r, "Failed to add source to set: %m");
567 sd_event_source_unref(source);
568 return r;
569 }
570
571 /* Set the watcher to oneshot in case other processes are also
572 * watching to accept(). */
573 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
574 if (r < 0)
575 return log_error_errno(r, "Failed to enable oneshot mode: %m");
576
577 return 0;
578 }
579
580 static void help(void) {
581 printf("%1$s [HOST:PORT]\n"
582 "%1$s [SOCKET]\n\n"
583 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
584 " -h --help Show this help\n"
585 " --version Show package version\n",
586 program_invocation_short_name);
587 }
588
589 static int parse_argv(int argc, char *argv[]) {
590
591 enum {
592 ARG_VERSION = 0x100,
593 ARG_IGNORE_ENV
594 };
595
596 static const struct option options[] = {
597 { "help", no_argument, NULL, 'h' },
598 { "version", no_argument, NULL, ARG_VERSION },
599 {}
600 };
601
602 int c;
603
604 assert(argc >= 0);
605 assert(argv);
606
607 while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
608
609 switch (c) {
610
611 case 'h':
612 help();
613 return 0;
614
615 case ARG_VERSION:
616 puts(PACKAGE_STRING);
617 puts(SYSTEMD_FEATURES);
618 return 0;
619
620 case '?':
621 return -EINVAL;
622
623 default:
624 assert_not_reached("Unhandled option");
625 }
626
627 if (optind >= argc) {
628 log_error("Not enough parameters.");
629 return -EINVAL;
630 }
631
632 if (argc != optind+1) {
633 log_error("Too many parameters.");
634 return -EINVAL;
635 }
636
637 arg_remote_host = argv[optind];
638 return 1;
639 }
640
641 int main(int argc, char *argv[]) {
642 Context context = {};
643 int r, n, fd;
644
645 log_parse_environment();
646 log_open();
647
648 r = parse_argv(argc, argv);
649 if (r <= 0)
650 goto finish;
651
652 r = sd_event_default(&context.event);
653 if (r < 0) {
654 log_error_errno(r, "Failed to allocate event loop: %m");
655 goto finish;
656 }
657
658 r = sd_resolve_default(&context.resolve);
659 if (r < 0) {
660 log_error_errno(r, "Failed to allocate resolver: %m");
661 goto finish;
662 }
663
664 r = sd_resolve_attach_event(context.resolve, context.event, 0);
665 if (r < 0) {
666 log_error_errno(r, "Failed to attach resolver: %m");
667 goto finish;
668 }
669
670 sd_event_set_watchdog(context.event, true);
671
672 n = sd_listen_fds(1);
673 if (n < 0) {
674 log_error("Failed to receive sockets from parent.");
675 r = n;
676 goto finish;
677 } else if (n == 0) {
678 log_error("Didn't get any sockets passed in.");
679 r = -EINVAL;
680 goto finish;
681 }
682
683 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
684 r = add_listen_socket(&context, fd);
685 if (r < 0)
686 goto finish;
687 }
688
689 r = sd_event_loop(context.event);
690 if (r < 0) {
691 log_error_errno(r, "Failed to run event loop: %m");
692 goto finish;
693 }
694
695 finish:
696 context_free(&context);
697
698 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
699 }