From: Bob Halley Date: Sun, 19 Mar 2023 16:24:32 +0000 (-0700) Subject: Better DNS-over-HTTPS support. (#908) X-Git-Tag: v2.4.0rc1~46 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6c5f0c9d8086c999357531c38b831efd24b6b5ac;p=thirdparty%2Fdnspython.git Better DNS-over-HTTPS support. (#908) This change: Allows resolution hostnames in URLs using dnspython's resolver or via a bootstrap address, without rewriting URLs. Adds full support for source addresses and ports to httpx, except for asyncio I/O where only the source address can be specified. Removes support for requests. --- diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 84e4ecc4..3f887fea 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -60,7 +60,7 @@ jobs: sudo apt install -y gnome-keyring python -m pip install --upgrade pip python -m pip install poetry - poetry install -E dnssec -E doh -E idna -E trio + poetry install -E dnssec -E doh -E idna -E trio -E curio - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v2 diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index ff24604f..7fd4926b 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -61,6 +61,11 @@ class StreamSocket(Socket): # pragma: no cover raise NotImplementedError +class NullTransport: + async def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + class Backend: # pragma: no cover def name(self): return "unknown" @@ -83,3 +88,6 @@ class Backend: # pragma: no cover async def sleep(self, interval): raise NotImplementedError + + def get_transport_class(self): + raise NotImplementedError diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 82a06249..98971be9 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -113,6 +113,82 @@ class StreamSocket(dns._asyncbackend.StreamSocket): return self.writer.get_extra_info("sockname") +try: + import anyio + import httpx + + import httpcore + import httpcore.backends.base + import httpcore.backends.asyncio + + from dns.query import _compute_times, _remaining, _expiration_for_this_attempt + + class _NetworkBackend(httpcore.backends.base.AsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + if local_port != 0: + raise NotImplementedError( + "the asyncio transport for HTTPX cannot set the local port" + ) + + async def connect_tcp(self, host, port, timeout, local_address): + addresses = [] + now, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + with anyio.fail_after(timeout): + stream = await anyio.connect_tcp( + remote_host=host, + remote_port=port, + local_host=local_address, + ) + return httpcore.backends.asyncio.AsyncIOStream(stream) + except Exception: + pass + raise httpcore.ConnectError + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + + class Backend(dns._asyncbackend.Backend): def name(self): return "asyncio" @@ -171,3 +247,6 @@ class Backend(dns._asyncbackend.Backend): def datagram_connection_required(self): return _is_win32 + + def get_transport_class(self): + return _HTTPTransport diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index b0c02103..08101f9a 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -83,6 +83,80 @@ class StreamSocket(dns._asyncbackend.StreamSocket): return self.stream.socket.getsockname() +try: + import httpx + + import httpcore + import httpcore.backends.base + import httpcore.backends.trio + + from dns.query import _compute_times, _remaining, _expiration_for_this_attempt + + class _NetworkBackend(httpcore.backends.base.AsyncNetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + async def connect_tcp(self, host, port, timeout, local_address): + addresses = [] + now, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = await self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + try: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = (local_address, self._local_port) + else: + source = None + destination = (address, port) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + timeout = _remaining(attempt_expiration) + sock = await Backend().make_socket( + af, socket.SOCK_STREAM, 0, source, destination, timeout + ) + return httpcore.backends.trio.TrioStream(sock.stream) + except Exception: + continue + raise httpcore.ConnectError + + class _HTTPTransport(httpx.AsyncHTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + +except ImportError: + _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore + + class Backend(dns._asyncbackend.Backend): def name(self): return "trio" @@ -104,8 +178,14 @@ class Backend(dns._asyncbackend.Backend): if source: await s.bind(_lltuple(source, af)) if socktype == socket.SOCK_STREAM: + connected = False with _maybe_timeout(timeout): await s.connect(_lltuple(destination, af)) + connected = True + if not connected: + raise dns.exception.Timeout( + timeout=timeout + ) # lgtm[py/unreachable-statement] except Exception: # pragma: no cover s.close() raise @@ -130,3 +210,6 @@ class Backend(dns._asyncbackend.Backend): async def sleep(self, interval): await trio.sleep(interval) + + def get_transport_class(self): + return _HTTPTransport diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 459c611d..ea539116 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -43,13 +43,13 @@ from dns.query import ( BadResponse, ssl, UDPMode, - _have_httpx, + have_doh, _have_http2, NoDOH, NoDOQ, ) -if _have_httpx: +if have_doh: import httpx # for brevity @@ -495,6 +495,9 @@ async def https( path: str = "/dns-query", post: bool = True, verify: Union[bool, str] = True, + bootstrap_address: Optional[str] = None, + resolver: Optional["dns.asyncresolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -508,8 +511,10 @@ async def https( parameters, exceptions, and return type of this method. """ - if not _have_httpx: + if not have_doh: raise NoDOH("httpx is not available.") # pragma: no cover + if client and not isinstance(client, httpx.AsyncClient): + raise ValueError("session parameter must be an httpx.AsyncClient") wire = q.to_wire() try: @@ -518,15 +523,30 @@ async def https( af = None transport = None headers = {"accept": "application/dns-message"} - if af is not None: + if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: url = "https://[{}]:{}{}".format(where, port, path) else: url = where - if source is not None: - transport = httpx.AsyncHTTPTransport(local_address=source[0]) + + backend = dns.asyncbackend.get_default_backend() + + if source is None: + local_address = None + local_port = 0 + else: + local_address = source + local_port = source_port + transport = backend.get_transport_class()( + local_address=local_address, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) if client: cm: contextlib.AbstractAsyncContextManager = NullContext(client) diff --git a/dns/inet.py b/dns/inet.py index 11180c96..23a4a86e 100644 --- a/dns/inet.py +++ b/dns/inet.py @@ -171,3 +171,12 @@ def low_level_address_tuple( return tup else: raise NotImplementedError(f"unknown address family {af}") + + +def any_for_af(af): + """Return the 'any' address for the specified address family.""" + if af == socket.AF_INET: + return "0.0.0.0" + elif af == socket.AF_INET6: + return "::" + raise NotImplementedError(f"unknown address family {af}") diff --git a/dns/query.py b/dns/query.py index b4cd69f7..517bab02 100644 --- a/dns/query.py +++ b/dns/query.py @@ -43,14 +43,21 @@ import dns.transaction import dns.tsig import dns.xfr -try: - import requests - from requests_toolbelt.adapters.source import SourceAddressAdapter - from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter - _have_requests = True -except ImportError: # pragma: no cover - _have_requests = False +def _remaining(expiration): + if expiration is None: + return None + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + return timeout + + +def _expiration_for_this_attempt(timeout, expiration): + if expiration is None: + return None + return min(time.time() + timeout, expiration) + _have_httpx = False _have_http2 = False @@ -64,10 +71,83 @@ try: _have_http2 = True except Exception: pass + + import httpcore + import httpcore.backends.base + import httpcore.backends.sync + + class _NetworkBackend(httpcore.backends.base.NetworkBackend): + def __init__(self, resolver, local_port, bootstrap_address, family): + super().__init__() + self._local_port = local_port + self._resolver = resolver + self._bootstrap_address = bootstrap_address + self._family = family + + def connect_tcp(self, host, port, timeout, local_address): + addresses = [] + now, expiration = _compute_times(timeout) + if dns.inet.is_address(host): + addresses.append(host) + elif self._bootstrap_address is not None: + addresses.append(self._bootstrap_address) + else: + timeout = _remaining(expiration) + family = self._family + if local_address: + family = dns.inet.af_for_address(local_address) + answers = self._resolver.resolve_name( + host, family=family, lifetime=timeout + ) + addresses = answers.addresses() + for address in addresses: + af = dns.inet.af_for_address(address) + if local_address is not None or self._local_port != 0: + source = dns.inet.low_level_address_tuple( + (local_address, self._local_port), af + ) + else: + source = None + sock = _make_socket(af, socket.SOCK_STREAM, source) + attempt_expiration = _expiration_for_this_attempt(2.0, expiration) + try: + _connect( + sock, + dns.inet.low_level_address_tuple((address, port), af), + attempt_expiration, + ) + return httpcore.backends.sync.SyncStream(sock) + except Exception: + pass + raise httpcore.ConnectError + + class _HTTPTransport(httpx.HTTPTransport): + def __init__( + self, + *args, + local_port=0, + bootstrap_address=None, + resolver=None, + family=socket.AF_UNSPEC, + **kwargs, + ): + if resolver is None: + import dns.resolver + + resolver = dns.resolver.Resolver() + super().__init__(*args, **kwargs) + self._pool._network_backend = _NetworkBackend( + resolver, local_port, bootstrap_address, family + ) + except ImportError: # pragma: no cover - pass -have_doh = _have_requests or _have_httpx + class _HTTPTransport: # type: ignore + def connect_tcp(self, host, port, timeout, local_address): + raise NotImplementedError + + +have_doh = _have_httpx try: import ssl @@ -240,11 +320,10 @@ def _destination_and_source( # Caller has specified a source_port but not an address, so we # need to return a source, and we need to use the appropriate # wildcard address as the address. - if af == socket.AF_INET: - source = "0.0.0.0" - elif af == socket.AF_INET6: - source = "::" - else: + try: + source = dns.inet.any_for_af(af) + except Exception: + # we catch this and raise ValueError for backwards compatibility raise ValueError("source_port specified but address family is unknown") # Convert high-level (address, port) tuples into low-level address # tuples. @@ -289,6 +368,8 @@ def https( post: bool = True, bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, + resolver: Optional["dns.resolver.Resolver"] = None, + family: Optional[int] = socket.AF_UNSPEC, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -314,91 +395,76 @@ def https( *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *session*, an ``httpx.Client`` or ``requests.session.Session``. If provided, the - client/session to use to send the queries. + *session*, an ``httpx.Client``. If provided, the client session to use to send the + queries. *path*, a ``str``. If *where* is an IP address, then *path* will be used to construct the URL to send the DNS query to. *post*, a ``bool``. If ``True``, the default, POST method will be used. - *bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS - resolver. + *bootstrap_address*, a ``str``, the IP address to use to bypass resolution. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification of the server is done using the default CA bundle; if ``False``, then no verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. + *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for + resolution of hostnames in URLs. If not specified, a new resolver with a default + configuration will be used; note this is *not* the default resolver as that resolver + might have been configured to use DoH causing a chicken-and-egg problem. This + parameter only has an effect if the HTTP library is httpx. + + *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A + and AAAA records will be retrieved. + Returns a ``dns.message.Message``. """ if not have_doh: - raise NoDOH("Neither httpx nor requests is available.") # pragma: no cover - - _httpx_ok = _have_httpx + raise NoDOH("DNS-over-HTTPS is not available.") # pragma: no cover + if session and not isinstance(session, httpx.Client): + raise ValueError("session parameter must be an httpx.Client") wire = q.to_wire() - (af, _, source) = _destination_and_source(where, port, source, source_port, False) - transport_adapter = None + (af, _, the_source) = _destination_and_source( + where, port, source, source_port, False + ) transport = None headers = {"accept": "application/dns-message"} - if af is not None: + if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) elif af == socket.AF_INET6: url = "https://[{}]:{}{}".format(where, port, path) - elif bootstrap_address is not None: - _httpx_ok = False - split_url = urllib.parse.urlsplit(where) - if split_url.hostname is None: - raise ValueError("DoH URL has no hostname") - headers["Host"] = split_url.hostname - url = where.replace(split_url.hostname, bootstrap_address) - if _have_requests: - transport_adapter = HostHeaderSSLAdapter() else: url = where - if source is not None: - # set source port and source address - if _have_httpx: - if source_port == 0: - transport = httpx.HTTPTransport(local_address=source[0], verify=verify) - else: - _httpx_ok = False - if _have_requests: - transport_adapter = SourceAddressAdapter(source) - if session: - if _have_httpx: - _is_httpx = isinstance(session, httpx.Client) - else: - _is_httpx = False - if _is_httpx and not _httpx_ok: - raise NoDOH( - "Session is httpx, but httpx cannot be used for " - "the requested operation." - ) - else: - _is_httpx = _httpx_ok + # set source port and source address - if not _httpx_ok and not _have_requests: - raise NoDOH( - "Cannot use httpx for this operation, and requests is not available." - ) + if the_source is None: + local_address = None + local_port = 0 + else: + local_address = the_source[0] + local_port = the_source[1] + transport = _HTTPTransport( + local_address=local_address, + verify=verify, + local_port=local_port, + bootstrap_address=bootstrap_address, + resolver=resolver, + family=family, + ) if session: cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) - elif _is_httpx: + else: cm = httpx.Client( http1=True, http2=_have_http2, verify=verify, transport=transport ) - else: - cm = requests.sessions.Session() with cm as session: - if transport_adapter and not _is_httpx: - session.mount(url, transport_adapter) - # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples if post: @@ -408,29 +474,13 @@ def https( "content-length": str(len(wire)), } ) - if _is_httpx: - response = session.post( - url, headers=headers, content=wire, timeout=timeout - ) - else: - response = session.post( - url, headers=headers, data=wire, timeout=timeout, verify=verify - ) + response = session.post(url, headers=headers, content=wire, timeout=timeout) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") - if _is_httpx: - twire = wire.decode() # httpx does a repr() if we give it bytes - response = session.get( - url, headers=headers, timeout=timeout, params={"dns": twire} - ) - else: - response = session.get( - url, - headers=headers, - timeout=timeout, - verify=verify, - params={"dns": wire}, - ) + twire = wire.decode() # httpx does a repr() if we give it bytes + response = session.get( + url, headers=headers, timeout=timeout, params={"dns": twire} + ) # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes diff --git a/doc/installation.rst b/doc/installation.rst index 81dcdb04..35b46ae2 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -45,10 +45,11 @@ Optional Modules The following modules are optional, but recommended for full functionality. -If ``requests`` and ``requests-toolbelt`` are installed, then DNS-over-HTTPS -will be available. +If ``httpx`` is installed, then DNS-over-HTTPS will be available. If ``cryptography`` is installed, then dnspython will be -able to do low-level DNSSEC RSA, DSA, ECDSA and EdDSA signature validation. +able to do low-level DNSSEC signature generation and validation. If ``idna`` is installed, then IDNA 2008 will be available. + +If ``aioquic`` is installed, the DNS-over-QUIC will be available. diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 07b800a8..3b5fc32f 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -12,6 +12,11 @@ What's New in dnspython an IPv4, IPv6, or HTTPS URL as a nameserver, instances of ``dns.nameserver.Nameserver`` are now permitted. +* The DNS-over-HTTPS bootstrap address no longer causes URL rewriting. + +* DNS-over-HTTPS now only uses httpx; support for requests has been dropped. A source + port may now be supplied when using httpx. + 2.3.0 ----- diff --git a/examples/doh-json.py b/examples/doh-json.py index e9fa0876..c8d830ba 100755 --- a/examples/doh-json.py +++ b/examples/doh-json.py @@ -2,7 +2,7 @@ import copy import json -import requests +import httpx import dns.flags import dns.message @@ -92,7 +92,7 @@ def from_doh_simple(simple, add_qr=False): a = dns.resolver.resolve("www.dnspython.org", "a") p = to_doh_simple(a.response) print(json.dumps(p, indent=4)) -response = requests.get( +response = httpx.get( "https://dns.google/resolve?", verify=True, params={"name": "www.dnspython.org", "type": 1}, diff --git a/examples/doh.py b/examples/doh.py index 17787ed3..2fd44ff3 100755 --- a/examples/doh.py +++ b/examples/doh.py @@ -1,11 +1,7 @@ #!/usr/bin/env python3 # # This is an example of sending DNS queries over HTTPS (DoH) with dnspython. -# Requires use of the requests module's Session object. -# -# See https://2.python-requests.org/en/latest/user/advanced/#session-objects -# for more details about Session objects -import requests +import httpx import dns.message import dns.query @@ -13,31 +9,16 @@ import dns.rdatatype def main(): - where = "1.1.1.1" + where = "https://dns.google/dns-query" qname = "example.com." - # one method is to use context manager, session will automatically close - with requests.sessions.Session() as session: + with httpx.Client() as client: q = dns.message.make_query(qname, dns.rdatatype.A) - r = dns.query.https(q, where, session=session) + r = dns.query.https(q, where, session=client) for answer in r.answer: print(answer) # ... do more lookups - where = "https://dns.google/dns-query" - qname = "example.net." - # second method, close session manually - session = requests.sessions.Session() - q = dns.message.make_query(qname, dns.rdatatype.A) - r = dns.query.https(q, where, session=session) - for answer in r.answer: - print(answer) - - # ... do more lookups - - # close the session when you're done - session.close() - if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 7617862e..8f785b56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,8 +41,6 @@ documentation = "https://dnspython.readthedocs.io/en/stable/" python = "^3.7" httpx = {version=">=0.21.1", optional=true, python=">=3.6.2"} h2 = {version=">=4.1.0", optional=true, python=">=3.6.2"} -requests-toolbelt = {version=">=0.9.1,<0.11.0", optional=true} -requests = {version="^2.23.0", optional=true} idna = {version=">=2.1,<4.0", optional=true} cryptography = {version=">=2.6,<40.0", optional=true} trio = {version=">=0.14,<0.23", optional=true} @@ -63,7 +61,7 @@ mypy = ">=1.0.1" black = "^23.1.0" [tool.poetry.extras] -doh = ['httpx', 'h2', 'requests', 'requests-toolbelt'] +doh = ['httpx', 'h2'] idna = ['idna'] dnssec = ['cryptography'] trio = ['trio'] diff --git a/setup.cfg b/setup.cfg index 52325276..4a27fbf0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ test_suite = tests setup_requires = setuptools>=44; setuptools_scm[toml]>=3.4.3 [options.extras_require] -DOH = httpx>=0.21.1; h2>=4.1.0; requests; requests-toolbelt +DOH = httpx>=0.21.1; h2>=4.1.0 IDNA = idna>=2.1 DNSSEC = cryptography>=2.6 trio = trio>=0.14.0 diff --git a/tests/test_async.py b/tests/test_async.py index 62f7fc5a..5ae8854b 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -54,9 +54,18 @@ except Exception: pass query_addresses = [] +family = socket.AF_UNSPEC if tests.util.have_ipv4(): query_addresses.append("8.8.8.8") + family = socket.AF_INET if tests.util.have_ipv6(): + have_v6 = True + if family == socket.AF_INET: + # we have both working, go back to UNSPEC + family = socket.AF_UNSPEC + else: + # v6 only + family = socket.AF_INET6 query_addresses.append("2001:4860:4860::8888") KNOWN_ANYCAST_DOH_RESOLVER_URLS = [ @@ -503,7 +512,9 @@ class AsyncTests(unittest.TestCase): async def run(): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4) + r = await dns.asyncquery.https( + q, nameserver_url, post=False, timeout=4, family=family + ) self.assertTrue(q.is_response(r)) self.async_run(run) @@ -516,7 +527,9 @@ class AsyncTests(unittest.TestCase): dns.query._have_http2 = False nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4) + r = await dns.asyncquery.https( + q, nameserver_url, post=False, timeout=4, family=family + ) self.assertTrue(q.is_response(r)) finally: dns.query._have_http2 = saved_have_http2 @@ -528,7 +541,9 @@ class AsyncTests(unittest.TestCase): async def run(): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = await dns.asyncquery.https(q, nameserver_url, post=True, timeout=4) + r = await dns.asyncquery.https( + q, nameserver_url, post=True, timeout=4, family=family + ) self.assertTrue(q.is_response(r)) self.async_run(run) diff --git a/tests/test_doh.py b/tests/test_doh.py index 3626bf37..f43b1c75 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -31,10 +31,6 @@ import dns.query import dns.rdatatype import dns.resolver -if dns.query._have_requests: - import requests - from requests.exceptions import SSLError - if dns.query._have_httpx: import httpx @@ -42,18 +38,26 @@ import tests.util resolver_v4_addresses = [] resolver_v6_addresses = [] +family = socket.AF_UNSPEC if tests.util.have_ipv4(): resolver_v4_addresses = [ "1.1.1.1", "8.8.8.8", # '9.9.9.9', ] + family = socket.AF_INET if tests.util.have_ipv6(): resolver_v6_addresses = [ "2606:4700:4700::1111", "2001:4860:4860::8888", # '2620:fe::fe', ] + if family == socket.AF_INET: + # we have both working, go back to UNSPEC + family = socket.AF_UNSPEC + else: + # v6 only + family = socket.AF_INET6 KNOWN_ANYCAST_DOH_RESOLVER_URLS = [ "https://cloudflare-dns.com/dns-query", @@ -67,86 +71,6 @@ KNOWN_PAD_AWARE_DOH_RESOLVER_URLS = [ ] -@unittest.skipUnless( - dns.query._have_requests and tests.util.is_internet_reachable(), - "Python requests cannot be imported; no DNS over HTTPS (DOH)", -) -class DNSOverHTTPSTestCaseRequests(unittest.TestCase): - def setUp(self): - self.session = requests.sessions.Session() - - def tearDown(self): - self.session.close() - - def test_get_request(self): - nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = dns.query.https( - q, nameserver_url, session=self.session, post=False, timeout=4 - ) - self.assertTrue(q.is_response(r)) - - def test_post_request(self): - nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = dns.query.https( - q, nameserver_url, session=self.session, post=True, timeout=4 - ) - self.assertTrue(q.is_response(r)) - - def test_build_url_from_ip(self): - self.assertTrue(resolver_v4_addresses or resolver_v6_addresses) - if resolver_v4_addresses: - nameserver_ip = random.choice(resolver_v4_addresses) - q = dns.message.make_query("example.com.", dns.rdatatype.A) - # For some reason Google's DNS over HTTPS fails when you POST to - # https://8.8.8.8/dns-query - # So we're just going to do GET requests here - r = dns.query.https( - q, nameserver_ip, session=self.session, post=False, timeout=4 - ) - - self.assertTrue(q.is_response(r)) - if resolver_v6_addresses: - nameserver_ip = random.choice(resolver_v6_addresses) - q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = dns.query.https( - q, nameserver_ip, session=self.session, post=False, timeout=4 - ) - self.assertTrue(q.is_response(r)) - - def test_bootstrap_address(self): - # We test this to see if v4 is available - if resolver_v4_addresses: - ip = "185.228.168.168" - invalid_tls_url = "https://{}/doh/family-filter/".format(ip) - valid_tls_url = "https://doh.cleanbrowsing.org/doh/family-filter/" - q = dns.message.make_query("example.com.", dns.rdatatype.A) - # make sure CleanBrowsing's IP address will fail TLS certificate - # check - with self.assertRaises(SSLError): - dns.query.https(q, invalid_tls_url, session=self.session, timeout=4) - # use host header - r = dns.query.https( - q, valid_tls_url, session=self.session, bootstrap_address=ip, timeout=4 - ) - self.assertTrue(q.is_response(r)) - - def test_new_session(self): - nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, timeout=4) - self.assertTrue(q.is_response(r)) - - def test_resolver(self): - res = dns.resolver.Resolver(configure=False) - res.nameservers = ["https://dns.google/dns-query"] - answer = res.resolve("dns.google", "A") - seen = set([rdata.address for rdata in answer]) - self.assertTrue("8.8.8.8" in seen) - self.assertTrue("8.8.4.4" in seen) - - @unittest.skipUnless( dns.query._have_httpx and tests.util.is_internet_reachable() and _have_ssl, "Python httpx cannot be imported; no DNS over HTTPS (DOH)", @@ -162,7 +86,12 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query("example.com.", dns.rdatatype.A) r = dns.query.https( - q, nameserver_url, session=self.session, post=False, timeout=4 + q, + nameserver_url, + session=self.session, + post=False, + timeout=4, + family=family, ) self.assertTrue(q.is_response(r)) @@ -173,7 +102,12 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query("example.com.", dns.rdatatype.A) r = dns.query.https( - q, nameserver_url, session=self.session, post=False, timeout=4 + q, + nameserver_url, + session=self.session, + post=False, + timeout=4, + family=family, ) self.assertTrue(q.is_response(r)) finally: @@ -183,7 +117,12 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query("example.com.", dns.rdatatype.A) r = dns.query.https( - q, nameserver_url, session=self.session, post=True, timeout=4 + q, + nameserver_url, + session=self.session, + post=True, + timeout=4, + family=family, ) self.assertTrue(q.is_response(r)) @@ -219,17 +158,15 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): # check. with self.assertRaises(httpx.ConnectError): dns.query.https(q, invalid_tls_url, session=self.session, timeout=4) - # We can't do the Host header and SNI magic with httpx, but - # we are demanding httpx be used by providing a session, so - # we should get a NoDOH exception. - with self.assertRaises(dns.query.NoDOH): - dns.query.https( - q, - valid_tls_url, - session=self.session, - bootstrap_address=ip, - timeout=4, - ) + # And if we don't mangle the URL, it should work. + r = dns.query.https( + q, + valid_tls_url, + session=self.session, + bootstrap_address=ip, + timeout=4, + ) + self.assertTrue(q.is_response(r)) def test_new_session(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)