From: Bob Halley Date: Thu, 28 Dec 2023 18:47:44 +0000 (-0800) Subject: Uniform TLS verify argument support. (#1027) X-Git-Tag: v2.5.0rc1~5 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=609d6b2e7ba4d01e3541558e01e9f7357bc6d0c6;p=thirdparty%2Fdnspython.git Uniform TLS verify argument support. (#1027) * Uniform TLS verify argument support. * async TLS should get verify too --- diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 13d317d1..7e6b3899 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -42,6 +42,7 @@ from dns.query import ( UDPMode, _compute_times, _have_http2, + _make_dot_ssl_context, _matches_destination, _remaining, have_doh, @@ -297,7 +298,7 @@ async def send_tcp( # copying the wire into tcpmsg is inefficient, but lets us # avoid writev() or doing a short write that would get pushed # onto the net - tcpmsg = len(what).to_bytes(2, 'big') + what + tcpmsg = len(what).to_bytes(2, "big") + what sent_time = time.time() await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) return (len(tcpmsg), sent_time) @@ -416,6 +417,7 @@ async def tls( backend: Optional[dns.asyncbackend.Backend] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None, + verify: Union[bool, str] = True, ) -> dns.message.Message: """Return the response obtained after sending a query via TLS. @@ -437,11 +439,7 @@ async def tls( cm: contextlib.AbstractAsyncContextManager = NullContext(sock) else: if ssl_context is None: - # See the comment about ssl.create_default_context() in query.py - ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol] - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 - if server_hostname is None: - ssl_context.check_hostname = False + ssl_context = _make_dot_ssl_context(server_hostname, verify) af = dns.inet.af_for_address(where) stuple = _source_tuple(af, source, source_port) dtuple = (where, port) diff --git a/dns/nameserver.py b/dns/nameserver.py index 2639efa3..a1fb5498 100644 --- a/dns/nameserver.py +++ b/dns/nameserver.py @@ -158,10 +158,16 @@ class Do53Nameserver(AddressAndPortNameserver): class DoHNameserver(Nameserver): - def __init__(self, url: str, bootstrap_address: Optional[str] = None): + def __init__( + self, + url: str, + bootstrap_address: Optional[str] = None, + verify: Union[bool, str] = True, + ): super().__init__() self.url = url self.bootstrap_address = bootstrap_address + self.verify = verify def kind(self): return "DoH" @@ -198,6 +204,7 @@ class DoHNameserver(Nameserver): bootstrap_address=self.bootstrap_address, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, + verify=self.verify, ) async def async_query( @@ -218,13 +225,21 @@ class DoHNameserver(Nameserver): bootstrap_address=self.bootstrap_address, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, + verify=self.verify, ) class DoTNameserver(AddressAndPortNameserver): - def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None): + def __init__( + self, + address: str, + port: int = 853, + hostname: Optional[str] = None, + verify: Union[bool, str] = True, + ): super().__init__(address, port) self.hostname = hostname + self.verify = verify def kind(self): return "DoT" @@ -247,6 +262,7 @@ class DoTNameserver(AddressAndPortNameserver): one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, server_hostname=self.hostname, + verify=self.verify, ) async def async_query( @@ -268,6 +284,7 @@ class DoTNameserver(AddressAndPortNameserver): one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, server_hostname=self.hostname, + verify=self.verify, ) diff --git a/dns/query.py b/dns/query.py index 869d3418..9ffdbf4f 100644 --- a/dns/query.py +++ b/dns/query.py @@ -22,6 +22,7 @@ import contextlib import enum import errno import os +import os.path import selectors import socket import struct @@ -161,6 +162,8 @@ try: except ImportError: # pragma: no cover class ssl: # type: ignore + CERT_NONE = 0 + class WantReadException(Exception): pass @@ -1012,6 +1015,28 @@ def _tls_handshake(s, expiration): _wait_for_writable(s, expiration) +def _make_dot_ssl_context( + server_hostname: Optional[str], verify: Union[bool, str] +) -> ssl.SSLContext: + cafile: Optional[str] = None + capath: Optional[str] = None + if isinstance(verify, str): + if os.path.isfile(verify): + cafile = verify + elif os.path.isdir(verify): + capath = verify + else: + raise ValueError("invalid verify string") + ssl_context = ssl.create_default_context(cafile=cafile, capath=capath) + ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 + if server_hostname is None: + ssl_context.check_hostname = False + ssl_context.set_alpn_protocols(["dot"]) + if verify is False: + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context + + def tls( q: dns.message.Message, where: str, @@ -1024,6 +1049,7 @@ def tls( sock: Optional[ssl.SSLSocket] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None, + verify: Union[bool, str] = True, ) -> dns.message.Message: """Return the response obtained after sending a query via TLS. @@ -1063,6 +1089,11 @@ def tls( default is ``None``, which means that no hostname is known, and if an SSL context is created, hostname checking will be disabled. + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. + Returns a ``dns.message.Message``. """ @@ -1089,11 +1120,7 @@ def tls( where, port, source, source_port ) if ssl_context is None and not sock: - ssl_context = ssl.create_default_context() - ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 - if server_hostname is None: - ssl_context.check_hostname = False - ssl_context.set_alpn_protocols(["dot"]) + ssl_context = _make_dot_ssl_context(server_hostname, verify) with _make_socket( af,