]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
improvements to dns.query.https
authorkimbo <kimballleavitt@gmail.com>
Fri, 20 Dec 2019 00:25:27 +0000 (17:25 -0700)
committerkimbo <kimballleavitt@gmail.com>
Fri, 20 Dec 2019 00:25:27 +0000 (17:25 -0700)
- use requests module (instead of urllib)
- added option to pass in an IP address or a URL
- added basic tests (for dns.query.https)

TODO: af, source_port

dns/query.py
tests/test_doh.py [new file with mode: 0644]

index 5ed82c541e63f465120460887eb2b420fc8e3ab6..48b319db51a49829057ae87c83b4dab68d317a9b 100644 (file)
@@ -19,7 +19,6 @@
 
 from __future__ import generators
 
-import urllib.request
 import errno
 import os
 import select
@@ -28,6 +27,7 @@ import struct
 import sys
 import time
 import base64
+import ipaddress
 
 import dns.exception
 import dns.inet
@@ -37,6 +37,8 @@ import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 
+import requests
+
 try:
     import ssl
 except ImportError:
@@ -207,17 +209,28 @@ def _destination_and_source(af, where, port, source, source_port):
     return (af, destination, source)
 
 
-def https(query, url, timeout=None, post=True, one_rr_per_rrset=False, ignore_trailing=False):
+def https(q, where, timeout=None, af=None, source_port=0,
+          one_rr_per_rrset=False, ignore_trailing=False,
+          post=True, path='/dns-query', verify=True):
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
-    *query*, a ``dns.message.Message``, the query to send.
+    *q*, a ``dns.message.Message``, the query to send.
 
-    *url*, a ``str``, the nameserver URL.
+    *where*, a ``str``, the nameserver IP address or the full URL.
 
     *timeout*, a ``float`` or ``None``, the number of seconds to
     wait before the query times out. If ``None``, the default, wait forever.
 
-    *post*, a ``bool``. If ``True``, the default, POST method should be used.
+    *af*, an ``int``, the address family to use.  The default is ``None``,
+    which causes the address family to use to be inferred from the form of
+    *where*.  If the inference attempt fails, AF_INET is used.  This
+    parameter is historical; you need never set it.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
 
     *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
     RRset.
@@ -225,27 +238,51 @@ def https(query, url, timeout=None, post=True, one_rr_per_rrset=False, ignore_tr
     *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
     junk at end of the received message.
 
+    *post*, a ``bool``. If ``True``, the default, POST method should be used.
+
+    *path*, a ``str``. If *where* is an IP address, then *path* will be used to
+    construct the URL to send the DNS query to
+
     Returns a ``dns.message.Message``.
     """
+    wire = q.to_wire()
 
-    wirequery = query.to_wire()
-    headers = {
-        'Accept': 'application/dns-message',
-        'Content-Type': 'application/dns-message',
-    }
+    try:
+        _ = ipaddress.ip_address(where)
+        url = 'https://{}{}'.format(where, path)
+    except ValueError:
+        url = where
 
+    # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
     if post:
-        request = urllib.request.Request(url, data=wirequery, headers=headers)
+        headers = {
+            "accept": "application/dns-message",
+            "content-type": "application/dns-message",
+            "content-length": str(len(wire))
+        }
+        response = requests.post(url, headers=headers, data=wire, stream=True, timeout=timeout, verify=verify)
     else:
-        wirequery = base64.urlsafe_b64encode(wirequery).decode('utf-8').strip('=')
-        request = urllib.request.Request(url + '?dns=' + wirequery, headers=headers)
-
-    response = urllib.request.urlopen(request, timeout=timeout).read()
-    return dns.message.from_wire(response,
-                                 keyring=query.keyring,
-                                 request_mac=query.request_mac,
-                                 one_rr_per_rrset=one_rr_per_rrset,
-                                 ignore_trailing=ignore_trailing)
+        wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
+        headers = {
+            "accept": "application/dns-message"
+        }
+        url += "?dns={}".format(wire)
+        response = requests.get(url, headers=headers, stream=True, timeout=timeout, verify=verify)
+
+    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for more info about DoH status codes
+    if 200 > response.status_code > 299:
+        raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
+            where, response.status_code, response.content))
+    r = dns.message.from_wire(response.content,
+                              keyring=q.keyring,
+                              request_mac=q.request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset,
+                              ignore_trailing=ignore_trailing)
+    r.time = response.elapsed
+    r.status_code = response.status_code
+    if not q.is_response(r):
+        raise BadResponse
+    return r
 
 def send_udp(sock, what, destination, expiration=None):
     """Send a DNS message to the specified UDP socket.
diff --git a/tests/test_doh.py b/tests/test_doh.py
new file mode 100644 (file)
index 0000000..8574790
--- /dev/null
@@ -0,0 +1,64 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import unittest
+import random
+
+import dns.query
+import dns.rdatatype
+import dns.message
+
+KNOWN_ANYCAST_DOH_RESOLVER_IPS = ['1.1.1.1', '8.8.8.8', '9.9.9.9']
+KNOWN_ANYCAST_DOH_RESOLVER_URLS = ['https://cloudflare-dns.com/dns-query',
+                                   'https://dns.google/dns-query',
+                                   'https://dns11.quad9.net/dns-query']
+
+class DNSOverHTTPSTestCase(unittest.TestCase):
+    nameserver_ip = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS)
+
+    def test_get_request(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        r = dns.query.https(q, nameserver_url, post=False)
+        self.assertTrue(q.is_response(r))
+
+    def test_post_request(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        r = dns.query.https(q, nameserver_url, post=True)
+        self.assertTrue(q.is_response(r))
+
+    def test_build_url_from_ip(self):
+        nameserver_ip = '8.8.8.8' #random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS)
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        # For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query
+        # So we're just going to do the GET request
+        r = dns.query.https(q, nameserver_ip, post=False)
+        self.assertTrue(q.is_response(r))
+
+    def test_custom_path(self):
+        cleanbrowsing_ip = '185.228.168.168'
+        cleanbrowsing_path = '/doh/security-filter/'
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        r = dns.query.https(q, cleanbrowsing_ip, path=cleanbrowsing_path, verify=False)
+        self.assertTrue(q.is_response(r))
+
+    def test_use_full_url(self):
+        pass
+
+if __name__ == '__main__':
+    unittest.main()