]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/resolve/resolved-dns-stream.c
Merge pull request #22276 from mrc0mmand/TEST-64-workaround
[thirdparty/systemd.git] / src / resolve / resolved-dns-stream.c
1 /* SPDX-License-Identifier: LGPL-2.1-or-later */
2
3 #include <netinet/tcp.h>
4 #include <unistd.h>
5
6 #include "alloc-util.h"
7 #include "fd-util.h"
8 #include "io-util.h"
9 #include "macro.h"
10 #include "missing_network.h"
11 #include "resolved-dns-stream.h"
12 #include "resolved-manager.h"
13
14 #define DNS_STREAMS_MAX 128
15
16 #define DNS_QUERIES_PER_STREAM 32
17
18 static void dns_stream_stop(DnsStream *s) {
19 assert(s);
20
21 s->io_event_source = sd_event_source_disable_unref(s->io_event_source);
22 s->timeout_event_source = sd_event_source_disable_unref(s->timeout_event_source);
23 s->fd = safe_close(s->fd);
24
25 /* Disconnect us from the server object if we are now not usable anymore */
26 dns_stream_detach(s);
27 }
28
29 static int dns_stream_update_io(DnsStream *s) {
30 int f = 0;
31
32 assert(s);
33
34 if (s->write_packet && s->n_written < sizeof(s->write_size) + s->write_packet->size)
35 f |= EPOLLOUT;
36 else if (!ordered_set_isempty(s->write_queue)) {
37 dns_packet_unref(s->write_packet);
38 s->write_packet = ordered_set_steal_first(s->write_queue);
39 s->write_size = htobe16(s->write_packet->size);
40 s->n_written = 0;
41 f |= EPOLLOUT;
42 }
43
44 /* Let's read a packet if we haven't queued any yet. Except if we already hit a limit of parallel
45 * queries for this connection. */
46 if ((!s->read_packet || s->n_read < sizeof(s->read_size) + s->read_packet->size) &&
47 set_size(s->queries) < DNS_QUERIES_PER_STREAM)
48 f |= EPOLLIN;
49
50 #if ENABLE_DNS_OVER_TLS
51 /* For handshake and clean closing purposes, TLS can override requested events */
52 if (s->dnstls_events != 0)
53 f = s->dnstls_events;
54 #endif
55
56 return sd_event_source_set_io_events(s->io_event_source, f);
57 }
58
59 static int dns_stream_complete(DnsStream *s, int error) {
60 _cleanup_(dns_stream_unrefp) _unused_ DnsStream *ref = dns_stream_ref(s); /* Protect stream while we process it */
61
62 assert(s);
63 assert(error >= 0);
64
65 /* Error is > 0 when the connection failed for some reason in the network stack. It's == 0 if we sent
66 * and received exactly one packet each (in the LLMNR client case). */
67
68 #if ENABLE_DNS_OVER_TLS
69 if (s->encrypted) {
70 int r;
71
72 r = dnstls_stream_shutdown(s, error);
73 if (r != -EAGAIN)
74 dns_stream_stop(s);
75 } else
76 #endif
77 dns_stream_stop(s);
78
79 dns_stream_detach(s);
80
81 if (s->complete)
82 s->complete(s, error);
83 else /* the default action if no completion function is set is to close the stream */
84 dns_stream_unref(s);
85
86 return 0;
87 }
88
89 static int dns_stream_identify(DnsStream *s) {
90 CMSG_BUFFER_TYPE(CMSG_SPACE(MAXSIZE(struct in_pktinfo, struct in6_pktinfo))
91 + CMSG_SPACE(int) + /* for the TTL */
92 + EXTRA_CMSG_SPACE /* kernel appears to require extra space */) control;
93 struct msghdr mh = {};
94 struct cmsghdr *cmsg;
95 socklen_t sl;
96 int r;
97
98 assert(s);
99
100 if (s->identified)
101 return 0;
102
103 /* Query the local side */
104 s->local_salen = sizeof(s->local);
105 r = getsockname(s->fd, &s->local.sa, &s->local_salen);
106 if (r < 0)
107 return -errno;
108 if (s->local.sa.sa_family == AF_INET6 && s->ifindex <= 0)
109 s->ifindex = s->local.in6.sin6_scope_id;
110
111 /* Query the remote side */
112 s->peer_salen = sizeof(s->peer);
113 r = getpeername(s->fd, &s->peer.sa, &s->peer_salen);
114 if (r < 0)
115 return -errno;
116 if (s->peer.sa.sa_family == AF_INET6 && s->ifindex <= 0)
117 s->ifindex = s->peer.in6.sin6_scope_id;
118
119 /* Check consistency */
120 assert(s->peer.sa.sa_family == s->local.sa.sa_family);
121 assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
122
123 /* Query connection meta information */
124 sl = sizeof(control);
125 if (s->peer.sa.sa_family == AF_INET) {
126 r = getsockopt(s->fd, IPPROTO_IP, IP_PKTOPTIONS, &control, &sl);
127 if (r < 0)
128 return -errno;
129 } else if (s->peer.sa.sa_family == AF_INET6) {
130
131 r = getsockopt(s->fd, IPPROTO_IPV6, IPV6_2292PKTOPTIONS, &control, &sl);
132 if (r < 0)
133 return -errno;
134 } else
135 return -EAFNOSUPPORT;
136
137 mh.msg_control = &control;
138 mh.msg_controllen = sl;
139
140 CMSG_FOREACH(cmsg, &mh) {
141
142 if (cmsg->cmsg_level == IPPROTO_IPV6) {
143 assert(s->peer.sa.sa_family == AF_INET6);
144
145 switch (cmsg->cmsg_type) {
146
147 case IPV6_PKTINFO: {
148 struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg);
149
150 if (s->ifindex <= 0)
151 s->ifindex = i->ipi6_ifindex;
152 break;
153 }
154
155 case IPV6_HOPLIMIT:
156 s->ttl = *(int *) CMSG_DATA(cmsg);
157 break;
158 }
159
160 } else if (cmsg->cmsg_level == IPPROTO_IP) {
161 assert(s->peer.sa.sa_family == AF_INET);
162
163 switch (cmsg->cmsg_type) {
164
165 case IP_PKTINFO: {
166 struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg);
167
168 if (s->ifindex <= 0)
169 s->ifindex = i->ipi_ifindex;
170 break;
171 }
172
173 case IP_TTL:
174 s->ttl = *(int *) CMSG_DATA(cmsg);
175 break;
176 }
177 }
178 }
179
180 /* The Linux kernel sets the interface index to the loopback
181 * device if the connection came from the local host since it
182 * avoids the routing table in such a case. Let's unset the
183 * interface index in such a case. */
184 if (s->ifindex == LOOPBACK_IFINDEX)
185 s->ifindex = 0;
186
187 /* If we don't know the interface index still, we look for the
188 * first local interface with a matching address. Yuck! */
189 if (s->ifindex <= 0)
190 s->ifindex = manager_find_ifindex(s->manager, s->local.sa.sa_family, sockaddr_in_addr(&s->local.sa));
191
192 if (s->protocol == DNS_PROTOCOL_LLMNR && s->ifindex > 0) {
193 /* Make sure all packets for this connection are sent on the same interface */
194 r = socket_set_unicast_if(s->fd, s->local.sa.sa_family, s->ifindex);
195 if (r < 0)
196 log_debug_errno(errno, "Failed to invoke IP_UNICAST_IF/IPV6_UNICAST_IF: %m");
197 }
198
199 s->identified = true;
200
201 return 0;
202 }
203
204 ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt, int flags) {
205 ssize_t m;
206
207 assert(s);
208 assert(iov);
209
210 #if ENABLE_DNS_OVER_TLS
211 if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) {
212 ssize_t ss;
213 size_t i;
214
215 m = 0;
216 for (i = 0; i < iovcnt; i++) {
217 ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len);
218 if (ss < 0)
219 return ss;
220
221 m += ss;
222 if (ss != (ssize_t) iov[i].iov_len)
223 continue;
224 }
225 } else
226 #endif
227 if (s->tfo_salen > 0) {
228 struct msghdr hdr = {
229 .msg_iov = (struct iovec*) iov,
230 .msg_iovlen = iovcnt,
231 .msg_name = &s->tfo_address.sa,
232 .msg_namelen = s->tfo_salen
233 };
234
235 m = sendmsg(s->fd, &hdr, MSG_FASTOPEN);
236 if (m < 0) {
237 if (errno == EOPNOTSUPP) {
238 s->tfo_salen = 0;
239 if (connect(s->fd, &s->tfo_address.sa, s->tfo_salen) < 0)
240 return -errno;
241
242 return -EAGAIN;
243 }
244 if (errno == EINPROGRESS)
245 return -EAGAIN;
246
247 return -errno;
248 } else
249 s->tfo_salen = 0; /* connection is made */
250 } else {
251 m = writev(s->fd, iov, iovcnt);
252 if (m < 0)
253 return -errno;
254 }
255
256 return m;
257 }
258
259 static ssize_t dns_stream_read(DnsStream *s, void *buf, size_t count) {
260 ssize_t ss;
261
262 #if ENABLE_DNS_OVER_TLS
263 if (s->encrypted)
264 ss = dnstls_stream_read(s, buf, count);
265 else
266 #endif
267 {
268 ss = read(s->fd, buf, count);
269 if (ss < 0)
270 return -errno;
271 }
272
273 return ss;
274 }
275
276 static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
277 DnsStream *s = userdata;
278
279 assert(s);
280
281 return dns_stream_complete(s, ETIMEDOUT);
282 }
283
284 static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
285 bool progressed = false;
286 int r;
287
288 assert(s);
289
290 /* This returns 1 when possible remaining stream exists, 0 on completed
291 stream or recoverable error, and negative errno on failure. */
292
293 #if ENABLE_DNS_OVER_TLS
294 if (s->encrypted) {
295 r = dnstls_stream_on_io(s, revents);
296 if (r == DNSTLS_STREAM_CLOSED)
297 return 0;
298 if (r == -EAGAIN)
299 return dns_stream_update_io(s);
300 if (r < 0)
301 return dns_stream_complete(s, -r);
302
303 r = dns_stream_update_io(s);
304 if (r < 0)
305 return r;
306 }
307 #endif
308
309 /* only identify after connecting */
310 if (s->tfo_salen == 0) {
311 r = dns_stream_identify(s);
312 if (r < 0)
313 return dns_stream_complete(s, -r);
314 }
315
316 if ((revents & EPOLLOUT) &&
317 s->write_packet &&
318 s->n_written < sizeof(s->write_size) + s->write_packet->size) {
319
320 struct iovec iov[] = {
321 IOVEC_MAKE(&s->write_size, sizeof(s->write_size)),
322 IOVEC_MAKE(DNS_PACKET_DATA(s->write_packet), s->write_packet->size),
323 };
324
325 IOVEC_INCREMENT(iov, ELEMENTSOF(iov), s->n_written);
326
327 ssize_t ss = dns_stream_writev(s, iov, ELEMENTSOF(iov), 0);
328 if (ss < 0) {
329 if (!ERRNO_IS_TRANSIENT(ss))
330 return dns_stream_complete(s, -ss);
331 } else {
332 progressed = true;
333 s->n_written += ss;
334 }
335
336 /* Are we done? If so, disable the event source for EPOLLOUT */
337 if (s->n_written >= sizeof(s->write_size) + s->write_packet->size) {
338 r = dns_stream_update_io(s);
339 if (r < 0)
340 return dns_stream_complete(s, -r);
341 }
342 }
343
344 if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
345 (!s->read_packet ||
346 s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
347
348 if (s->n_read < sizeof(s->read_size)) {
349 ssize_t ss;
350
351 ss = dns_stream_read(s, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
352 if (ss < 0) {
353 if (!ERRNO_IS_TRANSIENT(ss))
354 return dns_stream_complete(s, -ss);
355 } else if (ss == 0)
356 return dns_stream_complete(s, ECONNRESET);
357 else {
358 progressed = true;
359 s->n_read += ss;
360 }
361 }
362
363 if (s->n_read >= sizeof(s->read_size)) {
364
365 if (be16toh(s->read_size) < DNS_PACKET_HEADER_SIZE)
366 return dns_stream_complete(s, EBADMSG);
367
368 if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) {
369 ssize_t ss;
370
371 if (!s->read_packet) {
372 r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size), DNS_PACKET_SIZE_MAX);
373 if (r < 0)
374 return dns_stream_complete(s, -r);
375
376 s->read_packet->size = be16toh(s->read_size);
377 s->read_packet->ipproto = IPPROTO_TCP;
378 s->read_packet->family = s->peer.sa.sa_family;
379 s->read_packet->ttl = s->ttl;
380 s->read_packet->ifindex = s->ifindex;
381 s->read_packet->timestamp = now(clock_boottime_or_monotonic());
382
383 if (s->read_packet->family == AF_INET) {
384 s->read_packet->sender.in = s->peer.in.sin_addr;
385 s->read_packet->sender_port = be16toh(s->peer.in.sin_port);
386 s->read_packet->destination.in = s->local.in.sin_addr;
387 s->read_packet->destination_port = be16toh(s->local.in.sin_port);
388 } else {
389 assert(s->read_packet->family == AF_INET6);
390 s->read_packet->sender.in6 = s->peer.in6.sin6_addr;
391 s->read_packet->sender_port = be16toh(s->peer.in6.sin6_port);
392 s->read_packet->destination.in6 = s->local.in6.sin6_addr;
393 s->read_packet->destination_port = be16toh(s->local.in6.sin6_port);
394
395 if (s->read_packet->ifindex == 0)
396 s->read_packet->ifindex = s->peer.in6.sin6_scope_id;
397 if (s->read_packet->ifindex == 0)
398 s->read_packet->ifindex = s->local.in6.sin6_scope_id;
399 }
400 }
401
402 ss = dns_stream_read(s,
403 (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
404 sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
405 if (ss < 0) {
406 if (!ERRNO_IS_TRANSIENT(ss))
407 return dns_stream_complete(s, -ss);
408 } else if (ss == 0)
409 return dns_stream_complete(s, ECONNRESET);
410 else
411 s->n_read += ss;
412 }
413
414 /* Are we done? If so, disable the event source for EPOLLIN */
415 if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
416 /* If there's a packet handler
417 * installed, call that. Note that
418 * this is optional... */
419 if (s->on_packet) {
420 r = s->on_packet(s);
421 if (r < 0)
422 return r;
423 }
424
425 r = dns_stream_update_io(s);
426 if (r < 0)
427 return dns_stream_complete(s, -r);
428 }
429 }
430 }
431
432 /* Call "complete" callback if finished reading and writing one packet, and there's nothing else left
433 * to write. */
434 if (s->type == DNS_STREAM_LLMNR_SEND &&
435 (s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) &&
436 ordered_set_isempty(s->write_queue) &&
437 (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size))
438 return dns_stream_complete(s, 0);
439
440 /* If we did something, let's restart the timeout event source */
441 if (progressed && s->timeout_event_source) {
442 r = sd_event_source_set_time_relative(s->timeout_event_source, DNS_STREAM_ESTABLISHED_TIMEOUT_USEC);
443 if (r < 0)
444 log_warning_errno(errno, "Couldn't restart TCP connection timeout, ignoring: %m");
445 }
446
447 return 1;
448 }
449
450 static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
451 _cleanup_(dns_stream_unrefp) DnsStream *s = dns_stream_ref(userdata); /* Protect stream while we process it */
452 int r;
453
454 assert(s);
455
456 r = on_stream_io_impl(s, revents);
457 if (r <= 0)
458 return r;
459
460 #if ENABLE_DNS_OVER_TLS
461 if (!s->encrypted)
462 return 0;
463
464 /* When using DNS-over-TLS, the underlying TLS library may read the entire TLS record
465 and buffer it internally. If this happens, we will not receive further EPOLLIN events,
466 and unless there's some unrelated activity on the socket, we will hang until time out.
467 To avoid this, if there's buffered TLS data, generate a "fake" EPOLLIN event.
468 This is hacky, but it makes this case transparent to the rest of the IO code. */
469 while (dnstls_stream_has_buffered_data(s)) {
470 uint32_t events;
471
472 /* Make sure the stream still wants to process more data... */
473 r = sd_event_source_get_io_events(s->io_event_source, &events);
474 if (r < 0)
475 return r;
476 if (!FLAGS_SET(events, EPOLLIN))
477 break;
478
479 r = on_stream_io_impl(s, EPOLLIN);
480 if (r <= 0)
481 return r;
482 }
483 #endif
484
485 return 0;
486 }
487
488 static DnsStream *dns_stream_free(DnsStream *s) {
489 DnsPacket *p;
490
491 assert(s);
492
493 dns_stream_stop(s);
494
495 if (s->manager) {
496 LIST_REMOVE(streams, s->manager->dns_streams, s);
497 s->manager->n_dns_streams[s->type]--;
498 }
499
500 #if ENABLE_DNS_OVER_TLS
501 if (s->encrypted)
502 dnstls_stream_free(s);
503 #endif
504
505 ORDERED_SET_FOREACH(p, s->write_queue)
506 dns_packet_unref(ordered_set_remove(s->write_queue, p));
507
508 dns_packet_unref(s->write_packet);
509 dns_packet_unref(s->read_packet);
510 dns_server_unref(s->server);
511
512 ordered_set_free(s->write_queue);
513
514 return mfree(s);
515 }
516
517 DEFINE_TRIVIAL_REF_UNREF_FUNC(DnsStream, dns_stream, dns_stream_free);
518
519 int dns_stream_new(
520 Manager *m,
521 DnsStream **ret,
522 DnsStreamType type,
523 DnsProtocol protocol,
524 int fd,
525 const union sockaddr_union *tfo_address,
526 usec_t connect_timeout_usec) {
527
528 _cleanup_(dns_stream_unrefp) DnsStream *s = NULL;
529 int r;
530
531 assert(m);
532 assert(ret);
533 assert(type >= 0);
534 assert(type < _DNS_STREAM_TYPE_MAX);
535 assert(protocol >= 0);
536 assert(protocol < _DNS_PROTOCOL_MAX);
537 assert(fd >= 0);
538
539 if (m->n_dns_streams[type] > DNS_STREAMS_MAX)
540 return -EBUSY;
541
542 s = new(DnsStream, 1);
543 if (!s)
544 return -ENOMEM;
545
546 *s = (DnsStream) {
547 .n_ref = 1,
548 .fd = -1,
549 .protocol = protocol,
550 .type = type,
551 };
552
553 r = ordered_set_ensure_allocated(&s->write_queue, &dns_packet_hash_ops);
554 if (r < 0)
555 return r;
556
557 r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s);
558 if (r < 0)
559 return r;
560
561 (void) sd_event_source_set_description(s->io_event_source, "dns-stream-io");
562
563 r = sd_event_add_time_relative(
564 m->event,
565 &s->timeout_event_source,
566 clock_boottime_or_monotonic(),
567 connect_timeout_usec, 0,
568 on_stream_timeout, s);
569 if (r < 0)
570 return r;
571
572 (void) sd_event_source_set_description(s->timeout_event_source, "dns-stream-timeout");
573
574 LIST_PREPEND(streams, m->dns_streams, s);
575 m->n_dns_streams[type]++;
576 s->manager = m;
577
578 s->fd = fd;
579
580 if (tfo_address) {
581 s->tfo_address = *tfo_address;
582 s->tfo_salen = tfo_address->sa.sa_family == AF_INET6 ? sizeof(tfo_address->in6) : sizeof(tfo_address->in);
583 }
584
585 *ret = TAKE_PTR(s);
586
587 return 0;
588 }
589
590 int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
591 int r;
592
593 assert(s);
594 assert(p);
595
596 r = ordered_set_put(s->write_queue, p);
597 if (r < 0)
598 return r;
599
600 dns_packet_ref(p);
601
602 return dns_stream_update_io(s);
603 }
604
605 DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
606 assert(s);
607
608 if (!s->read_packet)
609 return NULL;
610
611 if (s->n_read < sizeof(s->read_size))
612 return NULL;
613
614 if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
615 return NULL;
616
617 s->n_read = 0;
618 return TAKE_PTR(s->read_packet);
619 }
620
621 void dns_stream_detach(DnsStream *s) {
622 assert(s);
623
624 if (!s->server)
625 return;
626
627 if (s->server->stream != s)
628 return;
629
630 dns_server_unref_stream(s->server);
631 }