]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
added support for source port/address, af
authorkimbo <kimballleavitt@gmail.com>
Sat, 21 Dec 2019 05:05:40 +0000 (22:05 -0700)
committerkimbo <kimballleavitt@gmail.com>
Sat, 21 Dec 2019 05:05:40 +0000 (22:05 -0700)
dns/query.py

index 48b319db51a49829057ae87c83b4dab68d317a9b..49608a1a3b09b8d6b6bf18d16949e10cf27bda0a 100644 (file)
@@ -209,7 +209,7 @@ def _destination_and_source(af, where, port, source, source_port):
     return (af, destination, source)
 
 
-def https(q, where, timeout=None, af=None, source_port=0,
+def https(q, where, timeout=None, af=None, source=None, source_port=0,
           one_rr_per_rrset=False, ignore_trailing=False,
           post=True, path='/dns-query', verify=True):
     """Return the response obtained after sending a query via DNS-over-HTTPS.
@@ -247,6 +247,24 @@ def https(q, where, timeout=None, af=None, source_port=0,
     """
     wire = q.to_wire()
 
+    port = 443
+    (af, destination, source) = _destination_and_source(af, where, port,
+                                                        source, source_port)
+
+    # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py
+    from requests_toolbelt.adapters.source import SourceAddressAdapter
+    session = requests.Session()
+    session.mount('http://', SourceAddressAdapter(source))
+    session.mount('https://', SourceAddressAdapter(source))
+
+    # see https://stackoverflow.com/a/46972341/9638991
+    import requests.packages.urllib3.util.connection as urllib3_cn
+
+    def allowed_gai_family():
+        return af
+
+    urllib3_cn.allowed_gai_family = allowed_gai_family
+
     try:
         _ = ipaddress.ip_address(where)
         url = 'https://{}{}'.format(where, path)
@@ -260,14 +278,14 @@ def https(q, where, timeout=None, af=None, source_port=0,
             "content-type": "application/dns-message",
             "content-length": str(len(wire))
         }
-        response = requests.post(url, headers=headers, data=wire, stream=True, timeout=timeout, verify=verify)
+        response = session.post(url, headers=headers, data=wire, stream=True, timeout=timeout, verify=verify)
     else:
         wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
         headers = {
             "accept": "application/dns-message"
         }
         url += "?dns={}".format(wire)
-        response = requests.get(url, headers=headers, stream=True, timeout=timeout, verify=verify)
+        response = session.get(url, headers=headers, stream=True, timeout=timeout, verify=verify)
 
     # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for more info about DoH status codes
     if 200 > response.status_code > 299: