]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add EDNS padding. 797/head
authorBob Halley <halley@dnspython.org>
Sun, 20 Mar 2022 19:19:54 +0000 (12:19 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 23 Mar 2022 12:47:26 +0000 (05:47 -0700)
dns/message.py
dns/query.py
dns/renderer.py
dns/tsig.py
tests/test_message.py
tests/test_renderer.py

index 9541df68e54dc8854e0952fbef8dcbc060bc0c43..507a5b08eeb37153ee82be2d2f9b5d2532804935 100644 (file)
@@ -152,6 +152,7 @@ class Message:
         self.sections: List[List[dns.rrset.RRset]] = [[], [], [], []]
         self.opt: Optional[dns.rrset.RRset] = None
         self.request_payload = 0
+        self.pad = 0
         self.keyring: Any = None
         self.tsig: Optional[dns.rrset.RRset] = None
         self.request_mac = b""
@@ -460,6 +461,38 @@ class Message:
             rrset = None
         return rrset
 
+    def _compute_opt_reserve(self) -> int:
+        """Compute the size required for the OPT RR, padding excluded"""
+        if not self.opt:
+            return 0
+        # 1 byte for the root name, 10 for the standard RR fields
+        size = 11
+        # This would be more efficient if options had a size() method, but we won't
+        # worry about that for now.  We also don't worry if there is an existing padding
+        # option, as it is unlikely and probably harmless, as the worst case is that we
+        # may add another, and this seems to be legal.
+        for option in self.opt[0].options:
+            wire = option.to_wire()
+            # We add 4 here to account for the option type and length
+            size += len(wire) + 4
+        if self.pad:
+            # Padding will be added, so again add the option type and length.
+            size += 4
+        return size
+
+    def _compute_tsig_reserve(self) -> int:
+        """Compute the size required for the TSIG RR"""
+        # This would be more efficient if TSIGs had a size method, but we won't
+        # worry about for now.  Also, we can't really cope with the potential
+        # compressibility of the TSIG owner name, so we estimate with the uncompressed
+        # size.  We will disable compression when TSIG and padding are both is active
+        # so that the padding comes out right.
+        if not self.tsig:
+            return 0
+        f = io.BytesIO()
+        self.tsig.to_wire(f)
+        return len(f.getvalue())
+
     def to_wire(
         self,
         origin: Optional[dns.name.Name] = None,
@@ -505,16 +538,21 @@ class Message:
         elif max_size > 65535:
             max_size = 65535
         r = dns.renderer.Renderer(self.id, self.flags, max_size, origin)
+        opt_reserve = self._compute_opt_reserve()
+        r.reserve(opt_reserve)
+        tsig_reserve = self._compute_tsig_reserve()
+        r.reserve(tsig_reserve)
         for rrset in self.question:
             r.add_question(rrset.name, rrset.rdtype, rrset.rdclass)
         for rrset in self.answer:
             r.add_rrset(dns.renderer.ANSWER, rrset, **kw)
         for rrset in self.authority:
             r.add_rrset(dns.renderer.AUTHORITY, rrset, **kw)
-        if self.opt is not None:
-            r.add_rrset(dns.renderer.ADDITIONAL, self.opt)
         for rrset in self.additional:
             r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
+        r.release_reserved()
+        if self.opt is not None:
+            r.add_opt(self.opt, self.pad, opt_reserve, tsig_reserve)
         r.write_header()
         if self.tsig is not None:
             (new_tsig, ctx) = dns.tsig.sign(
@@ -619,7 +657,7 @@ class Message:
             self.keyring.algorithm,
             0,
             fudge,
-            b"",
+            b"\x00" * dns.tsig.mac_sizes[self.keyring.algorithm],
             original_id,
             tsig_error,
             other_data,
@@ -669,26 +707,29 @@ class Message:
         payload: int = DEFAULT_EDNS_PAYLOAD,
         request_payload: Optional[int] = None,
         options: Optional[List[dns.edns.Option]] = None,
+        pad: int = 0,
     ) -> None:
         """Configure EDNS behavior.
 
-        *edns*, an ``int``, is the EDNS level to use.  Specifying
-        ``None``, ``False``, or ``-1`` means "do not use EDNS", and in this case
-        the other parameters are ignored.  Specifying ``True`` is
-        equivalent to specifying 0, i.e. "use EDNS0".
+        *edns*, an ``int``, is the EDNS level to use.  Specifying ``None``, ``False``,
+        or ``-1`` means "do not use EDNS", and in this case the other parameters are
+        ignored.  Specifying ``True`` is equivalent to specifying 0, i.e. "use EDNS0".
 
         *ednsflags*, an ``int``, the EDNS flag values.
 
-        *payload*, an ``int``, is the EDNS sender's payload field, which is the
-        maximum size of UDP datagram the sender can handle.  I.e. how big
-        a response to this message can be.
+        *payload*, an ``int``, is the EDNS sender's payload field, which is the maximum
+        size of UDP datagram the sender can handle.  I.e. how big a response to this
+        message can be.
+
+        *request_payload*, an ``int``, is the EDNS payload size to use when sending this
+        message.  If not specified, defaults to the value of *payload*.
 
-        *request_payload*, an ``int``, is the EDNS payload size to use when
-        sending this message.  If not specified, defaults to the value of
-        *payload*.
+        *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS options.
 
-        *options*, a list of ``dns.edns.Option`` objects or ``None``, the EDNS
-        options.
+        *pad*, a non-negative ``int``.  If 0, the default, do not pad; otherwise add
+        padding bytes to make the message size a multiple of *pad*.  Note that if
+        padding is non-zero, an EDNS PADDING option will always be added to the
+        message.
         """
 
         if edns is None or edns is False:
@@ -708,6 +749,7 @@ class Message:
             if request_payload is None:
                 request_payload = payload
             self.request_payload = request_payload
+            self.pad = pad
 
     @property
     def edns(self) -> int:
@@ -1607,6 +1649,7 @@ def make_query(
     idna_codec: Optional[dns.name.IDNACodec] = None,
     id: Optional[int] = None,
     flags: int = dns.flags.RD,
+    pad: int = 0,
 ) -> QueryMessage:
     """Make a query message.
 
@@ -1655,6 +1698,11 @@ def make_query(
     *flags*, an ``int``, the desired query flags.  The default is
     ``dns.flags.RD``.
 
+    *pad*, a non-negative ``int``.  If 0, the default, do not pad; otherwise add
+    padding bytes to make the message size a multiple of *pad*.  Note that if
+    padding is non-zero, an EDNS PADDING option will always be added to the
+    message.
+
     Returns a ``dns.message.QueryMessage``
     """
 
@@ -1682,6 +1730,7 @@ def make_query(
     if kwargs and use_edns is None:
         use_edns = 0
     kwargs["edns"] = use_edns
+    kwargs["pad"] = pad
     m.use_edns(**kwargs)
     m.want_dnssec(want_dnssec)
     return m
index 7dec23e34f7efc8f0d0e7b470047e2938e5de8cd..9d069bb4e2255678ab9f9b8386ad46d3844d26b0 100644 (file)
@@ -288,37 +288,36 @@ def https(
 
     *q*, a ``dns.message.Message``, the query to send.
 
-    *where*, a ``str``, the nameserver IP address or the full URL. If an IP
-    address is given, the URL will be constructed using the following schema:
+    *where*, a ``str``, the nameserver IP address or the full URL. If an IP address is
+    given, the URL will be constructed using the following schema:
     https://<IP-address>:<port>/<path>.
 
-    *timeout*, a ``float`` or ``None``, the number of seconds to
-    wait before the query times out. If ``None``, the default, wait forever.
+    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
+    times out. If ``None``, the default, wait forever.
 
     *port*, a ``int``, the port to send the query to. The default is 443.
 
-    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
-    the source address.  The default is the wildcard address.
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
+    address.  The default is the wildcard address.
 
-    *source_port*, an ``int``, the port from which to send the message.
-    The default is 0.
+    *source_port*, an ``int``, the port from which to send the message. The default is
+    0.
 
-    *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
-    RRset.
+    *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
 
-    *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
-    junk at end of the received message.
+    *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
+    received message.
 
-    *session*, an ``httpx.Client`` or ``requests.session.Session``.  If
-    provided, the client/session to use to send the queries.
+    *session*, an ``httpx.Client`` or ``requests.session.Session``.  If provided, the
+    client/session to use to send the queries.
 
     *path*, a ``str``. If *where* is an IP address, then *path* will be used to
     construct the URL to send the DNS query to.
 
     *post*, a ``bool``. If ``True``, the default, POST method will be used.
 
-    *bootstrap_address*, a ``str``, the IP address to use to bypass the
-    system's DNS resolver.
+    *bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS
+    resolver.
 
     *verify*, a ``str``, containing a path to a certificate file or directory.
 
index 95e8bd3ad48fec1c6359631d8872c9e84338ae97..63fb3fac5b7860467cf11099b281856b5c08ef72 100644 (file)
@@ -48,13 +48,17 @@ class Renderer:
         r.add_rrset(dns.renderer.ANSWER, rrset_1)
         r.add_rrset(dns.renderer.ANSWER, rrset_2)
         r.add_rrset(dns.renderer.AUTHORITY, ns_rrset)
-        r.add_edns(0, 0, 4096)
         r.add_rrset(dns.renderer.ADDITIONAL, ad_rrset_1)
         r.add_rrset(dns.renderer.ADDITIONAL, ad_rrset_2)
+        r.add_edns(0, 0, 4096)
         r.write_header()
         r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac)
         wire = r.get_wire()
 
+    If padding is going to be used, then the OPT record MUST be
+    written after everything else in the additional section except for
+    the TSIG (if any).
+
     output, an io.BytesIO, where rendering is written
 
     id: the message id
@@ -90,6 +94,8 @@ class Renderer:
         self.counts = [0, 0, 0, 0]
         self.output.write(b"\x00" * 12)
         self.mac = ""
+        self.reserved = 0
+        self.was_padded = False
 
     def _rollback(self, where):
         """Truncate the output buffer at offset *where*, and remove any
@@ -163,6 +169,29 @@ class Renderer:
             n = rdataset.to_wire(name, self.output, self.compress, self.origin, **kw)
         self.counts[section] += n
 
+    def add_opt(self, opt, pad=0, opt_size=0, tsig_size=0):
+        """Add *opt* to the additional section, applying padding if desired.  The
+        padding will take the specified precomputed OPT size and TSIG size into
+        account.
+
+        Note that we don't have reliable way of knowing how big a GSS-TSIG digest
+        might be, so we we might not get an even multiple of the pad in that case."""
+        if pad:
+            ttl = opt.ttl
+            assert opt_size >= 11
+            opt_rdata = opt[0]
+            size_without_padding = self.output.tell() + opt_size + tsig_size
+            remainder = size_without_padding % pad
+            if remainder:
+                pad = b"\x00" * (pad - remainder)
+            else:
+                pad = b""
+            options = list(opt_rdata.options)
+            options.append(dns.edns.GenericOption(dns.edns.OptionType.PADDING, pad))
+            opt = dns.message.Message._make_opt(ttl, opt_rdata.rdclass, options)
+            self.was_padded = True
+        self.add_rrset(ADDITIONAL, opt)
+
     def add_edns(self, edns, ednsflags, payload, options=None):
         """Add an EDNS OPT record to the message."""
 
@@ -170,7 +199,7 @@ class Renderer:
         ednsflags &= 0xFF00FFFF
         ednsflags |= edns << 16
         opt = dns.message.Message._make_opt(ednsflags, payload, options)
-        self.add_rrset(ADDITIONAL, opt)
+        self.add_opt(opt)
 
     def add_tsig(
         self,
@@ -233,9 +262,13 @@ class Renderer:
         return ctx
 
     def _write_tsig(self, tsig, keyname):
+        if self.was_padded:
+            compress = None
+        else:
+            compress = self.compress
         self._set_section(ADDITIONAL)
         with self._track_size():
-            keyname.to_wire(self.output, self.compress, self.origin)
+            keyname.to_wire(self.output, compress, self.origin)
             self.output.write(
                 struct.pack("!HHIH", dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0)
             )
@@ -276,3 +309,17 @@ class Renderer:
         """Return the wire format message."""
 
         return self.output.getvalue()
+
+    def reserve(self, size: int) -> None:
+        """Reserve *size* bytes."""
+        if size < 0:
+            raise ValueError(f"reserved amount must be non-negative")
+        if size > self.max_size:
+            raise ValueError(f"cannot reserve more than the maximum size")
+        self.reserved += size
+        self.max_size -= size
+
+    def release_reserved(self) -> None:
+        """Release the reserved bytes."""
+        self.max_size += self.reserved
+        self.reserved = 0
index b3f525169571edfee08adc3f86d35db1419b24cf..2476fdfb67390c4b9805749195c1b45c15560400 100644 (file)
@@ -88,6 +88,19 @@ GSS_TSIG = dns.name.from_text("gss-tsig")
 
 default_algorithm = HMAC_SHA256
 
+mac_sizes = {
+    HMAC_SHA1: 20,
+    HMAC_SHA224: 28,
+    HMAC_SHA256: 32,
+    HMAC_SHA256_128: 16,
+    HMAC_SHA384: 48,
+    HMAC_SHA384_192: 24,
+    HMAC_SHA512: 64,
+    HMAC_SHA512_256: 32,
+    HMAC_MD5: 16,
+    GSS_TSIG: 128,  # This is what we assume to be the worst case!
+}
+
 
 class GSSTSig:
     """
index d2f5b0ea16f8855d4890f64479925d96defdd740..29632833c2bc0d600234cae4edb8ce2e9739f5c4 100644 (file)
@@ -20,6 +20,7 @@ import unittest
 import binascii
 
 import dns.exception
+import dns.edns
 import dns.flags
 import dns.message
 import dns.name
@@ -30,6 +31,7 @@ import dns.tsig
 import dns.update
 import dns.rdtypes.ANY.OPT
 import dns.rdtypes.ANY.TSIG
+import dns.tsigkeyring
 
 from tests.util import here
 
@@ -810,6 +812,56 @@ www.dnspython.org. 300 IN A 1.2.3.4
         )
         self.assertEqual(m, expected_message)
 
+    def test_padding_basic(self):
+        q = dns.message.make_query("www.example", "a", use_edns=0, pad=0)
+        w = q.to_wire()
+        self.assertEqual(len(w), 40)
+        q = dns.message.make_query("www.example", "a", use_edns=0, pad=128)
+        w = q.to_wire()
+        self.assertEqual(len(w), 128)
+        q2 = dns.message.from_wire(w)
+        self.assertEqual(q, q2)
+
+    def test_padding_various(self):
+        q = dns.message.make_query("www.example", "a", use_edns=0, pad=1)
+        w = q.to_wire()
+        self.assertEqual(len(w), 44)
+        q = dns.message.make_query("www.example", "a", use_edns=0, pad=2)
+        w = q.to_wire()
+        self.assertEqual(len(w), 44)
+        q = dns.message.make_query("www.example", "a", use_edns=0, pad=3)
+        w = q.to_wire()
+        self.assertEqual(len(w), 45)
+        q = dns.message.make_query("www.example", "a", use_edns=0, pad=44)
+        w = q.to_wire()
+        self.assertEqual(len(w), 44)
+        q = dns.message.make_query("www.example", "a", use_edns=0, pad=67)
+        w = q.to_wire()
+        self.assertEqual(len(w), 67)
+
+    def test_padding_with_option(self):
+        options = [dns.edns.ECSOption("1.2.3.0", 24)]
+        q = dns.message.make_query(
+            "www.example", "a", use_edns=0, pad=128, options=options
+        )
+        w = q.to_wire()
+        self.assertEqual(len(w), 128)
+        q2 = dns.message.from_wire(w)
+        self.assertEqual(q, q2)
+
+    def test_padding_with_tsig_and_option(self):
+        keyring = dns.tsigkeyring.from_text({"keyname.": "NjHwPsMKjdN++dOfE5iAiQ=="})
+        options = [dns.edns.ECSOption("1.2.3.0", 24)]
+        q = dns.message.make_query(
+            "www.example", "a", use_edns=0, options=options, pad=128
+        )
+        q.use_tsig(keyring)
+        w = q.to_wire()
+        self.assertEqual(len(w), 256)
+        q2 = dns.message.from_wire(w, keyring=keyring)
+        self.assertIsNotNone(q2.tsig)
+        self.assertEqual(q, q2)
+
 
 if __name__ == "__main__":
     unittest.main()
index ca5a85e6c4f6f8ee0c4870692e53ccc78068c951..1682fd1722c04ddcbb20eccb673ae2bc5515b48a 100644 (file)
@@ -103,3 +103,14 @@ class RendererTestCase(unittest.TestCase):
             r.add_rdataset(dns.renderer.ANSWER, qname, rds)
 
         self.assertRaises(dns.exception.FormError, bad)
+
+    def test_reservation(self):
+        r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
+        r.reserve(100)
+        assert r.max_size == 412
+        r.release_reserved()
+        assert r.max_size == 512
+        with self.assertRaises(ValueError):
+            r.reserve(-1)
+        with self.assertRaises(ValueError):
+            r.reserve(513)