]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
DNS-over-HTTP3 (#1048)
authorBob Halley <halley@dnspython.org>
Sat, 24 Feb 2024 13:33:57 +0000 (05:33 -0800)
committerGitHub <noreply@github.com>
Sat, 24 Feb 2024 13:33:57 +0000 (05:33 -0800)
* Implement DNS-over-HTTP3 using aioquic directly.

* Add h3 support for DoHNameserver.

dns/asyncquery.py
dns/nameserver.py
dns/query.py
dns/quic/__init__.py
dns/quic/_asyncio.py
dns/quic/_common.py
dns/quic/_sync.py
dns/quic/_trio.py
tests/test_async.py
tests/test_doh.py

index 4d9ab9ae49385e83515143ced8a04b01938fcab1..e5ebad41c9a282a5208b6629a85311696e108d63 100644 (file)
 
 import base64
 import contextlib
+import random
 import socket
 import struct
 import time
+import urllib.parse
 from typing import Any, Dict, Optional, Tuple, Union
 
 import dns.asyncbackend
@@ -40,6 +42,7 @@ from dns.query import (
     NoDOH,
     NoDOQ,
     UDPMode,
+    _check_status,
     _compute_times,
     _make_dot_ssl_context,
     _matches_destination,
@@ -500,6 +503,20 @@ async def tls(
         return response
 
 
+def _maybe_get_resolver(
+    resolver: Optional["dns.asyncresolver.Resolver"],
+) -> "dns.asyncresolver.Resolver":
+    # We need a separate method for this to avoid overriding the global
+    # variable "dns" with the as-yet undefined local variable "dns"
+    # in https().
+    if resolver is None:
+        # pylint: disable=import-outside-toplevel,redefined-outer-name
+        import dns.asyncresolver
+
+        resolver = dns.asyncresolver.Resolver()
+    return resolver
+
+
 async def https(
     q: dns.message.Message,
     where: str,
@@ -515,7 +532,8 @@ async def https(
     verify: Union[bool, str] = True,
     bootstrap_address: Optional[str] = None,
     resolver: Optional["dns.asyncresolver.Resolver"] = None,
-    family: Optional[int] = socket.AF_UNSPEC,
+    family: int = socket.AF_UNSPEC,
+    h3: bool = False,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -529,18 +547,10 @@ async def https(
     parameters, exceptions, and return type of this method.
     """
 
-    if not have_doh:
-        raise NoDOH  # pragma: no cover
-    if client and not isinstance(client, httpx.AsyncClient):
-        raise ValueError("session parameter must be an httpx.AsyncClient")
-
-    wire = q.to_wire()
     try:
         af = dns.inet.af_for_address(where)
     except ValueError:
         af = None
-    transport = None
-    headers = {"accept": "application/dns-message"}
     if af is not None and dns.inet.is_address(where):
         if af == socket.AF_INET:
             url = "https://{}:{}{}".format(where, port, path)
@@ -549,6 +559,39 @@ async def https(
     else:
         url = where
 
+    if h3:
+        if bootstrap_address is None:
+            parsed = urllib.parse.urlparse(url)
+            resolver = _maybe_get_resolver(resolver)
+            if parsed.hostname is None:
+                raise ValueError("no hostname in URL")
+            answers = await resolver.resolve_name(parsed.hostname, family)
+            bootstrap_address = random.choice(list(answers.addresses()))
+            if parsed.port is not None:
+                port = parsed.port
+        return await _http3(
+            q,
+            bootstrap_address,
+            url,
+            timeout,
+            port,
+            source,
+            source_port,
+            one_rr_per_rrset,
+            ignore_trailing,
+            verify=verify,
+            post=post,
+        )
+
+    if not have_doh:
+        raise NoDOH  # pragma: no cover
+    if client and not isinstance(client, httpx.AsyncClient):
+        raise ValueError("session parameter must be an httpx.AsyncClient")
+
+    wire = q.to_wire()
+    transport = None
+    headers = {"accept": "application/dns-message"}
+
     backend = dns.asyncbackend.get_default_backend()
 
     if source is None:
@@ -617,6 +660,57 @@ async def https(
     return r
 
 
+async def _http3(
+    q: dns.message.Message,
+    where: str,
+    url: str,
+    timeout: Optional[float] = None,
+    port: int = 853,
+    source: Optional[str] = None,
+    source_port: int = 0,
+    one_rr_per_rrset: bool = False,
+    ignore_trailing: bool = False,
+    verify: Union[bool, str] = True,
+    backend: Optional[dns.asyncbackend.Backend] = None,
+    hostname: Optional[str] = None,
+    post: bool = True,
+) -> dns.message.Message:
+    if not dns.quic.have_quic:
+        raise NoDOH("DNS-over-HTTP3 is not available.")  # pragma: no cover
+
+    url_parts = urllib.parse.urlparse(url)
+    hostname = url_parts.hostname
+
+    q.id = 0
+    wire = q.to_wire()
+    (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
+
+    async with cfactory() as context:
+        async with mfactory(
+            context, verify_mode=verify, server_name=hostname, h3=True
+        ) as the_manager:
+            the_connection = the_manager.connect(where, port, source, source_port)
+            (start, expiration) = _compute_times(timeout)
+            stream = await the_connection.make_stream(timeout)
+            async with stream:
+                # note that send_h3() does not need await
+                stream.send_h3(url, wire, post)
+                wire = await stream.receive(_remaining(expiration))
+                _check_status(stream.headers(), where, wire)
+            finish = time.time()
+        r = dns.message.from_wire(
+            wire,
+            keyring=q.keyring,
+            request_mac=q.request_mac,
+            one_rr_per_rrset=one_rr_per_rrset,
+            ignore_trailing=ignore_trailing,
+        )
+    r.time = max(finish - start, 0.0)
+    if not q.is_response(r):
+        raise BadResponse
+    return r
+
+
 async def inbound_xfr(
     where: str,
     txn_manager: dns.transaction.TransactionManager,
@@ -730,6 +824,7 @@ async def quic(
     connection: Optional[dns.quic.AsyncQuicConnection] = None,
     verify: Union[bool, str] = True,
     backend: Optional[dns.asyncbackend.Backend] = None,
+    hostname: Optional[str] = None,
     server_hostname: Optional[str] = None,
 ) -> dns.message.Message:
     """Return the response obtained after sending an asynchronous query via
@@ -745,6 +840,9 @@ async def quic(
     if not dns.quic.have_quic:
         raise NoDOQ("DNS-over-QUIC is not available.")  # pragma: no cover
 
+    if server_hostname is not None and hostname is None:
+        hostname = server_hostname
+
     q.id = 0
     wire = q.to_wire()
     the_connection: dns.quic.AsyncQuicConnection
@@ -757,7 +855,9 @@ async def quic(
 
     async with cfactory() as context:
         async with mfactory(
-            context, verify_mode=verify, server_name=server_hostname
+            context,
+            verify_mode=verify,
+            server_name=server_hostname,
         ) as the_manager:
             if not connection:
                 the_connection = the_manager.connect(where, port, source, source_port)
index 5dbb4e8baf00a4086f54ee6b796de6a89e462fb2..e8068e7e456cf8fe66455f3ec5663a00da4f0631 100644 (file)
@@ -168,12 +168,14 @@ class DoHNameserver(Nameserver):
         bootstrap_address: Optional[str] = None,
         verify: Union[bool, str] = True,
         want_get: bool = False,
+        h3: bool = False,
     ):
         super().__init__()
         self.url = url
         self.bootstrap_address = bootstrap_address
         self.verify = verify
         self.want_get = want_get
+        self.h3 = h3
 
     def kind(self):
         return "DoH"
@@ -214,6 +216,7 @@ class DoHNameserver(Nameserver):
             ignore_trailing=ignore_trailing,
             verify=self.verify,
             post=(not self.want_get),
+            h3=self.h3,
         )
 
     async def async_query(
@@ -238,6 +241,7 @@ class DoHNameserver(Nameserver):
             ignore_trailing=ignore_trailing,
             verify=self.verify,
             post=(not self.want_get),
+            h3=self.h3,
         )
 
 
index 384bf31e388f0f09f4d2e6696f038c0ee4d1f150..8f82ab676789c46ae137ca25a6409fb86476e70b 100644 (file)
@@ -23,11 +23,13 @@ import enum
 import errno
 import os
 import os.path
+import random
 import selectors
 import socket
 import struct
 import time
-from typing import Any, Dict, Optional, Tuple, Union
+import urllib.parse
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import dns._features
 import dns.exception
@@ -335,6 +337,20 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
         raise
 
 
+def _maybe_get_resolver(
+    resolver: Optional["dns.resolver.Resolver"],
+) -> "dns.resolver.Resolver":
+    # We need a separate method for this to avoid overriding the global
+    # variable "dns" with the as-yet undefined local variable "dns"
+    # in https().
+    if resolver is None:
+        # pylint: disable=import-outside-toplevel,redefined-outer-name
+        import dns.resolver
+
+        resolver = dns.resolver.Resolver()
+    return resolver
+
+
 def https(
     q: dns.message.Message,
     where: str,
@@ -350,7 +366,8 @@ def https(
     bootstrap_address: Optional[str] = None,
     verify: Union[bool, str] = True,
     resolver: Optional["dns.resolver.Resolver"] = None,
-    family: Optional[int] = socket.AF_UNSPEC,
+    family: int = socket.AF_UNSPEC,
+    h3: bool = False,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
@@ -400,20 +417,14 @@ def https(
     *family*, an ``int``, the address family.  If socket.AF_UNSPEC (the default), both A
     and AAAA records will be retrieved.
 
+    *h3*, a ``bool``.  If ``True``, use HTTP/3 otherwise use HTTP/2 or HTTP/1.1.
+
     Returns a ``dns.message.Message``.
     """
 
-    if not have_doh:
-        raise NoDOH  # pragma: no cover
-    if session and not isinstance(session, httpx.Client):
-        raise ValueError("session parameter must be an httpx.Client")
-
-    wire = q.to_wire()
     (af, _, the_source) = _destination_and_source(
         where, port, source, source_port, False
     )
-    transport = None
-    headers = {"accept": "application/dns-message"}
     if af is not None and dns.inet.is_address(where):
         if af == socket.AF_INET:
             url = "https://{}:{}{}".format(where, port, path)
@@ -422,6 +433,39 @@ def https(
     else:
         url = where
 
+    if h3:
+        if bootstrap_address is None:
+            parsed = urllib.parse.urlparse(url)
+            resolver = _maybe_get_resolver(resolver)
+            if parsed.hostname is None:
+                raise ValueError("no hostname in URL")
+            answers = resolver.resolve_name(parsed.hostname, family)
+            bootstrap_address = random.choice(list(answers.addresses()))
+            if parsed.port is not None:
+                port = parsed.port
+        return _http3(
+            q,
+            bootstrap_address,
+            url,
+            timeout,
+            port,
+            source,
+            source_port,
+            one_rr_per_rrset,
+            ignore_trailing,
+            verify=verify,
+            post=post,
+        )
+
+    if not have_doh:
+        raise NoDOH  # pragma: no cover
+    if session and not isinstance(session, httpx.Client):
+        raise ValueError("session parameter must be an httpx.Client")
+
+    wire = q.to_wire()
+    transport = None
+    headers = {"accept": "application/dns-message"}
+
     # set source port and source address
 
     if the_source is None:
@@ -483,6 +527,79 @@ def https(
     return r
 
 
+def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
+    if headers is None:
+        raise KeyError
+    for header, value in headers:
+        if header == name:
+            return value
+    raise KeyError
+
+
+def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
+    value = _find_header(headers, b":status")
+    if value is None:
+        raise SyntaxError("no :status header in response")
+    status = int(value)
+    if status < 0:
+        raise SyntaxError("status is negative")
+    if status < 200 or status > 299:
+        error = ""
+        if len(wire) > 0:
+            try:
+                error = ": " + wire.decode()
+            except Exception:
+                pass
+        raise ValueError(f"{peer} responded with status code {status}{error}")
+
+
+def _http3(
+    q: dns.message.Message,
+    where: str,
+    url: str,
+    timeout: Optional[float] = None,
+    port: int = 853,
+    source: Optional[str] = None,
+    source_port: int = 0,
+    one_rr_per_rrset: bool = False,
+    ignore_trailing: bool = False,
+    verify: Union[bool, str] = True,
+    hostname: Optional[str] = None,
+    post: bool = True,
+) -> dns.message.Message:
+    if not dns.quic.have_quic:
+        raise NoDOH("DNS-over-HTTP3 is not available.")  # pragma: no cover
+
+    url_parts = urllib.parse.urlparse(url)
+    hostname = url_parts.hostname
+
+    q.id = 0
+    wire = q.to_wire()
+    manager = dns.quic.SyncQuicManager(
+        verify_mode=verify, server_name=hostname, h3=True
+    )
+
+    with manager:
+        connection = manager.connect(where, port, source, source_port)
+        (start, expiration) = _compute_times(timeout)
+        with connection.make_stream(timeout) as stream:
+            stream.send_h3(url, wire, post)
+            wire = stream.receive(_remaining(expiration))
+            _check_status(stream.headers(), where, wire)
+        finish = time.time()
+    r = dns.message.from_wire(
+        wire,
+        keyring=q.keyring,
+        request_mac=q.request_mac,
+        one_rr_per_rrset=one_rr_per_rrset,
+        ignore_trailing=ignore_trailing,
+    )
+    r.time = max(finish - start, 0.0)
+    if not q.is_response(r):
+        raise BadResponse
+    return r
+
+
 def _udp_recv(sock, max_size, expiration):
     """Reads a datagram from the socket.
     A Timeout exception will be raised if the operation is not completed
@@ -1168,6 +1285,7 @@ def quic(
     ignore_trailing: bool = False,
     connection: Optional[dns.quic.SyncQuicConnection] = None,
     verify: Union[bool, str] = True,
+    hostname: Optional[str] = None,
     server_hostname: Optional[str] = None,
 ) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-QUIC.
@@ -1192,17 +1310,21 @@ def quic(
     *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
     received message.
 
-    *connection*, a ``dns.quic.SyncQuicConnection``.  If provided, the
-    connection to use to send the query.
+    *connection*, a ``dns.quic.SyncQuicConnection``.  If provided, the connection to use
+    to send the query.
 
     *verify*, a ``bool`` or ``str``.  If a ``True``, then TLS certificate verification
     of the server is done using the default CA bundle; if ``False``, then no
     verification is done; if a `str` then it specifies the path to a certificate file or
     directory which will be used for verification.
 
-    *server_hostname*, a ``str`` containing the server's hostname.  The
-    default is ``None``, which means that no hostname is known, and if an
-    SSL context is created, hostname checking will be disabled.
+    *hostname*, a ``str`` containing the server's hostname or ``None``.  The default is
+    ``None``, which means that no hostname is known, and if an SSL context is created,
+    hostname checking will be disabled.  This value is ignored if *url* is not
+    ``None``.
+
+    *server_hostname*, a ``str`` or ``None``.  This item is for backwards compatibility
+    only, and has the same meaning as *hostname*.
 
     Returns a ``dns.message.Message``.
     """
@@ -1210,6 +1332,9 @@ def quic(
     if not dns.quic.have_quic:
         raise NoDOQ("DNS-over-QUIC is not available.")  # pragma: no cover
 
+    if server_hostname is not None and hostname is None:
+        hostname = server_hostname
+
     q.id = 0
     wire = q.to_wire()
     the_connection: dns.quic.SyncQuicConnection
@@ -1218,9 +1343,7 @@ def quic(
         manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
         the_connection = connection
     else:
-        manager = dns.quic.SyncQuicManager(
-            verify_mode=verify, server_name=server_hostname
-        )
+        manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
         the_manager = manager  # for type checking happiness
 
     with manager:
index 20aff34552771a68db84e027912c4a461adadc6e..0750e729b4401e77bf3da1f8716e23e4538c1d24 100644 (file)
@@ -1,5 +1,7 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+from typing import List, Tuple
+
 import dns._features
 import dns.asyncbackend
 
@@ -73,3 +75,6 @@ else:  # pragma: no cover
     class SyncQuicConnection:  # type: ignore
         def make_stream(self) -> Any:
             raise NotImplementedError
+
+
+Headers = List[Tuple[bytes, bytes]]
index 0f44331f61830b3d9c7da6bb26b4f72e89744d64..069387f4f0c94592d119e35706c236f71dac04d5 100644 (file)
@@ -43,12 +43,26 @@ class AsyncioQuicStream(BaseQuicStream):
                 raise dns.exception.Timeout
             self._expecting = 0
 
+    async def wait_for_end(self, expiration):
+        while True:
+            timeout = self._timeout_from_expiration(expiration)
+            if self._buffer.seen_end():
+                return
+            try:
+                await asyncio.wait_for(self._wait_for_wake_up(), timeout)
+            except TimeoutError:
+                raise dns.exception.Timeout
+
     async def receive(self, timeout=None):
         expiration = self._expiration_from_timeout(timeout)
-        await self.wait_for(2, expiration)
-        (size,) = struct.unpack("!H", self._buffer.get(2))
-        await self.wait_for(size, expiration)
-        return self._buffer.get(size)
+        if self._connection.is_h3():
+            await self.wait_for_end(expiration)
+            return self._buffer.get_all()
+        else:
+            await self.wait_for(2, expiration)
+            (size,) = struct.unpack("!H", self._buffer.get(2))
+            await self.wait_for(size, expiration)
+            return self._buffer.get(size)
 
     async def send(self, datagram, is_end=False):
         data = self._encapsulate(datagram)
@@ -140,9 +154,28 @@ class AsyncioQuicConnection(AsyncQuicConnection):
             if event is None:
                 return
             if isinstance(event, aioquic.quic.events.StreamDataReceived):
-                stream = self._streams.get(event.stream_id)
-                if stream:
-                    await stream._add_input(event.data, event.end_stream)
+                if self.is_h3():
+                    h3_events = self._h3_conn.handle_event(event)
+                    for h3_event in h3_events:
+                        if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
+                            stream = self._streams.get(event.stream_id)
+                            if stream:
+                                if stream._headers is None:
+                                    stream._headers = h3_event.headers
+                                elif stream._trailers is None:
+                                    stream._trailers = h3_event.headers
+                                if h3_event.stream_ended:
+                                    await stream._add_input(b"", True)
+                        elif isinstance(h3_event, aioquic.h3.events.DataReceived):
+                            stream = self._streams.get(event.stream_id)
+                            if stream:
+                                await stream._add_input(
+                                    h3_event.data, h3_event.stream_ended
+                                )
+                else:
+                    stream = self._streams.get(event.stream_id)
+                    if stream:
+                        await stream._add_input(event.data, event.end_stream)
             elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
                 self._handshake_complete.set()
             elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@@ -203,8 +236,10 @@ class AsyncioQuicConnection(AsyncQuicConnection):
 
 
 class AsyncioQuicManager(AsyncQuicManager):
-    def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
-        super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
+    def __init__(
+        self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
+    ):
+        super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
 
     def connect(
         self, address, port=853, source=None, source_port=0, want_session_ticket=True
index 0eacc691aac712294ff24afbeef91b7dafbcb674..5e6c40d3c4fda845d28b5348151fb37a604ee436 100644 (file)
@@ -1,12 +1,16 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+import base64
 import copy
 import functools
 import socket
 import struct
 import time
+import urllib
 from typing import Any, Optional
 
+import aioquic.h3.connection  # type: ignore
+import aioquic.h3.events  # type: ignore
 import aioquic.quic.configuration  # type: ignore
 import aioquic.quic.connection  # type: ignore
 
@@ -51,6 +55,12 @@ class Buffer:
         self._buffer = self._buffer[amount:]
         return data
 
+    def get_all(self):
+        assert self.seen_end()
+        data = self._buffer
+        self._buffer = b""
+        return data
+
 
 class BaseQuicStream:
     def __init__(self, connection, stream_id):
@@ -58,10 +68,18 @@ class BaseQuicStream:
         self._stream_id = stream_id
         self._buffer = Buffer()
         self._expecting = 0
+        self._headers = None
+        self._trailers = None
 
     def id(self):
         return self._stream_id
 
+    def headers(self):
+        return self._headers
+
+    def trailers(self):
+        return self._trailers
+
     def _expiration_from_timeout(self, timeout):
         if timeout is not None:
             expiration = time.time() + timeout
@@ -77,16 +95,51 @@ class BaseQuicStream:
         return timeout
 
     # Subclass must implement receive() as sync / async and which returns a message
-    # or raises UnexpectedEOF.
+    # or raises.
+
+    # Subclass must implement send() as sync / async and which takes a message and
+    # an EOF indicator.
+
+    def send_h3(self, url, datagram, post=True):
+        if not self._connection.is_h3():
+            raise SyntaxError("cannot send H3 to a non-H3 connection")
+        url_parts = urllib.parse.urlparse(url)
+        path = url_parts.path.encode()
+        if post:
+            method = b"POST"
+        else:
+            method = b"GET"
+            path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
+        headers = [
+            (b":method", method),
+            (b":scheme", url_parts.scheme.encode()),
+            (b":authority", url_parts.netloc.encode()),
+            (b":path", path),
+            (b"accept", b"application/dns-message"),
+        ]
+        if post:
+            headers.extend(
+                [
+                    (b"content-type", b"application/dns-message"),
+                    (b"content-length", str(len(datagram)).encode()),
+                ]
+            )
+        self._connection.send_headers(self._stream_id, headers, not post)
+        if post:
+            self._connection.send_data(self._stream_id, datagram, True)
 
     def _encapsulate(self, datagram):
+        if self._connection.is_h3():
+            return datagram
         l = len(datagram)
         return struct.pack("!H", l) + datagram
 
     def _common_add_input(self, data, is_end):
         self._buffer.put(data, is_end)
         try:
-            return self._expecting > 0 and self._buffer.have(self._expecting)
+            return (
+                self._expecting > 0 and self._buffer.have(self._expecting)
+            ) or self._buffer.seen_end
         except UnexpectedEOF:
             return True
 
@@ -97,7 +150,13 @@ class BaseQuicStream:
 
 class BaseQuicConnection:
     def __init__(
-        self, connection, address, port, source=None, source_port=0, manager=None
+        self,
+        connection,
+        address,
+        port,
+        source=None,
+        source_port=0,
+        manager=None,
     ):
         self._done = False
         self._connection = connection
@@ -106,6 +165,10 @@ class BaseQuicConnection:
         self._closed = False
         self._manager = manager
         self._streams = {}
+        if manager.is_h3():
+            self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
+        else:
+            self._h3_conn = None
         self._af = dns.inet.af_for_address(address)
         self._peer = dns.inet.low_level_address_tuple((address, port))
         if source is None and source_port != 0:
@@ -120,9 +183,18 @@ class BaseQuicConnection:
         else:
             self._source = None
 
+    def is_h3(self):
+        return self._h3_conn is not None
+
     def close_stream(self, stream_id):
         del self._streams[stream_id]
 
+    def send_headers(self, stream_id, headers, is_end=False):
+        self._h3_conn.send_headers(stream_id, headers, is_end)
+
+    def send_data(self, stream_id, data, is_end=False):
+        self._h3_conn.send_data(stream_id, data, is_end)
+
     def _get_timer_values(self, closed_is_special=True):
         now = time.time()
         expiration = self._connection.get_timer()
@@ -148,17 +220,24 @@ class AsyncQuicConnection(BaseQuicConnection):
 
 
 class BaseQuicManager:
-    def __init__(self, conf, verify_mode, connection_factory, server_name=None):
+    def __init__(
+        self, conf, verify_mode, connection_factory, server_name=None, h3=False
+    ):
         self._connections = {}
         self._connection_factory = connection_factory
         self._session_tickets = {}
+        self._h3 = h3
         if conf is None:
             verify_path = None
             if isinstance(verify_mode, str):
                 verify_path = verify_mode
                 verify_mode = True
+            if h3:
+                alpn_protocols = ["h3"]
+            else:
+                alpn_protocols = ["doq", "doq-i03"]
             conf = aioquic.quic.configuration.QuicConfiguration(
-                alpn_protocols=["doq", "doq-i03"],
+                alpn_protocols=alpn_protocols,
                 verify_mode=verify_mode,
                 server_name=server_name,
             )
@@ -167,7 +246,12 @@ class BaseQuicManager:
         self._conf = conf
 
     def _connect(
-        self, address, port=853, source=None, source_port=0, want_session_ticket=True
+        self,
+        address,
+        port=853,
+        source=None,
+        source_port=0,
+        want_session_ticket=True,
     ):
         connection = self._connections.get((address, port))
         if connection is not None:
@@ -207,6 +291,9 @@ class BaseQuicManager:
         except KeyError:
             pass
 
+    def is_h3(self):
+        return self._h3
+
     def save_session_ticket(self, address, port, ticket):
         # We rely on dictionaries keys() being in insertion order here.  We
         # can't just popitem() as that would be LIFO which is the opposite of
index 6ef5dc9423b52c87f5a6a0a4c50ddda0a44caedc..a1062f5869d827015be319f63656bba9e90d7c31 100644 (file)
@@ -43,14 +43,29 @@ class SyncQuicStream(BaseQuicStream):
                     raise dns.exception.Timeout
             self._expecting = 0
 
+    def wait_for_end(self, expiration):
+        while True:
+            timeout = self._timeout_from_expiration(expiration)
+            with self._lock:
+                if self._buffer.seen_end():
+                    return
+            with self._wake_up:
+                if not self._wake_up.wait(timeout):
+                    raise dns.exception.Timeout
+
     def receive(self, timeout=None):
         expiration = self._expiration_from_timeout(timeout)
-        self.wait_for(2, expiration)
-        with self._lock:
-            (size,) = struct.unpack("!H", self._buffer.get(2))
-        self.wait_for(size, expiration)
-        with self._lock:
-            return self._buffer.get(size)
+        if self._connection.is_h3():
+            self.wait_for_end(expiration)
+            with self._lock:
+                return self._buffer.get_all()
+        else:
+            self.wait_for(2, expiration)
+            with self._lock:
+                (size,) = struct.unpack("!H", self._buffer.get(2))
+            self.wait_for(size, expiration)
+            with self._lock:
+                return self._buffer.get(size)
 
     def send(self, datagram, is_end=False):
         data = self._encapsulate(datagram)
@@ -147,10 +162,29 @@ class SyncQuicConnection(BaseQuicConnection):
             if event is None:
                 return
             if isinstance(event, aioquic.quic.events.StreamDataReceived):
-                with self._lock:
-                    stream = self._streams.get(event.stream_id)
-                if stream:
-                    stream._add_input(event.data, event.end_stream)
+                if self.is_h3():
+                    h3_events = self._h3_conn.handle_event(event)
+                    for h3_event in h3_events:
+                        if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
+                            with self._lock:
+                                stream = self._streams.get(event.stream_id)
+                            if stream:
+                                if stream._headers is None:
+                                    stream._headers = h3_event.headers
+                                elif stream._trailers is None:
+                                    stream._trailers = h3_event.headers
+                                if h3_event.stream_ended:
+                                    stream._add_input(b"", True)
+                        elif isinstance(h3_event, aioquic.h3.events.DataReceived):
+                            with self._lock:
+                                stream = self._streams.get(event.stream_id)
+                            if stream:
+                                stream._add_input(h3_event.data, h3_event.stream_ended)
+                else:
+                    with self._lock:
+                        stream = self._streams.get(event.stream_id)
+                    if stream:
+                        stream._add_input(event.data, event.end_stream)
             elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
                 self._handshake_complete.set()
             elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@@ -167,6 +201,18 @@ class SyncQuicConnection(BaseQuicConnection):
             self._connection.send_stream_data(stream, data, is_end)
         self._send_wakeup.send(b"\x01")
 
+    def send_headers(self, stream_id, headers, is_end=False):
+        with self._lock:
+            super().send_headers(stream_id, headers, is_end)
+        if is_end:
+            self._send_wakeup.send(b"\x01")
+
+    def send_data(self, stream_id, data, is_end=False):
+        with self._lock:
+            super().send_data(stream_id, data, is_end)
+        if is_end:
+            self._send_wakeup.send(b"\x01")
+
     def run(self):
         if self._closed:
             return
@@ -200,12 +246,19 @@ class SyncQuicConnection(BaseQuicConnection):
 
 
 class SyncQuicManager(BaseQuicManager):
-    def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
-        super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
+    def __init__(
+        self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
+    ):
+        super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
         self._lock = threading.Lock()
 
     def connect(
-        self, address, port=853, source=None, source_port=0, want_session_ticket=True
+        self,
+        address,
+        port=853,
+        source=None,
+        source_port=0,
+        want_session_ticket=True,
     ):
         with self._lock:
             (connection, start) = self._connect(
index 35e36b982f71df873fd5ac70edd46179a8091ab4..bf2845579a9747d97e6fe5fdd2cd83d7a20df885 100644 (file)
@@ -36,16 +36,27 @@ class TrioQuicStream(BaseQuicStream):
                 await self._wake_up.wait()
             self._expecting = 0
 
+    async def wait_for_end(self):
+        while True:
+            if self._buffer.seen_end():
+                return
+            async with self._wake_up:
+                await self._wake_up.wait()
+
     async def receive(self, timeout=None):
         if timeout is None:
             context = NullContext(None)
         else:
             context = trio.move_on_after(timeout)
         with context:
-            await self.wait_for(2)
-            (size,) = struct.unpack("!H", self._buffer.get(2))
-            await self.wait_for(size)
-            return self._buffer.get(size)
+            if self._connection.is_h3():
+                await self.wait_for_end()
+                return self._buffer.get_all()
+            else:
+                await self.wait_for(2)
+                (size,) = struct.unpack("!H", self._buffer.get(2))
+                await self.wait_for(size)
+                return self._buffer.get(size)
         raise dns.exception.Timeout
 
     async def send(self, datagram, is_end=False):
@@ -124,9 +135,28 @@ class TrioQuicConnection(AsyncQuicConnection):
             if event is None:
                 return
             if isinstance(event, aioquic.quic.events.StreamDataReceived):
-                stream = self._streams.get(event.stream_id)
-                if stream:
-                    await stream._add_input(event.data, event.end_stream)
+                if self.is_h3():
+                    h3_events = self._h3_conn.handle_event(event)
+                    for h3_event in h3_events:
+                        if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
+                            stream = self._streams.get(event.stream_id)
+                            if stream:
+                                if stream._headers is None:
+                                    stream._headers = h3_event.headers
+                                elif stream._trailers is None:
+                                    stream._trailers = h3_event.headers
+                                if h3_event.stream_ended:
+                                    await stream._add_input(b"", True)
+                        elif isinstance(h3_event, aioquic.h3.events.DataReceived):
+                            stream = self._streams.get(event.stream_id)
+                            if stream:
+                                await stream._add_input(
+                                    h3_event.data, h3_event.stream_ended
+                                )
+                else:
+                    stream = self._streams.get(event.stream_id)
+                    if stream:
+                        await stream._add_input(event.data, event.end_stream)
             elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
                 self._handshake_complete.set()
             elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
@@ -183,9 +213,14 @@ class TrioQuicConnection(AsyncQuicConnection):
 
 class TrioQuicManager(AsyncQuicManager):
     def __init__(
-        self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
+        self,
+        nursery,
+        conf=None,
+        verify_mode=ssl.CERT_REQUIRED,
+        server_name=None,
+        h3=False,
     ):
-        super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
+        super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
         self._nursery = nursery
 
     def connect(
index 9373548d7682a81fca688b0568cd00c7bc530304..18ba96dbefd2681fccd2809bd379ba092b513c14 100644 (file)
@@ -27,6 +27,7 @@ import dns.asyncresolver
 import dns.message
 import dns.name
 import dns.query
+import dns.quic
 import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
@@ -78,6 +79,11 @@ KNOWN_ANYCAST_DOH_RESOLVER_URLS = [
     # 'https://dns11.quad9.net/dns-query',
 ]
 
+KNOWN_ANYCAST_DOH3_RESOLVER_URLS = [
+    "https://cloudflare-dns.com/dns-query",
+    "https://dns.google/dns-query",
+]
+
 
 class AsyncDetectionTests(unittest.TestCase):
     sniff_result = "asyncio"
@@ -553,6 +559,40 @@ class AsyncTests(unittest.TestCase):
 
         self.async_run(run)
 
+    @unittest.skipIf(not dns.quic.have_quic, "aioquic not available")
+    def testDoH3GetRequest(self):
+        async def run():
+            nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+            q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+            r = await dns.asyncquery.https(
+                q,
+                nameserver_url,
+                post=False,
+                timeout=4,
+                family=family,
+                h3=True,
+            )
+            self.assertTrue(q.is_response(r))
+
+        self.async_run(run)
+
+    @unittest.skipIf(not dns.quic.have_quic, "aioquic not available")
+    def TestDoH3PostRequest(self):
+        async def run():
+            nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+            q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+            r = await dns.asyncquery.https(
+                q,
+                nameserver_url,
+                post=True,
+                timeout=4,
+                family=family,
+                h3=True,
+            )
+            self.assertTrue(q.is_response(r))
+
+        self.async_run(run)
+
     @unittest.skipIf(not dns.query._have_httpx, "httpx not available")
     def testResolverDOH(self):
         async def run():
index 0a5908f95b18d1241b36cf4328f53feeec859623..8912dd6cf135f9108091cc44c6fe4104a8cf83a5 100644 (file)
@@ -28,6 +28,7 @@ except Exception:
 import dns.edns
 import dns.message
 import dns.query
+import dns.quic
 import dns.rdatatype
 import dns.resolver
 
@@ -65,6 +66,11 @@ KNOWN_ANYCAST_DOH_RESOLVER_URLS = [
     # 'https://dns11.quad9.net/dns-query',
 ]
 
+KNOWN_ANYCAST_DOH3_RESOLVER_URLS = [
+    "https://cloudflare-dns.com/dns-query",
+    "https://dns.google/dns-query",
+]
+
 KNOWN_PAD_AWARE_DOH_RESOLVER_URLS = [
     "https://cloudflare-dns.com/dns-query",
     "https://dns.google/dns-query",
@@ -183,5 +189,37 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
         self.assertTrue(has_pad)
 
 
+@unittest.skipUnless(
+    dns.quic.have_quic and tests.util.is_internet_reachable() and _have_ssl,
+    "Aioquic cannot be imported; no DNS over HTTP3 (DOH3)",
+)
+class DNSOverHTTP3TestCase(unittest.TestCase):
+    def testDoH3GetRequest(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+        q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+        r = dns.query.https(
+            q,
+            nameserver_url,
+            post=False,
+            timeout=4,
+            family=family,
+            h3=True,
+        )
+        self.assertTrue(q.is_response(r))
+
+    def testDoH3PostRequest(self):
+        nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+        q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+        r = dns.query.https(
+            q,
+            nameserver_url,
+            post=True,
+            timeout=4,
+            family=family,
+            h3=True,
+        )
+        self.assertTrue(q.is_response(r))
+
+
 if __name__ == "__main__":
     unittest.main()