From: kimbo Date: Mon, 23 Dec 2019 21:09:11 +0000 (-0700) Subject: cleaned up query.py and test_doh X-Git-Tag: v2.0.0rc1~342^2~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c292519ff971cf0036ef6d4f57a5bcc1aba45200;p=thirdparty%2Fdnspython.git cleaned up query.py and test_doh --- diff --git a/dns/query.py b/dns/query.py index 49608a1a..a2d126e8 100644 --- a/dns/query.py +++ b/dns/query.py @@ -38,6 +38,8 @@ import dns.rdataclass import dns.rdatatype import requests +import requests.packages.urllib3.util.connection as urllib3_cn +from requests_toolbelt.adapters.source import SourceAddressAdapter try: import ssl @@ -209,18 +211,27 @@ def _destination_and_source(af, where, port, source, source_port): return (af, destination, source) -def https(q, where, timeout=None, af=None, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, - post=True, path='/dns-query', verify=True): +def https(q, where, timeout=None, port=443, path='/dns-query', post=True, + verify=True, af=None, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False): """Return the response obtained after sending a query via DNS-over-HTTPS. *q*, a ``dns.message.Message``, the query to send. - *where*, a ``str``, the nameserver IP address or the full URL. + *where*, a ``str``, the nameserver IP address or the full URL. If an IP + address is given, the URL will be constructed using the following schema: + https::/. *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query times out. If ``None``, the default, wait forever. + *port*, a ``int``, the port to send the query to. Default is 443. + + *path*, a ``str``. If *where* is an IP address, then *path* will be used to + construct the URL to send the DNS query to. + + *post*, a ``bool``. If ``True``, the default, POST method will be used. + *af*, an ``int``, the address family to use. The default is ``None``, which causes the address family to use to be inferred from the form of *where*. If the inference attempt fails, AF_INET is used. This @@ -238,66 +249,60 @@ def https(q, where, timeout=None, af=None, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *post*, a ``bool``. If ``True``, the default, POST method should be used. - - *path*, a ``str``. If *where* is an IP address, then *path* will be used to - construct the URL to send the DNS query to - Returns a ``dns.message.Message``. """ wire = q.to_wire() - port = 443 (af, destination, source) = _destination_and_source(af, where, port, source, source_port) + if source is None: + source = ('', 0) + if headers is None: + headers = {} + with requests.Session() as session: + # set source port and source address + # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py + session.mount('http://', SourceAddressAdapter(source)) + session.mount('https://', SourceAddressAdapter(source)) + + # effectively set address family + # see https://stackoverflow.com/a/46972341/9638991 + urllib3_cn.allowed_gai_family = lambda: af - # see https://github.com/requests/toolbelt/blob/master/requests_toolbelt/adapters/source.py - from requests_toolbelt.adapters.source import SourceAddressAdapter - session = requests.Session() - session.mount('http://', SourceAddressAdapter(source)) - session.mount('https://', SourceAddressAdapter(source)) - - # see https://stackoverflow.com/a/46972341/9638991 - import requests.packages.urllib3.util.connection as urllib3_cn - - def allowed_gai_family(): - return af - - urllib3_cn.allowed_gai_family = allowed_gai_family - - try: - _ = ipaddress.ip_address(where) - url = 'https://{}{}'.format(where, path) - except ValueError: - url = where - - # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples - if post: - headers = { - "accept": "application/dns-message", - "content-type": "application/dns-message", - "content-length": str(len(wire)) - } - response = session.post(url, headers=headers, data=wire, stream=True, timeout=timeout, verify=verify) - else: - wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=") - headers = { - "accept": "application/dns-message" - } - url += "?dns={}".format(wire) - response = session.get(url, headers=headers, stream=True, timeout=timeout, verify=verify) - - # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for more info about DoH status codes - if 200 > response.status_code > 299: - raise ValueError('{} responded with status code {}\nResponse body: {}'.format( - where, response.status_code, response.content)) - r = dns.message.from_wire(response.content, - keyring=q.keyring, - request_mac=q.request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) + try: + _ = ipaddress.ip_address(where) + url = 'https://{}:{}{}'.format(where, port, path) + except ValueError: + url = where + + # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples + if post: + headers = { + "accept": "application/dns-message", + "content-type": "application/dns-message", + "content-length": str(len(wire)) + } + response = session.post(url, headers=headers, data=wire, stream=True, + timeout=timeout, verify=verify) + else: + wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=") + headers = { + "accept": "application/dns-message" + } + url += "?dns={}".format(wire) + response = session.get(url, headers=headers, stream=True, + timeout=timeout, verify=verify) + + # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes + if response.status_code < 200 or response.status_code > 299: + raise ValueError('{} responded with status code {}\nResponse body: {}'.format( + where, response.status_code, response.content)) + r = dns.message.from_wire(response.content, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) r.time = response.elapsed - r.status_code = response.status_code if not q.is_response(r): raise BadResponse return r diff --git a/tests/test_doh.py b/tests/test_doh.py index 8574790e..ded2f7ff 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -14,7 +14,6 @@ # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - import unittest import random @@ -43,22 +42,12 @@ class DNSOverHTTPSTestCase(unittest.TestCase): self.assertTrue(q.is_response(r)) def test_build_url_from_ip(self): - nameserver_ip = '8.8.8.8' #random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS) + nameserver_ip = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_IPS) q = dns.message.make_query('example.com.', dns.rdatatype.A) # For some reason Google's DNS over HTTPS fails when you POST to https://8.8.8.8/dns-query - # So we're just going to do the GET request + # So we're just going to do GET requests here r = dns.query.https(q, nameserver_ip, post=False) self.assertTrue(q.is_response(r)) - def test_custom_path(self): - cleanbrowsing_ip = '185.228.168.168' - cleanbrowsing_path = '/doh/security-filter/' - q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, cleanbrowsing_ip, path=cleanbrowsing_path, verify=False) - self.assertTrue(q.is_response(r)) - - def test_use_full_url(self): - pass - if __name__ == '__main__': unittest.main()