]> git.ipfire.org Git - thirdparty/systemd.git/blob - src/resolve/resolved-dns-stream.c
resolved: TCP Fast Open and TLS Session Tickets for OpenSSL
[thirdparty/systemd.git] / src / resolve / resolved-dns-stream.c
1 /* SPDX-License-Identifier: LGPL-2.1+ */
2
3 #include <netinet/tcp.h>
4
5 #include "alloc-util.h"
6 #include "fd-util.h"
7 #include "io-util.h"
8 #include "missing.h"
9 #include "resolved-dns-stream.h"
10
11 #define DNS_STREAM_TIMEOUT_USEC (10 * USEC_PER_SEC)
12 #define DNS_STREAMS_MAX 128
13
14 static void dns_stream_stop(DnsStream *s) {
15 assert(s);
16
17 s->io_event_source = sd_event_source_unref(s->io_event_source);
18 s->timeout_event_source = sd_event_source_unref(s->timeout_event_source);
19 s->fd = safe_close(s->fd);
20 }
21
22 static int dns_stream_update_io(DnsStream *s) {
23 int f = 0;
24
25 assert(s);
26
27 if (s->write_packet && s->n_written < sizeof(s->write_size) + s->write_packet->size)
28 f |= EPOLLOUT;
29 else if (!ordered_set_isempty(s->write_queue)) {
30 dns_packet_unref(s->write_packet);
31 s->write_packet = ordered_set_steal_first(s->write_queue);
32 s->write_size = htobe16(s->write_packet->size);
33 s->n_written = 0;
34 f |= EPOLLOUT;
35 }
36 if (!s->read_packet || s->n_read < sizeof(s->read_size) + s->read_packet->size)
37 f |= EPOLLIN;
38
39 #if ENABLE_DNS_OVER_TLS
40 /* For handshake and clean closing purposes, TLS can override requested events */
41 if (s->dnstls_events)
42 f = s->dnstls_events;
43 #endif
44
45 return sd_event_source_set_io_events(s->io_event_source, f);
46 }
47
48 static int dns_stream_complete(DnsStream *s, int error) {
49 assert(s);
50
51 #if ENABLE_DNS_OVER_TLS
52 if (s->encrypted) {
53 int r;
54
55 r = dnstls_stream_shutdown(s, error);
56 if (r != -EAGAIN)
57 dns_stream_stop(s);
58 } else
59 #endif
60 dns_stream_stop(s);
61
62 if (s->complete)
63 s->complete(s, error);
64 else /* the default action if no completion function is set is to close the stream */
65 dns_stream_unref(s);
66
67 return 0;
68 }
69
70 static int dns_stream_identify(DnsStream *s) {
71 union {
72 struct cmsghdr header; /* For alignment */
73 uint8_t buffer[CMSG_SPACE(MAXSIZE(struct in_pktinfo, struct in6_pktinfo))
74 + EXTRA_CMSG_SPACE /* kernel appears to require extra space */];
75 } control;
76 struct msghdr mh = {};
77 struct cmsghdr *cmsg;
78 socklen_t sl;
79 int r;
80
81 assert(s);
82
83 if (s->identified)
84 return 0;
85
86 /* Query the local side */
87 s->local_salen = sizeof(s->local);
88 r = getsockname(s->fd, &s->local.sa, &s->local_salen);
89 if (r < 0)
90 return -errno;
91 if (s->local.sa.sa_family == AF_INET6 && s->ifindex <= 0)
92 s->ifindex = s->local.in6.sin6_scope_id;
93
94 /* Query the remote side */
95 s->peer_salen = sizeof(s->peer);
96 r = getpeername(s->fd, &s->peer.sa, &s->peer_salen);
97 if (r < 0)
98 return -errno;
99 if (s->peer.sa.sa_family == AF_INET6 && s->ifindex <= 0)
100 s->ifindex = s->peer.in6.sin6_scope_id;
101
102 /* Check consistency */
103 assert(s->peer.sa.sa_family == s->local.sa.sa_family);
104 assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
105
106 /* Query connection meta information */
107 sl = sizeof(control);
108 if (s->peer.sa.sa_family == AF_INET) {
109 r = getsockopt(s->fd, IPPROTO_IP, IP_PKTOPTIONS, &control, &sl);
110 if (r < 0)
111 return -errno;
112 } else if (s->peer.sa.sa_family == AF_INET6) {
113
114 r = getsockopt(s->fd, IPPROTO_IPV6, IPV6_2292PKTOPTIONS, &control, &sl);
115 if (r < 0)
116 return -errno;
117 } else
118 return -EAFNOSUPPORT;
119
120 mh.msg_control = &control;
121 mh.msg_controllen = sl;
122
123 CMSG_FOREACH(cmsg, &mh) {
124
125 if (cmsg->cmsg_level == IPPROTO_IPV6) {
126 assert(s->peer.sa.sa_family == AF_INET6);
127
128 switch (cmsg->cmsg_type) {
129
130 case IPV6_PKTINFO: {
131 struct in6_pktinfo *i = (struct in6_pktinfo*) CMSG_DATA(cmsg);
132
133 if (s->ifindex <= 0)
134 s->ifindex = i->ipi6_ifindex;
135 break;
136 }
137
138 case IPV6_HOPLIMIT:
139 s->ttl = *(int *) CMSG_DATA(cmsg);
140 break;
141 }
142
143 } else if (cmsg->cmsg_level == IPPROTO_IP) {
144 assert(s->peer.sa.sa_family == AF_INET);
145
146 switch (cmsg->cmsg_type) {
147
148 case IP_PKTINFO: {
149 struct in_pktinfo *i = (struct in_pktinfo*) CMSG_DATA(cmsg);
150
151 if (s->ifindex <= 0)
152 s->ifindex = i->ipi_ifindex;
153 break;
154 }
155
156 case IP_TTL:
157 s->ttl = *(int *) CMSG_DATA(cmsg);
158 break;
159 }
160 }
161 }
162
163 /* The Linux kernel sets the interface index to the loopback
164 * device if the connection came from the local host since it
165 * avoids the routing table in such a case. Let's unset the
166 * interface index in such a case. */
167 if (s->ifindex == LOOPBACK_IFINDEX)
168 s->ifindex = 0;
169
170 /* If we don't know the interface index still, we look for the
171 * first local interface with a matching address. Yuck! */
172 if (s->ifindex <= 0)
173 s->ifindex = manager_find_ifindex(s->manager, s->local.sa.sa_family, s->local.sa.sa_family == AF_INET ? (union in_addr_union*) &s->local.in.sin_addr : (union in_addr_union*) &s->local.in6.sin6_addr);
174
175 if (s->protocol == DNS_PROTOCOL_LLMNR && s->ifindex > 0) {
176 uint32_t ifindex = htobe32(s->ifindex);
177
178 /* Make sure all packets for this connection are sent on the same interface */
179 if (s->local.sa.sa_family == AF_INET) {
180 r = setsockopt(s->fd, IPPROTO_IP, IP_UNICAST_IF, &ifindex, sizeof(ifindex));
181 if (r < 0)
182 log_debug_errno(errno, "Failed to invoke IP_UNICAST_IF: %m");
183 } else if (s->local.sa.sa_family == AF_INET6) {
184 r = setsockopt(s->fd, IPPROTO_IPV6, IPV6_UNICAST_IF, &ifindex, sizeof(ifindex));
185 if (r < 0)
186 log_debug_errno(errno, "Failed to invoke IPV6_UNICAST_IF: %m");
187 }
188 }
189
190 s->identified = true;
191
192 return 0;
193 }
194
195 ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt, int flags) {
196 ssize_t r;
197
198 assert(s);
199 assert(iov);
200
201 #if ENABLE_DNS_OVER_TLS
202 if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) {
203 ssize_t ss;
204 size_t i;
205
206 r = 0;
207 for (i = 0; i < iovcnt; i++) {
208 ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len);
209 if (ss < 0)
210 return ss;
211
212 r += ss;
213 if (ss != (ssize_t) iov[i].iov_len)
214 continue;
215 }
216 } else
217 #endif
218 if (s->tfo_salen > 0) {
219 struct msghdr hdr = {
220 .msg_iov = (struct iovec*) iov,
221 .msg_iovlen = iovcnt,
222 .msg_name = &s->tfo_address.sa,
223 .msg_namelen = s->tfo_salen
224 };
225
226 r = sendmsg(s->fd, &hdr, MSG_FASTOPEN);
227 if (r < 0) {
228 if (errno == EOPNOTSUPP) {
229 s->tfo_salen = 0;
230 r = connect(s->fd, &s->tfo_address.sa, s->tfo_salen);
231 if (r < 0)
232 return -errno;
233
234 r = -EAGAIN;
235 } else if (errno == EINPROGRESS)
236 r = -EAGAIN;
237 else
238 r = -errno;
239 } else
240 s->tfo_salen = 0; /* connection is made */
241 } else {
242 r = writev(s->fd, iov, iovcnt);
243 if (r < 0)
244 r = -errno;
245 }
246
247 return r;
248 }
249
250 static ssize_t dns_stream_read(DnsStream *s, void *buf, size_t count) {
251 ssize_t ss;
252
253 #if ENABLE_DNS_OVER_TLS
254 if (s->encrypted)
255 ss = dnstls_stream_read(s, buf, count);
256 else
257 #endif
258 {
259 ss = read(s->fd, buf, count);
260 if (ss < 0)
261 ss = -errno;
262 }
263
264 return ss;
265 }
266
267 static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
268 DnsStream *s = userdata;
269
270 assert(s);
271
272 return dns_stream_complete(s, ETIMEDOUT);
273 }
274
275 static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
276 DnsStream *s = userdata;
277 int r;
278
279 assert(s);
280
281 #if ENABLE_DNS_OVER_TLS
282 if (s->encrypted) {
283 r = dnstls_stream_on_io(s, revents);
284
285 if (r == DNSTLS_STREAM_CLOSED)
286 return 0;
287 else if (r == -EAGAIN)
288 return dns_stream_update_io(s);
289 else if (r < 0) {
290 return dns_stream_complete(s, -r);
291 } else {
292 r = dns_stream_update_io(s);
293 if (r < 0)
294 return r;
295 }
296 }
297 #endif
298
299 /* only identify after connecting */
300 if (s->tfo_salen == 0) {
301 r = dns_stream_identify(s);
302 if (r < 0)
303 return dns_stream_complete(s, -r);
304 }
305
306 if ((revents & EPOLLOUT) &&
307 s->write_packet &&
308 s->n_written < sizeof(s->write_size) + s->write_packet->size) {
309
310 struct iovec iov[2];
311 ssize_t ss;
312
313 iov[0].iov_base = &s->write_size;
314 iov[0].iov_len = sizeof(s->write_size);
315 iov[1].iov_base = DNS_PACKET_DATA(s->write_packet);
316 iov[1].iov_len = s->write_packet->size;
317
318 IOVEC_INCREMENT(iov, 2, s->n_written);
319
320 ss = dns_stream_writev(s, iov, 2, 0);
321 if (ss < 0) {
322 if (!IN_SET(-ss, EINTR, EAGAIN))
323 return dns_stream_complete(s, -ss);
324 } else
325 s->n_written += ss;
326
327 /* Are we done? If so, disable the event source for EPOLLOUT */
328 if (s->n_written >= sizeof(s->write_size) + s->write_packet->size) {
329 r = dns_stream_update_io(s);
330 if (r < 0)
331 return dns_stream_complete(s, -r);
332 }
333 }
334
335 if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
336 (!s->read_packet ||
337 s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
338
339 if (s->n_read < sizeof(s->read_size)) {
340 ssize_t ss;
341
342 ss = dns_stream_read(s, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
343 if (ss < 0) {
344 if (!IN_SET(-ss, EINTR, EAGAIN))
345 return dns_stream_complete(s, -ss);
346 } else if (ss == 0)
347 return dns_stream_complete(s, ECONNRESET);
348 else
349 s->n_read += ss;
350 }
351
352 if (s->n_read >= sizeof(s->read_size)) {
353
354 if (be16toh(s->read_size) < DNS_PACKET_HEADER_SIZE)
355 return dns_stream_complete(s, EBADMSG);
356
357 if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size)) {
358 ssize_t ss;
359
360 if (!s->read_packet) {
361 r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size), DNS_PACKET_SIZE_MAX);
362 if (r < 0)
363 return dns_stream_complete(s, -r);
364
365 s->read_packet->size = be16toh(s->read_size);
366 s->read_packet->ipproto = IPPROTO_TCP;
367 s->read_packet->family = s->peer.sa.sa_family;
368 s->read_packet->ttl = s->ttl;
369 s->read_packet->ifindex = s->ifindex;
370
371 if (s->read_packet->family == AF_INET) {
372 s->read_packet->sender.in = s->peer.in.sin_addr;
373 s->read_packet->sender_port = be16toh(s->peer.in.sin_port);
374 s->read_packet->destination.in = s->local.in.sin_addr;
375 s->read_packet->destination_port = be16toh(s->local.in.sin_port);
376 } else {
377 assert(s->read_packet->family == AF_INET6);
378 s->read_packet->sender.in6 = s->peer.in6.sin6_addr;
379 s->read_packet->sender_port = be16toh(s->peer.in6.sin6_port);
380 s->read_packet->destination.in6 = s->local.in6.sin6_addr;
381 s->read_packet->destination_port = be16toh(s->local.in6.sin6_port);
382
383 if (s->read_packet->ifindex == 0)
384 s->read_packet->ifindex = s->peer.in6.sin6_scope_id;
385 if (s->read_packet->ifindex == 0)
386 s->read_packet->ifindex = s->local.in6.sin6_scope_id;
387 }
388 }
389
390 ss = dns_stream_read(s,
391 (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
392 sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
393 if (ss < 0) {
394 if (!IN_SET(errno, EINTR, EAGAIN))
395 return dns_stream_complete(s, errno);
396 } else if (ss == 0)
397 return dns_stream_complete(s, ECONNRESET);
398 else
399 s->n_read += ss;
400 }
401
402 /* Are we done? If so, disable the event source for EPOLLIN */
403 if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
404 /* If there's a packet handler
405 * installed, call that. Note that
406 * this is optional... */
407 if (s->on_packet) {
408 r = s->on_packet(s);
409 if (r < 0)
410 return r;
411 }
412
413 r = dns_stream_update_io(s);
414 if (r < 0)
415 return dns_stream_complete(s, -r);
416 }
417 }
418 }
419
420 if ((s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) &&
421 (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size))
422 return dns_stream_complete(s, 0);
423
424 return 0;
425 }
426
427 DnsStream *dns_stream_unref(DnsStream *s) {
428 DnsPacket *p;
429 Iterator i;
430
431 if (!s)
432 return NULL;
433
434 assert(s->n_ref > 0);
435 s->n_ref--;
436
437 if (s->n_ref > 0)
438 return NULL;
439
440 dns_stream_stop(s);
441
442 if (s->server && s->server->stream == s)
443 s->server->stream = NULL;
444
445 if (s->manager) {
446 LIST_REMOVE(streams, s->manager->dns_streams, s);
447 s->manager->n_dns_streams--;
448 }
449
450 #if ENABLE_DNS_OVER_TLS
451 if (s->encrypted)
452 dnstls_stream_free(s);
453 #endif
454
455 ORDERED_SET_FOREACH(p, s->write_queue, i)
456 dns_packet_unref(ordered_set_remove(s->write_queue, p));
457
458 dns_packet_unref(s->write_packet);
459 dns_packet_unref(s->read_packet);
460 dns_server_unref(s->server);
461
462 ordered_set_free(s->write_queue);
463
464 return mfree(s);
465 }
466
467 DnsStream *dns_stream_ref(DnsStream *s) {
468 if (!s)
469 return NULL;
470
471 assert(s->n_ref > 0);
472 s->n_ref++;
473
474 return s;
475 }
476
477 int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd, const union sockaddr_union *tfo_address) {
478 _cleanup_(dns_stream_unrefp) DnsStream *s = NULL;
479 int r;
480
481 assert(m);
482 assert(fd >= 0);
483
484 if (m->n_dns_streams > DNS_STREAMS_MAX)
485 return -EBUSY;
486
487 s = new0(DnsStream, 1);
488 if (!s)
489 return -ENOMEM;
490
491 r = ordered_set_ensure_allocated(&s->write_queue, &dns_packet_hash_ops);
492 if (r < 0)
493 return r;
494
495 s->n_ref = 1;
496 s->fd = -1;
497 s->protocol = protocol;
498
499 r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s);
500 if (r < 0)
501 return r;
502
503 (void) sd_event_source_set_description(s->io_event_source, "dns-stream-io");
504
505 r = sd_event_add_time(
506 m->event,
507 &s->timeout_event_source,
508 clock_boottime_or_monotonic(),
509 now(clock_boottime_or_monotonic()) + DNS_STREAM_TIMEOUT_USEC, 0,
510 on_stream_timeout, s);
511 if (r < 0)
512 return r;
513
514 (void) sd_event_source_set_description(s->timeout_event_source, "dns-stream-timeout");
515
516 LIST_PREPEND(streams, m->dns_streams, s);
517 s->manager = m;
518 s->fd = fd;
519 if (tfo_address) {
520 s->tfo_address = *tfo_address;
521 s->tfo_salen = tfo_address->sa.sa_family == AF_INET6 ? sizeof(tfo_address->in6) : sizeof(tfo_address->in);
522 }
523
524 m->n_dns_streams++;
525
526 *ret = TAKE_PTR(s);
527
528 return 0;
529 }
530
531 int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
532 int r;
533
534 assert(s);
535
536 r = ordered_set_put(s->write_queue, p);
537 if (r < 0)
538 return r;
539
540 dns_packet_ref(p);
541
542 return dns_stream_update_io(s);
543 }