]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
added bootstrap address option to https()
authorkimbo <kimballleavitt@gmail.com>
Mon, 30 Dec 2019 22:34:46 +0000 (15:34 -0700)
committerkimbo <kimballleavitt@gmail.com>
Mon, 30 Dec 2019 22:34:46 +0000 (15:34 -0700)
dns/query.py
dns/query.pyi
tests/test_doh.py

index 626120a69cbfe7d5f9640be47c0afb976e30d7b6..0f8ced9e1d968b66e6659d9e7092868af9262434 100644 (file)
@@ -28,6 +28,7 @@ import sys
 import time
 import base64
 import ipaddress
+import urllib.parse
 
 import dns.exception
 import dns.inet
@@ -37,8 +38,8 @@ import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 
-import requests
 from requests_toolbelt.adapters.source import SourceAddressAdapter
+from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
 
 try:
     import ssl
@@ -211,7 +212,7 @@ def _destination_and_source(af, where, port, source, source_port):
 
 
 def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True,
-          verify=True, source=None, source_port=0,
+          bootstrap_address=None, verify=True, source=None, source_port=0,
           one_rr_per_rrset=False, ignore_trailing=False):
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -234,6 +235,9 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
 
     *post*, a ``bool``. If ``True``, the default, POST method will be used.
 
+    *bootstrap_address*, a ``str``, the IP address to use to bypass the system's
+    DNS resolver.
+
     *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
     the source address.  The default is the wildcard address.
 
@@ -253,31 +257,34 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
     af = None
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
-    if source is None:
-        source = ('', 0)
+    headers = {
+        "accept": "application/dns-message"
+    }
     try:
         _ = ipaddress.ip_address(where)
         url = 'https://{}:{}{}'.format(where, port, path)
     except ValueError:
-        url = where
-    # set source port and source address
-    # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
-    session.mount(url, SourceAddressAdapter(source))
+        if bootstrap_address is not None:
+            split_url = urllib.parse.urlsplit(where)
+            headers['Host'] = split_url.hostname
+            url = where.replace(split_url.hostname, bootstrap_address)
+            session.mount(url, HostHeaderSSLAdapter())
+        else:
+            url = where
+    if source is not None:
+        # set source port and source address
+        session.mount(url, SourceAddressAdapter(source))
 
     # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
     if post:
-        headers = {
-            "accept": "application/dns-message",
+        headers.update({
             "content-type": "application/dns-message",
             "content-length": str(len(wire))
-        }
+        })
         response = session.post(url, headers=headers, data=wire, stream=True,
                                 timeout=timeout, verify=verify)
     else:
         wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
-        headers = {
-            "accept": "application/dns-message"
-        }
         url += "?dns={}".format(wire)
         response = session.get(url, headers=headers, stream=True,
                                timeout=timeout, verify=verify)
index faedf8ce65de9cae82ecf111715e8c6cebb51ed6..e943ba5fa7c4cc326431ed7900ca5468a772dcf7 100644 (file)
@@ -9,7 +9,7 @@ except ImportError:
         SSLContext = {}
 
 def https(q : message.Message, where: str, session: Session, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True,
-          verify : Optional[bool] = True, source : Optional[str] = None, source_port : Optional[int] = 0,
+          bootstrap_address : Optional[str] = None, verify : Optional[bool] = True, source : Optional[str] = None, source_port : Optional[int] = 0,
           one_rr_per_rrset : Optional[bool] = False, ignore_trailing : Optional[bool] = False) -> message.Message:
     pass
 
index 5e9297643cba484d9dbc137d6b1da54c612fcc55..481ca0c076179cabc0684d64359fe4cf8a6604a0 100644 (file)
@@ -18,6 +18,7 @@ import unittest
 import random
 
 import requests
+from requests.exceptions import SSLError
 
 import dns.query
 import dns.rdatatype
@@ -55,5 +56,17 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
         r = dns.query.https(q, nameserver_ip, self.session, post=False)
         self.assertTrue(q.is_response(r))
 
+    def test_bootstrap_address(self):
+        ip = '185.228.168.168'
+        invalid_tls_url = 'https://{}/doh/family-filter/'.format(ip)
+        valid_tls_url = 'https://doh.cleanbrowsing.org/doh/family-filter/'
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        # make sure CleanBrowsing's IP address will fail TLS certificate check
+        with self.assertRaises(SSLError):
+            dns.query.https(q, invalid_tls_url, self.session)
+        # use host header
+        r = dns.query.https(q, valid_tls_url, self.session, bootstrap_address=ip)
+        self.assertTrue(q.is_response(r))
+
 if __name__ == '__main__':
     unittest.main()