]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
reset allowed_gai_family after every https() call
authorkimbo <kimballleavitt@gmail.com>
Tue, 24 Dec 2019 13:59:00 +0000 (06:59 -0700)
committerkimbo <kimballleavitt@gmail.com>
Tue, 24 Dec 2019 13:59:00 +0000 (06:59 -0700)
dns/query.py

index 47a3ccb0c3635aaa0d2780dcc2d2107965ad2bd2..262ac76753db70b43bbe4e99a311c18c58f0fd21 100644 (file)
@@ -210,6 +210,9 @@ def _destination_and_source(af, where, port, source, source_port):
             source = (source, source_port, 0, 0)
     return (af, destination, source)
 
+# keep the original function so we can reset it to avoid
+# unintentional breakages
+_allowed_gai_family = urllib3.util.connection.allowed_gai_family
 
 def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
           verify=True, af=None, source=None, source_port=0,
@@ -251,29 +254,30 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
 
     Returns a ``dns.message.Message``.
     """
+
     wire = q.to_wire()
 
+    # This will effectively set the address family passed to getaddrinfo()
+    # in urllib3.util.connection.create_connection(), which is used by requests
+    if af is not None:
+        urllib3.util.connection.allowed_gai_family = lambda: af
+
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
     if source is None:
         source = ('', 0)
-    with requests.Session() as session:
+    try:
+        _ = ipaddress.ip_address(where)
+        url = 'https://{}:{}{}'.format(where, port, path)
+    except ValueError:
+        url = where
+    session = requests.sessions.Session()
+    try:
         # set source port and source address
         # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
         session.mount('http://', SourceAddressAdapter(source))
         session.mount('https://', SourceAddressAdapter(source))
 
-        # This will effectively set the address family passed to getaddrinfo()
-        # in urllib3.util.connection.create_connection(), which is used by requests
-        if af is not None:
-            urllib3.util.connection.allowed_gai_family = lambda: af
-
-        try:
-            _ = ipaddress.ip_address(where)
-            url = 'https://{}:{}{}'.format(where, port, path)
-        except ValueError:
-            url = where
-
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
         if post:
             headers = {
@@ -301,6 +305,9 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True,
                                   request_mac=q.request_mac,
                                   one_rr_per_rrset=one_rr_per_rrset,
                                   ignore_trailing=ignore_trailing)
+    finally:
+        session.close()
+        urllib3.util.connection.allowed_gai_family = _allowed_gai_family
     r.time = response.elapsed
     if not q.is_response(r):
         raise BadResponse