From: Bob Halley Date: Sat, 1 Oct 2022 22:24:13 +0000 (-0700) Subject: Initial DoQ support. X-Git-Tag: v2.3.0rc1~31^2~2 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=948d5a6a2b3eb08647acaf2a67df92e9bcc2979a;p=thirdparty%2Fdnspython.git Initial DoQ support. --- diff --git a/dns/__init__.py b/dns/__init__.py index 196be22d..9abdf018 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -38,6 +38,7 @@ __all__ = [ "node", "opcode", "query", + "quic", "rcode", "rdata", "rdataclass", diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 107f7667..15612e58 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -30,6 +30,7 @@ import dns.exception import dns.inet import dns.name import dns.message +import dns.quic import dns.rcode import dns.rdataclass import dns.rdatatype @@ -45,6 +46,7 @@ from dns.query import ( _have_httpx, _have_http2, NoDOH, + NoDOQ, ) if _have_httpx: @@ -670,3 +672,62 @@ async def inbound_xfr( tsig_ctx = r.tsig_ctx if not retry and query.keyring and not r.had_tsig: raise dns.exception.FormError("missing TSIG") + + +async def quic( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + connection: Optional[dns.quic.AsyncQuicConnection] = None, + verify: Union[bool, str] = True, + backend: Optional[dns.asyncbackend.Backend] = None, +) -> dns.message.Message: + """Return the response obtained after sending an asynchronous query via + DNS-over-QUIC. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + See :py:func:`dns.query.doq()` for the documentation of the other + parameters, exceptions, and return type of this method. + """ + + if not dns.quic.have_quic: + raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + + q.id = 0 + wire = q.to_wire() + the_connection: dns.quic.AsyncQuicConnection + if connection: + cfactory = dns.quic.null_factory + mfactory = dns.quic.null_factory + the_connection = connection + else: + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + + async with cfactory() as context: + async with mfactory(context, verify_mode=verify) as the_manager: + if not connection: + the_connection = the_manager.connect(where, port, source, source_port) + start = time.time() + stream = await the_connection.make_stream() + async with stream: + await stream.send(wire, True) + wire = await stream.receive(timeout) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r diff --git a/dns/dnssec.py b/dns/dnssec.py index b325f9f8..13415bdf 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -269,7 +269,7 @@ def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) keyptr = keyptr[octets:] dsa_y = keyptr[0:octets] try: - dsa_public_key = dsa.DSAPublicNumbers( + dsa_public_key = dsa.DSAPublicNumbers( # type: ignore _bytes_to_long(dsa_y), dsa.DSAParameterNumbers( _bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g) diff --git a/dns/query.py b/dns/query.py index 5663e234..871b67f4 100644 --- a/dns/query.py +++ b/dns/query.py @@ -34,6 +34,7 @@ import dns.exception import dns.inet import dns.name import dns.message +import dns.quic import dns.rcode import dns.rdataclass import dns.rdatatype @@ -108,6 +109,11 @@ class NoDOH(dns.exception.DNSException): available.""" +class NoDOQ(dns.exception.DNSException): + """DNS over QUIC (DOQ) was requested but the aioquic module is not + available.""" + + # for backwards compatibility TransferError = dns.xfr.TransferError @@ -1059,6 +1065,86 @@ def tls( ) +def quic( + q: dns.message.Message, + where: str, + timeout: Optional[float] = None, + port: int = 853, + source: Optional[str] = None, + source_port: int = 0, + one_rr_per_rrset: bool = False, + ignore_trailing: bool = False, + connection: Optional[dns.quic.SyncQuicConnection] = None, + verify: Union[bool, str] = True, +) -> dns.message.Message: + """Return the response obtained after sending a query via DNS-over-QUIC. + + *q*, a ``dns.message.Message``, the query to send. + + *where*, a ``str``, the nameserver IP address. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query + times out. If ``None``, the default, wait forever. + + *port*, a ``int``, the port to send the query to. The default is 443. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source + address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. The default is + 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the + received message. + + *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the + connection to use to send the query. + + *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification + of the server is done using the default CA bundle; if ``False``, then no + verification is done; if a `str` then it specifies the path to a certificate file or + directory which will be used for verification. + + Returns a ``dns.message.Message``. + """ + + if not dns.quic.have_quic: + raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover + + q.id = 0 + wire = q.to_wire() + the_connection: dns.quic.SyncQuicConnection + the_manager: dns.quic.SyncQuicManager + if connection: + manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) + the_connection = connection + else: + manager = dns.quic.SyncQuicManager(verify_mode=verify) + the_manager = manager # for type checking happiness + + with manager: + if not connection: + the_connection = the_manager.connect(where, port) + start = time.time() + with the_connection.make_stream() as stream: + stream.send(wire, True) + wire = stream.receive(timeout) + finish = time.time() + r = dns.message.from_wire( + wire, + keyring=q.keyring, + request_mac=q.request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + ) + r.time = max(finish - start, 0.0) + if not q.is_response(r): + raise BadResponse + return r + + def xfr( where: str, zone: Union[dns.name.Name, str], diff --git a/dns/quic/__init__.py b/dns/quic/__init__.py new file mode 100644 index 00000000..88c58624 --- /dev/null +++ b/dns/quic/__init__.py @@ -0,0 +1,65 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +try: + import aioquic.quic.configuration # type: ignore + + import dns.asyncbackend + from dns._asyncbackend import NullContext + from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream + from dns.quic._asyncio import ( + AsyncioQuicManager, + AsyncioQuicConnection, + AsyncioQuicStream, + ) + from dns.quic._common import AsyncQuicConnection, AsyncQuicManager + + have_quic = True + + def null_factory(*args, **kwargs): + return NullContext(None) + + def _asyncio_manager_factory(context, *args, **kwargs): + return AsyncioQuicManager(*args, **kwargs) + + # We have a context factory and a manager factory as for trio we need to have + # a nursery. + + _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} + + try: + import trio + from dns.quic._trio import TrioQuicManager, TrioQuicConnection, TrioQuicStream + + def _trio_context_factory(): + return trio.open_nursery() + + def _trio_manager_factory(context, *args, **kwargs): + return TrioQuicManager(context, *args, **kwargs) + + _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory) + except ImportError: + pass + + def factories_for_backend(backend=None): + if backend is None: + backend = dns.asyncbackend.get_default_backend() + return _async_factories[backend.name()] + +except ImportError: + have_quic = False + + from typing import Any + + class AsyncQuicStream: # type: ignore + pass + + class AsyncQuicConnection: # type: ignore + async def make_stream(self) -> Any: + raise NotImplementedError + + class SyncQuicStream: # type: ignore + pass + + class SyncQuicConnection: # type: ignore + def make_stream(self) -> Any: + raise NotImplementedError diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py new file mode 100644 index 00000000..0a2e220d --- /dev/null +++ b/dns/quic/_asyncio.py @@ -0,0 +1,206 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import asyncio +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import dns.inet +import dns.asyncbackend + +from dns.quic._common import ( + BaseQuicStream, + AsyncQuicConnection, + AsyncQuicManager, + QUIC_MAX_DATAGRAM, +) + + +class AsyncioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = asyncio.Condition() + + async def _wait_for_wake_up(self): + async with self._wake_up: + await self._wake_up.wait() + + async def wait_for(self, amount, expiration): + timeout = self._timeout_from_expiration(expiration) + while True: + if self._buffer.have(amount): + return + self._expecting = amount + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except Exception: + pass + self._expecting = 0 + + async def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + await self.wait_for(2, expiration) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size, expiration) + return self._buffer.get(size) + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class AsyncioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = None + self._handshake_complete = asyncio.Event() + self._socket_created = asyncio.Event() + self._wake_timer = asyncio.Condition() + self._receiver_task = None + self._sender_task = None + + async def _receiver(self): + try: + af = dns.inet.af_for_address(self._address) + backend = dns.asyncbackend.get_backend("asyncio") + self._socket = await backend.make_socket( + af, socket.SOCK_DGRAM, 0, self._source, self._peer + ) + self._socket_created.set() + async with self._socket: + while not self._done: + (datagram, address) = await self._socket.recvfrom( + QUIC_MAX_DATAGRAM, None + ) + if address[0] != self._peer[0] or address[1] != self._peer[1]: + continue + self._connection.receive_datagram( + datagram, self._peer[0], time.time() + ) + # Wake up the timer in case the sender is sleeping, as there may be + # stuff to send now. + async with self._wake_timer: + self._wake_timer.notify_all() + except Exception: + pass + + async def _wait_for_wake_timer(self): + async with self._wake_timer: + await self._wake_timer.wait() + + async def _sender(self): + await self._socket_created.wait() + while not self._done: + datagrams = self._connection.datagrams_to_send(time.time()) + for (datagram, address) in datagrams: + assert address == self._peer[0] + await self._socket.sendto(datagram, self._peer, None) + (expiration, interval) = self._get_timer_values() + try: + await asyncio.wait_for(self._wait_for_wake_timer(), interval) + except Exception: + pass + self._handle_timer(expiration) + await self._handle_events() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance( + event, aioquic.quic.events.ConnectionTerminated + ) or isinstance(event, aioquic.quic.events.StreamReset): + self._done = True + self._receiver_task.cancel() + count += 1 + if count > 10: + # yield + count = 0 + await asyncio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + async with self._wake_timer: + self._wake_timer.notify_all() + + def run(self): + if self._closed: + return + self._receiver_task = asyncio.Task(self._receiver()) + self._sender_task = asyncio.Task(self._sender()) + + async def make_stream(self): + await self._handshake_complete.wait() + stream_id = self._connection.get_next_available_stream_id(False) + stream = AsyncioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + async with self._wake_timer: + self._wake_timer.notify_all() + try: + await self._receiver_task + except asyncio.CancelledError: + pass + try: + await self._sender_task + except asyncio.CancelledError: + pass + + +class AsyncioQuicManager(AsyncQuicManager): + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): + super().__init__(conf, verify_mode, AsyncioQuicConnection) + + def connect(self, address, port=853, source=None, source_port=0): + (connection, start) = self._connect(address, port, source, source_port) + if start: + connection.run() + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the itertor into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False diff --git a/dns/quic/_common.py b/dns/quic/_common.py new file mode 100644 index 00000000..2b14c232 --- /dev/null +++ b/dns/quic/_common.py @@ -0,0 +1,181 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import dns.inet +from dns._asyncbackend import NullContext + +from typing import Any + + +QUIC_MAX_DATAGRAM = 2048 + + +class UnexpectedEOF(Exception): + pass + + +class Buffer: + def __init__(self): + self._buffer = b"" + self._seen_end = False + + def put(self, data, is_end): + if self._seen_end: + return + self._buffer += data + if is_end: + self._seen_end = True + + def have(self, amount): + if len(self._buffer) >= amount: + return True + if self._seen_end: + raise UnexpectedEOF + return False + + def seen_end(self): + return self._seen_end + + def get(self, amount): + assert self.have(amount) + data = self._buffer[:amount] + self._buffer = self._buffer[amount:] + return data + + +class BaseQuicStream: + def __init__(self, connection, stream_id): + self._connection = connection + self._stream_id = stream_id + self._buffer = Buffer() + self._expecting = 0 + + def id(self): + return self._stream_id + + def _expiration_from_timeout(self, timeout): + if timeout is not None: + expiration = time.time() + timeout + else: + expiration = None + return expiration + + def _timeout_from_expiration(self, expiration): + if expiration is not None: + timeout = max(expiration - time.time(), 0.0) + else: + timeout = None + return timeout + + # Subclass must implement receive() as sync / async and which returns a message + # or raises UnexpectedEOF. + + def _encapsulate(self, datagram): + l = len(datagram) + return struct.pack("!H", l) + datagram + + def _common_add_input(self, data, is_end): + self._buffer.put(data, is_end) + return self._expecting > 0 and self._buffer.have(self._expecting) + + def _close(self): + self._connection.close_stream(self._stream_id) + self._buffer.put(b"", True) # send EOF in case we haven't seen it. + + +class BaseQuicConnection: + def __init__( + self, connection, address, port, source=None, source_port=0, manager=None + ): + self._done = False + self._connection = connection + self._address = address + self._port = port + self._closed = False + self._manager = manager + self._streams = {} + self._af = dns.inet.af_for_address(address) + self._peer = dns.inet.low_level_address_tuple((address, port)) + if source is None and source_port != 0: + if self._af == socket.AF_INET: + source = "0.0.0.0" + elif self._af == socket.AF_INET6: + source = "::" + else: + raise NotImplementedError + if source: + self._source = (source, source_port) + else: + self._source = None + + def close_stream(self, stream_id): + del self._streams[stream_id] + + def _get_timer_values(self, closed_is_special=True): + now = time.time() + expiration = self._connection.get_timer() + if expiration is None: + expiration = now + 3600 # arbitrary "big" value + interval = max(expiration - now, 0) + if self._closed and closed_is_special: + # lower sleep interval to avoid a race in the closing process + # which can lead to higher latency closing due to sleeping when + # we have events. + interval = min(interval, 0.05) + return (expiration, interval) + + def _handle_timer(self, expiration): + now = time.time() + if expiration <= now: + self._connection.handle_timer(now) + + +class AsyncQuicConnection(BaseQuicConnection): + async def make_stream(self) -> Any: + pass + + +class BaseQuicManager: + def __init__(self, conf, verify_mode, connection_factory): + self._connections = {} + self._connection_factory = connection_factory + if conf is None: + verify_path = None + if isinstance(verify_mode, str): + verify_path = verify_mode + verify_mode = True + conf = aioquic.quic.configuration.QuicConfiguration( + alpn_protocols=["doq", "doq-i03"], + verify_mode=verify_mode, + ) + if verify_path is not None: + conf.load_verify_locations(verify_path) + self._conf = conf + + def _connect(self, address, port=853, source=None, source_port=0): + connection = self._connections.get((address, port)) + if connection is not None: + return (connection, False) + qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf) + qconn.connect(address, time.time()) + connection = self._connection_factory( + qconn, address, port, source, source_port, self + ) + self._connections[(address, port)] = connection + return (connection, True) + + def closed(self, address, port): + try: + del self._connections[(address, port)] + except KeyError: + pass + + +class AsyncQuicManager(BaseQuicManager): + def connect(self, address, port=853, source=None, source_port=0): + raise NotImplementedError diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py new file mode 100644 index 00000000..be005ba9 --- /dev/null +++ b/dns/quic/_sync.py @@ -0,0 +1,214 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import ssl +import selectors +import struct +import threading +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import dns.inet + +from dns.quic._common import ( + BaseQuicStream, + BaseQuicConnection, + BaseQuicManager, + QUIC_MAX_DATAGRAM, +) + +# Avoid circularity with dns.query +if hasattr(selectors, "PollSelector"): + _selector_class = selectors.PollSelector # type: ignore +else: + _selector_class = selectors.SelectSelector # type: ignore + + +class SyncQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = threading.Condition() + self._lock = threading.Lock() + + def wait_for(self, amount, expiration): + timeout = self._timeout_from_expiration(expiration) + while True: + with self._lock: + if self._buffer.have(amount): + return + self._expecting = amount + with self._wake_up: + self._wake_up.wait(timeout) + self._expecting = 0 + + def receive(self, timeout=None): + expiration = self._expiration_from_timeout(timeout) + self.wait_for(2, expiration) + with self._lock: + (size,) = struct.unpack("!H", self._buffer.get(2)) + self.wait_for(size, expiration) + with self._lock: + return self._buffer.get(size) + + def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + self._connection.write(self._stream_id, data, is_end) + + def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + with self._wake_up: + self._wake_up.notify() + + def close(self): + with self._lock: + self._close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + with self._wake_up: + self._wake_up.notify() + return False + + +class SyncQuicConnection(BaseQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0) + self._socket.connect(self._peer) + (self._send_wakeup, self._receive_wakeup) = socket.socketpair() + self._receive_wakeup.setblocking(False) + self._socket.setblocking(False) + if self._source is not None: + try: + self._socket.bind( + dns.inet.low_level_address_tuple(self._source, self._af) + ) + except Exception: + self._socket.close() + raise + self._handshake_complete = threading.Event() + self._worker_thread = None + self._lock = threading.Lock() + + def _read(self): + count = 0 + while count < 10: + count += 1 + try: + datagram = self._socket.recv(QUIC_MAX_DATAGRAM) + except BlockingIOError: + return + with self._lock: + self._connection.receive_datagram(datagram, self._peer[0], time.time()) + + def _drain_wakeup(self): + while True: + try: + self._receive_wakeup.recv(32) + except BlockingIOError: + return + + def _worker(self): + sel = _selector_class() + sel.register(self._socket, selectors.EVENT_READ, self._read) + sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + items = sel.select(interval) + for (key, _) in items: + key.data() + with self._lock: + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for (datagram, _) in datagrams: + try: + self._socket.send(datagram) + except BlockingIOError: + # we let QUIC handle any lossage + pass + self._handle_events() + + def _handle_events(self): + while True: + with self._lock: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + with self._lock: + stream = self._streams.get(event.stream_id) + if stream: + stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance( + event, aioquic.quic.events.ConnectionTerminated + ) or isinstance(event, aioquic.quic.events.StreamReset): + with self._lock: + self._done = True + + def write(self, stream, data, is_end=False): + with self._lock: + self._connection.send_stream_data(stream, data, is_end) + self._send_wakeup.send(b"\x01") + + def run(self): + if self._closed: + return + self._worker_thread = threading.Thread(target=self._worker) + self._worker_thread.start() + + def make_stream(self): + self._handshake_complete.wait() + with self._lock: + stream_id = self._connection.get_next_available_stream_id(False) + stream = SyncQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + def close_stream(self, stream_id): + with self._lock: + super().close_stream(stream_id) + + def close(self): + with self._lock: + if self._closed: + return + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + self._send_wakeup.send(b"\x01") + self._worker_thread.join() + + +class SyncQuicManager(BaseQuicManager): + def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED): + super().__init__(conf, verify_mode, SyncQuicConnection) + self._lock = threading.Lock() + + def connect(self, address, port=853, source=None, source_port=0): + with self._lock: + (connection, start) = self._connect(address, port, source, source_port) + if start: + connection.run() + return connection + + def closed(self, address, port): + with self._lock: + super().closed(address, port) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Copy the itertor into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + connection.close() + return False diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py new file mode 100644 index 00000000..09b69508 --- /dev/null +++ b/dns/quic/_trio.py @@ -0,0 +1,170 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import socket +import ssl +import struct +import time + +import aioquic.quic.configuration # type: ignore +import aioquic.quic.connection # type: ignore +import aioquic.quic.events # type: ignore +import dns.inet +import trio + +from dns.quic._common import ( + BaseQuicStream, + AsyncQuicConnection, + AsyncQuicManager, + NullContext, + QUIC_MAX_DATAGRAM, +) + + +class TrioQuicStream(BaseQuicStream): + def __init__(self, connection, stream_id): + super().__init__(connection, stream_id) + self._wake_up = trio.Condition() + + async def wait_for(self, amount): + while True: + if self._buffer.have(amount): + return + self._expecting = amount + async with self._wake_up: + await self._wake_up.wait() + self._expecting = 0 + + async def receive(self, timeout=None): + if timeout is None: + context = NullContext(None) + else: + context = trio.move_on_after(timeout) + with context: + await self.wait_for(2) + (size,) = struct.unpack("!H", self._buffer.get(2)) + await self.wait_for(size) + return self._buffer.get(size) + + async def send(self, datagram, is_end=False): + data = self._encapsulate(datagram) + await self._connection.write(self._stream_id, data, is_end) + + async def _add_input(self, data, is_end): + if self._common_add_input(data, is_end): + async with self._wake_up: + self._wake_up.notify() + + async def close(self): + self._close() + + # Streams are async context managers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + async with self._wake_up: + self._wake_up.notify() + return False + + +class TrioQuicConnection(AsyncQuicConnection): + def __init__(self, connection, address, port, source, source_port, manager=None): + super().__init__(connection, address, port, source, source_port, manager) + self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0) + if self._source: + trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af)) + self._handshake_complete = trio.Event() + self._run_done = trio.Event() + self._worker_scope = None + + async def _worker(self): + await self._socket.connect(self._peer) + while not self._done: + (expiration, interval) = self._get_timer_values(False) + with trio.CancelScope( + deadline=trio.current_time() + interval + ) as self._worker_scope: + datagram = await self._socket.recv(QUIC_MAX_DATAGRAM) + self._connection.receive_datagram(datagram, self._peer[0], time.time()) + self._worker_scope = None + self._handle_timer(expiration) + datagrams = self._connection.datagrams_to_send(time.time()) + for (datagram, _) in datagrams: + await self._socket.send(datagram) + await self._handle_events() + + async def _handle_events(self): + count = 0 + while True: + event = self._connection.next_event() + if event is None: + return + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream = self._streams.get(event.stream_id) + if stream: + await stream._add_input(event.data, event.end_stream) + elif isinstance(event, aioquic.quic.events.HandshakeCompleted): + self._handshake_complete.set() + elif isinstance( + event, aioquic.quic.events.ConnectionTerminated + ) or isinstance(event, aioquic.quic.events.StreamReset): + self._done = True + self._socket.close() + count += 1 + if count > 10: + # yield + count = 0 + await trio.sleep(0) + + async def write(self, stream, data, is_end=False): + self._connection.send_stream_data(stream, data, is_end) + if self._worker_scope is not None: + self._worker_scope.cancel() + + async def run(self): + if self._closed: + return + async with trio.open_nursery() as nursery: + nursery.start_soon(self._worker) + self._run_done.set() + + async def make_stream(self): + await self._handshake_complete.wait() + stream_id = self._connection.get_next_available_stream_id(False) + stream = TrioQuicStream(self, stream_id) + self._streams[stream_id] = stream + return stream + + async def close(self): + if not self._closed: + self._manager.closed(self._peer[0], self._peer[1]) + self._closed = True + self._connection.close() + if self._worker_scope is not None: + self._worker_scope.cancel() + await self._run_done.wait() + + +class TrioQuicManager(AsyncQuicManager): + def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED): + super().__init__(conf, verify_mode, TrioQuicConnection) + self._nursery = nursery + + def connect(self, address, port=853, source=None, source_port=0): + (connection, start) = self._connect(address, port, source, source_port) + if start: + self._nursery.start_soon(connection.run) + return connection + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + # Copy the itertor into a list as exiting things will mutate the connections + # table. + connections = list(self._connections.values()) + for connection in connections: + await connection.close() + return False diff --git a/examples/doq.py b/examples/doq.py new file mode 100644 index 00000000..281b96dd --- /dev/null +++ b/examples/doq.py @@ -0,0 +1,120 @@ +import asyncio +import threading + +import dns.asyncbackend +import dns.asyncquery +import dns.message +import dns.query +import dns.quic +import dns.rdatatype + +try: + import trio + + have_trio = True +except ImportError: + have_trio = False + +# This demo assumes you have the aioquic example doq_server.py running on localhost +# on port 4784 on localhost. +peer_address = "127.0.0.1" +peer_port = 4784 +query_name = "www.dnspython.org" +tls_verify_mode = False + + +def squery(rdtype="A", connection=None): + q = dns.message.make_query(query_name, rdtype) + r = dns.query.quic( + q, peer_address, port=peer_port, connection=connection, verify=tls_verify_mode + ) + print(r) + + +def srun(): + squery() + + +def smultirun(): + with dns.quic.SyncQuicManager(verify_mode=tls_verify_mode) as manager: + connection = manager.connect(peer_address, peer_port) + t1 = threading.Thread(target=squery, args=["A", connection]) + t1.start() + t2 = threading.Thread(target=squery, args=["AAAA", connection]) + t2.start() + t1.join() + t2.join() + + +async def aquery(rdtype="A", connection=None): + q = dns.message.make_query(query_name, rdtype) + r = await dns.asyncquery.quic( + q, peer_address, port=peer_port, connection=connection, verify=tls_verify_mode + ) + print(r) + + +def arun(): + asyncio.run(aquery()) + + +async def amulti(): + async with dns.quic.AsyncioQuicManager(verify_mode=tls_verify_mode) as manager: + connection = manager.connect(peer_address, peer_port) + t1 = asyncio.Task(aquery("A", connection)) + t2 = asyncio.Task(aquery("AAAA", connection)) + await t1 + await t2 + + +def amultirun(): + asyncio.run(amulti()) + + +if have_trio: + + def trun(): + trio.run(aquery) + + async def tmulti(): + async with trio.open_nursery() as nursery: + async with dns.quic.TrioQuicManager( + nursery, verify_mode=tls_verify_mode + ) as manager: + async with trio.open_nursery() as query_nursery: + # We run queries in a separate nursery so we can demonstrate + # waiting for them all to exit without waiting for the manager to + # exit as well. + connection = manager.connect(peer_address, peer_port) + query_nursery.start_soon(aquery, "A", connection) + query_nursery.start_soon(aquery, "AAAA", connection) + + def tmultirun(): + trio.run(tmulti) + + +def main(): + print("*** Single Queries ***") + print("--- Sync ---") + srun() + print("--- Asyncio ---") + dns.asyncbackend.set_default_backend("asyncio") + arun() + if have_trio: + print("--- Trio ---") + dns.asyncbackend.set_default_backend("trio") + trun() + print("*** Multi-connection Queries ***") + print("--- Sync ---") + smultirun() + print("--- Asyncio ---") + dns.asyncbackend.set_default_backend("asyncio") + amultirun() + if have_trio: + print("--- Trio ---") + dns.asyncbackend.set_default_backend("trio") + tmultirun() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 7c3a3128..f3c93ea8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,8 @@ include = [ { path="tests/*.pickle", format="sdist" }, { path="tests/*.text", format="sdist" }, { path="tests/*.generic", format="sdist" }, + { path="tests/tls/*.crt", format="sdist" }, + { path="tests/tls/*.pem", format="sdist" }, { path="util/**", format="sdist" }, { path="setup.cfg", format="sdist" }, ] @@ -47,6 +49,7 @@ trio = {version=">=0.14,<0.23", optional=true} curio = {version="^1.2", optional=true} sniffio = {version="^1.1", optional=true} wmi = {version="^1.5.1", optional=true} +aioquic = {version="^0.9.20", optional=true} [tool.poetry.dev-dependencies] pytest = ">=5.4.1,<8" @@ -67,6 +70,7 @@ dnssec = ['cryptography'] trio = ['trio'] curio = ['curio', 'sniffio'] wmi = ['wmi'] +doq = ['aioquic'] [build-system] requires = ["poetry-core"] diff --git a/tests/nanoquic.py b/tests/nanoquic.py new file mode 100644 index 00000000..0bcb12c7 --- /dev/null +++ b/tests/nanoquic.py @@ -0,0 +1,131 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +try: + import asyncio + import socket + import struct + import threading + + import aioquic.asyncio + import aioquic.asyncio.server + import aioquic.quic.configuration + import aioquic.quic.events + + import dns.asyncquery + import dns.message + import dns.rcode + + from tests.util import here + + have_quic = True + + class Request: + def __init__(self, message, wire): + self.message = message + self.wire = wire + + @property + def question(self): + return self.message.question[0] + + @property + def qname(self): + return self.question.name + + @property + def qclass(self): + return self.question.rdclass + + @property + def qtype(self): + return self.question.rdtype + + class NanoQuic(aioquic.asyncio.QuicConnectionProtocol): + def quic_event_received(self, event): + # This is a bit hackish and not fully general, but this is a test server! + if isinstance(event, aioquic.quic.events.StreamDataReceived): + data = bytes(event.data) + (wire_len,) = struct.unpack("!H", data[:2]) + wire = self.handle_wire(data[2 : 2 + wire_len]) + if wire is not None: + self._quic.send_stream_data(event.stream_id, wire, end_stream=True) + + def handle(self, request): + r = dns.message.make_response(request.message) + r.set_rcode(dns.rcode.REFUSED) + return r + + def handle_wire(self, wire): + response = None + try: + q = dns.message.from_wire(wire) + except dns.message.ShortHeader: + return + except Exception as e: + try: + q = dns.message.from_wire(wire, question_only=True) + response = dns.message.make_response(q) + response.set_rcode(dns.rcode.FORMERR) + except Exception: + return + if response is None: + try: + request = Request(q, wire) + response = self.handle(request) + except Exception: + response = dns.message.make_response(q) + response.set_rcode(dns.rcode.SERVFAIL) + wire = response.to_wire() + return struct.pack("!H", len(wire)) + wire + + class Server(threading.Thread): + def __init__(self): + super().__init__() + self.transport = None + self.protocol = None + self.left = None + self.right = None + + def __enter__(self): + self.left, self.right = socket.socketpair() + self.start() + + def __exit__(self, ex_ty, ex_va, ex_tr): + if self.protocol is not None: + self.protocol.close() + if self.transport is not None: + self.transport.close() + if self.left: + self.left.close() + if self.is_alive(): + self.join() + if self.right: + self.right.close() + + async def arun(self): + reader, _ = await asyncio.open_connection(sock=self.right) + conf = aioquic.quic.configuration.QuicConfiguration( + alpn_protocols=["doq"], + is_client=False, + ) + conf.load_cert_chain(here("tls/public.crt"), here("tls/private.pem")) + loop = asyncio.get_event_loop() + (self.transport, self.protocol) = await loop.create_datagram_endpoint( + lambda: aioquic.asyncio.server.QuicServer( + configuration=conf, create_protocol=NanoQuic + ), + local_addr=("127.0.0.1", 8853), + ) + try: + await reader.read(1) + except Exception: + pass + + def run(self): + asyncio.run(self.arun()) + +except ImportError: + have_quic = False + + class NanoQuic: + pass diff --git a/tests/test_doq.py b/tests/test_doq.py new file mode 100644 index 00000000..571581a4 --- /dev/null +++ b/tests/test_doq.py @@ -0,0 +1,56 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import asyncio +import sys + +import pytest + +import dns.asyncbackend +import dns.asyncquery +import dns.message +import dns.query +import dns.rcode + +from .util import here + +try: + from .nanoquic import Server + + _nanoquic_available = True +except ImportError: + _nanoquic_available = False + + class Server(object): + pass + + +@pytest.mark.skipif(not _nanoquic_available, reason="requires nanoquic") +def test_basic_sync(): + with Server() as server: + q = dns.message.make_query("www.example.", "A") + r = dns.query.quic(q, "127.0.0.1", port=8853, verify=here("tls/ca.crt")) + assert r.rcode() == dns.rcode.REFUSED + + +async def amain(): + q = dns.message.make_query("www.example.", "A") + r = await dns.asyncquery.quic(q, "127.0.0.1", port=8853, verify=here("tls/ca.crt")) + assert r.rcode() == dns.rcode.REFUSED + + +def test_basic_asyncio(): + dns.asyncbackend.set_default_backend("asyncio") + with Server() as server: + asyncio.run(amain()) + + +try: + import trio + + def test_basic_trio(): + dns.asyncbackend.set_default_backend("trio") + with Server() as server: + trio.run(amain) + +except ImportError: + pass diff --git a/tests/tls/ca.crt b/tests/tls/ca.crt new file mode 100644 index 00000000..81c76825 --- /dev/null +++ b/tests/tls/ca.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDTDCCAjSgAwIBAgIUUCWxpsMnzETqwNKJ38le9z7oFEEwDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMDcwOTIyMjQw +N1oXDTMyMDcwNjIyMjQzN1owHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3Jn +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0AMlXDsx/7Kis4lUhAML +yaL4wtvhPGnqz20Gnhd/b2uAjZbtLtKDG2aRC0QtHL6N0vfBhj+KUV/unT60Mf7G +Pm2Z8fOxiwh/UJ8oxoJe59izklrwM0PL2iR21OMCCsiYcjiOOx75RUZ/6KEGMTgd +3wvqwEV320yd3WInkdO72n9jlQTN3VtwLwkIkSbINiuUCKgP9hy28K7HjMHvEIlf +QZfh9wIHhbqs/JP3dirRL7MKWFAv3MlmMffb/6NBBFb6FaRjS6WjojD8qaSTr14/ +tyqrK7zL32npKm/TbzxC8hFwYdwd3HURgpWInA6CRIcyZM/k4y7dHQlI4ID7hmcC +1QIDAQABo4GDMIGAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0G +A1UdDgQWBBQrNPKeL6rBhPV+Eb1RnvIkeax5sDAfBgNVHSMEGDAWgBQrNPKeL6rB +hPV+Eb1RnvIkeax5sDAdBgNVHREEFjAUghJxdWljLmRuc3B5dGhvbi5vcmcwDQYJ +KoZIhvcNAQELBQADggEBAADpAtDvceOrhn5FReYip9DlTW7KKrRDDFCo0SNdhvN3 +6mU/Hn3jNXYu9Ym3NDVL8q9UWzLRcSNLUo1qjkK3aOlgwcO6PuGKXukF7Zdd8wVb +pPdUqooBmj6akqmNvmloZyDmQ+aXcYhR83hcEHFOK+C7pGLqSFChN1mgDT1/mgBk +pODOZkcLtZI8YJyQ2sn3WhUJS52D6xfmPigliUcYqi6i+w1vxD45QilWbvqCwnN/ +6qmb3JQsMf+3MCtogVcSZjE9cf4CwlmKqgMxsBKz+/Qk9YPMpDuecEbd76L+Htdl +HWuDlemBzyhd5qO5y/UGarqmuh3MgkOdFVQWAUygcCM= +-----END CERTIFICATE----- diff --git a/tests/tls/private.pem b/tests/tls/private.pem new file mode 100644 index 00000000..06a01fad --- /dev/null +++ b/tests/tls/private.pem @@ -0,0 +1,3 @@ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIL2OxuOo+awfhPvvm82EBZ4VA6ULQHlebxGCamZ/H5Rt +-----END PRIVATE KEY----- diff --git a/tests/tls/public.crt b/tests/tls/public.crt new file mode 100644 index 00000000..96129a1b --- /dev/null +++ b/tests/tls/public.crt @@ -0,0 +1,35 @@ +-----BEGIN CERTIFICATE----- +MIICZjCCAU6gAwIBAgIUBTlEzhtkXYQvZl5CYRNBxOG4GpEwDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMTAwOTE2MjYw +OFoXDTMwMTIyNjE2MjYzOFowFDESMBAGA1UEAxMJbG9jYWxob3N0MCowBQYDK2Vw +AyEAKpQbO2JXhCGnQs2MrWmGBK5LcmJMWPXCzM2PfWbo1TyjgaAwgZ0wDgYDVR0P +AQH/BAQDAgOoMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4E +FgQUM2pZy8pH78CvP+FnuF190KEJkjUwHwYDVR0jBBgwFoAUKzTyni+qwYT1fhG9 +UZ7yJHmsebAwLAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAA +AAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQA0JlNLrLz3ajCzSVfOQsUdd3a3wR7Q +Dr28mYoDHSY9mhnJ9IQeInmGvPMLA4dgiRPFqxWsKh+lxzZObkbMjf1IAIVykfh6 +LynePm58/lnRrhdvf8vFfccuTyeb2aD0ZBA/RyhZam79J6JjRRovkSj9TyIqKfif +6T6QWXOXwAF89rH8YHAKnRSl32pqZuDhOnM0Ien+Sa6KpCvgIDogHQxIVbe1egZl +2Ec0LVQUaXhoICd1c6xoRoAa5UzDFJ7ujeu1XNGWKIiXESlcIo7SZjzusL2p5vv/ +frM+r43khtZ4s+F70A+B3AndcVSeKTQ5KlftN9CBuiQoYzhY29NmL93X +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDTDCCAjSgAwIBAgIUUCWxpsMnzETqwNKJ38le9z7oFEEwDQYJKoZIhvcNAQEL +BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMDcwOTIyMjQw +N1oXDTMyMDcwNjIyMjQzN1owHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3Jn +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0AMlXDsx/7Kis4lUhAML +yaL4wtvhPGnqz20Gnhd/b2uAjZbtLtKDG2aRC0QtHL6N0vfBhj+KUV/unT60Mf7G +Pm2Z8fOxiwh/UJ8oxoJe59izklrwM0PL2iR21OMCCsiYcjiOOx75RUZ/6KEGMTgd +3wvqwEV320yd3WInkdO72n9jlQTN3VtwLwkIkSbINiuUCKgP9hy28K7HjMHvEIlf +QZfh9wIHhbqs/JP3dirRL7MKWFAv3MlmMffb/6NBBFb6FaRjS6WjojD8qaSTr14/ +tyqrK7zL32npKm/TbzxC8hFwYdwd3HURgpWInA6CRIcyZM/k4y7dHQlI4ID7hmcC +1QIDAQABo4GDMIGAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0G +A1UdDgQWBBQrNPKeL6rBhPV+Eb1RnvIkeax5sDAfBgNVHSMEGDAWgBQrNPKeL6rB +hPV+Eb1RnvIkeax5sDAdBgNVHREEFjAUghJxdWljLmRuc3B5dGhvbi5vcmcwDQYJ +KoZIhvcNAQELBQADggEBAADpAtDvceOrhn5FReYip9DlTW7KKrRDDFCo0SNdhvN3 +6mU/Hn3jNXYu9Ym3NDVL8q9UWzLRcSNLUo1qjkK3aOlgwcO6PuGKXukF7Zdd8wVb +pPdUqooBmj6akqmNvmloZyDmQ+aXcYhR83hcEHFOK+C7pGLqSFChN1mgDT1/mgBk +pODOZkcLtZI8YJyQ2sn3WhUJS52D6xfmPigliUcYqi6i+w1vxD45QilWbvqCwnN/ +6qmb3JQsMf+3MCtogVcSZjE9cf4CwlmKqgMxsBKz+/Qk9YPMpDuecEbd76L+Htdl +HWuDlemBzyhd5qO5y/UGarqmuh3MgkOdFVQWAUygcCM= +-----END CERTIFICATE-----