wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
- s = socket_factory(af, socket.SOCK_DGRAM, 0)
- received_time = None
- sent_time = None
- try:
+ with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
if source is not None:
(r, received_time) = receive_udp(s, destination, expiration,
ignore_unexpected, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
- finally:
- if sent_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - sent_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
+ r.time = received_time - sent_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
def _net_read(sock, count, expiration):
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
- s = socket_factory(af, socket.SOCK_STREAM, 0)
- begin_time = None
- received_time = None
- try:
+ with socket_factory(af, socket.SOCK_STREAM, 0) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
begin_time = time.time()
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
- finally:
- if begin_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - begin_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+
+
+def _tls_handshake(s, expiration):
+ while True:
+ try:
+ s.do_handshake()
+ return
+ except ssl.SSLWantReadError:
+ _wait_for_readable(s, expiration)
+ except ssl.SSLWantWriteError:
+ _wait_for_writable(s, expiration)
def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
wire = q.to_wire()
(af, destination, source) = _destination_and_source(af, where, port,
source, source_port)
- s = socket_factory(af, socket.SOCK_STREAM, 0)
- begin_time = None
- received_time = None
- try:
+ if ssl_context is None:
+ ssl_context = ssl.create_default_context()
+ if server_hostname is None:
+ ssl_context.check_hostname = False
+ with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0),
+ do_handshake_on_connect=False,
+ server_hostname=server_hostname) as s:
expiration = _compute_expiration(timeout)
s.setblocking(0)
begin_time = time.time()
if source is not None:
s.bind(source)
_connect(s, destination, expiration)
- if ssl_context is None:
- ssl_context = ssl.create_default_context()
- if server_hostname is None:
- ssl_context.check_hostname = False
- s = ssl_context.wrap_socket(s, do_handshake_on_connect=False,
- server_hostname=server_hostname)
- while True:
- try:
- s.do_handshake()
- break
- except ssl.SSLWantReadError:
- _wait_for_readable(s, expiration)
- except ssl.SSLWantWriteError:
- _wait_for_writable(s, expiration)
+ _tls_handshake(s, expiration)
send_tcp(s, wire, expiration)
(r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
q.keyring, q.mac, ignore_trailing)
- finally:
- if begin_time is None or received_time is None:
- response_time = 0
- else:
- response_time = received_time - begin_time
- s.close()
- r.time = response_time
- if not q.is_response(r):
- raise BadResponse
- return r
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,