]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
QUIC code should process verify correctly when given a directory (#1179)
authorBob Halley <halley@dnspython.org>
Wed, 29 Jan 2025 19:36:31 +0000 (11:36 -0800)
committerGitHub <noreply@github.com>
Wed, 29 Jan 2025 19:36:31 +0000 (11:36 -0800)
path. [#1174]

dns/_tls_util.py [new file with mode: 0644]
dns/query.py
dns/quic/_common.py

diff --git a/dns/_tls_util.py b/dns/_tls_util.py
new file mode 100644 (file)
index 0000000..79c421d
--- /dev/null
@@ -0,0 +1,19 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import os
+from typing import Optional, Tuple, Union
+
+
+def convert_verify_to_cafile_and_capath(
+    verify: Union[bool, str],
+) -> Tuple[Optional[str], Optional[str]]:
+    cafile: Optional[str] = None
+    capath: Optional[str] = None
+    if isinstance(verify, str):
+        if os.path.isfile(verify):
+            cafile = verify
+        elif os.path.isdir(verify):
+            capath = verify
+        else:
+            raise ValueError("invalid verify string")
+    return cafile, capath
index 5af4a36c3dafb6873a4cb200c6ae411602c2a41a..b7ebe1ecdb857490cdba0d28ca9daafcbae6f668 100644 (file)
@@ -32,6 +32,7 @@ import urllib.parse
 from typing import Any, Dict, Optional, Tuple, Union, cast
 
 import dns._features
+import dns._tls_util
 import dns.exception
 import dns.inet
 import dns.message
@@ -1213,15 +1214,7 @@ def _tls_handshake(s, expiration):
 def _make_dot_ssl_context(
     server_hostname: Optional[str], verify: Union[bool, str]
 ) -> ssl.SSLContext:
-    cafile: Optional[str] = None
-    capath: Optional[str] = None
-    if isinstance(verify, str):
-        if os.path.isfile(verify):
-            cafile = verify
-        elif os.path.isdir(verify):
-            capath = verify
-        else:
-            raise ValueError("invalid verify string")
+    cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(verify)
     ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
     ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
     if server_hostname is None:
index 930cf660cacbb4ca413110a6a9e567f67107d8b3..d21ceea6181bfa8460fa39963261bae320ffed16 100644 (file)
@@ -14,6 +14,7 @@ import aioquic.h3.events  # type: ignore
 import aioquic.quic.configuration  # type: ignore
 import aioquic.quic.connection  # type: ignore
 
+import dns._tls_util
 import dns.inet
 
 QUIC_MAX_DATAGRAM = 2048
@@ -245,7 +246,10 @@ class BaseQuicManager:
                 server_name=server_name,
             )
             if verify_path is not None:
-                conf.load_verify_locations(verify_path)
+                cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(
+                    verify_path
+                )
+                conf.load_verify_locations(cafile=cafile, capath=capath)
         self._conf = conf
 
     def _connect(