]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Start trio async support.
authorBob Halley <halley@dnspython.org>
Sat, 16 May 2020 23:51:28 +0000 (16:51 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 16 May 2020 23:52:13 +0000 (16:52 -0700)
dns/trio/__init__.py [new file with mode: 0644]
dns/trio/query.py [new file with mode: 0644]
examples/trio.py [new file with mode: 0644]

diff --git a/dns/trio/__init__.py b/dns/trio/__init__.py
new file mode 100644 (file)
index 0000000..7aa0c7f
--- /dev/null
@@ -0,0 +1,7 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""trio async I/O library helpers"""
+
+__all__ = [
+    'query',
+]
diff --git a/dns/trio/query.py b/dns/trio/query.py
new file mode 100644 (file)
index 0000000..494d8d8
--- /dev/null
@@ -0,0 +1,280 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""trio async I/O library query support"""
+
+import socket
+import struct
+import trio
+import trio.socket
+
+import dns.exception
+import dns.inet
+import dns.name
+import dns.message
+import dns.query
+import dns.rcode
+import dns.rdataclass
+import dns.rdatatype
+
+ssl = dns.query.ssl
+
+# Function used to create a socket.  Can be overridden if needed in special
+# situations.
+socket_factory = trio.socket.socket
+
+async def send_udp(sock, what, destination):
+    """Asynchronously send a DNS message to the specified UDP socket.
+
+    *sock*, a ``trio.socket``.
+
+    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
+
+    *destination*, a destination tuple appropriate for the address family
+    of the socket, specifying where to send the query.
+
+    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
+    """
+
+    if isinstance(what, dns.message.Message):
+        what = what.to_wire()
+    sent_time = trio.current_time()
+    n = await sock.sendto(what, destination)
+    return (n, sent_time)
+
+
+async def receive_udp(sock, destination, ignore_unexpected=False,
+                      one_rr_per_rrset=False, keyring=None, request_mac=b'',
+                      ignore_trailing=False):
+    """Asynchronously read a DNS message from a UDP socket.
+
+    *sock*, a ``trio.socket``.
+
+    *destination*, a destination tuple appropriate for the address family
+    of the socket, specifying where the associated query was sent.
+
+    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
+    unexpected sources.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *keyring*, a ``dict``, the keyring to use for TSIG.
+
+    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    Raises if the message is malformed, if network errors occur, of if
+    there is a timeout.
+
+    Returns a ``dns.message.Message`` object.
+    """
+
+    wire = b''
+    while True:
+        (wire, from_address) = await sock.recvfrom(65535)
+        if dns.query._addresses_equal(sock.family, from_address,
+                                      destination) or \
+           (dns.inet.is_multicast(destination[0]) and
+            from_address[1:] == destination[1:]):
+            break
+        if not ignore_unexpected:
+            raise UnexpectedSource('got a response from '
+                                   '%s instead of %s' % (from_address,
+                                                         destination))
+    received_time = trio.current_time()
+    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset,
+                              ignore_trailing=ignore_trailing)
+    return (r, received_time)
+
+async def udp(q, where, port=53, source=None, source_port=0,
+              ignore_unexpected=False, one_rr_per_rrset=False,
+              ignore_trailing=False):
+    """Asynchronously return the response obtained after sending a query
+    via UDP.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *port*, an ``int``, the port send the message to.  The default is 53.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
+    unexpected sources.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    Returns a ``dns.message.Message``.
+    """
+
+    wire = q.to_wire()
+    (af, destination, source) = \
+        dns.query._destination_and_source(None, where, port, source,
+                                          source_port)
+    with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
+        received_time = None
+        sent_time = None
+        if source is not None:
+            await s.bind(source)
+        (_, sent_time) = await send_udp(s, wire, destination)
+        (r, received_time) = await receive_udp(s, destination,
+                                               ignore_unexpected,
+                                               one_rr_per_rrset, q.keyring,
+                                               q.mac, ignore_trailing)
+        if not q.is_response(r):
+            raise BadResponse
+        r.time = received_time - sent_time
+        return r
+
+async def send_stream(stream, what):
+    """Asynchronously send a DNS message to the specified stream.
+
+    *stream*, a ``trio.abc.Stream``.
+
+    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
+
+    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
+    """
+
+    if isinstance(what, dns.message.Message):
+        what = what.to_wire()
+    l = len(what)
+    # copying the wire into tcpmsg is inefficient, but lets us
+    # avoid writev() or doing a short write that would get pushed
+    # onto the net
+    stream_message = struct.pack("!H", l) + what
+    sent_time = trio.current_time()
+    await stream.send_all(stream_message)
+    return (len(stream_message), sent_time)
+
+async def _read_exactly(stream, count):
+    """Read the specified number of bytes from stream.  Keep trying until we
+    either get the desired amount, or we hit EOF.
+    """
+    s = b''
+    while count > 0:
+        n = await stream.receive_some(count)
+        if n == b'':
+            raise EOFError
+        count = count - len(n)
+        s = s + n
+    return s
+
+async def receive_stream(stream, one_rr_per_rrset=False, keyring=None,
+                         request_mac=b'', ignore_trailing=False):
+    """Read a DNS message from a stream.
+
+    *stream*, a ``trio.abc.Stream``.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *keyring*, a ``dict``, the keyring to use for TSIG.
+
+    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    Raises if the message is malformed, if network errors occur, of if
+    there is a timeout.
+
+    Returns a ``dns.message.Message`` object.
+    """
+
+    ldata = await _read_exactly(stream, 2)
+    (l,) = struct.unpack("!H", ldata)
+    wire = await _read_exactly(stream, l)
+    received_time = trio.current_time()
+    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+                              one_rr_per_rrset=one_rr_per_rrset,
+                              ignore_trailing=ignore_trailing)
+    return (r, received_time)
+
+async def stream(q, where, tls=False, port=None, source=None, source_port=0,
+                 one_rr_per_rrset=False, ignore_trailing=False,
+                 ssl_context=None, server_hostname=None):
+    """Return the response obtained after sending a query using TCP or TLS.
+
+    *q*, a ``dns.message.Message``, the query to send.
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *tls*, a ``bool``.  If ``False``, the default, the query will be
+    sent using TCP and *port* will default to 53.  If ``True``, the
+    query is sent using TLS, and *port* will default to 853.
+
+    *port*, an ``int``, the port send the message to.  The default is as
+    specified in the description for *tls*.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
+    a TLS connection. If ``None``, the default, creates one with the default
+    configuration.  If this value is not ``None``, then the *tls* parameter
+    is treated as if it were ``True`` regardless of its value.
+
+    *server_hostname*, a ``str`` containing the server's hostname.  The
+    default is ``None``, which means that no hostname is known, and if an
+    SSL context is created, hostname checking will be disabled.
+
+    Returns a ``dns.message.Message``.
+    """
+
+    if ssl_context is not None:
+        tls = True
+    if port is None:
+        if tls:
+            port = 853
+        else:
+            port = 53
+    wire = q.to_wire()
+    (af, destination, source) = \
+        dns.query._destination_and_source(None, where, port, source,
+                                          source_port)
+    with socket_factory(af, socket.SOCK_STREAM, 0) as s:
+        begin_time = trio.current_time()
+        if source is not None:
+            await s.bind(source)
+        await s.connect(destination)
+        stream = trio.SocketStream(s)
+        if tls and ssl_context is None:
+            ssl_context = ssl.create_default_context()
+            if server_hostname is None:
+                ssl_context.check_hostname = False
+        if ssl_context:
+            stream = trio.SSLStream(stream, ssl_context,
+                                    server_hostname=server_hostname)
+        async with stream:
+            await send_stream(stream, wire)
+            (r, received_time) = await receive_stream(stream, one_rr_per_rrset,
+                                                      q.keyring, q.mac,
+                                                      ignore_trailing)
+            if not q.is_response(r):
+                raise BadResponse
+            r.time = received_time - begin_time
+            return r
diff --git a/examples/trio.py b/examples/trio.py
new file mode 100644 (file)
index 0000000..5bf889c
--- /dev/null
@@ -0,0 +1,24 @@
+
+import sys
+import trio
+
+import dns.message
+import dns.trio.query
+
+async def main():
+    if len(sys.argv) > 1:
+        host = sys.argv[0]
+    else:
+        host = 'www.dnspython.org'
+    q = dns.message.make_query(host, 'A')
+    r = await dns.trio.query.udp(q, '8.8.8.8')
+    print(r)
+    q = dns.message.make_query(host, 'A')
+    r = await dns.trio.query.stream(q, '8.8.8.8')
+    print(r)
+    q = dns.message.make_query(host, 'A')
+    r = await dns.trio.query.stream(q, '8.8.8.8', tls=True)
+    print(r)
+
+if __name__ == '__main__':
+    trio.run(main)