]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Allow a socket to be passed to udp(), and a stream to stream().
authorBob Halley <halley@dnspython.org>
Sat, 6 Jun 2020 22:43:06 +0000 (15:43 -0700)
committerBob Halley <halley@dnspython.org>
Sat, 6 Jun 2020 22:43:06 +0000 (15:43 -0700)
dns/trio/query.py
dns/trio/query.pyi
tests/test_trio.py

index 11af1744c330d26e4c602a9728ed6ff964d1c8af..53b8fe50b9329543959b5feeea30d7e7ec33c675 100644 (file)
@@ -2,6 +2,7 @@
 
 """trio async I/O library query support"""
 
+import contextlib
 import socket
 import struct
 import time
@@ -27,7 +28,7 @@ socket_factory = trio.socket.socket
 async def send_udp(sock, what, destination):
     """Asynchronously send a DNS message to the specified UDP socket.
 
-    *sock*, a ``trio.socket``.
+    *sock*, a ``trio.socket.socket``.
 
     *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
 
@@ -49,7 +50,7 @@ async def receive_udp(sock, destination, ignore_unexpected=False,
                       ignore_trailing=False, raise_on_truncation=False):
     """Asynchronously read a DNS message from a UDP socket.
 
-    *sock*, a ``trio.socket``.
+    *sock*, a ``trio.socket.socket``.
 
     *destination*, a destination tuple appropriate for the address family
     of the socket, specifying where the associated query was sent.
@@ -97,7 +98,8 @@ async def receive_udp(sock, destination, ignore_unexpected=False,
 
 async def udp(q, where, port=53, source=None, source_port=0,
               ignore_unexpected=False, one_rr_per_rrset=False,
-              ignore_trailing=False, raise_on_truncation=False):
+              ignore_trailing=False, raise_on_truncation=False,
+              sock=None):
     """Asynchronously return the response obtained after sending a query
     via UDP.
 
@@ -126,18 +128,27 @@ async def udp(q, where, port=53, source=None, source_port=0,
     *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
     the TC bit is set.
 
+    *sock*, a ``trio.socket.socket``, or ``None``, the socket to use
+    for the query.  If ``None``, the default, a socket is created.  if
+    a socket is provided, the *source* and *source_port* are ignored.
+
     Returns a ``dns.message.Message``.
+
     """
 
     wire = q.to_wire()
     (af, destination, source) = \
         dns.query._destination_and_source(None, where, port, source,
                                           source_port)
-    with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
-        received_time = None
-        sent_time = None
-        if source is not None:
-            await s.bind(source)
+    # We can use an ExitStack here as exiting a trio.socket.socket does
+    # not await.
+    with contextlib.ExitStack() as stack:
+        if sock:
+            s = sock
+        else:
+            s = stack.enter_context(socket_factory(af, socket.SOCK_DGRAM, 0))
+            if source is not None:
+                await s.bind(source)
         (_, sent_time) = await send_udp(s, wire, destination)
         (r, received_time) = await receive_udp(s, destination,
                                                ignore_unexpected,
@@ -260,7 +271,7 @@ async def receive_stream(stream, one_rr_per_rrset=False, keyring=None,
 
 async def stream(q, where, tls=False, port=None, source=None, source_port=0,
                  one_rr_per_rrset=False, ignore_trailing=False,
-                 ssl_context=None, server_hostname=None):
+                 stream=None, ssl_context=None, server_hostname=None):
     """Return the response obtained after sending a query using TCP or TLS.
 
     *q*, a ``dns.message.Message``, the query to send.
@@ -287,6 +298,12 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0,
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
 
+    *stream*, a ``trio.abc.Stream``, or ``None``, the stream to use for
+    the query.  If ``None``, the default, a stream is created.  if a
+    socket is provided, it must be connected, and the *where*, *port*,
+    *tls*, *source*, *source_port*, *ssl_context*, and
+    *server_hostname* parameters are ignored.
+
     *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
     a TLS connection. If ``None``, the default, creates one with the default
     configuration.  If this value is not ``None``, then the *tls* parameter
@@ -297,8 +314,8 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0,
     SSL context is created, hostname checking will be disabled.
 
     Returns a ``dns.message.Message``.
-    """
 
+    """
     if ssl_context is not None:
         tls = True
     if port is None:
@@ -307,28 +324,50 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0,
         else:
             port = 53
     wire = q.to_wire()
-    (af, destination, source) = \
-        dns.query._destination_and_source(None, where, port, source,
-                                          source_port)
-    with socket_factory(af, socket.SOCK_STREAM, 0) as s:
-        begin_time = time.time()
-        if source is not None:
-            await s.bind(source)
-        await s.connect(destination)
-        stream = trio.SocketStream(s)
-        if tls and ssl_context is None:
-            ssl_context = ssl.create_default_context()
-            if server_hostname is None:
-                ssl_context.check_hostname = False
-        if ssl_context:
-            stream = trio.SSLStream(stream, ssl_context,
-                                    server_hostname=server_hostname)
-        async with stream:
-            await send_stream(stream, wire)
-            (r, received_time) = await receive_stream(stream, one_rr_per_rrset,
-                                                      q.keyring, q.mac,
-                                                      ignore_trailing)
-            if not q.is_response(r):
-                raise BadResponse
-            r.time = received_time - begin_time
-            return r
+    # We'd like to be able to use an AsyncExitStack here, but that's a 3.7
+    # feature, so we are forced to try ... finally.
+    sock = None
+    s = None
+    begin_time = time.time()
+    try:
+        if stream:
+            #
+            # Verify that the socket is connected, as if it's not connected,
+            # it's not writable, and the polling in send_tcp() will time out or
+            # hang forever.
+            if isinstance(stream, trio.SSLStream):
+                tsock = stream.transport_stream.socket
+            else:
+                tsock = stream.socket
+            tsock.getpeername()
+            s = stream
+        else:
+            (af, destination, source) = \
+                dns.query._destination_and_source(None, where, port, source,
+                                                  source_port)
+            sock = socket_factory(af, socket.SOCK_STREAM, 0)
+            if source is not None:
+                await sock.bind(source)
+            await sock.connect(destination)
+            s = trio.SocketStream(sock)
+            sock = None
+            if tls and ssl_context is None:
+                ssl_context = ssl.create_default_context()
+                if server_hostname is None:
+                    ssl_context.check_hostname = False
+            if ssl_context:
+                s = trio.SSLStream(s, ssl_context,
+                                   server_hostname=server_hostname)
+        await send_stream(s, wire)
+        (r, received_time) = await receive_stream(s, one_rr_per_rrset,
+                                                  q.keyring, q.mac,
+                                                  ignore_trailing)
+        if not q.is_response(r):
+            raise BadResponse
+        r.time = received_time - begin_time
+        return r
+    finally:
+        if sock:
+            sock.close()
+        if s and s != stream:
+            await s.aclose()
index c51f000b7cc95a32d1d0de712fb1383cfed2c279..0a5ab923203cd70887204029dcde06208e5e26aa 100644 (file)
@@ -12,11 +12,14 @@ except ImportError:
     class ssl:    # type: ignore
         SSLContext : Dict = {}
 
+import trio
+
 def udp(q : message.Message, where : str, port=53,
         source : Optional[str] = None, source_port : Optional[int] = 0,
         ignore_unexpected : Optional[bool] = False,
         one_rr_per_rrset : Optional[bool] = False,
-        ignore_trailing : Optional[bool] = False) -> message.Message:
+        ignore_trailing : Optional[bool] = False,
+        sock : Optional[trio.socket.socket] = None) -> message.Message:
     ...
 
 def stream(q : message.Message, where : str, tls : Optional[bool] = False,
@@ -24,6 +27,7 @@ def stream(q : message.Message, where : str, tls : Optional[bool] = False,
            source_port : Optional[int] = 0,
            one_rr_per_rrset : Optional[bool] = False,
            ignore_trailing : Optional[bool] = False,
+           stream : Optional[trio.abc.Stream] = None,
            ssl_context: Optional[ssl.SSLContext] = None,
            server_hostname: Optional[str] = None) -> message.Message:
     ...
index d519844d7306d9cee143ac1003491c0e28166a08..8304a1f812b9a45c807e76e76968cc3eb2710cb2 100644 (file)
@@ -20,6 +20,7 @@ import unittest
 
 try:
     import trio
+    import trio.socket
 
     import dns.message
     import dns.name
@@ -99,6 +100,20 @@ try:
             self.assertTrue('8.8.8.8' in seen)
             self.assertTrue('8.8.4.4' in seen)
 
+        def testQueryUDPWithSocket(self):
+            qname = dns.name.from_text('dns.google.')
+            async def run():
+                with trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+                    q = dns.message.make_query(qname, dns.rdatatype.A)
+                    return await dns.trio.query.udp(q, '8.8.8.8', sock=s)
+            response = trio.run(run)
+            rrs = response.get_rrset(response.answer, qname,
+                                     dns.rdataclass.IN, dns.rdatatype.A)
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue('8.8.8.8' in seen)
+            self.assertTrue('8.8.4.4' in seen)
+
         def testQueryTCP(self):
             qname = dns.name.from_text('dns.google.')
             async def run():
@@ -112,6 +127,20 @@ try:
             self.assertTrue('8.8.8.8' in seen)
             self.assertTrue('8.8.4.4' in seen)
 
+        def testQueryTCPWithSocket(self):
+            qname = dns.name.from_text('dns.google.')
+            async def run():
+                async with await trio.open_tcp_stream('8.8.8.8', 53) as s:
+                    q = dns.message.make_query(qname, dns.rdatatype.A)
+                    return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
+            response = trio.run(run)
+            rrs = response.get_rrset(response.answer, qname,
+                                     dns.rdataclass.IN, dns.rdatatype.A)
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue('8.8.8.8' in seen)
+            self.assertTrue('8.8.4.4' in seen)
+
         def testQueryTLS(self):
             qname = dns.name.from_text('dns.google.')
             async def run():
@@ -125,6 +154,21 @@ try:
             self.assertTrue('8.8.8.8' in seen)
             self.assertTrue('8.8.4.4' in seen)
 
+        def testQueryTLSWithSocket(self):
+            qname = dns.name.from_text('dns.google.')
+            async def run():
+                async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
+                                                               853) as s:
+                    q = dns.message.make_query(qname, dns.rdatatype.A)
+                    return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
+            response = trio.run(run)
+            rrs = response.get_rrset(response.answer, qname,
+                                     dns.rdataclass.IN, dns.rdatatype.A)
+            self.assertTrue(rrs is not None)
+            seen = set([rdata.address for rdata in rrs])
+            self.assertTrue('8.8.8.8' in seen)
+            self.assertTrue('8.8.4.4' in seen)
+
         def testQueryUDPFallback(self):
             qname = dns.name.from_text('.')
             async def run():