]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/resolve/resolved-dns-stream.c
Add SPDX license identifiers to source files under the LGPL
[thirdparty/systemd.git] / src / resolve / resolved-dns-stream.c
index b72e6cc06fb4585dacace3c704dc941ca3200b0d..52f23cd864443fa6a5647d3b96147130310370b4 100644 (file)
@@ -1,5 +1,4 @@
-/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
-
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
@@ -58,8 +57,8 @@ static int dns_stream_complete(DnsStream *s, int error) {
 
         if (s->complete)
                 s->complete(s, error);
-        else
-                dns_stream_free(s);
+        else /* the default action if no completion function is set is to close the stream */
+                dns_stream_unref(s);
 
         return 0;
 }
@@ -223,7 +222,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
 
                 ss = writev(fd, iov, 2);
                 if (ss < 0) {
-                        if (errno != EINTR && errno != EAGAIN)
+                        if (!IN_SET(errno, EINTR, EAGAIN))
                                 return dns_stream_complete(s, errno);
                 } else
                         s->n_written += ss;
@@ -245,7 +244,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
 
                         ss = read(fd, (uint8_t*) &s->read_size + s->n_read, sizeof(s->read_size) - s->n_read);
                         if (ss < 0) {
-                                if (errno != EINTR && errno != EAGAIN)
+                                if (!IN_SET(errno, EINTR, EAGAIN))
                                         return dns_stream_complete(s, errno);
                         } else if (ss == 0)
                                 return dns_stream_complete(s, ECONNRESET);
@@ -262,7 +261,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
                                 ssize_t ss;
 
                                 if (!s->read_packet) {
-                                        r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size));
+                                        r = dns_packet_new(&s->read_packet, s->protocol, be16toh(s->read_size), DNS_PACKET_SIZE_MAX);
                                         if (r < 0)
                                                 return dns_stream_complete(s, -r);
 
@@ -295,7 +294,7 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
                                           (uint8_t*) DNS_PACKET_DATA(s->read_packet) + s->n_read - sizeof(s->read_size),
                                           sizeof(s->read_size) + be16toh(s->read_size) - s->n_read);
                                 if (ss < 0) {
-                                        if (errno != EINTR && errno != EAGAIN)
+                                        if (!IN_SET(errno, EINTR, EAGAIN))
                                                 return dns_stream_complete(s, errno);
                                 } else if (ss == 0)
                                         return dns_stream_complete(s, ECONNRESET);
@@ -325,10 +324,16 @@ static int on_stream_io(sd_event_source *es, int fd, uint32_t revents, void *use
         return 0;
 }
 
-DnsStream *dns_stream_free(DnsStream *s) {
+DnsStream *dns_stream_unref(DnsStream *s) {
         if (!s)
                 return NULL;
 
+        assert(s->n_ref > 0);
+        s->n_ref--;
+
+        if (s->n_ref > 0)
+                return NULL;
+
         dns_stream_stop(s);
 
         if (s->manager) {
@@ -339,15 +344,23 @@ DnsStream *dns_stream_free(DnsStream *s) {
         dns_packet_unref(s->write_packet);
         dns_packet_unref(s->read_packet);
 
-        free(s);
-
-        return 0;
+        return mfree(s);
 }
 
-DEFINE_TRIVIAL_CLEANUP_FUNC(DnsStream*, dns_stream_free);
+DEFINE_TRIVIAL_CLEANUP_FUNC(DnsStream*, dns_stream_unref);
+
+DnsStream *dns_stream_ref(DnsStream *s) {
+        if (!s)
+                return NULL;
+
+        assert(s->n_ref > 0);
+        s->n_ref++;
+
+        return s;
+}
 
 int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd) {
-        _cleanup_(dns_stream_freep) DnsStream *s = NULL;
+        _cleanup_(dns_stream_unrefp) DnsStream *s = NULL;
         int r;
 
         assert(m);
@@ -360,6 +373,7 @@ int dns_stream_new(Manager *m, DnsStream **ret, DnsProtocol protocol, int fd) {
         if (!s)
                 return -ENOMEM;
 
+        s->n_ref = 1;
         s->fd = -1;
         s->protocol = protocol;