]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/resolve/resolved-dns-packet.c
Add SPDX license identifiers to source files under the LGPL
[thirdparty/systemd.git] / src / resolve / resolved-dns-packet.c
index a8ad8fe342b88a919863ca09ec0718948b39010c..40f35475fdce63344927c304e969029adc2ba615 100644 (file)
@@ -1,3 +1,4 @@
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
@@ -28,6 +29,8 @@
 
 #define EDNS0_OPT_DO (1<<15)
 
+assert_cc(DNS_PACKET_SIZE_START > DNS_PACKET_HEADER_SIZE)
+
 typedef struct DnsPacketRewinder {
         DnsPacket *packet;
         size_t saved_rindex;
@@ -41,26 +44,44 @@ static void rewind_dns_packet(DnsPacketRewinder *rewinder) {
 #define INIT_REWINDER(rewinder, p) do { rewinder.packet = p; rewinder.saved_rindex = p->rindex; } while (0)
 #define CANCEL_REWINDER(rewinder) do { rewinder.packet = NULL; } while (0)
 
-int dns_packet_new(DnsPacket **ret, DnsProtocol protocol, size_t mtu) {
+int dns_packet_new(
+                DnsPacket **ret,
+                DnsProtocol protocol,
+                size_t min_alloc_dsize,
+                size_t max_size) {
+
         DnsPacket *p;
         size_t a;
 
         assert(ret);
+        assert(max_size >= DNS_PACKET_HEADER_SIZE);
 
-        if (mtu <= UDP_PACKET_HEADER_SIZE)
+        if (max_size > DNS_PACKET_SIZE_MAX)
+                max_size = DNS_PACKET_SIZE_MAX;
+
+        /* The caller may not check what is going to be truly allocated, so do not allow to
+         * allocate a DNS packet bigger than DNS_PACKET_SIZE_MAX.
+         */
+        if (min_alloc_dsize > DNS_PACKET_SIZE_MAX) {
+                log_error("Requested packet data size too big: %zu", min_alloc_dsize);
+                return -EFBIG;
+        }
+
+        /* When dns_packet_new() is called with min_alloc_dsize == 0, allocate more than the
+         * absolute minimum (which is the dns packet header size), to avoid
+         * resizing immediately again after appending the first data to the packet.
+         */
+        if (min_alloc_dsize < DNS_PACKET_HEADER_SIZE)
                 a = DNS_PACKET_SIZE_START;
         else
-                a = mtu - UDP_PACKET_HEADER_SIZE;
-
-        if (a < DNS_PACKET_HEADER_SIZE)
-                a = DNS_PACKET_HEADER_SIZE;
+                a = min_alloc_dsize;
 
         /* round up to next page size */
         a = PAGE_ALIGN(ALIGN(sizeof(DnsPacket)) + a) - ALIGN(sizeof(DnsPacket));
 
         /* make sure we never allocate more than useful */
-        if (a > DNS_PACKET_SIZE_MAX)
-                a = DNS_PACKET_SIZE_MAX;
+        if (a > max_size)
+                a = max_size;
 
         p = malloc0(ALIGN(sizeof(DnsPacket)) + a);
         if (!p)
@@ -68,6 +89,7 @@ int dns_packet_new(DnsPacket **ret, DnsProtocol protocol, size_t mtu) {
 
         p->size = p->rindex = DNS_PACKET_HEADER_SIZE;
         p->allocated = a;
+        p->max_size = max_size;
         p->protocol = protocol;
         p->opt_start = p->opt_size = (size_t) -1;
         p->n_ref = 1;
@@ -127,13 +149,13 @@ void dns_packet_set_flags(DnsPacket *p, bool dnssec_checking_disabled, bool trun
         }
 }
 
-int dns_packet_new_query(DnsPacket **ret, DnsProtocol protocol, size_t mtu, bool dnssec_checking_disabled) {
+int dns_packet_new_query(DnsPacket **ret, DnsProtocol protocol, size_t min_alloc_dsize, bool dnssec_checking_disabled) {
         DnsPacket *p;
         int r;
 
         assert(ret);
 
-        r = dns_packet_new(&p, protocol, mtu);
+        r = dns_packet_new(&p, protocol, min_alloc_dsize, DNS_PACKET_SIZE_MAX);
         if (r < 0)
                 return r;
 
@@ -302,11 +324,13 @@ static int dns_packet_extend(DnsPacket *p, size_t add, void **ret, size_t *start
         assert(p);
 
         if (p->size + add > p->allocated) {
-                size_t a;
+                size_t a, ms;
 
                 a = PAGE_ALIGN((p->size + add) * 2);
-                if (a > DNS_PACKET_SIZE_MAX)
-                        a = DNS_PACKET_SIZE_MAX;
+
+                ms = dns_packet_size_max(p);
+                if (a > ms)
+                        a = ms;
 
                 if (p->size + add > a)
                         return -EMSGSIZE;
@@ -569,8 +593,9 @@ fail:
         return r;
 }
 
-int dns_packet_append_key(DnsPacket *p, const DnsResourceKey *k, size_t *start) {
+int dns_packet_append_key(DnsPacket *p, const DnsResourceKey *k, const DnsAnswerFlags flags, size_t *start) {
         size_t saved_size;
+        uint16_t class;
         int r;
 
         assert(p);
@@ -586,7 +611,8 @@ int dns_packet_append_key(DnsPacket *p, const DnsResourceKey *k, size_t *start)
         if (r < 0)
                 goto fail;
 
-        r = dns_packet_append_uint16(p, k->class, NULL);
+        class = flags & DNS_ANSWER_CACHE_FLUSH ? k->class | MDNS_RR_CACHE_FLUSH : k->class;
+        r = dns_packet_append_uint16(p, class, NULL);
         if (r < 0)
                 goto fail;
 
@@ -791,9 +817,10 @@ int dns_packet_truncate_opt(DnsPacket *p) {
         return 1;
 }
 
-int dns_packet_append_rr(DnsPacket *p, const DnsResourceRecord *rr, size_t *start, size_t *rdata_start) {
+int dns_packet_append_rr(DnsPacket *p, const DnsResourceRecord *rr, const DnsAnswerFlags flags, size_t *start, size_t *rdata_start) {
 
         size_t saved_size, rdlength_offset, end, rdlength, rds;
+        uint32_t ttl;
         int r;
 
         assert(p);
@@ -801,11 +828,12 @@ int dns_packet_append_rr(DnsPacket *p, const DnsResourceRecord *rr, size_t *star
 
         saved_size = p->size;
 
-        r = dns_packet_append_key(p, rr->key, NULL);
+        r = dns_packet_append_key(p, rr->key, flags, NULL);
         if (r < 0)
                 goto fail;
 
-        r = dns_packet_append_uint32(p, rr->ttl, NULL);
+        ttl = flags & DNS_ANSWER_GOODBYE ? 0 : rr->ttl;
+        r = dns_packet_append_uint32(p, ttl, NULL);
         if (r < 0)
                 goto fail;
 
@@ -1143,7 +1171,7 @@ int dns_packet_append_question(DnsPacket *p, DnsQuestion *q) {
         assert(p);
 
         DNS_QUESTION_FOREACH(key, q) {
-                r = dns_packet_append_key(p, key, NULL);
+                r = dns_packet_append_key(p, key, 0, NULL);
                 if (r < 0)
                         return r;
         }
@@ -1153,12 +1181,13 @@ int dns_packet_append_question(DnsPacket *p, DnsQuestion *q) {
 
 int dns_packet_append_answer(DnsPacket *p, DnsAnswer *a) {
         DnsResourceRecord *rr;
+        DnsAnswerFlags flags;
         int r;
 
         assert(p);
 
-        DNS_ANSWER_FOREACH(rr, a) {
-                r = dns_packet_append_rr(p, rr, NULL, NULL);
+        DNS_ANSWER_FOREACH_FLAGS(rr, flags, a) {
+                r = dns_packet_append_rr(p, rr, flags, NULL, NULL);
                 if (r < 0)
                         return r;
         }
@@ -1486,7 +1515,7 @@ static int dns_packet_read_type_window(DnsPacket *p, Bitmap **types, size_t *sta
 
                 found = true;
 
-                while (bitmask) {
+                for (; bitmask; bit++, bitmask >>= 1)
                         if (bitmap[i] & bitmask) {
                                 uint16_t n;
 
@@ -1500,10 +1529,6 @@ static int dns_packet_read_type_window(DnsPacket *p, Bitmap **types, size_t *sta
                                 if (r < 0)
                                         return r;
                         }
-
-                        bit++;
-                        bitmask >>= 1;
-                }
         }
 
         if (!found)
@@ -2143,7 +2168,7 @@ int dns_packet_extract(DnsPacket *p) {
 
                 for (i = 0; i < n; i++) {
                         _cleanup_(dns_resource_record_unrefp) DnsResourceRecord *rr = NULL;
-                        bool cache_flush;
+                        bool cache_flush = false;
 
                         r = dns_packet_read_rr(p, &rr, &cache_flush, NULL);
                         if (r < 0)
@@ -2264,6 +2289,9 @@ int dns_packet_is_reply_for(DnsPacket *p, const DnsResourceKey *key) {
         if (r < 0)
                 return r;
 
+        if (!p->question)
+                return 0;
+
         if (p->question->n_keys != 1)
                 return 0;
 
@@ -2289,6 +2317,7 @@ static const char* const dns_rcode_table[_DNS_RCODE_MAX_DEFINED] = {
         [DNS_RCODE_BADNAME] = "BADNAME",
         [DNS_RCODE_BADALG] = "BADALG",
         [DNS_RCODE_BADTRUNC] = "BADTRUNC",
+        [DNS_RCODE_BADCOOKIE] = "BADCOOKIE",
 };
 DEFINE_STRING_TABLE_LOOKUP(dns_rcode, int);