]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
DoH cleanup. 408/head
authorBrian Wellington <bwelling@xbill.org>
Tue, 7 Jan 2020 21:03:13 +0000 (13:03 -0800)
committerBrian Wellington <bwelling@xbill.org>
Tue, 7 Jan 2020 21:03:13 +0000 (13:03 -0800)
dns/query.py
dns/query.pyi
dns/resolver.py
examples/doh.py
tests/test_doh.py

index c36248e3c6b05ca2b297e109c3423d7913720c74..5876623e1c9031607d5382d632b0c864ff3b5421 100644 (file)
@@ -189,14 +189,16 @@ def _addresses_equal(af, a1, a2):
     return n1 == n2 and a1[1:] == a2[1:]
 
 
-def _destination_and_source(af, where, port, source, source_port):
+def _destination_and_source(af, where, port, source, source_port,
+                            default_to_inet=True):
     # Apply defaults and compute destination and source tuples
     # suitable for use in connect(), sendto(), or bind().
     if af is None:
         try:
             af = dns.inet.af_for_address(where)
         except Exception:
-            af = dns.inet.AF_INET
+            if default_to_inet:
+                af = dns.inet.AF_INET
     if af == dns.inet.AF_INET:
         destination = (where, port)
         if source is not None or source_port != 0:
@@ -209,6 +211,9 @@ def _destination_and_source(af, where, port, source, source_port):
             if source is None:
                 source = '::'
             source = (source, source_port, 0, 0)
+    else:
+        source = None
+        destination = None
     return (af, destination, source)
 
 def send_https(session, what, lifetime=None):
@@ -225,9 +230,10 @@ def send_https(session, what, lifetime=None):
         what = what.prepare()
     return session.send(what, timeout=lifetime)
 
-def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True,
-          bootstrap_address=None, verify=True, source=None, source_port=0,
-          one_rr_per_rrset=False, ignore_trailing=False):
+def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0,
+          one_rr_per_rrset=False, ignore_trailing=False,
+          session=None, path='/dns-query', post=True,
+          bootstrap_address=None, verify=True):
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
     *q*, a ``dns.message.Message``, the query to send.
@@ -236,21 +242,15 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
     address is given, the URL will be constructed using the following schema:
     https://<IP-address>:<port>/<path>.
 
-    *session*, a ``requests.session.Session``, the session to use to send the
-    queries. This argument is required to allow for connection reuse.
-
     *timeout*, a ``float`` or ``None``, the number of seconds to
     wait before the query times out. If ``None``, the default, wait forever.
 
-    *port*, a ``int``, the port to send the query to. Default is 443.
-
-    *path*, a ``str``. If *where* is an IP address, then *path* will be used to
-    construct the URL to send the DNS query to.
-
-    *post*, a ``bool``. If ``True``, the default, POST method will be used.
+    *port*, a ``int``, the port to send the query to. The default is 443.
 
-    *bootstrap_address*, a ``str``, the IP address to use to bypass the system's
-    DNS resolver.
+    *af*, an ``int``, the address family to use.  The default is ``None``,
+    which causes the address family to use to be inferred from the form of
+    *where*, or uses the system default.  Setting this to AF_INET or
+    AF_INET6 currently has no effect.
 
     *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
     the source address.  The default is the wildcard address.
@@ -264,13 +264,27 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
     *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
     junk at end of the received message.
 
+    *session*, a ``requests.session.Session``.  If provided, the session to use
+    to send the queries.
+
+    *path*, a ``str``. If *where* is an IP address, then *path* will be used to
+    construct the URL to send the DNS query to.
+
+    *post*, a ``bool``. If ``True``, the default, POST method will be used.
+
+    *bootstrap_address*, a ``str``, the IP address to use to bypass the
+    system's DNS resolver.
+
+    *verify*, a ``str`, containing a path to a certificate file or directory.
+
     Returns a ``dns.message.Message``.
     """
 
     wire = q.to_wire()
-    af = None
     (af, destination, source) = _destination_and_source(af, where, port,
-                                                        source, source_port)
+                                                        source, source_port,
+                                                        False)
+    transport_adapter = None
     headers = {
         "accept": "application/dns-message"
     }
@@ -282,31 +296,49 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
             split_url = urllib.parse.urlsplit(where)
             headers['Host'] = split_url.hostname
             url = where.replace(split_url.hostname, bootstrap_address)
-            session.mount(url, HostHeaderSSLAdapter())
+            transport_adapter = HostHeaderSSLAdapter()
         else:
             url = where
     if source is not None:
         # set source port and source address
-        session.mount(url, SourceAddressAdapter(source))
-
-    # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
-    if post:
-        headers.update({
-            "content-type": "application/dns-message",
-            "content-length": str(len(wire))
-        })
-        response = session.post(url, headers=headers, data=wire, stream=True,
-                                timeout=timeout, verify=verify)
+        transport_adapter = SourceAddressAdapter(source)
+
+    if session:
+        close_session = False
     else:
-        wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
-        url += "?dns={}".format(wire)
-        response = session.get(url, headers=headers, stream=True,
-                               timeout=timeout, verify=verify)
+        session = requests.sessions.Session()
+        close_session = True
+
+    try:
+        if transport_adapter:
+            session.mount(url, transport_adapter)
+
+        # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
+        # GET and POST examples
+        if post:
+            headers.update({
+                "content-type": "application/dns-message",
+                "content-length": str(len(wire))
+            })
+            response = session.post(url, headers=headers, data=wire,
+                                    stream=True, timeout=timeout,
+                                    verify=verify)
+        else:
+            wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
+            url += "?dns={}".format(wire)
+            response = session.get(url, headers=headers, stream=True,
+                                   timeout=timeout, verify=verify)
+    finally:
+        if close_session:
+            session.close()
 
-    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes
+    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
+    # status codes
     if response.status_code < 200 or response.status_code > 299:
-        raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
-            where, response.status_code, response.content))
+        raise ValueError('{} responded with status code {}'
+                         '\nResponse body: {}'.format(where,
+                                                      response.status_code,
+                                                      response.content))
     r = dns.message.from_wire(response.content,
                               keyring=q.keyring,
                               request_mac=q.request_mac,
index e943ba5fa7c4cc326431ed7900ca5468a772dcf7..9346123296f1e22ad7ca9d89ea65547d9d7c063e 100644 (file)
@@ -8,9 +8,8 @@ except ImportError:
     class ssl(object):
         SSLContext = {}
 
-def https(q : message.Message, where: str, session: Session, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True,
-          bootstrap_address : Optional[str] = None, verify : Optional[bool] = True, source : Optional[str] = None, source_port : Optional[int] = 0,
-          one_rr_per_rrset : Optional[bool] = False, ignore_trailing : Optional[bool] = False) -> message.Message:
+def https(q : message.Message, where: str, timeout : Optional[float] = None, port : Optional[int] = 443, af : Optional[int] = None, source : Optional[str] = None, source_port : Optional[int] = 0,
+          session: Optional[Session], path : Optional[str] = '/dns-query', post : Optional[bool] = True, bootstrap_address : Optional[str] = None, verify : Optional[bool] = True) -> message.Message:
     pass
 
 def tcp(q : message.Message, where : str, timeout : float = None, port=53, af : Optional[int] = None, source : Optional[str] = None, source_port : Optional[int] = 0,
index 735de9f584842a03e6ffd047a7d3607c7ba27c4f..3f5e45185947ffc5010e031d1b4c9ee58c5cd1d4 100644 (file)
@@ -909,30 +909,37 @@ class Resolver(object):
                     try:
                         if protocol == 'https':
                             tcp_attempt = True
-                            response = dns.query.https(request, nameserver, timeout)
+                            response = dns.query.https(request, nameserver,
+                                                       timeout=timeout)
                         elif protocol:
                             continue
                         else:
                             tcp_attempt = tcp
                             if tcp:
                                 response = dns.query.tcp(request, nameserver,
-                                                         timeout, port,
+                                                         timeout=timeout,
+                                                         port=port,
                                                          source=source,
-                                                         source_port=source_port)
+                                                         source_port=\
+                                                         source_port)
                             else:
                                 try:
-                                    response = dns.query.udp(request, nameserver,
-                                                             timeout, port,
+                                    response = dns.query.udp(request,
+                                                             nameserver,
+                                                             timeout=timeout,
+                                                             port=port,
                                                              source=source,
                                                              source_port=\
                                                              source_port)
                                 except dns.message.Truncated:
                                     # Response truncated; retry with TCP.
                                     tcp_attempt = True
-                                    timeout = self._compute_timeout(start, lifetime)
+                                    timeout = self._compute_timeout(start,
+                                                                    lifetime)
                                     response = \
                                         dns.query.tcp(request, nameserver,
-                                                      timeout, port,
+                                                      timeout=timeout,
+                                                      port=port,
                                                       source=source,
                                                       source_port=source_port)
                     except (socket.error, dns.exception.Timeout) as ex:
index 01c562f8836441a6e3664ef0b2d2af0256ae6fe8..eff9ae7579fbac567902dfc9be6ebf919a2d117b 100644 (file)
@@ -18,7 +18,7 @@ def main():
     # one method is to use context manager, session will automatically close
     with requests.sessions.Session() as session:
         q = dns.message.make_query(qname, dns.rdatatype.A)
-        r = dns.query.https(q, where, session)
+        r = dns.query.https(q, where, session=session)
         for answer in r.answer:
             print(answer)
 
@@ -29,7 +29,7 @@ def main():
     # second method, close session manually
     session = requests.sessions.Session()
     q = dns.message.make_query(qname, dns.rdatatype.A)
-    r = dns.query.https(q, where, session)
+    r = dns.query.https(q, where, session=session)
     for answer in r.answer:
         print(answer)
 
index 3819b1a4d6569f941ddaf57c7bac2842b1fa7271..acda5af64ffc03d06b46b691aec87472a9cb3f7f 100644 (file)
@@ -40,13 +40,13 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
     def test_get_request(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url, self.session, post=False)
+        r = dns.query.https(q, nameserver_url, session=self.session, post=False)
         self.assertTrue(q.is_response(r))
 
     def test_post_request(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url, self.session, post=True)
+        r = dns.query.https(q, nameserver_url, session=self.session, post=True)
         self.assertTrue(q.is_response(r))
 
     def test_build_url_from_ip(self):
@@ -54,7 +54,7 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
         # For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query
         # So we're just going to do GET requests here
-        r = dns.query.https(q, nameserver_ip, self.session, post=False)
+        r = dns.query.https(q, nameserver_ip, session=self.session, post=False)
         self.assertTrue(q.is_response(r))
 
     def test_bootstrap_address(self):
@@ -64,9 +64,9 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
         # make sure CleanBrowsing's IP address will fail TLS certificate check
         with self.assertRaises(SSLError):
-            dns.query.https(q, invalid_tls_url, self.session)
+            dns.query.https(q, invalid_tls_url, session=self.session)
         # use host header
-        r = dns.query.https(q, valid_tls_url, self.session, bootstrap_address=ip)
+        r = dns.query.https(q, valid_tls_url, session=self.session, bootstrap_address=ip)
         self.assertTrue(q.is_response(r))
 
     def test_send_https(self):
@@ -79,5 +79,11 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
         dns_resp = dns.message.from_wire(response.content)
         self.assertTrue(q.is_response(dns_resp))
 
+    def test_new_session(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
+        q = dns.message.make_query('example.com.', dns.rdatatype.A)
+        r = dns.query.https(q, nameserver_url)
+        self.assertTrue(q.is_response(r))
+
 if __name__ == '__main__':
     unittest.main()