]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
Merge pull request #1372 from jemk/prefsrc
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4 This file is part of systemd.
5
6 Copyright 2013 David Strauss
7
8 systemd is free software; you can redistribute it and/or modify it
9 under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or
11 (at your option) any later version.
12
13 systemd is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 Lesser General Public License for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <getopt.h>
25 #include <netdb.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/un.h>
31 #include <unistd.h>
32
33 #include "sd-daemon.h"
34 #include "sd-event.h"
35 #include "sd-resolve.h"
36
37 #include "log.h"
38 #include "path-util.h"
39 #include "set.h"
40 #include "socket-util.h"
41 #include "util.h"
42
43 #define BUFFER_SIZE (256 * 1024)
44 #define CONNECTIONS_MAX 256
45
46 static const char *arg_remote_host = NULL;
47
48 typedef struct Context {
49 sd_event *event;
50 sd_resolve *resolve;
51
52 Set *listen;
53 Set *connections;
54 } Context;
55
56 typedef struct Connection {
57 Context *context;
58
59 int server_fd, client_fd;
60 int server_to_client_buffer[2]; /* a pipe */
61 int client_to_server_buffer[2]; /* a pipe */
62
63 size_t server_to_client_buffer_full, client_to_server_buffer_full;
64 size_t server_to_client_buffer_size, client_to_server_buffer_size;
65
66 sd_event_source *server_event_source, *client_event_source;
67
68 sd_resolve_query *resolve_query;
69 } Connection;
70
71 static void connection_free(Connection *c) {
72 assert(c);
73
74 if (c->context)
75 set_remove(c->context->connections, c);
76
77 sd_event_source_unref(c->server_event_source);
78 sd_event_source_unref(c->client_event_source);
79
80 safe_close(c->server_fd);
81 safe_close(c->client_fd);
82
83 safe_close_pair(c->server_to_client_buffer);
84 safe_close_pair(c->client_to_server_buffer);
85
86 sd_resolve_query_unref(c->resolve_query);
87
88 free(c);
89 }
90
91 static void context_free(Context *context) {
92 sd_event_source *es;
93 Connection *c;
94
95 assert(context);
96
97 while ((es = set_steal_first(context->listen)))
98 sd_event_source_unref(es);
99
100 while ((c = set_first(context->connections)))
101 connection_free(c);
102
103 set_free(context->listen);
104 set_free(context->connections);
105
106 sd_event_unref(context->event);
107 sd_resolve_unref(context->resolve);
108 }
109
110 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
111 int r;
112
113 assert(c);
114 assert(buffer);
115 assert(sz);
116
117 if (buffer[0] >= 0)
118 return 0;
119
120 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
121 if (r < 0)
122 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
123
124 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
125
126 r = fcntl(buffer[0], F_GETPIPE_SZ);
127 if (r < 0)
128 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
129
130 assert(r > 0);
131 *sz = r;
132
133 return 0;
134 }
135
136 static int connection_shovel(
137 Connection *c,
138 int *from, int buffer[2], int *to,
139 size_t *full, size_t *sz,
140 sd_event_source **from_source, sd_event_source **to_source) {
141
142 bool shoveled;
143
144 assert(c);
145 assert(from);
146 assert(buffer);
147 assert(buffer[0] >= 0);
148 assert(buffer[1] >= 0);
149 assert(to);
150 assert(full);
151 assert(sz);
152 assert(from_source);
153 assert(to_source);
154
155 do {
156 ssize_t z;
157
158 shoveled = false;
159
160 if (*full < *sz && *from >= 0 && *to >= 0) {
161 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
162 if (z > 0) {
163 *full += z;
164 shoveled = true;
165 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
166 *from_source = sd_event_source_unref(*from_source);
167 *from = safe_close(*from);
168 } else if (errno != EAGAIN && errno != EINTR)
169 return log_error_errno(errno, "Failed to splice: %m");
170 }
171
172 if (*full > 0 && *to >= 0) {
173 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
174 if (z > 0) {
175 *full -= z;
176 shoveled = true;
177 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
178 *to_source = sd_event_source_unref(*to_source);
179 *to = safe_close(*to);
180 } else if (errno != EAGAIN && errno != EINTR)
181 return log_error_errno(errno, "Failed to splice: %m");
182 }
183 } while (shoveled);
184
185 return 0;
186 }
187
188 static int connection_enable_event_sources(Connection *c);
189
190 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
191 Connection *c = userdata;
192 int r;
193
194 assert(s);
195 assert(fd >= 0);
196 assert(c);
197
198 r = connection_shovel(c,
199 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
200 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
201 &c->server_event_source, &c->client_event_source);
202 if (r < 0)
203 goto quit;
204
205 r = connection_shovel(c,
206 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
207 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
208 &c->client_event_source, &c->server_event_source);
209 if (r < 0)
210 goto quit;
211
212 /* EOF on both sides? */
213 if (c->server_fd == -1 && c->client_fd == -1)
214 goto quit;
215
216 /* Server closed, and all data written to client? */
217 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
218 goto quit;
219
220 /* Client closed, and all data written to server? */
221 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
222 goto quit;
223
224 r = connection_enable_event_sources(c);
225 if (r < 0)
226 goto quit;
227
228 return 1;
229
230 quit:
231 connection_free(c);
232 return 0; /* ignore errors, continue serving */
233 }
234
235 static int connection_enable_event_sources(Connection *c) {
236 uint32_t a = 0, b = 0;
237 int r;
238
239 assert(c);
240
241 if (c->server_to_client_buffer_full > 0)
242 b |= EPOLLOUT;
243 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
244 a |= EPOLLIN;
245
246 if (c->client_to_server_buffer_full > 0)
247 a |= EPOLLOUT;
248 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
249 b |= EPOLLIN;
250
251 if (c->server_event_source)
252 r = sd_event_source_set_io_events(c->server_event_source, a);
253 else if (c->server_fd >= 0)
254 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
255 else
256 r = 0;
257
258 if (r < 0)
259 return log_error_errno(r, "Failed to set up server event source: %m");
260
261 if (c->client_event_source)
262 r = sd_event_source_set_io_events(c->client_event_source, b);
263 else if (c->client_fd >= 0)
264 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
265 else
266 r = 0;
267
268 if (r < 0)
269 return log_error_errno(r, "Failed to set up client event source: %m");
270
271 return 0;
272 }
273
274 static int connection_complete(Connection *c) {
275 int r;
276
277 assert(c);
278
279 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
280 if (r < 0)
281 goto fail;
282
283 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
284 if (r < 0)
285 goto fail;
286
287 r = connection_enable_event_sources(c);
288 if (r < 0)
289 goto fail;
290
291 return 0;
292
293 fail:
294 connection_free(c);
295 return 0; /* ignore errors, continue serving */
296 }
297
298 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
299 Connection *c = userdata;
300 socklen_t solen;
301 int error, r;
302
303 assert(s);
304 assert(fd >= 0);
305 assert(c);
306
307 solen = sizeof(error);
308 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
309 if (r < 0) {
310 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
311 goto fail;
312 }
313
314 if (error != 0) {
315 log_error_errno(error, "Failed to connect to remote host: %m");
316 goto fail;
317 }
318
319 c->client_event_source = sd_event_source_unref(c->client_event_source);
320
321 return connection_complete(c);
322
323 fail:
324 connection_free(c);
325 return 0; /* ignore errors, continue serving */
326 }
327
328 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
329 int r;
330
331 assert(c);
332 assert(sa);
333 assert(salen);
334
335 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
336 if (c->client_fd < 0) {
337 log_error_errno(errno, "Failed to get remote socket: %m");
338 goto fail;
339 }
340
341 r = connect(c->client_fd, sa, salen);
342 if (r < 0) {
343 if (errno == EINPROGRESS) {
344 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
345 if (r < 0) {
346 log_error_errno(r, "Failed to add connection socket: %m");
347 goto fail;
348 }
349
350 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
351 if (r < 0) {
352 log_error_errno(r, "Failed to enable oneshot event source: %m");
353 goto fail;
354 }
355 } else {
356 log_error_errno(errno, "Failed to connect to remote host: %m");
357 goto fail;
358 }
359 } else {
360 r = connection_complete(c);
361 if (r < 0)
362 goto fail;
363 }
364
365 return 0;
366
367 fail:
368 connection_free(c);
369 return 0; /* ignore errors, continue serving */
370 }
371
372 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
373 Connection *c = userdata;
374
375 assert(q);
376 assert(c);
377
378 if (ret != 0) {
379 log_error("Failed to resolve host: %s", gai_strerror(ret));
380 goto fail;
381 }
382
383 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
384
385 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
386
387 fail:
388 connection_free(c);
389 return 0; /* ignore errors, continue serving */
390 }
391
392 static int resolve_remote(Connection *c) {
393
394 static const struct addrinfo hints = {
395 .ai_family = AF_UNSPEC,
396 .ai_socktype = SOCK_STREAM,
397 .ai_flags = AI_ADDRCONFIG
398 };
399
400 union sockaddr_union sa = {};
401 const char *node, *service;
402 socklen_t salen;
403 int r;
404
405 if (path_is_absolute(arg_remote_host)) {
406 sa.un.sun_family = AF_UNIX;
407 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
408 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
409
410 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
411
412 return connection_start(c, &sa.sa, salen);
413 }
414
415 if (arg_remote_host[0] == '@') {
416 sa.un.sun_family = AF_UNIX;
417 sa.un.sun_path[0] = 0;
418 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
419 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
420
421 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
422
423 return connection_start(c, &sa.sa, salen);
424 }
425
426 service = strrchr(arg_remote_host, ':');
427 if (service) {
428 node = strndupa(arg_remote_host, service - arg_remote_host);
429 service ++;
430 } else {
431 node = arg_remote_host;
432 service = "80";
433 }
434
435 log_debug("Looking up address info for %s:%s", node, service);
436 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
437 if (r < 0) {
438 log_error_errno(r, "Failed to resolve remote host: %m");
439 goto fail;
440 }
441
442 return 0;
443
444 fail:
445 connection_free(c);
446 return 0; /* ignore errors, continue serving */
447 }
448
449 static int add_connection_socket(Context *context, int fd) {
450 Connection *c;
451 int r;
452
453 assert(context);
454 assert(fd >= 0);
455
456 if (set_size(context->connections) > CONNECTIONS_MAX) {
457 log_warning("Hit connection limit, refusing connection.");
458 safe_close(fd);
459 return 0;
460 }
461
462 r = set_ensure_allocated(&context->connections, NULL);
463 if (r < 0) {
464 log_oom();
465 return 0;
466 }
467
468 c = new0(Connection, 1);
469 if (!c) {
470 log_oom();
471 return 0;
472 }
473
474 c->context = context;
475 c->server_fd = fd;
476 c->client_fd = -1;
477 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
478 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
479
480 r = set_put(context->connections, c);
481 if (r < 0) {
482 free(c);
483 log_oom();
484 return 0;
485 }
486
487 return resolve_remote(c);
488 }
489
490 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
491 _cleanup_free_ char *peer = NULL;
492 Context *context = userdata;
493 int nfd = -1, r;
494
495 assert(s);
496 assert(fd >= 0);
497 assert(revents & EPOLLIN);
498 assert(context);
499
500 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
501 if (nfd < 0) {
502 if (errno != -EAGAIN)
503 log_warning_errno(errno, "Failed to accept() socket: %m");
504 } else {
505 getpeername_pretty(nfd, &peer);
506 log_debug("New connection from %s", strna(peer));
507
508 r = add_connection_socket(context, nfd);
509 if (r < 0) {
510 log_error_errno(r, "Failed to accept connection, ignoring: %m");
511 safe_close(fd);
512 }
513 }
514
515 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
516 if (r < 0) {
517 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
518 sd_event_exit(context->event, r);
519 return r;
520 }
521
522 return 1;
523 }
524
525 static int add_listen_socket(Context *context, int fd) {
526 sd_event_source *source;
527 int r;
528
529 assert(context);
530 assert(fd >= 0);
531
532 r = set_ensure_allocated(&context->listen, NULL);
533 if (r < 0) {
534 log_oom();
535 return r;
536 }
537
538 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
539 if (r < 0)
540 return log_error_errno(r, "Failed to determine socket type: %m");
541 if (r == 0) {
542 log_error("Passed in socket is not a stream socket.");
543 return -EINVAL;
544 }
545
546 r = fd_nonblock(fd, true);
547 if (r < 0)
548 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
549
550 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
551 if (r < 0)
552 return log_error_errno(r, "Failed to add event source: %m");
553
554 r = set_put(context->listen, source);
555 if (r < 0) {
556 log_error_errno(r, "Failed to add source to set: %m");
557 sd_event_source_unref(source);
558 return r;
559 }
560
561 /* Set the watcher to oneshot in case other processes are also
562 * watching to accept(). */
563 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
564 if (r < 0)
565 return log_error_errno(r, "Failed to enable oneshot mode: %m");
566
567 return 0;
568 }
569
570 static void help(void) {
571 printf("%1$s [HOST:PORT]\n"
572 "%1$s [SOCKET]\n\n"
573 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
574 " -h --help Show this help\n"
575 " --version Show package version\n",
576 program_invocation_short_name);
577 }
578
579 static int parse_argv(int argc, char *argv[]) {
580
581 enum {
582 ARG_VERSION = 0x100,
583 ARG_IGNORE_ENV
584 };
585
586 static const struct option options[] = {
587 { "help", no_argument, NULL, 'h' },
588 { "version", no_argument, NULL, ARG_VERSION },
589 {}
590 };
591
592 int c;
593
594 assert(argc >= 0);
595 assert(argv);
596
597 while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
598
599 switch (c) {
600
601 case 'h':
602 help();
603 return 0;
604
605 case ARG_VERSION:
606 return version();
607
608 case '?':
609 return -EINVAL;
610
611 default:
612 assert_not_reached("Unhandled option");
613 }
614
615 if (optind >= argc) {
616 log_error("Not enough parameters.");
617 return -EINVAL;
618 }
619
620 if (argc != optind+1) {
621 log_error("Too many parameters.");
622 return -EINVAL;
623 }
624
625 arg_remote_host = argv[optind];
626 return 1;
627 }
628
629 int main(int argc, char *argv[]) {
630 Context context = {};
631 int r, n, fd;
632
633 log_parse_environment();
634 log_open();
635
636 r = parse_argv(argc, argv);
637 if (r <= 0)
638 goto finish;
639
640 r = sd_event_default(&context.event);
641 if (r < 0) {
642 log_error_errno(r, "Failed to allocate event loop: %m");
643 goto finish;
644 }
645
646 r = sd_resolve_default(&context.resolve);
647 if (r < 0) {
648 log_error_errno(r, "Failed to allocate resolver: %m");
649 goto finish;
650 }
651
652 r = sd_resolve_attach_event(context.resolve, context.event, 0);
653 if (r < 0) {
654 log_error_errno(r, "Failed to attach resolver: %m");
655 goto finish;
656 }
657
658 sd_event_set_watchdog(context.event, true);
659
660 n = sd_listen_fds(1);
661 if (n < 0) {
662 log_error("Failed to receive sockets from parent.");
663 r = n;
664 goto finish;
665 } else if (n == 0) {
666 log_error("Didn't get any sockets passed in.");
667 r = -EINVAL;
668 goto finish;
669 }
670
671 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
672 r = add_listen_socket(&context, fd);
673 if (r < 0)
674 goto finish;
675 }
676
677 r = sd_event_loop(context.event);
678 if (r < 0) {
679 log_error_errno(r, "Failed to run event loop: %m");
680 goto finish;
681 }
682
683 finish:
684 context_free(&context);
685
686 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
687 }