available."""
-def _compute_expiration(timeout):
+def _compute_times(timeout):
+ now = time.time()
if timeout is None:
- return None
+ return (now, None)
else:
- return time.time() + timeout
+ return (now, now + timeout)
# This module can use either poll() or select() as the "polling backend".
#
destination = None
return (af, destination, source)
+def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
+ s = socket_factory(af, type)
+ try:
+ s.setblocking(False)
+ if source is not None:
+ s.bind(source)
+ if ssl_context:
+ return ssl_context.wrap_socket(s, do_handshake_on_connect=False,
+ server_hostname=server_hostname)
+ else:
+ return s
+ except Exception:
+ s.close()
+ raise
+
def https(q, where, timeout=None, port=443, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False,
session=None, path='/dns-query', post=True,
def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
- raise_on_truncation=False):
+ raise_on_truncation=False, sock=None):
"""Return the response obtained after sending a query via UDP.
*q*, a ``dns.message.Message``, the query to send
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
the TC bit is set.
+ *sock*, a ``socket.socket``, or ``None``, the socket to use for the
+ query. If ``None``, the default, a socket is created. Note that
+ if a socket is provided, it must be a nonblocking datagram socket,
+ and the *source* and *source_port* are ignored.
+
Returns a ``dns.message.Message``.
"""
wire = q.to_wire()
(af, destination, source) = _destination_and_source(None, where, port,
source, source_port)
- with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
- expiration = _compute_expiration(timeout)
- s.setblocking(0)
- if source is not None:
- s.bind(source)
- (_, sent_time) = send_udp(s, wire, destination, expiration)
+ (begin_time, expiration) = _compute_times(timeout)
+ with contextlib.ExitStack() as stack:
+ if sock:
+ s = sock
+ else:
+ s = stack.enter_context(_make_socket(af, socket.SOCK_DGRAM, source))
+ send_udp(s, wire, destination, expiration)
(r, received_time) = receive_udp(s, destination, expiration,
ignore_unexpected, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing,
raise_on_truncation)
- r.time = received_time - sent_time
+ r.time = received_time - begin_time
if not q.is_response(r):
raise BadResponse
return r
def udp_with_fallback(q, where, timeout=None, port=53, source=None,
source_port=0, ignore_unexpected=False,
- one_rr_per_rrset=False, ignore_trailing=False):
+ one_rr_per_rrset=False, ignore_trailing=False,
+ udp_sock=None, tcp_sock=None):
"""Return the response to the query, trying UDP first and falling back
to TCP if UDP results in a truncated response.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
+ UDP query. If ``None``, the default, a socket is created. Note that
+ if a socket is provided, it must be a nonblocking datagram socket,
+ and the *source* and *source_port* are ignored for the UDP query.
+
+ *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
+ TCP query. If ``None``, the default, a socket is created. Note that
+ if a socket is provided, it must be a nonblocking connected stream
+ socket, and *where*, *source* and *source_port* are ignored for the TCP
+ query.
+
Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
if and only if TCP was used.
"""
try:
response = udp(q, where, timeout, port, source, source_port,
ignore_unexpected, one_rr_per_rrset,
- ignore_trailing, True)
+ ignore_trailing, True, udp_sock)
return (response, False)
except dns.message.Truncated:
response = tcp(q, where, timeout, port, source, source_port,
- one_rr_per_rrset, ignore_trailing)
+ one_rr_per_rrset, ignore_trailing, tcp_sock)
return (response, True)
def _net_read(sock, count, expiration):
def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False):
+ one_rr_per_rrset=False, ignore_trailing=False, sock=None):
"""Return the response obtained after sending a query via TCP.
*q*, a ``dns.message.Message``, the query to send
- *where*, a ``str`` containing an IPv4 or IPv6 address, where
+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
to send the message.
*timeout*, a ``float`` or ``None``, the number of seconds to wait before the
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *sock*, a ``socket.socket``, or ``None``, the socket to use for the
+ query. If ``None``, the default, a socket is created. Note that
+ if a socket is provided, it must be a nonblocking connected stream
+ socket, and *where*, *source* and *source_port* are ignored.
+
Returns a ``dns.message.Message``.
"""
wire = q.to_wire()
- (af, destination, source) = _destination_and_source(None, where, port,
- source, source_port)
- with socket_factory(af, socket.SOCK_STREAM, 0) as s:
- expiration = _compute_expiration(timeout)
- s.setblocking(0)
- begin_time = time.time()
- if source is not None:
- s.bind(source)
- _connect(s, destination, expiration)
+ (begin_time, expiration) = _compute_times(timeout)
+ with contextlib.ExitStack() as stack:
+ if sock:
+ #
+ # Verify that the socket is connected, as if it's not connected,
+ # it's not writable, and the polling in send_tcp() will time out or
+ # hang forever.
+ sock.getpeername()
+ s = sock
+ else:
+ (af, destination, source) = _destination_and_source(None, where,
+ port, source,
+ source_port)
+ s = stack.enter_context(_make_socket(af, socket.SOCK_STREAM,
+ source))
+ _connect(s, destination, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
def tls(q, where, timeout=None, port=853, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False,
+ one_rr_per_rrset=False, ignore_trailing=False, sock=None,
ssl_context=None, server_hostname=None):
"""Return the response obtained after sending a query via TLS.
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for the
+ query. If ``None``, the default, a socket is created. Note that
+ if a socket is provided, it must be a nonblocking connected SSL stream
+ socket, and *where*, *source*, *source_port*, and *ssl_context* are ignored.
+
*ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
a TLS connection. If ``None``, the default, creates one with the default
configuration.
Returns a ``dns.message.Message``.
"""
+ if sock:
+ #
+ # If a socket was provided, there's no special TLS handling needed.
+ #
+ return tcp(q, where, timeout, port, source, source_port,
+ one_rr_per_rrset, ignore_trailing, sock)
+
wire = q.to_wire()
+ (begin_time, expiration) = _compute_times(timeout)
(af, destination, source) = _destination_and_source(None, where, port,
- source, source_port)
- if ssl_context is None:
+ source, source_port)
+ if ssl_context is None and not sock:
ssl_context = ssl.create_default_context()
if server_hostname is None:
ssl_context.check_hostname = False
- with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0),
- do_handshake_on_connect=False,
- server_hostname=server_hostname) as s:
- expiration = _compute_expiration(timeout)
- s.setblocking(0)
- begin_time = time.time()
- if source is not None:
- s.bind(source)
+
+ with _make_socket(af, socket.SOCK_STREAM, source, ssl_context=ssl_context,
+ server_hostname=server_hostname) as s:
_connect(s, destination, expiration)
_tls_handshake(s, expiration)
send_tcp(s, wire, expiration)
if use_udp and rdtype != dns.rdatatype.IXFR:
raise ValueError('cannot do a UDP AXFR')
sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
- with socket_factory(af, sock_type, 0) as s:
- s.setblocking(0)
- if source is not None:
- s.bind(source)
- expiration = _compute_expiration(lifetime)
+ with _make_socket(af, sock_type, source) as s:
+ (_, expiration) = _compute_times(lifetime)
_connect(s, destination, expiration)
l = len(wire)
if use_udp:
tsig_ctx = None
first = True
while not done:
- mexpiration = _compute_expiration(timeout)
+ (_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \
(expiration is not None and mexpiration > expiration):
mexpiration = expiration
import socket
import unittest
+try:
+ import ssl
+ have_ssl = True
+except Exception:
+ have_ssl = False
+
import dns.message
import dns.name
import dns.rdataclass
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryUDPWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ s.setblocking(0)
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.udp(q, '8.8.8.8', sock=s)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTCP(self):
qname = dns.name.from_text('dns.google.')
q = dns.message.make_query(qname, dns.rdatatype.A)
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryTCPWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.connect(('8.8.8.8', 53))
+ s.setblocking(0)
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tcp(q, None, sock=s)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTLS(self):
qname = dns.name.from_text('dns.google.')
q = dns.message.make_query(qname, dns.rdatatype.A)
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ @unittest.skipUnless(have_ssl, "No SSL support")
+ def testQueryTLSWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.connect(('8.8.8.8', 853))
+ ctx = ssl.create_default_context()
+ s = ctx.wrap_socket(s, server_hostname='dns.google')
+ s.setblocking(0)
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tls(q, None, sock=s)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryUDPFallback(self):
qname = dns.name.from_text('.')
q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
(_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
self.assertTrue(tcp)
+ def testQueryUDPFallbackWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_s:
+ udp_s.setblocking(0)
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_s:
+ tcp_s.connect(('8.8.8.8', 53))
+ tcp_s.setblocking(0)
+ qname = dns.name.from_text('.')
+ q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+ (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8',
+ udp_sock=udp_s,
+ tcp_sock=tcp_s)
+ self.assertTrue(tcp)
+
def testQueryUDPFallbackNoFallback(self):
qname = dns.name.from_text('dns.google.')
q = dns.message.make_query(qname, dns.rdatatype.A)