]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Optionally allow server hostname to be checked by QUIC.
authorBob Halley <halley@dnspython.org>
Sun, 26 Mar 2023 19:28:02 +0000 (12:28 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 26 Mar 2023 19:28:02 +0000 (12:28 -0700)
dns/asyncquery.py
dns/query.py
dns/quic/_asyncio.py
dns/quic/_common.py
dns/quic/_sync.py
dns/quic/_trio.py

index ea5391165176e4a95d115bcc45ea4f7b9ec15355..a2bd06e108ba339f2bb1190494e1826376d65e1b 100644 (file)
@@ -710,6 +710,7 @@ async def quic(
     connection: Optional[dns.quic.AsyncQuicConnection] = None,
     verify: Union[bool, str] = True,
     backend: Optional[dns.asyncbackend.Backend] = None,
+    server_hostname: Optional[str] = None,
 ) -> dns.message.Message:
     """Return the response obtained after sending an asynchronous query via
     DNS-over-QUIC.
@@ -735,7 +736,9 @@ async def quic(
         (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
 
     async with cfactory() as context:
-        async with mfactory(context, verify_mode=verify) as the_manager:
+        async with mfactory(
+            context, verify_mode=verify, server_name=server_hostname
+        ) as the_manager:
             if not connection:
                 the_connection = the_manager.connect(where, port, source, source_port)
             start = time.time()
index f34330ff9082045d3b2d74873fa7696db0cea3d1..7ba3add17f2651ab256b7d799bcf0bbc54defc39 100644 (file)
@@ -1127,6 +1127,7 @@ def quic(
     ignore_trailing: bool = False,
     connection: Optional[dns.quic.SyncQuicConnection] = None,
     verify: Union[bool, str] = True,
+    server_hostname: Optional[str] = None,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-QUIC.
 
@@ -1158,6 +1159,10 @@ def quic(
     verification is done; if a `str` then it specifies the path to a certificate file or
     directory which will be used for verification.
 
+    *server_hostname*, a ``str`` containing the server's hostname.  The
+    default is ``None``, which means that no hostname is known, and if an
+    SSL context is created, hostname checking will be disabled.
+
     Returns a ``dns.message.Message``.
     """
 
@@ -1172,7 +1177,9 @@ def quic(
         manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
         the_connection = connection
     else:
-        manager = dns.quic.SyncQuicManager(verify_mode=verify)
+        manager = dns.quic.SyncQuicManager(
+            verify_mode=verify, server_name=server_hostname
+        )
         the_manager = manager  # for type checking happiness
 
     with manager:
index 80f244d125c8c358cf51a40a01c9037d0b0282d1..69d884c1cccc4db2f7466f3fbc5f8bf70df5bfe4 100644 (file)
@@ -193,8 +193,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
 
 
 class AsyncioQuicManager(AsyncQuicManager):
-    def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
-        super().__init__(conf, verify_mode, AsyncioQuicConnection)
+    def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
+        super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
 
     def connect(self, address, port=853, source=None, source_port=0):
         (connection, start) = self._connect(address, port, source, source_port)
index d8f6f7fd6ec3734a9cffeec3ceef2294f2e46ed6..625fab7f3d46f56665db20441e29d54a3e3b3b3d 100644 (file)
@@ -140,7 +140,7 @@ class AsyncQuicConnection(BaseQuicConnection):
 
 
 class BaseQuicManager:
-    def __init__(self, conf, verify_mode, connection_factory):
+    def __init__(self, conf, verify_mode, connection_factory, server_name=None):
         self._connections = {}
         self._connection_factory = connection_factory
         if conf is None:
@@ -151,6 +151,7 @@ class BaseQuicManager:
             conf = aioquic.quic.configuration.QuicConfiguration(
                 alpn_protocols=["doq", "doq-i03"],
                 verify_mode=verify_mode,
+                server_name=server_name
             )
             if verify_path is not None:
                 conf.load_verify_locations(verify_path)
index bc034fa93569766b28a57f1e0cf0b9b117187c4a..bc9f172c1d870068994f1e072b24fef7574e48f5 100644 (file)
@@ -197,8 +197,8 @@ class SyncQuicConnection(BaseQuicConnection):
 
 
 class SyncQuicManager(BaseQuicManager):
-    def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
-        super().__init__(conf, verify_mode, SyncQuicConnection)
+    def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
+        super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
         self._lock = threading.Lock()
 
     def connect(self, address, port=853, source=None, source_port=0):
index 7f81061c970ac550dc93da7c4909918fecc22053..38eab3e967a2a07666b5e5ae9ffe6aaea2260f46 100644 (file)
@@ -157,8 +157,10 @@ class TrioQuicConnection(AsyncQuicConnection):
 
 
 class TrioQuicManager(AsyncQuicManager):
-    def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED):
-        super().__init__(conf, verify_mode, TrioQuicConnection)
+    def __init__(
+        self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
+    ):
+        super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
         self._nursery = nursery
 
     def connect(self, address, port=853, source=None, source_port=0):