"node",
"opcode",
"query",
+ "quic",
"rcode",
"rdata",
"rdataclass",
import dns.inet
import dns.name
import dns.message
+import dns.quic
import dns.rcode
import dns.rdataclass
import dns.rdatatype
_have_httpx,
_have_http2,
NoDOH,
+ NoDOQ,
)
if _have_httpx:
tsig_ctx = r.tsig_ctx
if not retry and query.keyring and not r.had_tsig:
raise dns.exception.FormError("missing TSIG")
+
+
+async def quic(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ connection: Optional[dns.quic.AsyncQuicConnection] = None,
+ verify: Union[bool, str] = True,
+ backend: Optional[dns.asyncbackend.Backend] = None,
+) -> dns.message.Message:
+ """Return the response obtained after sending an asynchronous query via
+ DNS-over-QUIC.
+
+ *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
+ the default, then dnspython will use the default backend.
+
+ See :py:func:`dns.query.doq()` for the documentation of the other
+ parameters, exceptions, and return type of this method.
+ """
+
+ if not dns.quic.have_quic:
+ raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
+
+ q.id = 0
+ wire = q.to_wire()
+ the_connection: dns.quic.AsyncQuicConnection
+ if connection:
+ cfactory = dns.quic.null_factory
+ mfactory = dns.quic.null_factory
+ the_connection = connection
+ else:
+ (cfactory, mfactory) = dns.quic.factories_for_backend(backend)
+
+ async with cfactory() as context:
+ async with mfactory(context, verify_mode=verify) as the_manager:
+ if not connection:
+ the_connection = the_manager.connect(where, port, source, source_port)
+ start = time.time()
+ stream = await the_connection.make_stream()
+ async with stream:
+ await stream.send(wire, True)
+ wire = await stream.receive(timeout)
+ finish = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = max(finish - start, 0.0)
+ if not q.is_response(r):
+ raise BadResponse
+ return r
keyptr = keyptr[octets:]
dsa_y = keyptr[0:octets]
try:
- dsa_public_key = dsa.DSAPublicNumbers(
+ dsa_public_key = dsa.DSAPublicNumbers( # type: ignore
_bytes_to_long(dsa_y),
dsa.DSAParameterNumbers(
_bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g)
import dns.inet
import dns.name
import dns.message
+import dns.quic
import dns.rcode
import dns.rdataclass
import dns.rdatatype
available."""
+class NoDOQ(dns.exception.DNSException):
+ """DNS over QUIC (DOQ) was requested but the aioquic module is not
+ available."""
+
+
# for backwards compatibility
TransferError = dns.xfr.TransferError
)
+def quic(
+ q: dns.message.Message,
+ where: str,
+ timeout: Optional[float] = None,
+ port: int = 853,
+ source: Optional[str] = None,
+ source_port: int = 0,
+ one_rr_per_rrset: bool = False,
+ ignore_trailing: bool = False,
+ connection: Optional[dns.quic.SyncQuicConnection] = None,
+ verify: Union[bool, str] = True,
+) -> dns.message.Message:
+ """Return the response obtained after sending a query via DNS-over-QUIC.
+
+ *q*, a ``dns.message.Message``, the query to send.
+
+ *where*, a ``str``, the nameserver IP address.
+
+ *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
+ times out. If ``None``, the default, wait forever.
+
+ *port*, a ``int``, the port to send the query to. The default is 443.
+
+ *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
+ address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message. The default is
+ 0.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
+ received message.
+
+ *connection*, a ``dns.quic.SyncQuicConnection``. If provided, the
+ connection to use to send the query.
+
+ *verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
+ of the server is done using the default CA bundle; if ``False``, then no
+ verification is done; if a `str` then it specifies the path to a certificate file or
+ directory which will be used for verification.
+
+ Returns a ``dns.message.Message``.
+ """
+
+ if not dns.quic.have_quic:
+ raise NoDOQ("DNS-over-QUIC is not available.") # pragma: no cover
+
+ q.id = 0
+ wire = q.to_wire()
+ the_connection: dns.quic.SyncQuicConnection
+ the_manager: dns.quic.SyncQuicManager
+ if connection:
+ manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
+ the_connection = connection
+ else:
+ manager = dns.quic.SyncQuicManager(verify_mode=verify)
+ the_manager = manager # for type checking happiness
+
+ with manager:
+ if not connection:
+ the_connection = the_manager.connect(where, port)
+ start = time.time()
+ with the_connection.make_stream() as stream:
+ stream.send(wire, True)
+ wire = stream.receive(timeout)
+ finish = time.time()
+ r = dns.message.from_wire(
+ wire,
+ keyring=q.keyring,
+ request_mac=q.request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ )
+ r.time = max(finish - start, 0.0)
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
def xfr(
where: str,
zone: Union[dns.name.Name, str],
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+try:
+ import aioquic.quic.configuration # type: ignore
+
+ import dns.asyncbackend
+ from dns._asyncbackend import NullContext
+ from dns.quic._sync import SyncQuicManager, SyncQuicConnection, SyncQuicStream
+ from dns.quic._asyncio import (
+ AsyncioQuicManager,
+ AsyncioQuicConnection,
+ AsyncioQuicStream,
+ )
+ from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
+
+ have_quic = True
+
+ def null_factory(*args, **kwargs):
+ return NullContext(None)
+
+ def _asyncio_manager_factory(context, *args, **kwargs):
+ return AsyncioQuicManager(*args, **kwargs)
+
+ # We have a context factory and a manager factory as for trio we need to have
+ # a nursery.
+
+ _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
+
+ try:
+ import trio
+ from dns.quic._trio import TrioQuicManager, TrioQuicConnection, TrioQuicStream
+
+ def _trio_context_factory():
+ return trio.open_nursery()
+
+ def _trio_manager_factory(context, *args, **kwargs):
+ return TrioQuicManager(context, *args, **kwargs)
+
+ _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
+ except ImportError:
+ pass
+
+ def factories_for_backend(backend=None):
+ if backend is None:
+ backend = dns.asyncbackend.get_default_backend()
+ return _async_factories[backend.name()]
+
+except ImportError:
+ have_quic = False
+
+ from typing import Any
+
+ class AsyncQuicStream: # type: ignore
+ pass
+
+ class AsyncQuicConnection: # type: ignore
+ async def make_stream(self) -> Any:
+ raise NotImplementedError
+
+ class SyncQuicStream: # type: ignore
+ pass
+
+ class SyncQuicConnection: # type: ignore
+ def make_stream(self) -> Any:
+ raise NotImplementedError
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import asyncio
+import socket
+import ssl
+import struct
+import time
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+import aioquic.quic.events # type: ignore
+import dns.inet
+import dns.asyncbackend
+
+from dns.quic._common import (
+ BaseQuicStream,
+ AsyncQuicConnection,
+ AsyncQuicManager,
+ QUIC_MAX_DATAGRAM,
+)
+
+
+class AsyncioQuicStream(BaseQuicStream):
+ def __init__(self, connection, stream_id):
+ super().__init__(connection, stream_id)
+ self._wake_up = asyncio.Condition()
+
+ async def _wait_for_wake_up(self):
+ async with self._wake_up:
+ await self._wake_up.wait()
+
+ async def wait_for(self, amount, expiration):
+ timeout = self._timeout_from_expiration(expiration)
+ while True:
+ if self._buffer.have(amount):
+ return
+ self._expecting = amount
+ try:
+ await asyncio.wait_for(self._wait_for_wake_up(), timeout)
+ except Exception:
+ pass
+ self._expecting = 0
+
+ async def receive(self, timeout=None):
+ expiration = self._expiration_from_timeout(timeout)
+ await self.wait_for(2, expiration)
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ await self.wait_for(size, expiration)
+ return self._buffer.get(size)
+
+ async def send(self, datagram, is_end=False):
+ data = self._encapsulate(datagram)
+ await self._connection.write(self._stream_id, data, is_end)
+
+ async def _add_input(self, data, is_end):
+ if self._common_add_input(data, is_end):
+ async with self._wake_up:
+ self._wake_up.notify()
+
+ async def close(self):
+ self._close()
+
+ # Streams are async context managers
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+ async with self._wake_up:
+ self._wake_up.notify()
+ return False
+
+
+class AsyncioQuicConnection(AsyncQuicConnection):
+ def __init__(self, connection, address, port, source, source_port, manager=None):
+ super().__init__(connection, address, port, source, source_port, manager)
+ self._socket = None
+ self._handshake_complete = asyncio.Event()
+ self._socket_created = asyncio.Event()
+ self._wake_timer = asyncio.Condition()
+ self._receiver_task = None
+ self._sender_task = None
+
+ async def _receiver(self):
+ try:
+ af = dns.inet.af_for_address(self._address)
+ backend = dns.asyncbackend.get_backend("asyncio")
+ self._socket = await backend.make_socket(
+ af, socket.SOCK_DGRAM, 0, self._source, self._peer
+ )
+ self._socket_created.set()
+ async with self._socket:
+ while not self._done:
+ (datagram, address) = await self._socket.recvfrom(
+ QUIC_MAX_DATAGRAM, None
+ )
+ if address[0] != self._peer[0] or address[1] != self._peer[1]:
+ continue
+ self._connection.receive_datagram(
+ datagram, self._peer[0], time.time()
+ )
+ # Wake up the timer in case the sender is sleeping, as there may be
+ # stuff to send now.
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+ except Exception:
+ pass
+
+ async def _wait_for_wake_timer(self):
+ async with self._wake_timer:
+ await self._wake_timer.wait()
+
+ async def _sender(self):
+ await self._socket_created.wait()
+ while not self._done:
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for (datagram, address) in datagrams:
+ assert address == self._peer[0]
+ await self._socket.sendto(datagram, self._peer, None)
+ (expiration, interval) = self._get_timer_values()
+ try:
+ await asyncio.wait_for(self._wait_for_wake_timer(), interval)
+ except Exception:
+ pass
+ self._handle_timer(expiration)
+ await self._handle_events()
+
+ async def _handle_events(self):
+ count = 0
+ while True:
+ event = self._connection.next_event()
+ if event is None:
+ return
+ if isinstance(event, aioquic.quic.events.StreamDataReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(event.data, event.end_stream)
+ elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
+ self._handshake_complete.set()
+ elif isinstance(
+ event, aioquic.quic.events.ConnectionTerminated
+ ) or isinstance(event, aioquic.quic.events.StreamReset):
+ self._done = True
+ self._receiver_task.cancel()
+ count += 1
+ if count > 10:
+ # yield
+ count = 0
+ await asyncio.sleep(0)
+
+ async def write(self, stream, data, is_end=False):
+ self._connection.send_stream_data(stream, data, is_end)
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+
+ def run(self):
+ if self._closed:
+ return
+ self._receiver_task = asyncio.Task(self._receiver())
+ self._sender_task = asyncio.Task(self._sender())
+
+ async def make_stream(self):
+ await self._handshake_complete.wait()
+ stream_id = self._connection.get_next_available_stream_id(False)
+ stream = AsyncioQuicStream(self, stream_id)
+ self._streams[stream_id] = stream
+ return stream
+
+ async def close(self):
+ if not self._closed:
+ self._manager.closed(self._peer[0], self._peer[1])
+ self._closed = True
+ self._connection.close()
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+ try:
+ await self._receiver_task
+ except asyncio.CancelledError:
+ pass
+ try:
+ await self._sender_task
+ except asyncio.CancelledError:
+ pass
+
+
+class AsyncioQuicManager(AsyncQuicManager):
+ def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
+ super().__init__(conf, verify_mode, AsyncioQuicConnection)
+
+ def connect(self, address, port=853, source=None, source_port=0):
+ (connection, start) = self._connect(address, port, source, source_port)
+ if start:
+ connection.run()
+ return connection
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ # Copy the itertor into a list as exiting things will mutate the connections
+ # table.
+ connections = list(self._connections.values())
+ for connection in connections:
+ await connection.close()
+ return False
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import socket
+import struct
+import time
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+import dns.inet
+from dns._asyncbackend import NullContext
+
+from typing import Any
+
+
+QUIC_MAX_DATAGRAM = 2048
+
+
+class UnexpectedEOF(Exception):
+ pass
+
+
+class Buffer:
+ def __init__(self):
+ self._buffer = b""
+ self._seen_end = False
+
+ def put(self, data, is_end):
+ if self._seen_end:
+ return
+ self._buffer += data
+ if is_end:
+ self._seen_end = True
+
+ def have(self, amount):
+ if len(self._buffer) >= amount:
+ return True
+ if self._seen_end:
+ raise UnexpectedEOF
+ return False
+
+ def seen_end(self):
+ return self._seen_end
+
+ def get(self, amount):
+ assert self.have(amount)
+ data = self._buffer[:amount]
+ self._buffer = self._buffer[amount:]
+ return data
+
+
+class BaseQuicStream:
+ def __init__(self, connection, stream_id):
+ self._connection = connection
+ self._stream_id = stream_id
+ self._buffer = Buffer()
+ self._expecting = 0
+
+ def id(self):
+ return self._stream_id
+
+ def _expiration_from_timeout(self, timeout):
+ if timeout is not None:
+ expiration = time.time() + timeout
+ else:
+ expiration = None
+ return expiration
+
+ def _timeout_from_expiration(self, expiration):
+ if expiration is not None:
+ timeout = max(expiration - time.time(), 0.0)
+ else:
+ timeout = None
+ return timeout
+
+ # Subclass must implement receive() as sync / async and which returns a message
+ # or raises UnexpectedEOF.
+
+ def _encapsulate(self, datagram):
+ l = len(datagram)
+ return struct.pack("!H", l) + datagram
+
+ def _common_add_input(self, data, is_end):
+ self._buffer.put(data, is_end)
+ return self._expecting > 0 and self._buffer.have(self._expecting)
+
+ def _close(self):
+ self._connection.close_stream(self._stream_id)
+ self._buffer.put(b"", True) # send EOF in case we haven't seen it.
+
+
+class BaseQuicConnection:
+ def __init__(
+ self, connection, address, port, source=None, source_port=0, manager=None
+ ):
+ self._done = False
+ self._connection = connection
+ self._address = address
+ self._port = port
+ self._closed = False
+ self._manager = manager
+ self._streams = {}
+ self._af = dns.inet.af_for_address(address)
+ self._peer = dns.inet.low_level_address_tuple((address, port))
+ if source is None and source_port != 0:
+ if self._af == socket.AF_INET:
+ source = "0.0.0.0"
+ elif self._af == socket.AF_INET6:
+ source = "::"
+ else:
+ raise NotImplementedError
+ if source:
+ self._source = (source, source_port)
+ else:
+ self._source = None
+
+ def close_stream(self, stream_id):
+ del self._streams[stream_id]
+
+ def _get_timer_values(self, closed_is_special=True):
+ now = time.time()
+ expiration = self._connection.get_timer()
+ if expiration is None:
+ expiration = now + 3600 # arbitrary "big" value
+ interval = max(expiration - now, 0)
+ if self._closed and closed_is_special:
+ # lower sleep interval to avoid a race in the closing process
+ # which can lead to higher latency closing due to sleeping when
+ # we have events.
+ interval = min(interval, 0.05)
+ return (expiration, interval)
+
+ def _handle_timer(self, expiration):
+ now = time.time()
+ if expiration <= now:
+ self._connection.handle_timer(now)
+
+
+class AsyncQuicConnection(BaseQuicConnection):
+ async def make_stream(self) -> Any:
+ pass
+
+
+class BaseQuicManager:
+ def __init__(self, conf, verify_mode, connection_factory):
+ self._connections = {}
+ self._connection_factory = connection_factory
+ if conf is None:
+ verify_path = None
+ if isinstance(verify_mode, str):
+ verify_path = verify_mode
+ verify_mode = True
+ conf = aioquic.quic.configuration.QuicConfiguration(
+ alpn_protocols=["doq", "doq-i03"],
+ verify_mode=verify_mode,
+ )
+ if verify_path is not None:
+ conf.load_verify_locations(verify_path)
+ self._conf = conf
+
+ def _connect(self, address, port=853, source=None, source_port=0):
+ connection = self._connections.get((address, port))
+ if connection is not None:
+ return (connection, False)
+ qconn = aioquic.quic.connection.QuicConnection(configuration=self._conf)
+ qconn.connect(address, time.time())
+ connection = self._connection_factory(
+ qconn, address, port, source, source_port, self
+ )
+ self._connections[(address, port)] = connection
+ return (connection, True)
+
+ def closed(self, address, port):
+ try:
+ del self._connections[(address, port)]
+ except KeyError:
+ pass
+
+
+class AsyncQuicManager(BaseQuicManager):
+ def connect(self, address, port=853, source=None, source_port=0):
+ raise NotImplementedError
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import socket
+import ssl
+import selectors
+import struct
+import threading
+import time
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+import aioquic.quic.events # type: ignore
+import dns.inet
+
+from dns.quic._common import (
+ BaseQuicStream,
+ BaseQuicConnection,
+ BaseQuicManager,
+ QUIC_MAX_DATAGRAM,
+)
+
+# Avoid circularity with dns.query
+if hasattr(selectors, "PollSelector"):
+ _selector_class = selectors.PollSelector # type: ignore
+else:
+ _selector_class = selectors.SelectSelector # type: ignore
+
+
+class SyncQuicStream(BaseQuicStream):
+ def __init__(self, connection, stream_id):
+ super().__init__(connection, stream_id)
+ self._wake_up = threading.Condition()
+ self._lock = threading.Lock()
+
+ def wait_for(self, amount, expiration):
+ timeout = self._timeout_from_expiration(expiration)
+ while True:
+ with self._lock:
+ if self._buffer.have(amount):
+ return
+ self._expecting = amount
+ with self._wake_up:
+ self._wake_up.wait(timeout)
+ self._expecting = 0
+
+ def receive(self, timeout=None):
+ expiration = self._expiration_from_timeout(timeout)
+ self.wait_for(2, expiration)
+ with self._lock:
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ self.wait_for(size, expiration)
+ with self._lock:
+ return self._buffer.get(size)
+
+ def send(self, datagram, is_end=False):
+ data = self._encapsulate(datagram)
+ self._connection.write(self._stream_id, data, is_end)
+
+ def _add_input(self, data, is_end):
+ if self._common_add_input(data, is_end):
+ with self._wake_up:
+ self._wake_up.notify()
+
+ def close(self):
+ with self._lock:
+ self._close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+ with self._wake_up:
+ self._wake_up.notify()
+ return False
+
+
+class SyncQuicConnection(BaseQuicConnection):
+ def __init__(self, connection, address, port, source, source_port, manager):
+ super().__init__(connection, address, port, source, source_port, manager)
+ self._socket = socket.socket(self._af, socket.SOCK_DGRAM, 0)
+ self._socket.connect(self._peer)
+ (self._send_wakeup, self._receive_wakeup) = socket.socketpair()
+ self._receive_wakeup.setblocking(False)
+ self._socket.setblocking(False)
+ if self._source is not None:
+ try:
+ self._socket.bind(
+ dns.inet.low_level_address_tuple(self._source, self._af)
+ )
+ except Exception:
+ self._socket.close()
+ raise
+ self._handshake_complete = threading.Event()
+ self._worker_thread = None
+ self._lock = threading.Lock()
+
+ def _read(self):
+ count = 0
+ while count < 10:
+ count += 1
+ try:
+ datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
+ except BlockingIOError:
+ return
+ with self._lock:
+ self._connection.receive_datagram(datagram, self._peer[0], time.time())
+
+ def _drain_wakeup(self):
+ while True:
+ try:
+ self._receive_wakeup.recv(32)
+ except BlockingIOError:
+ return
+
+ def _worker(self):
+ sel = _selector_class()
+ sel.register(self._socket, selectors.EVENT_READ, self._read)
+ sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
+ while not self._done:
+ (expiration, interval) = self._get_timer_values(False)
+ items = sel.select(interval)
+ for (key, _) in items:
+ key.data()
+ with self._lock:
+ self._handle_timer(expiration)
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for (datagram, _) in datagrams:
+ try:
+ self._socket.send(datagram)
+ except BlockingIOError:
+ # we let QUIC handle any lossage
+ pass
+ self._handle_events()
+
+ def _handle_events(self):
+ while True:
+ with self._lock:
+ event = self._connection.next_event()
+ if event is None:
+ return
+ if isinstance(event, aioquic.quic.events.StreamDataReceived):
+ with self._lock:
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ stream._add_input(event.data, event.end_stream)
+ elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
+ self._handshake_complete.set()
+ elif isinstance(
+ event, aioquic.quic.events.ConnectionTerminated
+ ) or isinstance(event, aioquic.quic.events.StreamReset):
+ with self._lock:
+ self._done = True
+
+ def write(self, stream, data, is_end=False):
+ with self._lock:
+ self._connection.send_stream_data(stream, data, is_end)
+ self._send_wakeup.send(b"\x01")
+
+ def run(self):
+ if self._closed:
+ return
+ self._worker_thread = threading.Thread(target=self._worker)
+ self._worker_thread.start()
+
+ def make_stream(self):
+ self._handshake_complete.wait()
+ with self._lock:
+ stream_id = self._connection.get_next_available_stream_id(False)
+ stream = SyncQuicStream(self, stream_id)
+ self._streams[stream_id] = stream
+ return stream
+
+ def close_stream(self, stream_id):
+ with self._lock:
+ super().close_stream(stream_id)
+
+ def close(self):
+ with self._lock:
+ if self._closed:
+ return
+ self._manager.closed(self._peer[0], self._peer[1])
+ self._closed = True
+ self._connection.close()
+ self._send_wakeup.send(b"\x01")
+ self._worker_thread.join()
+
+
+class SyncQuicManager(BaseQuicManager):
+ def __init__(self, conf=None, verify_mode=ssl.CERT_REQUIRED):
+ super().__init__(conf, verify_mode, SyncQuicConnection)
+ self._lock = threading.Lock()
+
+ def connect(self, address, port=853, source=None, source_port=0):
+ with self._lock:
+ (connection, start) = self._connect(address, port, source, source_port)
+ if start:
+ connection.run()
+ return connection
+
+ def closed(self, address, port):
+ with self._lock:
+ super().closed(address, port)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ # Copy the itertor into a list as exiting things will mutate the connections
+ # table.
+ connections = list(self._connections.values())
+ for connection in connections:
+ connection.close()
+ return False
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import socket
+import ssl
+import struct
+import time
+
+import aioquic.quic.configuration # type: ignore
+import aioquic.quic.connection # type: ignore
+import aioquic.quic.events # type: ignore
+import dns.inet
+import trio
+
+from dns.quic._common import (
+ BaseQuicStream,
+ AsyncQuicConnection,
+ AsyncQuicManager,
+ NullContext,
+ QUIC_MAX_DATAGRAM,
+)
+
+
+class TrioQuicStream(BaseQuicStream):
+ def __init__(self, connection, stream_id):
+ super().__init__(connection, stream_id)
+ self._wake_up = trio.Condition()
+
+ async def wait_for(self, amount):
+ while True:
+ if self._buffer.have(amount):
+ return
+ self._expecting = amount
+ async with self._wake_up:
+ await self._wake_up.wait()
+ self._expecting = 0
+
+ async def receive(self, timeout=None):
+ if timeout is None:
+ context = NullContext(None)
+ else:
+ context = trio.move_on_after(timeout)
+ with context:
+ await self.wait_for(2)
+ (size,) = struct.unpack("!H", self._buffer.get(2))
+ await self.wait_for(size)
+ return self._buffer.get(size)
+
+ async def send(self, datagram, is_end=False):
+ data = self._encapsulate(datagram)
+ await self._connection.write(self._stream_id, data, is_end)
+
+ async def _add_input(self, data, is_end):
+ if self._common_add_input(data, is_end):
+ async with self._wake_up:
+ self._wake_up.notify()
+
+ async def close(self):
+ self._close()
+
+ # Streams are async context managers
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+ async with self._wake_up:
+ self._wake_up.notify()
+ return False
+
+
+class TrioQuicConnection(AsyncQuicConnection):
+ def __init__(self, connection, address, port, source, source_port, manager=None):
+ super().__init__(connection, address, port, source, source_port, manager)
+ self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
+ if self._source:
+ trio.socket.bind(dns.inet.low_level_address_tuple(self._source, self._af))
+ self._handshake_complete = trio.Event()
+ self._run_done = trio.Event()
+ self._worker_scope = None
+
+ async def _worker(self):
+ await self._socket.connect(self._peer)
+ while not self._done:
+ (expiration, interval) = self._get_timer_values(False)
+ with trio.CancelScope(
+ deadline=trio.current_time() + interval
+ ) as self._worker_scope:
+ datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
+ self._connection.receive_datagram(datagram, self._peer[0], time.time())
+ self._worker_scope = None
+ self._handle_timer(expiration)
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for (datagram, _) in datagrams:
+ await self._socket.send(datagram)
+ await self._handle_events()
+
+ async def _handle_events(self):
+ count = 0
+ while True:
+ event = self._connection.next_event()
+ if event is None:
+ return
+ if isinstance(event, aioquic.quic.events.StreamDataReceived):
+ stream = self._streams.get(event.stream_id)
+ if stream:
+ await stream._add_input(event.data, event.end_stream)
+ elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
+ self._handshake_complete.set()
+ elif isinstance(
+ event, aioquic.quic.events.ConnectionTerminated
+ ) or isinstance(event, aioquic.quic.events.StreamReset):
+ self._done = True
+ self._socket.close()
+ count += 1
+ if count > 10:
+ # yield
+ count = 0
+ await trio.sleep(0)
+
+ async def write(self, stream, data, is_end=False):
+ self._connection.send_stream_data(stream, data, is_end)
+ if self._worker_scope is not None:
+ self._worker_scope.cancel()
+
+ async def run(self):
+ if self._closed:
+ return
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(self._worker)
+ self._run_done.set()
+
+ async def make_stream(self):
+ await self._handshake_complete.wait()
+ stream_id = self._connection.get_next_available_stream_id(False)
+ stream = TrioQuicStream(self, stream_id)
+ self._streams[stream_id] = stream
+ return stream
+
+ async def close(self):
+ if not self._closed:
+ self._manager.closed(self._peer[0], self._peer[1])
+ self._closed = True
+ self._connection.close()
+ if self._worker_scope is not None:
+ self._worker_scope.cancel()
+ await self._run_done.wait()
+
+
+class TrioQuicManager(AsyncQuicManager):
+ def __init__(self, nursery, conf=None, verify_mode=ssl.CERT_REQUIRED):
+ super().__init__(conf, verify_mode, TrioQuicConnection)
+ self._nursery = nursery
+
+ def connect(self, address, port=853, source=None, source_port=0):
+ (connection, start) = self._connect(address, port, source, source_port)
+ if start:
+ self._nursery.start_soon(connection.run)
+ return connection
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ # Copy the itertor into a list as exiting things will mutate the connections
+ # table.
+ connections = list(self._connections.values())
+ for connection in connections:
+ await connection.close()
+ return False
--- /dev/null
+import asyncio
+import threading
+
+import dns.asyncbackend
+import dns.asyncquery
+import dns.message
+import dns.query
+import dns.quic
+import dns.rdatatype
+
+try:
+ import trio
+
+ have_trio = True
+except ImportError:
+ have_trio = False
+
+# This demo assumes you have the aioquic example doq_server.py running on localhost
+# on port 4784 on localhost.
+peer_address = "127.0.0.1"
+peer_port = 4784
+query_name = "www.dnspython.org"
+tls_verify_mode = False
+
+
+def squery(rdtype="A", connection=None):
+ q = dns.message.make_query(query_name, rdtype)
+ r = dns.query.quic(
+ q, peer_address, port=peer_port, connection=connection, verify=tls_verify_mode
+ )
+ print(r)
+
+
+def srun():
+ squery()
+
+
+def smultirun():
+ with dns.quic.SyncQuicManager(verify_mode=tls_verify_mode) as manager:
+ connection = manager.connect(peer_address, peer_port)
+ t1 = threading.Thread(target=squery, args=["A", connection])
+ t1.start()
+ t2 = threading.Thread(target=squery, args=["AAAA", connection])
+ t2.start()
+ t1.join()
+ t2.join()
+
+
+async def aquery(rdtype="A", connection=None):
+ q = dns.message.make_query(query_name, rdtype)
+ r = await dns.asyncquery.quic(
+ q, peer_address, port=peer_port, connection=connection, verify=tls_verify_mode
+ )
+ print(r)
+
+
+def arun():
+ asyncio.run(aquery())
+
+
+async def amulti():
+ async with dns.quic.AsyncioQuicManager(verify_mode=tls_verify_mode) as manager:
+ connection = manager.connect(peer_address, peer_port)
+ t1 = asyncio.Task(aquery("A", connection))
+ t2 = asyncio.Task(aquery("AAAA", connection))
+ await t1
+ await t2
+
+
+def amultirun():
+ asyncio.run(amulti())
+
+
+if have_trio:
+
+ def trun():
+ trio.run(aquery)
+
+ async def tmulti():
+ async with trio.open_nursery() as nursery:
+ async with dns.quic.TrioQuicManager(
+ nursery, verify_mode=tls_verify_mode
+ ) as manager:
+ async with trio.open_nursery() as query_nursery:
+ # We run queries in a separate nursery so we can demonstrate
+ # waiting for them all to exit without waiting for the manager to
+ # exit as well.
+ connection = manager.connect(peer_address, peer_port)
+ query_nursery.start_soon(aquery, "A", connection)
+ query_nursery.start_soon(aquery, "AAAA", connection)
+
+ def tmultirun():
+ trio.run(tmulti)
+
+
+def main():
+ print("*** Single Queries ***")
+ print("--- Sync ---")
+ srun()
+ print("--- Asyncio ---")
+ dns.asyncbackend.set_default_backend("asyncio")
+ arun()
+ if have_trio:
+ print("--- Trio ---")
+ dns.asyncbackend.set_default_backend("trio")
+ trun()
+ print("*** Multi-connection Queries ***")
+ print("--- Sync ---")
+ smultirun()
+ print("--- Asyncio ---")
+ dns.asyncbackend.set_default_backend("asyncio")
+ amultirun()
+ if have_trio:
+ print("--- Trio ---")
+ dns.asyncbackend.set_default_backend("trio")
+ tmultirun()
+
+
+if __name__ == "__main__":
+ main()
{ path="tests/*.pickle", format="sdist" },
{ path="tests/*.text", format="sdist" },
{ path="tests/*.generic", format="sdist" },
+ { path="tests/tls/*.crt", format="sdist" },
+ { path="tests/tls/*.pem", format="sdist" },
{ path="util/**", format="sdist" },
{ path="setup.cfg", format="sdist" },
]
curio = {version="^1.2", optional=true}
sniffio = {version="^1.1", optional=true}
wmi = {version="^1.5.1", optional=true}
+aioquic = {version="^0.9.20", optional=true}
[tool.poetry.dev-dependencies]
pytest = ">=5.4.1,<8"
trio = ['trio']
curio = ['curio', 'sniffio']
wmi = ['wmi']
+doq = ['aioquic']
[build-system]
requires = ["poetry-core"]
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+try:
+ import asyncio
+ import socket
+ import struct
+ import threading
+
+ import aioquic.asyncio
+ import aioquic.asyncio.server
+ import aioquic.quic.configuration
+ import aioquic.quic.events
+
+ import dns.asyncquery
+ import dns.message
+ import dns.rcode
+
+ from tests.util import here
+
+ have_quic = True
+
+ class Request:
+ def __init__(self, message, wire):
+ self.message = message
+ self.wire = wire
+
+ @property
+ def question(self):
+ return self.message.question[0]
+
+ @property
+ def qname(self):
+ return self.question.name
+
+ @property
+ def qclass(self):
+ return self.question.rdclass
+
+ @property
+ def qtype(self):
+ return self.question.rdtype
+
+ class NanoQuic(aioquic.asyncio.QuicConnectionProtocol):
+ def quic_event_received(self, event):
+ # This is a bit hackish and not fully general, but this is a test server!
+ if isinstance(event, aioquic.quic.events.StreamDataReceived):
+ data = bytes(event.data)
+ (wire_len,) = struct.unpack("!H", data[:2])
+ wire = self.handle_wire(data[2 : 2 + wire_len])
+ if wire is not None:
+ self._quic.send_stream_data(event.stream_id, wire, end_stream=True)
+
+ def handle(self, request):
+ r = dns.message.make_response(request.message)
+ r.set_rcode(dns.rcode.REFUSED)
+ return r
+
+ def handle_wire(self, wire):
+ response = None
+ try:
+ q = dns.message.from_wire(wire)
+ except dns.message.ShortHeader:
+ return
+ except Exception as e:
+ try:
+ q = dns.message.from_wire(wire, question_only=True)
+ response = dns.message.make_response(q)
+ response.set_rcode(dns.rcode.FORMERR)
+ except Exception:
+ return
+ if response is None:
+ try:
+ request = Request(q, wire)
+ response = self.handle(request)
+ except Exception:
+ response = dns.message.make_response(q)
+ response.set_rcode(dns.rcode.SERVFAIL)
+ wire = response.to_wire()
+ return struct.pack("!H", len(wire)) + wire
+
+ class Server(threading.Thread):
+ def __init__(self):
+ super().__init__()
+ self.transport = None
+ self.protocol = None
+ self.left = None
+ self.right = None
+
+ def __enter__(self):
+ self.left, self.right = socket.socketpair()
+ self.start()
+
+ def __exit__(self, ex_ty, ex_va, ex_tr):
+ if self.protocol is not None:
+ self.protocol.close()
+ if self.transport is not None:
+ self.transport.close()
+ if self.left:
+ self.left.close()
+ if self.is_alive():
+ self.join()
+ if self.right:
+ self.right.close()
+
+ async def arun(self):
+ reader, _ = await asyncio.open_connection(sock=self.right)
+ conf = aioquic.quic.configuration.QuicConfiguration(
+ alpn_protocols=["doq"],
+ is_client=False,
+ )
+ conf.load_cert_chain(here("tls/public.crt"), here("tls/private.pem"))
+ loop = asyncio.get_event_loop()
+ (self.transport, self.protocol) = await loop.create_datagram_endpoint(
+ lambda: aioquic.asyncio.server.QuicServer(
+ configuration=conf, create_protocol=NanoQuic
+ ),
+ local_addr=("127.0.0.1", 8853),
+ )
+ try:
+ await reader.read(1)
+ except Exception:
+ pass
+
+ def run(self):
+ asyncio.run(self.arun())
+
+except ImportError:
+ have_quic = False
+
+ class NanoQuic:
+ pass
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import asyncio
+import sys
+
+import pytest
+
+import dns.asyncbackend
+import dns.asyncquery
+import dns.message
+import dns.query
+import dns.rcode
+
+from .util import here
+
+try:
+ from .nanoquic import Server
+
+ _nanoquic_available = True
+except ImportError:
+ _nanoquic_available = False
+
+ class Server(object):
+ pass
+
+
+@pytest.mark.skipif(not _nanoquic_available, reason="requires nanoquic")
+def test_basic_sync():
+ with Server() as server:
+ q = dns.message.make_query("www.example.", "A")
+ r = dns.query.quic(q, "127.0.0.1", port=8853, verify=here("tls/ca.crt"))
+ assert r.rcode() == dns.rcode.REFUSED
+
+
+async def amain():
+ q = dns.message.make_query("www.example.", "A")
+ r = await dns.asyncquery.quic(q, "127.0.0.1", port=8853, verify=here("tls/ca.crt"))
+ assert r.rcode() == dns.rcode.REFUSED
+
+
+def test_basic_asyncio():
+ dns.asyncbackend.set_default_backend("asyncio")
+ with Server() as server:
+ asyncio.run(amain())
+
+
+try:
+ import trio
+
+ def test_basic_trio():
+ dns.asyncbackend.set_default_backend("trio")
+ with Server() as server:
+ trio.run(amain)
+
+except ImportError:
+ pass
--- /dev/null
+-----BEGIN CERTIFICATE-----
+MIIDTDCCAjSgAwIBAgIUUCWxpsMnzETqwNKJ38le9z7oFEEwDQYJKoZIhvcNAQEL
+BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMDcwOTIyMjQw
+N1oXDTMyMDcwNjIyMjQzN1owHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3Jn
+MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0AMlXDsx/7Kis4lUhAML
+yaL4wtvhPGnqz20Gnhd/b2uAjZbtLtKDG2aRC0QtHL6N0vfBhj+KUV/unT60Mf7G
+Pm2Z8fOxiwh/UJ8oxoJe59izklrwM0PL2iR21OMCCsiYcjiOOx75RUZ/6KEGMTgd
+3wvqwEV320yd3WInkdO72n9jlQTN3VtwLwkIkSbINiuUCKgP9hy28K7HjMHvEIlf
+QZfh9wIHhbqs/JP3dirRL7MKWFAv3MlmMffb/6NBBFb6FaRjS6WjojD8qaSTr14/
+tyqrK7zL32npKm/TbzxC8hFwYdwd3HURgpWInA6CRIcyZM/k4y7dHQlI4ID7hmcC
+1QIDAQABo4GDMIGAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0G
+A1UdDgQWBBQrNPKeL6rBhPV+Eb1RnvIkeax5sDAfBgNVHSMEGDAWgBQrNPKeL6rB
+hPV+Eb1RnvIkeax5sDAdBgNVHREEFjAUghJxdWljLmRuc3B5dGhvbi5vcmcwDQYJ
+KoZIhvcNAQELBQADggEBAADpAtDvceOrhn5FReYip9DlTW7KKrRDDFCo0SNdhvN3
+6mU/Hn3jNXYu9Ym3NDVL8q9UWzLRcSNLUo1qjkK3aOlgwcO6PuGKXukF7Zdd8wVb
+pPdUqooBmj6akqmNvmloZyDmQ+aXcYhR83hcEHFOK+C7pGLqSFChN1mgDT1/mgBk
+pODOZkcLtZI8YJyQ2sn3WhUJS52D6xfmPigliUcYqi6i+w1vxD45QilWbvqCwnN/
+6qmb3JQsMf+3MCtogVcSZjE9cf4CwlmKqgMxsBKz+/Qk9YPMpDuecEbd76L+Htdl
+HWuDlemBzyhd5qO5y/UGarqmuh3MgkOdFVQWAUygcCM=
+-----END CERTIFICATE-----
--- /dev/null
+-----BEGIN PRIVATE KEY-----
+MC4CAQAwBQYDK2VwBCIEIL2OxuOo+awfhPvvm82EBZ4VA6ULQHlebxGCamZ/H5Rt
+-----END PRIVATE KEY-----
--- /dev/null
+-----BEGIN CERTIFICATE-----
+MIICZjCCAU6gAwIBAgIUBTlEzhtkXYQvZl5CYRNBxOG4GpEwDQYJKoZIhvcNAQEL
+BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMTAwOTE2MjYw
+OFoXDTMwMTIyNjE2MjYzOFowFDESMBAGA1UEAxMJbG9jYWxob3N0MCowBQYDK2Vw
+AyEAKpQbO2JXhCGnQs2MrWmGBK5LcmJMWPXCzM2PfWbo1TyjgaAwgZ0wDgYDVR0P
+AQH/BAQDAgOoMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjAdBgNVHQ4E
+FgQUM2pZy8pH78CvP+FnuF190KEJkjUwHwYDVR0jBBgwFoAUKzTyni+qwYT1fhG9
+UZ7yJHmsebAwLAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAA
+AAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQA0JlNLrLz3ajCzSVfOQsUdd3a3wR7Q
+Dr28mYoDHSY9mhnJ9IQeInmGvPMLA4dgiRPFqxWsKh+lxzZObkbMjf1IAIVykfh6
+LynePm58/lnRrhdvf8vFfccuTyeb2aD0ZBA/RyhZam79J6JjRRovkSj9TyIqKfif
+6T6QWXOXwAF89rH8YHAKnRSl32pqZuDhOnM0Ien+Sa6KpCvgIDogHQxIVbe1egZl
+2Ec0LVQUaXhoICd1c6xoRoAa5UzDFJ7ujeu1XNGWKIiXESlcIo7SZjzusL2p5vv/
+frM+r43khtZ4s+F70A+B3AndcVSeKTQ5KlftN9CBuiQoYzhY29NmL93X
+-----END CERTIFICATE-----
+-----BEGIN CERTIFICATE-----
+MIIDTDCCAjSgAwIBAgIUUCWxpsMnzETqwNKJ38le9z7oFEEwDQYJKoZIhvcNAQEL
+BQAwHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3JnMB4XDTIyMDcwOTIyMjQw
+N1oXDTMyMDcwNjIyMjQzN1owHTEbMBkGA1UEAxMScXVpYy5kbnNweXRob24ub3Jn
+MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA0AMlXDsx/7Kis4lUhAML
+yaL4wtvhPGnqz20Gnhd/b2uAjZbtLtKDG2aRC0QtHL6N0vfBhj+KUV/unT60Mf7G
+Pm2Z8fOxiwh/UJ8oxoJe59izklrwM0PL2iR21OMCCsiYcjiOOx75RUZ/6KEGMTgd
+3wvqwEV320yd3WInkdO72n9jlQTN3VtwLwkIkSbINiuUCKgP9hy28K7HjMHvEIlf
+QZfh9wIHhbqs/JP3dirRL7MKWFAv3MlmMffb/6NBBFb6FaRjS6WjojD8qaSTr14/
+tyqrK7zL32npKm/TbzxC8hFwYdwd3HURgpWInA6CRIcyZM/k4y7dHQlI4ID7hmcC
+1QIDAQABo4GDMIGAMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0G
+A1UdDgQWBBQrNPKeL6rBhPV+Eb1RnvIkeax5sDAfBgNVHSMEGDAWgBQrNPKeL6rB
+hPV+Eb1RnvIkeax5sDAdBgNVHREEFjAUghJxdWljLmRuc3B5dGhvbi5vcmcwDQYJ
+KoZIhvcNAQELBQADggEBAADpAtDvceOrhn5FReYip9DlTW7KKrRDDFCo0SNdhvN3
+6mU/Hn3jNXYu9Ym3NDVL8q9UWzLRcSNLUo1qjkK3aOlgwcO6PuGKXukF7Zdd8wVb
+pPdUqooBmj6akqmNvmloZyDmQ+aXcYhR83hcEHFOK+C7pGLqSFChN1mgDT1/mgBk
+pODOZkcLtZI8YJyQ2sn3WhUJS52D6xfmPigliUcYqi6i+w1vxD45QilWbvqCwnN/
+6qmb3JQsMf+3MCtogVcSZjE9cf4CwlmKqgMxsBKz+/Qk9YPMpDuecEbd76L+Htdl
+HWuDlemBzyhd5qO5y/UGarqmuh3MgkOdFVQWAUygcCM=
+-----END CERTIFICATE-----