]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/resolve/resolved-dns-stream.c
tree-wide: use -EBADF for fd initialization
[thirdparty/systemd.git] / src / resolve / resolved-dns-stream.c
index e6f72f00b43a28553b107f5ba52bd696481cbe48..0a10a0d17e50ca98e04e0d1fb6534fc27437e323 100644 (file)
@@ -1,4 +1,4 @@
-/* SPDX-License-Identifier: LGPL-2.1+ */
+/* SPDX-License-Identifier: LGPL-2.1-or-later */
 
 #include <netinet/tcp.h>
 #include <unistd.h>
@@ -6,11 +6,11 @@
 #include "alloc-util.h"
 #include "fd-util.h"
 #include "io-util.h"
+#include "macro.h"
 #include "missing_network.h"
 #include "resolved-dns-stream.h"
 #include "resolved-manager.h"
 
-#define DNS_STREAM_TIMEOUT_USEC (10 * USEC_PER_SEC)
 #define DNS_STREAMS_MAX 128
 
 #define DNS_QUERIES_PER_STREAM 32
@@ -18,8 +18,8 @@
 static void dns_stream_stop(DnsStream *s) {
         assert(s);
 
-        s->io_event_source = sd_event_source_unref(s->io_event_source);
-        s->timeout_event_source = sd_event_source_unref(s->timeout_event_source);
+        s->io_event_source = sd_event_source_disable_unref(s->io_event_source);
+        s->timeout_event_source = sd_event_source_disable_unref(s->timeout_event_source);
         s->fd = safe_close(s->fd);
 
         /* Disconnect us from the server object if we are now not usable anymore */
@@ -27,7 +27,7 @@ static void dns_stream_stop(DnsStream *s) {
 }
 
 static int dns_stream_update_io(DnsStream *s) {
-        int f = 0;
+        uint32_t f = 0;
 
         assert(s);
 
@@ -47,6 +47,8 @@ static int dns_stream_update_io(DnsStream *s) {
                 set_size(s->queries) < DNS_QUERIES_PER_STREAM)
                 f |= EPOLLIN;
 
+        s->requested_events = f;
+
 #if ENABLE_DNS_OVER_TLS
         /* For handshake and clean closing purposes, TLS can override requested events */
         if (s->dnstls_events != 0)
@@ -187,7 +189,7 @@ static int dns_stream_identify(DnsStream *s) {
         /* If we don't know the interface index still, we look for the
          * first local interface with a matching address. Yuck! */
         if (s->ifindex <= 0)
-                s->ifindex = manager_find_ifindex(s->manager, s->local.sa.sa_family, s->local.sa.sa_family == AF_INET ? (union in_addr_union*) &s->local.in.sin_addr : (union in_addr_union*)  &s->local.in6.sin6_addr);
+                s->ifindex = manager_find_ifindex(s->manager, s->local.sa.sa_family, sockaddr_in_addr(&s->local.sa));
 
         if (s->protocol == DNS_PROTOCOL_LLMNR && s->ifindex > 0) {
                 /* Make sure all packets for this connection are sent on the same interface */
@@ -208,22 +210,10 @@ ssize_t dns_stream_writev(DnsStream *s, const struct iovec *iov, size_t iovcnt,
         assert(iov);
 
 #if ENABLE_DNS_OVER_TLS
-        if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA)) {
-                ssize_t ss;
-                size_t i;
-
-                m = 0;
-                for (i = 0; i < iovcnt; i++) {
-                        ss = dnstls_stream_write(s, iov[i].iov_base, iov[i].iov_len);
-                        if (ss < 0)
-                                return ss;
-
-                        m += ss;
-                        if (ss != (ssize_t) iov[i].iov_len)
-                                continue;
-                }
-        } else
+        if (s->encrypted && !(flags & DNS_STREAM_WRITE_TLS_DATA))
+                return dnstls_stream_writev(s, iov, iovcnt);
 #endif
+
         if (s->tfo_salen > 0) {
                 struct msghdr hdr = {
                         .msg_iov = (struct iovec*) iov,
@@ -274,11 +264,32 @@ static ssize_t dns_stream_read(DnsStream *s, void *buf, size_t count) {
 }
 
 static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
-        DnsStream *s = userdata;
+        DnsStream *s = ASSERT_PTR(userdata);
 
+        return dns_stream_complete(s, ETIMEDOUT);
+}
+
+static DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
         assert(s);
 
-        return dns_stream_complete(s, ETIMEDOUT);
+        /* Note, dns_stream_update() should be called after this is called. When this is called, the
+         * stream may be already full and the EPOLLIN flag is dropped from the stream IO event source.
+         * Even this makes a room to read in the stream, this does not call dns_stream_update(), hence
+         * EPOLLIN flag is not set automatically. So, to read further packets from the stream,
+         * dns_stream_update() must be called explicitly. Currently, this is only called from
+         * on_stream_io(), and there dns_stream_update() is called. */
+
+        if (!s->read_packet)
+                return NULL;
+
+        if (s->n_read < sizeof(s->read_size))
+                return NULL;
+
+        if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
+                return NULL;
+
+        s->n_read = 0;
+        return TAKE_PTR(s->read_packet);
 }
 
 static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *userdata) {
@@ -315,17 +326,16 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
             s->write_packet &&
             s->n_written < sizeof(s->write_size) + s->write_packet->size) {
 
-                struct iovec iov[2];
-                ssize_t ss;
-
-                iov[0] = IOVEC_MAKE(&s->write_size, sizeof(s->write_size));
-                iov[1] = IOVEC_MAKE(DNS_PACKET_DATA(s->write_packet), s->write_packet->size);
+                struct iovec iov[] = {
+                        IOVEC_MAKE(&s->write_size, sizeof(s->write_size)),
+                        IOVEC_MAKE(DNS_PACKET_DATA(s->write_packet), s->write_packet->size),
+                };
 
-                IOVEC_INCREMENT(iov, 2, s->n_written);
+                IOVEC_INCREMENT(iov, ELEMENTSOF(iov), s->n_written);
 
-                ss = dns_stream_writev(s, iov, 2, 0);
+                ssize_t ss = dns_stream_writev(s, iov, ELEMENTSOF(iov), 0);
                 if (ss < 0) {
-                        if (!IN_SET(-ss, EINTR, EAGAIN))
+                        if (!ERRNO_IS_TRANSIENT(ss))
                                 return dns_stream_complete(s, -ss);
                 } else {
                         progressed = true;
@@ -340,17 +350,18 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
                 }
         }
 
-        if ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
-            (!s->read_packet ||
-             s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
+        while ((revents & (EPOLLIN|EPOLLHUP|EPOLLRDHUP)) &&
+               (!s->read_packet ||
+                s->n_read < sizeof(s->read_size) + s->read_packet->size)) {
 
                 if (s->n_read < sizeof(s->read_size)) {
                         ssize_t ss;
 
                         ss = dns_stream_read(s, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
                         if (ss < 0) {
-                                if (!IN_SET(-ss, EINTR, EAGAIN))
+                                if (!ERRNO_IS_TRANSIENT(ss))
                                         return dns_stream_complete(s, -ss);
+                                break;
                         } else if (ss == 0)
                                 return dns_stream_complete(s, ECONNRESET);
                         else {
@@ -377,6 +388,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
                                         s->read_packet->family = s->peer.sa.sa_family;
                                         s->read_packet->ttl = s->ttl;
                                         s->read_packet->ifindex = s->ifindex;
+                                        s->read_packet->timestamp = now(CLOCK_BOOTTIME);
 
                                         if (s->read_packet->family == AF_INET) {
                                                 s->read_packet->sender.in = s->peer.in.sin_addr;
@@ -401,43 +413,46 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
                                           (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
                                           sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
                                 if (ss < 0) {
-                                        if (!IN_SET(-ss, EINTR, EAGAIN))
+                                        if (!ERRNO_IS_TRANSIENT(ss))
                                                 return dns_stream_complete(s, -ss);
+                                        break;
                                 } else if (ss == 0)
                                         return dns_stream_complete(s, ECONNRESET);
                                 else
                                         s->n_read += ss;
                         }
 
-                        /* Are we done? If so, disable the event source for EPOLLIN */
-                        if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
-                                /* If there's a packet handler
-                                 * installed, call that. Note that
-                                 * this is optional... */
-                                if (s->on_packet) {
-                                        r = s->on_packet(s);
-                                        if (r < 0)
-                                                return r;
-                                }
+                        /* Are we done? If so, call the packet handler and re-enable EPOLLIN for the
+                         * event source if necessary. */
+                        _cleanup_(dns_packet_unrefp) DnsPacket *p = dns_stream_take_read_packet(s);
+                        if (p) {
+                                assert(s->on_packet);
+                                r = s->on_packet(s, p);
+                                if (r < 0)
+                                        return r;
 
                                 r = dns_stream_update_io(s);
                                 if (r < 0)
                                         return dns_stream_complete(s, -r);
+
+                                s->packet_received = true;
+
+                                /* If we just disabled the read event, stop reading */
+                                if (!FLAGS_SET(s->requested_events, EPOLLIN))
+                                        break;
                         }
                 }
         }
 
-        /* Call "complete" callback if finished reading and writing one packet, and there's nothing else left
-         * to write. */
-        if (s->type == DNS_STREAM_LLMNR_SEND &&
-            (s->write_packet && s->n_written >= sizeof(s->write_size) + s->write_packet->size) &&
-            ordered_set_isempty(s->write_queue) &&
-            (s->read_packet && s->n_read >= sizeof(s->read_size) + s->read_packet->size))
+        /* Complete the stream if finished reading and writing one packet, and there's nothing
+         * else left to write. */
+        if (s->type == DNS_STREAM_LLMNR_SEND && s->packet_received &&
+            !FLAGS_SET(s->requested_events, EPOLLOUT))
                 return dns_stream_complete(s, 0);
 
         /* If we did something, let's restart the timeout event source */
         if (progressed && s->timeout_event_source) {
-                r = sd_event_source_set_time_relative(s->timeout_event_source, DNS_STREAM_TIMEOUT_USEC);
+                r = sd_event_source_set_time_relative(s->timeout_event_source, DNS_STREAM_ESTABLISHED_TIMEOUT_USEC);
                 if (r < 0)
                         log_warning_errno(errno, "Couldn't restart TCP connection timeout, ignoring: %m");
         }
@@ -482,7 +497,10 @@ int dns_stream_new(
                 DnsStreamType type,
                 DnsProtocol protocol,
                 int fd,
-                const union sockaddr_union *tfo_address) {
+                const union sockaddr_union *tfo_address,
+                int (on_packet)(DnsStream*, DnsPacket*),
+                int (complete)(DnsStream*, int), /* optional */
+                usec_t connect_timeout_usec) {
 
         _cleanup_(dns_stream_unrefp) DnsStream *s = NULL;
         int r;
@@ -494,6 +512,7 @@ int dns_stream_new(
         assert(protocol >= 0);
         assert(protocol < _DNS_PROTOCOL_MAX);
         assert(fd >= 0);
+        assert(on_packet);
 
         if (m->n_dns_streams[type] > DNS_STREAMS_MAX)
                 return -EBUSY;
@@ -504,7 +523,7 @@ int dns_stream_new(
 
         *s = (DnsStream) {
                 .n_ref = 1,
-                .fd = -1,
+                .fd = -EBADF,
                 .protocol = protocol,
                 .type = type,
         };
@@ -522,8 +541,8 @@ int dns_stream_new(
         r = sd_event_add_time_relative(
                         m->event,
                         &s->timeout_event_source,
-                        clock_boottime_or_monotonic(),
-                        DNS_STREAM_TIMEOUT_USEC, 0,
+                        CLOCK_BOOTTIME,
+                        connect_timeout_usec, 0,
                         on_stream_timeout, s);
         if (r < 0)
                 return r;
@@ -535,6 +554,8 @@ int dns_stream_new(
         s->manager = m;
 
         s->fd = fd;
+        s->on_packet = on_packet;
+        s->complete = complete;
 
         if (tfo_address) {
                 s->tfo_address = *tfo_address;
@@ -561,22 +582,6 @@ int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
         return dns_stream_update_io(s);
 }
 
-DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
-        assert(s);
-
-        if (!s->read_packet)
-                return NULL;
-
-        if (s->n_read < sizeof(s->read_size))
-                return NULL;
-
-        if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
-                return NULL;
-
-        s->n_read = 0;
-        return TAKE_PTR(s->read_packet);
-}
-
 void dns_stream_detach(DnsStream *s) {
         assert(s);