]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/socket-proxy/socket-proxyd.c
util-lib: split our string related calls from util.[ch] into its own file string...
[thirdparty/systemd.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4 This file is part of systemd.
5
6 Copyright 2013 David Strauss
7
8 systemd is free software; you can redistribute it and/or modify it
9 under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or
11 (at your option) any later version.
12
13 systemd is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 Lesser General Public License for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <getopt.h>
25 #include <netdb.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/un.h>
31 #include <unistd.h>
32
33 #include "sd-daemon.h"
34 #include "sd-event.h"
35 #include "sd-resolve.h"
36
37 #include "log.h"
38 #include "path-util.h"
39 #include "set.h"
40 #include "socket-util.h"
41 #include "string-util.h"
42 #include "util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 static const char *arg_remote_host = NULL;
48
49 typedef struct Context {
50 sd_event *event;
51 sd_resolve *resolve;
52
53 Set *listen;
54 Set *connections;
55 } Context;
56
57 typedef struct Connection {
58 Context *context;
59
60 int server_fd, client_fd;
61 int server_to_client_buffer[2]; /* a pipe */
62 int client_to_server_buffer[2]; /* a pipe */
63
64 size_t server_to_client_buffer_full, client_to_server_buffer_full;
65 size_t server_to_client_buffer_size, client_to_server_buffer_size;
66
67 sd_event_source *server_event_source, *client_event_source;
68
69 sd_resolve_query *resolve_query;
70 } Connection;
71
72 static void connection_free(Connection *c) {
73 assert(c);
74
75 if (c->context)
76 set_remove(c->context->connections, c);
77
78 sd_event_source_unref(c->server_event_source);
79 sd_event_source_unref(c->client_event_source);
80
81 safe_close(c->server_fd);
82 safe_close(c->client_fd);
83
84 safe_close_pair(c->server_to_client_buffer);
85 safe_close_pair(c->client_to_server_buffer);
86
87 sd_resolve_query_unref(c->resolve_query);
88
89 free(c);
90 }
91
92 static void context_free(Context *context) {
93 sd_event_source *es;
94 Connection *c;
95
96 assert(context);
97
98 while ((es = set_steal_first(context->listen)))
99 sd_event_source_unref(es);
100
101 while ((c = set_first(context->connections)))
102 connection_free(c);
103
104 set_free(context->listen);
105 set_free(context->connections);
106
107 sd_event_unref(context->event);
108 sd_resolve_unref(context->resolve);
109 }
110
111 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
112 int r;
113
114 assert(c);
115 assert(buffer);
116 assert(sz);
117
118 if (buffer[0] >= 0)
119 return 0;
120
121 r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
122 if (r < 0)
123 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
124
125 (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
126
127 r = fcntl(buffer[0], F_GETPIPE_SZ);
128 if (r < 0)
129 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
130
131 assert(r > 0);
132 *sz = r;
133
134 return 0;
135 }
136
137 static int connection_shovel(
138 Connection *c,
139 int *from, int buffer[2], int *to,
140 size_t *full, size_t *sz,
141 sd_event_source **from_source, sd_event_source **to_source) {
142
143 bool shoveled;
144
145 assert(c);
146 assert(from);
147 assert(buffer);
148 assert(buffer[0] >= 0);
149 assert(buffer[1] >= 0);
150 assert(to);
151 assert(full);
152 assert(sz);
153 assert(from_source);
154 assert(to_source);
155
156 do {
157 ssize_t z;
158
159 shoveled = false;
160
161 if (*full < *sz && *from >= 0 && *to >= 0) {
162 z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
163 if (z > 0) {
164 *full += z;
165 shoveled = true;
166 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
167 *from_source = sd_event_source_unref(*from_source);
168 *from = safe_close(*from);
169 } else if (errno != EAGAIN && errno != EINTR)
170 return log_error_errno(errno, "Failed to splice: %m");
171 }
172
173 if (*full > 0 && *to >= 0) {
174 z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
175 if (z > 0) {
176 *full -= z;
177 shoveled = true;
178 } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
179 *to_source = sd_event_source_unref(*to_source);
180 *to = safe_close(*to);
181 } else if (errno != EAGAIN && errno != EINTR)
182 return log_error_errno(errno, "Failed to splice: %m");
183 }
184 } while (shoveled);
185
186 return 0;
187 }
188
189 static int connection_enable_event_sources(Connection *c);
190
191 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
192 Connection *c = userdata;
193 int r;
194
195 assert(s);
196 assert(fd >= 0);
197 assert(c);
198
199 r = connection_shovel(c,
200 &c->server_fd, c->server_to_client_buffer, &c->client_fd,
201 &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
202 &c->server_event_source, &c->client_event_source);
203 if (r < 0)
204 goto quit;
205
206 r = connection_shovel(c,
207 &c->client_fd, c->client_to_server_buffer, &c->server_fd,
208 &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
209 &c->client_event_source, &c->server_event_source);
210 if (r < 0)
211 goto quit;
212
213 /* EOF on both sides? */
214 if (c->server_fd == -1 && c->client_fd == -1)
215 goto quit;
216
217 /* Server closed, and all data written to client? */
218 if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
219 goto quit;
220
221 /* Client closed, and all data written to server? */
222 if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
223 goto quit;
224
225 r = connection_enable_event_sources(c);
226 if (r < 0)
227 goto quit;
228
229 return 1;
230
231 quit:
232 connection_free(c);
233 return 0; /* ignore errors, continue serving */
234 }
235
236 static int connection_enable_event_sources(Connection *c) {
237 uint32_t a = 0, b = 0;
238 int r;
239
240 assert(c);
241
242 if (c->server_to_client_buffer_full > 0)
243 b |= EPOLLOUT;
244 if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
245 a |= EPOLLIN;
246
247 if (c->client_to_server_buffer_full > 0)
248 a |= EPOLLOUT;
249 if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
250 b |= EPOLLIN;
251
252 if (c->server_event_source)
253 r = sd_event_source_set_io_events(c->server_event_source, a);
254 else if (c->server_fd >= 0)
255 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
256 else
257 r = 0;
258
259 if (r < 0)
260 return log_error_errno(r, "Failed to set up server event source: %m");
261
262 if (c->client_event_source)
263 r = sd_event_source_set_io_events(c->client_event_source, b);
264 else if (c->client_fd >= 0)
265 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
266 else
267 r = 0;
268
269 if (r < 0)
270 return log_error_errno(r, "Failed to set up client event source: %m");
271
272 return 0;
273 }
274
275 static int connection_complete(Connection *c) {
276 int r;
277
278 assert(c);
279
280 r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
281 if (r < 0)
282 goto fail;
283
284 r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
285 if (r < 0)
286 goto fail;
287
288 r = connection_enable_event_sources(c);
289 if (r < 0)
290 goto fail;
291
292 return 0;
293
294 fail:
295 connection_free(c);
296 return 0; /* ignore errors, continue serving */
297 }
298
299 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
300 Connection *c = userdata;
301 socklen_t solen;
302 int error, r;
303
304 assert(s);
305 assert(fd >= 0);
306 assert(c);
307
308 solen = sizeof(error);
309 r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
310 if (r < 0) {
311 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
312 goto fail;
313 }
314
315 if (error != 0) {
316 log_error_errno(error, "Failed to connect to remote host: %m");
317 goto fail;
318 }
319
320 c->client_event_source = sd_event_source_unref(c->client_event_source);
321
322 return connection_complete(c);
323
324 fail:
325 connection_free(c);
326 return 0; /* ignore errors, continue serving */
327 }
328
329 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
330 int r;
331
332 assert(c);
333 assert(sa);
334 assert(salen);
335
336 c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
337 if (c->client_fd < 0) {
338 log_error_errno(errno, "Failed to get remote socket: %m");
339 goto fail;
340 }
341
342 r = connect(c->client_fd, sa, salen);
343 if (r < 0) {
344 if (errno == EINPROGRESS) {
345 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
346 if (r < 0) {
347 log_error_errno(r, "Failed to add connection socket: %m");
348 goto fail;
349 }
350
351 r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
352 if (r < 0) {
353 log_error_errno(r, "Failed to enable oneshot event source: %m");
354 goto fail;
355 }
356 } else {
357 log_error_errno(errno, "Failed to connect to remote host: %m");
358 goto fail;
359 }
360 } else {
361 r = connection_complete(c);
362 if (r < 0)
363 goto fail;
364 }
365
366 return 0;
367
368 fail:
369 connection_free(c);
370 return 0; /* ignore errors, continue serving */
371 }
372
373 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
374 Connection *c = userdata;
375
376 assert(q);
377 assert(c);
378
379 if (ret != 0) {
380 log_error("Failed to resolve host: %s", gai_strerror(ret));
381 goto fail;
382 }
383
384 c->resolve_query = sd_resolve_query_unref(c->resolve_query);
385
386 return connection_start(c, ai->ai_addr, ai->ai_addrlen);
387
388 fail:
389 connection_free(c);
390 return 0; /* ignore errors, continue serving */
391 }
392
393 static int resolve_remote(Connection *c) {
394
395 static const struct addrinfo hints = {
396 .ai_family = AF_UNSPEC,
397 .ai_socktype = SOCK_STREAM,
398 .ai_flags = AI_ADDRCONFIG
399 };
400
401 union sockaddr_union sa = {};
402 const char *node, *service;
403 socklen_t salen;
404 int r;
405
406 if (path_is_absolute(arg_remote_host)) {
407 sa.un.sun_family = AF_UNIX;
408 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
409 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
410
411 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
412
413 return connection_start(c, &sa.sa, salen);
414 }
415
416 if (arg_remote_host[0] == '@') {
417 sa.un.sun_family = AF_UNIX;
418 sa.un.sun_path[0] = 0;
419 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
420 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
421
422 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
423
424 return connection_start(c, &sa.sa, salen);
425 }
426
427 service = strrchr(arg_remote_host, ':');
428 if (service) {
429 node = strndupa(arg_remote_host, service - arg_remote_host);
430 service ++;
431 } else {
432 node = arg_remote_host;
433 service = "80";
434 }
435
436 log_debug("Looking up address info for %s:%s", node, service);
437 r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
438 if (r < 0) {
439 log_error_errno(r, "Failed to resolve remote host: %m");
440 goto fail;
441 }
442
443 return 0;
444
445 fail:
446 connection_free(c);
447 return 0; /* ignore errors, continue serving */
448 }
449
450 static int add_connection_socket(Context *context, int fd) {
451 Connection *c;
452 int r;
453
454 assert(context);
455 assert(fd >= 0);
456
457 if (set_size(context->connections) > CONNECTIONS_MAX) {
458 log_warning("Hit connection limit, refusing connection.");
459 safe_close(fd);
460 return 0;
461 }
462
463 r = set_ensure_allocated(&context->connections, NULL);
464 if (r < 0) {
465 log_oom();
466 return 0;
467 }
468
469 c = new0(Connection, 1);
470 if (!c) {
471 log_oom();
472 return 0;
473 }
474
475 c->context = context;
476 c->server_fd = fd;
477 c->client_fd = -1;
478 c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
479 c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
480
481 r = set_put(context->connections, c);
482 if (r < 0) {
483 free(c);
484 log_oom();
485 return 0;
486 }
487
488 return resolve_remote(c);
489 }
490
491 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
492 _cleanup_free_ char *peer = NULL;
493 Context *context = userdata;
494 int nfd = -1, r;
495
496 assert(s);
497 assert(fd >= 0);
498 assert(revents & EPOLLIN);
499 assert(context);
500
501 nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
502 if (nfd < 0) {
503 if (errno != -EAGAIN)
504 log_warning_errno(errno, "Failed to accept() socket: %m");
505 } else {
506 getpeername_pretty(nfd, &peer);
507 log_debug("New connection from %s", strna(peer));
508
509 r = add_connection_socket(context, nfd);
510 if (r < 0) {
511 log_error_errno(r, "Failed to accept connection, ignoring: %m");
512 safe_close(fd);
513 }
514 }
515
516 r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
517 if (r < 0) {
518 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
519 sd_event_exit(context->event, r);
520 return r;
521 }
522
523 return 1;
524 }
525
526 static int add_listen_socket(Context *context, int fd) {
527 sd_event_source *source;
528 int r;
529
530 assert(context);
531 assert(fd >= 0);
532
533 r = set_ensure_allocated(&context->listen, NULL);
534 if (r < 0) {
535 log_oom();
536 return r;
537 }
538
539 r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
540 if (r < 0)
541 return log_error_errno(r, "Failed to determine socket type: %m");
542 if (r == 0) {
543 log_error("Passed in socket is not a stream socket.");
544 return -EINVAL;
545 }
546
547 r = fd_nonblock(fd, true);
548 if (r < 0)
549 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
550
551 r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
552 if (r < 0)
553 return log_error_errno(r, "Failed to add event source: %m");
554
555 r = set_put(context->listen, source);
556 if (r < 0) {
557 log_error_errno(r, "Failed to add source to set: %m");
558 sd_event_source_unref(source);
559 return r;
560 }
561
562 /* Set the watcher to oneshot in case other processes are also
563 * watching to accept(). */
564 r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
565 if (r < 0)
566 return log_error_errno(r, "Failed to enable oneshot mode: %m");
567
568 return 0;
569 }
570
571 static void help(void) {
572 printf("%1$s [HOST:PORT]\n"
573 "%1$s [SOCKET]\n\n"
574 "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
575 " -h --help Show this help\n"
576 " --version Show package version\n",
577 program_invocation_short_name);
578 }
579
580 static int parse_argv(int argc, char *argv[]) {
581
582 enum {
583 ARG_VERSION = 0x100,
584 ARG_IGNORE_ENV
585 };
586
587 static const struct option options[] = {
588 { "help", no_argument, NULL, 'h' },
589 { "version", no_argument, NULL, ARG_VERSION },
590 {}
591 };
592
593 int c;
594
595 assert(argc >= 0);
596 assert(argv);
597
598 while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
599
600 switch (c) {
601
602 case 'h':
603 help();
604 return 0;
605
606 case ARG_VERSION:
607 return version();
608
609 case '?':
610 return -EINVAL;
611
612 default:
613 assert_not_reached("Unhandled option");
614 }
615
616 if (optind >= argc) {
617 log_error("Not enough parameters.");
618 return -EINVAL;
619 }
620
621 if (argc != optind+1) {
622 log_error("Too many parameters.");
623 return -EINVAL;
624 }
625
626 arg_remote_host = argv[optind];
627 return 1;
628 }
629
630 int main(int argc, char *argv[]) {
631 Context context = {};
632 int r, n, fd;
633
634 log_parse_environment();
635 log_open();
636
637 r = parse_argv(argc, argv);
638 if (r <= 0)
639 goto finish;
640
641 r = sd_event_default(&context.event);
642 if (r < 0) {
643 log_error_errno(r, "Failed to allocate event loop: %m");
644 goto finish;
645 }
646
647 r = sd_resolve_default(&context.resolve);
648 if (r < 0) {
649 log_error_errno(r, "Failed to allocate resolver: %m");
650 goto finish;
651 }
652
653 r = sd_resolve_attach_event(context.resolve, context.event, 0);
654 if (r < 0) {
655 log_error_errno(r, "Failed to attach resolver: %m");
656 goto finish;
657 }
658
659 sd_event_set_watchdog(context.event, true);
660
661 n = sd_listen_fds(1);
662 if (n < 0) {
663 log_error("Failed to receive sockets from parent.");
664 r = n;
665 goto finish;
666 } else if (n == 0) {
667 log_error("Didn't get any sockets passed in.");
668 r = -EINVAL;
669 goto finish;
670 }
671
672 for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
673 r = add_listen_socket(&context, fd);
674 if (r < 0)
675 goto finish;
676 }
677
678 r = sd_event_loop(context.event);
679 if (r < 0) {
680 log_error_errno(r, "Failed to run event loop: %m");
681 goto finish;
682 }
683
684 finish:
685 context_free(&context);
686
687 return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
688 }