"""trio async I/O library query support"""
+import contextlib
import socket
import struct
import time
async def send_udp(sock, what, destination):
"""Asynchronously send a DNS message to the specified UDP socket.
- *sock*, a ``trio.socket``.
+ *sock*, a ``trio.socket.socket``.
*what*, a ``bytes`` or ``dns.message.Message``, the message to send.
ignore_trailing=False, raise_on_truncation=False):
"""Asynchronously read a DNS message from a UDP socket.
- *sock*, a ``trio.socket``.
+ *sock*, a ``trio.socket.socket``.
*destination*, a destination tuple appropriate for the address family
of the socket, specifying where the associated query was sent.
async def udp(q, where, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False,
- ignore_trailing=False, raise_on_truncation=False):
+ ignore_trailing=False, raise_on_truncation=False,
+ sock=None):
"""Asynchronously return the response obtained after sending a query
via UDP.
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
the TC bit is set.
+ *sock*, a ``trio.socket.socket``, or ``None``, the socket to use
+ for the query. If ``None``, the default, a socket is created. if
+ a socket is provided, the *source* and *source_port* are ignored.
+
Returns a ``dns.message.Message``.
+
"""
wire = q.to_wire()
(af, destination, source) = \
dns.query._destination_and_source(None, where, port, source,
source_port)
- with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
- received_time = None
- sent_time = None
- if source is not None:
- await s.bind(source)
+ # We can use an ExitStack here as exiting a trio.socket.socket does
+ # not await.
+ with contextlib.ExitStack() as stack:
+ if sock:
+ s = sock
+ else:
+ s = stack.enter_context(socket_factory(af, socket.SOCK_DGRAM, 0))
+ if source is not None:
+ await s.bind(source)
(_, sent_time) = await send_udp(s, wire, destination)
(r, received_time) = await receive_udp(s, destination,
ignore_unexpected,
async def stream(q, where, tls=False, port=None, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False,
- ssl_context=None, server_hostname=None):
+ stream=None, ssl_context=None, server_hostname=None):
"""Return the response obtained after sending a query using TCP or TLS.
*q*, a ``dns.message.Message``, the query to send.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *stream*, a ``trio.abc.Stream``, or ``None``, the stream to use for
+ the query. If ``None``, the default, a stream is created. if a
+ socket is provided, it must be connected, and the *where*, *port*,
+ *tls*, *source*, *source_port*, *ssl_context*, and
+ *server_hostname* parameters are ignored.
+
*ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
a TLS connection. If ``None``, the default, creates one with the default
configuration. If this value is not ``None``, then the *tls* parameter
SSL context is created, hostname checking will be disabled.
Returns a ``dns.message.Message``.
- """
+ """
if ssl_context is not None:
tls = True
if port is None:
else:
port = 53
wire = q.to_wire()
- (af, destination, source) = \
- dns.query._destination_and_source(None, where, port, source,
- source_port)
- with socket_factory(af, socket.SOCK_STREAM, 0) as s:
- begin_time = time.time()
- if source is not None:
- await s.bind(source)
- await s.connect(destination)
- stream = trio.SocketStream(s)
- if tls and ssl_context is None:
- ssl_context = ssl.create_default_context()
- if server_hostname is None:
- ssl_context.check_hostname = False
- if ssl_context:
- stream = trio.SSLStream(stream, ssl_context,
- server_hostname=server_hostname)
- async with stream:
- await send_stream(stream, wire)
- (r, received_time) = await receive_stream(stream, one_rr_per_rrset,
- q.keyring, q.mac,
- ignore_trailing)
- if not q.is_response(r):
- raise BadResponse
- r.time = received_time - begin_time
- return r
+ # We'd like to be able to use an AsyncExitStack here, but that's a 3.7
+ # feature, so we are forced to try ... finally.
+ sock = None
+ s = None
+ begin_time = time.time()
+ try:
+ if stream:
+ #
+ # Verify that the socket is connected, as if it's not connected,
+ # it's not writable, and the polling in send_tcp() will time out or
+ # hang forever.
+ if isinstance(stream, trio.SSLStream):
+ tsock = stream.transport_stream.socket
+ else:
+ tsock = stream.socket
+ tsock.getpeername()
+ s = stream
+ else:
+ (af, destination, source) = \
+ dns.query._destination_and_source(None, where, port, source,
+ source_port)
+ sock = socket_factory(af, socket.SOCK_STREAM, 0)
+ if source is not None:
+ await sock.bind(source)
+ await sock.connect(destination)
+ s = trio.SocketStream(sock)
+ sock = None
+ if tls and ssl_context is None:
+ ssl_context = ssl.create_default_context()
+ if server_hostname is None:
+ ssl_context.check_hostname = False
+ if ssl_context:
+ s = trio.SSLStream(s, ssl_context,
+ server_hostname=server_hostname)
+ await send_stream(s, wire)
+ (r, received_time) = await receive_stream(s, one_rr_per_rrset,
+ q.keyring, q.mac,
+ ignore_trailing)
+ if not q.is_response(r):
+ raise BadResponse
+ r.time = received_time - begin_time
+ return r
+ finally:
+ if sock:
+ sock.close()
+ if s and s != stream:
+ await s.aclose()
try:
import trio
+ import trio.socket
import dns.message
import dns.name
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryUDPWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ with trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.trio.query.udp(q, '8.8.8.8', sock=s)
+ response = trio.run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTCP(self):
qname = dns.name.from_text('dns.google.')
async def run():
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryTCPWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ async with await trio.open_tcp_stream('8.8.8.8', 53) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
+ response = trio.run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTLS(self):
qname = dns.name.from_text('dns.google.')
async def run():
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryTLSWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
+ 853) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
+ response = trio.run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryUDPFallback(self):
qname = dns.name.from_text('.')
async def run():