Simplify code using try/finally to use context managers.
In some cases, contextlib.ExitStack() is used; this could probably be
further simplified to use contextlib.nullcontext() once Python 3.7+ is a
requirement.
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+import contextlib
import os
import hashlib
import random
self.hash_len = 20
self.pool = bytearray(b'\0' * self.hash_len)
if seed is not None:
- self.stir(bytearray(seed))
+ self._stir(bytearray(seed))
self.seeded = True
self.seed_pid = os.getpid()
else:
self.seeded = False
self.seed_pid = 0
- def stir(self, entropy, already_locked=False):
- if not already_locked:
- self.lock.acquire()
- try:
- for c in entropy:
- if self.pool_index == self.hash_len:
- self.pool_index = 0
- b = c & 0xff
- self.pool[self.pool_index] ^= b
- self.pool_index += 1
- finally:
- if not already_locked:
- self.lock.release()
+ def _stir(self, entropy):
+ for c in entropy:
+ if self.pool_index == self.hash_len:
+ self.pool_index = 0
+ b = c & 0xff
+ self.pool[self.pool_index] ^= b
+ self.pool_index += 1
+
+ def stir(self, entropy):
+ with self.lock:
+ self._stir(entropy)
def _maybe_seed(self):
if not self.seeded or self.seed_pid != os.getpid():
seed = os.urandom(16)
except Exception:
try:
- r = open('/dev/urandom', 'rb', 0)
- try:
+ with open('/dev/urandom', 'rb', 0) as r:
seed = r.read(16)
- finally:
- r.close()
except Exception:
seed = str(time.time())
self.seeded = True
self.seed_pid = os.getpid()
self.digest = None
seed = bytearray(seed)
- self.stir(seed, True)
+ self._stir(seed)
def random_8(self):
- self.lock.acquire()
- try:
+ with self.lock:
self._maybe_seed()
if self.digest is None or self.next_byte == self.hash_len:
self.hash.update(bytes(self.pool))
self.digest = bytearray(self.hash.digest())
- self.stir(self.digest, True)
+ self._stir(self.digest)
self.next_byte = 0
value = self.digest[self.next_byte]
self.next_byte += 1
- finally:
- self.lock.release()
return value
def random_16(self):
"""DNS Messages"""
+import contextlib
import io
import struct
import time
Returns a ``dns.message.Message object``
"""
- str_type = str
- opts = 'rU'
-
- if isinstance(f, str_type):
- f = open(f, opts)
- want_close = True
- else:
- want_close = False
-
- try:
- m = from_text(f)
- finally:
- if want_close:
- f.close()
- return m
+ with contextlib.ExitStack() as stack:
+ if isinstance(f, str):
+ f = stack.enter_context(open(f))
+ return from_text(f)
def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
"""Talk to a DNS server."""
+import contextlib
import errno
import os
import select
# set source port and source address
transport_adapter = SourceAddressAdapter(source)
- if session:
- close_session = False
- else:
- session = requests.sessions.Session()
- close_session = True
+ with contextlib.ExitStack() as stack:
+ if not session:
+ session = stack.enter_context(requests.sessions.Session())
- try:
if transport_adapter:
session.mount(url, transport_adapter)
url += "?dns={}".format(wire)
response = session.get(url, headers=headers, stream=True,
timeout=timeout, verify=verify)
- finally:
- if close_session:
- session.close()
# see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
# status codes
"""DNS stub resolver."""
from urllib.parse import urlparse
+import contextlib
import socket
import sys
import time
Returns a ``dns.resolver.Answer`` or ``None``.
"""
- try:
- self.lock.acquire()
+ with self.lock:
self._maybe_clean()
v = self.data.get(key)
if v is None or v.expiration <= time.time():
return None
return v
- finally:
- self.lock.release()
def put(self, key, value):
"""Associate key and value in the cache.
*value*, a ``dns.resolver.Answer``, the answer.
"""
- try:
- self.lock.acquire()
+ with self.lock:
self._maybe_clean()
self.data[key] = value
- finally:
- self.lock.release()
def flush(self, key=None):
"""Flush the cache.
query name, rdtype, and rdclass respectively.
"""
- try:
- self.lock.acquire()
+ with self.lock:
if key is not None:
if key in self.data:
del self.data[key]
else:
self.data = {}
self.next_cleaning = time.time() + self.cleaning_interval
- finally:
- self.lock.release()
class LRUCacheNode(object):
Returns a ``dns.resolver.Answer`` or ``None``.
"""
- try:
- self.lock.acquire()
+ with self.lock:
node = self.data.get(key)
if node is None:
return None
return None
node.link_after(self.sentinel)
return node.value
- finally:
- self.lock.release()
def put(self, key, value):
"""Associate key and value in the cache.
*value*, a ``dns.resolver.Answer``, the answer.
"""
- try:
- self.lock.acquire()
+ with self.lock:
node = self.data.get(key)
if node is not None:
node.unlink()
node = LRUCacheNode(key, value)
node.link_after(self.sentinel)
self.data[key] = node
- finally:
- self.lock.release()
def flush(self, key=None):
"""Flush the cache.
query name, rdtype, and rdclass respectively.
"""
- try:
- self.lock.acquire()
+ with self.lock:
if key is not None:
node = self.data.get(key)
if node is not None:
node.next = None
node = next
self.data = {}
- finally:
- self.lock.release()
class Resolver(object):
"""
- if isinstance(f, str):
- try:
- f = open(f, 'r')
- except IOError:
- # /etc/resolv.conf doesn't exist, can't be read, etc.
- raise NoResolverConfiguration
- want_close = True
- else:
- want_close = False
- try:
+ with contextlib.ExitStack() as stack:
+ if isinstance(f, str):
+ try:
+ f = stack.enter_context(open(f))
+ except IOError:
+ # /etc/resolv.conf doesn't exist, can't be read, etc.
+ raise NoResolverConfiguration
+
for l in f:
if len(l) == 0 or l[0] == '#' or l[0] == ';':
continue
self.ndots = int(opt.split(':')[1])
except (ValueError, IndexError):
pass
- finally:
- if want_close:
- f.close()
if len(self.nameservers) == 0:
raise NoResolverConfiguration
"""DNS Zones."""
+import contextlib
import io
import os
import re
@type nl: string or None
"""
- if isinstance(f, str):
- f = open(f, 'wb')
- want_close = True
- else:
- want_close = False
-
- # must be in this way, f.encoding may contain None, or even attribute
- # may not be there
- file_enc = getattr(f, 'encoding', None)
- if file_enc is None:
- file_enc = 'utf-8'
-
- if nl is None:
- nl_b = os.linesep.encode(file_enc) # binary mode, '\n' is not enough
- nl = '\n'
- elif isinstance(nl, str):
- nl_b = nl.encode(file_enc)
- else:
- nl_b = nl
- nl = nl.decode()
+ with contextlib.ExitStack() as stack:
+ if isinstance(f, str):
+ f = stack.enter_context(open(f, 'wb'))
+
+ # must be in this way, f.encoding may contain None, or even attribute
+ # may not be there
+ file_enc = getattr(f, 'encoding', None)
+ if file_enc is None:
+ file_enc = 'utf-8'
+
+ if nl is None:
+ # binary mode, '\n' is not enough
+ nl_b = os.linesep.encode(file_enc)
+ nl = '\n'
+ elif isinstance(nl, str):
+ nl_b = nl.encode(file_enc)
+ else:
+ nl_b = nl
+ nl = nl.decode()
- try:
if sorted:
names = list(self.keys())
names.sort()
except TypeError: # textual mode
f.write(l)
f.write(nl)
- finally:
- if want_close:
- f.close()
def to_text(self, sorted=True, relativize=True, nl=None):
"""Return a zone's text as though it were written to a file.
@rtype: dns.zone.Zone object
"""
- if isinstance(f, str):
- if filename is None:
- filename = f
- f = open(f, 'r')
- want_close = True
- else:
- if filename is None:
- filename = '<file>'
- want_close = False
-
- try:
- z = from_text(f, origin, rdclass, relativize, zone_factory,
- filename, allow_include, check_origin)
- finally:
- if want_close:
- f.close()
- return z
+ with contextlib.ExitStack() as stack:
+ if isinstance(f, str):
+ if filename is None:
+ filename = f
+ f = stack.enter_context(open(f))
+ return from_text(f, origin, rdclass, relativize, zone_factory,
+ filename, allow_include, check_origin)
def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):