]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/resolve/resolved-dns-stream.c
3fd056bdb709e9879631f4ee5ac3a66b9c677559
[thirdparty/systemd.git] / src / resolve / resolved-dns-stream.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2
3 #include <netinet/tcp.h>
4
5 #include "alloc-util.h"
6 #include "fd-util.h"
7 #include "io-util.h"
8 #include "missing.h"
9 #include "resolved-dns-stream.h"
10
11 #define DNS_STREAM_TIMEOUT_USEC (10 * USEC_PER_SEC)
12 #define DNS_STREAMS_MAX 128
13
14 static void dns_stream_stop(DnsStream *s) {
15 assert(s);
16
17 s->io_event_source = sd_event_source_unref(s->io_event_source);
18 s->timeout_event_source = sd_event_source_unref(s->timeout_event_source);
19 s->fd = safe_close(s->fd);
20
21 /* Disconnect us from the server object if we are now not usable anymore */
22 dns_stream_detach(s);
23 }
24
25 static int dns_stream_update_io(DnsStream *s) {
26 int f = 0;
27
28 assert(s);
29
30 if (s->write_packet && s->n_written < sizeof(s->write_size) + s->write_packet->size)
31 f |= EPOLLOUT;
32 else if (!ordered_set_isempty(s->write_queue)) {
33 dns_packet_unref(s->write_packet);
34 s->write_packet = ordered_set_steal_first(s->write_queue);
35 s->write_size = htobe16(s->write_packet->size);
36 s->n_written = 0;
37 f |= EPOLLOUT;
38 }
39 if (!s->read_packet || s->n_read < sizeof(s->read_size) + s->read_packet->size)
40 f |= EPOLLIN;
41
42 #if ENABLE_DNS_OVER_TLS
43 /* For handshake and clean closing purposes, TLS can override requested events */
44 if (s->dnstls_events != 0)
45 f = s->dnstls_events;
46 #endif
47
48 return sd_event_source_set_io_events(s->io_event_source, f);
49 }
50
51 static int dns_stream_complete(DnsStream *s, int error) {
52 _cleanup_(dns_stream_unrefp) _unused_ DnsStream *ref = dns_stream_ref(s); /* Protect stream while we process it */
53
54 assert(s);
55 assert(error >= 0);
56
57 /* Error is > 0 when the connection failed for some reason in the network stack. It's == 0 if we sent
58 * and receieved exactly one packet each (in the LLMNR client case). */
59
60 #if ENABLE_DNS_OVER_TLS
61 if (s->encrypted) {
62 int r;
63
64 r = dnstls_stream_shutdown(s, error);
65 if (r != -EAGAIN)
66 dns_stream_stop(s);
67 } else
68 #endif
69 dns_stream_stop(s);
70
71 dns_stream_detach(s);
72
73 if (s->complete)
74 s->complete(s, error);
75 else /* the default action if no completion function is set is to close the stream */
76 dns_stream_unref(s);
77
78 return 0;
79 }
80
81 static int dns_stream_identify(DnsStream *s) {
82 union {
83 struct cmsghdr header; /* For alignment */
84 uint8_t buffer[CMSG_SPACE(MAXSIZE(struct in_pktinfo, struct in6_pktinfo))
85 + EXTRA_CMSG_SPACE /* kernel appears to require extra space */];
86 } control;
87 struct msghdr mh = {};
88 struct cmsghdr *cmsg;
89 socklen_t sl;
90 int r;
91
92 assert(s);
93
94 if (s->identified)
95 return 0;
96
97 /* Query the local side */
98 s->local_salen = sizeof(s->local);
99 r = getsockname(s->fd, &s->local.sa, &s->local_salen);
100 if (r < 0)
101 return -errno;
102 if (s->local.sa.sa_family == AF_INET6 && s->ifindex <= 0)
103 s->ifindex = s->local.in6.sin6_scope_id;
104
105 /* Query the remote side */
106 s->peer_salen = sizeof(s->peer);
107 r = getpeername(s->fd, &s->peer.sa, &s->peer_salen);
108 if (r < 0)
109 return -errno;
110 if (s->peer.sa.sa_family == AF_INET6 && s->ifindex <= 0)
111 s->ifindex = s->peer.in6.sin6_scope_id;
112
113 /* Check consistency */
114 assert(s->peer.sa.sa_family == s->local.sa.sa_family);
115 assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
116
117 /* Query connection meta information */
118 sl = sizeof(control);
119 if (s->peer.sa.sa_family == AF_INET) {
120 r = getsockopt(s->fd, IPPROTO_IP, IP_PKTOPTIONS, &control, &sl);
121 if (r < 0)
122 return -errno;
123 } else if (s->peer.sa.sa_family == AF_INET6) {
124
125 r = getsockopt(s->fd, IPPROTO_IPV6, IPV6_2292PKTOPTIONS, &control, &sl);
126 if (r < 0)
127 return -errno;
128 } else
129 return -EAFNOSUPPORT;
130
131 mh.msg_control = &control;
132 mh.msg_controllen = sl;
133
134 CMSG_FOREACH(cmsg, &mh) {
135
136 if (cmsg->cmsg_level == IPPROTO_IPV6) {
137 assert(s->peer.sa.sa_family == AF_INET6);
138
139 switch (cmsg->cmsg_type) {
140
141 case IPV6_PKTINFO: {
142 struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg);
143
144 if (s->ifindex <= 0)
145 s->ifindex = i->ipi6_ifindex;
146 break;
147 }
148
149 case IPV6_HOPLIMIT:
150 s->ttl = *(int *) CMSG_DATA(cmsg);
151 break;
152 }
153
154 } else if (cmsg->cmsg_level == IPPROTO_IP) {
155 assert(s->peer.sa.sa_family == AF_INET);
156
157 switch (cmsg->cmsg_type) {
158
159 case IP_PKTINFO: {
160 struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg);
161
162 if (s->ifindex <= 0)
163 s->ifindex = i->ipi_ifindex;
164 break;
165 }
166
167 case IP_TTL:
168 s->ttl = *(int *) CMSG_DATA(cmsg);
169 break;
170 }
171 }
172 }
173
174 /* The Linux kernel sets the interface index to the loopback
175 * device if the connection came from the local host since it
176 * avoids the routing table in such a case. Let's unset the
177 * interface index in such a case. */
178 if (s->ifindex == LOOPBACK_IFINDEX)
179 s->ifindex = 0;
180
181 /* If we don't know the interface index still, we look for the
182 * first local interface with a matching address. Yuck! */
183 if (s->ifindex <= 0)
184 s->ifindex = manager_find_ifindex(s->manager, s->local.sa.sa_family, s->local.sa.sa_family == AF_INET ? (union in_addr_union*) &s->local.in.sin_addr : (union in_addr_union*) &s->local.in6.sin6_addr);
185
186 if (s->protocol == DNS_PROTOCOL_LLMNR && s->ifindex > 0) {
187 uint32_t ifindex = htobe32(s->ifindex);
188
189 /* Make sure all packets for this connection are sent on the same interface */
190 if (s->local.sa.sa_family == AF_INET) {
191 r = setsockopt(s->fd, IPPROTO_IP, IP_UNICAST_IF, &ifindex, sizeof(ifindex));
192 if (r < 0)
193 log_debug_errno(errno, "Failed to invoke IP_UNICAST_IF: %m");
194 } else if (s->local.sa.sa_family == AF_INET6) {
195 r = setsockopt(s->fd, IPPROTO_IPV6, IPV6_UNICAST_IF, &ifindex, sizeof(ifindex));
196 if (r < 0)
197 log_debug_errno(errno, "Failed to invoke IPV6_UNICAST_IF: %m");
198 }
199 }
200
201 s->identified = true;
202
203 return 0;
204 }
205
206 ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt, int flags) {
207 ssize_t m;
208
209 assert(s);
210 assert(iov);
211
212 #if ENABLE_DNS_OVER_TLS
213 if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) {
214 ssize_t ss;
215 size_t i;
216
217 m = 0;
218 for (i = 0; i < iovcnt; i++) {
219 ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len);
220 if (ss < 0)
221 return ss;
222
223 m += ss;
224 if (ss != (ssize_t) iov[i].iov_len)
225 continue;
226 }
227 } else
228 #endif
229 if (s->tfo_salen > 0) {
230 struct msghdr hdr = {
231 .msg_iov = (struct iovec*) iov,
232 .msg_iovlen = iovcnt,
233 .msg_name = &s->tfo_address.sa,
234 .msg_namelen = s->tfo_salen
235 };
236
237 m = sendmsg(s->fd, &hdr, MSG_FASTOPEN);
238 if (m < 0) {
239 if (errno == EOPNOTSUPP) {
240 s->tfo_salen = 0;
241 if (connect(s->fd, &s->tfo_address.sa, s->tfo_salen) < 0)
242 return -errno;
243
244 return -EAGAIN;
245 }
246 if (errno == EINPROGRESS)
247 return -EAGAIN;
248
249 return -errno;
250 } else
251 s->tfo_salen = 0; /* connection is made */
252 } else {
253 m = writev(s->fd, iov, iovcnt);
254 if (m < 0)
255 return -errno;
256 }
257
258 return m;
259 }
260
261 static ssize_t dns_stream_read(DnsStream *s, void *buf, size_t count) {
262 ssize_t ss;
263
264 #if ENABLE_DNS_OVER_TLS
265 if (s->encrypted)
266 ss = dnstls_stream_read(s, buf, count);
267 else
268 #endif
269 {
270 ss = read(s->fd, buf, count);
271 if (ss < 0)
272 return -errno;
273 }
274
275 return ss;
276 }
277
278 static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
279 DnsStream *s = userdata;
280
281 assert(s);
282
283 return dns_stream_complete(s, ETIMEDOUT);
284 }
285
286 static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
287 _cleanup_(dns_stream_unrefp) DnsStream *s = dns_stream_ref(userdata); /* Protect stream while we process it */
288 bool progressed = false;
289 int r;
290
291 assert(s);
292
293 #if ENABLE_DNS_OVER_TLS
294 if (s->encrypted) {
295 r = dnstls_stream_on_io(s, revents);
296 if (r == DNSTLS_STREAM_CLOSED)
297 return 0;
298 if (r == -EAGAIN)
299 return dns_stream_update_io(s);
300 if (r < 0)
301 return dns_stream_complete(s, -r);
302
303 r = dns_stream_update_io(s);
304 if (r < 0)
305 return r;
306 }
307 #endif
308
309 /* only identify after connecting */
310 if (s->tfo_salen == 0) {
311 r = dns_stream_identify(s);
312 if (r < 0)
313 return dns_stream_complete(s, -r);
314 }
315
316 if ((revents & EPOLLOUT) &&
317 s->write_packet &&
318 s->n_written < sizeof(s->write_size) + s->write_packet->size) {
319
320 struct iovec iov[2];
321 ssize_t ss;
322
323 iov[0] = IOVEC_MAKE(&s->write_size, sizeof(s->write_size));
324 iov[1] = IOVEC_MAKE(DNS_PACKET_DATA(s->write_packet), s->write_packet->size);
325
326 IOVEC_INCREMENT(iov, 2, s->n_written);
327
328 ss = dns_stream_writev(s, iov, 2, 0);
329 if (ss < 0) {
330 if (!IN_SET(-ss, EINTR, EAGAIN))
331 return dns_stream_complete(s, -ss);
332 } else {
333 progressed = true;
334 s->n_written += ss;
335 }
336
337 /* Are we done? If so, disable the event source for EPOLLOUT */
338 if (s->n_written >= sizeof(s->write_size) + s->write_packet->size) {
339 r = dns_stream_update_io(s);
340 if (r < 0)
341 return dns_stream_complete(s, -r);
342 }
343 }
344
345 if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
346 (!s->read_packet ||
347 s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
348
349 if (s->n_read < sizeof(s->read_size)) {
350 ssize_t ss;
351
352 ss = dns_stream_read(s, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
353 if (ss < 0) {
354 if (!IN_SET(-ss, EINTR, EAGAIN))
355 return dns_stream_complete(s, -ss);
356 } else if (ss == 0)
357 return dns_stream_complete(s, ECONNRESET);
358 else {
359 progressed = true;
360 s->n_read += ss;
361 }
362 }
363
364 if (s->n_read >= sizeof(s->read_size)) {
365
366 if (be16toh(s->read_size) < DNS_PACKET_HEADER_SIZE)
367 return dns_stream_complete(s, EBADMSG);
368
369 if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) {
370 ssize_t ss;
371
372 if (!s->read_packet) {
373 r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size), DNS_PACKET_SIZE_MAX);
374 if (r < 0)
375 return dns_stream_complete(s, -r);
376
377 s->read_packet->size = be16toh(s->read_size);
378 s->read_packet->ipproto = IPPROTO_TCP;
379 s->read_packet->family = s->peer.sa.sa_family;
380 s->read_packet->ttl = s->ttl;
381 s->read_packet->ifindex = s->ifindex;
382
383 if (s->read_packet->family == AF_INET) {
384 s->read_packet->sender.in = s->peer.in.sin_addr;
385 s->read_packet->sender_port = be16toh(s->peer.in.sin_port);
386 s->read_packet->destination.in = s->local.in.sin_addr;
387 s->read_packet->destination_port = be16toh(s->local.in.sin_port);
388 } else {
389 assert(s->read_packet->family == AF_INET6);
390 s->read_packet->sender.in6 = s->peer.in6.sin6_addr;
391 s->read_packet->sender_port = be16toh(s->peer.in6.sin6_port);
392 s->read_packet->destination.in6 = s->local.in6.sin6_addr;
393 s->read_packet->destination_port = be16toh(s->local.in6.sin6_port);
394
395 if (s->read_packet->ifindex == 0)
396 s->read_packet->ifindex = s->peer.in6.sin6_scope_id;
397 if (s->read_packet->ifindex == 0)
398 s->read_packet->ifindex = s->local.in6.sin6_scope_id;
399 }
400 }
401
402 ss = dns_stream_read(s,
403 (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
404 sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
405 if (ss < 0) {
406 if (!IN_SET(-ss, EINTR, EAGAIN))
407 return dns_stream_complete(s, -ss);
408 } else if (ss == 0)
409 return dns_stream_complete(s, ECONNRESET);
410 else
411 s->n_read += ss;
412 }
413
414 /* Are we done? If so, disable the event source for EPOLLIN */
415 if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
416 /* If there's a packet handler
417 * installed, call that. Note that
418 * this is optional... */
419 if (s->on_packet) {
420 r = s->on_packet(s);
421 if (r < 0)
422 return r;
423 }
424
425 r = dns_stream_update_io(s);
426 if (r < 0)
427 return dns_stream_complete(s, -r);
428 }
429 }
430 }
431
432 /* Call "complete" callback if finished reading and writing one packet, and there's nothing else left
433 * to write. */
434 if (s->type == DNS_STREAM_LLMNR_SEND &&
435 (s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) &&
436 ordered_set_isempty(s->write_queue) &&
437 (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size))
438 return dns_stream_complete(s, 0);
439
440 /* If we did something, let's restart the timeout event source */
441 if (progressed && s->timeout_event_source) {
442 r = sd_event_source_set_time(s->timeout_event_source, now(clock_boottime_or_monotonic()) + DNS_STREAM_TIMEOUT_USEC);
443 if (r < 0)
444 log_warning_errno(errno, "Couldn't restart TCP connection timeout, ignoring: %m");
445 }
446
447 return 0;
448 }
449
450 static DnsStream *dns_stream_free(DnsStream *s) {
451 DnsPacket *p;
452 Iterator i;
453
454 assert(s);
455
456 dns_stream_stop(s);
457
458 if (s->manager) {
459 LIST_REMOVE(streams, s->manager->dns_streams, s);
460 s->manager->n_dns_streams[s->type]--;
461 }
462
463 #if ENABLE_DNS_OVER_TLS
464 if (s->encrypted)
465 dnstls_stream_free(s);
466 #endif
467
468 ORDERED_SET_FOREACH(p, s->write_queue, i)
469 dns_packet_unref(ordered_set_remove(s->write_queue, p));
470
471 dns_packet_unref(s->write_packet);
472 dns_packet_unref(s->read_packet);
473 dns_server_unref(s->server);
474
475 ordered_set_free(s->write_queue);
476
477 return mfree(s);
478 }
479
480 DEFINE_TRIVIAL_REF_UNREF_FUNC(DnsStream, dns_stream, dns_stream_free);
481
482 int dns_stream_new(
483 Manager *m,
484 DnsStream **ret,
485 DnsStreamType type,
486 DnsProtocol protocol,
487 int fd,
488 const union sockaddr_union *tfo_address) {
489
490 _cleanup_(dns_stream_unrefp) DnsStream *s = NULL;
491 int r;
492
493 assert(m);
494 assert(ret);
495 assert(type >= 0);
496 assert(type < _DNS_STREAM_TYPE_MAX);
497 assert(protocol >= 0);
498 assert(protocol < _DNS_PROTOCOL_MAX);
499 assert(fd >= 0);
500
501 if (m->n_dns_streams[type] > DNS_STREAMS_MAX)
502 return -EBUSY;
503
504 s = new(DnsStream, 1);
505 if (!s)
506 return -ENOMEM;
507
508 *s = (DnsStream) {
509 .n_ref = 1,
510 .fd = -1,
511 .protocol = protocol,
512 };
513
514 r = ordered_set_ensure_allocated(&s->write_queue, &dns_packet_hash_ops);
515 if (r < 0)
516 return r;
517
518 r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s);
519 if (r < 0)
520 return r;
521
522 (void) sd_event_source_set_description(s->io_event_source, "dns-stream-io");
523
524 r = sd_event_add_time(
525 m->event,
526 &s->timeout_event_source,
527 clock_boottime_or_monotonic(),
528 now(clock_boottime_or_monotonic()) + DNS_STREAM_TIMEOUT_USEC, 0,
529 on_stream_timeout, s);
530 if (r < 0)
531 return r;
532
533 (void) sd_event_source_set_description(s->timeout_event_source, "dns-stream-timeout");
534
535 LIST_PREPEND(streams, m->dns_streams, s);
536 m->n_dns_streams[type]++;
537 s->manager = m;
538
539 s->fd = fd;
540
541 if (tfo_address) {
542 s->tfo_address = *tfo_address;
543 s->tfo_salen = tfo_address->sa.sa_family == AF_INET6 ? sizeof(tfo_address->in6) : sizeof(tfo_address->in);
544 }
545
546 *ret = TAKE_PTR(s);
547
548 return 0;
549 }
550
551 int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
552 int r;
553
554 assert(s);
555 assert(p);
556
557 r = ordered_set_put(s->write_queue, p);
558 if (r < 0)
559 return r;
560
561 dns_packet_ref(p);
562
563 return dns_stream_update_io(s);
564 }
565
566 DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
567 assert(s);
568
569 if (!s->read_packet)
570 return NULL;
571
572 if (s->n_read < sizeof(s->read_size))
573 return NULL;
574
575 if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
576 return NULL;
577
578 s->n_read = 0;
579 return TAKE_PTR(s->read_packet);
580 }
581
582 void dns_stream_detach(DnsStream *s) {
583 assert(s);
584
585 if (!s->server)
586 return;
587
588 if (s->server->stream != s)
589 return;
590
591 dns_server_unref_stream(s->server);
592 }