async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
- s = curio.socket.socket(af, socktype, proto)
- try:
- if source:
- s.bind(_lltuple(af, source))
- if socktype == socket.SOCK_STREAM:
- with _maybe_timeout(timeout):
- await s.connect(_lltuple(af, destination))
- except Exception:
- await s.close()
- raise
if socktype == socket.SOCK_DGRAM:
+ s = curio.socket.socket(af, socktype, proto)
+ try:
+ if source:
+ s.bind(_lltuple(af, source))
+ except Exception:
+ await s.close()
+ raise
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
+ if source:
+ source_addr = (_lltuple(af, source))
+ else:
+ source_addr = None
+ s = await curio.open_connection(destination[0], destination[1],
+ ssl=ssl_context,
+ source_addr=source_addr,
+ server_hostname=server_hostname)
return StreamSocket(s)
raise NotImplementedError(f'unsupported socket type {socktype}')
class StreamSocket(dns._asyncbackend.DatagramSocket):
- def __init__(self, family, stream):
+ def __init__(self, family, stream, tls=False):
self.family = family
self.stream = stream
+ self.tls = tls
async def sendall(self, what, timeout):
with _maybe_timeout(timeout):
await self.stream.aclose()
async def getpeername(self):
- return self.stream.socket.getpeername()
+ if self.tls:
+ return self.stream.transport_stream.socket.getpeername()
+ else:
+ return self.stream.socket.getpeername()
class Backend(dns._asyncbackend.Backend):
destination=None, timeout=None,
ssl_context=None, server_hostname=None):
s = trio.socket.socket(af, socktype, proto)
+ stream = None
try:
if source:
await s.bind(_lltuple(af, source))
if socktype == socket.SOCK_DGRAM:
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
- return StreamSocket(af, trio.SocketStream(s))
+ stream = trio.SocketStream(s)
+ s = None
+ tls = False
+ if ssl_context:
+ print('TLS')
+ tls = True
+ try:
+ stream = trio.SSLStream(stream, ssl_context,
+ server_hostname=server_hostname)
+ except Exception:
+ await stream.aclose()
+ raise
+ print('SOCKET')
+ return StreamSocket(af, stream, tls)
raise NotImplementedError(f'unsupported socket type {socktype}')
async def sleep(self, interval):
import dns.rdatatype
from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \
- BadResponse
+ BadResponse, ssl
# for brevity
finally:
if not sock and s:
await s.close()
+
+async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False, sock=None,
+ backend=None, ssl_context=None, server_hostname=None):
+ """Return the response obtained after sending a query via TLS.
+
+ *q*, a ``dns.message.Message``, the query to send
+
+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
+ to send the message.
+
+ *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+ query times out. If ``None``, the default, wait forever.
+
+ *port*, an ``int``, the port send the message to. The default is 853.
+
+ *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+ the source address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message.
+ The default is 0.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
+ to use for the query. If ``None``, the default, a socket is
+ created. Note that if a socket is provided, it must be a
+ connected SSL stream socket, and *where*, *port*,
+ *source*, *source_port*, and *ssl_context* are ignored.
+
+ *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
+ the default, then dnspython will use the default backend.
+
+ *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
+ a TLS connection. If ``None``, the default, creates one with the default
+ configuration.
+
+ *server_hostname*, a ``str`` containing the server's hostname. The
+ default is ``None``, which means that no hostname is known, and if an
+ SSL context is created, hostname checking will be disabled.
+
+ Returns a ``dns.message.Message``.
+ """
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ if not sock:
+ if ssl_context is None:
+ ssl_context = ssl.create_default_context()
+ if server_hostname is None:
+ ssl_context.check_hostname = False
+ else:
+ ssl_context = None
+ server_hostname = None
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ dtuple = (where, port)
+ s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
+ dtuple, timeout, ssl_context,
+ server_hostname)
+ else:
+ s = sock
+ try:
+ #
+ # If a socket was provided, there's no special TLS handling needed.
+ #
+ return await tcp(q, where, timeout, port, source, source_port,
+ one_rr_per_rrset, ignore_trailing, s, backend)
+ finally:
+ if not sock and s:
+ await s.close()