]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Adds sock parameters to query methods. 498/head
authorBrian Wellington <bwelling@xbill.org>
Mon, 1 Jun 2020 17:09:37 +0000 (10:09 -0700)
committerBrian Wellington <bwelling@xbill.org>
Mon, 1 Jun 2020 17:09:37 +0000 (10:09 -0700)
Allow passing a socket into dns.query.{udp,tcp,tls,udp_with_fallback},
and add tests for this.

dns/query.py
tests/test_query.py

index 8f3fdab295b98f96ef571aa57dd60325de7a8ceb..8ae073f16d2c4786405e2e7ee8a8deb427e258f4 100644 (file)
@@ -87,11 +87,12 @@ class NoDOH(dns.exception.DNSException):
     available."""
 
 
-def _compute_expiration(timeout):
+def _compute_times(timeout):
+    now = time.time()
     if timeout is None:
-        return None
+        return (now, None)
     else:
-        return time.time() + timeout
+        return (now, now + timeout)
 
 # This module can use either poll() or select() as the "polling backend".
 #
@@ -230,6 +231,21 @@ def _destination_and_source(af, where, port, source, source_port,
         destination = None
     return (af, destination, source)
 
+def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
+    s = socket_factory(af, type)
+    try:
+        s.setblocking(False)
+        if source is not None:
+            s.bind(source)
+        if ssl_context:
+            return ssl_context.wrap_socket(s, do_handshake_on_connect=False,
+                                           server_hostname=server_hostname)
+        else:
+            return s
+    except Exception:
+        s.close()
+        raise
+
 def https(q, where, timeout=None, port=443, source=None, source_port=0,
           one_rr_per_rrset=False, ignore_trailing=False,
           session=None, path='/dns-query', post=True,
@@ -424,7 +440,7 @@ def receive_udp(sock, destination, expiration=None,
 
 def udp(q, where, timeout=None, port=53, source=None, source_port=0,
         ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
-        raise_on_truncation=False):
+        raise_on_truncation=False, sock=None):
     """Return the response obtained after sending a query via UDP.
 
     *q*, a ``dns.message.Message``, the query to send
@@ -455,30 +471,37 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0,
     *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
     the TC bit is set.
 
+    *sock*, a ``socket.socket``, or ``None``, the socket to use for the
+    query.  If ``None``, the default, a socket is created.  Note that
+    if a socket is provided, it must be a nonblocking datagram socket,
+    and the *source* and *source_port* are ignored.
+
     Returns a ``dns.message.Message``.
     """
 
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(None, where, port,
                                                         source, source_port)
-    with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
-        expiration = _compute_expiration(timeout)
-        s.setblocking(0)
-        if source is not None:
-            s.bind(source)
-        (_, sent_time) = send_udp(s, wire, destination, expiration)
+    (begin_time, expiration) = _compute_times(timeout)
+    with contextlib.ExitStack() as stack:
+        if sock:
+            s = sock
+        else:
+            s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source))
+        send_udp(s, wire, destination, expiration)
         (r, received_time) = receive_udp(s, destination, expiration,
                                          ignore_unexpected, one_rr_per_rrset,
                                          q.keyring, q.mac, ignore_trailing,
                                          raise_on_truncation)
-        r.time = received_time - sent_time
+        r.time = received_time - begin_time
         if not q.is_response(r):
             raise BadResponse
         return r
 
 def udp_with_fallback(q, where, timeout=None, port=53, source=None,
                       source_port=0, ignore_unexpected=False,
-                      one_rr_per_rrset=False, ignore_trailing=False):
+                      one_rr_per_rrset=False, ignore_trailing=False,
+                      udp_sock=None, tcp_sock=None):
     """Return the response to the query, trying UDP first and falling back
     to TCP if UDP results in a truncated response.
 
@@ -507,17 +530,28 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
+    UDP query.  If ``None``, the default, a socket is created.  Note that
+    if a socket is provided, it must be a nonblocking datagram socket,
+    and the *source* and *source_port* are ignored for the UDP query.
+
+    *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
+    TCP query.  If ``None``, the default, a socket is created.  Note that
+    if a socket is provided, it must be a nonblocking connected stream
+    socket, and *where*, *source* and *source_port* are ignored for the TCP
+    query.
+
     Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
     if and only if TCP was used.
     """
     try:
         response = udp(q, where, timeout, port, source, source_port,
                        ignore_unexpected, one_rr_per_rrset,
-                       ignore_trailing, True)
+                       ignore_trailing, True, udp_sock)
         return (response, False)
     except dns.message.Truncated:
         response = tcp(q, where, timeout, port, source, source_port,
-                       one_rr_per_rrset, ignore_trailing)
+                       one_rr_per_rrset, ignore_trailing, tcp_sock)
         return (response, True)
 
 def _net_read(sock, count, expiration):
@@ -634,12 +668,12 @@ def _connect(s, address, expiration):
 
 
 def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
-        one_rr_per_rrset=False, ignore_trailing=False):
+        one_rr_per_rrset=False, ignore_trailing=False, sock=None):
     """Return the response obtained after sending a query via TCP.
 
     *q*, a ``dns.message.Message``, the query to send
 
-    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    *where*, a ``str`` containing an IPv4 or IPv6 address, where
     to send the message.
 
     *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
@@ -659,19 +693,31 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *sock*, a ``socket.socket``, or ``None``, the socket to use for the
+    query.  If ``None``, the default, a socket is created.  Note that
+    if a socket is provided, it must be a nonblocking connected stream
+    socket, and *where*, *source* and *source_port* are ignored.
+
     Returns a ``dns.message.Message``.
     """
 
     wire = q.to_wire()
-    (af, destination, source) = _destination_and_source(None, where, port,
-                                                        source, source_port)
-    with socket_factory(af, socket.SOCK_STREAM, 0) as s:
-        expiration = _compute_expiration(timeout)
-        s.setblocking(0)
-        begin_time = time.time()
-        if source is not None:
-            s.bind(source)
-        _connect(s, destination, expiration)
+    (begin_time, expiration) = _compute_times(timeout)
+    with contextlib.ExitStack() as stack:
+        if sock:
+            #
+            # Verify that the socket is connected, as if it's not connected,
+            # it's not writable, and the polling in send_tcp() will time out or
+            # hang forever.
+            sock.getpeername()
+            s = sock
+        else:
+            (af, destination, source) = _destination_and_source(None, where,
+                                                                port, source,
+                                                                source_port)
+            s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM,
+                                                 source))
+            _connect(s, destination, expiration)
         send_tcp(s, wire, expiration)
         (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
                                          q.keyring, q.mac, ignore_trailing)
@@ -693,7 +739,7 @@ def _tls_handshake(s, expiration):
 
 
 def tls(q, where, timeout=None, port=853, source=None, source_port=0,
-        one_rr_per_rrset=False, ignore_trailing=False,
+        one_rr_per_rrset=False, ignore_trailing=False, sock=None,
         ssl_context=None, server_hostname=None):
     """Return the response obtained after sending a query via TLS.
 
@@ -719,6 +765,11 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for the
+    query.  If ``None``, the default, a socket is created.  Note that
+    if a socket is provided, it must be a nonblocking connected SSL stream
+    socket, and *where*, *source*, *source_port*, and *ssl_context* are ignored.
+
     *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
     a TLS connection. If ``None``, the default, creates one with the default
     configuration.
@@ -730,21 +781,24 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0,
     Returns a ``dns.message.Message``.
     """
 
+    if sock:
+        #
+        # If a socket was provided, there's no special TLS handling needed.
+        #
+        return tcp(q, where, timeout, port, source, source_port,
+                   one_rr_per_rrset, ignore_trailing, sock)
+
     wire = q.to_wire()
+    (begin_time, expiration) = _compute_times(timeout)
     (af, destination, source) = _destination_and_source(None, where, port,
-                                                        source, source_port)
-    if ssl_context is None:
+                                                       source, source_port)
+    if ssl_context is None and not sock:
         ssl_context = ssl.create_default_context()
         if server_hostname is None:
             ssl_context.check_hostname = False
-    with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0),
-                                 do_handshake_on_connect=False,
-                                 server_hostname=server_hostname) as s:
-        expiration = _compute_expiration(timeout)
-        s.setblocking(0)
-        begin_time = time.time()
-        if source is not None:
-            s.bind(source)
+
+    with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context,
+                      server_hostname=server_hostname) as s:
         _connect(s, destination, expiration)
         _tls_handshake(s, expiration)
         send_tcp(s, wire, expiration)
@@ -828,11 +882,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
     if use_udp and rdtype != dns.rdatatype.IXFR:
         raise ValueError('cannot do a UDP AXFR')
     sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
-    with socket_factory(af, sock_type, 0) as s:
-        s.setblocking(0)
-        if source is not None:
-            s.bind(source)
-        expiration = _compute_expiration(lifetime)
+    with _make_socket(af, sock_type, source) as s:
+        (_, expiration) = _compute_times(lifetime)
         _connect(s, destination, expiration)
         l = len(wire)
         if use_udp:
@@ -854,7 +905,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
         tsig_ctx = None
         first = True
         while not done:
-            mexpiration = _compute_expiration(timeout)
+            (_, mexpiration) = _compute_times(timeout)
             if mexpiration is None or \
                (expiration is not None and mexpiration > expiration):
                 mexpiration = expiration
index 9c6321712dc960f53aeb5543f478ad5391548773..e031cfd19325ff6d6f2d287a43ef978df7862757 100644 (file)
 import socket
 import unittest
 
+try:
+    import ssl
+    have_ssl = True
+except Exception:
+    have_ssl = False
+
 import dns.message
 import dns.name
 import dns.rdataclass
@@ -46,6 +52,19 @@ class QueryTests(unittest.TestCase):
         self.assertTrue('8.8.8.8' in seen)
         self.assertTrue('8.8.4.4' in seen)
 
+    def testQueryUDPWithSocket(self):
+        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+            s.setblocking(0)
+            qname = dns.name.from_text('dns.google.')
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            response = dns.query.udp(q, '8.8.8.8', sock=s)
+            rrs = response.get_rrset(response.answer, qname,
+                                     dns.rdataclass.IN, dns.rdatatype.A)
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue('8.8.8.8' in seen)
+            self.assertTrue('8.8.4.4' in seen)
+
     def testQueryTCP(self):
         qname = dns.name.from_text('dns.google.')
         q = dns.message.make_query(qname, dns.rdatatype.A)
@@ -57,6 +76,20 @@ class QueryTests(unittest.TestCase):
         self.assertTrue('8.8.8.8' in seen)
         self.assertTrue('8.8.4.4' in seen)
 
+    def testQueryTCPWithSocket(self):
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            s.connect(('8.8.8.8', 53))
+            s.setblocking(0)
+            qname = dns.name.from_text('dns.google.')
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            response = dns.query.tcp(q, None, sock=s)
+            rrs = response.get_rrset(response.answer, qname,
+                                     dns.rdataclass.IN, dns.rdatatype.A)
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue('8.8.8.8' in seen)
+            self.assertTrue('8.8.4.4' in seen)
+
     def testQueryTLS(self):
         qname = dns.name.from_text('dns.google.')
         q = dns.message.make_query(qname, dns.rdatatype.A)
@@ -68,12 +101,42 @@ class QueryTests(unittest.TestCase):
         self.assertTrue('8.8.8.8' in seen)
         self.assertTrue('8.8.4.4' in seen)
 
+    @unittest.skipUnless(have_ssl, "No SSL support")
+    def testQueryTLSWithSocket(self):
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            s.connect(('8.8.8.8', 853))
+            ctx = ssl.create_default_context()
+            s = ctx.wrap_socket(s, server_hostname='dns.google')
+            s.setblocking(0)
+            qname = dns.name.from_text('dns.google.')
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            response = dns.query.tls(q, None, sock=s)
+            rrs = response.get_rrset(response.answer, qname,
+                                     dns.rdataclass.IN, dns.rdatatype.A)
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue('8.8.8.8' in seen)
+            self.assertTrue('8.8.4.4' in seen)
+
     def testQueryUDPFallback(self):
         qname = dns.name.from_text('.')
         q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
         (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
         self.assertTrue(tcp)
 
+    def testQueryUDPFallbackWithSocket(self):
+        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_s:
+            udp_s.setblocking(0)
+            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_s:
+                tcp_s.connect(('8.8.8.8', 53))
+                tcp_s.setblocking(0)
+                qname = dns.name.from_text('.')
+                q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+                (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8',
+                                                      udp_sock=udp_s,
+                                                      tcp_sock=tcp_s)
+                self.assertTrue(tcp)
+
     def testQueryUDPFallbackNoFallback(self):
         qname = dns.name.from_text('dns.google.')
         q = dns.message.make_query(qname, dns.rdatatype.A)