]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Support trio, curio, and asyncio with one API!
authorBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 01:50:30 +0000 (18:50 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 01:50:30 +0000 (18:50 -0700)
16 files changed:
dns/__init__.py
dns/_asyncbackend.py [new file with mode: 0644]
dns/_asyncio_backend.py [new file with mode: 0644]
dns/_curio_backend.py [new file with mode: 0644]
dns/_trio_backend.py [new file with mode: 0644]
dns/asyncbackend.py [new file with mode: 0644]
dns/asyncquery.py [new file with mode: 0644]
dns/asyncresolver.py [moved from dns/trio/resolver.py with 81% similarity]
dns/trio/__init__.py [deleted file]
dns/trio/query.py [deleted file]
dns/trio/query.pyi [deleted file]
dns/trio/resolver.pyi [deleted file]
pyproject.toml
setup.py
tests/test_async.py [new file with mode: 0644]
tests/test_trio.py [deleted file]

index d5cadb89f4115707e48e4ff66859963c2e36d058..6412fb5e19eaa78480894dad60e07619524d3cef 100644 (file)
@@ -18,6 +18,8 @@
 """dnspython DNS toolkit"""
 
 __all__ = [
+    'asyncquery.py',
+    'asyncresolver.py',
     'dnssec',
     'e164',
     'edns',
diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py
new file mode 100644 (file)
index 0000000..9bfdaba
--- /dev/null
@@ -0,0 +1,77 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+
+import dns.inet
+
+
+# This is a nullcontext for both sync and async
+
+class NullContext:
+    def __init__(self, enter_result=None):
+        self.enter_result = enter_result
+
+    def __enter__(self):
+        return self.enter_result
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        pass
+
+    async def __aenter__(self):
+        return self.enter_result
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        pass
+
+
+# This is handy, but should probably move somewhere else!
+
+def low_level_address_tuple(af, high_level_address_tuple):
+    address, port = high_level_address_tuple
+    if af == dns.inet.AF_INET:
+        return (address, port)
+    elif af == dns.inet.AF_INET6:
+        ai_flags = socket.AI_NUMERICHOST
+        ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
+        return tup
+    else:
+        raise NotImplementedError(f'unknown address family {af}')
+
+
+# These are declared here so backends can import them without creating
+# circular dependencies with dns.asyncbackend.
+
+class Socket:
+    async def close(self):
+        pass
+
+    async def __aenter__(self):
+        pass
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        await self.close()
+
+
+class DatagramSocket(Socket):
+    async def sendto(self, what, destination, timeout):
+        pass
+
+    async def recvfrom(self, size, timeout):
+        pass
+
+
+class StreamSocket(Socket):
+    async def sendall(self, what, destination, timeout):
+        pass
+
+    async def recv(self, size, timeout):
+        pass
+
+
+class Backend:
+    def name(self):
+        return 'unknown'
+
+    async def make_socket(self, af, socktype, proto=0,
+                          source=None, raw_source=None,
+                          ssl_context=None, server_hostname=None):
+        raise NotImplementedError
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py
new file mode 100644 (file)
index 0000000..42c6e66
--- /dev/null
@@ -0,0 +1,118 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""asyncio library query support"""
+
+import socket
+import asyncio
+
+import dns._asyncbackend
+import dns.exception
+
+class _DatagramProtocol:
+    def __init__(self):
+        self.transport = None
+        self.recvfrom = None
+
+    def connection_made(self, transport):
+        self.transport = transport
+
+    def datagram_received(self, data, addr):
+        if self.recvfrom:
+            self.recvfrom.set_result((data, addr))
+            self.recvfrom = None
+
+    def error_received(self, exc):
+        if self.recvfrom:
+            self.recvfrom.set_exception(exc)
+
+    def connection_lost(self, exc):
+        if self.recvfrom:
+            self.recvfrom.set_exception(exc)
+
+    def close(self):
+        self.transport.close()
+
+
+async def _maybe_wait_for(awaitable, timeout):
+    if timeout:
+        try:
+            return await asyncio.wait_for(awaitable, timeout)
+        except asyncio.TimeoutError:
+            raise dns.exception.Timeout(timeout=timeout)
+    else:
+        return await awaitable
+
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
+    def __init__(self, family, transport, protocol):
+        self.family = family
+        self.transport = transport
+        self.protocol = protocol
+
+    async def sendto(self, what, destination, timeout):
+        # no timeout for asyncio sendto
+        self.transport.sendto(what, destination)
+
+    async def recvfrom(self, timeout):
+        done = asyncio.get_running_loop().create_future()
+        assert self.protocol.recvfrom is None
+        self.protocol.recvfrom = done
+        await _maybe_wait_for(done, timeout)
+        return done.result()
+
+    async def close(self):
+        self.protocol.close()
+
+    async def getpeername(self):
+        return self.transport.get_extra_info('peername')
+
+
+class StreamSocket(dns._asyncbackend.DatagramSocket):
+    def __init__(self, af, reader, writer):
+        self.family = af
+        self.reader = reader
+        self.writer = writer
+
+    async def sendall(self, what, timeout):
+        self.writer.write(what),
+        return await _maybe_wait_for(self.writer.drain(), timeout)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def recv(self, count, timeout):
+        return await _maybe_wait_for(self.reader.read(count),
+                                     timeout)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def close(self):
+        self.writer.close()
+        await self.writer.wait_closed()
+
+    async def getpeername(self):
+        return self.reader.get_extra_info('peername')
+
+
+class Backend(dns._asyncbackend.Backend):
+    def name(self):
+        return 'asyncio'
+
+    async def make_socket(self, af, socktype, proto=0,
+                          source=None, destination=None, timeout=None,
+                          ssl_context=None, server_hostname=None):
+        loop = asyncio.get_running_loop()
+        if socktype == socket.SOCK_DGRAM:
+            transport, protocol = await loop.create_datagram_endpoint(
+                _DatagramProtocol, source, family=af,
+                proto=proto)
+            return DatagramSocket(af, transport, protocol)
+        elif socktype == socket.SOCK_STREAM:
+            (r, w) = await _maybe_wait_for(
+                asyncio.open_connection(destination[0],
+                                        destination[1],
+                                        family=af,
+                                        proto=proto,
+                                        local_addr=source),
+                timeout)
+            return StreamSocket(af, r, w)
+        raise NotImplementedError(f'unsupported socket type {socktype}')
+
+    async def sleep(self, interval):
+        await asyncio.sleep(interval)
diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py
new file mode 100644 (file)
index 0000000..e37fea3
--- /dev/null
@@ -0,0 +1,92 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""curio async I/O library query support"""
+
+import socket
+import curio
+import curio.socket  # type: ignore
+
+import dns._asyncbackend
+import dns.exception
+
+
+def _maybe_timeout(timeout):
+    if timeout:
+        return curio.ignore_after(timeout)
+    else:
+        return dns._asyncbackend.NullContext()
+
+
+# for brevity
+_lltuple = dns._asyncbackend.low_level_address_tuple
+
+
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
+    def __init__(self, socket):
+        self.socket = socket
+        self.family = socket.family
+
+    async def sendto(self, what, destination, timeout):
+        async with _maybe_timeout(timeout):
+            return await self.socket.sendto(what, destination)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def recvfrom(self, timeout):
+        async with _maybe_timeout(timeout):
+            return await self.socket.recvfrom(65535)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def close(self):
+        await self.socket.close()
+
+    async def getpeername(self):
+        return self.socket.getpeername()
+
+
+class StreamSocket(dns._asyncbackend.DatagramSocket):
+    def __init__(self, socket):
+        self.socket = socket
+        self.family = socket.family
+
+    async def sendall(self, what, timeout):
+        async with _maybe_timeout(timeout):
+            return await self.socket.sendall(what)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def recv(self, count, timeout):
+        async with _maybe_timeout(timeout):
+            return await self.socket.recv(count)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def close(self):
+        await self.socket.close()
+
+    async def getpeername(self):
+        return self.socket.getpeername()
+
+
+class Backend(dns._asyncbackend.Backend):
+    def name(self):
+        return 'curio'
+
+    async def make_socket(self, af, socktype, proto=0,
+                          source=None, destination=None, timeout=None,
+                          ssl_context=None, server_hostname=None):
+        s = curio.socket.socket(af, socktype, proto)
+        try:
+            if source:
+                s.bind(_lltuple(af, source))
+            if socktype == socket.SOCK_STREAM:
+                with _maybe_timeout(timeout):
+                    await s.connect(_lltuple(af, destination))
+        except Exception:
+            await s.close()
+            raise
+        if socktype == socket.SOCK_DGRAM:
+            return DatagramSocket(s)
+        elif socktype == socket.SOCK_STREAM:
+            return StreamSocket(s)
+        raise NotImplementedError(f'unsupported socket type {socktype}')
+
+    async def sleep(self, interval):
+        await curio.sleep(interval)
diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py
new file mode 100644 (file)
index 0000000..bcaddcc
--- /dev/null
@@ -0,0 +1,92 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""trio async I/O library query support"""
+
+import socket
+import trio
+import trio.socket  # type: ignore
+
+import dns._asyncbackend
+import dns.exception
+
+
+def _maybe_timeout(timeout):
+    if timeout:
+        return trio.move_on_after(timeout)
+    else:
+        return dns._asyncbackend.NullContext()
+
+
+# for brevity
+_lltuple = dns._asyncbackend.low_level_address_tuple
+
+
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
+    def __init__(self, socket):
+        self.socket = socket
+        self.family = socket.family
+
+    async def sendto(self, what, destination, timeout):
+        with _maybe_timeout(timeout):
+            return await self.socket.sendto(what, destination)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def recvfrom(self, timeout):
+        with _maybe_timeout(timeout):
+            return await self.socket.recvfrom(65535)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def close(self):
+        self.socket.close()
+
+    async def getpeername(self):
+        return self.socket.getpeername()
+
+
+class StreamSocket(dns._asyncbackend.DatagramSocket):
+    def __init__(self, family, stream):
+        self.family = family
+        self.stream = stream
+
+    async def sendall(self, what, timeout):
+        with _maybe_timeout(timeout):
+            return await self.stream.send_all(what)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def recv(self, count, timeout):
+        with _maybe_timeout(timeout):
+            return await self.stream.receive_some(count)
+        raise dns.exception.Timeout(timeout=timeout)
+
+    async def close(self):
+        await self.stream.aclose()
+
+    async def getpeername(self):
+        return self.stream.socket.getpeername()
+
+
+class Backend(dns._asyncbackend.Backend):
+    def name(self):
+        return 'trio'
+
+    async def make_socket(self, af, socktype, proto=0, source=None,
+                          destination=None, timeout=None,
+                          ssl_context=None, server_hostname=None):
+        s = trio.socket.socket(af, socktype, proto)
+        try:
+            if source:
+                await s.bind(_lltuple(af, source))
+            if socktype == socket.SOCK_STREAM:
+                with _maybe_timeout(timeout):
+                    await s.connect(_lltuple(af, destination))
+        except Exception:
+            s.close()
+            raise
+        if socktype == socket.SOCK_DGRAM:
+            return DatagramSocket(s)
+        elif socktype == socket.SOCK_STREAM:
+            return StreamSocket(af, trio.SocketStream(s))
+        raise NotImplementedError(f'unsupported socket type {socktype}')
+
+    async def sleep(self, interval):
+        await trio.sleep(interval)
diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py
new file mode 100644 (file)
index 0000000..92a1ae3
--- /dev/null
@@ -0,0 +1,43 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+
+from dns._asyncbackend import Socket, DatagramSocket, \
+    StreamSocket, Backend, low_level_address_tuple
+
+
+_default_backend = None
+
+
+def get_default_backend():
+    if _default_backend:
+        return _default_backend
+
+    return set_default_backend(sniff())
+
+
+def sniff():
+    name = 'asyncio'
+    try:
+        import sniffio
+        name = sniffio.current_async_library()
+    except Exception:
+        pass
+    return name
+
+
+def set_default_backend(name):
+    global _default_backend
+
+    if name == 'trio':
+        import dns._trio_backend
+        _default_backend = dns._trio_backend.Backend()
+    elif name == 'curio':
+        import dns._curio_backend
+        _default_backend = dns._curio_backend.Backend()
+    elif name == 'asyncio':
+        import dns._asyncio_backend
+        _default_backend = dns._asyncio_backend.Backend()
+    else:
+        raise NotImplementedException(f'unimplemented async backend {name}')
+
+    return _default_backend
diff --git a/dns/asyncquery.py b/dns/asyncquery.py
new file mode 100644 (file)
index 0000000..ed51fdc
--- /dev/null
@@ -0,0 +1,422 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+"""Talk to a DNS server."""
+
+import os
+import socket
+import struct
+import time
+import base64
+import ipaddress
+
+import dns.asyncbackend
+import dns.exception
+import dns.inet
+import dns.name
+import dns.message
+import dns.rcode
+import dns.rdataclass
+import dns.rdatatype
+
+from dns.query import _addresses_equal, _destination_and_source, \
+    _compute_times, UnexpectedSource
+
+
+# for brevity
+_lltuple = dns.asyncbackend.low_level_address_tuple
+
+
+def _source_tuple(af, address, port):
+    # Make a high level source tuple, or return None if address and port
+    # are both None
+    if address or port:
+        if address is None:
+            if af == socket.AF_INET:
+                address = '0.0.0.0'
+            elif af == socket.AF_INET6:
+                address = '::'
+            else:
+                raise NotImplementedError(f'unknown address family {af}')
+        return (address, port)
+    else:
+        return None
+
+
+def _timeout(expiration, now=None):
+    if expiration:
+        if not now:
+            now = time.time()
+        return max(expiration - now, 0)
+    else:
+        return None
+
+
+async def send_udp(sock, what, destination, expiration=None):
+    """Send a DNS message to the specified UDP socket.
+
+    *sock*, a ``dns.asyncbackend.DatagramSocket``.
+
+    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
+
+    *destination*, a destination tuple appropriate for the address family
+    of the socket, specifying where to send the query.
+
+    *expiration*, a ``float`` or ``None``, the absolute time at which
+    a timeout exception should be raised.  If ``None``, no timeout will
+    occur.
+
+    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
+    """
+
+    if isinstance(what, dns.message.Message):
+        what = what.to_wire()
+    sent_time = time.time()
+    n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
+    return (n, sent_time)
+
+
+async def receive_udp(sock, destination, expiration=None,
+                      ignore_unexpected=False, one_rr_per_rrset=False,
+                      keyring=None, request_mac=b'', ignore_trailing=False,
+                      raise_on_truncation=False):
+    """Read a DNS message from a UDP socket.
+
+    *sock*, a ``dns.asyncbackend.DatagramSocket``.
+
+    *destination*, a destination tuple appropriate for the address family
+    of the socket, specifying where the associated query was sent.
+
+    *expiration*, a ``float`` or ``None``, the absolute time at which
+    a timeout exception should be raised.  If ``None``, no timeout will
+    occur.
+
+    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
+    unexpected sources.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *keyring*, a ``dict``, the keyring to use for TSIG.
+
+    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
+    the TC bit is set.
+
+    Raises if the message is malformed, if network errors occur, of if
+    there is a timeout.
+
+    Returns a ``dns.message.Message`` object.
+    """
+
+    wire = b''
+    while 1:
+        (wire, from_address) = await sock.recvfrom(65535)
+        if _addresses_equal(sock.family, from_address, destination) or \
+           (dns.inet.is_multicast(destination[0]) and
+            from_address[1:] == destination[1:]):
+            break
+        if not ignore_unexpected:
+            raise UnexpectedSource('got a response from '
+                                   '%s instead of %s' % (from_address,
+                                                         destination))
+    received_time = time.time()
+    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset,
+                              ignore_trailing=ignore_trailing,
+                              raise_on_truncation=raise_on_truncation)
+    return (r, received_time)
+
+async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
+              ignore_unexpected=False, one_rr_per_rrset=False,
+              ignore_trailing=False, raise_on_truncation=False, sock=None,
+              backend=None):
+    """Return the response obtained after sending a query via UDP.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+    query times out.  If ``None``, the default, wait forever.
+
+    *port*, an ``int``, the port send the message to.  The default is 53.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
+    unexpected sources.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
+    the TC bit is set.
+
+    *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
+    the socket to use for the query.  If ``None``, the default, a
+    socket is created.  Note that if a socket is provided, the
+    *source* and *source_port* are ignored.
+
+    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
+    the default, then dnspython will use the default backend.
+
+    Returns a ``dns.message.Message``.
+    """
+    if not backend:
+        backend = dns.asyncbackend.get_default_backend()
+    wire = q.to_wire()
+    (begin_time, expiration) = _compute_times(timeout)
+    s = None
+    try:
+        if sock:
+            s = sock
+        else:
+            af = dns.inet.af_for_address(where)
+            stuple = _source_tuple(af, source, source_port)
+            s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
+            destination = _lltuple(af, (where, port))
+        await send_udp(s, wire, destination, expiration)
+        (r, received_time) = await receive_udp(s, destination, expiration,
+                                               ignore_unexpected,
+                                               one_rr_per_rrset,
+                                               q.keyring, q.mac,
+                                               ignore_trailing,
+                                               raise_on_truncation)
+        r.time = received_time - begin_time
+        if not q.is_response(r):
+            raise BadResponse
+        return r
+    finally:
+        if not sock and s:
+            await s.close()
+
+async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
+                            source_port=0, ignore_unexpected=False,
+                            one_rr_per_rrset=False, ignore_trailing=False,
+                            udp_sock=None, tcp_sock=None):
+    """Return the response to the query, trying UDP first and falling back
+    to TCP if UDP results in a truncated response.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+    query times out.  If ``None``, the default, wait forever.
+
+    *port*, an ``int``, the port send the message to.  The default is 53.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
+    unexpected sources.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
+    the socket to use for the UDP query.  If ``None``, the default, a
+    socket is created.  Note that if a socket is provided the *source*
+    and *source_port* are ignored for the UDP query.
+
+    *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
+    socket to use for the TCP query.  If ``None``, the default, a
+    socket is created.  Note that if a socket is provided *where*,
+    *source* and *source_port* are ignored for the TCP query.
+
+    Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
+    if and only if TCP was used.
+    """
+    try:
+        response = await udp(q, where, timeout, port, source, source_port,
+                             ignore_unexpected, one_rr_per_rrset,
+                             ignore_trailing, True, udp_sock)
+        return (response, False)
+    except dns.message.Truncated:
+        response = await tcp(q, where, timeout, port, source, source_port,
+                             one_rr_per_rrset, ignore_trailing, tcp_sock)
+        return (response, True)
+
+
+
+async def send_tcp(sock, what, expiration=None):
+    """Send a DNS message to the specified TCP socket.
+
+    *sock*, a ``socket``.
+
+    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
+
+    *expiration*, a ``float`` or ``None``, the absolute time at which
+    a timeout exception should be raised.  If ``None``, no timeout will
+    occur.
+
+    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
+    """
+
+    if isinstance(what, dns.message.Message):
+        what = what.to_wire()
+    l = len(what)
+    # copying the wire into tcpmsg is inefficient, but lets us
+    # avoid writev() or doing a short write that would get pushed
+    # onto the net
+    tcpmsg = struct.pack("!H", l) + what
+    sent_time = time.time()
+    await sock.sendall(tcpmsg, expiration)
+    return (len(tcpmsg), sent_time)
+
+
+async def read_exactly(sock, count, expiration):
+    """Read the specified number of bytes from stream.  Keep trying until we
+    either get the desired amount, or we hit EOF.
+    """
+    s = b''
+    while count > 0:
+        n = await sock.recv(count, _timeout(expiration))
+        if n == b'':
+            raise EOFError
+        count = count - len(n)
+        s = s + n
+    return s
+
+
+async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
+                      keyring=None, request_mac=b'', ignore_trailing=False):
+    """Read a DNS message from a TCP socket.
+
+    *sock*, a ``socket``.
+
+    *expiration*, a ``float`` or ``None``, the absolute time at which
+    a timeout exception should be raised.  If ``None``, no timeout will
+    occur.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *keyring*, a ``dict``, the keyring to use for TSIG.
+
+    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    Raises if the message is malformed, if network errors occur, of if
+    there is a timeout.
+
+    Returns a ``dns.message.Message`` object.
+    """
+
+    ldata = await read_exactly(sock, 2, expiration)
+    (l,) = struct.unpack("!H", ldata)
+    wire = await read_exactly(sock, l, expiration)
+    received_time = time.time()
+    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset,
+                              ignore_trailing=ignore_trailing)
+    return (r, received_time)
+
+
+async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
+              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
+              backend=None):
+    """Return the response obtained after sending a query via TCP.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address, where
+    to send the message.
+
+    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+    query times out.  If ``None``, the default, wait forever.
+
+    *port*, an ``int``, the port send the message to.  The default is 53.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
+    socket to use for the query.  If ``None``, the default, a socket
+    is created.  Note that if a socket is provided
+    *where*, *port*, *source* and *source_port* are ignored.
+
+    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
+    the default, then dnspython will use the default backend.
+
+    Returns a ``dns.message.Message``.
+    """
+
+    if not backend:
+        backend = dns.asyncbackend.get_default_backend()
+    wire = q.to_wire()
+    (begin_time, expiration) = _compute_times(timeout)
+    s = None
+    try:
+        if sock:
+            # Verify that the socket is connected, as if it's not connected,
+            # it's not writable, and the polling in send_tcp() will time out or
+            # hang forever.
+            await sock.getpeername()
+            s = sock
+        else:
+            # These are simple (address, port) pairs, not
+            # family-dependent tuples you pass to lowlevel socket
+            # code.
+            af = dns.inet.af_for_address(where)
+            stuple = _source_tuple(af, source, source_port)
+            dtuple = (where, port)
+            s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
+                                          dtuple, timeout)
+        await send_tcp(s, wire, expiration)
+        (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
+                                               q.keyring, q.mac,
+                                               ignore_trailing)
+        r.time = received_time - begin_time
+        if not q.is_response(r):
+            raise BadResponse
+        return r
+    finally:
+        if not sock and s:
+            await s.close()
similarity index 81%
rename from dns/trio/resolver.py
rename to dns/asyncresolver.py
index 07e70f97d6f9f06a24adf8172579f38a065a4c02..b45a35b5af59da62808efb62b1b62c2e1454a7b5 100644 (file)
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
-"""trio async I/O library DNS stub resolver."""
+"""Asynchronous DNS stub resolver."""
 
-import trio
+import time
 
+import dns.asyncbackend
+import dns.asyncquery
 import dns.exception
 import dns.query
 import dns.resolver
-import dns.trio.query
 
 # import some resolver symbols for brevity
 from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
 
-# we do this for indentation reasons below
-_udp = dns.trio.query.udp
-_stream = dns.trio.query.stream
 
-class TooManyAttempts(dns.exception.DNSException):
-    """A resolution had too many unsuccessful attempts."""
+# for identation purposes below
+_udp = dns.asyncquery.udp
+_tcp = dns.asyncquery.tcp
+
 
 class Resolver(dns.resolver.Resolver):
 
     async def resolve(self, qname, rdtype=dns.rdatatype.A,
                       rdclass=dns.rdataclass.IN,
                       tcp=False, source=None, raise_on_no_answer=True,
-                      source_port=0, search=None):
+                      source_port=0, lifetime=None, search=None,
+                      backend=None):
         """Query nameservers asynchronously to find the answer to the question.
 
         The *qname*, *rdtype*, and *rdclass* parameters may be objects
@@ -62,6 +63,9 @@ class Resolver(dns.resolver.Resolver):
 
         *source_port*, an ``int``, the port from which to send the message.
 
+        *lifetime*, a ``float``, how many seconds a query should run
+         before timing out.
+
         *search*, a ``bool`` or ``None``, determines whether the
         search list configured in the system's resolver configuration
         are used for relative names, and whether the resolver's domain
@@ -69,6 +73,9 @@ class Resolver(dns.resolver.Resolver):
         which causes the value of the resolver's
         ``use_search_by_default`` attribute to be used.
 
+        *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
+        the default, then dnspython will use the default backend.
+
         Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist.
 
         Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after
@@ -87,6 +94,9 @@ class Resolver(dns.resolver.Resolver):
 
         resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
                                               raise_on_no_answer, search)
+        if not backend:
+            backend = dns.asyncbackend.get_default_backend()
+        start = time.time()
         while True:
             (request, answer) = resolution.next_request()
             # Note we need to say "if answer is not None" and not just
@@ -101,30 +111,24 @@ class Resolver(dns.resolver.Resolver):
             while not done:
                 (nameserver, port, tcp, backoff) = resolution.next_nameserver()
                 if backoff:
-                    loops += 1
-                    if loops >= 5:
-                        raise TooManyAttempts
-                    await trio.sleep(backoff)
+                    await backend.sleep(backoff)
+                timeout = self._compute_timeout(start, lifetime)
                 try:
-                    with trio.fail_after(self.timeout):
-                        if dns.inet.is_address(nameserver):
-                            if tcp:
-                                response = await \
-                                    _stream(request, nameserver,
-                                            port=port,
-                                            source=source,
-                                            source_port=source_port)
-                            else:
-                                response = await \
-                                    _udp(request,
-                                         nameserver,
-                                         port=port,
-                                         source=source,
-                                         source_port=source_port,
-                                         raise_on_truncation=True)
+                    if dns.inet.is_address(nameserver):
+                        if tcp:
+                            response = await _tcp(request, nameserver,
+                                                  timeout, port,
+                                                  source, source_port,
+                                                  backend=backend)
                         else:
-                            # We don't do DoH yet.
-                            raise NotImplementedError
+                            response = await _udp(request, nameserver,
+                                                  timeout, port,
+                                                  source, source_port,
+                                                  raise_on_truncation=True,
+                                                  backend=backend)
+                    else:
+                        # We don't do DoH yet.
+                        raise NotImplementedError
                 except Exception as ex:
                     (_, done) = resolution.query_result(None, ex)
                     continue
@@ -191,7 +195,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
     This is a convenience function that uses the default resolver
     object to make the query.
 
-    See ``dns.trio.resolver.Resolver.resolve`` for more information on the
+    See ``dns.asyncresolver.Resolver.resolve`` for more information on the
     parameters.
     """
 
@@ -203,7 +207,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
 async def resolve_address(ipaddr, *args, **kwargs):
     """Use a resolver to run a reverse query for PTR records.
 
-    See ``dns.trio.resolver.Resolver.resolve_address`` for more
+    See ``dns.asyncresolver.Resolver.resolve_address`` for more
     information on the parameters.
     """
 
@@ -220,7 +224,7 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
 
     *tcp*, a ``bool``.  If ``True``, use TCP to make the query.
 
-    *resolver*, a ``dns.trio.resolver.Resolver`` or ``None``, the
+    *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the
     resolver to use.  If ``None``, the default resolver is used.
 
     Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS
diff --git a/dns/trio/__init__.py b/dns/trio/__init__.py
deleted file mode 100644 (file)
index 744f880..0000000
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-"""trio async I/O library helpers"""
-
-__all__ = [
-    'query',
-    'resolver',
-]
diff --git a/dns/trio/query.py b/dns/trio/query.py
deleted file mode 100644 (file)
index a3a28fe..0000000
+++ /dev/null
@@ -1,374 +0,0 @@
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-"""trio async I/O library query support"""
-
-import contextlib
-import socket
-import struct
-import time
-import trio
-import trio.socket  # type: ignore
-
-import dns.exception
-import dns.inet
-import dns.name
-import dns.message
-import dns.query
-import dns.rcode
-import dns.rdataclass
-import dns.rdatatype
-
-# import query symbols for compatibility and brevity
-from dns.query import ssl, UnexpectedSource, BadResponse
-
-# Function used to create a socket.  Can be overridden if needed in special
-# situations.
-socket_factory = trio.socket.socket
-
-async def send_udp(sock, what, destination):
-    """Asynchronously send a DNS message to the specified UDP socket.
-
-    *sock*, a ``trio.socket.socket``.
-
-    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
-
-    *destination*, a destination tuple appropriate for the address family
-    of the socket, specifying where to send the query.
-
-    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
-    """
-
-    if isinstance(what, dns.message.Message):
-        what = what.to_wire()
-    sent_time = time.time()
-    n = await sock.sendto(what, destination)
-    return (n, sent_time)
-
-
-async def receive_udp(sock, destination, ignore_unexpected=False,
-                      one_rr_per_rrset=False, keyring=None, request_mac=b'',
-                      ignore_trailing=False, raise_on_truncation=False):
-    """Asynchronously read a DNS message from a UDP socket.
-
-    *sock*, a ``trio.socket.socket``.
-
-    *destination*, a destination tuple appropriate for the address family
-    of the socket, specifying where the associated query was sent.
-
-    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
-    unexpected sources.
-
-    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
-    RRset.
-
-    *keyring*, a ``dict``, the keyring to use for TSIG.
-
-    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
-
-    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
-    junk at end of the received message.
-
-    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
-    the TC bit is set.
-
-    Raises if the message is malformed, if network errors occur, of if
-    there is a timeout.
-
-    Returns a ``dns.message.Message`` object.
-    """
-
-    wire = b''
-    while True:
-        (wire, from_address) = await sock.recvfrom(65535)
-        if dns.query._addresses_equal(sock.family, from_address,
-                                      destination) or \
-           (dns.inet.is_multicast(destination[0]) and
-            from_address[1:] == destination[1:]):
-            break
-        if not ignore_unexpected:
-            raise UnexpectedSource('got a response from '
-                                   '%s instead of %s' % (from_address,
-                                                         destination))
-    received_time = time.time()
-    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
-                              one_rr_per_rrset=one_rr_per_rrset,
-                              ignore_trailing=ignore_trailing,
-                              raise_on_truncation=raise_on_truncation)
-    return (r, received_time)
-
-async def udp(q, where, port=53, source=None, source_port=0,
-              ignore_unexpected=False, one_rr_per_rrset=False,
-              ignore_trailing=False, raise_on_truncation=False,
-              sock=None):
-    """Asynchronously return the response obtained after sending a query
-    via UDP.
-
-    *q*, a ``dns.message.Message``, the query to send
-
-    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
-    to send the message.
-
-    *port*, an ``int``, the port send the message to.  The default is 53.
-
-    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
-    the source address.  The default is the wildcard address.
-
-    *source_port*, an ``int``, the port from which to send the message.
-    The default is 0.
-
-    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
-    unexpected sources.
-
-    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
-    RRset.
-
-    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
-    junk at end of the received message.
-
-    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
-    the TC bit is set.
-
-    *sock*, a ``trio.socket.socket``, or ``None``, the socket to use
-    for the query.  If ``None``, the default, a socket is created.  if
-    a socket is provided, the *source* and *source_port* are ignored.
-
-    Returns a ``dns.message.Message``.
-
-    """
-
-    wire = q.to_wire()
-    (af, destination, source) = \
-        dns.query._destination_and_source(None, where, port, source,
-                                          source_port)
-    # We can use an ExitStack here as exiting a trio.socket.socket does
-    # not await.
-    with contextlib.ExitStack() as stack:
-        if sock:
-            s = sock
-        else:
-            s = stack.enter_context(socket_factory(af, socket.SOCK_DGRAM, 0))
-            if source is not None:
-                await s.bind(source)
-        (_, sent_time) = await send_udp(s, wire, destination)
-        (r, received_time) = await receive_udp(s, destination,
-                                               ignore_unexpected,
-                                               one_rr_per_rrset, q.keyring,
-                                               q.mac, ignore_trailing,
-                                               raise_on_truncation)
-        if not q.is_response(r):
-            raise BadResponse
-        r.time = received_time - sent_time
-        return r
-
-async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
-                            source_port=0, ignore_unexpected=False,
-                            one_rr_per_rrset=False, ignore_trailing=False):
-    """Return the response to the query, trying UDP first and falling back
-    to TCP if UDP results in a truncated response.
-
-    *q*, a ``dns.message.Message``, the query to send
-
-    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
-    to send the message.
-
-    *port*, an ``int``, the port send the message to.  The default is 53.
-
-    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
-    the source address.  The default is the wildcard address.
-
-    *source_port*, an ``int``, the port from which to send the message.
-    The default is 0.
-
-    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
-    unexpected sources.
-
-    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
-    RRset.
-
-    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
-    junk at end of the received message.
-
-    Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
-    if and only if TCP was used.
-    """
-    try:
-        response = await udp(q, where, port, source, source_port,
-                             ignore_unexpected, one_rr_per_rrset,
-                             ignore_trailing, True)
-        return (response, False)
-    except dns.message.Truncated:
-        response = await stream(q, where, False, port, source, source_port,
-                                one_rr_per_rrset, ignore_trailing)
-
-        return (response, True)
-
-# pylint: disable=redefined-outer-name
-
-async def send_stream(stream, what):
-    """Asynchronously send a DNS message to the specified stream.
-
-    *stream*, a ``trio.abc.Stream``.
-
-    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
-
-    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
-    """
-
-    if isinstance(what, dns.message.Message):
-        what = what.to_wire()
-    l = len(what)
-    # copying the wire into tcpmsg is inefficient, but lets us
-    # avoid writev() or doing a short write that would get pushed
-    # onto the net
-    stream_message = struct.pack("!H", l) + what
-    sent_time = time.time()
-    await stream.send_all(stream_message)
-    return (len(stream_message), sent_time)
-
-async def read_exactly(stream, count):
-    """Read the specified number of bytes from stream.  Keep trying until we
-    either get the desired amount, or we hit EOF.
-    """
-    s = b''
-    while count > 0:
-        n = await stream.receive_some(count)
-        if n == b'':
-            raise EOFError
-        count = count - len(n)
-        s = s + n
-    return s
-
-async def receive_stream(stream, one_rr_per_rrset=False, keyring=None,
-                         request_mac=b'', ignore_trailing=False):
-    """Read a DNS message from a stream.
-
-    *stream*, a ``trio.abc.Stream``.
-
-    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
-    RRset.
-
-    *keyring*, a ``dict``, the keyring to use for TSIG.
-
-    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
-
-    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
-    junk at end of the received message.
-
-    Raises if the message is malformed, if network errors occur, of if
-    there is a timeout.
-
-    Returns a ``dns.message.Message`` object.
-    """
-
-    ldata = await read_exactly(stream, 2)
-    (l,) = struct.unpack("!H", ldata)
-    wire = await read_exactly(stream, l)
-    received_time = time.time()
-    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
-                              one_rr_per_rrset=one_rr_per_rrset,
-                              ignore_trailing=ignore_trailing)
-    return (r, received_time)
-
-async def stream(q, where, tls=False, port=None, source=None, source_port=0,
-                 one_rr_per_rrset=False, ignore_trailing=False,
-                 stream=None, ssl_context=None, server_hostname=None):
-    """Return the response obtained after sending a query using TCP or TLS.
-
-    *q*, a ``dns.message.Message``, the query to send.
-
-    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
-    to send the message.
-
-    *tls*, a ``bool``.  If ``False``, the default, the query will be
-    sent using TCP and *port* will default to 53.  If ``True``, the
-    query is sent using TLS, and *port* will default to 853.
-
-    *port*, an ``int``, the port send the message to.  The default is as
-    specified in the description for *tls*.
-
-    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
-    the source address.  The default is the wildcard address.
-
-    *source_port*, an ``int``, the port from which to send the message.
-    The default is 0.
-
-    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
-    RRset.
-
-    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
-    junk at end of the received message.
-
-    *stream*, a ``trio.abc.Stream``, or ``None``, the stream to use for
-    the query.  If ``None``, the default, a stream is created.  if a
-    socket is provided, it must be connected, and the *where*, *port*,
-    *tls*, *source*, *source_port*, *ssl_context*, and
-    *server_hostname* parameters are ignored.
-
-    *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
-    a TLS connection. If ``None``, the default, creates one with the default
-    configuration.  If this value is not ``None``, then the *tls* parameter
-    is treated as if it were ``True`` regardless of its value.
-
-    *server_hostname*, a ``str`` containing the server's hostname.  The
-    default is ``None``, which means that no hostname is known, and if an
-    SSL context is created, hostname checking will be disabled.
-
-    Returns a ``dns.message.Message``.
-
-    """
-    if ssl_context is not None:
-        tls = True
-    if port is None:
-        if tls:
-            port = 853
-        else:
-            port = 53
-    wire = q.to_wire()
-    # We'd like to be able to use an AsyncExitStack here, because
-    # unlike closing a socket, closing a stream requires an await, but
-    # that's a 3.7 feature, so we are forced to try ... finally.
-    sock = None
-    s = None
-    begin_time = time.time()
-    try:
-        if stream:
-            #
-            # Verify that the socket is connected, as if it's not connected,
-            # it's not writable, and the polling in send_tcp() will time out or
-            # hang forever.
-            if isinstance(stream, trio.SSLStream):
-                tsock = stream.transport_stream.socket
-            else:
-                tsock = stream.socket
-            tsock.getpeername()
-            s = stream
-        else:
-            (af, destination, source) = \
-                dns.query._destination_and_source(None, where, port, source,
-                                                  source_port)
-            sock = socket_factory(af, socket.SOCK_STREAM, 0)
-            if source is not None:
-                await sock.bind(source)
-            await sock.connect(destination)
-            s = trio.SocketStream(sock)
-            sock = None
-            if tls and ssl_context is None:
-                ssl_context = ssl.create_default_context()
-                if server_hostname is None:
-                    ssl_context.check_hostname = False
-            if ssl_context:
-                s = trio.SSLStream(s, ssl_context,
-                                   server_hostname=server_hostname)
-        await send_stream(s, wire)
-        (r, received_time) = await receive_stream(s, one_rr_per_rrset,
-                                                  q.keyring, q.mac,
-                                                  ignore_trailing)
-        if not q.is_response(r):
-            raise BadResponse
-        r.time = received_time - begin_time
-        return r
-    finally:
-        if sock:
-            sock.close()
-        if s and s != stream:
-            await s.aclose()
diff --git a/dns/trio/query.pyi b/dns/trio/query.pyi
deleted file mode 100644 (file)
index 0a5ab92..0000000
+++ /dev/null
@@ -1,33 +0,0 @@
-from typing import Optional, Dict, Any
-from . import rdatatype, rdataclass, name, message
-
-# If the ssl import works, then
-#
-#    error: Name 'ssl' already defined (by an import)
-#
-# is expected and can be ignored.
-try:
-    import ssl
-except ImportError:
-    class ssl:    # type: ignore
-        SSLContext : Dict = {}
-
-import trio
-
-def udp(q : message.Message, where : str, port=53,
-        source : Optional[str] = None, source_port : Optional[int] = 0,
-        ignore_unexpected : Optional[bool] = False,
-        one_rr_per_rrset : Optional[bool] = False,
-        ignore_trailing : Optional[bool] = False,
-        sock : Optional[trio.socket.socket] = None) -> message.Message:
-    ...
-
-def stream(q : message.Message, where : str, tls : Optional[bool] = False,
-           port=53, source : Optional[str] = None,
-           source_port : Optional[int] = 0,
-           one_rr_per_rrset : Optional[bool] = False,
-           ignore_trailing : Optional[bool] = False,
-           stream : Optional[trio.abc.Stream] = None,
-           ssl_context: Optional[ssl.SSLContext] = None,
-           server_hostname: Optional[str] = None) -> message.Message:
-    ...
diff --git a/dns/trio/resolver.pyi b/dns/trio/resolver.pyi
deleted file mode 100644 (file)
index d84419b..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import Union, Optional, List, Any, Dict
-from .. import exception, rdataclass, name, rdatatype
-
-def resolve(qname : str, rdtype : Union[int,str] = 0,
-            rdclass : Union[int,str] = 0,
-            tcp=False, source=None, raise_on_no_answer=True,
-            source_port=0, search : Optional[bool]=None):
-    ...
-
-def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Optional[Dict]):
-    ...
-
-def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False,
-                  resolver : Optional[Resolver] = None):
-    ...
-
-class Resolver:
-    def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
-                 configure : Optional[bool] = True):
-        self.nameservers : List[str]
-    def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
-                rdclass : Union[int,str] = rdataclass.IN,
-                tcp : bool = False, source : Optional[str] = None,
-                raise_on_no_answer=True, source_port : int = 0,
-                 search : Optional[bool]=None):
-        ...
index 44a1dd70852eec1dab26278b3f4a39223baf0c00..b33fe00d97c6e722599b64eeb3ac086800752c37 100644 (file)
@@ -14,7 +14,7 @@ requests-toolbelt = {version="^0.9.1", optional=true}
 requests = {version="^2.23.0", optional=true}
 idna = {version="^2.1", optional=true}
 cryptography = {version="^2.6", optional=true}
-trio = {version="^0.14.0", optional=true}
+trio = {version="^0.14", optional=true}
 
 [tool.poetry.dev-dependencies]
 mypy = "^0.770"
@@ -28,6 +28,7 @@ doh = ['requests', 'requests-toolbelt']
 idna = ['idna']
 dnssec = ['cryptography']
 trio = ['trio']
+curio = ['curio', 'sniffio']
 
 [build-system]
 requires = ["poetry>=0.12"]
index 50a3da101d5f8feb93da26170d0aca3e03975ad5..5dfc61e1171280e84c95e474b40f263a52649c2a 100755 (executable)
--- a/setup.py
+++ b/setup.py
@@ -50,7 +50,7 @@ direct manipulation of DNS zones, messages, names, and records.""",
     'license' : 'ISC',
     'url' : 'http://www.dnspython.org',
     'packages' : ['dns', 'dns.rdtypes', 'dns.rdtypes.IN', 'dns.rdtypes.ANY',
-                  'dns.rdtypes.CH', 'dns.trio'],
+                  'dns.rdtypes.CH'],
     'package_data' : {'dns': ['py.typed']},
     'download_url' : \
     'http://www.dnspython.org/kits/{}/dnspython-{}.tar.gz'.format(version, version),
diff --git a/tests/test_async.py b/tests/test_async.py
new file mode 100644 (file)
index 0000000..c09941b
--- /dev/null
@@ -0,0 +1,219 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import asyncio
+import socket
+import unittest
+
+import dns.asyncbackend
+import dns.asyncquery
+import dns.asyncresolver
+import dns.message
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+import dns.resolver
+
+# Some tests require the internet to be available to run, so let's
+# skip those if it's not there.
+_network_available = True
+try:
+    socket.gethostbyname('dnspython.org')
+except socket.gaierror:
+    _network_available = False
+
+@unittest.skipIf(not _network_available, "Internet not reachable")
+class AsyncTests(unittest.TestCase):
+
+    def setUp(self):
+        self.backend = dns.asyncbackend.set_default_backend('asyncio')
+
+    def async_run(self, afunc):
+        return asyncio.run(afunc())
+
+    def testResolve(self):
+        async def run():
+            answer = await dns.asyncresolver.resolve('dns.google.', 'A')
+            return set([rdata.address for rdata in answer])
+        seen = self.async_run(run)
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    def testResolveAddress(self):
+        async def run():
+            return await dns.asyncresolver.resolve_address('8.8.8.8')
+        answer = self.async_run(run)
+        dnsgoogle = dns.name.from_text('dns.google.')
+        self.assertEqual(answer[0].target, dnsgoogle)
+
+    def testZoneForName1(self):
+        async def run():
+            name = dns.name.from_text('www.dnspython.org.')
+            return await dns.asyncresolver.zone_for_name(name)
+        ezname = dns.name.from_text('dnspython.org.')
+        zname = self.async_run(run)
+        self.assertEqual(zname, ezname)
+
+    def testZoneForName2(self):
+        async def run():
+            name = dns.name.from_text('a.b.www.dnspython.org.')
+            return await dns.asyncresolver.zone_for_name(name)
+        ezname = dns.name.from_text('dnspython.org.')
+        zname = self.async_run(run)
+        self.assertEqual(zname, ezname)
+
+    def testZoneForName3(self):
+        async def run():
+            name = dns.name.from_text('dnspython.org.')
+            return await dns.asyncresolver.zone_for_name(name)
+        ezname = dns.name.from_text('dnspython.org.')
+        zname = self.async_run(run)
+        self.assertEqual(zname, ezname)
+
+    def testZoneForName4(self):
+        def bad():
+            name = dns.name.from_text('dnspython.org', None)
+            async def run():
+                return await dns.asyncresolver.zone_for_name(name)
+            self.async_run(run)
+        self.assertRaises(dns.resolver.NotAbsolute, bad)
+
+    def testQueryUDP(self):
+        qname = dns.name.from_text('dns.google.')
+        async def run():
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            return await dns.asyncquery.udp(q, '8.8.8.8')
+        response = self.async_run(run)
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    def testQueryUDPWithSocket(self):
+        qname = dns.name.from_text('dns.google.')
+        async def run():
+            async with await self.backend.make_socket(socket.AF_INET,
+                                                      socket.SOCK_DGRAM) as s:
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                return await dns.asyncquery.udp(q, '8.8.8.8', sock=s)
+        response = self.async_run(run)
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    def testQueryTCP(self):
+        qname = dns.name.from_text('dns.google.')
+        async def run():
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            return await dns.asyncquery.tcp(q, '8.8.8.8')
+        response = self.async_run(run)
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    def testQueryTCPWithSocket(self):
+        qname = dns.name.from_text('dns.google.')
+        async def run():
+            async with await self.backend.make_socket(socket.AF_INET,
+                                                      socket.SOCK_STREAM, 0,
+                                                      None,
+                                                      ('8.8.8.8', 53)) as s:
+                q = dns.message.make_query(qname, dns.rdatatype.A)
+                return await dns.asyncquery.tcp(q, '8.8.8.8', sock=s)
+        response = self.async_run(run)
+        rrs = response.get_rrset(response.answer, qname,
+                                 dns.rdataclass.IN, dns.rdatatype.A)
+        self.assertTrue(rrs is not None)
+        seen = set([rdata.address for rdata in rrs])
+        self.assertTrue('8.8.8.8' in seen)
+        self.assertTrue('8.8.4.4' in seen)
+
+    # def testQueryTLS(self):
+    #     qname = dns.name.from_text('dns.google.')
+    #     async def run():
+    #         q = dns.message.make_query(qname, dns.rdatatype.A)
+    #         return await dns.asyncquery.stream(q, '8.8.8.8', True)
+    #     response = self.async_run(run)
+    #     rrs = response.get_rrset(response.answer, qname,
+    #                              dns.rdataclass.IN, dns.rdatatype.A)
+    #     self.assertTrue(rrs is not None)
+    #     seen = set([rdata.address for rdata in rrs])
+    #     self.assertTrue('8.8.8.8' in seen)
+    #     self.assertTrue('8.8.4.4' in seen)
+
+    # def testQueryTLSWithSocket(self):
+    #     qname = dns.name.from_text('dns.google.')
+    #     async def run():
+    #         async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
+    #                                                        853) as s:
+    #             q = dns.message.make_query(qname, dns.rdatatype.A)
+    #             return await dns.asyncquery.stream(q, '8.8.8.8', stream=s)
+    #     response = self.async_run(run)
+    #     rrs = response.get_rrset(response.answer, qname,
+    #                              dns.rdataclass.IN, dns.rdatatype.A)
+    #     self.assertTrue(rrs is not None)
+    #     seen = set([rdata.address for rdata in rrs])
+    #     self.assertTrue('8.8.8.8' in seen)
+    #     self.assertTrue('8.8.4.4' in seen)
+
+    def testQueryUDPFallback(self):
+        qname = dns.name.from_text('.')
+        async def run():
+            q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+            return await dns.asyncquery.udp_with_fallback(q, '8.8.8.8')
+        (_, tcp) = self.async_run(run)
+        self.assertTrue(tcp)
+
+    def testQueryUDPFallbackNoFallback(self):
+        qname = dns.name.from_text('dns.google.')
+        async def run():
+            q = dns.message.make_query(qname, dns.rdatatype.A)
+            return await dns.asyncquery.udp_with_fallback(q, '8.8.8.8')
+        (_, tcp) = self.async_run(run)
+        self.assertFalse(tcp)
+
+try:
+    import trio
+
+    class TrioAsyncTests(AsyncTests):
+        def setUp(self):
+            self.backend = dns.asyncbackend.set_default_backend('trio')
+
+        def async_run(self, afunc):
+            return trio.run(afunc)
+except ImportError:
+    pass
+
+try:
+    import curio
+
+    class CurioAsyncTests(AsyncTests):
+        def setUp(self):
+            self.backend = dns.asyncbackend.set_default_backend('curio')
+
+        def async_run(self, afunc):
+            return curio.run(afunc)
+except ImportError:
+    pass
diff --git a/tests/test_trio.py b/tests/test_trio.py
deleted file mode 100644 (file)
index 8304a1f..0000000
+++ /dev/null
@@ -1,189 +0,0 @@
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-# Copyright (C) 2003-2017 Nominum, Inc.
-#
-# Permission to use, copy, modify, and distribute this software and its
-# documentation for any purpose with or without fee is hereby granted,
-# provided that the above copyright notice and this permission notice
-# appear in all copies.
-#
-# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
-# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
-# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
-# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
-# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
-# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
-# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-
-import socket
-import unittest
-
-try:
-    import trio
-    import trio.socket
-
-    import dns.message
-    import dns.name
-    import dns.rdataclass
-    import dns.rdatatype
-    import dns.trio.query
-    import dns.trio.resolver
-
-    # Some tests require the internet to be available to run, so let's
-    # skip those if it's not there.
-    _network_available = True
-    try:
-        socket.gethostbyname('dnspython.org')
-    except socket.gaierror:
-        _network_available = False
-
-    @unittest.skipIf(not _network_available, "Internet not reachable")
-    class TrioTests(unittest.TestCase):
-
-        def testResolve(self):
-            async def run():
-                answer = await dns.trio.resolver.resolve('dns.google.', 'A')
-                return set([rdata.address for rdata in answer])
-            seen = trio.run(run)
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
-
-        def testResolveAddress(self):
-            async def run():
-                return await dns.trio.resolver.resolve_address('8.8.8.8')
-            answer = trio.run(run)
-            dnsgoogle = dns.name.from_text('dns.google.')
-            self.assertEqual(answer[0].target, dnsgoogle)
-
-        def testZoneForName1(self):
-            async def run():
-                name = dns.name.from_text('www.dnspython.org.')
-                return await dns.trio.resolver.zone_for_name(name)
-            ezname = dns.name.from_text('dnspython.org.')
-            zname = trio.run(run)
-            self.assertEqual(zname, ezname)
-
-        def testZoneForName2(self):
-            async def run():
-                name = dns.name.from_text('a.b.www.dnspython.org.')
-                return await dns.trio.resolver.zone_for_name(name)
-            ezname = dns.name.from_text('dnspython.org.')
-            zname = trio.run(run)
-            self.assertEqual(zname, ezname)
-
-        def testZoneForName3(self):
-            async def run():
-                name = dns.name.from_text('dnspython.org.')
-                return await dns.trio.resolver.zone_for_name(name)
-            ezname = dns.name.from_text('dnspython.org.')
-            zname = trio.run(run)
-            self.assertEqual(zname, ezname)
-
-        def testZoneForName4(self):
-            def bad():
-                name = dns.name.from_text('dnspython.org', None)
-                async def run():
-                    return await dns.trio.resolver.zone_for_name(name)
-                trio.run(run)
-            self.assertRaises(dns.resolver.NotAbsolute, bad)
-
-        def testQueryUDP(self):
-            qname = dns.name.from_text('dns.google.')
-            async def run():
-                q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.trio.query.udp(q, '8.8.8.8')
-            response = trio.run(run)
-            rrs = response.get_rrset(response.answer, qname,
-                                     dns.rdataclass.IN, dns.rdatatype.A)
-            self.assertTrue(rrs is not None)
-            seen = set([rdata.address for rdata in rrs])
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
-
-        def testQueryUDPWithSocket(self):
-            qname = dns.name.from_text('dns.google.')
-            async def run():
-                with trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
-                    q = dns.message.make_query(qname, dns.rdatatype.A)
-                    return await dns.trio.query.udp(q, '8.8.8.8', sock=s)
-            response = trio.run(run)
-            rrs = response.get_rrset(response.answer, qname,
-                                     dns.rdataclass.IN, dns.rdatatype.A)
-            self.assertTrue(rrs is not None)
-            seen = set([rdata.address for rdata in rrs])
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
-
-        def testQueryTCP(self):
-            qname = dns.name.from_text('dns.google.')
-            async def run():
-                q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.trio.query.stream(q, '8.8.8.8')
-            response = trio.run(run)
-            rrs = response.get_rrset(response.answer, qname,
-                                     dns.rdataclass.IN, dns.rdatatype.A)
-            self.assertTrue(rrs is not None)
-            seen = set([rdata.address for rdata in rrs])
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
-
-        def testQueryTCPWithSocket(self):
-            qname = dns.name.from_text('dns.google.')
-            async def run():
-                async with await trio.open_tcp_stream('8.8.8.8', 53) as s:
-                    q = dns.message.make_query(qname, dns.rdatatype.A)
-                    return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
-            response = trio.run(run)
-            rrs = response.get_rrset(response.answer, qname,
-                                     dns.rdataclass.IN, dns.rdatatype.A)
-            self.assertTrue(rrs is not None)
-            seen = set([rdata.address for rdata in rrs])
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
-
-        def testQueryTLS(self):
-            qname = dns.name.from_text('dns.google.')
-            async def run():
-                q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.trio.query.stream(q, '8.8.8.8', True)
-            response = trio.run(run)
-            rrs = response.get_rrset(response.answer, qname,
-                                     dns.rdataclass.IN, dns.rdatatype.A)
-            self.assertTrue(rrs is not None)
-            seen = set([rdata.address for rdata in rrs])
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
-
-        def testQueryTLSWithSocket(self):
-            qname = dns.name.from_text('dns.google.')
-            async def run():
-                async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
-                                                               853) as s:
-                    q = dns.message.make_query(qname, dns.rdatatype.A)
-                    return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
-            response = trio.run(run)
-            rrs = response.get_rrset(response.answer, qname,
-                                     dns.rdataclass.IN, dns.rdatatype.A)
-            self.assertTrue(rrs is not None)
-            seen = set([rdata.address for rdata in rrs])
-            self.assertTrue('8.8.8.8' in seen)
-            self.assertTrue('8.8.4.4' in seen)
-
-        def testQueryUDPFallback(self):
-            qname = dns.name.from_text('.')
-            async def run():
-                q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
-                return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
-            (_, tcp) = trio.run(run)
-            self.assertTrue(tcp)
-
-        def testQueryUDPFallbackNoFallback(self):
-            qname = dns.name.from_text('dns.google.')
-            async def run():
-                q = dns.message.make_query(qname, dns.rdatatype.A)
-                return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
-            (_, tcp) = trio.run(run)
-            self.assertFalse(tcp)
-
-except ModuleNotFoundError:
-    pass