]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Uniform TLS verify argument support. (#1027)
authorBob Halley <halley@dnspython.org>
Thu, 28 Dec 2023 18:47:44 +0000 (10:47 -0800)
committerGitHub <noreply@github.com>
Thu, 28 Dec 2023 18:47:44 +0000 (10:47 -0800)
* Uniform TLS verify argument support.

* async TLS should get verify too

dns/asyncquery.py
dns/nameserver.py
dns/query.py

index 13d317d148a2ff0e47c909eabb330502d0330d19..7e6b389929edbc7334bf6e28aed1d521b055af5d 100644 (file)
@@ -42,6 +42,7 @@ from dns.query import (
     UDPMode,
     _compute_times,
     _have_http2,
+    _make_dot_ssl_context,
     _matches_destination,
     _remaining,
     have_doh,
@@ -297,7 +298,7 @@ async def send_tcp(
         # copying the wire into tcpmsg is inefficient, but lets us
         # avoid writev() or doing a short write that would get pushed
         # onto the net
-        tcpmsg = len(what).to_bytes(2, 'big') + what
+        tcpmsg = len(what).to_bytes(2, "big") + what
     sent_time = time.time()
     await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
     return (len(tcpmsg), sent_time)
@@ -416,6 +417,7 @@ async def tls(
     backend: Optional[dns.asyncbackend.Backend] = None,
     ssl_context: Optional[ssl.SSLContext] = None,
     server_hostname: Optional[str] = None,
+    verify: Union[bool, str] = True,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via TLS.
 
@@ -437,11 +439,7 @@ async def tls(
         cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
     else:
         if ssl_context is None:
-            # See the comment about ssl.create_default_context() in query.py
-            ssl_context = ssl.create_default_context()  # lgtm[py/insecure-protocol]
-            ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
-            if server_hostname is None:
-                ssl_context.check_hostname = False
+            ssl_context = _make_dot_ssl_context(server_hostname, verify)
         af = dns.inet.af_for_address(where)
         stuple = _source_tuple(af, source, source_port)
         dtuple = (where, port)
index 2639efa35e06e68980f7fce9b5d5147c3ff7da1f..a1fb54987ad97a2e656de008f2ec1d40ba9d821d 100644 (file)
@@ -158,10 +158,16 @@ class Do53Nameserver(AddressAndPortNameserver):
 
 
 class DoHNameserver(Nameserver):
-    def __init__(self, url: str, bootstrap_address: Optional[str] = None):
+    def __init__(
+        self,
+        url: str,
+        bootstrap_address: Optional[str] = None,
+        verify: Union[bool, str] = True,
+    ):
         super().__init__()
         self.url = url
         self.bootstrap_address = bootstrap_address
+        self.verify = verify
 
     def kind(self):
         return "DoH"
@@ -198,6 +204,7 @@ class DoHNameserver(Nameserver):
             bootstrap_address=self.bootstrap_address,
             one_rr_per_rrset=one_rr_per_rrset,
             ignore_trailing=ignore_trailing,
+            verify=self.verify,
         )
 
     async def async_query(
@@ -218,13 +225,21 @@ class DoHNameserver(Nameserver):
             bootstrap_address=self.bootstrap_address,
             one_rr_per_rrset=one_rr_per_rrset,
             ignore_trailing=ignore_trailing,
+            verify=self.verify,
         )
 
 
 class DoTNameserver(AddressAndPortNameserver):
-    def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
+    def __init__(
+        self,
+        address: str,
+        port: int = 853,
+        hostname: Optional[str] = None,
+        verify: Union[bool, str] = True,
+    ):
         super().__init__(address, port)
         self.hostname = hostname
+        self.verify = verify
 
     def kind(self):
         return "DoT"
@@ -247,6 +262,7 @@ class DoTNameserver(AddressAndPortNameserver):
             one_rr_per_rrset=one_rr_per_rrset,
             ignore_trailing=ignore_trailing,
             server_hostname=self.hostname,
+            verify=self.verify,
         )
 
     async def async_query(
@@ -268,6 +284,7 @@ class DoTNameserver(AddressAndPortNameserver):
             one_rr_per_rrset=one_rr_per_rrset,
             ignore_trailing=ignore_trailing,
             server_hostname=self.hostname,
+            verify=self.verify,
         )
 
 
index 869d34189669e08865416e8ad376fe1b89f1de5a..9ffdbf4fa6d712fcc8d2347cdb518a61fb49290a 100644 (file)
@@ -22,6 +22,7 @@ import contextlib
 import enum
 import errno
 import os
+import os.path
 import selectors
 import socket
 import struct
@@ -161,6 +162,8 @@ try:
 except ImportError:  # pragma: no cover
 
     class ssl:  # type: ignore
+        CERT_NONE = 0
+
         class WantReadException(Exception):
             pass
 
@@ -1012,6 +1015,28 @@ def _tls_handshake(s, expiration):
             _wait_for_writable(s, expiration)
 
 
+def _make_dot_ssl_context(
+    server_hostname: Optional[str], verify: Union[bool, str]
+) -> ssl.SSLContext:
+    cafile: Optional[str] = None
+    capath: Optional[str] = None
+    if isinstance(verify, str):
+        if os.path.isfile(verify):
+            cafile = verify
+        elif os.path.isdir(verify):
+            capath = verify
+        else:
+            raise ValueError("invalid verify string")
+    ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
+    ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
+    if server_hostname is None:
+        ssl_context.check_hostname = False
+    ssl_context.set_alpn_protocols(["dot"])
+    if verify is False:
+        ssl_context.verify_mode = ssl.CERT_NONE
+    return ssl_context
+
+
 def tls(
     q: dns.message.Message,
     where: str,
@@ -1024,6 +1049,7 @@ def tls(
     sock: Optional[ssl.SSLSocket] = None,
     ssl_context: Optional[ssl.SSLContext] = None,
     server_hostname: Optional[str] = None,
+    verify: Union[bool, str] = True,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via TLS.
 
@@ -1063,6 +1089,11 @@ def tls(
     default is ``None``, which means that no hostname is known, and if an
     SSL context is created, hostname checking will be disabled.
 
+    *verify*, a ``bool`` or ``str``.  If a ``True``, then TLS certificate verification
+    of the server is done using the default CA bundle; if ``False``, then no
+    verification is done; if a `str` then it specifies the path to a certificate file or
+    directory which will be used for verification.
+
     Returns a ``dns.message.Message``.
 
     """
@@ -1089,11 +1120,7 @@ def tls(
         where, port, source, source_port
     )
     if ssl_context is None and not sock:
-        ssl_context = ssl.create_default_context()
-        ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
-        if server_hostname is None:
-            ssl_context.check_hostname = False
-        ssl_context.set_alpn_protocols(["dot"])
+        ssl_context = _make_dot_ssl_context(server_hostname, verify)
 
     with _make_socket(
         af,