From 279f7b788b98e193bb8fcd2b6aaed1d57ce2c29b Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 1 May 2020 17:20:21 -0700 Subject: [PATCH] Use context managers to simplify code. Simplify code using try/finally to use context managers. In some cases, contextlib.ExitStack() is used; this could probably be further simplified to use contextlib.nullcontext() once Python 3.7+ is a requirement. --- dns/entropy.py | 41 +++++++++++++---------------- dns/message.py | 20 ++++----------- dns/query.py | 13 +++------- dns/resolver.py | 52 +++++++++++-------------------------- dns/zone.py | 68 ++++++++++++++++++++----------------------------- 5 files changed, 68 insertions(+), 126 deletions(-) diff --git a/dns/entropy.py b/dns/entropy.py index da4332c8..8c000099 100644 --- a/dns/entropy.py +++ b/dns/entropy.py @@ -15,6 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +import contextlib import os import hashlib import random @@ -41,26 +42,24 @@ class EntropyPool(object): self.hash_len = 20 self.pool = bytearray(b'\0' * self.hash_len) if seed is not None: - self.stir(bytearray(seed)) + self._stir(bytearray(seed)) self.seeded = True self.seed_pid = os.getpid() else: self.seeded = False self.seed_pid = 0 - def stir(self, entropy, already_locked=False): - if not already_locked: - self.lock.acquire() - try: - for c in entropy: - if self.pool_index == self.hash_len: - self.pool_index = 0 - b = c & 0xff - self.pool[self.pool_index] ^= b - self.pool_index += 1 - finally: - if not already_locked: - self.lock.release() + def _stir(self, entropy): + for c in entropy: + if self.pool_index == self.hash_len: + self.pool_index = 0 + b = c & 0xff + self.pool[self.pool_index] ^= b + self.pool_index += 1 + + def stir(self, entropy): + with self.lock: + self._stir(entropy) def _maybe_seed(self): if not self.seeded or self.seed_pid != os.getpid(): @@ -68,32 +67,26 @@ class EntropyPool(object): seed = os.urandom(16) except Exception: try: - r = open('/dev/urandom', 'rb', 0) - try: + with open('/dev/urandom', 'rb', 0) as r: seed = r.read(16) - finally: - r.close() except Exception: seed = str(time.time()) self.seeded = True self.seed_pid = os.getpid() self.digest = None seed = bytearray(seed) - self.stir(seed, True) + self._stir(seed) def random_8(self): - self.lock.acquire() - try: + with self.lock: self._maybe_seed() if self.digest is None or self.next_byte == self.hash_len: self.hash.update(bytes(self.pool)) self.digest = bytearray(self.hash.digest()) - self.stir(self.digest, True) + self._stir(self.digest) self.next_byte = 0 value = self.digest[self.next_byte] self.next_byte += 1 - finally: - self.lock.release() return value def random_16(self): diff --git a/dns/message.py b/dns/message.py index 935d6b87..b9b0253d 100644 --- a/dns/message.py +++ b/dns/message.py @@ -17,6 +17,7 @@ """DNS Messages""" +import contextlib import io import struct import time @@ -1046,21 +1047,10 @@ def from_file(f): Returns a ``dns.message.Message object`` """ - str_type = str - opts = 'rU' - - if isinstance(f, str_type): - f = open(f, opts) - want_close = True - else: - want_close = False - - try: - m = from_text(f) - finally: - if want_close: - f.close() - return m + with contextlib.ExitStack() as stack: + if isinstance(f, str): + f = stack.enter_context(open(f)) + return from_text(f) def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, diff --git a/dns/query.py b/dns/query.py index 41bf9197..15293774 100644 --- a/dns/query.py +++ b/dns/query.py @@ -17,6 +17,7 @@ """Talk to a DNS server.""" +import contextlib import errno import os import select @@ -316,13 +317,10 @@ def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0, # set source port and source address transport_adapter = SourceAddressAdapter(source) - if session: - close_session = False - else: - session = requests.sessions.Session() - close_session = True + with contextlib.ExitStack() as stack: + if not session: + session = stack.enter_context(requests.sessions.Session()) - try: if transport_adapter: session.mount(url, transport_adapter) @@ -341,9 +339,6 @@ def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0, url += "?dns={}".format(wire) response = session.get(url, headers=headers, stream=True, timeout=timeout, verify=verify) - finally: - if close_session: - session.close() # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes diff --git a/dns/resolver.py b/dns/resolver.py index 548e5614..805c5bd5 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -17,6 +17,7 @@ """DNS stub resolver.""" from urllib.parse import urlparse +import contextlib import socket import sys import time @@ -327,15 +328,12 @@ class Cache(object): Returns a ``dns.resolver.Answer`` or ``None``. """ - try: - self.lock.acquire() + with self.lock: self._maybe_clean() v = self.data.get(key) if v is None or v.expiration <= time.time(): return None return v - finally: - self.lock.release() def put(self, key, value): """Associate key and value in the cache. @@ -346,12 +344,9 @@ class Cache(object): *value*, a ``dns.resolver.Answer``, the answer. """ - try: - self.lock.acquire() + with self.lock: self._maybe_clean() self.data[key] = value - finally: - self.lock.release() def flush(self, key=None): """Flush the cache. @@ -363,16 +358,13 @@ class Cache(object): query name, rdtype, and rdclass respectively. """ - try: - self.lock.acquire() + with self.lock: if key is not None: if key in self.data: del self.data[key] else: self.data = {} self.next_cleaning = time.time() + self.cleaning_interval - finally: - self.lock.release() class LRUCacheNode(object): @@ -437,8 +429,7 @@ class LRUCache(object): Returns a ``dns.resolver.Answer`` or ``None``. """ - try: - self.lock.acquire() + with self.lock: node = self.data.get(key) if node is None: return None @@ -450,8 +441,6 @@ class LRUCache(object): return None node.link_after(self.sentinel) return node.value - finally: - self.lock.release() def put(self, key, value): """Associate key and value in the cache. @@ -462,8 +451,7 @@ class LRUCache(object): *value*, a ``dns.resolver.Answer``, the answer. """ - try: - self.lock.acquire() + with self.lock: node = self.data.get(key) if node is not None: node.unlink() @@ -475,8 +463,6 @@ class LRUCache(object): node = LRUCacheNode(key, value) node.link_after(self.sentinel) self.data[key] = node - finally: - self.lock.release() def flush(self, key=None): """Flush the cache. @@ -488,8 +474,7 @@ class LRUCache(object): query name, rdtype, and rdclass respectively. """ - try: - self.lock.acquire() + with self.lock: if key is not None: node = self.data.get(key) if node is not None: @@ -503,8 +488,6 @@ class LRUCache(object): node.next = None node = next self.data = {} - finally: - self.lock.release() class Resolver(object): @@ -590,16 +573,14 @@ class Resolver(object): """ - if isinstance(f, str): - try: - f = open(f, 'r') - except IOError: - # /etc/resolv.conf doesn't exist, can't be read, etc. - raise NoResolverConfiguration - want_close = True - else: - want_close = False - try: + with contextlib.ExitStack() as stack: + if isinstance(f, str): + try: + f = stack.enter_context(open(f)) + except IOError: + # /etc/resolv.conf doesn't exist, can't be read, etc. + raise NoResolverConfiguration + for l in f: if len(l) == 0 or l[0] == '#' or l[0] == ';': continue @@ -632,9 +613,6 @@ class Resolver(object): self.ndots = int(opt.split(':')[1]) except (ValueError, IndexError): pass - finally: - if want_close: - f.close() if len(self.nameservers) == 0: raise NoResolverConfiguration diff --git a/dns/zone.py b/dns/zone.py index f403bd84..73c9bf43 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -17,6 +17,7 @@ """DNS Zones.""" +import contextlib import io import os import re @@ -490,28 +491,26 @@ class Zone(object): @type nl: string or None """ - if isinstance(f, str): - f = open(f, 'wb') - want_close = True - else: - want_close = False - - # must be in this way, f.encoding may contain None, or even attribute - # may not be there - file_enc = getattr(f, 'encoding', None) - if file_enc is None: - file_enc = 'utf-8' - - if nl is None: - nl_b = os.linesep.encode(file_enc) # binary mode, '\n' is not enough - nl = '\n' - elif isinstance(nl, str): - nl_b = nl.encode(file_enc) - else: - nl_b = nl - nl = nl.decode() + with contextlib.ExitStack() as stack: + if isinstance(f, str): + f = stack.enter_context(open(f, 'wb')) + + # must be in this way, f.encoding may contain None, or even attribute + # may not be there + file_enc = getattr(f, 'encoding', None) + if file_enc is None: + file_enc = 'utf-8' + + if nl is None: + # binary mode, '\n' is not enough + nl_b = os.linesep.encode(file_enc) + nl = '\n' + elif isinstance(nl, str): + nl_b = nl.encode(file_enc) + else: + nl_b = nl + nl = nl.decode() - try: if sorted: names = list(self.keys()) names.sort() @@ -532,9 +531,6 @@ class Zone(object): except TypeError: # textual mode f.write(l) f.write(nl) - finally: - if want_close: - f.close() def to_text(self, sorted=True, relativize=True, nl=None): """Return a zone's text as though it were written to a file. @@ -1061,23 +1057,13 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN, @rtype: dns.zone.Zone object """ - if isinstance(f, str): - if filename is None: - filename = f - f = open(f, 'r') - want_close = True - else: - if filename is None: - filename = '' - want_close = False - - try: - z = from_text(f, origin, rdclass, relativize, zone_factory, - filename, allow_include, check_origin) - finally: - if want_close: - f.close() - return z + with contextlib.ExitStack() as stack: + if isinstance(f, str): + if filename is None: + filename = f + f = stack.enter_context(open(f)) + return from_text(f, origin, rdclass, relativize, zone_factory, + filename, allow_include, check_origin) def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): -- 2.47.3