]> git.ipfire.org Git - thirdparty/systemd.git/commitdiff
resolve: call dns_stream_take_read_packet() in on_stream_io()
authorYu Watanabe <watanabe.yu+github@gmail.com>
Thu, 27 Jan 2022 23:57:05 +0000 (08:57 +0900)
committerYu Watanabe <watanabe.yu+github@gmail.com>
Fri, 28 Jan 2022 00:00:56 +0000 (09:00 +0900)
As dns_stream_take_read_packet() is called only in on_packet callbacks,
and all on_packet callbacks call it.

src/resolve/resolved-dns-stream.c
src/resolve/resolved-dns-stream.h
src/resolve/resolved-dns-stub.c
src/resolve/resolved-dns-transaction.c
src/resolve/resolved-llmnr.c
src/resolve/test-resolved-stream.c

index bdf46170d18e1c8b5fdfab6074778f9aeed157f3..1b2db5121296a6e45f6da349323a3ecf12678848 100644 (file)
@@ -281,6 +281,22 @@ static int on_stream_timeout(sd_event_source *es, usec_t usec, void *userdata) {
         return dns_stream_complete(s, ETIMEDOUT);
 }
 
+static DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
+        assert(s);
+
+        if (!s->read_packet)
+                return NULL;
+
+        if (s->n_read < sizeof(s->read_size))
+                return NULL;
+
+        if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
+                return NULL;
+
+        s->n_read = 0;
+        return TAKE_PTR(s->read_packet);
+}
+
 static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
         bool progressed = false;
         int r;
@@ -413,9 +429,10 @@ static int on_stream_io_impl(DnsStream *s, uint32_t revents) {
 
                         /* Are we done? If so, call the packet handler and re-enable EPOLLIN for the
                          * event source if necessary. */
-                        if (s->n_read >= sizeof(s->read_size) + be16toh(s->read_size)) {
+                        _cleanup_(dns_packet_unrefp) DnsPacket *p = dns_stream_take_read_packet(s);
+                        if (p) {
                                 assert(s->on_packet);
-                                r = s->on_packet(s);
+                                r = s->on_packet(s, p);
                                 if (r < 0)
                                         return r;
 
@@ -520,7 +537,7 @@ int dns_stream_new(
                 DnsProtocol protocol,
                 int fd,
                 const union sockaddr_union *tfo_address,
-                int (on_packet)(DnsStream*),
+                int (on_packet)(DnsStream*, DnsPacket*),
                 int (complete)(DnsStream*, int), /* optional */
                 usec_t connect_timeout_usec) {
 
@@ -604,22 +621,6 @@ int dns_stream_write_packet(DnsStream *s, DnsPacket *p) {
         return dns_stream_update_io(s);
 }
 
-DnsPacket *dns_stream_take_read_packet(DnsStream *s) {
-        assert(s);
-
-        if (!s->read_packet)
-                return NULL;
-
-        if (s->n_read < sizeof(s->read_size))
-                return NULL;
-
-        if (s->n_read < sizeof(s->read_size) + be16toh(s->read_size))
-                return NULL;
-
-        s->n_read = 0;
-        return TAKE_PTR(s->read_packet);
-}
-
 void dns_stream_detach(DnsStream *s) {
         assert(s);
 
index 548b2edc9ef546057a564157521a61379e643f48..fedbab2da2cadea6fad0fe315d3d3a68898cc039 100644 (file)
@@ -78,7 +78,7 @@ struct DnsStream {
         size_t n_written, n_read;
         OrderedSet *write_queue;
 
-        int (*on_packet)(DnsStream *s);
+        int (*on_packet)(DnsStream *s, DnsPacket *p);
         int (*complete)(DnsStream *s, int error);
 
         LIST_HEAD(DnsTransaction, transactions); /* when used by the transaction logic */
@@ -100,7 +100,7 @@ int dns_stream_new(
                 DnsProtocol protocol,
                 int fd,
                 const union sockaddr_union *tfo_address,
-                int (on_packet)(DnsStream*),
+                int (on_packet)(DnsStream*, DnsPacket*),
                 int (complete)(DnsStream*, int), /* optional */
                 usec_t connect_timeout_usec);
 #if ENABLE_DNS_OVER_TLS
@@ -123,6 +123,4 @@ static inline bool DNS_STREAM_QUEUED(DnsStream *s) {
         return !!s->write_packet;
 }
 
-DnsPacket *dns_stream_take_read_packet(DnsStream *s);
-
 void dns_stream_detach(DnsStream *s);
index 7eb93f117473ff58bc3bc67948dcfc81643d28f5..992ae19bbc74c1142a00ac87d8b79764716aaa6b 100644 (file)
@@ -1044,12 +1044,9 @@ static int on_dns_stub_packet_extra(sd_event_source *s, int fd, uint32_t revents
         return on_dns_stub_packet_internal(s, fd, revents, l->manager, l);
 }
 
-static int on_dns_stub_stream_packet(DnsStream *s) {
-        _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL;
-
+static int on_dns_stub_stream_packet(DnsStream *s, DnsPacket *p) {
         assert(s);
-
-        p = dns_stream_take_read_packet(s);
+        assert(s->manager);
         assert(p);
 
         if (dns_packet_validate_query(p) > 0) {
index 20d257bbf3bb9a12f2e82ab714f006b315fbe054..f937f9f7b594cde0de4a8dde8673fdb08c9330bc 100644 (file)
@@ -644,14 +644,12 @@ static int on_stream_complete(DnsStream *s, int error) {
         return 0;
 }
 
-static int on_stream_packet(DnsStream *s) {
-        _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL;
+static int on_stream_packet(DnsStream *s, DnsPacket *p) {
         DnsTransaction *t;
 
         assert(s);
-
-        /* Take ownership of packet to be able to receive new packets */
-        assert_se(p = dns_stream_take_read_packet(s));
+        assert(s->manager);
+        assert(p);
 
         t = hashmap_get(s->manager->dns_transactions, UINT_TO_PTR(DNS_PACKET_ID(p)));
         if (t && t->stream == s) /* Validate that the stream we got this on actually is the stream the
index 150cbab18633a9ab2db80e6a6c0908297e122b3e..b4e551c219ddec303401ac84fad5af2e840e7268 100644 (file)
@@ -277,13 +277,11 @@ int manager_llmnr_ipv6_udp_fd(Manager *m) {
         return m->llmnr_ipv6_udp_fd = TAKE_FD(s);
 }
 
-static int on_llmnr_stream_packet(DnsStream *s) {
-        _cleanup_(dns_packet_unrefp) DnsPacket *p = NULL;
+static int on_llmnr_stream_packet(DnsStream *s, DnsPacket *p) {
         DnsScope *scope;
 
         assert(s);
-
-        p = dns_stream_take_read_packet(s);
+        assert(s->manager);
         assert(p);
 
         scope = manager_find_scope(s->manager, p);
index 76467629fbdf1d809c92e91644c2fca073ff02d2..8a01460a0eee3992f084659f78079a580ef3319c 100644 (file)
@@ -194,9 +194,9 @@ static const size_t MAX_RECEIVED_PACKETS = 2;
 static DnsPacket *received_packets[2] = {};
 static size_t n_received_packets = 0;
 
-static int on_stream_packet(DnsStream *stream) {
+static int on_stream_packet(DnsStream *stream, DnsPacket *p) {
         assert_se(n_received_packets < MAX_RECEIVED_PACKETS);
-        assert_se(received_packets[n_received_packets++] = dns_stream_take_read_packet(stream));
+        assert_se(received_packets[n_received_packets++] = dns_packet_ref(p));
         return 0;
 }