From b7da95bdb11d8e0afa7e880127c60e1140f060ea Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 12 Jun 2020 07:31:57 -0700 Subject: [PATCH] Add TLS for Trio and Curio. --- dns/_asyncbackend.py | 2 +- dns/_curio_backend.py | 25 ++++++++------ dns/_trio_backend.py | 24 ++++++++++++-- dns/asyncquery.py | 76 ++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 112 insertions(+), 15 deletions(-) diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index 0dbcd742..dc1330e1 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -73,6 +73,6 @@ class Backend: return 'unknown' async def make_socket(self, af, socktype, proto=0, - source=None, raw_source=None, + source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): raise NotImplementedError diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index e37fea39..2efd25db 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -72,19 +72,24 @@ class Backend(dns._asyncbackend.Backend): async def make_socket(self, af, socktype, proto=0, source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): - s = curio.socket.socket(af, socktype, proto) - try: - if source: - s.bind(_lltuple(af, source)) - if socktype == socket.SOCK_STREAM: - with _maybe_timeout(timeout): - await s.connect(_lltuple(af, destination)) - except Exception: - await s.close() - raise if socktype == socket.SOCK_DGRAM: + s = curio.socket.socket(af, socktype, proto) + try: + if source: + s.bind(_lltuple(af, source)) + except Exception: + await s.close() + raise return DatagramSocket(s) elif socktype == socket.SOCK_STREAM: + if source: + source_addr = (_lltuple(af, source)) + else: + source_addr = None + s = await curio.open_connection(destination[0], destination[1], + ssl=ssl_context, + source_addr=source_addr, + server_hostname=server_hostname) return StreamSocket(s) raise NotImplementedError(f'unsupported socket type {socktype}') diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index bcaddcca..d6a93873 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -44,9 +44,10 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): class StreamSocket(dns._asyncbackend.DatagramSocket): - def __init__(self, family, stream): + def __init__(self, family, stream, tls=False): self.family = family self.stream = stream + self.tls = tls async def sendall(self, what, timeout): with _maybe_timeout(timeout): @@ -62,7 +63,10 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): await self.stream.aclose() async def getpeername(self): - return self.stream.socket.getpeername() + if self.tls: + return self.stream.transport_stream.socket.getpeername() + else: + return self.stream.socket.getpeername() class Backend(dns._asyncbackend.Backend): @@ -73,6 +77,7 @@ class Backend(dns._asyncbackend.Backend): destination=None, timeout=None, ssl_context=None, server_hostname=None): s = trio.socket.socket(af, socktype, proto) + stream = None try: if source: await s.bind(_lltuple(af, source)) @@ -85,7 +90,20 @@ class Backend(dns._asyncbackend.Backend): if socktype == socket.SOCK_DGRAM: return DatagramSocket(s) elif socktype == socket.SOCK_STREAM: - return StreamSocket(af, trio.SocketStream(s)) + stream = trio.SocketStream(s) + s = None + tls = False + if ssl_context: + print('TLS') + tls = True + try: + stream = trio.SSLStream(stream, ssl_context, + server_hostname=server_hostname) + except Exception: + await stream.aclose() + raise + print('SOCKET') + return StreamSocket(af, stream, tls) raise NotImplementedError(f'unsupported socket type {socktype}') async def sleep(self, interval): diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 3e377278..b9f7212f 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -31,7 +31,7 @@ import dns.rdataclass import dns.rdatatype from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \ - BadResponse + BadResponse, ssl # for brevity @@ -420,3 +420,77 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, finally: if not sock and s: await s.close() + +async def tls(q, where, timeout=None, port=853, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock=None, + backend=None, ssl_context=None, server_hostname=None): + """Return the response obtained after sending a query via TLS. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 853. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket + to use for the query. If ``None``, the default, a socket is + created. Note that if a socket is provided, it must be a + connected SSL stream socket, and *where*, *port*, + *source*, *source_port*, and *ssl_context* are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing + a TLS connection. If ``None``, the default, creates one with the default + configuration. + + *server_hostname*, a ``str`` containing the server's hostname. The + default is ``None``, which means that no hostname is known, and if an + SSL context is created, hostname checking will be disabled. + + Returns a ``dns.message.Message``. + """ + if not backend: + backend = dns.asyncbackend.get_default_backend() + if not sock: + if ssl_context is None: + ssl_context = ssl.create_default_context() + if server_hostname is None: + ssl_context.check_hostname = False + else: + ssl_context = None + server_hostname = None + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, + dtuple, timeout, ssl_context, + server_hostname) + else: + s = sock + try: + # + # If a socket was provided, there's no special TLS handling needed. + # + return await tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing, s, backend) + finally: + if not sock and s: + await s.close() -- 2.47.3