]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add TLS for Trio and Curio.
authorBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 14:31:57 +0000 (07:31 -0700)
committerBob Halley <halley@dnspython.org>
Fri, 12 Jun 2020 14:31:57 +0000 (07:31 -0700)
dns/_asyncbackend.py
dns/_curio_backend.py
dns/_trio_backend.py
dns/asyncquery.py

index 0dbcd742ecb6f708b712cb11a16379b4edeb8666..dc1330e1593700cc00e5670d6ccbf425c522ce8c 100644 (file)
@@ -73,6 +73,6 @@ class Backend:
         return 'unknown'
 
     async def make_socket(self, af, socktype, proto=0,
-                          source=None, raw_source=None,
+                          source=None, destination=None, timeout=None,
                           ssl_context=None, server_hostname=None):
         raise NotImplementedError
index e37fea39503883b6ef98ef71a24d9f8a08b114bd..2efd25db24d3a933ede8a3fe71170922fe954ff4 100644 (file)
@@ -72,19 +72,24 @@ class Backend(dns._asyncbackend.Backend):
     async def make_socket(self, af, socktype, proto=0,
                           source=None, destination=None, timeout=None,
                           ssl_context=None, server_hostname=None):
-        s = curio.socket.socket(af, socktype, proto)
-        try:
-            if source:
-                s.bind(_lltuple(af, source))
-            if socktype == socket.SOCK_STREAM:
-                with _maybe_timeout(timeout):
-                    await s.connect(_lltuple(af, destination))
-        except Exception:
-            await s.close()
-            raise
         if socktype == socket.SOCK_DGRAM:
+            s = curio.socket.socket(af, socktype, proto)
+            try:
+                if source:
+                    s.bind(_lltuple(af, source))
+            except Exception:
+                await s.close()
+                raise
             return DatagramSocket(s)
         elif socktype == socket.SOCK_STREAM:
+            if source:
+                source_addr = (_lltuple(af, source))
+            else:
+                source_addr = None
+            s = await curio.open_connection(destination[0], destination[1],
+                                            ssl=ssl_context,
+                                            source_addr=source_addr,
+                                            server_hostname=server_hostname)
             return StreamSocket(s)
         raise NotImplementedError(f'unsupported socket type {socktype}')
 
index bcaddcca571d7579f72c58dd7a2e6f615cc4060e..d6a9387300958b05a5ee96dc89ba5737e24251fd 100644 (file)
@@ -44,9 +44,10 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket):
 
 
 class StreamSocket(dns._asyncbackend.DatagramSocket):
-    def __init__(self, family, stream):
+    def __init__(self, family, stream, tls=False):
         self.family = family
         self.stream = stream
+        self.tls = tls
 
     async def sendall(self, what, timeout):
         with _maybe_timeout(timeout):
@@ -62,7 +63,10 @@ class StreamSocket(dns._asyncbackend.DatagramSocket):
         await self.stream.aclose()
 
     async def getpeername(self):
-        return self.stream.socket.getpeername()
+        if self.tls:
+            return self.stream.transport_stream.socket.getpeername()
+        else:
+            return self.stream.socket.getpeername()
 
 
 class Backend(dns._asyncbackend.Backend):
@@ -73,6 +77,7 @@ class Backend(dns._asyncbackend.Backend):
                           destination=None, timeout=None,
                           ssl_context=None, server_hostname=None):
         s = trio.socket.socket(af, socktype, proto)
+        stream = None
         try:
             if source:
                 await s.bind(_lltuple(af, source))
@@ -85,7 +90,20 @@ class Backend(dns._asyncbackend.Backend):
         if socktype == socket.SOCK_DGRAM:
             return DatagramSocket(s)
         elif socktype == socket.SOCK_STREAM:
-            return StreamSocket(af, trio.SocketStream(s))
+            stream = trio.SocketStream(s)
+            s = None
+            tls = False
+            if ssl_context:
+                print('TLS')
+                tls = True
+                try:
+                    stream = trio.SSLStream(stream, ssl_context,
+                                            server_hostname=server_hostname)
+                except Exception:
+                    await stream.aclose()
+                    raise
+            print('SOCKET')
+            return StreamSocket(af, stream, tls)
         raise NotImplementedError(f'unsupported socket type {socktype}')
 
     async def sleep(self, interval):
index 3e377278364c64903e069019b979e5b5e580f937..b9f7212f283b757ac52c378a05ad1abbc23450f3 100644 (file)
@@ -31,7 +31,7 @@ import dns.rdataclass
 import dns.rdatatype
 
 from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \
-    BadResponse
+    BadResponse, ssl
 
 
 # for brevity
@@ -420,3 +420,77 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
     finally:
         if not sock and s:
             await s.close()
+
+async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
+              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
+              backend=None, ssl_context=None, server_hostname=None):
+    """Return the response obtained after sending a query via TLS.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+    query times out.  If ``None``, the default, wait forever.
+
+    *port*, an ``int``, the port send the message to.  The default is 853.
+
+    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
+    to use for the query.  If ``None``, the default, a socket is
+    created.  Note that if a socket is provided, it must be a
+    connected SSL stream socket, and *where*, *port*,
+    *source*, *source_port*, and *ssl_context* are ignored.
+
+    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
+    the default, then dnspython will use the default backend.
+
+    *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
+    a TLS connection. If ``None``, the default, creates one with the default
+    configuration.
+
+    *server_hostname*, a ``str`` containing the server's hostname.  The
+    default is ``None``, which means that no hostname is known, and if an
+    SSL context is created, hostname checking will be disabled.
+
+    Returns a ``dns.message.Message``.
+    """
+    if not backend:
+        backend = dns.asyncbackend.get_default_backend()
+    if not sock:
+        if ssl_context is None:
+            ssl_context = ssl.create_default_context()
+            if server_hostname is None:
+                ssl_context.check_hostname = False
+        else:
+            ssl_context = None
+            server_hostname = None
+        af = dns.inet.af_for_address(where)
+        stuple = _source_tuple(af, source, source_port)
+        dtuple = (where, port)
+        s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
+                                      dtuple, timeout, ssl_context,
+                                      server_hostname)
+    else:
+        s = sock
+    try:
+        #
+        # If a socket was provided, there's no special TLS handling needed.
+        #
+        return await tcp(q, where, timeout, port, source, source_port,
+                         one_rr_per_rrset, ignore_trailing, s, backend)
+    finally:
+        if not sock and s:
+            await s.close()