]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
For DoH, use httpx and with HTTP/2 if we can, but fall back to requests if we have to.
authorBob Halley <halley@dnspython.org>
Wed, 10 Nov 2021 16:42:22 +0000 (08:42 -0800)
committerBob Halley <halley@dnspython.org>
Wed, 10 Nov 2021 16:42:22 +0000 (08:42 -0800)
dns/query.py
tests/test_doh.py

index fee5d6af7861bd9a1a96f75684614d59b7d1ca7b..5bf471ab97a28905563220478583b1746e478895 100644 (file)
@@ -42,9 +42,17 @@ try:
     import requests
     from requests_toolbelt.adapters.source import SourceAddressAdapter
     from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
-    have_doh = True
+    _have_requests = True
 except ImportError:  # pragma: no cover
-    have_doh = False
+    _have_requests = False
+
+try:
+    import httpx
+    _have_httpx = True
+except ImportError:  # pragma: no cover
+    _have_httpx = False
+
+have_doh = _have_requests or _have_httpx
 
 try:
     import ssl
@@ -277,10 +285,13 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
     if not have_doh:
         raise NoDOH  # pragma: no cover
 
+    _httpx_ok = True
+
     wire = q.to_wire()
     (af, _, source) = _destination_and_source(where, port, source, source_port,
                                               False)
     transport_adapter = None
+    transport = None
     headers = {
         "accept": "application/dns-message"
     }
@@ -290,19 +301,47 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
         elif af == socket.AF_INET6:
             url = 'https://[{}]:{}{}'.format(where, port, path)
     elif bootstrap_address is not None:
+        _httpx_ok = False
         split_url = urllib.parse.urlsplit(where)
         headers['Host'] = split_url.hostname
         url = where.replace(split_url.hostname, bootstrap_address)
-        transport_adapter = HostHeaderSSLAdapter()
+        if _have_requests:
+            transport_adapter = HostHeaderSSLAdapter()
     else:
         url = where
     if source is not None:
         # set source port and source address
-        transport_adapter = SourceAddressAdapter(source)
+        if _have_httpx:
+            if source_port == 0:
+                transport = httpx.HTTPTransport(local_address=source[0])
+            else:
+                _httpx_ok = False
+        if _have_requests:
+            transport_adapter = SourceAddressAdapter(source)
+
+    if not _httpx_ok and not _have_requests:
+        raise NoDOH('Cannot use httpx for this operation, and '
+                    'requests is not available.')
 
     with contextlib.ExitStack() as stack:
+        if session:
+            if _have_httpx:
+                _is_httpx = isinstance(session, httpx.Client)
+            else:
+                _is_httpx = False
+            if _is_httpx and not _httpx_ok:
+                # we can't use this session
+                session = None
         if not session:
-            session = stack.enter_context(requests.sessions.Session())
+            if _have_httpx and _httpx_ok:
+                _is_httpx = True
+                session = stack.enter_context(httpx.Client(http1=True,
+                                                           http2=True,
+                                                           verify=verify,
+                                                           transport=transport))
+            else:
+                _is_httpx = False
+                session = stack.enter_context(requests.sessions.Session())
 
         if transport_adapter:
             session.mount(url, transport_adapter)
@@ -314,13 +353,23 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
                 "content-type": "application/dns-message",
                 "content-length": str(len(wire))
             })
-            response = session.post(url, headers=headers, data=wire,
-                                    timeout=timeout, verify=verify)
+            if _is_httpx:
+                response = session.post(url, headers=headers, content=wire,
+                                        timeout=timeout)
+            else:
+                response = session.post(url, headers=headers, data=wire,
+                                        timeout=timeout, verify=verify)
         else:
             wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
-            response = session.get(url, headers=headers,
-                                   timeout=timeout, verify=verify,
-                                   params={"dns": wire})
+            if _is_httpx:
+                wire = wire.decode()  # httpx does a repr() if we give it bytes
+                response = session.get(url, headers=headers,
+                                       timeout=timeout,
+                                       params={"dns": wire})
+            else:
+                response = session.get(url, headers=headers,
+                                       timeout=timeout, verify=verify,
+                                       params={"dns": wire})
 
     # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
     # status codes
index 835e07daa1d0b2fe31285951805b108a0799749b..b5750546df7133ef4c1828caa0e4d159860ef908 100644 (file)
@@ -23,10 +23,13 @@ import dns.query
 import dns.rdatatype
 import dns.resolver
 
-if dns.query.have_doh:
+if dns.query._have_requests:
     import requests
     from requests.exceptions import SSLError
 
+if dns.query._have_httpx:
+    import httpx
+
 # Probe for IPv4 and IPv6
 resolver_v4_addresses = []
 resolver_v6_addresses = []
@@ -66,9 +69,10 @@ try:
 except socket.gaierror:
     _network_available = False
 
-@unittest.skipUnless(dns.query.have_doh and _network_available,
+
+@unittest.skipUnless(dns.query._have_requests and _network_available,
                      "Python requests cannot be imported; no DNS over HTTPS (DOH)")
-class DNSOverHTTPSTestCase(unittest.TestCase):
+class DNSOverHTTPSTestCaseRequests(unittest.TestCase):
     def setUp(self):
         self.session = requests.sessions.Session()
 
@@ -140,5 +144,79 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
         self.assertTrue('8.8.4.4' in seen)
 
 
+@unittest.skipUnless(dns.query._have_httpx and _network_available,
+                     "Python httpx cannot be imported; no DNS over HTTPS (DOH)")
+class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
+    def setUp(self):
+        self.session = httpx.Client(http1=True, http2=True, verify=True)
+
+    def tearDown(self):
+        self.session.close()
+
+    def test_get_request(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        r = dns.query.https(q, nameserver_url, session=self.session, post=False,
+                            timeout=4)
+        self.assertTrue(q.is_response(r))
+
+    def test_post_request(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        r = dns.query.https(q, nameserver_url, session=self.session, post=True,
+                            timeout=4)
+        self.assertTrue(q.is_response(r))
+
+    def test_build_url_from_ip(self):
+        self.assertTrue(resolver_v4_addresses or resolver_v6_addresses)
+        if resolver_v4_addresses:
+            nameserver_ip = random.choice(resolver_v4_addresses)
+            q = dns.message.make_query('example.com.', dns.rdatatype.A)
+            # For some reason Google's DNS over HTTPS fails when you POST to
+            # https://8.8.8.8/dns-query
+            # So we're just going to do GET requests here
+            r = dns.query.https(q, nameserver_ip, session=self.session,
+                                post=False, timeout=4)
+
+            self.assertTrue(q.is_response(r))
+        if resolver_v6_addresses:
+            nameserver_ip = random.choice(resolver_v6_addresses)
+            q = dns.message.make_query('example.com.', dns.rdatatype.A)
+            r = dns.query.https(q, nameserver_ip, session=self.session,
+                                post=False, timeout=4)
+            self.assertTrue(q.is_response(r))
+
+    def test_bootstrap_address(self):
+        # We test this to see if v4 is available
+        if resolver_v4_addresses:
+            ip = '185.228.168.168'
+            invalid_tls_url = 'https://{}/doh/family-filter/'.format(ip)
+            valid_tls_url = 'https://doh.cleanbrowsing.org/doh/family-filter/'
+            q = dns.message.make_query('example.com.', dns.rdatatype.A)
+            # make sure CleanBrowsing's IP address will fail TLS certificate
+            # check
+            with self.assertRaises(httpx.ConnectError):
+                dns.query.https(q, invalid_tls_url, session=self.session,
+                                timeout=4)
+            # use host header
+            r = dns.query.https(q, valid_tls_url, session=self.session,
+                                bootstrap_address=ip, timeout=4)
+            self.assertTrue(q.is_response(r))
+
+    def test_new_session(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        r = dns.query.https(q, nameserver_url, timeout=4)
+        self.assertTrue(q.is_response(r))
+
+    def test_resolver(self):
+        res = dns.resolver.Resolver(configure=False)
+        res.nameservers = ['https://dns.google/dns-query']
+        answer = res.resolve('dns.google', 'A')
+        seen = set([rdata.address for rdata in answer])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+
 if __name__ == '__main__':
     unittest.main()