import enum
import errno
import os
-import os.path
import random
import selectors
import socket
import struct
import time
import urllib.parse
-from typing import Any, Dict, Optional, Tuple, Union, cast
+from typing import Any, Callable, Dict, Optional, Tuple, Union, cast
import dns._features
import dns._tls_util
import dns.message
import dns.name
import dns.quic
-import dns.rcode
import dns.rdata
import dns.rdataclass
import dns.rdatatype
-import dns.serial
import dns.transaction
import dns.tsig
import dns.xfr
+try:
+ import ssl
+except ImportError:
+ import dns._no_ssl as ssl # type: ignore
+
def _remaining(expiration):
if expiration is None:
else:
source = None
try:
- sock = _make_socket(af, socket.SOCK_STREAM, source)
+ sock = make_socket(af, socket.SOCK_STREAM, source)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
_connect(
sock,
have_doh = _have_httpx
-try:
- import ssl # pyright: ignore
-except ImportError: # pragma: no cover
-
- class ssl: # type: ignore
- CERT_NONE = 0
-
- class WantReadException(Exception):
- pass
-
- class WantWriteException(Exception):
- pass
-
- class SSLWantReadError(Exception):
- pass
-
- class SSLWantWriteError(Exception):
- pass
-
- class SSLContext:
- pass
-
- class SSLSocket:
- def pending(self) -> bool:
- return False
- @classmethod
- def create_default_context(cls, *args, **kwargs):
- raise Exception("no ssl support") # pylint: disable=broad-exception-raised
+def default_socket_factory(
+ af: Union[socket.AddressFamily, int],
+ kind: socket.SocketKind,
+ proto: int,
+) -> socket.socket:
+ return socket.socket(af, kind, proto)
# Function used to create a socket. Can be overridden if needed in special
# situations.
-socket_factory = socket.socket
+socket_factory: Callable[
+ [Union[socket.AddressFamily, int], socket.SocketKind, int], socket.socket
+] = default_socket_factory
class UnexpectedSource(dns.exception.DNSException):
return (af, destination, source)
-def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
- s = socket_factory(af, type)
+def make_socket(
+ af: Union[socket.AddressFamily, int],
+ type: socket.SocketKind,
+ source: Optional[Any] = None,
+) -> socket.socket:
+ """Make a socket.
+
+ This function uses the module's ``socket_factory`` to make a socket of the
+ specified address family and type.
+
+ *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
+ ``socket.AF_INET`` or ``socket.AF_INET6``.
+
+ *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
+ a datagram socket, or ``socket.SOCK_STREAM``, a stream socket. Note that the
+ ``proto`` attribute of a socket is always zero with this API, so a datagram socket
+ will always be a UDP socket, and a stream socket will always be a TCP socket.
+
+ *source* is the source address and port to bind to, if any. The default is
+ ``None`` which will bind to the wildcard address and a randomly chosen port.
+ If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
+ """
+ s = socket_factory(af, type, 0)
try:
s.setblocking(False)
if source is not None:
s.bind(source)
- if ssl_context:
- # LGTM gets a false positive here, as our default context is OK
- return ssl_context.wrap_socket(
- s,
- do_handshake_on_connect=False, # lgtm[py/insecure-protocol]
- server_hostname=server_hostname,
- )
- else:
- return s
+ return s
except Exception:
s.close()
raise
+def make_ssl_socket(
+ af: Union[socket.AddressFamily, int],
+ type: socket.SocketKind,
+ ssl_context: ssl.SSLContext,
+ server_hostname: Optional[Union[dns.name.Name, str]] = None,
+ source: Optional[Any] = None,
+) -> ssl.SSLSocket:
+ """Make a socket.
+
+ This function uses the module's ``socket_factory`` to make a socket of the
+ specified address family and type.
+
+ *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
+ ``socket.AF_INET`` or ``socket.AF_INET6``.
+
+ *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
+ a datagram socket, or ``socket.SOCK_STREAM``, a stream socket. Note that the
+ ``proto`` attribute of a socket is always zero with this API, so a datagram socket
+ will always be a UDP socket, and a stream socket will always be a TCP socket.
+
+ If *ssl_context* is not ``None``, then it specifies the SSL context to use,
+ typically created with ``make_ssl_context()``.
+
+ If *server_hostname* is not ``None``, then it is the hostname to use for server
+ certificate validation. A valid hostname must be supplied if *ssl_context*
+ requires hostname checking.
+
+ *source* is the source address and port to bind to, if any. The default is
+ ``None`` which will bind to the wildcard address and a randomly chosen port.
+ If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
+ """
+ sock = make_socket(af, type, source)
+ if isinstance(server_hostname, dns.name.Name):
+ server_hostname = server_hostname.to_text()
+ # LGTM gets a false positive here, as our default context is OK
+ return ssl_context.wrap_socket(
+ sock,
+ do_handshake_on_connect=False, # lgtm[py/insecure-protocol]
+ server_hostname=server_hostname,
+ )
+
+
+# for backwards compatibility
+def _make_socket(
+ af,
+ type,
+ source,
+ ssl_context,
+ server_hostname,
+):
+ if ssl_context is not None:
+ return make_ssl_socket(af, type, ssl_context, server_hostname, source)
+ else:
+ return make_socket(af, type, source)
+
+
def _maybe_get_resolver(
resolver: Optional["dns.resolver.Resolver"], # pyright: ignore
) -> "dns.resolver.Resolver": # pyright: ignore
family=family, # pyright: ignore
)
- cm = httpx.Client( # pyright: ignore
- http1=h1, http2=h2, verify=verify, transport=transport # pyright: ignore
+ cm = httpx.Client( # type: ignore
+ http1=h1, http2=h2, verify=verify, transport=transport # type: ignore
)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
wire = q.to_wire()
(af, destination, source) = _destination_and_source(
- where, port, source, source_port
+ where, port, source, source_port, True
)
(begin_time, expiration) = _compute_times(timeout)
if sock:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
else:
- cm = _make_socket(af, socket.SOCK_DGRAM, source)
+ assert af is not None
+ cm = make_socket(af, socket.SOCK_DGRAM, source)
with cm as s:
send_udp(s, wire, destination, expiration)
(r, received_time) = receive_udp(
cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
else:
(af, destination, source) = _destination_and_source(
- where, port, source, source_port
+ where, port, source, source_port, True
)
- cm = _make_socket(af, socket.SOCK_STREAM, source)
+ assert af is not None
+ cm = make_socket(af, socket.SOCK_STREAM, source)
with cm as s:
if not sock:
# pylint: disable=possibly-used-before-assignment
_wait_for_writable(s, expiration)
-def _make_dot_ssl_context(
- server_hostname: Optional[str], verify: Union[bool, str]
+def make_ssl_context(
+ verify: Union[bool, str] = True,
+ check_hostname: bool = True,
+ alpns: Optional[list[str]] = None,
) -> ssl.SSLContext:
+ """Make an SSL context
+
+ If *verify* is ``True``, the default, then certificate verification will occur using
+ the standard CA roots. If *verify* is ``False``, then certificate verification will
+ be disabled. If *verify* is a string which is a valid pathname, then if the
+ pathname is a regular file, the CA roots will be taken from the file, otherwise if
+ the pathname is a directory roots will be taken from the directory.
+
+ If *check_hostname* is ``True``, the default, then the hostname of the server must
+ be specified when connecting and the server's certificate must authorize the
+ hostname. If ``False``, then hostname checking is disabled.
+
+ *aplns* is ``None`` or a list of TLS ALPN (Application Layer Protocol Negotiation)
+ strings to use in negotiation. For DNS-over-TLS, the right value is `["dot"]`.
+ """
cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(verify)
ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
- ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
- if server_hostname is None:
- ssl_context.check_hostname = False
- ssl_context.set_alpn_protocols(["dot"])
+ # the pyright ignores below are because it gets confused between the
+ # _no_ssl compatibility types and the real ones.
+ ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 # type: ignore
+ ssl_context.check_hostname = check_hostname
if verify is False:
- ssl_context.verify_mode = ssl.CERT_NONE
- return ssl_context
+ ssl_context.verify_mode = ssl.CERT_NONE # type: ignore
+ if alpns is not None:
+ ssl_context.set_alpn_protocols(alpns)
+ return ssl_context # type: ignore
+
+
+# for backwards compatibility
+def _make_dot_ssl_context(
+ server_hostname: Optional[str], verify: Union[bool, str]
+) -> ssl.SSLContext:
+ return make_ssl_context(verify, server_hostname is not None, ["dot"])
def tls(
wire = q.to_wire()
(begin_time, expiration) = _compute_times(timeout)
(af, destination, source) = _destination_and_source(
- where, port, source, source_port
+ where, port, source, source_port, True
)
- if ssl_context is None and not sock:
- ssl_context = _make_dot_ssl_context(server_hostname, verify)
+ assert af is not None # where must be an address
+ if ssl_context is None:
+ ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])
- with _make_socket(
+ with make_ssl_socket(
af,
socket.SOCK_STREAM,
- source,
ssl_context=ssl_context,
server_hostname=server_hostname,
+ source=source,
) as s:
_connect(s, destination, expiration)
_tls_handshake(s, expiration)
def _inbound_xfr(
txn_manager: dns.transaction.TransactionManager,
- s: socket.socket,
+ s: Union[socket.socket, ssl.SSLSocket],
query: dns.message.Message,
serial: Optional[int],
timeout: Optional[float],
is_ixfr = rdtype == dns.rdatatype.IXFR
origin = txn_manager.from_wire_origin()
wire = query.to_wire()
- is_udp = s.type == socket.SOCK_DGRAM
+ is_udp = isinstance(s, socket.socket) and s.type == socket.SOCK_DGRAM
if is_udp:
_udp_send(s, wire, None, expiration)
else:
if keyring is not None:
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
(af, destination, source) = _destination_and_source(
- where, port, source, source_port
+ where, port, source, source_port, True
)
+ assert af is not None
(_, expiration) = _compute_times(lifetime)
tm = DummyTransactionManager(zone, relativize)
if use_udp and rdtype != dns.rdatatype.IXFR:
raise ValueError("cannot do a UDP AXFR")
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
- with _make_socket(af, sock_type, source) as s:
+ with make_socket(af, sock_type, source) as s:
_connect(s, destination, expiration)
yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)
serial = dns.xfr.extract_serial_from_query(query)
(af, destination, source) = _destination_and_source(
- where, port, source, source_port
+ where, port, source, source_port, True
)
+ assert af is not None
(_, expiration) = _compute_times(lifetime)
if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
- with _make_socket(af, socket.SOCK_DGRAM, source) as s:
+ with make_socket(af, socket.SOCK_DGRAM, source) as s:
_connect(s, destination, expiration)
try:
for _ in _inbound_xfr(
if udp_mode == UDPMode.ONLY:
raise
- with _make_socket(af, socket.SOCK_STREAM, source) as s:
+ with make_socket(af, socket.SOCK_STREAM, source) as s:
_connect(s, destination, expiration)
for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
pass