From: Bob Halley Date: Sun, 26 Mar 2023 19:28:02 +0000 (-0700) Subject: Optionally allow server hostname to be checked by QUIC. X-Git-Tag: v2.4.0rc1~39 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4c25a1ef00f3e16ced7d5b487a3e07b3b1580ed4;p=thirdparty%2Fdnspython.git Optionally allow server hostname to be checked by QUIC. --- diff --git a/dns/asyncquery.py b/dns/asyncquery.py index ea539116..a2bd06e1 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -710,6 +710,7 @@ async def quic( connection: Optional[dns.quic.AsyncQuicConnection] = None, verify: Union[bool, str] = True, backend: Optional[dns.asyncbackend.Backend] = None, + server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending an asynchronous query via DNS-over-QUIC. @@ -735,7 +736,9 @@ async def quic( (cfactory, mfactory) = dns.quic.factories_for_backend(backend) async with cfactory() as context: - async with mfactory(context, verify_mode=verify) as the_manager: + async with mfactory( + context, verify_mode=verify, server_name=server_hostname + ) as the_manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) start = time.time() diff --git a/dns/query.py b/dns/query.py index f34330ff..7ba3add1 100644 --- a/dns/query.py +++ b/dns/query.py @@ -1127,6 +1127,7 @@ def quic( ignore_trailing: bool = False, connection: Optional[dns.quic.SyncQuicConnection] = None, verify: Union[bool, str] = True, + server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-QUIC. @@ -1158,6 +1159,10 @@ def quic( verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + Returns a ``dns.message.Message``. """ @@ -1172,7 +1177,9 @@ def quic( manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) the_connection = connection else: - manager = dns.quic.SyncQuicManager(verify_mode=verify) + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=server_hostname + ) the_manager = manager # for type checking happiness with manager: diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index 80f244d1..69d884c1 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -193,8 +193,8 @@ class AsyncioQuicConnection(AsyncQuicConnection): class AsyncioQuicManager(AsyncQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, AsyncioQuicConnection) + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) def connect(self, address, port=853, source=None, source_port=0): (connection, start) = self._connect(address, port, source, source_port) diff --git a/dns/quic/_common.py b/dns/quic/_common.py index d8f6f7fd..625fab7f 100644 --- a/dns/quic/_common.py +++ b/dns/quic/_common.py @@ -140,7 +140,7 @@ class AsyncQuicConnection(BaseQuicConnection): class BaseQuicManager: - def __init__(self, conf, verify_mode, connection_factory): + def __init__(self, conf, verify_mode, connection_factory, server_name=None): self._connections = {} self._connection_factory = connection_factory if conf is None: @@ -151,6 +151,7 @@ class BaseQuicManager: conf = aioquic.quic.configuration.QuicConfiguration( alpn_protocols=["doq", "doq-i03"], verify_mode=verify_mode, + server_name=server_name ) if verify_path is not None: conf.load_verify_locations(verify_path) diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index bc034fa9..bc9f172c 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -197,8 +197,8 @@ class SyncQuicConnection(BaseQuicConnection): class SyncQuicManager(BaseQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, SyncQuicConnection) + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): + super().__init__(conf, verify_mode, SyncQuicConnection, server_name) self._lock = threading.Lock() def connect(self, address, port=853, source=None, source_port=0): diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index 7f81061c..38eab3e9 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -157,8 +157,10 @@ class TrioQuicConnection(AsyncQuicConnection): class TrioQuicManager(AsyncQuicManager): - def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED): - super().__init__(conf, verify_mode, TrioQuicConnection) + def __init__( + self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None + ): + super().__init__(conf, verify_mode, TrioQuicConnection, server_name) self._nursery = nursery def connect(self, address, port=853, source=None, source_port=0):