]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
Merge pull request #1668 from ssahani/net1
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4 This file is part of systemd.
5
6 Copyright 2013 David Strauss
7
8 systemd is free software; you can redistribute it and/or modify it
9 under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or
11 (at your option) any later version.
12
13 systemd is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 Lesser General Public License for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <getopt.h>
25 #include <netdb.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/un.h>
31 #include <unistd.h>
32
33 #include "sd-daemon.h"
34 #include "sd-event.h"
35 #include "sd-resolve.h"
36
37 #include "fd-util.h"
38 #include "log.h"
39 #include "path-util.h"
40 #include "set.h"
41 #include "socket-util.h"
42 #include "string-util.h"
43 #include "util.h"
44
45 #define BUFFER_SIZE (256 * 1024)
46 #define CONNECTIONS_MAX 256
47
48 static const char *arg_remote_host = NULL;
49
50 typedef struct Context {
51 sd_event *event;
52 sd_resolve *resolve;
53
54 Set *listen;
55 Set *connections;
56 } Context;
57
58 typedef struct Connection {
59 Context *context;
60
61 int server_fd, client_fd;
62 int server_to_client_buffer[2]; /* a pipe */
63 int client_to_server_buffer[2]; /* a pipe */
64
65 size_t server_to_client_buffer_full, client_to_server_buffer_full;
66 size_t server_to_client_buffer_size, client_to_server_buffer_size;
67
68 sd_event_source *server_event_source, *client_event_source;
69
70 sd_resolve_query *resolve_query;
71 } Connection;
72
73 static void connection_free(Connection *c) {
74 assert(c);
75
76 if (c->context)
77 set_remove(c->context->connections, c);
78
79 sd_event_source_unref(c->server_event_source);
80 sd_event_source_unref(c->client_event_source);
81
82 safe_close(c->server_fd);
83 safe_close(c->client_fd);
84
85 safe_close_pair(c->server_to_client_buffer);
86 safe_close_pair(c->client_to_server_buffer);
87
88 sd_resolve_query_unref(c->resolve_query);
89
90 free(c);
91 }
92
93 static void context_free(Context *context) {
94 sd_event_source *es;
95 Connection *c;
96
97 assert(context);
98
99 while ((es = set_steal_first(context->listen)))
100 sd_event_source_unref(es);
101
102 while ((c = set_first(context->connections)))
103 connection_free(c);
104
105 set_free(context->listen);
106 set_free(context->connections);
107
108 sd_event_unref(context->event);
109 sd_resolve_unref(context->resolve);
110 }
111
112 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
113 int r;
114
115 assert(c);
116 assert(buffer);
117 assert(sz);
118
119 if (buffer[0] >= 0)
120 return 0;
121
122 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
123 if (r < 0)
124 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
125
126 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
127
128 r = fcntl(buffer[0], F_GETPIPE_SZ);
129 if (r < 0)
130 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
131
132 assert(r > 0);
133 *sz = r;
134
135 return 0;
136 }
137
138 static int connection_shovel(
139 Connection *c,
140 int *from, int buffer[2], int *to,
141 size_t *full, size_t *sz,
142 sd_event_source **from_source, sd_event_source **to_source) {
143
144 bool shoveled;
145
146 assert(c);
147 assert(from);
148 assert(buffer);
149 assert(buffer[0] >= 0);
150 assert(buffer[1] >= 0);
151 assert(to);
152 assert(full);
153 assert(sz);
154 assert(from_source);
155 assert(to_source);
156
157 do {
158 ssize_t z;
159
160 shoveled = false;
161
162 if (*full < *sz && *from >= 0 && *to >= 0) {
163 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
164 if (z > 0) {
165 *full += z;
166 shoveled = true;
167 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
168 *from_source = sd_event_source_unref(*from_source);
169 *from = safe_close(*from);
170 } else if (errno != EAGAIN && errno != EINTR)
171 return log_error_errno(errno, "Failed to splice: %m");
172 }
173
174 if (*full > 0 && *to >= 0) {
175 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
176 if (z > 0) {
177 *full -= z;
178 shoveled = true;
179 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
180 *to_source = sd_event_source_unref(*to_source);
181 *to = safe_close(*to);
182 } else if (errno != EAGAIN && errno != EINTR)
183 return log_error_errno(errno, "Failed to splice: %m");
184 }
185 } while (shoveled);
186
187 return 0;
188 }
189
190 static int connection_enable_event_sources(Connection *c);
191
192 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
193 Connection *c = userdata;
194 int r;
195
196 assert(s);
197 assert(fd >= 0);
198 assert(c);
199
200 r = connection_shovel(c,
201 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
202 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
203 &c->server_event_source, &c->client_event_source);
204 if (r < 0)
205 goto quit;
206
207 r = connection_shovel(c,
208 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
209 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
210 &c->client_event_source, &c->server_event_source);
211 if (r < 0)
212 goto quit;
213
214 /* EOF on both sides? */
215 if (c->server_fd == -1 && c->client_fd == -1)
216 goto quit;
217
218 /* Server closed, and all data written to client? */
219 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
220 goto quit;
221
222 /* Client closed, and all data written to server? */
223 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
224 goto quit;
225
226 r = connection_enable_event_sources(c);
227 if (r < 0)
228 goto quit;
229
230 return 1;
231
232 quit:
233 connection_free(c);
234 return 0; /* ignore errors, continue serving */
235 }
236
237 static int connection_enable_event_sources(Connection *c) {
238 uint32_t a = 0, b = 0;
239 int r;
240
241 assert(c);
242
243 if (c->server_to_client_buffer_full > 0)
244 b |= EPOLLOUT;
245 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
246 a |= EPOLLIN;
247
248 if (c->client_to_server_buffer_full > 0)
249 a |= EPOLLOUT;
250 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
251 b |= EPOLLIN;
252
253 if (c->server_event_source)
254 r = sd_event_source_set_io_events(c->server_event_source, a);
255 else if (c->server_fd >= 0)
256 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
257 else
258 r = 0;
259
260 if (r < 0)
261 return log_error_errno(r, "Failed to set up server event source: %m");
262
263 if (c->client_event_source)
264 r = sd_event_source_set_io_events(c->client_event_source, b);
265 else if (c->client_fd >= 0)
266 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
267 else
268 r = 0;
269
270 if (r < 0)
271 return log_error_errno(r, "Failed to set up client event source: %m");
272
273 return 0;
274 }
275
276 static int connection_complete(Connection *c) {
277 int r;
278
279 assert(c);
280
281 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
282 if (r < 0)
283 goto fail;
284
285 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
286 if (r < 0)
287 goto fail;
288
289 r = connection_enable_event_sources(c);
290 if (r < 0)
291 goto fail;
292
293 return 0;
294
295 fail:
296 connection_free(c);
297 return 0; /* ignore errors, continue serving */
298 }
299
300 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
301 Connection *c = userdata;
302 socklen_t solen;
303 int error, r;
304
305 assert(s);
306 assert(fd >= 0);
307 assert(c);
308
309 solen = sizeof(error);
310 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
311 if (r < 0) {
312 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
313 goto fail;
314 }
315
316 if (error != 0) {
317 log_error_errno(error, "Failed to connect to remote host: %m");
318 goto fail;
319 }
320
321 c->client_event_source = sd_event_source_unref(c->client_event_source);
322
323 return connection_complete(c);
324
325 fail:
326 connection_free(c);
327 return 0; /* ignore errors, continue serving */
328 }
329
330 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
331 int r;
332
333 assert(c);
334 assert(sa);
335 assert(salen);
336
337 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
338 if (c->client_fd < 0) {
339 log_error_errno(errno, "Failed to get remote socket: %m");
340 goto fail;
341 }
342
343 r = connect(c->client_fd, sa, salen);
344 if (r < 0) {
345 if (errno == EINPROGRESS) {
346 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
347 if (r < 0) {
348 log_error_errno(r, "Failed to add connection socket: %m");
349 goto fail;
350 }
351
352 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
353 if (r < 0) {
354 log_error_errno(r, "Failed to enable oneshot event source: %m");
355 goto fail;
356 }
357 } else {
358 log_error_errno(errno, "Failed to connect to remote host: %m");
359 goto fail;
360 }
361 } else {
362 r = connection_complete(c);
363 if (r < 0)
364 goto fail;
365 }
366
367 return 0;
368
369 fail:
370 connection_free(c);
371 return 0; /* ignore errors, continue serving */
372 }
373
374 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
375 Connection *c = userdata;
376
377 assert(q);
378 assert(c);
379
380 if (ret != 0) {
381 log_error("Failed to resolve host: %s", gai_strerror(ret));
382 goto fail;
383 }
384
385 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
386
387 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
388
389 fail:
390 connection_free(c);
391 return 0; /* ignore errors, continue serving */
392 }
393
394 static int resolve_remote(Connection *c) {
395
396 static const struct addrinfo hints = {
397 .ai_family = AF_UNSPEC,
398 .ai_socktype = SOCK_STREAM,
399 .ai_flags = AI_ADDRCONFIG
400 };
401
402 union sockaddr_union sa = {};
403 const char *node, *service;
404 socklen_t salen;
405 int r;
406
407 if (path_is_absolute(arg_remote_host)) {
408 sa.un.sun_family = AF_UNIX;
409 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
410 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
411
412 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
413
414 return connection_start(c, &sa.sa, salen);
415 }
416
417 if (arg_remote_host[0] == '@') {
418 sa.un.sun_family = AF_UNIX;
419 sa.un.sun_path[0] = 0;
420 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
421 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
422
423 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
424
425 return connection_start(c, &sa.sa, salen);
426 }
427
428 service = strrchr(arg_remote_host, ':');
429 if (service) {
430 node = strndupa(arg_remote_host, service - arg_remote_host);
431 service ++;
432 } else {
433 node = arg_remote_host;
434 service = "80";
435 }
436
437 log_debug("Looking up address info for %s:%s", node, service);
438 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
439 if (r < 0) {
440 log_error_errno(r, "Failed to resolve remote host: %m");
441 goto fail;
442 }
443
444 return 0;
445
446 fail:
447 connection_free(c);
448 return 0; /* ignore errors, continue serving */
449 }
450
451 static int add_connection_socket(Context *context, int fd) {
452 Connection *c;
453 int r;
454
455 assert(context);
456 assert(fd >= 0);
457
458 if (set_size(context->connections) > CONNECTIONS_MAX) {
459 log_warning("Hit connection limit, refusing connection.");
460 safe_close(fd);
461 return 0;
462 }
463
464 r = set_ensure_allocated(&context->connections, NULL);
465 if (r < 0) {
466 log_oom();
467 return 0;
468 }
469
470 c = new0(Connection, 1);
471 if (!c) {
472 log_oom();
473 return 0;
474 }
475
476 c->context = context;
477 c->server_fd = fd;
478 c->client_fd = -1;
479 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
480 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
481
482 r = set_put(context->connections, c);
483 if (r < 0) {
484 free(c);
485 log_oom();
486 return 0;
487 }
488
489 return resolve_remote(c);
490 }
491
492 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
493 _cleanup_free_ char *peer = NULL;
494 Context *context = userdata;
495 int nfd = -1, r;
496
497 assert(s);
498 assert(fd >= 0);
499 assert(revents & EPOLLIN);
500 assert(context);
501
502 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
503 if (nfd < 0) {
504 if (errno != -EAGAIN)
505 log_warning_errno(errno, "Failed to accept() socket: %m");
506 } else {
507 getpeername_pretty(nfd, &peer);
508 log_debug("New connection from %s", strna(peer));
509
510 r = add_connection_socket(context, nfd);
511 if (r < 0) {
512 log_error_errno(r, "Failed to accept connection, ignoring: %m");
513 safe_close(fd);
514 }
515 }
516
517 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
518 if (r < 0) {
519 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
520 sd_event_exit(context->event, r);
521 return r;
522 }
523
524 return 1;
525 }
526
527 static int add_listen_socket(Context *context, int fd) {
528 sd_event_source *source;
529 int r;
530
531 assert(context);
532 assert(fd >= 0);
533
534 r = set_ensure_allocated(&context->listen, NULL);
535 if (r < 0) {
536 log_oom();
537 return r;
538 }
539
540 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
541 if (r < 0)
542 return log_error_errno(r, "Failed to determine socket type: %m");
543 if (r == 0) {
544 log_error("Passed in socket is not a stream socket.");
545 return -EINVAL;
546 }
547
548 r = fd_nonblock(fd, true);
549 if (r < 0)
550 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
551
552 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
553 if (r < 0)
554 return log_error_errno(r, "Failed to add event source: %m");
555
556 r = set_put(context->listen, source);
557 if (r < 0) {
558 log_error_errno(r, "Failed to add source to set: %m");
559 sd_event_source_unref(source);
560 return r;
561 }
562
563 /* Set the watcher to oneshot in case other processes are also
564 * watching to accept(). */
565 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
566 if (r < 0)
567 return log_error_errno(r, "Failed to enable oneshot mode: %m");
568
569 return 0;
570 }
571
572 static void help(void) {
573 printf("%1$s [HOST:PORT]\n"
574 "%1$s [SOCKET]\n\n"
575 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
576 " -h --help Show this help\n"
577 " --version Show package version\n",
578 program_invocation_short_name);
579 }
580
581 static int parse_argv(int argc, char *argv[]) {
582
583 enum {
584 ARG_VERSION = 0x100,
585 ARG_IGNORE_ENV
586 };
587
588 static const struct option options[] = {
589 { "help", no_argument, NULL, 'h' },
590 { "version", no_argument, NULL, ARG_VERSION },
591 {}
592 };
593
594 int c;
595
596 assert(argc >= 0);
597 assert(argv);
598
599 while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
600
601 switch (c) {
602
603 case 'h':
604 help();
605 return 0;
606
607 case ARG_VERSION:
608 return version();
609
610 case '?':
611 return -EINVAL;
612
613 default:
614 assert_not_reached("Unhandled option");
615 }
616
617 if (optind >= argc) {
618 log_error("Not enough parameters.");
619 return -EINVAL;
620 }
621
622 if (argc != optind+1) {
623 log_error("Too many parameters.");
624 return -EINVAL;
625 }
626
627 arg_remote_host = argv[optind];
628 return 1;
629 }
630
631 int main(int argc, char *argv[]) {
632 Context context = {};
633 int r, n, fd;
634
635 log_parse_environment();
636 log_open();
637
638 r = parse_argv(argc, argv);
639 if (r <= 0)
640 goto finish;
641
642 r = sd_event_default(&context.event);
643 if (r < 0) {
644 log_error_errno(r, "Failed to allocate event loop: %m");
645 goto finish;
646 }
647
648 r = sd_resolve_default(&context.resolve);
649 if (r < 0) {
650 log_error_errno(r, "Failed to allocate resolver: %m");
651 goto finish;
652 }
653
654 r = sd_resolve_attach_event(context.resolve, context.event, 0);
655 if (r < 0) {
656 log_error_errno(r, "Failed to attach resolver: %m");
657 goto finish;
658 }
659
660 sd_event_set_watchdog(context.event, true);
661
662 n = sd_listen_fds(1);
663 if (n < 0) {
664 log_error("Failed to receive sockets from parent.");
665 r = n;
666 goto finish;
667 } else if (n == 0) {
668 log_error("Didn't get any sockets passed in.");
669 r = -EINVAL;
670 goto finish;
671 }
672
673 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
674 r = add_listen_socket(&context, fd);
675 if (r < 0)
676 goto finish;
677 }
678
679 r = sd_event_loop(context.event);
680 if (r < 0) {
681 log_error_errno(r, "Failed to run event loop: %m");
682 goto finish;
683 }
684
685 finish:
686 context_free(&context);
687
688 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
689 }