]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Use context managers to simplify code. 461/head
authorBrian Wellington <bwelling@xbill.org>
Sat, 2 May 2020 00:20:21 +0000 (17:20 -0700)
committerBrian Wellington <bwelling@xbill.org>
Sat, 2 May 2020 00:20:21 +0000 (17:20 -0700)
Simplify code using try/finally to use context managers.

In some cases, contextlib.ExitStack() is used; this could probably be
further simplified to use contextlib.nullcontext() once Python 3.7+ is a
requirement.

dns/entropy.py
dns/message.py
dns/query.py
dns/resolver.py
dns/zone.py

index da4332c8e6ec031ad5c90ba33c2fe1fe268dc719..8c00009912dabe98323b193d8e757739766a7635 100644 (file)
@@ -15,6 +15,7 @@
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
+import contextlib
 import os
 import hashlib
 import random
@@ -41,26 +42,24 @@ class EntropyPool(object):
         self.hash_len = 20
         self.pool = bytearray(b'\0' * self.hash_len)
         if seed is not None:
-            self.stir(bytearray(seed))
+            self._stir(bytearray(seed))
             self.seeded = True
             self.seed_pid = os.getpid()
         else:
             self.seeded = False
             self.seed_pid = 0
 
-    def stir(self, entropy, already_locked=False):
-        if not already_locked:
-            self.lock.acquire()
-        try:
-            for c in entropy:
-                if self.pool_index == self.hash_len:
-                    self.pool_index = 0
-                b = c & 0xff
-                self.pool[self.pool_index] ^= b
-                self.pool_index += 1
-        finally:
-            if not already_locked:
-                self.lock.release()
+    def _stir(self, entropy):
+        for c in entropy:
+            if self.pool_index == self.hash_len:
+                self.pool_index = 0
+            b = c & 0xff
+            self.pool[self.pool_index] ^= b
+            self.pool_index += 1
+
+    def stir(self, entropy):
+        with self.lock:
+            self._stir(entropy)
 
     def _maybe_seed(self):
         if not self.seeded or self.seed_pid != os.getpid():
@@ -68,32 +67,26 @@ class EntropyPool(object):
                 seed = os.urandom(16)
             except Exception:
                 try:
-                    r = open('/dev/urandom', 'rb', 0)
-                    try:
+                    with open('/dev/urandom', 'rb', 0) as r:
                         seed = r.read(16)
-                    finally:
-                        r.close()
                 except Exception:
                     seed = str(time.time())
             self.seeded = True
             self.seed_pid = os.getpid()
             self.digest = None
             seed = bytearray(seed)
-            self.stir(seed, True)
+            self._stir(seed)
 
     def random_8(self):
-        self.lock.acquire()
-        try:
+        with self.lock:
             self._maybe_seed()
             if self.digest is None or self.next_byte == self.hash_len:
                 self.hash.update(bytes(self.pool))
                 self.digest = bytearray(self.hash.digest())
-                self.stir(self.digest, True)
+                self._stir(self.digest)
                 self.next_byte = 0
             value = self.digest[self.next_byte]
             self.next_byte += 1
-        finally:
-            self.lock.release()
         return value
 
     def random_16(self):
index 935d6b87a8a975a854d8ddeb74b9c3fc2664f588..b9b0253d0d7dc1e644f740aa9c98b52520d6c5a2 100644 (file)
@@ -17,6 +17,7 @@
 
 """DNS Messages"""
 
+import contextlib
 import io
 import struct
 import time
@@ -1046,21 +1047,10 @@ def from_file(f):
     Returns a ``dns.message.Message object``
     """
 
-    str_type = str
-    opts = 'rU'
-
-    if isinstance(f, str_type):
-        f = open(f, opts)
-        want_close = True
-    else:
-        want_close = False
-
-    try:
-        m = from_text(f)
-    finally:
-        if want_close:
-            f.close()
-    return m
+    with contextlib.ExitStack() as stack:
+        if isinstance(f, str):
+            f = stack.enter_context(open(f))
+        return from_text(f)
 
 
 def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
index 41bf919775a9a2ad749f56383551ca871685ba18..152937743451e8d2068150c6700db7698b0087cc 100644 (file)
@@ -17,6 +17,7 @@
 
 """Talk to a DNS server."""
 
+import contextlib
 import errno
 import os
 import select
@@ -316,13 +317,10 @@ def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0,
         # set source port and source address
         transport_adapter = SourceAddressAdapter(source)
 
-    if session:
-        close_session = False
-    else:
-        session = requests.sessions.Session()
-        close_session = True
+    with contextlib.ExitStack() as stack:
+        if not session:
+            session = stack.enter_context(requests.sessions.Session())
 
-    try:
         if transport_adapter:
             session.mount(url, transport_adapter)
 
@@ -341,9 +339,6 @@ def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0,
             url += "?dns={}".format(wire)
             response = session.get(url, headers=headers, stream=True,
                                    timeout=timeout, verify=verify)
-    finally:
-        if close_session:
-            session.close()
 
     # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
     # status codes
index 548e561499530cf1e2a90b14776c18b5de5c9e90..805c5bd597a0e4378e732b1f87486259ba4c844c 100644 (file)
@@ -17,6 +17,7 @@
 
 """DNS stub resolver."""
 from urllib.parse import urlparse
+import contextlib
 import socket
 import sys
 import time
@@ -327,15 +328,12 @@ class Cache(object):
         Returns a ``dns.resolver.Answer`` or ``None``.
         """
 
-        try:
-            self.lock.acquire()
+        with self.lock:
             self._maybe_clean()
             v = self.data.get(key)
             if v is None or v.expiration <= time.time():
                 return None
             return v
-        finally:
-            self.lock.release()
 
     def put(self, key, value):
         """Associate key and value in the cache.
@@ -346,12 +344,9 @@ class Cache(object):
         *value*, a ``dns.resolver.Answer``, the answer.
         """
 
-        try:
-            self.lock.acquire()
+        with self.lock:
             self._maybe_clean()
             self.data[key] = value
-        finally:
-            self.lock.release()
 
     def flush(self, key=None):
         """Flush the cache.
@@ -363,16 +358,13 @@ class Cache(object):
         query name, rdtype, and rdclass respectively.
         """
 
-        try:
-            self.lock.acquire()
+        with self.lock:
             if key is not None:
                 if key in self.data:
                     del self.data[key]
             else:
                 self.data = {}
                 self.next_cleaning = time.time() + self.cleaning_interval
-        finally:
-            self.lock.release()
 
 
 class LRUCacheNode(object):
@@ -437,8 +429,7 @@ class LRUCache(object):
         Returns a ``dns.resolver.Answer`` or ``None``.
         """
 
-        try:
-            self.lock.acquire()
+        with self.lock:
             node = self.data.get(key)
             if node is None:
                 return None
@@ -450,8 +441,6 @@ class LRUCache(object):
                 return None
             node.link_after(self.sentinel)
             return node.value
-        finally:
-            self.lock.release()
 
     def put(self, key, value):
         """Associate key and value in the cache.
@@ -462,8 +451,7 @@ class LRUCache(object):
         *value*, a ``dns.resolver.Answer``, the answer.
         """
 
-        try:
-            self.lock.acquire()
+        with self.lock:
             node = self.data.get(key)
             if node is not None:
                 node.unlink()
@@ -475,8 +463,6 @@ class LRUCache(object):
             node = LRUCacheNode(key, value)
             node.link_after(self.sentinel)
             self.data[key] = node
-        finally:
-            self.lock.release()
 
     def flush(self, key=None):
         """Flush the cache.
@@ -488,8 +474,7 @@ class LRUCache(object):
         query name, rdtype, and rdclass respectively.
         """
 
-        try:
-            self.lock.acquire()
+        with self.lock:
             if key is not None:
                 node = self.data.get(key)
                 if node is not None:
@@ -503,8 +488,6 @@ class LRUCache(object):
                     node.next = None
                     node = next
                 self.data = {}
-        finally:
-            self.lock.release()
 
 
 class Resolver(object):
@@ -590,16 +573,14 @@ class Resolver(object):
 
         """
 
-        if isinstance(f, str):
-            try:
-                f = open(f, 'r')
-            except IOError:
-                # /etc/resolv.conf doesn't exist, can't be read, etc.
-                raise NoResolverConfiguration
-            want_close = True
-        else:
-            want_close = False
-        try:
+        with contextlib.ExitStack() as stack:
+            if isinstance(f, str):
+                try:
+                    f = stack.enter_context(open(f))
+                except IOError:
+                    # /etc/resolv.conf doesn't exist, can't be read, etc.
+                    raise NoResolverConfiguration
+
             for l in f:
                 if len(l) == 0 or l[0] == '#' or l[0] == ';':
                     continue
@@ -632,9 +613,6 @@ class Resolver(object):
                                 self.ndots = int(opt.split(':')[1])
                             except (ValueError, IndexError):
                                 pass
-        finally:
-            if want_close:
-                f.close()
         if len(self.nameservers) == 0:
             raise NoResolverConfiguration
 
index f403bd84fc011956472c8d3653bbdfb365afdb27..73c9bf439e738e993c2c1dec2ded9078d62e8314 100644 (file)
@@ -17,6 +17,7 @@
 
 """DNS Zones."""
 
+import contextlib
 import io
 import os
 import re
@@ -490,28 +491,26 @@ class Zone(object):
         @type nl: string or None
         """
 
-        if isinstance(f, str):
-            f = open(f, 'wb')
-            want_close = True
-        else:
-            want_close = False
-
-        # must be in this way, f.encoding may contain None, or even attribute
-        # may not be there
-        file_enc = getattr(f, 'encoding', None)
-        if file_enc is None:
-            file_enc = 'utf-8'
-
-        if nl is None:
-            nl_b = os.linesep.encode(file_enc)  # binary mode, '\n' is not enough
-            nl = '\n'
-        elif isinstance(nl, str):
-            nl_b = nl.encode(file_enc)
-        else:
-            nl_b = nl
-            nl = nl.decode()
+        with contextlib.ExitStack() as stack:
+            if isinstance(f, str):
+                f = stack.enter_context(open(f, 'wb'))
+
+            # must be in this way, f.encoding may contain None, or even attribute
+            # may not be there
+            file_enc = getattr(f, 'encoding', None)
+            if file_enc is None:
+                file_enc = 'utf-8'
+
+            if nl is None:
+                # binary mode, '\n' is not enough
+                nl_b = os.linesep.encode(file_enc)
+                nl = '\n'
+            elif isinstance(nl, str):
+                nl_b = nl.encode(file_enc)
+            else:
+                nl_b = nl
+                nl = nl.decode()
 
-        try:
             if sorted:
                 names = list(self.keys())
                 names.sort()
@@ -532,9 +531,6 @@ class Zone(object):
                 except TypeError:  # textual mode
                     f.write(l)
                     f.write(nl)
-        finally:
-            if want_close:
-                f.close()
 
     def to_text(self, sorted=True, relativize=True, nl=None):
         """Return a zone's text as though it were written to a file.
@@ -1061,23 +1057,13 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN,
     @rtype: dns.zone.Zone object
     """
 
-    if isinstance(f, str):
-        if filename is None:
-            filename = f
-        f = open(f, 'r')
-        want_close = True
-    else:
-        if filename is None:
-            filename = '<file>'
-        want_close = False
-
-    try:
-        z = from_text(f, origin, rdclass, relativize, zone_factory,
-                      filename, allow_include, check_origin)
-    finally:
-        if want_close:
-            f.close()
-    return z
+    with contextlib.ExitStack() as stack:
+        if isinstance(f, str):
+            if filename is None:
+                filename = f
+            f = stack.enter_context(open(f))
+        return from_text(f, origin, rdclass, relativize, zone_factory,
+                         filename, allow_include, check_origin)
 
 
 def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):