]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Better DNS-over-HTTPS support. (#908)
authorBob Halley <halley@dnspython.org>
Sun, 19 Mar 2023 16:24:32 +0000 (09:24 -0700)
committerGitHub <noreply@github.com>
Sun, 19 Mar 2023 16:24:32 +0000 (09:24 -0700)
This change:

Allows resolution hostnames in URLs using dnspython's resolver
or via a bootstrap address, without rewriting URLs.

Adds full support for source addresses and ports to
httpx, except for asyncio I/O where only the source address
can be specified.

Removes support for requests.

15 files changed:
.github/workflows/codeql-analysis.yml
dns/_asyncbackend.py
dns/_asyncio_backend.py
dns/_trio_backend.py
dns/asyncquery.py
dns/inet.py
dns/query.py
doc/installation.rst
doc/whatsnew.rst
examples/doh-json.py
examples/doh.py
pyproject.toml
setup.cfg
tests/test_async.py
tests/test_doh.py

index 84e4ecc461943a9e77d9f3a2c32d3fd023c3cace..3f887fea1854214cf81c3cb719e967e2d9b0227c 100644 (file)
@@ -60,7 +60,7 @@ jobs:
         sudo apt install -y gnome-keyring
         python -m pip install --upgrade pip
         python -m pip install poetry
-        poetry install -E dnssec -E doh -E idna -E trio
+        poetry install -E dnssec -E doh -E idna -E trio -E curio
 
     - name: Perform CodeQL Analysis
       uses: github/codeql-action/analyze@v2
index ff24604f8d2e4acb5ba0f0a048988f1a52343adc..7fd4926b9b2ce9fe913950ea33bc974875f20e7c 100644 (file)
@@ -61,6 +61,11 @@ class StreamSocket(Socket):  # pragma: no cover
         raise NotImplementedError
 
 
+class NullTransport:
+    async def connect_tcp(self, host, port, timeout, local_address):
+        raise NotImplementedError
+
+
 class Backend:  # pragma: no cover
     def name(self):
         return "unknown"
@@ -83,3 +88,6 @@ class Backend:  # pragma: no cover
 
     async def sleep(self, interval):
         raise NotImplementedError
+
+    def get_transport_class(self):
+        raise NotImplementedError
index 82a06249733567c3985dbc403ce2cb9f46bc2e08..98971be9e9fe584a0fb815d61f1700fe97bc8297 100644 (file)
@@ -113,6 +113,82 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
         return self.writer.get_extra_info("sockname")
 
 
+try:
+    import anyio
+    import httpx
+
+    import httpcore
+    import httpcore.backends.base
+    import httpcore.backends.asyncio
+
+    from dns.query import _compute_times, _remaining, _expiration_for_this_attempt
+
+    class _NetworkBackend(httpcore.backends.base.AsyncNetworkBackend):
+        def __init__(self, resolver, local_port, bootstrap_address, family):
+            super().__init__()
+            self._local_port = local_port
+            self._resolver = resolver
+            self._bootstrap_address = bootstrap_address
+            self._family = family
+            if local_port != 0:
+                raise NotImplementedError(
+                    "the asyncio transport for HTTPX cannot set the local port"
+                )
+
+        async def connect_tcp(self, host, port, timeout, local_address):
+            addresses = []
+            now, expiration = _compute_times(timeout)
+            if dns.inet.is_address(host):
+                addresses.append(host)
+            elif self._bootstrap_address is not None:
+                addresses.append(self._bootstrap_address)
+            else:
+                timeout = _remaining(expiration)
+                family = self._family
+                if local_address:
+                    family = dns.inet.af_for_address(local_address)
+                answers = await self._resolver.resolve_name(
+                    host, family=family, lifetime=timeout
+                )
+                addresses = answers.addresses()
+            for address in addresses:
+                try:
+                    attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
+                    timeout = _remaining(attempt_expiration)
+                    with anyio.fail_after(timeout):
+                        stream = await anyio.connect_tcp(
+                            remote_host=host,
+                            remote_port=port,
+                            local_host=local_address,
+                        )
+                    return httpcore.backends.asyncio.AsyncIOStream(stream)
+                except Exception:
+                    pass
+            raise httpcore.ConnectError
+
+    class _HTTPTransport(httpx.AsyncHTTPTransport):
+        def __init__(
+            self,
+            *args,
+            local_port=0,
+            bootstrap_address=None,
+            resolver=None,
+            family=socket.AF_UNSPEC,
+            **kwargs,
+        ):
+            if resolver is None:
+                import dns.asyncresolver
+
+                resolver = dns.asyncresolver.Resolver()
+            super().__init__(*args, **kwargs)
+            self._pool._network_backend = _NetworkBackend(
+                resolver, local_port, bootstrap_address, family
+            )
+
+except ImportError:
+    _HTTPTransport = dns._asyncbackend.NullTransport  # type: ignore
+
+
 class Backend(dns._asyncbackend.Backend):
     def name(self):
         return "asyncio"
@@ -171,3 +247,6 @@ class Backend(dns._asyncbackend.Backend):
 
     def datagram_connection_required(self):
         return _is_win32
+
+    def get_transport_class(self):
+        return _HTTPTransport
index b0c021033779a1f4d939e4d025bfd9e6cb2a5727..08101f9abb240ca18e5878fb1a80053170443220 100644 (file)
@@ -83,6 +83,80 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
             return self.stream.socket.getsockname()
 
 
+try:
+    import httpx
+
+    import httpcore
+    import httpcore.backends.base
+    import httpcore.backends.trio
+
+    from dns.query import _compute_times, _remaining, _expiration_for_this_attempt
+
+    class _NetworkBackend(httpcore.backends.base.AsyncNetworkBackend):
+        def __init__(self, resolver, local_port, bootstrap_address, family):
+            super().__init__()
+            self._local_port = local_port
+            self._resolver = resolver
+            self._bootstrap_address = bootstrap_address
+            self._family = family
+
+        async def connect_tcp(self, host, port, timeout, local_address):
+            addresses = []
+            now, expiration = _compute_times(timeout)
+            if dns.inet.is_address(host):
+                addresses.append(host)
+            elif self._bootstrap_address is not None:
+                addresses.append(self._bootstrap_address)
+            else:
+                timeout = _remaining(expiration)
+                family = self._family
+                if local_address:
+                    family = dns.inet.af_for_address(local_address)
+                answers = await self._resolver.resolve_name(
+                    host, family=family, lifetime=timeout
+                )
+                addresses = answers.addresses()
+            for address in addresses:
+                try:
+                    af = dns.inet.af_for_address(address)
+                    if local_address is not None or self._local_port != 0:
+                        source = (local_address, self._local_port)
+                    else:
+                        source = None
+                    destination = (address, port)
+                    attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
+                    timeout = _remaining(attempt_expiration)
+                    sock = await Backend().make_socket(
+                        af, socket.SOCK_STREAM, 0, source, destination, timeout
+                    )
+                    return httpcore.backends.trio.TrioStream(sock.stream)
+                except Exception:
+                    continue
+            raise httpcore.ConnectError
+
+    class _HTTPTransport(httpx.AsyncHTTPTransport):
+        def __init__(
+            self,
+            *args,
+            local_port=0,
+            bootstrap_address=None,
+            resolver=None,
+            family=socket.AF_UNSPEC,
+            **kwargs,
+        ):
+            if resolver is None:
+                import dns.asyncresolver
+
+                resolver = dns.asyncresolver.Resolver()
+            super().__init__(*args, **kwargs)
+            self._pool._network_backend = _NetworkBackend(
+                resolver, local_port, bootstrap_address, family
+            )
+
+except ImportError:
+    _HTTPTransport = dns._asyncbackend.NullTransport  # type: ignore
+
+
 class Backend(dns._asyncbackend.Backend):
     def name(self):
         return "trio"
@@ -104,8 +178,14 @@ class Backend(dns._asyncbackend.Backend):
             if source:
                 await s.bind(_lltuple(source, af))
             if socktype == socket.SOCK_STREAM:
+                connected = False
                 with _maybe_timeout(timeout):
                     await s.connect(_lltuple(destination, af))
+                    connected = True
+                if not connected:
+                    raise dns.exception.Timeout(
+                        timeout=timeout
+                    )  # lgtm[py/unreachable-statement]
         except Exception:  # pragma: no cover
             s.close()
             raise
@@ -130,3 +210,6 @@ class Backend(dns._asyncbackend.Backend):
 
     async def sleep(self, interval):
         await trio.sleep(interval)
+
+    def get_transport_class(self):
+        return _HTTPTransport
index 459c611d27dc47363ce4005582e73fb1a0cfd991..ea5391165176e4a95d115bcc45ea4f7b9ec15355 100644 (file)
@@ -43,13 +43,13 @@ from dns.query import (
     BadResponse,
     ssl,
     UDPMode,
-    _have_httpx,
+    have_doh,
     _have_http2,
     NoDOH,
     NoDOQ,
 )
 
-if _have_httpx:
+if have_doh:
     import httpx
 
 # for brevity
@@ -495,6 +495,9 @@ async def https(
     path: str = "/dns-query",
     post: bool = True,
     verify: Union[bool, str] = True,
+    bootstrap_address: Optional[str] = None,
+    resolver: Optional["dns.asyncresolver.Resolver"] = None,
+    family: Optional[int] = socket.AF_UNSPEC,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -508,8 +511,10 @@ async def https(
     parameters, exceptions, and return type of this method.
     """
 
-    if not _have_httpx:
+    if not have_doh:
         raise NoDOH("httpx is not available.")  # pragma: no cover
+    if client and not isinstance(client, httpx.AsyncClient):
+        raise ValueError("session parameter must be an httpx.AsyncClient")
 
     wire = q.to_wire()
     try:
@@ -518,15 +523,30 @@ async def https(
         af = None
     transport = None
     headers = {"accept": "application/dns-message"}
-    if af is not None:
+    if af is not None and dns.inet.is_address(where):
         if af == socket.AF_INET:
             url = "https://{}:{}{}".format(where, port, path)
         elif af == socket.AF_INET6:
             url = "https://[{}]:{}{}".format(where, port, path)
     else:
         url = where
-    if source is not None:
-        transport = httpx.AsyncHTTPTransport(local_address=source[0])
+
+    backend = dns.asyncbackend.get_default_backend()
+
+    if source is None:
+        local_address = None
+        local_port = 0
+    else:
+        local_address = source
+        local_port = source_port
+    transport = backend.get_transport_class()(
+        local_address=local_address,
+        verify=verify,
+        local_port=local_port,
+        bootstrap_address=bootstrap_address,
+        resolver=resolver,
+        family=family,
+    )
 
     if client:
         cm: contextlib.AbstractAsyncContextManager = NullContext(client)
index 11180c969117c65898434fffcdf828b658c3d689..23a4a86e2c3bd8da80b3e55c4ef8276cb6c94d6a 100644 (file)
@@ -171,3 +171,12 @@ def low_level_address_tuple(
             return tup
     else:
         raise NotImplementedError(f"unknown address family {af}")
+
+
+def any_for_af(af):
+    """Return the 'any' address for the specified address family."""
+    if af == socket.AF_INET:
+        return "0.0.0.0"
+    elif af == socket.AF_INET6:
+        return "::"
+    raise NotImplementedError(f"unknown address family {af}")
index b4cd69f75446f322b25d3e0628cb1f33b2540522..517bab02eaf4c0954f9a26cb96f0941b54887feb 100644 (file)
@@ -43,14 +43,21 @@ import dns.transaction
 import dns.tsig
 import dns.xfr
 
-try:
-    import requests
-    from requests_toolbelt.adapters.source import SourceAddressAdapter
-    from requests_toolbelt.adapters.host_header_ssl import HostHeaderSSLAdapter
 
-    _have_requests = True
-except ImportError:  # pragma: no cover
-    _have_requests = False
+def _remaining(expiration):
+    if expiration is None:
+        return None
+    timeout = expiration - time.time()
+    if timeout <= 0.0:
+        raise dns.exception.Timeout
+    return timeout
+
+
+def _expiration_for_this_attempt(timeout, expiration):
+    if expiration is None:
+        return None
+    return min(time.time() + timeout, expiration)
+
 
 _have_httpx = False
 _have_http2 = False
@@ -64,10 +71,83 @@ try:
             _have_http2 = True
     except Exception:
         pass
+
+    import httpcore
+    import httpcore.backends.base
+    import httpcore.backends.sync
+
+    class _NetworkBackend(httpcore.backends.base.NetworkBackend):
+        def __init__(self, resolver, local_port, bootstrap_address, family):
+            super().__init__()
+            self._local_port = local_port
+            self._resolver = resolver
+            self._bootstrap_address = bootstrap_address
+            self._family = family
+
+        def connect_tcp(self, host, port, timeout, local_address):
+            addresses = []
+            now, expiration = _compute_times(timeout)
+            if dns.inet.is_address(host):
+                addresses.append(host)
+            elif self._bootstrap_address is not None:
+                addresses.append(self._bootstrap_address)
+            else:
+                timeout = _remaining(expiration)
+                family = self._family
+                if local_address:
+                    family = dns.inet.af_for_address(local_address)
+                answers = self._resolver.resolve_name(
+                    host, family=family, lifetime=timeout
+                )
+                addresses = answers.addresses()
+            for address in addresses:
+                af = dns.inet.af_for_address(address)
+                if local_address is not None or self._local_port != 0:
+                    source = dns.inet.low_level_address_tuple(
+                        (local_address, self._local_port), af
+                    )
+                else:
+                    source = None
+                sock = _make_socket(af, socket.SOCK_STREAM, source)
+                attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
+                try:
+                    _connect(
+                        sock,
+                        dns.inet.low_level_address_tuple((address, port), af),
+                        attempt_expiration,
+                    )
+                    return httpcore.backends.sync.SyncStream(sock)
+                except Exception:
+                    pass
+            raise httpcore.ConnectError
+
+    class _HTTPTransport(httpx.HTTPTransport):
+        def __init__(
+            self,
+            *args,
+            local_port=0,
+            bootstrap_address=None,
+            resolver=None,
+            family=socket.AF_UNSPEC,
+            **kwargs,
+        ):
+            if resolver is None:
+                import dns.resolver
+
+                resolver = dns.resolver.Resolver()
+            super().__init__(*args, **kwargs)
+            self._pool._network_backend = _NetworkBackend(
+                resolver, local_port, bootstrap_address, family
+            )
+
 except ImportError:  # pragma: no cover
-    pass
 
-have_doh = _have_requests or _have_httpx
+    class _HTTPTransport:  # type: ignore
+        def connect_tcp(self, host, port, timeout, local_address):
+            raise NotImplementedError
+
+
+have_doh = _have_httpx
 
 try:
     import ssl
@@ -240,11 +320,10 @@ def _destination_and_source(
         # Caller has specified a source_port but not an address, so we
         # need to return a source, and we need to use the appropriate
         # wildcard address as the address.
-        if af == socket.AF_INET:
-            source = "0.0.0.0"
-        elif af == socket.AF_INET6:
-            source = "::"
-        else:
+        try:
+            source = dns.inet.any_for_af(af)
+        except Exception:
+            # we catch this and raise ValueError for backwards compatibility
             raise ValueError("source_port specified but address family is unknown")
     # Convert high-level (address, port) tuples into low-level address
     # tuples.
@@ -289,6 +368,8 @@ def https(
     post: bool = True,
     bootstrap_address: Optional[str] = None,
     verify: Union[bool, str] = True,
+    resolver: Optional["dns.resolver.Resolver"] = None,
+    family: Optional[int] = socket.AF_UNSPEC,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -314,91 +395,76 @@ def https(
     *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
     received message.
 
-    *session*, an ``httpx.Client`` or ``requests.session.Session``.  If provided, the
-    client/session to use to send the queries.
+    *session*, an ``httpx.Client``.  If provided, the client session to use to send the
+    queries.
 
     *path*, a ``str``. If *where* is an IP address, then *path* will be used to
     construct the URL to send the DNS query to.
 
     *post*, a ``bool``. If ``True``, the default, POST method will be used.
 
-    *bootstrap_address*, a ``str``, the IP address to use to bypass the system's DNS
-    resolver.
+    *bootstrap_address*, a ``str``, the IP address to use to bypass resolution.
 
     *verify*, a ``bool`` or ``str``.  If a ``True``, then TLS certificate verification
     of the server is done using the default CA bundle; if ``False``, then no
     verification is done; if a `str` then it specifies the path to a certificate file or
     directory which will be used for verification.
 
+    *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
+    resolution of hostnames in URLs.  If not specified, a new resolver with a default
+    configuration will be used; note this is *not* the default resolver as that resolver
+    might have been configured to use DoH causing a chicken-and-egg problem.  This
+    parameter only has an effect if the HTTP library is httpx.
+
+    *family*, an ``int``, the address family.  If socket.AF_UNSPEC (the default), both A
+    and AAAA records will be retrieved.
+
     Returns a ``dns.message.Message``.
     """
 
     if not have_doh:
-        raise NoDOH("Neither httpx nor requests is available.")  # pragma: no cover
-
-    _httpx_ok = _have_httpx
+        raise NoDOH("DNS-over-HTTPS is not available.")  # pragma: no cover
+    if session and not isinstance(session, httpx.Client):
+        raise ValueError("session parameter must be an httpx.Client")
 
     wire = q.to_wire()
-    (af, _, source) = _destination_and_source(where, port, source, source_port, False)
-    transport_adapter = None
+    (af, _, the_source) = _destination_and_source(
+        where, port, source, source_port, False
+    )
     transport = None
     headers = {"accept": "application/dns-message"}
-    if af is not None:
+    if af is not None and dns.inet.is_address(where):
         if af == socket.AF_INET:
             url = "https://{}:{}{}".format(where, port, path)
         elif af == socket.AF_INET6:
             url = "https://[{}]:{}{}".format(where, port, path)
-    elif bootstrap_address is not None:
-        _httpx_ok = False
-        split_url = urllib.parse.urlsplit(where)
-        if split_url.hostname is None:
-            raise ValueError("DoH URL has no hostname")
-        headers["Host"] = split_url.hostname
-        url = where.replace(split_url.hostname, bootstrap_address)
-        if _have_requests:
-            transport_adapter = HostHeaderSSLAdapter()
     else:
         url = where
-    if source is not None:
-        # set source port and source address
-        if _have_httpx:
-            if source_port == 0:
-                transport = httpx.HTTPTransport(local_address=source[0], verify=verify)
-            else:
-                _httpx_ok = False
-        if _have_requests:
-            transport_adapter = SourceAddressAdapter(source)
 
-    if session:
-        if _have_httpx:
-            _is_httpx = isinstance(session, httpx.Client)
-        else:
-            _is_httpx = False
-        if _is_httpx and not _httpx_ok:
-            raise NoDOH(
-                "Session is httpx, but httpx cannot be used for "
-                "the requested operation."
-            )
-    else:
-        _is_httpx = _httpx_ok
+    # set source port and source address
 
-    if not _httpx_ok and not _have_requests:
-        raise NoDOH(
-            "Cannot use httpx for this operation, and requests is not available."
-        )
+    if the_source is None:
+        local_address = None
+        local_port = 0
+    else:
+        local_address = the_source[0]
+        local_port = the_source[1]
+    transport = _HTTPTransport(
+        local_address=local_address,
+        verify=verify,
+        local_port=local_port,
+        bootstrap_address=bootstrap_address,
+        resolver=resolver,
+        family=family,
+    )
 
     if session:
         cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
-    elif _is_httpx:
+    else:
         cm = httpx.Client(
             http1=True, http2=_have_http2, verify=verify, transport=transport
         )
-    else:
-        cm = requests.sessions.Session()
     with cm as session:
-        if transport_adapter and not _is_httpx:
-            session.mount(url, transport_adapter)
-
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
         # GET and POST examples
         if post:
@@ -408,29 +474,13 @@ def https(
                     "content-length": str(len(wire)),
                 }
             )
-            if _is_httpx:
-                response = session.post(
-                    url, headers=headers, content=wire, timeout=timeout
-                )
-            else:
-                response = session.post(
-                    url, headers=headers, data=wire, timeout=timeout, verify=verify
-                )
+            response = session.post(url, headers=headers, content=wire, timeout=timeout)
         else:
             wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
-            if _is_httpx:
-                twire = wire.decode()  # httpx does a repr() if we give it bytes
-                response = session.get(
-                    url, headers=headers, timeout=timeout, params={"dns": twire}
-                )
-            else:
-                response = session.get(
-                    url,
-                    headers=headers,
-                    timeout=timeout,
-                    verify=verify,
-                    params={"dns": wire},
-                )
+            twire = wire.decode()  # httpx does a repr() if we give it bytes
+            response = session.get(
+                url, headers=headers, timeout=timeout, params={"dns": twire}
+            )
 
     # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
     # status codes
index 81dcdb04d5b6763fbf6116ac56819bd691f75f18..35b46ae2edfbe120dffd306500653c37b5f77a2a 100644 (file)
@@ -45,10 +45,11 @@ Optional Modules
 
 The following modules are optional, but recommended for full functionality.
 
-If ``requests`` and ``requests-toolbelt`` are installed, then DNS-over-HTTPS
-will be available.
+If ``httpx`` is installed, then DNS-over-HTTPS will be available.
 
 If ``cryptography`` is installed, then dnspython will be
-able to do low-level DNSSEC RSA, DSA, ECDSA and EdDSA signature validation.
+able to do low-level DNSSEC signature generation and validation.
 
 If ``idna`` is installed, then IDNA 2008 will be available.
+
+If ``aioquic`` is installed, the DNS-over-QUIC will be available.
index 07b800a87108c67df411fd19e534797109d2fb51..3b5fc32f5b38c4582c0b70123f24bcb458f7bb78 100644 (file)
@@ -12,6 +12,11 @@ What's New in dnspython
   an IPv4, IPv6, or HTTPS URL as a nameserver, instances of ``dns.nameserver.Nameserver``
   are now permitted.
 
+* The DNS-over-HTTPS bootstrap address no longer causes URL rewriting.
+
+* DNS-over-HTTPS now only uses httpx; support for requests has been dropped.  A source
+  port may now be supplied when using httpx.
+
 2.3.0
 -----
 
index e9fa087637a9e1c29a0945bebf1b9a588cd11793..c8d830ba51410319930e6dedf55269b4fd16cd56 100755 (executable)
@@ -2,7 +2,7 @@
 
 import copy
 import json
-import requests
+import httpx
 
 import dns.flags
 import dns.message
@@ -92,7 +92,7 @@ def from_doh_simple(simple, add_qr=False):
 a = dns.resolver.resolve("www.dnspython.org", "a")
 p = to_doh_simple(a.response)
 print(json.dumps(p, indent=4))
-response = requests.get(
+response = httpx.get(
     "https://dns.google/resolve?",
     verify=True,
     params={"name": "www.dnspython.org", "type": 1},
index 17787ed38c422552700176bed272cded1d57bf8f..2fd44ff3b02c55bed1fb6d51e28ab4c564aff6a9 100755 (executable)
@@ -1,11 +1,7 @@
 #!/usr/bin/env python3
 #
 # This is an example of sending DNS queries over HTTPS (DoH) with dnspython.
-# Requires use of the requests module's Session object.
-#
-# See https://2.python-requests.org/en/latest/user/advanced/#session-objects
-# for more details about Session objects
-import requests
+import httpx
 
 import dns.message
 import dns.query
@@ -13,31 +9,16 @@ import dns.rdatatype
 
 
 def main():
-    where = "1.1.1.1"
+    where = "https://dns.google/dns-query"
     qname = "example.com."
-    # one method is to use context manager, session will automatically close
-    with requests.sessions.Session() as session:
+    with httpx.Client() as client:
         q = dns.message.make_query(qname, dns.rdatatype.A)
-        r = dns.query.https(q, where, session=session)
+        r = dns.query.https(q, where, session=client)
         for answer in r.answer:
             print(answer)
 
         # ... do more lookups
 
-    where = "https://dns.google/dns-query"
-    qname = "example.net."
-    # second method, close session manually
-    session = requests.sessions.Session()
-    q = dns.message.make_query(qname, dns.rdatatype.A)
-    r = dns.query.https(q, where, session=session)
-    for answer in r.answer:
-        print(answer)
-
-    # ... do more lookups
-
-    # close the session when you're done
-    session.close()
-
 
 if __name__ == "__main__":
     main()
index 7617862e554003ba83ae2cbd33a0562ef255824c..8f785b564da9b87f3ead61dc95124fa942e4a8c1 100644 (file)
@@ -41,8 +41,6 @@ documentation = "https://dnspython.readthedocs.io/en/stable/"
 python = "^3.7"
 httpx = {version=">=0.21.1", optional=true, python=">=3.6.2"}
 h2 = {version=">=4.1.0", optional=true, python=">=3.6.2"}
-requests-toolbelt = {version=">=0.9.1,<0.11.0", optional=true}
-requests = {version="^2.23.0", optional=true}
 idna = {version=">=2.1,<4.0", optional=true}
 cryptography = {version=">=2.6,<40.0", optional=true}
 trio = {version=">=0.14,<0.23", optional=true}
@@ -63,7 +61,7 @@ mypy = ">=1.0.1"
 black = "^23.1.0"
 
 [tool.poetry.extras]
-doh = ['httpx', 'h2', 'requests', 'requests-toolbelt']
+doh = ['httpx', 'h2']
 idna = ['idna']
 dnssec = ['cryptography']
 trio = ['trio']
index 52325276c12de6d38273512187b1d79e5d21ee16..4a27fbf068990423d09aff63f703977d55da32a1 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -50,7 +50,7 @@ test_suite = tests
 setup_requires = setuptools>=44; setuptools_scm[toml]>=3.4.3
 
 [options.extras_require]
-DOH = httpx>=0.21.1; h2>=4.1.0; requests; requests-toolbelt
+DOH = httpx>=0.21.1; h2>=4.1.0
 IDNA = idna>=2.1
 DNSSEC = cryptography>=2.6
 trio = trio>=0.14.0
index 62f7fc5a190ff2d1cc298fa3524777610cbd9dc8..5ae8854b586a003cc25444466694fd027a556b8b 100644 (file)
@@ -54,9 +54,18 @@ except Exception:
     pass
 
 query_addresses = []
+family = socket.AF_UNSPEC
 if tests.util.have_ipv4():
     query_addresses.append("8.8.8.8")
+    family = socket.AF_INET
 if tests.util.have_ipv6():
+    have_v6 = True
+    if family == socket.AF_INET:
+        # we have both working, go back to UNSPEC
+        family = socket.AF_UNSPEC
+    else:
+        # v6 only
+        family = socket.AF_INET6
     query_addresses.append("2001:4860:4860::8888")
 
 KNOWN_ANYCAST_DOH_RESOLVER_URLS = [
@@ -503,7 +512,9 @@ class AsyncTests(unittest.TestCase):
         async def run():
             nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
             q = dns.message.make_query("example.com.", dns.rdatatype.A)
-            r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4)
+            r = await dns.asyncquery.https(
+                q, nameserver_url, post=False, timeout=4, family=family
+            )
             self.assertTrue(q.is_response(r))
 
         self.async_run(run)
@@ -516,7 +527,9 @@ class AsyncTests(unittest.TestCase):
                 dns.query._have_http2 = False
                 nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
                 q = dns.message.make_query("example.com.", dns.rdatatype.A)
-                r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4)
+                r = await dns.asyncquery.https(
+                    q, nameserver_url, post=False, timeout=4, family=family
+                )
                 self.assertTrue(q.is_response(r))
             finally:
                 dns.query._have_http2 = saved_have_http2
@@ -528,7 +541,9 @@ class AsyncTests(unittest.TestCase):
         async def run():
             nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
             q = dns.message.make_query("example.com.", dns.rdatatype.A)
-            r = await dns.asyncquery.https(q, nameserver_url, post=True, timeout=4)
+            r = await dns.asyncquery.https(
+                q, nameserver_url, post=True, timeout=4, family=family
+            )
             self.assertTrue(q.is_response(r))
 
         self.async_run(run)
index 3626bf37df726db45b61577071929c73bfa64952..f43b1c759d1404c63a491282a26083bab7fe0bce 100644 (file)
@@ -31,10 +31,6 @@ import dns.query
 import dns.rdatatype
 import dns.resolver
 
-if dns.query._have_requests:
-    import requests
-    from requests.exceptions import SSLError
-
 if dns.query._have_httpx:
     import httpx
 
@@ -42,18 +38,26 @@ import tests.util
 
 resolver_v4_addresses = []
 resolver_v6_addresses = []
+family = socket.AF_UNSPEC
 if tests.util.have_ipv4():
     resolver_v4_addresses = [
         "1.1.1.1",
         "8.8.8.8",
         # '9.9.9.9',
     ]
+    family = socket.AF_INET
 if tests.util.have_ipv6():
     resolver_v6_addresses = [
         "2606:4700:4700::1111",
         "2001:4860:4860::8888",
         # '2620:fe::fe',
     ]
+    if family == socket.AF_INET:
+        # we have both working, go back to UNSPEC
+        family = socket.AF_UNSPEC
+    else:
+        # v6 only
+        family = socket.AF_INET6
 
 KNOWN_ANYCAST_DOH_RESOLVER_URLS = [
     "https://cloudflare-dns.com/dns-query",
@@ -67,86 +71,6 @@ KNOWN_PAD_AWARE_DOH_RESOLVER_URLS = [
 ]
 
 
-@unittest.skipUnless(
-    dns.query._have_requests and tests.util.is_internet_reachable(),
-    "Python requests cannot be imported; no DNS over HTTPS (DOH)",
-)
-class DNSOverHTTPSTestCaseRequests(unittest.TestCase):
-    def setUp(self):
-        self.session = requests.sessions.Session()
-
-    def tearDown(self):
-        self.session.close()
-
-    def test_get_request(self):
-        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
-        q = dns.message.make_query("example.com.", dns.rdatatype.A)
-        r = dns.query.https(
-            q, nameserver_url, session=self.session, post=False, timeout=4
-        )
-        self.assertTrue(q.is_response(r))
-
-    def test_post_request(self):
-        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
-        q = dns.message.make_query("example.com.", dns.rdatatype.A)
-        r = dns.query.https(
-            q, nameserver_url, session=self.session, post=True, timeout=4
-        )
-        self.assertTrue(q.is_response(r))
-
-    def test_build_url_from_ip(self):
-        self.assertTrue(resolver_v4_addresses or resolver_v6_addresses)
-        if resolver_v4_addresses:
-            nameserver_ip = random.choice(resolver_v4_addresses)
-            q = dns.message.make_query("example.com.", dns.rdatatype.A)
-            # For some reason Google's DNS over HTTPS fails when you POST to
-            # https://8.8.8.8/dns-query
-            # So we're just going to do GET requests here
-            r = dns.query.https(
-                q, nameserver_ip, session=self.session, post=False, timeout=4
-            )
-
-            self.assertTrue(q.is_response(r))
-        if resolver_v6_addresses:
-            nameserver_ip = random.choice(resolver_v6_addresses)
-            q = dns.message.make_query("example.com.", dns.rdatatype.A)
-            r = dns.query.https(
-                q, nameserver_ip, session=self.session, post=False, timeout=4
-            )
-            self.assertTrue(q.is_response(r))
-
-    def test_bootstrap_address(self):
-        # We test this to see if v4 is available
-        if resolver_v4_addresses:
-            ip = "185.228.168.168"
-            invalid_tls_url = "https://{}/doh/family-filter/".format(ip)
-            valid_tls_url = "https://doh.cleanbrowsing.org/doh/family-filter/"
-            q = dns.message.make_query("example.com.", dns.rdatatype.A)
-            # make sure CleanBrowsing's IP address will fail TLS certificate
-            # check
-            with self.assertRaises(SSLError):
-                dns.query.https(q, invalid_tls_url, session=self.session, timeout=4)
-            # use host header
-            r = dns.query.https(
-                q, valid_tls_url, session=self.session, bootstrap_address=ip, timeout=4
-            )
-            self.assertTrue(q.is_response(r))
-
-    def test_new_session(self):
-        nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
-        q = dns.message.make_query("example.com.", dns.rdatatype.A)
-        r = dns.query.https(q, nameserver_url, timeout=4)
-        self.assertTrue(q.is_response(r))
-
-    def test_resolver(self):
-        res = dns.resolver.Resolver(configure=False)
-        res.nameservers = ["https://dns.google/dns-query"]
-        answer = res.resolve("dns.google", "A")
-        seen = set([rdata.address for rdata in answer])
-        self.assertTrue("8.8.8.8" in seen)
-        self.assertTrue("8.8.4.4" in seen)
-
-
 @unittest.skipUnless(
     dns.query._have_httpx and tests.util.is_internet_reachable() and _have_ssl,
     "Python httpx cannot be imported; no DNS over HTTPS (DOH)",
@@ -162,7 +86,12 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query("example.com.", dns.rdatatype.A)
         r = dns.query.https(
-            q, nameserver_url, session=self.session, post=False, timeout=4
+            q,
+            nameserver_url,
+            session=self.session,
+            post=False,
+            timeout=4,
+            family=family,
         )
         self.assertTrue(q.is_response(r))
 
@@ -173,7 +102,12 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
             nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
             q = dns.message.make_query("example.com.", dns.rdatatype.A)
             r = dns.query.https(
-                q, nameserver_url, session=self.session, post=False, timeout=4
+                q,
+                nameserver_url,
+                session=self.session,
+                post=False,
+                timeout=4,
+                family=family,
             )
             self.assertTrue(q.is_response(r))
         finally:
@@ -183,7 +117,12 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query("example.com.", dns.rdatatype.A)
         r = dns.query.https(
-            q, nameserver_url, session=self.session, post=True, timeout=4
+            q,
+            nameserver_url,
+            session=self.session,
+            post=True,
+            timeout=4,
+            family=family,
         )
         self.assertTrue(q.is_response(r))
 
@@ -219,17 +158,15 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
             # check.
             with self.assertRaises(httpx.ConnectError):
                 dns.query.https(q, invalid_tls_url, session=self.session, timeout=4)
-            # We can't do the Host header and SNI magic with httpx, but
-            # we are demanding httpx be used by providing a session, so
-            # we should get a NoDOH exception.
-            with self.assertRaises(dns.query.NoDOH):
-                dns.query.https(
-                    q,
-                    valid_tls_url,
-                    session=self.session,
-                    bootstrap_address=ip,
-                    timeout=4,
-                )
+            # And if we don't mangle the URL, it should work.
+            r = dns.query.https(
+                q,
+                valid_tls_url,
+                session=self.session,
+                bootstrap_address=ip,
+                timeout=4,
+            )
+            self.assertTrue(q.is_response(r))
 
     def test_new_session(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)