From: kimbo Date: Tue, 24 Dec 2019 13:59:00 +0000 (-0700) Subject: reset allowed_gai_family after every https() call X-Git-Tag: v2.0.0rc1~342^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=43c804166762bd9da33c56fe29aeab080be09262;p=thirdparty%2Fdnspython.git reset allowed_gai_family after every https() call --- diff --git a/dns/query.py b/dns/query.py index 47a3ccb0..262ac767 100644 --- a/dns/query.py +++ b/dns/query.py @@ -210,6 +210,9 @@ def _destination_and_source(af, where, port, source, source_port): source = (source, source_port, 0, 0) return (af, destination, source) +# keep the original function so we can reset it to avoid +# unintentional breakages +_allowed_gai_family = urllib3.util.connection.allowed_gai_family def https(q, where, timeout=None, port=443, path='/dns-query', post=True, verify=True, af=None, source=None, source_port=0, @@ -251,29 +254,30 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True, Returns a ``dns.message.Message``. """ + wire = q.to_wire() + # This will effectively set the address family passed to getaddrinfo() + # in urllib3.util.connection.create_connection(), which is used by requests + if af is not None: + urllib3.util.connection.allowed_gai_family = lambda: af + (af, destination, source) = _destination_and_source(af, where, port, source, source_port) if source is None: source = ('', 0) - with requests.Session() as session: + try: + _ = ipaddress.ip_address(where) + url = 'https://{}:{}{}'.format(where, port, path) + except ValueError: + url = where + session = requests.sessions.Session() + try: # set source port and source address # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py session.mount('http://', SourceAddressAdapter(source)) session.mount('https://', SourceAddressAdapter(source)) - # This will effectively set the address family passed to getaddrinfo() - # in urllib3.util.connection.create_connection(), which is used by requests - if af is not None: - urllib3.util.connection.allowed_gai_family = lambda: af - - try: - _ = ipaddress.ip_address(where) - url = 'https://{}:{}{}'.format(where, port, path) - except ValueError: - url = where - # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples if post: headers = { @@ -301,6 +305,9 @@ def https(q, where, timeout=None, port=443, path='/dns-query', post=True, request_mac=q.request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing) + finally: + session.close() + urllib3.util.connection.allowed_gai_family = _allowed_gai_family r.time = response.elapsed if not q.is_response(r): raise BadResponse