From: Brian Wellington Date: Tue, 7 Jan 2020 21:03:13 +0000 (-0800) Subject: DoH cleanup. X-Git-Tag: v2.0.0rc1~341^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7ec39e21ab0a6761a34ec405a4a59dc4ebe54924;p=thirdparty%2Fdnspython.git DoH cleanup. --- diff --git a/dns/query.py b/dns/query.py index c36248e3..5876623e 100644 --- a/dns/query.py +++ b/dns/query.py @@ -189,14 +189,16 @@ def _addresses_equal(af, a1, a2): return n1 == n2 and a1[1:] == a2[1:] -def _destination_and_source(af, where, port, source, source_port): +def _destination_and_source(af, where, port, source, source_port, + default_to_inet=True): # Apply defaults and compute destination and source tuples # suitable for use in connect(), sendto(), or bind(). if af is None: try: af = dns.inet.af_for_address(where) except Exception: - af = dns.inet.AF_INET + if default_to_inet: + af = dns.inet.AF_INET if af == dns.inet.AF_INET: destination = (where, port) if source is not None or source_port != 0: @@ -209,6 +211,9 @@ def _destination_and_source(af, where, port, source, source_port): if source is None: source = '::' source = (source, source_port, 0, 0) + else: + source = None + destination = None return (af, destination, source) def send_https(session, what, lifetime=None): @@ -225,9 +230,10 @@ def send_https(session, what, lifetime=None): what = what.prepare() return session.send(what, timeout=lifetime) -def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True, - bootstrap_address=None, verify=True, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False): +def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, + session=None, path='/dns-query', post=True, + bootstrap_address=None, verify=True): """Return the response obtained after sending a query via DNS-over-HTTPS. *q*, a ``dns.message.Message``, the query to send. @@ -236,21 +242,15 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru address is given, the URL will be constructed using the following schema: https://:/. - *session*, a ``requests.session.Session``, the session to use to send the - queries. This argument is required to allow for connection reuse. - *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query times out. If ``None``, the default, wait forever. - *port*, a ``int``, the port to send the query to. Default is 443. - - *path*, a ``str``. If *where* is an IP address, then *path* will be used to - construct the URL to send the DNS query to. - - *post*, a ``bool``. If ``True``, the default, POST method will be used. + *port*, a ``int``, the port to send the query to. The default is 443. - *bootstrap_address*, a ``str``, the IP address to use to bypass the system's - DNS resolver. + *af*, an ``int``, the address family to use. The default is ``None``, + which causes the address family to use to be inferred from the form of + *where*, or uses the system default. Setting this to AF_INET or + AF_INET6 currently has no effect. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source address. The default is the wildcard address. @@ -264,13 +264,27 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *session*, a ``requests.session.Session``. If provided, the session to use + to send the queries. + + *path*, a ``str``. If *where* is an IP address, then *path* will be used to + construct the URL to send the DNS query to. + + *post*, a ``bool``. If ``True``, the default, POST method will be used. + + *bootstrap_address*, a ``str``, the IP address to use to bypass the + system's DNS resolver. + + *verify*, a ``str`, containing a path to a certificate file or directory. + Returns a ``dns.message.Message``. """ wire = q.to_wire() - af = None (af, destination, source) = _destination_and_source(af, where, port, - source, source_port) + source, source_port, + False) + transport_adapter = None headers = { "accept": "application/dns-message" } @@ -282,31 +296,49 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru split_url = urllib.parse.urlsplit(where) headers['Host'] = split_url.hostname url = where.replace(split_url.hostname, bootstrap_address) - session.mount(url, HostHeaderSSLAdapter()) + transport_adapter = HostHeaderSSLAdapter() else: url = where if source is not None: # set source port and source address - session.mount(url, SourceAddressAdapter(source)) - - # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples - if post: - headers.update({ - "content-type": "application/dns-message", - "content-length": str(len(wire)) - }) - response = session.post(url, headers=headers, data=wire, stream=True, - timeout=timeout, verify=verify) + transport_adapter = SourceAddressAdapter(source) + + if session: + close_session = False else: - wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=") - url += "?dns={}".format(wire) - response = session.get(url, headers=headers, stream=True, - timeout=timeout, verify=verify) + session = requests.sessions.Session() + close_session = True + + try: + if transport_adapter: + session.mount(url, transport_adapter) + + # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH + # GET and POST examples + if post: + headers.update({ + "content-type": "application/dns-message", + "content-length": str(len(wire)) + }) + response = session.post(url, headers=headers, data=wire, + stream=True, timeout=timeout, + verify=verify) + else: + wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=") + url += "?dns={}".format(wire) + response = session.get(url, headers=headers, stream=True, + timeout=timeout, verify=verify) + finally: + if close_session: + session.close() - # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes + # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH + # status codes if response.status_code < 200 or response.status_code > 299: - raise ValueError('{} responded with status code {}\nResponse body: {}'.format( - where, response.status_code, response.content)) + raise ValueError('{} responded with status code {}' + '\nResponse body: {}'.format(where, + response.status_code, + response.content)) r = dns.message.from_wire(response.content, keyring=q.keyring, request_mac=q.request_mac, diff --git a/dns/query.pyi b/dns/query.pyi index e943ba5f..93461232 100644 --- a/dns/query.pyi +++ b/dns/query.pyi @@ -8,9 +8,8 @@ except ImportError: class ssl(object): SSLContext = {} -def https(q : message.Message, where: str, session: Session, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True, - bootstrap_address : Optional[str] = None, verify : Optional[bool] = True, source : Optional[str] = None, source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, ignore_trailing : Optional[bool] = False) -> message.Message: +def https(q : message.Message, where: str, timeout : Optional[float] = None, port : Optional[int] = 443, af : Optional[int] = None, source : Optional[str] = None, source_port : Optional[int] = 0, + session: Optional[Session], path : Optional[str] = '/dns-query', post : Optional[bool] = True, bootstrap_address : Optional[str] = None, verify : Optional[bool] = True) -> message.Message: pass def tcp(q : message.Message, where : str, timeout : float = None, port=53, af : Optional[int] = None, source : Optional[str] = None, source_port : Optional[int] = 0, diff --git a/dns/resolver.py b/dns/resolver.py index 735de9f5..3f5e4518 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -909,30 +909,37 @@ class Resolver(object): try: if protocol == 'https': tcp_attempt = True - response = dns.query.https(request, nameserver, timeout) + response = dns.query.https(request, nameserver, + timeout=timeout) elif protocol: continue else: tcp_attempt = tcp if tcp: response = dns.query.tcp(request, nameserver, - timeout, port, + timeout=timeout, + port=port, source=source, - source_port=source_port) + source_port=\ + source_port) else: try: - response = dns.query.udp(request, nameserver, - timeout, port, + response = dns.query.udp(request, + nameserver, + timeout=timeout, + port=port, source=source, source_port=\ source_port) except dns.message.Truncated: # Response truncated; retry with TCP. tcp_attempt = True - timeout = self._compute_timeout(start, lifetime) + timeout = self._compute_timeout(start, + lifetime) response = \ dns.query.tcp(request, nameserver, - timeout, port, + timeout=timeout, + port=port, source=source, source_port=source_port) except (socket.error, dns.exception.Timeout) as ex: diff --git a/examples/doh.py b/examples/doh.py index 01c562f8..eff9ae75 100644 --- a/examples/doh.py +++ b/examples/doh.py @@ -18,7 +18,7 @@ def main(): # one method is to use context manager, session will automatically close with requests.sessions.Session() as session: q = dns.message.make_query(qname, dns.rdatatype.A) - r = dns.query.https(q, where, session) + r = dns.query.https(q, where, session=session) for answer in r.answer: print(answer) @@ -29,7 +29,7 @@ def main(): # second method, close session manually session = requests.sessions.Session() q = dns.message.make_query(qname, dns.rdatatype.A) - r = dns.query.https(q, where, session) + r = dns.query.https(q, where, session=session) for answer in r.answer: print(answer) diff --git a/tests/test_doh.py b/tests/test_doh.py index 3819b1a4..acda5af6 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -40,13 +40,13 @@ class DNSOverHTTPSTestCase(unittest.TestCase): def test_get_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, self.session, post=False) + r = dns.query.https(q, nameserver_url, session=self.session, post=False) self.assertTrue(q.is_response(r)) def test_post_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, self.session, post=True) + r = dns.query.https(q, nameserver_url, session=self.session, post=True) self.assertTrue(q.is_response(r)) def test_build_url_from_ip(self): @@ -54,7 +54,7 @@ class DNSOverHTTPSTestCase(unittest.TestCase): q = dns.message.make_query('example.com.', dns.rdatatype.A) # For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query # So we're just going to do GET requests here - r = dns.query.https(q, nameserver_ip, self.session, post=False) + r = dns.query.https(q, nameserver_ip, session=self.session, post=False) self.assertTrue(q.is_response(r)) def test_bootstrap_address(self): @@ -64,9 +64,9 @@ class DNSOverHTTPSTestCase(unittest.TestCase): q = dns.message.make_query('example.com.', dns.rdatatype.A) # make sure CleanBrowsing's IP address will fail TLS certificate check with self.assertRaises(SSLError): - dns.query.https(q, invalid_tls_url, self.session) + dns.query.https(q, invalid_tls_url, session=self.session) # use host header - r = dns.query.https(q, valid_tls_url, self.session, bootstrap_address=ip) + r = dns.query.https(q, valid_tls_url, session=self.session, bootstrap_address=ip) self.assertTrue(q.is_response(r)) def test_send_https(self): @@ -79,5 +79,11 @@ class DNSOverHTTPSTestCase(unittest.TestCase): dns_resp = dns.message.from_wire(response.content) self.assertTrue(q.is_response(dns_resp)) + def test_new_session(self): + nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) + q = dns.message.make_query('example.com.', dns.rdatatype.A) + r = dns.query.https(q, nameserver_url) + self.assertTrue(q.is_response(r)) + if __name__ == '__main__': unittest.main()