]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/resolve/resolved-dns-stream.c
grypt-util: drop two emacs modelines
[thirdparty/systemd.git] / src / resolve / resolved-dns-stream.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2 /***
3 Copyright 2014 Lennart Poettering
4 ***/
5
6 #include <netinet/tcp.h>
7
8 #include "alloc-util.h"
9 #include "fd-util.h"
10 #include "io-util.h"
11 #include "missing.h"
12 #include "resolved-dns-stream.h"
13
14 #define DNS_STREAM_TIMEOUT_USEC (10 * USEC_PER_SEC)
15 #define DNS_STREAMS_MAX 128
16
17 #define WRITE_TLS_DATA 1
18
19 static void dns_stream_stop(DnsStream *s) {
20 assert(s);
21
22 s->io_event_source = sd_event_source_unref(s->io_event_source);
23 s->timeout_event_source = sd_event_source_unref(s->timeout_event_source);
24 s->fd = safe_close(s->fd);
25 }
26
27 static int dns_stream_update_io(DnsStream *s) {
28 int f = 0;
29
30 assert(s);
31
32 if (s->write_packet && s->n_written < sizeof(s->write_size) + s->write_packet->size)
33 f |= EPOLLOUT;
34 else if (!ordered_set_isempty(s->write_queue)) {
35 dns_packet_unref(s->write_packet);
36 s->write_packet = ordered_set_steal_first(s->write_queue);
37 s->write_size = htobe16(s->write_packet->size);
38 s->n_written = 0;
39 f |= EPOLLOUT;
40 }
41 if (!s->read_packet || s->n_read < sizeof(s->read_size) + s->read_packet->size)
42 f |= EPOLLIN;
43
44 return sd_event_source_set_io_events(s->io_event_source, f);
45 }
46
47 static int dns_stream_complete(DnsStream *s, int error) {
48 assert(s);
49
50 #if HAVE_GNUTLS
51 if (s->tls_session && IN_SET(error, ETIMEDOUT, 0)) {
52 int r;
53
54 r = gnutls_bye(s->tls_session, GNUTLS_SHUT_RDWR);
55 if (r == GNUTLS_E_AGAIN && !s->tls_bye) {
56 dns_stream_ref(s); /* keep reference for closing TLS session */
57 s->tls_bye = true;
58 } else
59 dns_stream_stop(s);
60 } else
61 #endif
62 dns_stream_stop(s);
63
64 if (s->complete)
65 s->complete(s, error);
66 else /* the default action if no completion function is set is to close the stream */
67 dns_stream_unref(s);
68
69 return 0;
70 }
71
72 static int dns_stream_identify(DnsStream *s) {
73 union {
74 struct cmsghdr header; /* For alignment */
75 uint8_t buffer[CMSG_SPACE(MAXSIZE(struct in_pktinfo, struct in6_pktinfo))
76 + EXTRA_CMSG_SPACE /* kernel appears to require extra space */];
77 } control;
78 struct msghdr mh = {};
79 struct cmsghdr *cmsg;
80 socklen_t sl;
81 int r;
82
83 assert(s);
84
85 if (s->identified)
86 return 0;
87
88 /* Query the local side */
89 s->local_salen = sizeof(s->local);
90 r = getsockname(s->fd, &s->local.sa, &s->local_salen);
91 if (r < 0)
92 return -errno;
93 if (s->local.sa.sa_family == AF_INET6 && s->ifindex <= 0)
94 s->ifindex = s->local.in6.sin6_scope_id;
95
96 /* Query the remote side */
97 s->peer_salen = sizeof(s->peer);
98 r = getpeername(s->fd, &s->peer.sa, &s->peer_salen);
99 if (r < 0)
100 return -errno;
101 if (s->peer.sa.sa_family == AF_INET6 && s->ifindex <= 0)
102 s->ifindex = s->peer.in6.sin6_scope_id;
103
104 /* Check consistency */
105 assert(s->peer.sa.sa_family == s->local.sa.sa_family);
106 assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
107
108 /* Query connection meta information */
109 sl = sizeof(control);
110 if (s->peer.sa.sa_family == AF_INET) {
111 r = getsockopt(s->fd, IPPROTO_IP, IP_PKTOPTIONS, &control, &sl);
112 if (r < 0)
113 return -errno;
114 } else if (s->peer.sa.sa_family == AF_INET6) {
115
116 r = getsockopt(s->fd, IPPROTO_IPV6, IPV6_2292PKTOPTIONS, &control, &sl);
117 if (r < 0)
118 return -errno;
119 } else
120 return -EAFNOSUPPORT;
121
122 mh.msg_control = &control;
123 mh.msg_controllen = sl;
124
125 CMSG_FOREACH(cmsg, &mh) {
126
127 if (cmsg->cmsg_level == IPPROTO_IPV6) {
128 assert(s->peer.sa.sa_family == AF_INET6);
129
130 switch (cmsg->cmsg_type) {
131
132 case IPV6_PKTINFO: {
133 struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg);
134
135 if (s->ifindex <= 0)
136 s->ifindex = i->ipi6_ifindex;
137 break;
138 }
139
140 case IPV6_HOPLIMIT:
141 s->ttl = *(int *) CMSG_DATA(cmsg);
142 break;
143 }
144
145 } else if (cmsg->cmsg_level == IPPROTO_IP) {
146 assert(s->peer.sa.sa_family == AF_INET);
147
148 switch (cmsg->cmsg_type) {
149
150 case IP_PKTINFO: {
151 struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg);
152
153 if (s->ifindex <= 0)
154 s->ifindex = i->ipi_ifindex;
155 break;
156 }
157
158 case IP_TTL:
159 s->ttl = *(int *) CMSG_DATA(cmsg);
160 break;
161 }
162 }
163 }
164
165 /* The Linux kernel sets the interface index to the loopback
166 * device if the connection came from the local host since it
167 * avoids the routing table in such a case. Let's unset the
168 * interface index in such a case. */
169 if (s->ifindex == LOOPBACK_IFINDEX)
170 s->ifindex = 0;
171
172 /* If we don't know the interface index still, we look for the
173 * first local interface with a matching address. Yuck! */
174 if (s->ifindex <= 0)
175 s->ifindex = manager_find_ifindex(s->manager, s->local.sa.sa_family, s->local.sa.sa_family == AF_INET ? (union in_addr_union*) &s->local.in.sin_addr : (union in_addr_union*) &s->local.in6.sin6_addr);
176
177 if (s->protocol == DNS_PROTOCOL_LLMNR && s->ifindex > 0) {
178 uint32_t ifindex = htobe32(s->ifindex);
179
180 /* Make sure all packets for this connection are sent on the same interface */
181 if (s->local.sa.sa_family == AF_INET) {
182 r = setsockopt(s->fd, IPPROTO_IP, IP_UNICAST_IF, &ifindex, sizeof(ifindex));
183 if (r < 0)
184 log_debug_errno(errno, "Failed to invoke IP_UNICAST_IF: %m");
185 } else if (s->local.sa.sa_family == AF_INET6) {
186 r = setsockopt(s->fd, IPPROTO_IPV6, IPV6_UNICAST_IF, &ifindex, sizeof(ifindex));
187 if (r < 0)
188 log_debug_errno(errno, "Failed to invoke IPV6_UNICAST_IF: %m");
189 }
190 }
191
192 s->identified = true;
193
194 return 0;
195 }
196
197 static ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt, int flags) {
198 ssize_t r;
199
200 assert(s);
201 assert(iov);
202
203 #if HAVE_GNUTLS
204 if (s->tls_session && !(flags & WRITE_TLS_DATA)) {
205 ssize_t ss;
206 size_t i;
207
208 r = 0;
209 for (i = 0; i < iovcnt; i++) {
210 ss = gnutls_record_send(s->tls_session, iov[i].iov_base, iov[i].iov_len);
211 if (ss < 0) {
212 switch(ss) {
213
214 case GNUTLS_E_INTERRUPTED:
215 return -EINTR;
216 case GNUTLS_E_AGAIN:
217 return -EAGAIN;
218 default:
219 log_debug("Failed to invoke gnutls_record_send: %s", gnutls_strerror(ss));
220 return -EIO;
221 }
222 }
223
224 r += ss;
225 if (ss != (ssize_t) iov[i].iov_len)
226 continue;
227 }
228 } else
229 #endif
230 if (s->tfo_salen > 0) {
231 struct msghdr hdr = {
232 .msg_iov = (struct iovec*) iov,
233 .msg_iovlen = iovcnt,
234 .msg_name = &s->tfo_address.sa,
235 .msg_namelen = s->tfo_salen
236 };
237
238 r = sendmsg(s->fd, &hdr, MSG_FASTOPEN);
239 if (r < 0) {
240 if (errno == EOPNOTSUPP) {
241 s->tfo_salen = 0;
242 r = connect(s->fd, &s->tfo_address.sa, s->tfo_salen);
243 if (r < 0)
244 return -errno;
245
246 r = -EAGAIN;
247 } else if (errno == EINPROGRESS)
248 r = -EAGAIN;
249 } else
250 s->tfo_salen = 0; /* connection is made */
251 } else
252 r = writev(s->fd, iov, iovcnt);
253
254 return r;
255 }
256
257 static ssize_t dns_stream_read(DnsStream *s, void *buf, size_t count) {
258 ssize_t ss;
259
260 #if HAVE_GNUTLS
261 if (s->tls_session) {
262 ss = gnutls_record_recv(s->tls_session, buf, count);
263 if (ss < 0) {
264 switch(ss) {
265
266 case GNUTLS_E_INTERRUPTED:
267 return -EINTR;
268 case GNUTLS_E_AGAIN:
269 return -EAGAIN;
270 default:
271 log_debug("Failed to invoke gnutls_record_send: %s", gnutls_strerror(ss));
272 return -EIO;
273 }
274 } else if (s->on_connection) {
275 int r;
276
277 r = s->on_connection(s);
278 s->on_connection = NULL; /* only call once */
279 if (r < 0)
280 return r;
281 }
282 } else
283 #endif
284 ss = read(s->fd, buf, count);
285
286 return ss;
287 }
288
289 #if HAVE_GNUTLS
290 static ssize_t dns_stream_tls_writev(gnutls_transport_ptr_t p, const giovec_t * iov, int iovcnt) {
291 int r;
292
293 assert(p);
294
295 r = dns_stream_writev((DnsStream*) p, (struct iovec*) iov, iovcnt, WRITE_TLS_DATA);
296 if (r < 0) {
297 errno = -r;
298 return -1;
299 }
300
301 return r;
302 }
303 #endif
304
305 static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
306 DnsStream *s = userdata;
307
308 assert(s);
309
310 return dns_stream_complete(s, ETIMEDOUT);
311 }
312
313 static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
314 DnsStream *s = userdata;
315 int r;
316
317 assert(s);
318
319 #if HAVE_GNUTLS
320 if (s->tls_bye) {
321 assert(s->tls_session);
322
323 r = gnutls_bye(s->tls_session, GNUTLS_SHUT_RDWR);
324 if (r != GNUTLS_E_AGAIN) {
325 s->tls_bye = false;
326 dns_stream_unref(s);
327 }
328
329 return 0;
330 }
331
332 if (s->tls_handshake < 0) {
333 assert(s->tls_session);
334
335 s->tls_handshake = gnutls_handshake(s->tls_session);
336 if (s->tls_handshake >= 0) {
337 if (s->on_connection && !(gnutls_session_get_flags(s->tls_session) & GNUTLS_SFLAGS_FALSE_START)) {
338 r = s->on_connection(s);
339 s->on_connection = NULL; /* only call once */
340 if (r < 0)
341 return r;
342 }
343 } else {
344 if (gnutls_error_is_fatal(s->tls_handshake))
345 return dns_stream_complete(s, ECONNREFUSED);
346 else
347 return 0;
348 }
349
350 }
351 #endif
352
353 /* only identify after connecting */
354 if (s->tfo_salen == 0) {
355 r = dns_stream_identify(s);
356 if (r < 0)
357 return dns_stream_complete(s, -r);
358 }
359
360 if ((revents & EPOLLOUT) &&
361 s->write_packet &&
362 s->n_written < sizeof(s->write_size) + s->write_packet->size) {
363
364 struct iovec iov[2];
365 ssize_t ss;
366
367 iov[0].iov_base = &s->write_size;
368 iov[0].iov_len = sizeof(s->write_size);
369 iov[1].iov_base = DNS_PACKET_DATA(s->write_packet);
370 iov[1].iov_len = s->write_packet->size;
371
372 IOVEC_INCREMENT(iov, 2, s->n_written);
373
374 ss = dns_stream_writev(s, iov, 2, 0);
375 if (ss < 0) {
376 if (!IN_SET(errno, EINTR, EAGAIN))
377 return dns_stream_complete(s, errno);
378 } else
379 s->n_written += ss;
380
381 /* Are we done? If so, disable the event source for EPOLLOUT */
382 if (s->n_written >= sizeof(s->write_size) + s->write_packet->size) {
383 r = dns_stream_update_io(s);
384 if (r < 0)
385 return dns_stream_complete(s, -r);
386 }
387 }
388
389 if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
390 (!s->read_packet ||
391 s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
392
393 if (s->n_read < sizeof(s->read_size)) {
394 ssize_t ss;
395
396 ss = dns_stream_read(s, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
397 if (ss < 0) {
398 if (!IN_SET(errno, EINTR, EAGAIN))
399 return dns_stream_complete(s, errno);
400 } else if (ss == 0)
401 return dns_stream_complete(s, ECONNRESET);
402 else
403 s->n_read += ss;
404 }
405
406 if (s->n_read >= sizeof(s->read_size)) {
407
408 if (be16toh(s->read_size) < DNS_PACKET_HEADER_SIZE)
409 return dns_stream_complete(s, EBADMSG);
410
411 if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) {
412 ssize_t ss;
413
414 if (!s->read_packet) {
415 r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size), DNS_PACKET_SIZE_MAX);
416 if (r < 0)
417 return dns_stream_complete(s, -r);
418
419 s->read_packet->size = be16toh(s->read_size);
420 s->read_packet->ipproto = IPPROTO_TCP;
421 s->read_packet->family = s->peer.sa.sa_family;
422 s->read_packet->ttl = s->ttl;
423 s->read_packet->ifindex = s->ifindex;
424
425 if (s->read_packet->family == AF_INET) {
426 s->read_packet->sender.in = s->peer.in.sin_addr;
427 s->read_packet->sender_port = be16toh(s->peer.in.sin_port);
428 s->read_packet->destination.in = s->local.in.sin_addr;
429 s->read_packet->destination_port = be16toh(s->local.in.sin_port);
430 } else {
431 assert(s->read_packet->family == AF_INET6);
432 s->read_packet->sender.in6 = s->peer.in6.sin6_addr;
433 s->read_packet->sender_port = be16toh(s->peer.in6.sin6_port);
434 s->read_packet->destination.in6 = s->local.in6.sin6_addr;
435 s->read_packet->destination_port = be16toh(s->local.in6.sin6_port);
436
437 if (s->read_packet->ifindex == 0)
438 s->read_packet->ifindex = s->peer.in6.sin6_scope_id;
439 if (s->read_packet->ifindex == 0)
440 s->read_packet->ifindex = s->local.in6.sin6_scope_id;
441 }
442 }
443
444 ss = dns_stream_read(s,
445 (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
446 sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
447 if (ss < 0) {
448 if (!IN_SET(errno, EINTR, EAGAIN))
449 return dns_stream_complete(s, errno);
450 } else if (ss == 0)
451 return dns_stream_complete(s, ECONNRESET);
452 else
453 s->n_read += ss;
454 }
455
456 /* Are we done? If so, disable the event source for EPOLLIN */
457 if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
458 /* If there's a packet handler
459 * installed, call that. Note that
460 * this is optional... */
461 if (s->on_packet) {
462 r = s->on_packet(s);
463 if (r < 0)
464 return r;
465 }
466
467 r = dns_stream_update_io(s);
468 if (r < 0)
469 return dns_stream_complete(s, -r);
470 }
471 }
472 }
473
474 if ((s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) &&
475 (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size))
476 return dns_stream_complete(s, 0);
477
478 return 0;
479 }
480
481 DnsStream *dns_stream_unref(DnsStream *s) {
482 DnsPacket *p;
483 Iterator i;
484
485 if (!s)
486 return NULL;
487
488 assert(s->n_ref > 0);
489 s->n_ref--;
490
491 if (s->n_ref > 0)
492 return NULL;
493
494 dns_stream_stop(s);
495
496 if (s->server && s->server->stream == s)
497 s->server->stream = NULL;
498
499 if (s->manager) {
500 LIST_REMOVE(streams, s->manager->dns_streams, s);
501 s->manager->n_dns_streams--;
502 }
503
504 #if HAVE_GNUTLS
505 if (s->tls_session)
506 gnutls_deinit(s->tls_session);
507 #endif
508
509 ORDERED_SET_FOREACH(p, s->write_queue, i)
510 dns_packet_unref(ordered_set_remove(s->write_queue, p));
511
512 dns_packet_unref(s->write_packet);
513 dns_packet_unref(s->read_packet);
514 dns_server_unref(s->server);
515
516 ordered_set_free(s->write_queue);
517
518 return mfree(s);
519 }
520
521 DnsStream *dns_stream_ref(DnsStream *s) {
522 if (!s)
523 return NULL;
524
525 assert(s->n_ref > 0);
526 s->n_ref++;
527
528 return s;
529 }
530
531 int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd, const union sockaddr_union *tfo_address) {
532 _cleanup_(dns_stream_unrefp) DnsStream *s = NULL;
533 int r;
534
535 assert(m);
536 assert(fd >= 0);
537
538 if (m->n_dns_streams > DNS_STREAMS_MAX)
539 return -EBUSY;
540
541 s = new0(DnsStream, 1);
542 if (!s)
543 return -ENOMEM;
544
545 r = ordered_set_ensure_allocated(&s->write_queue, &dns_packet_hash_ops);
546 if (r < 0)
547 return r;
548
549 s->n_ref = 1;
550 s->fd = -1;
551 s->protocol = protocol;
552
553 r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s);
554 if (r < 0)
555 return r;
556
557 (void) sd_event_source_set_description(s->io_event_source, "dns-stream-io");
558
559 r = sd_event_add_time(
560 m->event,
561 &s->timeout_event_source,
562 clock_boottime_or_monotonic(),
563 now(clock_boottime_or_monotonic()) + DNS_STREAM_TIMEOUT_USEC, 0,
564 on_stream_timeout, s);
565 if (r < 0)
566 return r;
567
568 (void) sd_event_source_set_description(s->timeout_event_source, "dns-stream-timeout");
569
570 LIST_PREPEND(streams, m->dns_streams, s);
571 s->manager = m;
572 s->fd = fd;
573 if (tfo_address) {
574 s->tfo_address = *tfo_address;
575 s->tfo_salen = tfo_address->sa.sa_family == AF_INET6 ? sizeof(tfo_address->in6) : sizeof(tfo_address->in);
576 }
577
578 m->n_dns_streams++;
579
580 *ret = TAKE_PTR(s);
581
582 return 0;
583 }
584
585 #if HAVE_GNUTLS
586 int dns_stream_connect_tls(DnsStream *s, gnutls_session_t tls_session) {
587 gnutls_transport_set_ptr2(tls_session, (gnutls_transport_ptr_t) (long) s->fd, s);
588 gnutls_transport_set_vec_push_function(tls_session, &dns_stream_tls_writev);
589
590 s->encrypted = true;
591 s->tls_session = tls_session;
592 s->tls_handshake = gnutls_handshake(tls_session);
593 if (s->tls_handshake < 0 && gnutls_error_is_fatal(s->tls_handshake))
594 return -ECONNREFUSED;
595
596 return 0;
597 }
598 #endif
599
600 int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
601 int r;
602
603 assert(s);
604
605 r = ordered_set_put(s->write_queue, p);
606 if (r < 0)
607 return r;
608
609 dns_packet_ref(p);
610
611 return dns_stream_update_io(s);
612 }