import os
import select
import socket
+import ssl
import struct
import sys
import time
if timeout <= 0.0:
raise dns.exception.Timeout
try:
+ if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0:
+ return True
if not _polling_backend(fd, readable, writable, error, timeout):
raise dns.exception.Timeout
except OSError as e:
s = b''
while count > 0:
_wait_for_readable(sock, expiration)
- n = sock.recv(count)
+ try:
+ n = sock.recv(count)
+ except ssl.SSLWantReadError:
+ continue
+ except ssl.SSLWantWriteError:
+ _wait_for_writable(sock, expiration)
+ continue
if n == b'':
raise EOFError
count = count - len(n)
l = len(data)
while current < l:
_wait_for_writable(sock, expiration)
- current += sock.send(data[current:])
+ try:
+ current += sock.send(data[current:])
+ except ssl.SSLWantReadError:
+ _wait_for_readable(sock, expiration)
+ continue
+ except ssl.SSLWantWriteError:
+ continue
def send_tcp(sock, what, expiration=None):
return r
+def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ ssl_context=None):
+ """Return the response obtained after sending a query via TLS.
+
+ *q*, a ``dns.message.Message``, the query to send
+
+ *where*, a ``text`` containing an IPv4 or IPv6 address, where
+ to send the message.
+
+ *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+ query times out. If ``None``, the default, wait forever.
+
+ *port*, an ``int``, the port send the message to. The default is 853.
+
+ *af*, an ``int``, the address family to use. The default is ``None``,
+ which causes the address family to use to be inferred from the form of
+ *where*. If the inference attempt fails, AF_INET is used. This
+ parameter is historical; you need never set it.
+
+ *source*, a ``text`` containing an IPv4 or IPv6 address, specifying
+ the source address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message.
+ The default is 0.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
+ a TLS connection. If ``None``, the default, creates one with the default
+ configuration.
+
+ Returns a ``dns.message.Message``.
+ """
+
+ wire = q.to_wire()
+ (af, destination, source) = _destination_and_source(af, where, port,
+ source, source_port)
+ s = socket_factory(af, socket.SOCK_STREAM, 0)
+ begin_time = None
+ received_time = None
+ try:
+ expiration = _compute_expiration(timeout)
+ s.setblocking(0)
+ begin_time = time.time()
+ if source is not None:
+ s.bind(source)
+ _connect(s, destination, expiration)
+ if ssl_context is None:
+ ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
+ s = ssl_context.wrap_socket(s, do_handshake_on_connect=False)
+ while True:
+ try:
+ s.do_handshake()
+ break
+ except ssl.SSLWantReadError:
+ _wait_for_readable(s, expiration)
+ except ssl.SSLWantWriteError:
+ _wait_for_writable(s, expiration)
+ send_tcp(s, wire, expiration)
+ (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
+ q.keyring, q.mac, ignore_trailing)
+ finally:
+ if begin_time is None or received_time is None:
+ response_time = 0
+ else:
+ response_time = received_time - begin_time
+ s.close()
+ r.time = response_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
timeout=None, port=53, keyring=None, keyname=None, relativize=True,
af=None, lifetime=None, source=None, source_port=0, serial=0,