From 1c796045eda93992a319432e792d4959479f9192 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 22 May 2020 08:54:44 -0700 Subject: [PATCH] revision of truncation handling --- dns/message.py | 31 ++++++++++++++--- dns/query.py | 61 +++++++++++++++++++++++++++++--- dns/resolver.py | 3 +- dns/trio/query.py | 58 ++++++++++++++++++++++++++++--- dns/trio/resolver.py | 3 +- doc/query.rst | 1 + doc/whatsnew.rst | 3 -- tests/test_message.py | 4 +-- tests/test_query.py | 81 +++++++++++++++++++++++++++++++++++++++++++ tests/test_trio.py | 15 ++++++++ 10 files changed, 240 insertions(+), 20 deletions(-) create mode 100644 tests/test_query.py diff --git a/dns/message.py b/dns/message.py index 931a14ab..103bec88 100644 --- a/dns/message.py +++ b/dns/message.py @@ -68,6 +68,15 @@ class UnknownTSIGKey(dns.exception.DNSException): class Truncated(dns.exception.DNSException): """The truncated flag is set.""" + supp_kwargs = {'message'} + + def message(self): + """As much of the message as could be processed. + + Returns a ``dns.message.Message``. + """ + return self.kwargs['message'] + #: The question section number QUESTION = 0 @@ -745,8 +754,6 @@ class _WireReader(object): self.message.original_id = self.message.id if dns.opcode.is_update(self.message.flags): self.updating = True - if self.message.flags & dns.flags.TC: - raise Truncated self._get_question(qcount) if self.question_only: return @@ -763,7 +770,7 @@ class _WireReader(object): def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, first=True, question_only=False, one_rr_per_rrset=False, - ignore_trailing=False): + ignore_trailing=False, raise_on_truncation=False): """Convert a DNS wire format message into a message object. @@ -798,6 +805,9 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the message. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. @@ -810,7 +820,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of the additional data section. - Raises ``dns.message.Truncated`` if the TC flag is set. + Raises ``dns.message.Truncated`` if the TC flag is set and + *raise_on_truncation* is ``True``. Returns a ``dns.message.Message``. """ @@ -826,7 +837,17 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, reader = _WireReader(wire, m, question_only, one_rr_per_rrset, ignore_trailing) - reader.read() + try: + reader.read() + except dns.exception.FormError: + if m.flags & dns.flags.TC and raise_on_truncation: + raise Truncated(message=m) + else: + raise + # Reading a truncated message might not have any errors, so we + # have to do this check here too. + if m.flags & dns.flags.TC and raise_on_truncation: + raise Truncated(message=m) return m diff --git a/dns/query.py b/dns/query.py index b02d173a..8f3fdab2 100644 --- a/dns/query.py +++ b/dns/query.py @@ -368,7 +368,8 @@ def send_udp(sock, what, destination, expiration=None): def receive_udp(sock, destination, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False): + keyring=None, request_mac=b'', ignore_trailing=False, + raise_on_truncation=False): """Read a DNS message from a UDP socket. *sock*, a ``socket``. @@ -393,6 +394,9 @@ def receive_udp(sock, destination, expiration=None, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + Raises if the message is malformed, if network errors occur, of if there is a timeout. @@ -414,11 +418,13 @@ def receive_udp(sock, destination, expiration=None, received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation) return (r, received_time) def udp(q, where, timeout=None, port=53, source=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False): + ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, + raise_on_truncation=False): """Return the response obtained after sending a query via UDP. *q*, a ``dns.message.Message``, the query to send @@ -446,6 +452,9 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + Returns a ``dns.message.Message``. """ @@ -460,12 +469,56 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0, (_, sent_time) = send_udp(s, wire, destination, expiration) (r, received_time) = receive_udp(s, destination, expiration, ignore_unexpected, one_rr_per_rrset, - q.keyring, q.mac, ignore_trailing) + q.keyring, q.mac, ignore_trailing, + raise_on_truncation) r.time = received_time - sent_time if not q.is_response(r): raise BadResponse return r +def udp_with_fallback(q, where, timeout=None, port=53, source=None, + source_port=0, ignore_unexpected=False, + one_rr_per_rrset=False, ignore_trailing=False): + """Return the response to the query, trying UDP first and falling back + to TCP if UDP results in a truncated response. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` + if and only if TCP was used. + """ + try: + response = udp(q, where, timeout, port, source, source_port, + ignore_unexpected, one_rr_per_rrset, + ignore_trailing, True) + return (response, False) + except dns.message.Truncated: + response = tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing) + return (response, True) def _net_read(sock, count, expiration): """Read the specified number of bytes from sock. Keep trying until we diff --git a/dns/resolver.py b/dns/resolver.py index cc1f78b3..4339a06b 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -1076,7 +1076,8 @@ class Resolver(object): timeout=timeout, port=port, source=source, - source_port=source_port) + source_port=source_port, + raise_on_truncation=True) else: protocol = urlparse(nameserver).scheme if protocol == 'https': diff --git a/dns/trio/query.py b/dns/trio/query.py index e7e79a52..9dcbbacc 100644 --- a/dns/trio/query.py +++ b/dns/trio/query.py @@ -46,7 +46,7 @@ async def send_udp(sock, what, destination): async def receive_udp(sock, destination, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', - ignore_trailing=False): + ignore_trailing=False, raise_on_truncation=False): """Asynchronously read a DNS message from a UDP socket. *sock*, a ``trio.socket``. @@ -67,6 +67,9 @@ async def receive_udp(sock, destination, ignore_unexpected=False, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + Raises if the message is malformed, if network errors occur, of if there is a timeout. @@ -88,12 +91,13 @@ async def receive_udp(sock, destination, ignore_unexpected=False, received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation) return (r, received_time) async def udp(q, where, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, - ignore_trailing=False): + ignore_trailing=False, raise_on_truncation=False): """Asynchronously return the response obtained after sending a query via UDP. @@ -119,6 +123,9 @@ async def udp(q, where, port=53, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + Returns a ``dns.message.Message``. """ @@ -135,12 +142,55 @@ async def udp(q, where, port=53, source=None, source_port=0, (r, received_time) = await receive_udp(s, destination, ignore_unexpected, one_rr_per_rrset, q.keyring, - q.mac, ignore_trailing) + q.mac, ignore_trailing, + raise_on_truncation) if not q.is_response(r): raise BadResponse r.time = received_time - sent_time return r +async def udp_with_fallback(q, where, timeout=None, port=53, source=None, + source_port=0, ignore_unexpected=False, + one_rr_per_rrset=False, ignore_trailing=False): + """Return the response to the query, trying UDP first and falling back + to TCP if UDP results in a truncated response. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` + if and only if TCP was used. + """ + try: + response = await udp(q, where, port, source, source_port, + ignore_unexpected, one_rr_per_rrset, + ignore_trailing, True) + return (response, False) + except dns.message.Truncated: + response = await stream(q, where, False, port, source, source_port, + one_rr_per_rrset, ignore_trailing) + + return (response, True) + # pylint: disable=redefined-outer-name async def send_stream(stream, what): diff --git a/dns/trio/resolver.py b/dns/trio/resolver.py index 6a445458..d2bc12be 100644 --- a/dns/trio/resolver.py +++ b/dns/trio/resolver.py @@ -114,7 +114,8 @@ class Resolver(dns.resolver.Resolver): nameserver, port=port, source=source, - source_port=source_port) + source_port=source_port, + raise_on_truncation=True) else: # We don't do DoH yet. raise NotImplementedError diff --git a/doc/query.rst b/doc/query.rst index 55aa5ab5..08940b43 100644 --- a/doc/query.rst +++ b/doc/query.rst @@ -17,6 +17,7 @@ UDP --- .. autofunction:: dns.query.udp +.. autofunction:: dns.query.udp_with_fallback .. autofunction:: dns.query.send_udp .. autofunction:: dns.query.receive_udp diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 4dc80471..f2e7adf3 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -57,9 +57,6 @@ What's New in dnspython 2.0.0 * The NINFO record is supported. -* When decoding from wire format, if a message as TC (truncated) set, - a ``Truncated`` exception is now raised. - * The ``dns.hash`` module has been removed; just use Python's native ``hashlib`` module. diff --git a/tests/test_message.py b/tests/test_message.py index 166d7e86..55ae6fd0 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -225,7 +225,7 @@ class MessageTestCase(unittest.TestCase): a = dns.message.from_text(answer_text) a.flags |= dns.flags.TC wire = a.to_wire(want_shuffle=False) - dns.message.from_wire(wire) + dns.message.from_wire(wire, raise_on_truncation=True) self.assertRaises(dns.message.Truncated, bad) def test_MessyTruncated(self): @@ -233,7 +233,7 @@ class MessageTestCase(unittest.TestCase): a = dns.message.from_text(answer_text) a.flags |= dns.flags.TC wire = a.to_wire(want_shuffle=False) - dns.message.from_wire(wire[:-3]) + dns.message.from_wire(wire[:-3], raise_on_truncation=True) self.assertRaises(dns.message.Truncated, bad) def test_IDNA_2003(self): diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 00000000..9c632171 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,81 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import socket +import unittest + +import dns.message +import dns.name +import dns.rdataclass +import dns.rdatatype +import dns.query + +# Some tests require the internet to be available to run, so let's +# skip those if it's not there. +_network_available = True +try: + socket.gethostbyname('dnspython.org') +except socket.gaierror: + _network_available = False + +@unittest.skipIf(not _network_available, "Internet not reachable") +class QueryTests(unittest.TestCase): + + def testQueryUDP(self): + qname = dns.name.from_text('dns.google.') + q = dns.message.make_query(qname, dns.rdatatype.A) + response = dns.query.udp(q, '8.8.8.8') + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + def testQueryTCP(self): + qname = dns.name.from_text('dns.google.') + q = dns.message.make_query(qname, dns.rdatatype.A) + response = dns.query.tcp(q, '8.8.8.8') + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + def testQueryTLS(self): + qname = dns.name.from_text('dns.google.') + q = dns.message.make_query(qname, dns.rdatatype.A) + response = dns.query.tls(q, '8.8.8.8') + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + def testQueryUDPFallback(self): + qname = dns.name.from_text('.') + q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) + (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8') + self.assertTrue(tcp) + + def testQueryUDPFallbackNoFallback(self): + qname = dns.name.from_text('dns.google.') + q = dns.message.make_query(qname, dns.rdatatype.A) + (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8') + self.assertFalse(tcp) diff --git a/tests/test_trio.py b/tests/test_trio.py index 961ebda4..d519844d 100644 --- a/tests/test_trio.py +++ b/tests/test_trio.py @@ -125,6 +125,21 @@ try: self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) + def testQueryUDPFallback(self): + qname = dns.name.from_text('.') + async def run(): + q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) + return await dns.trio.query.udp_with_fallback(q, '8.8.8.8') + (_, tcp) = trio.run(run) + self.assertTrue(tcp) + + def testQueryUDPFallbackNoFallback(self): + qname = dns.name.from_text('dns.google.') + async def run(): + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.trio.query.udp_with_fallback(q, '8.8.8.8') + (_, tcp) = trio.run(run) + self.assertFalse(tcp) except ModuleNotFoundError: pass -- 2.47.3