]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
cleaned up query.py and test_doh
authorkimbo <kimballleavitt@gmail.com>
Mon, 23 Dec 2019 21:09:11 +0000 (14:09 -0700)
committerkimbo <kimballleavitt@gmail.com>
Mon, 23 Dec 2019 21:09:11 +0000 (14:09 -0700)
dns/query.py
tests/test_doh.py

index 49608a1a3b09b8d6b6bf18d16949e10cf27bda0a..a2d126e88aac301204a9543e347db518001cfb1e 100644 (file)
@@ -38,6 +38,8 @@ import dns.rdataclass
 import dns.rdatatype
 
 import requests
+import requests.packages.urllib3.util.connection as urllib3_cn
+from requests_toolbelt.adapters.source import SourceAddressAdapter
 
 try:
     import ssl
@@ -209,18 +211,27 @@ def _destination_and_source(af, where, port, source, source_port):
     return (af, destination, source)
 
 
-def https(q, where, timeout=None, af=None, source=None, source_port=0,
-          one_rr_per_rrset=False, ignore_trailing=False,
-          post=True, path='/dns-query', verify=True):
+def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
+          verify=True, af=None, source=None, source_port=0,
+          one_rr_per_rrset=False, ignore_trailing=False):
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
     *q*, a ``dns.message.Message``, the query to send.
 
-    *where*, a ``str``, the nameserver IP address or the full URL.
+    *where*, a ``str``, the nameserver IP address or the full URL. If an IP
+    address is given, the URL will be constructed using the following schema:
+    https:<IP-address>:<port>/<path>.
 
     *timeout*, a ``float`` or ``None``, the number of seconds to
     wait before the query times out. If ``None``, the default, wait forever.
 
+    *port*, a ``int``, the port to send the query to. Default is 443.
+
+    *path*, a ``str``. If *where* is an IP address, then *path* will be used to
+    construct the URL to send the DNS query to.
+
+    *post*, a ``bool``. If ``True``, the default, POST method will be used.
+
     *af*, an ``int``, the address family to use.  The default is ``None``,
     which causes the address family to use to be inferred from the form of
     *where*.  If the inference attempt fails, AF_INET is used.  This
@@ -238,66 +249,60 @@ def https(q, where, timeout=None, af=None, source=None, source_port=0,
     *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
     junk at end of the received message.
 
-    *post*, a ``bool``. If ``True``, the default, POST method should be used.
-
-    *path*, a ``str``. If *where* is an IP address, then *path* will be used to
-    construct the URL to send the DNS query to
-
     Returns a ``dns.message.Message``.
     """
     wire = q.to_wire()
 
-    port = 443
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
+    if source is None:
+        source = ('', 0)
+    if headers is None:
+        headers = {}
+    with requests.Session() as session:
+        # set source port and source address
+        # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
+        session.mount('http://', SourceAddressAdapter(source))
+        session.mount('https://', SourceAddressAdapter(source))
+
+        # effectively set address family
+        # see https://stackoverflow.com/a/46972341/9638991
+        urllib3_cn.allowed_gai_family = lambda: af
 
-    # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
-    from requests_toolbelt.adapters.source import SourceAddressAdapter
-    session = requests.Session()
-    session.mount('http://', SourceAddressAdapter(source))
-    session.mount('https://', SourceAddressAdapter(source))
-
-    # see https://stackoverflow.com/a/46972341/9638991
-    import requests.packages.urllib3.util.connection as urllib3_cn
-
-    def allowed_gai_family():
-        return af
-
-    urllib3_cn.allowed_gai_family = allowed_gai_family
-
-    try:
-        _ = ipaddress.ip_address(where)
-        url = 'https://{}{}'.format(where, path)
-    except ValueError:
-        url = where
-
-    # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
-    if post:
-        headers = {
-            "accept": "application/dns-message",
-            "content-type": "application/dns-message",
-            "content-length": str(len(wire))
-        }
-        response = session.post(url, headers=headers, data=wire, stream=True, timeout=timeout, verify=verify)
-    else:
-        wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
-        headers = {
-            "accept": "application/dns-message"
-        }
-        url += "?dns={}".format(wire)
-        response = session.get(url, headers=headers, stream=True, timeout=timeout, verify=verify)
-
-    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for more info about DoH status codes
-    if 200 > response.status_code > 299:
-        raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
-            where, response.status_code, response.content))
-    r = dns.message.from_wire(response.content,
-                              keyring=q.keyring,
-                              request_mac=q.request_mac,
-                              one_rr_per_rrset=one_rr_per_rrset,
-                              ignore_trailing=ignore_trailing)
+        try:
+            _ = ipaddress.ip_address(where)
+            url = 'https://{}:{}{}'.format(where, port, path)
+        except ValueError:
+            url = where
+
+        # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
+        if post:
+            headers = {
+                "accept": "application/dns-message",
+                "content-type": "application/dns-message",
+                "content-length": str(len(wire))
+            }
+            response = session.post(url, headers=headers, data=wire, stream=True,
+                                    timeout=timeout, verify=verify)
+        else:
+            wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
+            headers = {
+                "accept": "application/dns-message"
+            }
+            url += "?dns={}".format(wire)
+            response = session.get(url, headers=headers, stream=True,
+                                   timeout=timeout, verify=verify)
+
+        # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes
+        if response.status_code < 200 or response.status_code > 299:
+            raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
+                where, response.status_code, response.content))
+        r = dns.message.from_wire(response.content,
+                                  keyring=q.keyring,
+                                  request_mac=q.request_mac,
+                                  one_rr_per_rrset=one_rr_per_rrset,
+                                  ignore_trailing=ignore_trailing)
     r.time = response.elapsed
-    r.status_code = response.status_code
     if not q.is_response(r):
         raise BadResponse
     return r
index 8574790e7be6327c1117439f52f313af77ef02f1..ded2f7ff327750ef56089361136a866469593928 100644 (file)
@@ -14,7 +14,6 @@
 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-
 import unittest
 import random
 
@@ -43,22 +42,12 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
         self.assertTrue(q.is_response(r))
 
     def test_build_url_from_ip(self):
-        nameserver_ip = '8.8.8.8' #random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS)
+        nameserver_ip = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
         # For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query
-        # So we're just going to do the GET request
+        # So we're just going to do GET requests here
         r = dns.query.https(q, nameserver_ip, post=False)
         self.assertTrue(q.is_response(r))
 
-    def test_custom_path(self):
-        cleanbrowsing_ip = '185.228.168.168'
-        cleanbrowsing_path = '/doh/security-filter/'
-        q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, cleanbrowsing_ip, path=cleanbrowsing_path, verify=False)
-        self.assertTrue(q.is_response(r))
-
-    def test_use_full_url(self):
-        pass
-
 if __name__ == '__main__':
     unittest.main()