]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
revision of truncation handling
authorBob Halley <halley@dnspython.org>
Fri, 22 May 2020 15:54:44 +0000 (08:54 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 22 May 2020 15:54:44 +0000 (08:54 -0700)
dns/message.py
dns/query.py
dns/resolver.py
dns/trio/query.py
dns/trio/resolver.py
doc/query.rst
doc/whatsnew.rst
tests/test_message.py
tests/test_query.py [new file with mode: 0644]
tests/test_trio.py

index 931a14abafadbdcd56ee8984a5b8592bbc801ba1..103bec88e8c2386a14f205b630066281cbc5fb28 100644 (file)
@@ -68,6 +68,15 @@ class UnknownTSIGKey(dns.exception.DNSException):
 class Truncated(dns.exception.DNSException):
     """The truncated flag is set."""
 
+    supp_kwargs = {'message'}
+
+    def message(self):
+        """As much of the message as could be processed.
+
+        Returns a ``dns.message.Message``.
+        """
+        return self.kwargs['message']
+
 
 #: The question section number
 QUESTION = 0
@@ -745,8 +754,6 @@ class _WireReader(object):
         self.message.original_id = self.message.id
         if dns.opcode.is_update(self.message.flags):
             self.updating = True
-        if self.message.flags & dns.flags.TC:
-            raise Truncated
         self._get_question(qcount)
         if self.question_only:
             return
@@ -763,7 +770,7 @@ class _WireReader(object):
 def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
               tsig_ctx=None, multi=False, first=True,
               question_only=False, one_rr_per_rrset=False,
-              ignore_trailing=False):
+              ignore_trailing=False, raise_on_truncation=False):
     """Convert a DNS wire format message into a message
     object.
 
@@ -798,6 +805,9 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the message.
 
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
+    the TC bit is set.
+
     Raises ``dns.message.ShortHeader`` if the message is less than 12 octets
     long.
 
@@ -810,7 +820,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     Raises ``dns.message.BadTSIG`` if a TSIG record was not the last
     record of the additional data section.
 
-    Raises ``dns.message.Truncated`` if the TC flag is set.
+    Raises ``dns.message.Truncated`` if the TC flag is set and
+    *raise_on_truncation* is ``True``.
 
     Returns a ``dns.message.Message``.
     """
@@ -826,7 +837,17 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
 
     reader = _WireReader(wire, m, question_only, one_rr_per_rrset,
                          ignore_trailing)
-    reader.read()
+    try:
+        reader.read()
+    except dns.exception.FormError:
+        if m.flags & dns.flags.TC and raise_on_truncation:
+            raise Truncated(message=m)
+        else:
+            raise
+    # Reading a truncated message might not have any errors, so we
+    # have to do this check here too.
+    if m.flags & dns.flags.TC and raise_on_truncation:
+        raise Truncated(message=m)
 
     return m
 
index b02d173a68bbf0c5a7790b1724dea1645bab2b6c..8f3fdab295b98f96ef571aa57dd60325de7a8ceb 100644 (file)
@@ -368,7 +368,8 @@ def send_udp(sock, what, destination, expiration=None):
 
 def receive_udp(sock, destination, expiration=None,
                 ignore_unexpected=False, one_rr_per_rrset=False,
-                keyring=None, request_mac=b'', ignore_trailing=False):
+                keyring=None, request_mac=b'', ignore_trailing=False,
+                raise_on_truncation=False):
     """Read a DNS message from a UDP socket.
 
     *sock*, a ``socket``.
@@ -393,6 +394,9 @@ def receive_udp(sock, destination, expiration=None,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
+    the TC bit is set.
+
     Raises if the message is malformed, if network errors occur, of if
     there is a timeout.
 
@@ -414,11 +418,13 @@ def receive_udp(sock, destination, expiration=None,
     received_time = time.time()
     r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
                               one_rr_per_rrset=one_rr_per_rrset,
-                              ignore_trailing=ignore_trailing)
+                              ignore_trailing=ignore_trailing,
+                              raise_on_truncation=raise_on_truncation)
     return (r, received_time)
 
 def udp(q, where, timeout=None, port=53, source=None, source_port=0,
-        ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False):
+        ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+        raise_on_truncation=False):
     """Return the response obtained after sending a query via UDP.
 
     *q*, a ``dns.message.Message``, the query to send
@@ -446,6 +452,9 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
+    the TC bit is set.
+
     Returns a ``dns.message.Message``.
     """
 
@@ -460,12 +469,56 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0,
         (_, sent_time) = send_udp(s, wire, destination, expiration)
         (r, received_time) = receive_udp(s, destination, expiration,
                                          ignore_unexpected, one_rr_per_rrset,
-                                         q.keyring, q.mac, ignore_trailing)
+                                         q.keyring, q.mac, ignore_trailing,
+                                         raise_on_truncation)
         r.time = received_time - sent_time
         if not q.is_response(r):
             raise BadResponse
         return r
 
+def udp_with_fallback(q, where, timeout=None, port=53, source=None,
+                      source_port=0, ignore_unexpected=False,
+                      one_rr_per_rrset=False, ignore_trailing=False):
+    """Return the response to the query, trying UDP first and falling back
+    to TCP if UDP results in a truncated response.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+    query times out.  If ``None``, the default, wait forever.
+
+    *port*, an ``int``, the port send the message to.  The default is 53.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
+    unexpected sources.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
+    if and only if TCP was used.
+    """
+    try:
+        response = udp(q, where, timeout, port, source, source_port,
+                       ignore_unexpected, one_rr_per_rrset,
+                       ignore_trailing, True)
+        return (response, False)
+    except dns.message.Truncated:
+        response = tcp(q, where, timeout, port, source, source_port,
+                       one_rr_per_rrset, ignore_trailing)
+        return (response, True)
 
 def _net_read(sock, count, expiration):
     """Read the specified number of bytes from sock.  Keep trying until we
index cc1f78b305ff4ec1564341214212da27a84e9aa3..4339a06b39acd2cb03c768b2ec7c09c1448db683 100644 (file)
@@ -1076,7 +1076,8 @@ class Resolver(object):
                                                      timeout=timeout,
                                                      port=port,
                                                      source=source,
-                                                     source_port=source_port)
+                                                     source_port=source_port,
+                                                     raise_on_truncation=True)
                     else:
                         protocol = urlparse(nameserver).scheme
                         if protocol == 'https':
index e7e79a523ed4ca4e0c955ca646675530a3e310dd..9dcbbacc09b663a1114e2cb15fc1dc1478f7c99b 100644 (file)
@@ -46,7 +46,7 @@ async def send_udp(sock, what, destination):
 
 async def receive_udp(sock, destination, ignore_unexpected=False,
                       one_rr_per_rrset=False, keyring=None, request_mac=b'',
-                      ignore_trailing=False):
+                      ignore_trailing=False, raise_on_truncation=False):
     """Asynchronously read a DNS message from a UDP socket.
 
     *sock*, a ``trio.socket``.
@@ -67,6 +67,9 @@ async def receive_udp(sock, destination, ignore_unexpected=False,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
+    the TC bit is set.
+
     Raises if the message is malformed, if network errors occur, of if
     there is a timeout.
 
@@ -88,12 +91,13 @@ async def receive_udp(sock, destination, ignore_unexpected=False,
     received_time = time.time()
     r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
                               one_rr_per_rrset=one_rr_per_rrset,
-                              ignore_trailing=ignore_trailing)
+                              ignore_trailing=ignore_trailing,
+                              raise_on_truncation=raise_on_truncation)
     return (r, received_time)
 
 async def udp(q, where, port=53, source=None, source_port=0,
               ignore_unexpected=False, one_rr_per_rrset=False,
-              ignore_trailing=False):
+              ignore_trailing=False, raise_on_truncation=False):
     """Asynchronously return the response obtained after sending a query
     via UDP.
 
@@ -119,6 +123,9 @@ async def udp(q, where, port=53, source=None, source_port=0,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
+    the TC bit is set.
+
     Returns a ``dns.message.Message``.
     """
 
@@ -135,12 +142,55 @@ async def udp(q, where, port=53, source=None, source_port=0,
         (r, received_time) = await receive_udp(s, destination,
                                                ignore_unexpected,
                                                one_rr_per_rrset, q.keyring,
-                                               q.mac, ignore_trailing)
+                                               q.mac, ignore_trailing,
+                                               raise_on_truncation)
         if not q.is_response(r):
             raise BadResponse
         r.time = received_time - sent_time
         return r
 
+async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
+                            source_port=0, ignore_unexpected=False,
+                            one_rr_per_rrset=False, ignore_trailing=False):
+    """Return the response to the query, trying UDP first and falling back
+    to TCP if UDP results in a truncated response.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *port*, an ``int``, the port send the message to.  The default is 53.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
+    unexpected sources.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
+    if and only if TCP was used.
+    """
+    try:
+        response = await udp(q, where, port, source, source_port,
+                             ignore_unexpected, one_rr_per_rrset,
+                             ignore_trailing, True)
+        return (response, False)
+    except dns.message.Truncated:
+        response = await stream(q, where, False, port, source, source_port,
+                                one_rr_per_rrset, ignore_trailing)
+
+        return (response, True)
+
 # pylint: disable=redefined-outer-name
 
 async def send_stream(stream, what):
index 6a4454580bfd40e18f9836a3f8bfeaabdeb0aa4d..d2bc12be60a8425f3788e0d040081e0ddd1b5960 100644 (file)
@@ -114,7 +114,8 @@ class Resolver(dns.resolver.Resolver):
                                          nameserver,
                                          port=port,
                                          source=source,
-                                         source_port=source_port)
+                                         source_port=source_port,
+                                         raise_on_truncation=True)
                         else:
                             # We don't do DoH yet.
                             raise NotImplementedError
index 55aa5ab534e944f02ae74b1ee02cbb50ed413c16..08940b43e3cd76feddff713e6fe742cd40f317a6 100644 (file)
@@ -17,6 +17,7 @@ UDP
 ---
 
 .. autofunction:: dns.query.udp
+.. autofunction:: dns.query.udp_with_fallback
 .. autofunction:: dns.query.send_udp
 .. autofunction:: dns.query.receive_udp
 
index 4dc8047127fb87dca83edfe4b7a006e6cf28d88c..f2e7adf39dd870ce254246fab17ea65855b0d487 100644 (file)
@@ -57,9 +57,6 @@ What's New in dnspython 2.0.0
 
 * The NINFO record is supported.
 
-* When decoding from wire format, if a message as TC (truncated) set,
-  a ``Truncated`` exception is now raised.
-
 * The ``dns.hash`` module has been removed; just use Python's native
   ``hashlib`` module.
 
index 166d7e8643c189c90e2ba0d1156863cff36716d5..55ae6fd050a80d3c376d40de3fa38716ec84ea06 100644 (file)
@@ -225,7 +225,7 @@ class MessageTestCase(unittest.TestCase):
             a = dns.message.from_text(answer_text)
             a.flags |= dns.flags.TC
             wire = a.to_wire(want_shuffle=False)
-            dns.message.from_wire(wire)
+            dns.message.from_wire(wire, raise_on_truncation=True)
         self.assertRaises(dns.message.Truncated, bad)
 
     def test_MessyTruncated(self):
@@ -233,7 +233,7 @@ class MessageTestCase(unittest.TestCase):
             a = dns.message.from_text(answer_text)
             a.flags |= dns.flags.TC
             wire = a.to_wire(want_shuffle=False)
-            dns.message.from_wire(wire[:-3])
+            dns.message.from_wire(wire[:-3], raise_on_truncation=True)
         self.assertRaises(dns.message.Truncated, bad)
 
     def test_IDNA_2003(self):
diff --git a/tests/test_query.py b/tests/test_query.py
new file mode 100644 (file)
index 0000000..9c63217
--- /dev/null
@@ -0,0 +1,81 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import socket
+import unittest
+
+import dns.message
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+import dns.query
+
+# Some tests require the internet to be available to run, so let's
+# skip those if it's not there.
+_network_available = True
+try:
+    socket.gethostbyname('dnspython.org')
+except socket.gaierror:
+    _network_available = False
+
+@unittest.skipIf(not _network_available, "Internet not reachable")
+class QueryTests(unittest.TestCase):
+
+    def testQueryUDP(self):
+        qname = dns.name.from_text('dns.google.')
+        q = dns.message.make_query(qname, dns.rdatatype.A)
+        response = dns.query.udp(q, '8.8.8.8')
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    def testQueryTCP(self):
+        qname = dns.name.from_text('dns.google.')
+        q = dns.message.make_query(qname, dns.rdatatype.A)
+        response = dns.query.tcp(q, '8.8.8.8')
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    def testQueryTLS(self):
+        qname = dns.name.from_text('dns.google.')
+        q = dns.message.make_query(qname, dns.rdatatype.A)
+        response = dns.query.tls(q, '8.8.8.8')
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    def testQueryUDPFallback(self):
+        qname = dns.name.from_text('.')
+        q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+        (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
+        self.assertTrue(tcp)
+
+    def testQueryUDPFallbackNoFallback(self):
+        qname = dns.name.from_text('dns.google.')
+        q = dns.message.make_query(qname, dns.rdatatype.A)
+        (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
+        self.assertFalse(tcp)
index 961ebda45b0be9c9a7822c1081a859dd2841e59a..d519844d7306d9cee143ac1003491c0e28166a08 100644 (file)
@@ -125,6 +125,21 @@ try:
             self.assertTrue('8.8.8.8' in seen)
             self.assertTrue('8.8.4.4' in seen)
 
+        def testQueryUDPFallback(self):
+            qname = dns.name.from_text('.')
+            async def run():
+                q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+                return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
+            (_, tcp) = trio.run(run)
+            self.assertTrue(tcp)
+
+        def testQueryUDPFallbackNoFallback(self):
+            qname = dns.name.from_text('dns.google.')
+            async def run():
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
+            (_, tcp) = trio.run(run)
+            self.assertFalse(tcp)
 
 except ModuleNotFoundError:
     pass