]> git.ipfire.org Git - thirdparty/systemd.git/blobdiff - src/resolve/resolved-dns-packet.c
Add SPDX license identifiers to source files under the LGPL
[thirdparty/systemd.git] / src / resolve / resolved-dns-packet.c
index a486216d68da26d25aa307e100a46b037dfaf753..40f35475fdce63344927c304e969029adc2ba615 100644 (file)
@@ -1,3 +1,4 @@
+/* SPDX-License-Identifier: LGPL-2.1+ */
 /***
   This file is part of systemd.
 
@@ -28,8 +29,7 @@
 
 #define EDNS0_OPT_DO (1<<15)
 
-#define DNS_PACKET_SIZE_START 512u
-assert_cc(DNS_PACKET_SIZE_START > UDP_PACKET_HEADER_SIZE)
+assert_cc(DNS_PACKET_SIZE_START > DNS_PACKET_HEADER_SIZE)
 
 typedef struct DnsPacketRewinder {
         DnsPacket *packet;
@@ -44,27 +44,44 @@ static void rewind_dns_packet(DnsPacketRewinder *rewinder) {
 #define INIT_REWINDER(rewinder, p) do { rewinder.packet = p; rewinder.saved_rindex = p->rindex; } while (0)
 #define CANCEL_REWINDER(rewinder) do { rewinder.packet = NULL; } while (0)
 
-int dns_packet_new(DnsPacket **ret, DnsProtocol protocol, size_t mtu) {
+int dns_packet_new(
+                DnsPacket **ret,
+                DnsProtocol protocol,
+                size_t min_alloc_dsize,
+                size_t max_size) {
+
         DnsPacket *p;
         size_t a;
 
         assert(ret);
+        assert(max_size >= DNS_PACKET_HEADER_SIZE);
+
+        if (max_size > DNS_PACKET_SIZE_MAX)
+                max_size = DNS_PACKET_SIZE_MAX;
+
+        /* The caller may not check what is going to be truly allocated, so do not allow to
+         * allocate a DNS packet bigger than DNS_PACKET_SIZE_MAX.
+         */
+        if (min_alloc_dsize > DNS_PACKET_SIZE_MAX) {
+                log_error("Requested packet data size too big: %zu", min_alloc_dsize);
+                return -EFBIG;
+        }
 
-        /* When dns_packet_new() is called with mtu == 0, allocate more than the
+        /* When dns_packet_new() is called with min_alloc_dsize == 0, allocate more than the
          * absolute minimum (which is the dns packet header size), to avoid
          * resizing immediately again after appending the first data to the packet.
          */
-        if (mtu < UDP_PACKET_HEADER_SIZE)
+        if (min_alloc_dsize < DNS_PACKET_HEADER_SIZE)
                 a = DNS_PACKET_SIZE_START;
         else
-                a = MAX(mtu, DNS_PACKET_HEADER_SIZE);
+                a = min_alloc_dsize;
 
         /* round up to next page size */
         a = PAGE_ALIGN(ALIGN(sizeof(DnsPacket)) + a) - ALIGN(sizeof(DnsPacket));
 
         /* make sure we never allocate more than useful */
-        if (a > DNS_PACKET_SIZE_MAX)
-                a = DNS_PACKET_SIZE_MAX;
+        if (a > max_size)
+                a = max_size;
 
         p = malloc0(ALIGN(sizeof(DnsPacket)) + a);
         if (!p)
@@ -72,6 +89,7 @@ int dns_packet_new(DnsPacket **ret, DnsProtocol protocol, size_t mtu) {
 
         p->size = p->rindex = DNS_PACKET_HEADER_SIZE;
         p->allocated = a;
+        p->max_size = max_size;
         p->protocol = protocol;
         p->opt_start = p->opt_size = (size_t) -1;
         p->n_ref = 1;
@@ -131,13 +149,13 @@ void dns_packet_set_flags(DnsPacket *p, bool dnssec_checking_disabled, bool trun
         }
 }
 
-int dns_packet_new_query(DnsPacket **ret, DnsProtocol protocol, size_t mtu, bool dnssec_checking_disabled) {
+int dns_packet_new_query(DnsPacket **ret, DnsProtocol protocol, size_t min_alloc_dsize, bool dnssec_checking_disabled) {
         DnsPacket *p;
         int r;
 
         assert(ret);
 
-        r = dns_packet_new(&p, protocol, mtu);
+        r = dns_packet_new(&p, protocol, min_alloc_dsize, DNS_PACKET_SIZE_MAX);
         if (r < 0)
                 return r;
 
@@ -306,11 +324,13 @@ static int dns_packet_extend(DnsPacket *p, size_t add, void **ret, size_t *start
         assert(p);
 
         if (p->size + add > p->allocated) {
-                size_t a;
+                size_t a, ms;
 
                 a = PAGE_ALIGN((p->size + add) * 2);
-                if (a > DNS_PACKET_SIZE_MAX)
-                        a = DNS_PACKET_SIZE_MAX;
+
+                ms = dns_packet_size_max(p);
+                if (a > ms)
+                        a = ms;
 
                 if (p->size + add > a)
                         return -EMSGSIZE;
@@ -1495,7 +1515,7 @@ static int dns_packet_read_type_window(DnsPacket *p, Bitmap **types, size_t *sta
 
                 found = true;
 
-                while (bitmask) {
+                for (; bitmask; bit++, bitmask >>= 1)
                         if (bitmap[i] & bitmask) {
                                 uint16_t n;
 
@@ -1509,10 +1529,6 @@ static int dns_packet_read_type_window(DnsPacket *p, Bitmap **types, size_t *sta
                                 if (r < 0)
                                         return r;
                         }
-
-                        bit++;
-                        bitmask >>= 1;
-                }
         }
 
         if (!found)