]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Use context managers in the query methods. 478/head
authorBrian Wellington <bwelling@xbill.org>
Thu, 21 May 2020 01:48:31 +0000 (18:48 -0700)
committerBrian Wellington <bwelling@xbill.org>
Thu, 21 May 2020 01:48:31 +0000 (18:48 -0700)
dns/query.py

index 080a66ddba8a9a9d5dbdd0a9f35d1eda7ae1b7ee..a21bd6501ee50dcbcee96739c996b3fe9b2c3759 100644 (file)
@@ -463,10 +463,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
-    s = socket_factory(af, socket.SOCK_DGRAM, 0)
-    received_time = None
-    sent_time = None
-    try:
+    with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
         expiration = _compute_expiration(timeout)
         s.setblocking(0)
         if source is not None:
@@ -475,16 +472,10 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
         (r, received_time) = receive_udp(s, destination, expiration,
                                          ignore_unexpected, one_rr_per_rrset,
                                          q.keyring, q.mac, ignore_trailing)
-    finally:
-        if sent_time is None or received_time is None:
-            response_time = 0
-        else:
-            response_time = received_time - sent_time
-        s.close()
-    r.time = response_time
-    if not q.is_response(r):
-        raise BadResponse
-    return r
+        r.time = received_time - sent_time
+        if not q.is_response(r):
+            raise BadResponse
+        return r
 
 
 def _net_read(sock, count, expiration):
@@ -637,10 +628,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
-    s = socket_factory(af, socket.SOCK_STREAM, 0)
-    begin_time = None
-    received_time = None
-    try:
+    with socket_factory(af, socket.SOCK_STREAM, 0) as s:
         expiration = _compute_expiration(timeout)
         s.setblocking(0)
         begin_time = time.time()
@@ -650,16 +638,21 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
         send_tcp(s, wire, expiration)
         (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
                                          q.keyring, q.mac, ignore_trailing)
-    finally:
-        if begin_time is None or received_time is None:
-            response_time = 0
-        else:
-            response_time = received_time - begin_time
-        s.close()
-    r.time = response_time
-    if not q.is_response(r):
-        raise BadResponse
-    return r
+        r.time = received_time - begin_time
+        if not q.is_response(r):
+            raise BadResponse
+        return r
+
+
+def _tls_handshake(s, expiration):
+    while True:
+        try:
+            s.do_handshake()
+            return
+        except ssl.SSLWantReadError:
+            _wait_for_readable(s, expiration)
+        except ssl.SSLWantWriteError:
+            _wait_for_writable(s, expiration)
 
 
 def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
@@ -708,43 +701,27 @@ def tls(q, where, timeout=None, port=853, af=None, source=None, source_port=0,
     wire = q.to_wire()
     (af, destination, source) = _destination_and_source(af, where, port,
                                                         source, source_port)
-    s = socket_factory(af, socket.SOCK_STREAM, 0)
-    begin_time = None
-    received_time = None
-    try:
+    if ssl_context is None:
+        ssl_context = ssl.create_default_context()
+        if server_hostname is None:
+            ssl_context.check_hostname = False
+    with ssl_context.wrap_socket(socket_factory(af, socket.SOCK_STREAM, 0),
+                                 do_handshake_on_connect=False,
+                                 server_hostname=server_hostname) as s:
         expiration = _compute_expiration(timeout)
         s.setblocking(0)
         begin_time = time.time()
         if source is not None:
             s.bind(source)
         _connect(s, destination, expiration)
-        if ssl_context is None:
-            ssl_context = ssl.create_default_context()
-            if server_hostname is None:
-                ssl_context.check_hostname = False
-        s = ssl_context.wrap_socket(s, do_handshake_on_connect=False,
-                                    server_hostname=server_hostname)
-        while True:
-            try:
-                s.do_handshake()
-                break
-            except ssl.SSLWantReadError:
-                _wait_for_readable(s, expiration)
-            except ssl.SSLWantWriteError:
-                _wait_for_writable(s, expiration)
+        _tls_handshake(s, expiration)
         send_tcp(s, wire, expiration)
         (r, received_time) = receive_tcp(s, expiration, one_rr_per_rrset,
                                          q.keyring, q.mac, ignore_trailing)
-    finally:
-        if begin_time is None or received_time is None:
-            response_time = 0
-        else:
-            response_time = received_time - begin_time
-        s.close()
-    r.time = response_time
-    if not q.is_response(r):
-        raise BadResponse
-    return r
+        r.time = received_time - begin_time
+        if not q.is_response(r):
+            raise BadResponse
+        return r
 
 
 def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,