From 0a16d33906fcff93d01ed4423d0ef87d2624cfc7 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Mon, 30 Sep 2019 11:51:16 -0700 Subject: [PATCH] Add support for DoT (DNS over TLS) --- dns/query.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/dns/query.py b/dns/query.py index ac15895a..9ed51b7f 100644 --- a/dns/query.py +++ b/dns/query.py @@ -23,6 +23,7 @@ import errno import os import select import socket +import ssl import struct import sys import time @@ -123,6 +124,8 @@ def _wait_for(fd, readable, writable, error, expiration): if timeout <= 0.0: raise dns.exception.Timeout try: + if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0: + return True if not _polling_backend(fd, readable, writable, error, timeout): raise dns.exception.Timeout except OSError as e: @@ -339,7 +342,13 @@ def _net_read(sock, count, expiration): s = b'' while count > 0: _wait_for_readable(sock, expiration) - n = sock.recv(count) + try: + n = sock.recv(count) + except ssl.SSLWantReadError: + continue + except ssl.SSLWantWriteError: + _wait_for_writable(sock, expiration) + continue if n == b'': raise EOFError count = count - len(n) @@ -356,7 +365,13 @@ def _net_write(sock, data, expiration): l = len(data) while current < l: _wait_for_writable(sock, expiration) - current += sock.send(data[current:]) + try: + current += sock.send(data[current:]) + except ssl.SSLWantReadError: + _wait_for_readable(sock, expiration) + continue + except ssl.SSLWantWriteError: + continue def send_tcp(sock, what, expiration=None): @@ -500,6 +515,84 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, return r +def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, + ssl_context=None): + """Return the response obtained after sending a query via TLS. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``text`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 853. + + *af*, an ``int``, the address family to use. The default is ``None``, + which causes the address family to use to be inferred from the form of + *where*. If the inference attempt fails, AF_INET is used. This + parameter is historical; you need never set it. + + *source*, a ``text`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing + a TLS connection. If ``None``, the default, creates one with the default + configuration. + + Returns a ``dns.message.Message``. + """ + + wire = q.to_wire() + (af, destination, source) = _destination_and_source(af, where, port, + source, source_port) + s = socket_factory(af, socket.SOCK_STREAM, 0) + begin_time = None + received_time = None + try: + expiration = _compute_expiration(timeout) + s.setblocking(0) + begin_time = time.time() + if source is not None: + s.bind(source) + _connect(s, destination, expiration) + if ssl_context is None: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + s = ssl_context.wrap_socket(s, do_handshake_on_connect=False) + while True: + try: + s.do_handshake() + break + except ssl.SSLWantReadError: + _wait_for_readable(s, expiration) + except ssl.SSLWantWriteError: + _wait_for_writable(s, expiration) + send_tcp(s, wire, expiration) + (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, + q.keyring, q.mac, ignore_trailing) + finally: + if begin_time is None or received_time is None: + response_time = 0 + else: + response_time = received_time - begin_time + s.close() + r.time = response_time + if not q.is_response(r): + raise BadResponse + return r + + def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, timeout=None, port=53, keyring=None, keyname=None, relativize=True, af=None, lifetime=None, source=None, source_port=0, serial=0, -- 2.47.3