]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
req'd session param to https() for connection reuse
authorkimbo <kimballleavitt@gmail.com>
Mon, 30 Dec 2019 21:02:33 +0000 (14:02 -0700)
committerkimbo <kimballleavitt@gmail.com>
Mon, 30 Dec 2019 21:02:33 +0000 (14:02 -0700)
dns/query.py
dns/query.pyi
tests/test_doh.py

index 23711453952b8a5957d36c21923238eecb9a6468..867acab33c4c04069b1ef70a786a4cfb4107a03e 100644 (file)
@@ -210,7 +210,7 @@ def _destination_and_source(af, where, port, source, source_port):
     return (af, destination, source)
 
 
-def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
+def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True,
           verify=True, source=None, source_port=0,
           one_rr_per_rrset=False, ignore_trailing=False):
     """Return the response obtained after sending a query via DNS-over-HTTPS.
@@ -221,6 +221,9 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
     address is given, the URL will be constructed using the following schema:
     https:<IP-address>:<port>/<path>.
 
+    *session*, a ``requests.session.Session``, the session to use to send the
+    queries. This argument is required to allow for connection reuse.
+
     *timeout*, a ``float`` or ``None``, the number of seconds to
     wait before the query times out. If ``None``, the default, wait forever.
 
@@ -257,12 +260,11 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
         url = 'https://{}:{}{}'.format(where, port, path)
     except ValueError:
         url = where
-    session = requests.sessions.Session()
+    session = requests.sessions.Session()
     try:
         # set source port and source address
         # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
-        session.mount('http://', SourceAddressAdapter(source))
-        session.mount('https://', SourceAddressAdapter(source))
+        session.mount(url, SourceAddressAdapter(source))
 
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
         if post:
index 7c6f09c5454499b1199e1ab395606feb0759451c..faedf8ce65de9cae82ecf111715e8c6cebb51ed6 100644 (file)
@@ -1,5 +1,6 @@
 from typing import Optional, Union, Dict, Generator, Any
 from . import tsig, rdatatype, rdataclass, name, message
+from requests.sessions import Session
 
 try:
     import ssl
@@ -7,7 +8,7 @@ except ImportError:
     class ssl(object):
         SSLContext = {}
 
-def https(q : message.Message, where: str, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True,
+def https(q : message.Message, where: str, session: Session, timeout : Optional[float] = None, port : Optional[int] = 443, path : Optional[str] = '/dns-query', post : Optional[bool] = True,
           verify : Optional[bool] = True, source : Optional[str] = None, source_port : Optional[int] = 0,
           one_rr_per_rrset : Optional[bool] = False, ignore_trailing : Optional[bool] = False) -> message.Message:
     pass
index ded2f7ff327750ef56089361136a866469593928..5e9297643cba484d9dbc137d6b1da54c612fcc55 100644 (file)
@@ -17,6 +17,8 @@
 import unittest
 import random
 
+import requests
+
 import dns.query
 import dns.rdatatype
 import dns.message
@@ -27,18 +29,22 @@ KNOWN_ANYCAST_DOH_RESOLVER_URLS = ['https://cloudflare-dns.com/dns-query',
                                    'https://dns11.quad9.net/dns-query']
 
 class DNSOverHTTPSTestCase(unittest.TestCase):
-    nameserver_ip = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS)
+    def setUp(self):
+        self.session = requests.sessions.Session()
+
+    def tearDown(self):
+        self.session.close()
 
     def test_get_request(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url, post=False)
+        r = dns.query.https(q, nameserver_url, self.session, post=False)
         self.assertTrue(q.is_response(r))
 
     def test_post_request(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url, post=True)
+        r = dns.query.https(q, nameserver_url, self.session, post=True)
         self.assertTrue(q.is_response(r))
 
     def test_build_url_from_ip(self):
@@ -46,7 +52,7 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
         q = dns.message.make_query('example.com.', dns.rdatatype.A)
         # For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query
         # So we're just going to do GET requests here
-        r = dns.query.https(q, nameserver_ip, post=False)
+        r = dns.query.https(q, nameserver_ip, self.session, post=False)
         self.assertTrue(q.is_response(r))
 
 if __name__ == '__main__':