From 0fa0d197f9cd978d16c7d2ff73f6319173d7ff45 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Wed, 20 May 2020 18:48:31 -0700 Subject: [PATCH] Use context managers in the query methods. --- dns/query.py | 89 +++++++++++++++++++--------------------------------- 1 file changed, 33 insertions(+), 56 deletions(-) diff --git a/dns/query.py b/dns/query.py index 080a66dd..a21bd650 100644 --- a/dns/query.py +++ b/dns/query.py @@ -463,10 +463,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) - s = socket_factory(af, socket.SOCK_DGRAM, 0) - received_time = None - sent_time = None - try: + with socket_factory(af, socket.SOCK_DGRAM, 0) as s: expiration = _compute_expiration(timeout) s.setblocking(0) if source is not None: @@ -475,16 +472,10 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, (r, received_time) = receive_udp(s, destination, expiration, ignore_unexpected, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) - finally: - if sent_time is None or received_time is None: - response_time = 0 - else: - response_time = received_time - sent_time - s.close() - r.time = response_time - if not q.is_response(r): - raise BadResponse - return r + r.time = received_time - sent_time + if not q.is_response(r): + raise BadResponse + return r def _net_read(sock, count, expiration): @@ -637,10 +628,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) - s = socket_factory(af, socket.SOCK_STREAM, 0) - begin_time = None - received_time = None - try: + with socket_factory(af, socket.SOCK_STREAM, 0) as s: expiration = _compute_expiration(timeout) s.setblocking(0) begin_time = time.time() @@ -650,16 +638,21 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, send_tcp(s, wire, expiration) (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) - finally: - if begin_time is None or received_time is None: - response_time = 0 - else: - response_time = received_time - begin_time - s.close() - r.time = response_time - if not q.is_response(r): - raise BadResponse - return r + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + + +def _tls_handshake(s, expiration): + while True: + try: + s.do_handshake() + return + except ssl.SSLWantReadError: + _wait_for_readable(s, expiration) + except ssl.SSLWantWriteError: + _wait_for_writable(s, expiration) def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0, @@ -708,43 +701,27 @@ def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0, wire = q.to_wire() (af, destination, source) = _destination_and_source(af, where, port, source, source_port) - s = socket_factory(af, socket.SOCK_STREAM, 0) - begin_time = None - received_time = None - try: + if ssl_context is None: + ssl_context = ssl.create_default_context() + if server_hostname is None: + ssl_context.check_hostname = False + with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0), + do_handshake_on_connect=False, + server_hostname=server_hostname) as s: expiration = _compute_expiration(timeout) s.setblocking(0) begin_time = time.time() if source is not None: s.bind(source) _connect(s, destination, expiration) - if ssl_context is None: - ssl_context = ssl.create_default_context() - if server_hostname is None: - ssl_context.check_hostname = False - s = ssl_context.wrap_socket(s, do_handshake_on_connect=False, - server_hostname=server_hostname) - while True: - try: - s.do_handshake() - break - except ssl.SSLWantReadError: - _wait_for_readable(s, expiration) - except ssl.SSLWantWriteError: - _wait_for_writable(s, expiration) + _tls_handshake(s, expiration) send_tcp(s, wire, expiration) (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing) - finally: - if begin_time is None or received_time is None: - response_time = 0 - else: - response_time = received_time - begin_time - s.close() - r.time = response_time - if not q.is_response(r): - raise BadResponse - return r + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, -- 2.47.3