import pprint
import sys
import builtins
+import contextlib
from types import ModuleType, MethodType
from functools import wraps, partial
@wraps(func)
def patched(*args, **keywargs):
extra_args = []
- entered_patchers = []
-
- exc_info = tuple()
- try:
+ with contextlib.ExitStack() as exit_stack:
for patching in patched.patchings:
- arg = patching.__enter__()
- entered_patchers.append(patching)
+ arg = exit_stack.enter_context(patching)
if patching.attribute_name is not None:
keywargs.update(arg)
elif patching.new is DEFAULT:
args += tuple(extra_args)
return func(*args, **keywargs)
- except:
- if (patching not in entered_patchers and
- _is_started(patching)):
- # the patcher may have been started, but an exception
- # raised whilst entering one of its additional_patchers
- entered_patchers.append(patching)
- # Pass the exception to __exit__
- exc_info = sys.exc_info()
- # re-raise the exception
- raise
- finally:
- for patching in reversed(entered_patchers):
- patching.__exit__(*exc_info)
patched.patchings = [self]
return patched
self.temp_original = original
self.is_local = local
- setattr(self.target, self.attribute, new_attr)
- if self.attribute_name is not None:
- extra_args = {}
- if self.new is DEFAULT:
- extra_args[self.attribute_name] = new
- for patching in self.additional_patchers:
- arg = patching.__enter__()
- if patching.new is DEFAULT:
- extra_args.update(arg)
- return extra_args
-
- return new
-
+ self._exit_stack = contextlib.ExitStack()
+ try:
+ setattr(self.target, self.attribute, new_attr)
+ if self.attribute_name is not None:
+ extra_args = {}
+ if self.new is DEFAULT:
+ extra_args[self.attribute_name] = new
+ for patching in self.additional_patchers:
+ arg = self._exit_stack.enter_context(patching)
+ if patching.new is DEFAULT:
+ extra_args.update(arg)
+ return extra_args
+
+ return new
+ except:
+ if not self.__exit__(*sys.exc_info()):
+ raise
def __exit__(self, *exc_info):
"""Undo the patch."""
del self.temp_original
del self.is_local
del self.target
- for patcher in reversed(self.additional_patchers):
- if _is_started(patcher):
- patcher.__exit__(*exc_info)
+ exit_stack = self._exit_stack
+ del self._exit_stack
+ return exit_stack.__exit__(*exc_info)
def start(self):
# If the patch hasn't been started this will fail
pass
- return self.__exit__()
+ return self.__exit__(None, None, None)