From: Brian Wellington Date: Mon, 1 Jun 2020 17:09:37 +0000 (-0700) Subject: Adds sock parameters to query methods. X-Git-Tag: v2.0.0rc1~133^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=52f46fc0d23d760f6b4b1e5d8d46cda031ed86e9;p=thirdparty%2Fdnspython.git Adds sock parameters to query methods. Allow passing a socket into dns.query.{udp,tcp,tls,udp_with_fallback}, and add tests for this. --- diff --git a/dns/query.py b/dns/query.py index 8f3fdab2..8ae073f1 100644 --- a/dns/query.py +++ b/dns/query.py @@ -87,11 +87,12 @@ class NoDOH(dns.exception.DNSException): available.""" -def _compute_expiration(timeout): +def _compute_times(timeout): + now = time.time() if timeout is None: - return None + return (now, None) else: - return time.time() + timeout + return (now, now + timeout) # This module can use either poll() or select() as the "polling backend". # @@ -230,6 +231,21 @@ def _destination_and_source(af, where, port, source, source_port, destination = None return (af, destination, source) +def _make_socket(af, type, source, ssl_context=None, server_hostname=None): + s = socket_factory(af, type) + try: + s.setblocking(False) + if source is not None: + s.bind(source) + if ssl_context: + return ssl_context.wrap_socket(s, do_handshake_on_connect=False, + server_hostname=server_hostname) + else: + return s + except Exception: + s.close() + raise + def https(q, where, timeout=None, port=443, source=None, source_port=0, one_rr_per_rrset=False, ignore_trailing=False, session=None, path='/dns-query', post=True, @@ -424,7 +440,7 @@ def receive_udp(sock, destination, expiration=None, def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, - raise_on_truncation=False): + raise_on_truncation=False, sock=None): """Return the response obtained after sending a query via UDP. *q*, a ``dns.message.Message``, the query to send @@ -455,30 +471,37 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0, *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the TC bit is set. + *sock*, a ``socket.socket``, or ``None``, the socket to use for the + query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking datagram socket, + and the *source* and *source_port* are ignored. + Returns a ``dns.message.Message``. """ wire = q.to_wire() (af, destination, source) = _destination_and_source(None, where, port, source, source_port) - with socket_factory(af, socket.SOCK_DGRAM, 0) as s: - expiration = _compute_expiration(timeout) - s.setblocking(0) - if source is not None: - s.bind(source) - (_, sent_time) = send_udp(s, wire, destination, expiration) + (begin_time, expiration) = _compute_times(timeout) + with contextlib.ExitStack() as stack: + if sock: + s = sock + else: + s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source)) + send_udp(s, wire, destination, expiration) (r, received_time) = receive_udp(s, destination, expiration, ignore_unexpected, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing, raise_on_truncation) - r.time = received_time - sent_time + r.time = received_time - begin_time if not q.is_response(r): raise BadResponse return r def udp_with_fallback(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, - one_rr_per_rrset=False, ignore_trailing=False): + one_rr_per_rrset=False, ignore_trailing=False, + udp_sock=None, tcp_sock=None): """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -507,17 +530,28 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the + UDP query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking datagram socket, + and the *source* and *source_port* are ignored for the UDP query. + + *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the + TCP query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking connected stream + socket, and *where*, *source* and *source_port* are ignored for the TCP + query. + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if TCP was used. """ try: response = udp(q, where, timeout, port, source, source_port, ignore_unexpected, one_rr_per_rrset, - ignore_trailing, True) + ignore_trailing, True, udp_sock) return (response, False) except dns.message.Truncated: response = tcp(q, where, timeout, port, source, source_port, - one_rr_per_rrset, ignore_trailing) + one_rr_per_rrset, ignore_trailing, tcp_sock) return (response, True) def _net_read(sock, count, expiration): @@ -634,12 +668,12 @@ def _connect(s, address, expiration): def tcp(q, where, timeout=None, port=53, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False): + one_rr_per_rrset=False, ignore_trailing=False, sock=None): """Return the response obtained after sending a query via TCP. *q*, a ``dns.message.Message``, the query to send - *where*, a ``str`` containing an IPv4 or IPv6 address, where + *where*, a ``str`` containing an IPv4 or IPv6 address, where to send the message. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the @@ -659,19 +693,31 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *sock*, a ``socket.socket``, or ``None``, the socket to use for the + query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking connected stream + socket, and *where*, *source* and *source_port* are ignored. + Returns a ``dns.message.Message``. """ wire = q.to_wire() - (af, destination, source) = _destination_and_source(None, where, port, - source, source_port) - with socket_factory(af, socket.SOCK_STREAM, 0) as s: - expiration = _compute_expiration(timeout) - s.setblocking(0) - begin_time = time.time() - if source is not None: - s.bind(source) - _connect(s, destination, expiration) + (begin_time, expiration) = _compute_times(timeout) + with contextlib.ExitStack() as stack: + if sock: + # + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + sock.getpeername() + s = sock + else: + (af, destination, source) = _destination_and_source(None, where, + port, source, + source_port) + s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM, + source)) + _connect(s, destination, expiration) send_tcp(s, wire, expiration) (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) @@ -693,7 +739,7 @@ def _tls_handshake(s, expiration): def tls(q, where, timeout=None, port=853, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, + one_rr_per_rrset=False, ignore_trailing=False, sock=None, ssl_context=None, server_hostname=None): """Return the response obtained after sending a query via TLS. @@ -719,6 +765,11 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for the + query. If ``None``, the default, a socket is created. Note that + if a socket is provided, it must be a nonblocking connected SSL stream + socket, and *where*, *source*, *source_port*, and *ssl_context* are ignored. + *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing a TLS connection. If ``None``, the default, creates one with the default configuration. @@ -730,21 +781,24 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0, Returns a ``dns.message.Message``. """ + if sock: + # + # If a socket was provided, there's no special TLS handling needed. + # + return tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing, sock) + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) (af, destination, source) = _destination_and_source(None, where, port, - source, source_port) - if ssl_context is None: + source, source_port) + if ssl_context is None and not sock: ssl_context = ssl.create_default_context() if server_hostname is None: ssl_context.check_hostname = False - with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0), - do_handshake_on_connect=False, - server_hostname=server_hostname) as s: - expiration = _compute_expiration(timeout) - s.setblocking(0) - begin_time = time.time() - if source is not None: - s.bind(source) + + with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context, + server_hostname=server_hostname) as s: _connect(s, destination, expiration) _tls_handshake(s, expiration) send_tcp(s, wire, expiration) @@ -828,11 +882,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if use_udp and rdtype != dns.rdatatype.IXFR: raise ValueError('cannot do a UDP AXFR') sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM - with socket_factory(af, sock_type, 0) as s: - s.setblocking(0) - if source is not None: - s.bind(source) - expiration = _compute_expiration(lifetime) + with _make_socket(af, sock_type, source) as s: + (_, expiration) = _compute_times(lifetime) _connect(s, destination, expiration) l = len(wire) if use_udp: @@ -854,7 +905,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, tsig_ctx = None first = True while not done: - mexpiration = _compute_expiration(timeout) + (_, mexpiration) = _compute_times(timeout) if mexpiration is None or \ (expiration is not None and mexpiration > expiration): mexpiration = expiration diff --git a/tests/test_query.py b/tests/test_query.py index 9c632171..e031cfd1 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -18,6 +18,12 @@ import socket import unittest +try: + import ssl + have_ssl = True +except Exception: + have_ssl = False + import dns.message import dns.name import dns.rdataclass @@ -46,6 +52,19 @@ class QueryTests(unittest.TestCase): self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) + def testQueryUDPWithSocket(self): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.setblocking(0) + qname = dns.name.from_text('dns.google.') + q = dns.message.make_query(qname, dns.rdatatype.A) + response = dns.query.udp(q, '8.8.8.8', sock=s) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + def testQueryTCP(self): qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) @@ -57,6 +76,20 @@ class QueryTests(unittest.TestCase): self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) + def testQueryTCPWithSocket(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(('8.8.8.8', 53)) + s.setblocking(0) + qname = dns.name.from_text('dns.google.') + q = dns.message.make_query(qname, dns.rdatatype.A) + response = dns.query.tcp(q, None, sock=s) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + def testQueryTLS(self): qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) @@ -68,12 +101,42 @@ class QueryTests(unittest.TestCase): self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) + @unittest.skipUnless(have_ssl, "No SSL support") + def testQueryTLSWithSocket(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(('8.8.8.8', 853)) + ctx = ssl.create_default_context() + s = ctx.wrap_socket(s, server_hostname='dns.google') + s.setblocking(0) + qname = dns.name.from_text('dns.google.') + q = dns.message.make_query(qname, dns.rdatatype.A) + response = dns.query.tls(q, None, sock=s) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + def testQueryUDPFallback(self): qname = dns.name.from_text('.') q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8') self.assertTrue(tcp) + def testQueryUDPFallbackWithSocket(self): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_s: + udp_s.setblocking(0) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_s: + tcp_s.connect(('8.8.8.8', 53)) + tcp_s.setblocking(0) + qname = dns.name.from_text('.') + q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) + (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8', + udp_sock=udp_s, + tcp_sock=tcp_s) + self.assertTrue(tcp) + def testQueryUDPFallbackNoFallback(self): qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A)