From: Bob Halley Date: Mon, 18 Aug 2025 18:11:32 +0000 (-0700) Subject: Make better APIs for creating sockets and TLS contexts (#1217) X-Git-Tag: v2.8.0rc1~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=24156b03601778172877b72a51f1983e75f5d1b1;p=thirdparty%2Fdnspython.git Make better APIs for creating sockets and TLS contexts (#1217) * Make better APIs for creating sockets and TLS contexts for code that wants to have persistent connections. [#1176]. This code keeps the ability to use dnspython when the ssl module doesn't work, but moves the helper code to another module to declutter and make testing and type checking easier. We still have to make some type checking compromises, but we are making fewer than before. --- diff --git a/dns/_no_ssl.py b/dns/_no_ssl.py new file mode 100644 index 00000000..c64f0906 --- /dev/null +++ b/dns/_no_ssl.py @@ -0,0 +1,61 @@ +import enum +from typing import Any, Optional + +CERT_NONE = 0 + + +class TLSVersion(enum.IntEnum): + TLSv1_2 = 12 + + +class WantReadException(Exception): + pass + + +class WantWriteException(Exception): + pass + + +class SSLWantReadError(Exception): + pass + + +class SSLWantWriteError(Exception): + pass + + +class SSLContext: + def __init__(self) -> None: + self.minimum_version: Any = TLSVersion.TLSv1_2 + self.check_hostname: bool = False + self.verify_mode: int = CERT_NONE + + def wrap_socket(self, *args, **kwargs) -> "SSLSocket": # type: ignore + raise Exception("no ssl support") # pylint: disable=broad-exception-raised + + def set_alpn_protocols(self, *args, **kwargs): # type: ignore + raise Exception("no ssl support") # pylint: disable=broad-exception-raised + + +class SSLSocket: + def pending(self) -> bool: + raise Exception("no ssl support") # pylint: disable=broad-exception-raised + + def do_handshake(self) -> None: + raise Exception("no ssl support") # pylint: disable=broad-exception-raised + + def settimeout(self, value: Any) -> None: + pass + + def getpeercert(self) -> Any: + raise Exception("no ssl support") # pylint: disable=broad-exception-raised + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + +def create_default_context(*args, **kwargs) -> SSLContext: # type: ignore + raise Exception("no ssl support") # pylint: disable=broad-exception-raised diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 1e31b2e3..c7b43083 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -32,8 +32,6 @@ import dns.inet import dns.message import dns.name import dns.quic -import dns.rcode -import dns.rdataclass import dns.rdatatype import dns.transaction import dns.tsig @@ -47,13 +45,17 @@ from dns.query import ( UDPMode, _check_status, _compute_times, - _make_dot_ssl_context, _matches_destination, _remaining, have_doh, - ssl, + make_ssl_context, ) +try: + import ssl +except ImportError: + import dns._no_ssl as ssl # type: ignore + if have_doh: import httpx @@ -476,9 +478,7 @@ async def tls( cm: contextlib.AbstractAsyncContextManager = NullContext(sock) else: if ssl_context is None: - ssl_context = _make_dot_ssl_context( - server_hostname, verify - ) # pyright: ignore + ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"]) af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) dtuple = (where, port) @@ -648,7 +648,7 @@ async def https( ) cm = httpx.AsyncClient( # pyright: ignore - http1=h1, http2=h2, verify=verify, transport=transport + http1=h1, http2=h2, verify=verify, transport=transport # type: ignore ) async with cm as the_client: diff --git a/dns/query.py b/dns/query.py index 00255768..ed0a34cf 100644 --- a/dns/query.py +++ b/dns/query.py @@ -22,14 +22,13 @@ import contextlib import enum import errno import os -import os.path import random import selectors import socket import struct import time import urllib.parse -from typing import Any, Dict, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, Optional, Tuple, Union, cast import dns._features import dns._tls_util @@ -38,15 +37,18 @@ import dns.inet import dns.message import dns.name import dns.quic -import dns.rcode import dns.rdata import dns.rdataclass import dns.rdatatype -import dns.serial import dns.transaction import dns.tsig import dns.xfr +try: + import ssl +except ImportError: + import dns._no_ssl as ssl # type: ignore + def _remaining(expiration): if expiration is None: @@ -108,7 +110,7 @@ if _have_httpx: else: source = None try: - sock = _make_socket(af, socket.SOCK_STREAM, source) + sock = make_socket(af, socket.SOCK_STREAM, source) attempt_expiration = _expiration_for_this_attempt(2.0, expiration) _connect( sock, @@ -165,40 +167,20 @@ else: have_doh = _have_httpx -try: - import ssl # pyright: ignore -except ImportError: # pragma: no cover - - class ssl: # type: ignore - CERT_NONE = 0 - - class WantReadException(Exception): - pass - - class WantWriteException(Exception): - pass - - class SSLWantReadError(Exception): - pass - - class SSLWantWriteError(Exception): - pass - - class SSLContext: - pass - - class SSLSocket: - def pending(self) -> bool: - return False - @classmethod - def create_default_context(cls, *args, **kwargs): - raise Exception("no ssl support") # pylint: disable=broad-exception-raised +def default_socket_factory( + af: Union[socket.AddressFamily, int], + kind: socket.SocketKind, + proto: int, +) -> socket.socket: + return socket.socket(af, kind, proto) # Function used to create a socket. Can be overridden if needed in special # situations. -socket_factory = socket.socket +socket_factory: Callable[ + [Union[socket.AddressFamily, int], socket.SocketKind, int], socket.socket +] = default_socket_factory class UnexpectedSource(dns.exception.DNSException): @@ -339,26 +321,95 @@ def _destination_and_source( return (af, destination, source) -def _make_socket(af, type, source, ssl_context=None, server_hostname=None): - s = socket_factory(af, type) +def make_socket( + af: Union[socket.AddressFamily, int], + type: socket.SocketKind, + source: Optional[Any] = None, +) -> socket.socket: + """Make a socket. + + This function uses the module's ``socket_factory`` to make a socket of the + specified address family and type. + + *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either + ``socket.AF_INET`` or ``socket.AF_INET6``. + + *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``, + a datagram socket, or ``socket.SOCK_STREAM``, a stream socket. Note that the + ``proto`` attribute of a socket is always zero with this API, so a datagram socket + will always be a UDP socket, and a stream socket will always be a TCP socket. + + *source* is the source address and port to bind to, if any. The default is + ``None`` which will bind to the wildcard address and a randomly chosen port. + If not ``None``, it should be a (low-level) address tuple appropriate for *af*. + """ + s = socket_factory(af, type, 0) try: s.setblocking(False) if source is not None: s.bind(source) - if ssl_context: - # LGTM gets a false positive here, as our default context is OK - return ssl_context.wrap_socket( - s, - do_handshake_on_connect=False, # lgtm[py/insecure-protocol] - server_hostname=server_hostname, - ) - else: - return s + return s except Exception: s.close() raise +def make_ssl_socket( + af: Union[socket.AddressFamily, int], + type: socket.SocketKind, + ssl_context: ssl.SSLContext, + server_hostname: Optional[Union[dns.name.Name, str]] = None, + source: Optional[Any] = None, +) -> ssl.SSLSocket: + """Make a socket. + + This function uses the module's ``socket_factory`` to make a socket of the + specified address family and type. + + *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either + ``socket.AF_INET`` or ``socket.AF_INET6``. + + *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``, + a datagram socket, or ``socket.SOCK_STREAM``, a stream socket. Note that the + ``proto`` attribute of a socket is always zero with this API, so a datagram socket + will always be a UDP socket, and a stream socket will always be a TCP socket. + + If *ssl_context* is not ``None``, then it specifies the SSL context to use, + typically created with ``make_ssl_context()``. + + If *server_hostname* is not ``None``, then it is the hostname to use for server + certificate validation. A valid hostname must be supplied if *ssl_context* + requires hostname checking. + + *source* is the source address and port to bind to, if any. The default is + ``None`` which will bind to the wildcard address and a randomly chosen port. + If not ``None``, it should be a (low-level) address tuple appropriate for *af*. + """ + sock = make_socket(af, type, source) + if isinstance(server_hostname, dns.name.Name): + server_hostname = server_hostname.to_text() + # LGTM gets a false positive here, as our default context is OK + return ssl_context.wrap_socket( + sock, + do_handshake_on_connect=False, # lgtm[py/insecure-protocol] + server_hostname=server_hostname, + ) + + +# for backwards compatibility +def _make_socket( + af, + type, + source, + ssl_context, + server_hostname, +): + if ssl_context is not None: + return make_ssl_socket(af, type, ssl_context, server_hostname, source) + else: + return make_socket(af, type, source) + + def _maybe_get_resolver( resolver: Optional["dns.resolver.Resolver"], # pyright: ignore ) -> "dns.resolver.Resolver": # pyright: ignore @@ -544,8 +595,8 @@ def https( family=family, # pyright: ignore ) - cm = httpx.Client( # pyright: ignore - http1=h1, http2=h2, verify=verify, transport=transport # pyright: ignore + cm = httpx.Client( # type: ignore + http1=h1, http2=h2, verify=verify, transport=transport # type: ignore ) with cm as session: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH @@ -903,13 +954,14 @@ def udp( wire = q.to_wire() (af, destination, source) = _destination_and_source( - where, port, source, source_port + where, port, source, source_port, True ) (begin_time, expiration) = _compute_times(timeout) if sock: cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) else: - cm = _make_socket(af, socket.SOCK_DGRAM, source) + assert af is not None + cm = make_socket(af, socket.SOCK_DGRAM, source) with cm as s: send_udp(s, wire, destination, expiration) (r, received_time) = receive_udp( @@ -1198,9 +1250,10 @@ def tcp( cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock) else: (af, destination, source) = _destination_and_source( - where, port, source, source_port + where, port, source, source_port, True ) - cm = _make_socket(af, socket.SOCK_STREAM, source) + assert af is not None + cm = make_socket(af, socket.SOCK_STREAM, source) with cm as s: if not sock: # pylint: disable=possibly-used-before-assignment @@ -1229,18 +1282,44 @@ def _tls_handshake(s, expiration): _wait_for_writable(s, expiration) -def _make_dot_ssl_context( - server_hostname: Optional[str], verify: Union[bool, str] +def make_ssl_context( + verify: Union[bool, str] = True, + check_hostname: bool = True, + alpns: Optional[list[str]] = None, ) -> ssl.SSLContext: + """Make an SSL context + + If *verify* is ``True``, the default, then certificate verification will occur using + the standard CA roots. If *verify* is ``False``, then certificate verification will + be disabled. If *verify* is a string which is a valid pathname, then if the + pathname is a regular file, the CA roots will be taken from the file, otherwise if + the pathname is a directory roots will be taken from the directory. + + If *check_hostname* is ``True``, the default, then the hostname of the server must + be specified when connecting and the server's certificate must authorize the + hostname. If ``False``, then hostname checking is disabled. + + *aplns* is ``None`` or a list of TLS ALPN (Application Layer Protocol Negotiation) + strings to use in negotiation. For DNS-over-TLS, the right value is `["dot"]`. + """ cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(verify) ssl_context = ssl.create_default_context(cafile=cafile, capath=capath) - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 - if server_hostname is None: - ssl_context.check_hostname = False - ssl_context.set_alpn_protocols(["dot"]) + # the pyright ignores below are because it gets confused between the + # _no_ssl compatibility types and the real ones. + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 # type: ignore + ssl_context.check_hostname = check_hostname if verify is False: - ssl_context.verify_mode = ssl.CERT_NONE - return ssl_context + ssl_context.verify_mode = ssl.CERT_NONE # type: ignore + if alpns is not None: + ssl_context.set_alpn_protocols(alpns) + return ssl_context # type: ignore + + +# for backwards compatibility +def _make_dot_ssl_context( + server_hostname: Optional[str], verify: Union[bool, str] +) -> ssl.SSLContext: + return make_ssl_context(verify, server_hostname is not None, ["dot"]) def tls( @@ -1323,17 +1402,18 @@ def tls( wire = q.to_wire() (begin_time, expiration) = _compute_times(timeout) (af, destination, source) = _destination_and_source( - where, port, source, source_port + where, port, source, source_port, True ) - if ssl_context is None and not sock: - ssl_context = _make_dot_ssl_context(server_hostname, verify) + assert af is not None # where must be an address + if ssl_context is None: + ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"]) - with _make_socket( + with make_ssl_socket( af, socket.SOCK_STREAM, - source, ssl_context=ssl_context, server_hostname=server_hostname, + source=source, ) as s: _connect(s, destination, expiration) _tls_handshake(s, expiration) @@ -1462,7 +1542,7 @@ class UDPMode(enum.IntEnum): def _inbound_xfr( txn_manager: dns.transaction.TransactionManager, - s: socket.socket, + s: Union[socket.socket, ssl.SSLSocket], query: dns.message.Message, serial: Optional[int], timeout: Optional[float], @@ -1473,7 +1553,7 @@ def _inbound_xfr( is_ixfr = rdtype == dns.rdatatype.IXFR origin = txn_manager.from_wire_origin() wire = query.to_wire() - is_udp = s.type == socket.SOCK_DGRAM + is_udp = isinstance(s, socket.socket) and s.type == socket.SOCK_DGRAM if is_udp: _udp_send(s, wire, None, expiration) else: @@ -1617,14 +1697,15 @@ def xfr( if keyring is not None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) (af, destination, source) = _destination_and_source( - where, port, source, source_port + where, port, source, source_port, True ) + assert af is not None (_, expiration) = _compute_times(lifetime) tm = DummyTransactionManager(zone, relativize) if use_udp and rdtype != dns.rdatatype.IXFR: raise ValueError("cannot do a UDP AXFR") sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM - with _make_socket(af, sock_type, source) as s: + with make_socket(af, sock_type, source) as s: _connect(s, destination, expiration) yield from _inbound_xfr(tm, s, q, serial, timeout, expiration) @@ -1682,11 +1763,12 @@ def inbound_xfr( serial = dns.xfr.extract_serial_from_query(query) (af, destination, source) = _destination_and_source( - where, port, source, source_port + where, port, source, source_port, True ) + assert af is not None (_, expiration) = _compute_times(lifetime) if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER: - with _make_socket(af, socket.SOCK_DGRAM, source) as s: + with make_socket(af, socket.SOCK_DGRAM, source) as s: _connect(s, destination, expiration) try: for _ in _inbound_xfr( @@ -1698,7 +1780,7 @@ def inbound_xfr( if udp_mode == UDPMode.ONLY: raise - with _make_socket(af, socket.SOCK_STREAM, source) as s: + with make_socket(af, socket.SOCK_STREAM, source) as s: _connect(s, destination, expiration) for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration): pass diff --git a/doc/async-backend.rst b/doc/async-backend.rst index 5d6d5a2b..46caa4e4 100644 --- a/doc/async-backend.rst +++ b/doc/async-backend.rst @@ -4,13 +4,13 @@ module:: dns.asyncbackend Asynchronous Backend Functions ============================== -Dnspython has a "backend" for Trio, Curio, and asyncio which implements +Dnspython has "backends" for Trio and asyncio which implement the library-specific functionality needed by the generic asynchronous DNS code. Dnspython attempts to determine which backend is in use by "sniffing" for it with the ``sniffio`` module if it is installed. If sniffio is not available, -dnspython try to detect asyncio directly. +dnspython will try to detect asyncio directly. .. autofunction:: dns.asyncbackend.get_default_backend .. autofunction:: dns.asyncbackend.set_default_backend diff --git a/doc/async.rst b/doc/async.rst index 9f066833..3c7afecd 100644 --- a/doc/async.rst +++ b/doc/async.rst @@ -7,16 +7,13 @@ The ``dns.asyncquery`` and ``dns.asyncresolver`` modules offer asynchronous APIs equivalent to those of ``dns.query`` and ``dns.resolver``. -Dnspython presents a uniform API, but offers three different backend -implementations, to support the Trio, Curio, and asyncio libraries. +Dnspython presents a uniform API, but offers two different backend +implementations, to support the Trio and asyncio libraries. Dnspython attempts to detect which library is in use by using the ``sniffio`` library if it is available. It's also possible to explicitly select a "backend" library, or to pass a backend to a particular call, allowing for use in mixed library situations. -Note that Curio is not supported for DNS-over-HTTPS, due to a -lack of support in the anyio library used by httpx. - .. toctree:: async-query