* Implement DNS-over-HTTP3 using aioquic directly.
* Add h3 support for DoHNameserver.
import base64
import contextlib
+import random
import socket
import struct
import time
+import urllib.parse
from typing import Any, Dict, Optional, Tuple, Union
import dns.asyncbackend
NoDOH,
NoDOQ,
UDPMode,
+ _check_status,
_compute_times,
_make_dot_ssl_context,
_matches_destination,
return response
+def _maybe_get_resolver(
+ resolver: Optional["dns.asyncresolver.Resolver"],
+) -> "dns.asyncresolver.Resolver":
+ # We need a separate method for this to avoid overriding the global
+ # variable "dns" with the as-yet undefined local variable "dns"
+ # in https().
+ if resolver is None:
+ # pylint: disable=import-outside-toplevel,redefined-outer-name
+ import dns.asyncresolver
+
+ resolver = dns.asyncresolver.Resolver()
+ return resolver
+
+
async def https(
q: dns.message.Message,
where: str,
verify: Union[bool, str] = True,
bootstrap_address: Optional[str] = None,
resolver: Optional["dns.asyncresolver.Resolver"] = None,
- family: Optional[int] = socket.AF_UNSPEC,
+ family: int = socket.AF_UNSPEC,
+ h3: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
parameters, exceptions, and return type of this method.
"""
- if not have_doh:
- raise NoDOH # pragma: no cover
- if client and not isinstance(client, httpx.AsyncClient):
- raise ValueError("session parameter must be an httpx.AsyncClient")
-
- wire = q.to_wire()
try:
af = dns.inet.af_for_address(where)
except ValueError:
af = None
- transport = None
- headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path)
else:
url = where
+ if h3:
+ if bootstrap_address is None:
+ parsed = urllib.parse.urlparse(url)
+ resolver = _maybe_get_resolver(resolver)
+ if parsed.hostname is None:
+ raise ValueError("no hostname in URL")
+ answers = await resolver.resolve_name(parsed.hostname, family)
+ bootstrap_address = random.choice(list(answers.addresses()))
+ if parsed.port is not None:
+ port = parsed.port
+ return await _http3(
+ q,
+ bootstrap_address,
+ url,
+ timeout,
+ port,
+ source,
+ source_port,
+ one_rr_per_rrset,
+ ignore_trailing,
+ verify=verify,
+ post=post,
+ )
+
+ if not have_doh:
+ raise NoDOH # pragma: no cover
+ if client and not isinstance(client, httpx.AsyncClient):
+ raise ValueError("session parameter must be an httpx.AsyncClient")
+
+ wire = q.to_wire()
+ transport = None
+ headers = {"accept": "application/dns-message"}
+
backend = dns.asyncbackend.get_default_backend()
if source is None:
return r
+async def _http3(
+ q: dns.message.Message,
+ where: str,
+ url: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ verify: Union[bool, str] = True,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+ hostname: Optional[str] = None,
+ post: bool = True,
+) -> dns.message.Message:
+ if not dns.quic.have_quic:
+ raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
+
+ url_parts = urllib.parse.urlparse(url)
+ hostname = url_parts.hostname
+
+ q.id = 0
+ wire = q.to_wire()
+ (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
+
+ async with cfactory() as context:
+ async with mfactory(
+ context, verify_mode=verify, server_name=hostname, h3=True
+ ) as the_manager:
+ the_connection = the_manager.connect(where, port, source, source_port)
+ (start, expiration) = _compute_times(timeout)
+ stream = await the_connection.make_stream(timeout)
+ async with stream:
+ # note that send_h3() does not need await
+ stream.send_h3(url, wire, post)
+ wire = await stream.receive(_remaining(expiration))
+ _check_status(stream.headers(), where, wire)
+ finish = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = max(finish - start, 0.0)
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
async def inbound_xfr(
where: str,
txn_manager: dns.transaction.TransactionManager,
connection: Optional[dns.quic.AsyncQuicConnection] = None,
verify: Union[bool, str] = True,
backend: Optional[dns.asyncbackend.Backend] = None,
+ hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending an asynchronous query via
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
+ if server_hostname is not None and hostname is None:
+ hostname = server_hostname
+
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.AsyncQuicConnection
async with cfactory() as context:
async with mfactory(
- context, verify_mode=verify, server_name=server_hostname
+ context,
+ verify_mode=verify,
+ server_name=server_hostname,
) as the_manager:
if not connection:
the_connection = the_manager.connect(where, port, source, source_port)
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
want_get: bool = False,
+ h3: bool = False,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify
self.want_get = want_get
+ self.h3 = h3
def kind(self):
return "DoH"
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
+ h3=self.h3,
)
async def async_query(
ignore_trailing=ignore_trailing,
verify=self.verify,
post=(not self.want_get),
+ h3=self.h3,
)
import errno
import os
import os.path
+import random
import selectors
import socket
import struct
import time
-from typing import Any, Dict, Optional, Tuple, Union
+import urllib.parse
+from typing import Any, Dict, List, Optional, Tuple, Union
import dns._features
import dns.exception
raise
+def _maybe_get_resolver(
+ resolver: Optional["dns.resolver.Resolver"],
+) -> "dns.resolver.Resolver":
+ # We need a separate method for this to avoid overriding the global
+ # variable "dns" with the as-yet undefined local variable "dns"
+ # in https().
+ if resolver is None:
+ # pylint: disable=import-outside-toplevel,redefined-outer-name
+ import dns.resolver
+
+ resolver = dns.resolver.Resolver()
+ return resolver
+
+
def https(
q: dns.message.Message,
where: str,
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
resolver: Optional["dns.resolver.Resolver"] = None,
- family: Optional[int] = socket.AF_UNSPEC,
+ family: int = socket.AF_UNSPEC,
+ h3: bool = False,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*family*, an ``int``, the address family. If socket.AF_UNSPEC (the default), both A
and AAAA records will be retrieved.
+ *h3*, a ``bool``. If ``True``, use HTTP/3 otherwise use HTTP/2 or HTTP/1.1.
+
Returns a ``dns.message.Message``.
"""
- if not have_doh:
- raise NoDOH # pragma: no cover
- if session and not isinstance(session, httpx.Client):
- raise ValueError("session parameter must be an httpx.Client")
-
- wire = q.to_wire()
(af, _, the_source) = _destination_and_source(
where, port, source, source_port, False
)
- transport = None
- headers = {"accept": "application/dns-message"}
if af is not None and dns.inet.is_address(where):
if af == socket.AF_INET:
url = "https://{}:{}{}".format(where, port, path)
else:
url = where
+ if h3:
+ if bootstrap_address is None:
+ parsed = urllib.parse.urlparse(url)
+ resolver = _maybe_get_resolver(resolver)
+ if parsed.hostname is None:
+ raise ValueError("no hostname in URL")
+ answers = resolver.resolve_name(parsed.hostname, family)
+ bootstrap_address = random.choice(list(answers.addresses()))
+ if parsed.port is not None:
+ port = parsed.port
+ return _http3(
+ q,
+ bootstrap_address,
+ url,
+ timeout,
+ port,
+ source,
+ source_port,
+ one_rr_per_rrset,
+ ignore_trailing,
+ verify=verify,
+ post=post,
+ )
+
+ if not have_doh:
+ raise NoDOH # pragma: no cover
+ if session and not isinstance(session, httpx.Client):
+ raise ValueError("session parameter must be an httpx.Client")
+
+ wire = q.to_wire()
+ transport = None
+ headers = {"accept": "application/dns-message"}
+
# set source port and source address
if the_source is None:
return r
+def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
+ if headers is None:
+ raise KeyError
+ for header, value in headers:
+ if header == name:
+ return value
+ raise KeyError
+
+
+def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
+ value = _find_header(headers, b":status")
+ if value is None:
+ raise SyntaxError("no :status header in response")
+ status = int(value)
+ if status < 0:
+ raise SyntaxError("status is negative")
+ if status < 200 or status > 299:
+ error = ""
+ if len(wire) > 0:
+ try:
+ error = ": " + wire.decode()
+ except Exception:
+ pass
+ raise ValueError(f"{peer} responded with status code {status}{error}")
+
+
+def _http3(
+ q: dns.message.Message,
+ where: str,
+ url: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ verify: Union[bool, str] = True,
+ hostname: Optional[str] = None,
+ post: bool = True,
+) -> dns.message.Message:
+ if not dns.quic.have_quic:
+ raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover
+
+ url_parts = urllib.parse.urlparse(url)
+ hostname = url_parts.hostname
+
+ q.id = 0
+ wire = q.to_wire()
+ manager = dns.quic.SyncQuicManager(
+ verify_mode=verify, server_name=hostname, h3=True
+ )
+
+ with manager:
+ connection = manager.connect(where, port, source, source_port)
+ (start, expiration) = _compute_times(timeout)
+ with connection.make_stream(timeout) as stream:
+ stream.send_h3(url, wire, post)
+ wire = stream.receive(_remaining(expiration))
+ _check_status(stream.headers(), where, wire)
+ finish = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = max(finish - start, 0.0)
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
def _udp_recv(sock, max_size, expiration):
"""Reads a datagram from the socket.
A Timeout exception will be raised if the operation is not completed
ignore_trailing: bool = False,
connection: Optional[dns.quic.SyncQuicConnection] = None,
verify: Union[bool, str] = True,
+ hostname: Optional[str] = None,
server_hostname: Optional[str] = None,
) -> dns.message.Message:
"""Return the response obtained after sending a query via DNS-over-QUIC.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
received message.
- *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the
- connection to use to send the query.
+ *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the connection to use
+ to send the query.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
- *server_hostname*, a ``str`` containing the server's hostname. The
- default is ``None``, which means that no hostname is known, and if an
- SSL context is created, hostname checking will be disabled.
+ *hostname*, a ``str`` containing the server's hostname or ``None``. The default is
+ ``None``, which means that no hostname is known, and if an SSL context is created,
+ hostname checking will be disabled. This value is ignored if *url* is not
+ ``None``.
+
+ *server_hostname*, a ``str`` or ``None``. This item is for backwards compatibility
+ only, and has the same meaning as *hostname*.
Returns a ``dns.message.Message``.
"""
if not dns.quic.have_quic:
raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
+ if server_hostname is not None and hostname is None:
+ hostname = server_hostname
+
q.id = 0
wire = q.to_wire()
the_connection: dns.quic.SyncQuicConnection
manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
the_connection = connection
else:
- manager = dns.quic.SyncQuicManager(
- verify_mode=verify, server_name=server_hostname
- )
+ manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
the_manager = manager # for type checking happiness
with manager:
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+from typing import List, Tuple
+
import dns._features
import dns.asyncbackend
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError
+
+
+Headers = List[Tuple[bytes, bytes]]
raise dns.exception.Timeout
self._expecting = 0
+ async def wait_for_end(self, expiration):
+ while True:
+ timeout = self._timeout_from_expiration(expiration)
+ if self._buffer.seen_end():
+ return
+ try:
+ await asyncio.wait_for(self._wait_for_wake_up(), timeout)
+ except TimeoutError:
+ raise dns.exception.Timeout
+
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
- await self.wait_for(2, expiration)
- (size,) = struct.unpack("!H", self._buffer.get(2))
- await self.wait_for(size, expiration)
- return self._buffer.get(size)
+ if self._connection.is_h3():
+ await self.wait_for_end(expiration)
+ return self._buffer.get_all()
+ else:
+ await self.wait_for(2, expiration)
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ await self.wait_for(size, expiration)
+ return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
- stream = self._streams.get(event.stream_id)
- if stream:
- await stream._add_input(event.data, event.end_stream)
+ if self.is_h3():
+ h3_events = self._h3_conn.handle_event(event)
+ for h3_event in h3_events:
+ if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ if stream._headers is None:
+ stream._headers = h3_event.headers
+ elif stream._trailers is None:
+ stream._trailers = h3_event.headers
+ if h3_event.stream_ended:
+ await stream._add_input(b"", True)
+ elif isinstance(h3_event, aioquic.h3.events.DataReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(
+ h3_event.data, h3_event.stream_ended
+ )
+ else:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
class AsyncioQuicManager(AsyncQuicManager):
- def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
- super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name)
+ def __init__(
+ self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
+ ):
+ super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+import base64
import copy
import functools
import socket
import struct
import time
+import urllib
from typing import Any, Optional
+import aioquic.h3.connection # type: ignore
+import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
self._buffer = self._buffer[amount:]
return data
+ def get_all(self):
+ assert self.seen_end()
+ data = self._buffer
+ self._buffer = b""
+ return data
+
class BaseQuicStream:
def __init__(self, connection, stream_id):
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
+ self._headers = None
+ self._trailers = None
def id(self):
return self._stream_id
+ def headers(self):
+ return self._headers
+
+ def trailers(self):
+ return self._trailers
+
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
return timeout
# Subclass must implement receive() as sync / async and which returns a message
- # or raises UnexpectedEOF.
+ # or raises.
+
+ # Subclass must implement send() as sync / async and which takes a message and
+ # an EOF indicator.
+
+ def send_h3(self, url, datagram, post=True):
+ if not self._connection.is_h3():
+ raise SyntaxError("cannot send H3 to a non-H3 connection")
+ url_parts = urllib.parse.urlparse(url)
+ path = url_parts.path.encode()
+ if post:
+ method = b"POST"
+ else:
+ method = b"GET"
+ path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
+ headers = [
+ (b":method", method),
+ (b":scheme", url_parts.scheme.encode()),
+ (b":authority", url_parts.netloc.encode()),
+ (b":path", path),
+ (b"accept", b"application/dns-message"),
+ ]
+ if post:
+ headers.extend(
+ [
+ (b"content-type", b"application/dns-message"),
+ (b"content-length", str(len(datagram)).encode()),
+ ]
+ )
+ self._connection.send_headers(self._stream_id, headers, not post)
+ if post:
+ self._connection.send_data(self._stream_id, datagram, True)
def _encapsulate(self, datagram):
+ if self._connection.is_h3():
+ return datagram
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
try:
- return self._expecting > 0 and self._buffer.have(self._expecting)
+ return (
+ self._expecting > 0 and self._buffer.have(self._expecting)
+ ) or self._buffer.seen_end
except UnexpectedEOF:
return True
class BaseQuicConnection:
def __init__(
- self, connection, address, port, source=None, source_port=0, manager=None
+ self,
+ connection,
+ address,
+ port,
+ source=None,
+ source_port=0,
+ manager=None,
):
self._done = False
self._connection = connection
self._closed = False
self._manager = manager
self._streams = {}
+ if manager.is_h3():
+ self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
+ else:
+ self._h3_conn = None
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
else:
self._source = None
+ def is_h3(self):
+ return self._h3_conn is not None
+
def close_stream(self, stream_id):
del self._streams[stream_id]
+ def send_headers(self, stream_id, headers, is_end=False):
+ self._h3_conn.send_headers(stream_id, headers, is_end)
+
+ def send_data(self, stream_id, data, is_end=False):
+ self._h3_conn.send_data(stream_id, data, is_end)
+
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
class BaseQuicManager:
- def __init__(self, conf, verify_mode, connection_factory, server_name=None):
+ def __init__(
+ self, conf, verify_mode, connection_factory, server_name=None, h3=False
+ ):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
+ self._h3 = h3
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
+ if h3:
+ alpn_protocols = ["h3"]
+ else:
+ alpn_protocols = ["doq", "doq-i03"]
conf = aioquic.quic.configuration.QuicConfiguration(
- alpn_protocols=["doq", "doq-i03"],
+ alpn_protocols=alpn_protocols,
verify_mode=verify_mode,
server_name=server_name,
)
self._conf = conf
def _connect(
- self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ self,
+ address,
+ port=853,
+ source=None,
+ source_port=0,
+ want_session_ticket=True,
):
connection = self._connections.get((address, port))
if connection is not None:
except KeyError:
pass
+ def is_h3(self):
+ return self._h3
+
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
raise dns.exception.Timeout
self._expecting = 0
+ def wait_for_end(self, expiration):
+ while True:
+ timeout = self._timeout_from_expiration(expiration)
+ with self._lock:
+ if self._buffer.seen_end():
+ return
+ with self._wake_up:
+ if not self._wake_up.wait(timeout):
+ raise dns.exception.Timeout
+
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
- self.wait_for(2, expiration)
- with self._lock:
- (size,) = struct.unpack("!H", self._buffer.get(2))
- self.wait_for(size, expiration)
- with self._lock:
- return self._buffer.get(size)
+ if self._connection.is_h3():
+ self.wait_for_end(expiration)
+ with self._lock:
+ return self._buffer.get_all()
+ else:
+ self.wait_for(2, expiration)
+ with self._lock:
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ self.wait_for(size, expiration)
+ with self._lock:
+ return self._buffer.get(size)
def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
- with self._lock:
- stream = self._streams.get(event.stream_id)
- if stream:
- stream._add_input(event.data, event.end_stream)
+ if self.is_h3():
+ h3_events = self._h3_conn.handle_event(event)
+ for h3_event in h3_events:
+ if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
+ with self._lock:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ if stream._headers is None:
+ stream._headers = h3_event.headers
+ elif stream._trailers is None:
+ stream._trailers = h3_event.headers
+ if h3_event.stream_ended:
+ stream._add_input(b"", True)
+ elif isinstance(h3_event, aioquic.h3.events.DataReceived):
+ with self._lock:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ stream._add_input(h3_event.data, h3_event.stream_ended)
+ else:
+ with self._lock:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
+ def send_headers(self, stream_id, headers, is_end=False):
+ with self._lock:
+ super().send_headers(stream_id, headers, is_end)
+ if is_end:
+ self._send_wakeup.send(b"\x01")
+
+ def send_data(self, stream_id, data, is_end=False):
+ with self._lock:
+ super().send_data(stream_id, data, is_end)
+ if is_end:
+ self._send_wakeup.send(b"\x01")
+
def run(self):
if self._closed:
return
class SyncQuicManager(BaseQuicManager):
- def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None):
- super().__init__(conf, verify_mode, SyncQuicConnection, server_name)
+ def __init__(
+ self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
+ ):
+ super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
self._lock = threading.Lock()
def connect(
- self, address, port=853, source=None, source_port=0, want_session_ticket=True
+ self,
+ address,
+ port=853,
+ source=None,
+ source_port=0,
+ want_session_ticket=True,
):
with self._lock:
(connection, start) = self._connect(
await self._wake_up.wait()
self._expecting = 0
+ async def wait_for_end(self):
+ while True:
+ if self._buffer.seen_end():
+ return
+ async with self._wake_up:
+ await self._wake_up.wait()
+
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
- await self.wait_for(2)
- (size,) = struct.unpack("!H", self._buffer.get(2))
- await self.wait_for(size)
- return self._buffer.get(size)
+ if self._connection.is_h3():
+ await self.wait_for_end()
+ return self._buffer.get_all()
+ else:
+ await self.wait_for(2)
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ await self.wait_for(size)
+ return self._buffer.get(size)
raise dns.exception.Timeout
async def send(self, datagram, is_end=False):
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
- stream = self._streams.get(event.stream_id)
- if stream:
- await stream._add_input(event.data, event.end_stream)
+ if self.is_h3():
+ h3_events = self._h3_conn.handle_event(event)
+ for h3_event in h3_events:
+ if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ if stream._headers is None:
+ stream._headers = h3_event.headers
+ elif stream._trailers is None:
+ stream._trailers = h3_event.headers
+ if h3_event.stream_ended:
+ await stream._add_input(b"", True)
+ elif isinstance(h3_event, aioquic.h3.events.DataReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(
+ h3_event.data, h3_event.stream_ended
+ )
+ else:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
class TrioQuicManager(AsyncQuicManager):
def __init__(
- self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None
+ self,
+ nursery,
+ conf=None,
+ verify_mode=ssl.CERT_REQUIRED,
+ server_name=None,
+ h3=False,
):
- super().__init__(conf, verify_mode, TrioQuicConnection, server_name)
+ super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
self._nursery = nursery
def connect(
import dns.message
import dns.name
import dns.query
+import dns.quic
import dns.rcode
import dns.rdataclass
import dns.rdatatype
# 'https://dns11.quad9.net/dns-query',
]
+KNOWN_ANYCAST_DOH3_RESOLVER_URLS = [
+ "https://cloudflare-dns.com/dns-query",
+ "https://dns.google/dns-query",
+]
+
class AsyncDetectionTests(unittest.TestCase):
sniff_result = "asyncio"
self.async_run(run)
+ @unittest.skipIf(not dns.quic.have_quic, "aioquic not available")
+ def testDoH3GetRequest(self):
+ async def run():
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+ q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+ r = await dns.asyncquery.https(
+ q,
+ nameserver_url,
+ post=False,
+ timeout=4,
+ family=family,
+ h3=True,
+ )
+ self.assertTrue(q.is_response(r))
+
+ self.async_run(run)
+
+ @unittest.skipIf(not dns.quic.have_quic, "aioquic not available")
+ def TestDoH3PostRequest(self):
+ async def run():
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+ q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+ r = await dns.asyncquery.https(
+ q,
+ nameserver_url,
+ post=True,
+ timeout=4,
+ family=family,
+ h3=True,
+ )
+ self.assertTrue(q.is_response(r))
+
+ self.async_run(run)
+
@unittest.skipIf(not dns.query._have_httpx, "httpx not available")
def testResolverDOH(self):
async def run():
import dns.edns
import dns.message
import dns.query
+import dns.quic
import dns.rdatatype
import dns.resolver
# 'https://dns11.quad9.net/dns-query',
]
+KNOWN_ANYCAST_DOH3_RESOLVER_URLS = [
+ "https://cloudflare-dns.com/dns-query",
+ "https://dns.google/dns-query",
+]
+
KNOWN_PAD_AWARE_DOH_RESOLVER_URLS = [
"https://cloudflare-dns.com/dns-query",
"https://dns.google/dns-query",
self.assertTrue(has_pad)
+@unittest.skipUnless(
+ dns.quic.have_quic and tests.util.is_internet_reachable() and _have_ssl,
+ "Aioquic cannot be imported; no DNS over HTTP3 (DOH3)",
+)
+class DNSOverHTTP3TestCase(unittest.TestCase):
+ def testDoH3GetRequest(self):
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+ q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+ r = dns.query.https(
+ q,
+ nameserver_url,
+ post=False,
+ timeout=4,
+ family=family,
+ h3=True,
+ )
+ self.assertTrue(q.is_response(r))
+
+ def testDoH3PostRequest(self):
+ nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
+ q = dns.message.make_query("dns.google.", dns.rdatatype.A)
+ r = dns.query.https(
+ q,
+ nameserver_url,
+ post=True,
+ timeout=4,
+ family=family,
+ h3=True,
+ )
+ self.assertTrue(q.is_response(r))
+
+
if __name__ == "__main__":
unittest.main()