class Truncated(dns.exception.DNSException):
"""The truncated flag is set."""
+ supp_kwargs = {'message'}
+
+ def message(self):
+ """As much of the message as could be processed.
+
+ Returns a ``dns.message.Message``.
+ """
+ return self.kwargs['message']
+
#: The question section number
QUESTION = 0
self.message.original_id = self.message.id
if dns.opcode.is_update(self.message.flags):
self.updating = True
- if self.message.flags & dns.flags.TC:
- raise Truncated
self._get_question(qcount)
if self.question_only:
return
def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
tsig_ctx=None, multi=False, first=True,
question_only=False, one_rr_per_rrset=False,
- ignore_trailing=False):
+ ignore_trailing=False, raise_on_truncation=False):
"""Convert a DNS wire format message into a message
object.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the message.
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
+ the TC bit is set.
+
Raises ``dns.message.ShortHeader`` if the message is less than 12 octets
long.
Raises ``dns.message.BadTSIG`` if a TSIG record was not the last
record of the additional data section.
- Raises ``dns.message.Truncated`` if the TC flag is set.
+ Raises ``dns.message.Truncated`` if the TC flag is set and
+ *raise_on_truncation* is ``True``.
Returns a ``dns.message.Message``.
"""
reader = _WireReader(wire, m, question_only, one_rr_per_rrset,
ignore_trailing)
- reader.read()
+ try:
+ reader.read()
+ except dns.exception.FormError:
+ if m.flags & dns.flags.TC and raise_on_truncation:
+ raise Truncated(message=m)
+ else:
+ raise
+ # Reading a truncated message might not have any errors, so we
+ # have to do this check here too.
+ if m.flags & dns.flags.TC and raise_on_truncation:
+ raise Truncated(message=m)
return m
def receive_udp(sock, destination, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
- keyring=None, request_mac=b'', ignore_trailing=False):
+ keyring=None, request_mac=b'', ignore_trailing=False,
+ raise_on_truncation=False):
"""Read a DNS message from a UDP socket.
*sock*, a ``socket``.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
+ the TC bit is set.
+
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing)
+ ignore_trailing=ignore_trailing,
+ raise_on_truncation=raise_on_truncation)
return (r, received_time)
def udp(q, where, timeout=None, port=53, source=None, source_port=0,
- ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False):
+ ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+ raise_on_truncation=False):
"""Return the response obtained after sending a query via UDP.
*q*, a ``dns.message.Message``, the query to send
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
+ the TC bit is set.
+
Returns a ``dns.message.Message``.
"""
(_, sent_time) = send_udp(s, wire, destination, expiration)
(r, received_time) = receive_udp(s, destination, expiration,
ignore_unexpected, one_rr_per_rrset,
- q.keyring, q.mac, ignore_trailing)
+ q.keyring, q.mac, ignore_trailing,
+ raise_on_truncation)
r.time = received_time - sent_time
if not q.is_response(r):
raise BadResponse
return r
+def udp_with_fallback(q, where, timeout=None, port=53, source=None,
+ source_port=0, ignore_unexpected=False,
+ one_rr_per_rrset=False, ignore_trailing=False):
+ """Return the response to the query, trying UDP first and falling back
+ to TCP if UDP results in a truncated response.
+
+ *q*, a ``dns.message.Message``, the query to send
+
+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
+ to send the message.
+
+ *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+ query times out. If ``None``, the default, wait forever.
+
+ *port*, an ``int``, the port send the message to. The default is 53.
+
+ *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+ the source address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message.
+ The default is 0.
+
+ *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
+ unexpected sources.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
+ if and only if TCP was used.
+ """
+ try:
+ response = udp(q, where, timeout, port, source, source_port,
+ ignore_unexpected, one_rr_per_rrset,
+ ignore_trailing, True)
+ return (response, False)
+ except dns.message.Truncated:
+ response = tcp(q, where, timeout, port, source, source_port,
+ one_rr_per_rrset, ignore_trailing)
+ return (response, True)
def _net_read(sock, count, expiration):
"""Read the specified number of bytes from sock. Keep trying until we
timeout=timeout,
port=port,
source=source,
- source_port=source_port)
+ source_port=source_port,
+ raise_on_truncation=True)
else:
protocol = urlparse(nameserver).scheme
if protocol == 'https':
async def receive_udp(sock, destination, ignore_unexpected=False,
one_rr_per_rrset=False, keyring=None, request_mac=b'',
- ignore_trailing=False):
+ ignore_trailing=False, raise_on_truncation=False):
"""Asynchronously read a DNS message from a UDP socket.
*sock*, a ``trio.socket``.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
+ the TC bit is set.
+
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing)
+ ignore_trailing=ignore_trailing,
+ raise_on_truncation=raise_on_truncation)
return (r, received_time)
async def udp(q, where, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False,
- ignore_trailing=False):
+ ignore_trailing=False, raise_on_truncation=False):
"""Asynchronously return the response obtained after sending a query
via UDP.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
+ the TC bit is set.
+
Returns a ``dns.message.Message``.
"""
(r, received_time) = await receive_udp(s, destination,
ignore_unexpected,
one_rr_per_rrset, q.keyring,
- q.mac, ignore_trailing)
+ q.mac, ignore_trailing,
+ raise_on_truncation)
if not q.is_response(r):
raise BadResponse
r.time = received_time - sent_time
return r
+async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
+ source_port=0, ignore_unexpected=False,
+ one_rr_per_rrset=False, ignore_trailing=False):
+ """Return the response to the query, trying UDP first and falling back
+ to TCP if UDP results in a truncated response.
+
+ *q*, a ``dns.message.Message``, the query to send
+
+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
+ to send the message.
+
+ *port*, an ``int``, the port send the message to. The default is 53.
+
+ *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+ the source address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message.
+ The default is 0.
+
+ *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
+ unexpected sources.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
+ if and only if TCP was used.
+ """
+ try:
+ response = await udp(q, where, port, source, source_port,
+ ignore_unexpected, one_rr_per_rrset,
+ ignore_trailing, True)
+ return (response, False)
+ except dns.message.Truncated:
+ response = await stream(q, where, False, port, source, source_port,
+ one_rr_per_rrset, ignore_trailing)
+
+ return (response, True)
+
# pylint: disable=redefined-outer-name
async def send_stream(stream, what):
nameserver,
port=port,
source=source,
- source_port=source_port)
+ source_port=source_port,
+ raise_on_truncation=True)
else:
# We don't do DoH yet.
raise NotImplementedError
---
.. autofunction:: dns.query.udp
+.. autofunction:: dns.query.udp_with_fallback
.. autofunction:: dns.query.send_udp
.. autofunction:: dns.query.receive_udp
* The NINFO record is supported.
-* When decoding from wire format, if a message as TC (truncated) set,
- a ``Truncated`` exception is now raised.
-
* The ``dns.hash`` module has been removed; just use Python's native
``hashlib`` module.
a = dns.message.from_text(answer_text)
a.flags |= dns.flags.TC
wire = a.to_wire(want_shuffle=False)
- dns.message.from_wire(wire)
+ dns.message.from_wire(wire, raise_on_truncation=True)
self.assertRaises(dns.message.Truncated, bad)
def test_MessyTruncated(self):
a = dns.message.from_text(answer_text)
a.flags |= dns.flags.TC
wire = a.to_wire(want_shuffle=False)
- dns.message.from_wire(wire[:-3])
+ dns.message.from_wire(wire[:-3], raise_on_truncation=True)
self.assertRaises(dns.message.Truncated, bad)
def test_IDNA_2003(self):
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import socket
+import unittest
+
+import dns.message
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+import dns.query
+
+# Some tests require the internet to be available to run, so let's
+# skip those if it's not there.
+_network_available = True
+try:
+ socket.gethostbyname('dnspython.org')
+except socket.gaierror:
+ _network_available = False
+
+@unittest.skipIf(not _network_available, "Internet not reachable")
+class QueryTests(unittest.TestCase):
+
+ def testQueryUDP(self):
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.udp(q, '8.8.8.8')
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ def testQueryTCP(self):
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tcp(q, '8.8.8.8')
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ def testQueryTLS(self):
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tls(q, '8.8.8.8')
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ def testQueryUDPFallback(self):
+ qname = dns.name.from_text('.')
+ q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+ (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
+ self.assertTrue(tcp)
+
+ def testQueryUDPFallbackNoFallback(self):
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
+ self.assertFalse(tcp)
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryUDPFallback(self):
+ qname = dns.name.from_text('.')
+ async def run():
+ q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+ return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
+ (_, tcp) = trio.run(run)
+ self.assertTrue(tcp)
+
+ def testQueryUDPFallbackNoFallback(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
+ (_, tcp) = trio.run(run)
+ self.assertFalse(tcp)
except ModuleNotFoundError:
pass