1 /* SPDX-License-Identifier: LGPL-2.1-or-later */
3 #include <netinet/tcp.h>
6 #include "alloc-util.h"
10 #include "missing_network.h"
11 #include "resolved-dns-stream.h"
12 #include "resolved-manager.h"
14 #define DNS_STREAMS_MAX 128
16 #define DNS_QUERIES_PER_STREAM 32
18 static void dns_stream_stop(DnsStream
*s
) {
21 s
->io_event_source
= sd_event_source_disable_unref(s
->io_event_source
);
22 s
->timeout_event_source
= sd_event_source_disable_unref(s
->timeout_event_source
);
23 s
->fd
= safe_close(s
->fd
);
25 /* Disconnect us from the server object if we are now not usable anymore */
29 static int dns_stream_update_io(DnsStream
*s
) {
34 if (s
->write_packet
&& s
->n_written
< sizeof(s
->write_size
) + s
->write_packet
->size
)
36 else if (!ordered_set_isempty(s
->write_queue
)) {
37 dns_packet_unref(s
->write_packet
);
38 s
->write_packet
= ordered_set_steal_first(s
->write_queue
);
39 s
->write_size
= htobe16(s
->write_packet
->size
);
44 /* Let's read a packet if we haven't queued any yet. Except if we already hit a limit of parallel
45 * queries for this connection. */
46 if ((!s
->read_packet
|| s
->n_read
< sizeof(s
->read_size
) + s
->read_packet
->size
) &&
47 set_size(s
->queries
) < DNS_QUERIES_PER_STREAM
)
50 #if ENABLE_DNS_OVER_TLS
51 /* For handshake and clean closing purposes, TLS can override requested events */
52 if (s
->dnstls_events
!= 0)
56 return sd_event_source_set_io_events(s
->io_event_source
, f
);
59 static int dns_stream_complete(DnsStream
*s
, int error
) {
60 _cleanup_(dns_stream_unrefp
) _unused_ DnsStream
*ref
= dns_stream_ref(s
); /* Protect stream while we process it */
65 /* Error is > 0 when the connection failed for some reason in the network stack. It's == 0 if we sent
66 * and received exactly one packet each (in the LLMNR client case). */
68 #if ENABLE_DNS_OVER_TLS
72 r
= dnstls_stream_shutdown(s
, error
);
82 s
->complete(s
, error
);
83 else /* the default action if no completion function is set is to close the stream */
89 static int dns_stream_identify(DnsStream
*s
) {
90 CMSG_BUFFER_TYPE(CMSG_SPACE(MAXSIZE(struct in_pktinfo
, struct in6_pktinfo
))
91 + CMSG_SPACE(int) + /* for the TTL */
92 + EXTRA_CMSG_SPACE
/* kernel appears to require extra space */) control
;
93 struct msghdr mh
= {};
103 /* Query the local side */
104 s
->local_salen
= sizeof(s
->local
);
105 r
= getsockname(s
->fd
, &s
->local
.sa
, &s
->local_salen
);
108 if (s
->local
.sa
.sa_family
== AF_INET6
&& s
->ifindex
<= 0)
109 s
->ifindex
= s
->local
.in6
.sin6_scope_id
;
111 /* Query the remote side */
112 s
->peer_salen
= sizeof(s
->peer
);
113 r
= getpeername(s
->fd
, &s
->peer
.sa
, &s
->peer_salen
);
116 if (s
->peer
.sa
.sa_family
== AF_INET6
&& s
->ifindex
<= 0)
117 s
->ifindex
= s
->peer
.in6
.sin6_scope_id
;
119 /* Check consistency */
120 assert(s
->peer
.sa
.sa_family
== s
->local
.sa
.sa_family
);
121 assert(IN_SET(s
->peer
.sa
.sa_family
, AF_INET
, AF_INET6
));
123 /* Query connection meta information */
124 sl
= sizeof(control
);
125 if (s
->peer
.sa
.sa_family
== AF_INET
) {
126 r
= getsockopt(s
->fd
, IPPROTO_IP
, IP_PKTOPTIONS
, &control
, &sl
);
129 } else if (s
->peer
.sa
.sa_family
== AF_INET6
) {
131 r
= getsockopt(s
->fd
, IPPROTO_IPV6
, IPV6_2292PKTOPTIONS
, &control
, &sl
);
135 return -EAFNOSUPPORT
;
137 mh
.msg_control
= &control
;
138 mh
.msg_controllen
= sl
;
140 CMSG_FOREACH(cmsg
, &mh
) {
142 if (cmsg
->cmsg_level
== IPPROTO_IPV6
) {
143 assert(s
->peer
.sa
.sa_family
== AF_INET6
);
145 switch (cmsg
->cmsg_type
) {
148 struct in6_pktinfo
*i
= (struct in6_pktinfo
*) CMSG_DATA(cmsg
);
151 s
->ifindex
= i
->ipi6_ifindex
;
156 s
->ttl
= *(int *) CMSG_DATA(cmsg
);
160 } else if (cmsg
->cmsg_level
== IPPROTO_IP
) {
161 assert(s
->peer
.sa
.sa_family
== AF_INET
);
163 switch (cmsg
->cmsg_type
) {
166 struct in_pktinfo
*i
= (struct in_pktinfo
*) CMSG_DATA(cmsg
);
169 s
->ifindex
= i
->ipi_ifindex
;
174 s
->ttl
= *(int *) CMSG_DATA(cmsg
);
180 /* The Linux kernel sets the interface index to the loopback
181 * device if the connection came from the local host since it
182 * avoids the routing table in such a case. Let's unset the
183 * interface index in such a case. */
184 if (s
->ifindex
== LOOPBACK_IFINDEX
)
187 /* If we don't know the interface index still, we look for the
188 * first local interface with a matching address. Yuck! */
190 s
->ifindex
= manager_find_ifindex(s
->manager
, s
->local
.sa
.sa_family
, sockaddr_in_addr(&s
->local
.sa
));
192 if (s
->protocol
== DNS_PROTOCOL_LLMNR
&& s
->ifindex
> 0) {
193 /* Make sure all packets for this connection are sent on the same interface */
194 r
= socket_set_unicast_if(s
->fd
, s
->local
.sa
.sa_family
, s
->ifindex
);
196 log_debug_errno(errno
, "Failed to invoke IP_UNICAST_IF/IPV6_UNICAST_IF: %m");
199 s
->identified
= true;
204 ssize_t
dns_stream_writev(DnsStream
*s
, const struct iovec
*iov
, size_t iovcnt
, int flags
) {
210 #if ENABLE_DNS_OVER_TLS
211 if (s
->encrypted
&& !(flags
& DNS_STREAM_WRITE_TLS_DATA
)) {
216 for (i
= 0; i
< iovcnt
; i
++) {
217 ss
= dnstls_stream_write(s
, iov
[i
].iov_base
, iov
[i
].iov_len
);
222 if (ss
!= (ssize_t
) iov
[i
].iov_len
)
227 if (s
->tfo_salen
> 0) {
228 struct msghdr hdr
= {
229 .msg_iov
= (struct iovec
*) iov
,
230 .msg_iovlen
= iovcnt
,
231 .msg_name
= &s
->tfo_address
.sa
,
232 .msg_namelen
= s
->tfo_salen
235 m
= sendmsg(s
->fd
, &hdr
, MSG_FASTOPEN
);
237 if (errno
== EOPNOTSUPP
) {
239 if (connect(s
->fd
, &s
->tfo_address
.sa
, s
->tfo_salen
) < 0)
244 if (errno
== EINPROGRESS
)
249 s
->tfo_salen
= 0; /* connection is made */
251 m
= writev(s
->fd
, iov
, iovcnt
);
259 static ssize_t
dns_stream_read(DnsStream
*s
, void *buf
, size_t count
) {
262 #if ENABLE_DNS_OVER_TLS
264 ss
= dnstls_stream_read(s
, buf
, count
);
268 ss
= read(s
->fd
, buf
, count
);
276 static int on_stream_timeout(sd_event_source
*es
, usec_t usec
, void *userdata
) {
277 DnsStream
*s
= userdata
;
281 return dns_stream_complete(s
, ETIMEDOUT
);
284 static int on_stream_io_impl(DnsStream
*s
, uint32_t revents
) {
285 bool progressed
= false;
290 /* This returns 1 when possible remaining stream exists, 0 on completed
291 stream or recoverable error, and negative errno on failure. */
293 #if ENABLE_DNS_OVER_TLS
295 r
= dnstls_stream_on_io(s
, revents
);
296 if (r
== DNSTLS_STREAM_CLOSED
)
299 return dns_stream_update_io(s
);
301 return dns_stream_complete(s
, -r
);
303 r
= dns_stream_update_io(s
);
309 /* only identify after connecting */
310 if (s
->tfo_salen
== 0) {
311 r
= dns_stream_identify(s
);
313 return dns_stream_complete(s
, -r
);
316 if ((revents
& EPOLLOUT
) &&
318 s
->n_written
< sizeof(s
->write_size
) + s
->write_packet
->size
) {
320 struct iovec iov
[] = {
321 IOVEC_MAKE(&s
->write_size
, sizeof(s
->write_size
)),
322 IOVEC_MAKE(DNS_PACKET_DATA(s
->write_packet
), s
->write_packet
->size
),
325 IOVEC_INCREMENT(iov
, ELEMENTSOF(iov
), s
->n_written
);
327 ssize_t ss
= dns_stream_writev(s
, iov
, ELEMENTSOF(iov
), 0);
329 if (!ERRNO_IS_TRANSIENT(ss
))
330 return dns_stream_complete(s
, -ss
);
336 /* Are we done? If so, disable the event source for EPOLLOUT */
337 if (s
->n_written
>= sizeof(s
->write_size
) + s
->write_packet
->size
) {
338 r
= dns_stream_update_io(s
);
340 return dns_stream_complete(s
, -r
);
344 if ((revents
& (EPOLLIN
|EPOLLHUP
|EPOLLRDHUP
)) &&
346 s
->n_read
< sizeof(s
->read_size
) + s
->read_packet
->size
)) {
348 if (s
->n_read
< sizeof(s
->read_size
)) {
351 ss
= dns_stream_read(s
, (uint8_t*) &s
->read_size
+ s
->n_read
, sizeof(s
->read_size
) - s
->n_read
);
353 if (!ERRNO_IS_TRANSIENT(ss
))
354 return dns_stream_complete(s
, -ss
);
356 return dns_stream_complete(s
, ECONNRESET
);
363 if (s
->n_read
>= sizeof(s
->read_size
)) {
365 if (be16toh(s
->read_size
) < DNS_PACKET_HEADER_SIZE
)
366 return dns_stream_complete(s
, EBADMSG
);
368 if (s
->n_read
< sizeof(s
->read_size
) + be16toh(s
->read_size
)) {
371 if (!s
->read_packet
) {
372 r
= dns_packet_new(&s
->read_packet
, s
->protocol
, be16toh(s
->read_size
), DNS_PACKET_SIZE_MAX
);
374 return dns_stream_complete(s
, -r
);
376 s
->read_packet
->size
= be16toh(s
->read_size
);
377 s
->read_packet
->ipproto
= IPPROTO_TCP
;
378 s
->read_packet
->family
= s
->peer
.sa
.sa_family
;
379 s
->read_packet
->ttl
= s
->ttl
;
380 s
->read_packet
->ifindex
= s
->ifindex
;
381 s
->read_packet
->timestamp
= now(clock_boottime_or_monotonic());
383 if (s
->read_packet
->family
== AF_INET
) {
384 s
->read_packet
->sender
.in
= s
->peer
.in
.sin_addr
;
385 s
->read_packet
->sender_port
= be16toh(s
->peer
.in
.sin_port
);
386 s
->read_packet
->destination
.in
= s
->local
.in
.sin_addr
;
387 s
->read_packet
->destination_port
= be16toh(s
->local
.in
.sin_port
);
389 assert(s
->read_packet
->family
== AF_INET6
);
390 s
->read_packet
->sender
.in6
= s
->peer
.in6
.sin6_addr
;
391 s
->read_packet
->sender_port
= be16toh(s
->peer
.in6
.sin6_port
);
392 s
->read_packet
->destination
.in6
= s
->local
.in6
.sin6_addr
;
393 s
->read_packet
->destination_port
= be16toh(s
->local
.in6
.sin6_port
);
395 if (s
->read_packet
->ifindex
== 0)
396 s
->read_packet
->ifindex
= s
->peer
.in6
.sin6_scope_id
;
397 if (s
->read_packet
->ifindex
== 0)
398 s
->read_packet
->ifindex
= s
->local
.in6
.sin6_scope_id
;
402 ss
= dns_stream_read(s
,
403 (uint8_t*) DNS_PACKET_DATA(s
->read_packet
) + s
->n_read
- sizeof(s
->read_size
),
404 sizeof(s
->read_size
) + be16toh(s
->read_size
) - s
->n_read
);
406 if (!ERRNO_IS_TRANSIENT(ss
))
407 return dns_stream_complete(s
, -ss
);
409 return dns_stream_complete(s
, ECONNRESET
);
414 /* Are we done? If so, disable the event source for EPOLLIN */
415 if (s
->n_read
>= sizeof(s
->read_size
) + be16toh(s
->read_size
)) {
416 /* If there's a packet handler
417 * installed, call that. Note that
418 * this is optional... */
425 r
= dns_stream_update_io(s
);
427 return dns_stream_complete(s
, -r
);
432 /* Call "complete" callback if finished reading and writing one packet, and there's nothing else left
434 if (s
->type
== DNS_STREAM_LLMNR_SEND
&&
435 (s
->write_packet
&& s
->n_written
>= sizeof(s
->write_size
) + s
->write_packet
->size
) &&
436 ordered_set_isempty(s
->write_queue
) &&
437 (s
->read_packet
&& s
->n_read
>= sizeof(s
->read_size
) + s
->read_packet
->size
))
438 return dns_stream_complete(s
, 0);
440 /* If we did something, let's restart the timeout event source */
441 if (progressed
&& s
->timeout_event_source
) {
442 r
= sd_event_source_set_time_relative(s
->timeout_event_source
, DNS_STREAM_ESTABLISHED_TIMEOUT_USEC
);
444 log_warning_errno(errno
, "Couldn't restart TCP connection timeout, ignoring: %m");
450 static int on_stream_io(sd_event_source
*es
, int fd
, uint32_t revents
, void *userdata
) {
451 _cleanup_(dns_stream_unrefp
) DnsStream
*s
= dns_stream_ref(userdata
); /* Protect stream while we process it */
456 r
= on_stream_io_impl(s
, revents
);
460 #if ENABLE_DNS_OVER_TLS
464 /* When using DNS-over-TLS, the underlying TLS library may read the entire TLS record
465 and buffer it internally. If this happens, we will not receive further EPOLLIN events,
466 and unless there's some unrelated activity on the socket, we will hang until time out.
467 To avoid this, if there's buffered TLS data, generate a "fake" EPOLLIN event.
468 This is hacky, but it makes this case transparent to the rest of the IO code. */
469 while (dnstls_stream_has_buffered_data(s
)) {
472 /* Make sure the stream still wants to process more data... */
473 r
= sd_event_source_get_io_events(s
->io_event_source
, &events
);
476 if (!FLAGS_SET(events
, EPOLLIN
))
479 r
= on_stream_io_impl(s
, EPOLLIN
);
488 static DnsStream
*dns_stream_free(DnsStream
*s
) {
496 LIST_REMOVE(streams
, s
->manager
->dns_streams
, s
);
497 s
->manager
->n_dns_streams
[s
->type
]--;
500 #if ENABLE_DNS_OVER_TLS
502 dnstls_stream_free(s
);
505 ORDERED_SET_FOREACH(p
, s
->write_queue
)
506 dns_packet_unref(ordered_set_remove(s
->write_queue
, p
));
508 dns_packet_unref(s
->write_packet
);
509 dns_packet_unref(s
->read_packet
);
510 dns_server_unref(s
->server
);
512 ordered_set_free(s
->write_queue
);
517 DEFINE_TRIVIAL_REF_UNREF_FUNC(DnsStream
, dns_stream
, dns_stream_free
);
523 DnsProtocol protocol
,
525 const union sockaddr_union
*tfo_address
,
526 usec_t connect_timeout_usec
) {
528 _cleanup_(dns_stream_unrefp
) DnsStream
*s
= NULL
;
534 assert(type
< _DNS_STREAM_TYPE_MAX
);
535 assert(protocol
>= 0);
536 assert(protocol
< _DNS_PROTOCOL_MAX
);
539 if (m
->n_dns_streams
[type
] > DNS_STREAMS_MAX
)
542 s
= new(DnsStream
, 1);
549 .protocol
= protocol
,
553 r
= ordered_set_ensure_allocated(&s
->write_queue
, &dns_packet_hash_ops
);
557 r
= sd_event_add_io(m
->event
, &s
->io_event_source
, fd
, EPOLLIN
, on_stream_io
, s
);
561 (void) sd_event_source_set_description(s
->io_event_source
, "dns-stream-io");
563 r
= sd_event_add_time_relative(
565 &s
->timeout_event_source
,
566 clock_boottime_or_monotonic(),
567 connect_timeout_usec
, 0,
568 on_stream_timeout
, s
);
572 (void) sd_event_source_set_description(s
->timeout_event_source
, "dns-stream-timeout");
574 LIST_PREPEND(streams
, m
->dns_streams
, s
);
575 m
->n_dns_streams
[type
]++;
581 s
->tfo_address
= *tfo_address
;
582 s
->tfo_salen
= tfo_address
->sa
.sa_family
== AF_INET6
? sizeof(tfo_address
->in6
) : sizeof(tfo_address
->in
);
590 int dns_stream_write_packet(DnsStream
*s
, DnsPacket
*p
) {
596 r
= ordered_set_put(s
->write_queue
, p
);
602 return dns_stream_update_io(s
);
605 DnsPacket
*dns_stream_take_read_packet(DnsStream
*s
) {
611 if (s
->n_read
< sizeof(s
->read_size
))
614 if (s
->n_read
< sizeof(s
->read_size
) + be16toh(s
->read_size
))
618 return TAKE_PTR(s
->read_packet
);
621 void dns_stream_detach(DnsStream
*s
) {
627 if (s
->server
->stream
!= s
)
630 dns_server_unref_stream(s
->server
);