]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/resolve/resolved-dns-stream.c
Add SPDX license identifiers to source files under the LGPL
[thirdparty/systemd.git] / src / resolve / resolved-dns-stream.c
index 47130c4231923a42219e73358b3f602e678e8b67..52f23cd864443fa6a5647d3b96147130310370b4 100644 (file)
@@ -1,5 +1,4 @@
-/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
-
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
@@ -21,6 +20,9 @@
 
 #include <netinet/tcp.h>
 
+#include "alloc-util.h"
+#include "fd-util.h"
+#include "io-util.h"
 #include "missing.h"
 #include "resolved-dns-stream.h"
 
@@ -55,8 +57,8 @@ static int dns_stream_complete(DnsStream *s, int error) {
 
         if (s->complete)
                 s->complete(s, error);
-        else
-                dns_stream_free(s);
+        else /* the default action if no completion function is set is to close the stream */
+                dns_stream_unref(s);
 
         return 0;
 }
@@ -64,7 +66,7 @@ static int dns_stream_complete(DnsStream *s, int error) {
 static int dns_stream_identify(DnsStream *s) {
         union {
                 struct cmsghdr header; /* For alignment */
-                uint8_t buffer[CMSG_SPACE(MAX(sizeof(struct in_pktinfo), sizeof(struct in6_pktinfo)))
+                uint8_t buffer[CMSG_SPACE(MAXSIZE(struct in_pktinfo, struct in6_pktinfo))
                                + EXTRA_CMSG_SPACE /* kernel appears to require extra space */];
         } control;
         struct msghdr mh = {};
@@ -113,7 +115,8 @@ static int dns_stream_identify(DnsStream *s) {
 
         mh.msg_control = &control;
         mh.msg_controllen = sl;
-        for (cmsg = CMSG_FIRSTHDR(&mh); cmsg; cmsg = CMSG_NXTHDR(&mh, cmsg)) {
+
+        CMSG_FOREACH(cmsg, &mh) {
 
                 if (cmsg->cmsg_level == IPPROTO_IPV6) {
                         assert(s->peer.sa.sa_family == AF_INET6);
@@ -157,7 +160,7 @@ static int dns_stream_identify(DnsStream *s) {
          * device if the connection came from the local host since it
          * avoids the routing table in such a case. Let's unset the
          * interface index in such a case. */
-        if (s->ifindex > 0 && manager_ifindex_is_loopback(s->manager, s->ifindex) != 0)
+        if (s->ifindex == LOOPBACK_IFINDEX)
                 s->ifindex = 0;
 
         /* If we don't know the interface index still, we look for the
@@ -172,11 +175,11 @@ static int dns_stream_identify(DnsStream *s) {
                 if (s->local.sa.sa_family == AF_INET) {
                         r = setsockopt(s->fd, IPPROTO_IP, IP_UNICAST_IF, &ifindex, sizeof(ifindex));
                         if (r < 0)
-                                return -errno;
+                                log_debug_errno(errno, "Failed to invoke IP_UNICAST_IF: %m");
                 } else if (s->local.sa.sa_family == AF_INET6) {
                         r = setsockopt(s->fd, IPPROTO_IPV6, IPV6_UNICAST_IF, &ifindex, sizeof(ifindex));
                         if (r < 0)
-                                return -errno;
+                                log_debug_errno(errno, "Failed to invoke IPV6_UNICAST_IF: %m");
                 }
         }
 
@@ -219,7 +222,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
 
                 ss = writev(fd, iov, 2);
                 if (ss < 0) {
-                        if (errno != EINTR && errno != EAGAIN)
+                        if (!IN_SET(errno, EINTR, EAGAIN))
                                 return dns_stream_complete(s, errno);
                 } else
                         s->n_written += ss;
@@ -241,7 +244,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
 
                         ss = read(fd, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
                         if (ss < 0) {
-                                if (errno != EINTR && errno != EAGAIN)
+                                if (!IN_SET(errno, EINTR, EAGAIN))
                                         return dns_stream_complete(s, errno);
                         } else if (ss == 0)
                                 return dns_stream_complete(s, ECONNRESET);
@@ -258,7 +261,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
                                 ssize_t ss;
 
                                 if (!s->read_packet) {
-                                        r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size));
+                                        r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size), DNS_PACKET_SIZE_MAX);
                                         if (r < 0)
                                                 return dns_stream_complete(s, -r);
 
@@ -291,7 +294,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
                                           (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
                                           sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
                                 if (ss < 0) {
-                                        if (errno != EINTR && errno != EAGAIN)
+                                        if (!IN_SET(errno, EINTR, EAGAIN))
                                                 return dns_stream_complete(s, errno);
                                 } else if (ss == 0)
                                         return dns_stream_complete(s, ECONNRESET);
@@ -321,10 +324,16 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
         return 0;
 }
 
-DnsStream *dns_stream_free(DnsStream *s) {
+DnsStream *dns_stream_unref(DnsStream *s) {
         if (!s)
                 return NULL;
 
+        assert(s->n_ref > 0);
+        s->n_ref--;
+
+        if (s->n_ref > 0)
+                return NULL;
+
         dns_stream_stop(s);
 
         if (s->manager) {
@@ -335,16 +344,23 @@ DnsStream *dns_stream_free(DnsStream *s) {
         dns_packet_unref(s->write_packet);
         dns_packet_unref(s->read_packet);
 
-        free(s);
-
-        return 0;
+        return mfree(s);
 }
 
-DEFINE_TRIVIAL_CLEANUP_FUNC(DnsStream*, dns_stream_free);
+DEFINE_TRIVIAL_CLEANUP_FUNC(DnsStream*, dns_stream_unref);
+
+DnsStream *dns_stream_ref(DnsStream *s) {
+        if (!s)
+                return NULL;
+
+        assert(s->n_ref > 0);
+        s->n_ref++;
+
+        return s;
+}
 
 int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd) {
-        static const int one = 1;
-        _cleanup_(dns_stream_freep) DnsStream *s = NULL;
+        _cleanup_(dns_stream_unrefp) DnsStream *s = NULL;
         int r;
 
         assert(m);
@@ -357,21 +373,27 @@ int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd) {
         if (!s)
                 return -ENOMEM;
 
+        s->n_ref = 1;
         s->fd = -1;
         s->protocol = protocol;
 
-        r = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one));
-        if (r < 0)
-                return -errno;
-
         r = sd_event_add_io(m->event, &s->io_event_source, fd, EPOLLIN, on_stream_io, s);
         if (r < 0)
                 return r;
 
-        r = sd_event_add_time(m->event, &s->timeout_event_source, CLOCK_MONOTONIC, now(CLOCK_MONOTONIC) + DNS_STREAM_TIMEOUT_USEC, 0, on_stream_timeout, s);
+        (void) sd_event_source_set_description(s->io_event_source, "dns-stream-io");
+
+        r = sd_event_add_time(
+                        m->event,
+                        &s->timeout_event_source,
+                        clock_boottime_or_monotonic(),
+                        now(clock_boottime_or_monotonic()) + DNS_STREAM_TIMEOUT_USEC, 0,
+                        on_stream_timeout, s);
         if (r < 0)
                 return r;
 
+        (void) sd_event_source_set_description(s->timeout_event_source, "dns-stream-timeout");
+
         LIST_PREPEND(streams, m->dns_streams, s);
         s->manager = m;
         s->fd = fd;