From e9d58f27a8ededf5764cae170b4c8dc3c2e9a308 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 24 Feb 2024 05:33:57 -0800 Subject: [PATCH] DNS-over-HTTP3 (#1048) * Implement DNS-over-HTTP3 using aioquic directly. * Add h3 support for DoHNameserver. --- dns/asyncquery.py | 120 +++++++++++++++++++++++++++++--- dns/nameserver.py | 4 ++ dns/query.py | 159 ++++++++++++++++++++++++++++++++++++++----- dns/quic/__init__.py | 5 ++ dns/quic/_asyncio.py | 53 ++++++++++++--- dns/quic/_common.py | 99 +++++++++++++++++++++++++-- dns/quic/_sync.py | 79 +++++++++++++++++---- dns/quic/_trio.py | 53 ++++++++++++--- tests/test_async.py | 40 +++++++++++ tests/test_doh.py | 38 +++++++++++ 10 files changed, 585 insertions(+), 65 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 4d9ab9ae..e5ebad41 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -19,9 +19,11 @@ import base64 import contextlib +import random import socket import struct import time +import urllib.parse from typing import Any, Dict, Optional, Tuple, Union import dns.asyncbackend @@ -40,6 +42,7 @@ from dns.query import ( NoDOH, NoDOQ, UDPMode, + _check_status, _compute_times, _make_dot_ssl_context, _matches_destination, @@ -500,6 +503,20 @@ async def tls( return response +def _maybe_get_resolver( + resolver: Optional["dns.asyncresolver.Resolver"], +) -> "dns.asyncresolver.Resolver": + # We need a separate method for this to avoid overriding the global + # variable "dns" with the as-yet undefined local variable "dns" + # in https(). + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.asyncresolver + + resolver = dns.asyncresolver.Resolver() + return resolver + + async def https( q: dns.message.Message, where: str, @@ -515,7 +532,8 @@ async def https( verify: Union[bool, str] = True, bootstrap_address: Optional[str] = None, resolver: Optional["dns.asyncresolver.Resolver"] = None, - family: Optional[int] = socket.AF_UNSPEC, + family: int = socket.AF_UNSPEC, + h3: bool = False, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -529,18 +547,10 @@ async def https( parameters, exceptions, and return type of this method. """ - if not have_doh: - raise NoDOH # pragma: no cover - if client and not isinstance(client, httpx.AsyncClient): - raise ValueError("session parameter must be an httpx.AsyncClient") - - wire = q.to_wire() try: af = dns.inet.af_for_address(where) except ValueError: af = None - transport = None - headers = {"accept": "application/dns-message"} if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) @@ -549,6 +559,39 @@ async def https( else: url = where + if h3: + if bootstrap_address is None: + parsed = urllib.parse.urlparse(url) + resolver = _maybe_get_resolver(resolver) + if parsed.hostname is None: + raise ValueError("no hostname in URL") + answers = await resolver.resolve_name(parsed.hostname, family) + bootstrap_address = random.choice(list(answers.addresses())) + if parsed.port is not None: + port = parsed.port + return await _http3( + q, + bootstrap_address, + url, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + verify=verify, + post=post, + ) + + if not have_doh: + raise NoDOH # pragma: no cover + if client and not isinstance(client, httpx.AsyncClient): + raise ValueError("session parameter must be an httpx.AsyncClient") + + wire = q.to_wire() + transport = None + headers = {"accept": "application/dns-message"} + backend = dns.asyncbackend.get_default_backend() if source is None: @@ -617,6 +660,57 @@ async def https( return r +async def _http3( + q: dns.message.Message, + where: str, + url: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + verify: Union[bool, str] = True, + backend: Optional[dns.asyncbackend.Backend] = None, + hostname: Optional[str] = None, + post: bool = True, +) -> dns.message.Message: + if not dns.quic.have_quic: + raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover + + url_parts = urllib.parse.urlparse(url) + hostname = url_parts.hostname + + q.id = 0 + wire = q.to_wire() + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + + async with cfactory() as context: + async with mfactory( + context, verify_mode=verify, server_name=hostname, h3=True + ) as the_manager: + the_connection = the_manager.connect(where, port, source, source_port) + (start, expiration) = _compute_times(timeout) + stream = await the_connection.make_stream(timeout) + async with stream: + # note that send_h3() does not need await + stream.send_h3(url, wire, post) + wire = await stream.receive(_remaining(expiration)) + _check_status(stream.headers(), where, wire) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + async def inbound_xfr( where: str, txn_manager: dns.transaction.TransactionManager, @@ -730,6 +824,7 @@ async def quic( connection: Optional[dns.quic.AsyncQuicConnection] = None, verify: Union[bool, str] = True, backend: Optional[dns.asyncbackend.Backend] = None, + hostname: Optional[str] = None, server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending an asynchronous query via @@ -745,6 +840,9 @@ async def quic( if not dns.quic.have_quic: raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + if server_hostname is not None and hostname is None: + hostname = server_hostname + q.id = 0 wire = q.to_wire() the_connection: dns.quic.AsyncQuicConnection @@ -757,7 +855,9 @@ async def quic( async with cfactory() as context: async with mfactory( - context, verify_mode=verify, server_name=server_hostname + context, + verify_mode=verify, + server_name=server_hostname, ) as the_manager: if not connection: the_connection = the_manager.connect(where, port, source, source_port) diff --git a/dns/nameserver.py b/dns/nameserver.py index 5dbb4e8b..e8068e7e 100644 --- a/dns/nameserver.py +++ b/dns/nameserver.py @@ -168,12 +168,14 @@ class DoHNameserver(Nameserver): bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, want_get: bool = False, + h3: bool = False, ): super().__init__() self.url = url self.bootstrap_address = bootstrap_address self.verify = verify self.want_get = want_get + self.h3 = h3 def kind(self): return "DoH" @@ -214,6 +216,7 @@ class DoHNameserver(Nameserver): ignore_trailing=ignore_trailing, verify=self.verify, post=(not self.want_get), + h3=self.h3, ) async def async_query( @@ -238,6 +241,7 @@ class DoHNameserver(Nameserver): ignore_trailing=ignore_trailing, verify=self.verify, post=(not self.want_get), + h3=self.h3, ) diff --git a/dns/query.py b/dns/query.py index 384bf31e..8f82ab67 100644 --- a/dns/query.py +++ b/dns/query.py @@ -23,11 +23,13 @@ import enum import errno import os import os.path +import random import selectors import socket import struct import time -from typing import Any, Dict, Optional, Tuple, Union +import urllib.parse +from typing import Any, Dict, List, Optional, Tuple, Union import dns._features import dns.exception @@ -335,6 +337,20 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None): raise +def _maybe_get_resolver( + resolver: Optional["dns.resolver.Resolver"], +) -> "dns.resolver.Resolver": + # We need a separate method for this to avoid overriding the global + # variable "dns" with the as-yet undefined local variable "dns" + # in https(). + if resolver is None: + # pylint: disable=import-outside-toplevel,redefined-outer-name + import dns.resolver + + resolver = dns.resolver.Resolver() + return resolver + + def https( q: dns.message.Message, where: str, @@ -350,7 +366,8 @@ def https( bootstrap_address: Optional[str] = None, verify: Union[bool, str] = True, resolver: Optional["dns.resolver.Resolver"] = None, - family: Optional[int] = socket.AF_UNSPEC, + family: int = socket.AF_UNSPEC, + h3: bool = False, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -400,20 +417,14 @@ def https( *family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A and AAAA records will be retrieved. + *h3*, a ``bool``. If ``True``, use HTTP/3 otherwise use HTTP/2 or HTTP/1.1. + Returns a ``dns.message.Message``. """ - if not have_doh: - raise NoDOH # pragma: no cover - if session and not isinstance(session, httpx.Client): - raise ValueError("session parameter must be an httpx.Client") - - wire = q.to_wire() (af, _, the_source) = _destination_and_source( where, port, source, source_port, False ) - transport = None - headers = {"accept": "application/dns-message"} if af is not None and dns.inet.is_address(where): if af == socket.AF_INET: url = "https://{}:{}{}".format(where, port, path) @@ -422,6 +433,39 @@ def https( else: url = where + if h3: + if bootstrap_address is None: + parsed = urllib.parse.urlparse(url) + resolver = _maybe_get_resolver(resolver) + if parsed.hostname is None: + raise ValueError("no hostname in URL") + answers = resolver.resolve_name(parsed.hostname, family) + bootstrap_address = random.choice(list(answers.addresses())) + if parsed.port is not None: + port = parsed.port + return _http3( + q, + bootstrap_address, + url, + timeout, + port, + source, + source_port, + one_rr_per_rrset, + ignore_trailing, + verify=verify, + post=post, + ) + + if not have_doh: + raise NoDOH # pragma: no cover + if session and not isinstance(session, httpx.Client): + raise ValueError("session parameter must be an httpx.Client") + + wire = q.to_wire() + transport = None + headers = {"accept": "application/dns-message"} + # set source port and source address if the_source is None: @@ -483,6 +527,79 @@ def https( return r +def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes: + if headers is None: + raise KeyError + for header, value in headers: + if header == name: + return value + raise KeyError + + +def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None: + value = _find_header(headers, b":status") + if value is None: + raise SyntaxError("no :status header in response") + status = int(value) + if status < 0: + raise SyntaxError("status is negative") + if status < 200 or status > 299: + error = "" + if len(wire) > 0: + try: + error = ": " + wire.decode() + except Exception: + pass + raise ValueError(f"{peer} responded with status code {status}{error}") + + +def _http3( + q: dns.message.Message, + where: str, + url: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + verify: Union[bool, str] = True, + hostname: Optional[str] = None, + post: bool = True, +) -> dns.message.Message: + if not dns.quic.have_quic: + raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover + + url_parts = urllib.parse.urlparse(url) + hostname = url_parts.hostname + + q.id = 0 + wire = q.to_wire() + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=hostname, h3=True + ) + + with manager: + connection = manager.connect(where, port, source, source_port) + (start, expiration) = _compute_times(timeout) + with connection.make_stream(timeout) as stream: + stream.send_h3(url, wire, post) + wire = stream.receive(_remaining(expiration)) + _check_status(stream.headers(), where, wire) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + def _udp_recv(sock, max_size, expiration): """Reads a datagram from the socket. A Timeout exception will be raised if the operation is not completed @@ -1168,6 +1285,7 @@ def quic( ignore_trailing: bool = False, connection: Optional[dns.quic.SyncQuicConnection] = None, verify: Union[bool, str] = True, + hostname: Optional[str] = None, server_hostname: Optional[str] = None, ) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-QUIC. @@ -1192,17 +1310,21 @@ def quic( *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the - connection to use to send the query. + *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use + to send the query. *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification of the server is done using the default CA bundle; if ``False``, then no verification is done; if a `str` then it specifies the path to a certificate file or directory which will be used for verification. - *server_hostname*, a ``str`` containing the server's hostname. The - default is ``None``, which means that no hostname is known, and if an - SSL context is created, hostname checking will be disabled. + *hostname*, a ``str`` containing the server's hostname or ``None``. The default is + ``None``, which means that no hostname is known, and if an SSL context is created, + hostname checking will be disabled. This value is ignored if *url* is not + ``None``. + + *server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility + only, and has the same meaning as *hostname*. Returns a ``dns.message.Message``. """ @@ -1210,6 +1332,9 @@ def quic( if not dns.quic.have_quic: raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + if server_hostname is not None and hostname is None: + hostname = server_hostname + q.id = 0 wire = q.to_wire() the_connection: dns.quic.SyncQuicConnection @@ -1218,9 +1343,7 @@ def quic( manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) the_connection = connection else: - manager = dns.quic.SyncQuicManager( - verify_mode=verify, server_name=server_hostname - ) + manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname) the_manager = manager # for type checking happiness with manager: diff --git a/dns/quic/__init__.py b/dns/quic/__init__.py index 20aff345..0750e729 100644 --- a/dns/quic/__init__.py +++ b/dns/quic/__init__.py @@ -1,5 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import List, Tuple + import dns._features import dns.asyncbackend @@ -73,3 +75,6 @@ else: # pragma: no cover class SyncQuicConnection: # type: ignore def make_stream(self) -> Any: raise NotImplementedError + + +Headers = List[Tuple[bytes, bytes]] diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py index 0f44331f..069387f4 100644 --- a/dns/quic/_asyncio.py +++ b/dns/quic/_asyncio.py @@ -43,12 +43,26 @@ class AsyncioQuicStream(BaseQuicStream): raise dns.exception.Timeout self._expecting = 0 + async def wait_for_end(self, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + if self._buffer.seen_end(): + return + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except TimeoutError: + raise dns.exception.Timeout + async def receive(self, timeout=None): expiration = self._expiration_from_timeout(timeout) - await self.wait_for(2, expiration) - (size,) = struct.unpack("!H", self._buffer.get(2)) - await self.wait_for(size, expiration) - return self._buffer.get(size) + if self._connection.is_h3(): + await self.wait_for_end(expiration) + return self._buffer.get_all() + else: + await self.wait_for(2, expiration) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size, expiration) + return self._buffer.get(size) async def send(self, datagram, is_end=False): data = self._encapsulate(datagram) @@ -140,9 +154,28 @@ class AsyncioQuicConnection(AsyncQuicConnection): if event is None: return if isinstance(event, aioquic.quic.events.StreamDataReceived): - stream = self._streams.get(event.stream_id) - if stream: - await stream._add_input(event.data, event.end_stream) + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + await stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input( + h3_event.data, h3_event.stream_ended + ) + else: + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() elif isinstance(event, aioquic.quic.events.ConnectionTerminated): @@ -203,8 +236,10 @@ class AsyncioQuicConnection(AsyncQuicConnection): class AsyncioQuicManager(AsyncQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): - super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name) + def __init__( + self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False + ): + super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3) def connect( self, address, port=853, source=None, source_port=0, want_session_ticket=True diff --git a/dns/quic/_common.py b/dns/quic/_common.py index 0eacc691..5e6c40d3 100644 --- a/dns/quic/_common.py +++ b/dns/quic/_common.py @@ -1,12 +1,16 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +import base64 import copy import functools import socket import struct import time +import urllib from typing import Any, Optional +import aioquic.h3.connection # type: ignore +import aioquic.h3.events # type: ignore import aioquic.quic.configuration # type: ignore import aioquic.quic.connection # type: ignore @@ -51,6 +55,12 @@ class Buffer: self._buffer = self._buffer[amount:] return data + def get_all(self): + assert self.seen_end() + data = self._buffer + self._buffer = b"" + return data + class BaseQuicStream: def __init__(self, connection, stream_id): @@ -58,10 +68,18 @@ class BaseQuicStream: self._stream_id = stream_id self._buffer = Buffer() self._expecting = 0 + self._headers = None + self._trailers = None def id(self): return self._stream_id + def headers(self): + return self._headers + + def trailers(self): + return self._trailers + def _expiration_from_timeout(self, timeout): if timeout is not None: expiration = time.time() + timeout @@ -77,16 +95,51 @@ class BaseQuicStream: return timeout # Subclass must implement receive() as sync / async and which returns a message - # or raises UnexpectedEOF. + # or raises. + + # Subclass must implement send() as sync / async and which takes a message and + # an EOF indicator. + + def send_h3(self, url, datagram, post=True): + if not self._connection.is_h3(): + raise SyntaxError("cannot send H3 to a non-H3 connection") + url_parts = urllib.parse.urlparse(url) + path = url_parts.path.encode() + if post: + method = b"POST" + else: + method = b"GET" + path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=") + headers = [ + (b":method", method), + (b":scheme", url_parts.scheme.encode()), + (b":authority", url_parts.netloc.encode()), + (b":path", path), + (b"accept", b"application/dns-message"), + ] + if post: + headers.extend( + [ + (b"content-type", b"application/dns-message"), + (b"content-length", str(len(datagram)).encode()), + ] + ) + self._connection.send_headers(self._stream_id, headers, not post) + if post: + self._connection.send_data(self._stream_id, datagram, True) def _encapsulate(self, datagram): + if self._connection.is_h3(): + return datagram l = len(datagram) return struct.pack("!H", l) + datagram def _common_add_input(self, data, is_end): self._buffer.put(data, is_end) try: - return self._expecting > 0 and self._buffer.have(self._expecting) + return ( + self._expecting > 0 and self._buffer.have(self._expecting) + ) or self._buffer.seen_end except UnexpectedEOF: return True @@ -97,7 +150,13 @@ class BaseQuicStream: class BaseQuicConnection: def __init__( - self, connection, address, port, source=None, source_port=0, manager=None + self, + connection, + address, + port, + source=None, + source_port=0, + manager=None, ): self._done = False self._connection = connection @@ -106,6 +165,10 @@ class BaseQuicConnection: self._closed = False self._manager = manager self._streams = {} + if manager.is_h3(): + self._h3_conn = aioquic.h3.connection.H3Connection(connection, False) + else: + self._h3_conn = None self._af = dns.inet.af_for_address(address) self._peer = dns.inet.low_level_address_tuple((address, port)) if source is None and source_port != 0: @@ -120,9 +183,18 @@ class BaseQuicConnection: else: self._source = None + def is_h3(self): + return self._h3_conn is not None + def close_stream(self, stream_id): del self._streams[stream_id] + def send_headers(self, stream_id, headers, is_end=False): + self._h3_conn.send_headers(stream_id, headers, is_end) + + def send_data(self, stream_id, data, is_end=False): + self._h3_conn.send_data(stream_id, data, is_end) + def _get_timer_values(self, closed_is_special=True): now = time.time() expiration = self._connection.get_timer() @@ -148,17 +220,24 @@ class AsyncQuicConnection(BaseQuicConnection): class BaseQuicManager: - def __init__(self, conf, verify_mode, connection_factory, server_name=None): + def __init__( + self, conf, verify_mode, connection_factory, server_name=None, h3=False + ): self._connections = {} self._connection_factory = connection_factory self._session_tickets = {} + self._h3 = h3 if conf is None: verify_path = None if isinstance(verify_mode, str): verify_path = verify_mode verify_mode = True + if h3: + alpn_protocols = ["h3"] + else: + alpn_protocols = ["doq", "doq-i03"] conf = aioquic.quic.configuration.QuicConfiguration( - alpn_protocols=["doq", "doq-i03"], + alpn_protocols=alpn_protocols, verify_mode=verify_mode, server_name=server_name, ) @@ -167,7 +246,12 @@ class BaseQuicManager: self._conf = conf def _connect( - self, address, port=853, source=None, source_port=0, want_session_ticket=True + self, + address, + port=853, + source=None, + source_port=0, + want_session_ticket=True, ): connection = self._connections.get((address, port)) if connection is not None: @@ -207,6 +291,9 @@ class BaseQuicManager: except KeyError: pass + def is_h3(self): + return self._h3 + def save_session_ticket(self, address, port, ticket): # We rely on dictionaries keys() being in insertion order here. We # can't just popitem() as that would be LIFO which is the opposite of diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py index 6ef5dc94..a1062f58 100644 --- a/dns/quic/_sync.py +++ b/dns/quic/_sync.py @@ -43,14 +43,29 @@ class SyncQuicStream(BaseQuicStream): raise dns.exception.Timeout self._expecting = 0 + def wait_for_end(self, expiration): + while True: + timeout = self._timeout_from_expiration(expiration) + with self._lock: + if self._buffer.seen_end(): + return + with self._wake_up: + if not self._wake_up.wait(timeout): + raise dns.exception.Timeout + def receive(self, timeout=None): expiration = self._expiration_from_timeout(timeout) - self.wait_for(2, expiration) - with self._lock: - (size,) = struct.unpack("!H", self._buffer.get(2)) - self.wait_for(size, expiration) - with self._lock: - return self._buffer.get(size) + if self._connection.is_h3(): + self.wait_for_end(expiration) + with self._lock: + return self._buffer.get_all() + else: + self.wait_for(2, expiration) + with self._lock: + (size,) = struct.unpack("!H", self._buffer.get(2)) + self.wait_for(size, expiration) + with self._lock: + return self._buffer.get(size) def send(self, datagram, is_end=False): data = self._encapsulate(datagram) @@ -147,10 +162,29 @@ class SyncQuicConnection(BaseQuicConnection): if event is None: return if isinstance(event, aioquic.quic.events.StreamDataReceived): - with self._lock: - stream = self._streams.get(event.stream_id) - if stream: - stream._add_input(event.data, event.end_stream) + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(h3_event.data, h3_event.stream_ended) + else: + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() elif isinstance(event, aioquic.quic.events.ConnectionTerminated): @@ -167,6 +201,18 @@ class SyncQuicConnection(BaseQuicConnection): self._connection.send_stream_data(stream, data, is_end) self._send_wakeup.send(b"\x01") + def send_headers(self, stream_id, headers, is_end=False): + with self._lock: + super().send_headers(stream_id, headers, is_end) + if is_end: + self._send_wakeup.send(b"\x01") + + def send_data(self, stream_id, data, is_end=False): + with self._lock: + super().send_data(stream_id, data, is_end) + if is_end: + self._send_wakeup.send(b"\x01") + def run(self): if self._closed: return @@ -200,12 +246,19 @@ class SyncQuicConnection(BaseQuicConnection): class SyncQuicManager(BaseQuicManager): - def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None): - super().__init__(conf, verify_mode, SyncQuicConnection, server_name) + def __init__( + self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False + ): + super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3) self._lock = threading.Lock() def connect( - self, address, port=853, source=None, source_port=0, want_session_ticket=True + self, + address, + port=853, + source=None, + source_port=0, + want_session_ticket=True, ): with self._lock: (connection, start) = self._connect( diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py index 35e36b98..bf284557 100644 --- a/dns/quic/_trio.py +++ b/dns/quic/_trio.py @@ -36,16 +36,27 @@ class TrioQuicStream(BaseQuicStream): await self._wake_up.wait() self._expecting = 0 + async def wait_for_end(self): + while True: + if self._buffer.seen_end(): + return + async with self._wake_up: + await self._wake_up.wait() + async def receive(self, timeout=None): if timeout is None: context = NullContext(None) else: context = trio.move_on_after(timeout) with context: - await self.wait_for(2) - (size,) = struct.unpack("!H", self._buffer.get(2)) - await self.wait_for(size) - return self._buffer.get(size) + if self._connection.is_h3(): + await self.wait_for_end() + return self._buffer.get_all() + else: + await self.wait_for(2) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size) + return self._buffer.get(size) raise dns.exception.Timeout async def send(self, datagram, is_end=False): @@ -124,9 +135,28 @@ class TrioQuicConnection(AsyncQuicConnection): if event is None: return if isinstance(event, aioquic.quic.events.StreamDataReceived): - stream = self._streams.get(event.stream_id) - if stream: - await stream._add_input(event.data, event.end_stream) + if self.is_h3(): + h3_events = self._h3_conn.handle_event(event) + for h3_event in h3_events: + if isinstance(h3_event, aioquic.h3.events.HeadersReceived): + stream = self._streams.get(event.stream_id) + if stream: + if stream._headers is None: + stream._headers = h3_event.headers + elif stream._trailers is None: + stream._trailers = h3_event.headers + if h3_event.stream_ended: + await stream._add_input(b"", True) + elif isinstance(h3_event, aioquic.h3.events.DataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input( + h3_event.data, h3_event.stream_ended + ) + else: + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) elif isinstance(event, aioquic.quic.events.HandshakeCompleted): self._handshake_complete.set() elif isinstance(event, aioquic.quic.events.ConnectionTerminated): @@ -183,9 +213,14 @@ class TrioQuicConnection(AsyncQuicConnection): class TrioQuicManager(AsyncQuicManager): def __init__( - self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None + self, + nursery, + conf=None, + verify_mode=ssl.CERT_REQUIRED, + server_name=None, + h3=False, ): - super().__init__(conf, verify_mode, TrioQuicConnection, server_name) + super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3) self._nursery = nursery def connect( diff --git a/tests/test_async.py b/tests/test_async.py index 9373548d..18ba96db 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -27,6 +27,7 @@ import dns.asyncresolver import dns.message import dns.name import dns.query +import dns.quic import dns.rcode import dns.rdataclass import dns.rdatatype @@ -78,6 +79,11 @@ KNOWN_ANYCAST_DOH_RESOLVER_URLS = [ # 'https://dns11.quad9.net/dns-query', ] +KNOWN_ANYCAST_DOH3_RESOLVER_URLS = [ + "https://cloudflare-dns.com/dns-query", + "https://dns.google/dns-query", +] + class AsyncDetectionTests(unittest.TestCase): sniff_result = "asyncio" @@ -553,6 +559,40 @@ class AsyncTests(unittest.TestCase): self.async_run(run) + @unittest.skipIf(not dns.quic.have_quic, "aioquic not available") + def testDoH3GetRequest(self): + async def run(): + nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS) + q = dns.message.make_query("dns.google.", dns.rdatatype.A) + r = await dns.asyncquery.https( + q, + nameserver_url, + post=False, + timeout=4, + family=family, + h3=True, + ) + self.assertTrue(q.is_response(r)) + + self.async_run(run) + + @unittest.skipIf(not dns.quic.have_quic, "aioquic not available") + def TestDoH3PostRequest(self): + async def run(): + nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS) + q = dns.message.make_query("dns.google.", dns.rdatatype.A) + r = await dns.asyncquery.https( + q, + nameserver_url, + post=True, + timeout=4, + family=family, + h3=True, + ) + self.assertTrue(q.is_response(r)) + + self.async_run(run) + @unittest.skipIf(not dns.query._have_httpx, "httpx not available") def testResolverDOH(self): async def run(): diff --git a/tests/test_doh.py b/tests/test_doh.py index 0a5908f9..8912dd6c 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -28,6 +28,7 @@ except Exception: import dns.edns import dns.message import dns.query +import dns.quic import dns.rdatatype import dns.resolver @@ -65,6 +66,11 @@ KNOWN_ANYCAST_DOH_RESOLVER_URLS = [ # 'https://dns11.quad9.net/dns-query', ] +KNOWN_ANYCAST_DOH3_RESOLVER_URLS = [ + "https://cloudflare-dns.com/dns-query", + "https://dns.google/dns-query", +] + KNOWN_PAD_AWARE_DOH_RESOLVER_URLS = [ "https://cloudflare-dns.com/dns-query", "https://dns.google/dns-query", @@ -183,5 +189,37 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): self.assertTrue(has_pad) +@unittest.skipUnless( + dns.quic.have_quic and tests.util.is_internet_reachable() and _have_ssl, + "Aioquic cannot be imported; no DNS over HTTP3 (DOH3)", +) +class DNSOverHTTP3TestCase(unittest.TestCase): + def testDoH3GetRequest(self): + nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS) + q = dns.message.make_query("dns.google.", dns.rdatatype.A) + r = dns.query.https( + q, + nameserver_url, + post=False, + timeout=4, + family=family, + h3=True, + ) + self.assertTrue(q.is_response(r)) + + def testDoH3PostRequest(self): + nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS) + q = dns.message.make_query("dns.google.", dns.rdatatype.A) + r = dns.query.https( + q, + nameserver_url, + post=True, + timeout=4, + family=family, + h3=True, + ) + self.assertTrue(q.is_response(r)) + + if __name__ == "__main__": unittest.main() -- 2.47.3