]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add support for DoT (DNS over TLS)
authorBrian Wellington <bwelling@xbill.org>
Mon, 30 Sep 2019 18:51:16 +0000 (11:51 -0700)
committerBrian Wellington <bwelling@xbill.org>
Mon, 30 Sep 2019 18:51:16 +0000 (11:51 -0700)
dns/query.py

index ac15895ae1cac270186bcdf949ce10ea0c51d0bc..9ed51b7fe907749f2f648b920394cebb59a64b18 100644 (file)
@@ -23,6 +23,7 @@ import errno
 import os
 import select
 import socket
+import ssl
 import struct
 import sys
 import time
@@ -123,6 +124,8 @@ def _wait_for(fd, readable, writable, error, expiration):
             if timeout <= 0.0:
                 raise dns.exception.Timeout
         try:
+            if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0:
+                return True
             if not _polling_backend(fd, readable, writable, error, timeout):
                 raise dns.exception.Timeout
         except OSError as e:
@@ -339,7 +342,13 @@ def _net_read(sock, count, expiration):
     s = b''
     while count > 0:
         _wait_for_readable(sock, expiration)
-        n = sock.recv(count)
+        try:
+            n = sock.recv(count)
+        except ssl.SSLWantReadError:
+            continue
+        except ssl.SSLWantWriteError:
+            _wait_for_writable(sock, expiration)
+            continue
         if n == b'':
             raise EOFError
         count = count - len(n)
@@ -356,7 +365,13 @@ def _net_write(sock, data, expiration):
     l = len(data)
     while current < l:
         _wait_for_writable(sock, expiration)
-        current += sock.send(data[current:])
+        try:
+            current += sock.send(data[current:])
+        except ssl.SSLWantReadError:
+            _wait_for_readable(sock, expiration)
+            continue
+        except ssl.SSLWantWriteError:
+            continue
 
 
 def send_tcp(sock, what, expiration=None):
@@ -500,6 +515,84 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     return r
 
 
+def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
+        one_rr_per_rrset=False, ignore_trailing=False,
+        ssl_context=None):
+    """Return the response obtained after sending a query via TLS.
+
+    *q*, a ``dns.message.Message``, the query to send
+
+    *where*, a ``text`` containing an IPv4 or IPv6 address,  where
+    to send the message.
+
+    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+    query times out.  If ``None``, the default, wait forever.
+
+    *port*, an ``int``, the port send the message to.  The default is 853.
+
+    *af*, an ``int``, the address family to use.  The default is ``None``,
+    which causes the address family to use to be inferred from the form of
+    *where*.  If the inference attempt fails, AF_INET is used.  This
+    parameter is historical; you need never set it.
+
+    *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
+    the source address.  The default is the wildcard address.
+
+    *source_port*, an ``int``, the port from which to send the message.
+    The default is 0.
+
+    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
+    RRset.
+
+    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
+    junk at end of the received message.
+
+    *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
+    a TLS connection. If ``None``, the default, creates one with the default
+    configuration.
+
+    Returns a ``dns.message.Message``.
+    """
+
+    wire = q.to_wire()
+    (af, destination, source) = _destination_and_source(af, where, port,
+                                                        source, source_port)
+    s = socket_factory(af, socket.SOCK_STREAM, 0)
+    begin_time = None
+    received_time = None
+    try:
+        expiration = _compute_expiration(timeout)
+        s.setblocking(0)
+        begin_time = time.time()
+        if source is not None:
+            s.bind(source)
+        _connect(s, destination, expiration)
+        if ssl_context is None:
+            ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+        s = ssl_context.wrap_socket(s, do_handshake_on_connect=False)
+        while True:
+            try:
+                s.do_handshake()
+                break
+            except ssl.SSLWantReadError:
+                _wait_for_readable(s, expiration)
+            except ssl.SSLWantWriteError:
+                _wait_for_writable(s, expiration)
+        send_tcp(s, wire, expiration)
+        (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
+                                         q.keyring, q.mac, ignore_trailing)
+    finally:
+        if begin_time is None or received_time is None:
+            response_time = 0
+        else:
+            response_time = received_time - begin_time
+        s.close()
+    r.time = response_time
+    if not q.is_response(r):
+        raise BadResponse
+    return r
+
+
 def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
         timeout=None, port=53, keyring=None, keyname=None, relativize=True,
         af=None, lifetime=None, source=None, source_port=0, serial=0,