]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Make better APIs for creating sockets and TLS contexts (#1217)
authorBob Halley <halley@dnspython.org>
Mon, 18 Aug 2025 18:11:32 +0000 (11:11 -0700)
committerGitHub <noreply@github.com>
Mon, 18 Aug 2025 18:11:32 +0000 (11:11 -0700)
* Make better APIs for creating sockets and TLS contexts for
code that wants to have persistent connections.  [#1176].

This code keeps the ability to use dnspython when the ssl module
doesn't work, but moves the helper code to another module to declutter
and make testing and type checking easier.

We still have to make some type checking compromises, but we are
making fewer than before.

dns/_no_ssl.py [new file with mode: 0644]
dns/asyncquery.py
dns/query.py
doc/async-backend.rst
doc/async.rst

diff --git a/dns/_no_ssl.py b/dns/_no_ssl.py
new file mode 100644 (file)
index 0000000..c64f090
--- /dev/null
@@ -0,0 +1,61 @@
+import enum
+from typing import Any, Optional
+
+CERT_NONE = 0
+
+
+class TLSVersion(enum.IntEnum):
+    TLSv1_2 = 12
+
+
+class WantReadException(Exception):
+    pass
+
+
+class WantWriteException(Exception):
+    pass
+
+
+class SSLWantReadError(Exception):
+    pass
+
+
+class SSLWantWriteError(Exception):
+    pass
+
+
+class SSLContext:
+    def __init__(self) -> None:
+        self.minimum_version: Any = TLSVersion.TLSv1_2
+        self.check_hostname: bool = False
+        self.verify_mode: int = CERT_NONE
+
+    def wrap_socket(self, *args, **kwargs) -> "SSLSocket":  # type: ignore
+        raise Exception("no ssl support")  # pylint: disable=broad-exception-raised
+
+    def set_alpn_protocols(self, *args, **kwargs):  # type: ignore
+        raise Exception("no ssl support")  # pylint: disable=broad-exception-raised
+
+
+class SSLSocket:
+    def pending(self) -> bool:
+        raise Exception("no ssl support")  # pylint: disable=broad-exception-raised
+
+    def do_handshake(self) -> None:
+        raise Exception("no ssl support")  # pylint: disable=broad-exception-raised
+
+    def settimeout(self, value: Any) -> None:
+        pass
+
+    def getpeercert(self) -> Any:
+        raise Exception("no ssl support")  # pylint: disable=broad-exception-raised
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        return False
+
+
+def create_default_context(*args, **kwargs) -> SSLContext:  # type: ignore
+    raise Exception("no ssl support")  # pylint: disable=broad-exception-raised
index 1e31b2e3e0e029e0ea745d5059ab8266773734cb..c7b43083bef4023c0ccb11ed53a35be33369aa73 100644 (file)
@@ -32,8 +32,6 @@ import dns.inet
 import dns.message
 import dns.name
 import dns.quic
-import dns.rcode
-import dns.rdataclass
 import dns.rdatatype
 import dns.transaction
 import dns.tsig
@@ -47,13 +45,17 @@ from dns.query import (
     UDPMode,
     _check_status,
     _compute_times,
-    _make_dot_ssl_context,
     _matches_destination,
     _remaining,
     have_doh,
-    ssl,
+    make_ssl_context,
 )
 
+try:
+    import ssl
+except ImportError:
+    import dns._no_ssl as ssl  # type: ignore
+
 if have_doh:
     import httpx
 
@@ -476,9 +478,7 @@ async def tls(
         cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
     else:
         if ssl_context is None:
-            ssl_context = _make_dot_ssl_context(
-                server_hostname, verify
-            )  # pyright: ignore
+            ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
         af = dns.inet.af_for_address(where)
         stuple = _source_tuple(af, source, source_port)
         dtuple = (where, port)
@@ -648,7 +648,7 @@ async def https(
         )
 
         cm = httpx.AsyncClient(  # pyright: ignore
-            http1=h1, http2=h2, verify=verify, transport=transport
+            http1=h1, http2=h2, verify=verify, transport=transport  # type: ignore
         )
 
     async with cm as the_client:
index 002557680ced32d2787fc843bdc3019cb7ccc792..ed0a34cfbb9a4d08373577bbc14cc650b7b48e06 100644 (file)
@@ -22,14 +22,13 @@ import contextlib
 import enum
 import errno
 import os
-import os.path
 import random
 import selectors
 import socket
 import struct
 import time
 import urllib.parse
-from typing import Any, Dict, Optional, Tuple, Union, cast
+from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
 
 import dns._features
 import dns._tls_util
@@ -38,15 +37,18 @@ import dns.inet
 import dns.message
 import dns.name
 import dns.quic
-import dns.rcode
 import dns.rdata
 import dns.rdataclass
 import dns.rdatatype
-import dns.serial
 import dns.transaction
 import dns.tsig
 import dns.xfr
 
+try:
+    import ssl
+except ImportError:
+    import dns._no_ssl as ssl  # type: ignore
+
 
 def _remaining(expiration):
     if expiration is None:
@@ -108,7 +110,7 @@ if _have_httpx:
                 else:
                     source = None
                 try:
-                    sock = _make_socket(af, socket.SOCK_STREAM, source)
+                    sock = make_socket(af, socket.SOCK_STREAM, source)
                     attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
                     _connect(
                         sock,
@@ -165,40 +167,20 @@ else:
 
 have_doh = _have_httpx
 
-try:
-    import ssl  # pyright: ignore
-except ImportError:  # pragma: no cover
-
-    class ssl:  # type: ignore
-        CERT_NONE = 0
-
-        class WantReadException(Exception):
-            pass
-
-        class WantWriteException(Exception):
-            pass
-
-        class SSLWantReadError(Exception):
-            pass
-
-        class SSLWantWriteError(Exception):
-            pass
-
-        class SSLContext:
-            pass
-
-        class SSLSocket:
-            def pending(self) -> bool:
-                return False
 
-        @classmethod
-        def create_default_context(cls, *args, **kwargs):
-            raise Exception("no ssl support")  # pylint: disable=broad-exception-raised
+def default_socket_factory(
+    af: Union[socket.AddressFamily, int],
+    kind: socket.SocketKind,
+    proto: int,
+) -> socket.socket:
+    return socket.socket(af, kind, proto)
 
 
 # Function used to create a socket.  Can be overridden if needed in special
 # situations.
-socket_factory = socket.socket
+socket_factory: Callable[
+    [Union[socket.AddressFamily, int], socket.SocketKind, int], socket.socket
+] = default_socket_factory
 
 
 class UnexpectedSource(dns.exception.DNSException):
@@ -339,26 +321,95 @@ def _destination_and_source(
     return (af, destination, source)
 
 
-def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
-    s = socket_factory(af, type)
+def make_socket(
+    af: Union[socket.AddressFamily, int],
+    type: socket.SocketKind,
+    source: Optional[Any] = None,
+) -> socket.socket:
+    """Make a socket.
+
+    This function uses the module's ``socket_factory`` to make a socket of the
+    specified address family and type.
+
+    *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
+    ``socket.AF_INET`` or ``socket.AF_INET6``.
+
+    *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
+    a datagram socket, or ``socket.SOCK_STREAM``, a stream socket.  Note that the
+    ``proto`` attribute of a socket is always zero with this API, so a datagram socket
+    will always be a UDP socket, and a stream socket will always be a TCP socket.
+
+    *source* is the source address and port to bind to, if any.  The default is
+    ``None`` which will bind to the wildcard address and a randomly chosen port.
+    If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
+    """
+    s = socket_factory(af, type, 0)
     try:
         s.setblocking(False)
         if source is not None:
             s.bind(source)
-        if ssl_context:
-            # LGTM gets a false positive here, as our default context is OK
-            return ssl_context.wrap_socket(
-                s,
-                do_handshake_on_connect=False,  # lgtm[py/insecure-protocol]
-                server_hostname=server_hostname,
-            )
-        else:
-            return s
+        return s
     except Exception:
         s.close()
         raise
 
 
+def make_ssl_socket(
+    af: Union[socket.AddressFamily, int],
+    type: socket.SocketKind,
+    ssl_context: ssl.SSLContext,
+    server_hostname: Optional[Union[dns.name.Name, str]] = None,
+    source: Optional[Any] = None,
+) -> ssl.SSLSocket:
+    """Make a socket.
+
+    This function uses the module's ``socket_factory`` to make a socket of the
+    specified address family and type.
+
+    *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
+    ``socket.AF_INET`` or ``socket.AF_INET6``.
+
+    *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
+    a datagram socket, or ``socket.SOCK_STREAM``, a stream socket.  Note that the
+    ``proto`` attribute of a socket is always zero with this API, so a datagram socket
+    will always be a UDP socket, and a stream socket will always be a TCP socket.
+
+    If *ssl_context* is not ``None``, then it specifies the SSL context to use,
+    typically created with ``make_ssl_context()``.
+
+    If *server_hostname* is not ``None``, then it is the hostname to use for server
+    certificate validation.  A valid hostname must be supplied if *ssl_context*
+    requires hostname checking.
+
+    *source* is the source address and port to bind to, if any.  The default is
+    ``None`` which will bind to the wildcard address and a randomly chosen port.
+    If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
+    """
+    sock = make_socket(af, type, source)
+    if isinstance(server_hostname, dns.name.Name):
+        server_hostname = server_hostname.to_text()
+    # LGTM gets a false positive here, as our default context is OK
+    return ssl_context.wrap_socket(
+        sock,
+        do_handshake_on_connect=False,  # lgtm[py/insecure-protocol]
+        server_hostname=server_hostname,
+    )
+
+
+# for backwards compatibility
+def _make_socket(
+    af,
+    type,
+    source,
+    ssl_context,
+    server_hostname,
+):
+    if ssl_context is not None:
+        return make_ssl_socket(af, type, ssl_context, server_hostname, source)
+    else:
+        return make_socket(af, type, source)
+
+
 def _maybe_get_resolver(
     resolver: Optional["dns.resolver.Resolver"],  # pyright: ignore
 ) -> "dns.resolver.Resolver":  # pyright: ignore
@@ -544,8 +595,8 @@ def https(
             family=family,  # pyright: ignore
         )
 
-        cm = httpx.Client(  # pyright: ignore
-            http1=h1, http2=h2, verify=verify, transport=transport  # pyright: ignore
+        cm = httpx.Client(  # type: ignore
+            http1=h1, http2=h2, verify=verify, transport=transport  # type: ignore
         )
     with cm as session:
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
@@ -903,13 +954,14 @@ def udp(
 
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(
-        where, port, source, source_port
+        where, port, source, source_port, True
     )
     (begin_time, expiration) = _compute_times(timeout)
     if sock:
         cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
     else:
-        cm = _make_socket(af, socket.SOCK_DGRAM, source)
+        assert af is not None
+        cm = make_socket(af, socket.SOCK_DGRAM, source)
     with cm as s:
         send_udp(s, wire, destination, expiration)
         (r, received_time) = receive_udp(
@@ -1198,9 +1250,10 @@ def tcp(
         cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
     else:
         (af, destination, source) = _destination_and_source(
-            where, port, source, source_port
+            where, port, source, source_port, True
         )
-        cm = _make_socket(af, socket.SOCK_STREAM, source)
+        assert af is not None
+        cm = make_socket(af, socket.SOCK_STREAM, source)
     with cm as s:
         if not sock:
             # pylint: disable=possibly-used-before-assignment
@@ -1229,18 +1282,44 @@ def _tls_handshake(s, expiration):
             _wait_for_writable(s, expiration)
 
 
-def _make_dot_ssl_context(
-    server_hostname: Optional[str], verify: Union[bool, str]
+def make_ssl_context(
+    verify: Union[bool, str] = True,
+    check_hostname: bool = True,
+    alpns: Optional[list[str]] = None,
 ) -> ssl.SSLContext:
+    """Make an SSL context
+
+    If *verify* is ``True``, the default, then certificate verification will occur using
+    the standard CA roots.  If *verify* is ``False``, then certificate verification will
+    be disabled.  If *verify* is a string which is a valid pathname, then if the
+    pathname is a regular file, the CA roots will be taken from the file, otherwise if
+    the pathname is a directory roots will be taken from the directory.
+
+    If *check_hostname* is ``True``, the default, then the hostname of the server must
+    be specified when connecting and the server's certificate must authorize the
+    hostname.  If ``False``, then hostname checking is disabled.
+
+    *aplns* is ``None`` or a list of TLS ALPN (Application Layer Protocol Negotiation)
+    strings to use in negotiation.  For DNS-over-TLS, the right value is `["dot"]`.
+    """
     cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(verify)
     ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
-    ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
-    if server_hostname is None:
-        ssl_context.check_hostname = False
-    ssl_context.set_alpn_protocols(["dot"])
+    # the pyright ignores below are because it gets confused between the
+    # _no_ssl compatibility types and the real ones.
+    ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2  # type: ignore
+    ssl_context.check_hostname = check_hostname
     if verify is False:
-        ssl_context.verify_mode = ssl.CERT_NONE
-    return ssl_context
+        ssl_context.verify_mode = ssl.CERT_NONE  # type: ignore
+    if alpns is not None:
+        ssl_context.set_alpn_protocols(alpns)
+    return ssl_context  # type: ignore
+
+
+# for backwards compatibility
+def _make_dot_ssl_context(
+    server_hostname: Optional[str], verify: Union[bool, str]
+) -> ssl.SSLContext:
+    return make_ssl_context(verify, server_hostname is not None, ["dot"])
 
 
 def tls(
@@ -1323,17 +1402,18 @@ def tls(
     wire = q.to_wire()
     (begin_time, expiration) = _compute_times(timeout)
     (af, destination, source) = _destination_and_source(
-        where, port, source, source_port
+        where, port, source, source_port, True
     )
-    if ssl_context is None and not sock:
-        ssl_context = _make_dot_ssl_context(server_hostname, verify)
+    assert af is not None  # where must be an address
+    if ssl_context is None:
+        ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
 
-    with _make_socket(
+    with make_ssl_socket(
         af,
         socket.SOCK_STREAM,
-        source,
         ssl_context=ssl_context,
         server_hostname=server_hostname,
+        source=source,
     ) as s:
         _connect(s, destination, expiration)
         _tls_handshake(s, expiration)
@@ -1462,7 +1542,7 @@ class UDPMode(enum.IntEnum):
 
 def _inbound_xfr(
     txn_manager: dns.transaction.TransactionManager,
-    s: socket.socket,
+    s: Union[socket.socket, ssl.SSLSocket],
     query: dns.message.Message,
     serial: Optional[int],
     timeout: Optional[float],
@@ -1473,7 +1553,7 @@ def _inbound_xfr(
     is_ixfr = rdtype == dns.rdatatype.IXFR
     origin = txn_manager.from_wire_origin()
     wire = query.to_wire()
-    is_udp = s.type == socket.SOCK_DGRAM
+    is_udp = isinstance(s, socket.socket) and s.type == socket.SOCK_DGRAM
     if is_udp:
         _udp_send(s, wire, None, expiration)
     else:
@@ -1617,14 +1697,15 @@ def xfr(
     if keyring is not None:
         q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
     (af, destination, source) = _destination_and_source(
-        where, port, source, source_port
+        where, port, source, source_port, True
     )
+    assert af is not None
     (_, expiration) = _compute_times(lifetime)
     tm = DummyTransactionManager(zone, relativize)
     if use_udp and rdtype != dns.rdatatype.IXFR:
         raise ValueError("cannot do a UDP AXFR")
     sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
-    with _make_socket(af, sock_type, source) as s:
+    with make_socket(af, sock_type, source) as s:
         _connect(s, destination, expiration)
         yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
 
@@ -1682,11 +1763,12 @@ def inbound_xfr(
         serial = dns.xfr.extract_serial_from_query(query)
 
     (af, destination, source) = _destination_and_source(
-        where, port, source, source_port
+        where, port, source, source_port, True
     )
+    assert af is not None
     (_, expiration) = _compute_times(lifetime)
     if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
-        with _make_socket(af, socket.SOCK_DGRAM, source) as s:
+        with make_socket(af, socket.SOCK_DGRAM, source) as s:
             _connect(s, destination, expiration)
             try:
                 for _ in _inbound_xfr(
@@ -1698,7 +1780,7 @@ def inbound_xfr(
                 if udp_mode == UDPMode.ONLY:
                     raise
 
-    with _make_socket(af, socket.SOCK_STREAM, source) as s:
+    with make_socket(af, socket.SOCK_STREAM, source) as s:
         _connect(s, destination, expiration)
         for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
             pass
index 5d6d5a2b077935ee83e2ccd8e28650a75e0de082..46caa4e4a08b34749b700002958b314d1a21a53b 100644 (file)
@@ -4,13 +4,13 @@ module:: dns.asyncbackend
 Asynchronous Backend Functions
 ==============================
 
-Dnspython has a "backend" for Trio, Curio, and asyncio which implements
+Dnspython has "backends" for Trio and asyncio which implement
 the library-specific functionality needed by the generic asynchronous
 DNS code.
 
 Dnspython attempts to determine which backend is in use by "sniffing" for it
 with the ``sniffio`` module if it is installed.  If sniffio is not available,
-dnspython try to detect asyncio directly.
+dnspython will try to detect asyncio directly.
 
 .. autofunction:: dns.asyncbackend.get_default_backend
 .. autofunction:: dns.asyncbackend.set_default_backend
index 9f066833479c569423049304ade40ae2aa1d561d..3c7afecd34121f9d1121a71f7a3c14e185235c9d 100644 (file)
@@ -7,16 +7,13 @@ The ``dns.asyncquery`` and ``dns.asyncresolver`` modules offer
 asynchronous APIs equivalent to those of ``dns.query`` and
 ``dns.resolver``.
 
-Dnspython presents a uniform API, but offers three different backend
-implementations, to support the Trio, Curio, and asyncio libraries.
+Dnspython presents a uniform API, but offers two different backend
+implementations, to support the Trio and asyncio libraries.
 Dnspython attempts to detect which library is in use by using the
 ``sniffio`` library if it is available.  It's also possible to
 explicitly select a "backend" library, or to pass a backend to
 a particular call, allowing for use in mixed library situations.
 
-Note that Curio is not supported for DNS-over-HTTPS, due to a
-lack of support in the anyio library used by httpx.
-
 .. toctree::
 
    async-query