]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add support for persistent H3 connections. (#1184)
authorBrian Wellington <bwelling@xbill.org>
Thu, 20 Mar 2025 14:43:40 +0000 (07:43 -0700)
committerGitHub <noreply@github.com>
Thu, 20 Mar 2025 14:43:40 +0000 (07:43 -0700)
* Add support for persistent H3 connections.

* Make mypy happy.

* Make pyright happy.

dns/asyncquery.py
dns/query.py

index 883e8afc06ef10f576a1d5faa5054f1a97e641df..03e384adf005d499aa88ceed66ed393ecc753ea0 100644 (file)
@@ -536,7 +536,7 @@ async def https(
     source_port: int = 0,  # pylint: disable=W0613
     one_rr_per_rrset: bool = False,
     ignore_trailing: bool = False,
-    client: Optional["httpx.AsyncClient"] = None,
+    client: Optional["httpx.AsyncClient|dns.quic.AsyncQuicConnection"] = None,
     path: str = "/dns-query",
     post: bool = True,
     verify: Union[bool, str] = True,
@@ -591,6 +591,9 @@ async def https(
                 parsed.hostname, family  # pyright: ignore
             )
             bootstrap_address = random.choice(list(answers.addresses()))
+        if client and not isinstance(client, dns.quic.AsyncQuicConnection):  # pyright: ignore
+            raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.")
+        assert client is None or isinstance(client, dns.quic.AsyncQuicConnection)
         return await _http3(
             q,
             bootstrap_address,
@@ -603,13 +606,14 @@ async def https(
             ignore_trailing,
             verify=verify,
             post=post,
+            connection=client,
         )
 
     if not have_doh:
         raise NoDOH  # pragma: no cover
     # pylint: disable=possibly-used-before-assignment
     if client and not isinstance(client, httpx.AsyncClient):  # pyright: ignore
-        raise ValueError("session parameter must be an httpx.AsyncClient")
+        raise ValueError("client parameter must be an httpx.AsyncClient")
     # pylint: enable=possibly-used-before-assignment
 
     wire = q.to_wire()
@@ -711,6 +715,7 @@ async def _http3(
     backend: Optional[dns.asyncbackend.Backend] = None,
     hostname: Optional[str] = None,
     post: bool = True,
+    connection: Optional[dns.quic.AsyncQuicConnection] = None,
 ) -> dns.message.Message:
     if not dns.quic.have_quic:
         raise NoDOH("DNS-over-HTTP3 is not available.")  # pragma: no cover
@@ -722,15 +727,25 @@ async def _http3(
 
     q.id = 0
     wire = q.to_wire()
-    (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
+    the_connection: dns.quic.AsyncQuicConnection
+    if connection:
+        cfactory = dns.quic.null_factory
+        mfactory = dns.quic.null_factory
+    else:
+        (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
 
     async with cfactory() as context:
         async with mfactory(
             context, verify_mode=verify, server_name=hostname, h3=True
         ) as the_manager:
-            the_connection = the_manager.connect(where, port, source, source_port)
+            if connection:
+                the_connection = connection
+            else:
+                the_connection = the_manager.connect(  # pyright: ignore
+                    where, port, source, source_port
+                )
             (start, expiration) = _compute_times(timeout)
-            stream = await the_connection.make_stream(timeout)
+            stream = await the_connection.make_stream(timeout)  # pyright: ignore
             async with stream:
                 # note that send_h3() does not need await
                 stream.send_h3(url, wire, post)
index b7ebe1ecdb857490cdba0d28ca9daafcbae6f668..b81ffd18a9802e80a556adbe905b4cbb9f4483b6 100644 (file)
@@ -491,6 +491,8 @@ def https(
             assert parsed.hostname is not None  # pyright: ignore
             answers = resolver.resolve_name(parsed.hostname, family)  # pyright: ignore
             bootstrap_address = random.choice(list(answers.addresses()))
+        if session and not isinstance(session, dns.quic.SyncQuicConnection):  # pyright: ignore
+            raise ValueError("session parameter must be a dns.quic.SyncQuicConnection.")
         return _http3(
             q,
             bootstrap_address,
@@ -503,6 +505,7 @@ def https(
             ignore_trailing,
             verify=verify,
             post=post,
+            connection=session,
         )
 
     if not have_doh:
@@ -629,6 +632,7 @@ def _http3(
     verify: Union[bool, str] = True,
     hostname: Optional[str] = None,
     post: bool = True,
+    connection: Optional[dns.quic.SyncQuicConnection] = None,
 ) -> dns.message.Message:
     if not dns.quic.have_quic:
         raise NoDOH("DNS-over-HTTP3 is not available.")  # pragma: no cover
@@ -640,14 +644,25 @@ def _http3(
 
     q.id = 0
     wire = q.to_wire()
-    manager = dns.quic.SyncQuicManager(
-        verify_mode=verify, server_name=hostname, h3=True  # pyright: ignore
-    )
+    the_connection: dns.quic.SyncQuicConnection
+    the_manager: dns.quic.SyncQuicManager
+    if connection:
+        manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
+    else:
+        manager = dns.quic.SyncQuicManager(
+            verify_mode=verify, server_name=hostname, h3=True  # pyright: ignore
+        )
+        the_manager = manager  # for type checking happiness
 
     with manager:
-        connection = manager.connect(where, port, source, source_port)
+        if connection:
+            the_connection = connection
+        else:
+            the_connection = the_manager.connect(  # pyright: ignore
+                where, port, source, source_port
+            )
         (start, expiration) = _compute_times(timeout)
-        with connection.make_stream(timeout) as stream:
+        with the_connection.make_stream(timeout) as stream:  # pyright: ignore
             stream.send_h3(url, wire, post)
             wire = stream.receive(_remaining(expiration))
             _check_status(stream.headers(), where, wire)