from __future__ import with_statement
+from types import NoneType
+
import contextlib
import functools
import itertools
def __exit__(self, type, value, traceback):
_state.contexts = self.old_contexts
+class _StackContextWrapper(functools.partial):
+ pass
+
def wrap(fn):
'''Returns a callable object that will resore the current StackContext
when executed.
different execution context (either in a different thread or
asynchronously in the same thread).
'''
- if fn is None:
- return None
# functools.wraps doesn't appear to work on functools.partial objects
#@functools.wraps(fn)
def wrapped(callback, contexts, *args, **kwargs):
+ if contexts is _state.contexts or not contexts:
+ callback(*args, **kwargs)
+ return
+ if not _state.contexts:
+ new_contexts = [cls(arg) for (cls, arg) in contexts]
# If we're moving down the stack, _state.contexts is a prefix
# of contexts. For each element of contexts not in that prefix,
# create a new StackContext object.
# If we're moving up the stack (or to an entirely different stack),
# _state.contexts will have elements not in contexts. Use
# NullContext to clear the state and then recreate from contexts.
- if (len(_state.contexts) > len(contexts) or
+ elif (len(_state.contexts) > len(contexts) or
any(a[1] is not b[1]
for a, b in itertools.izip(_state.contexts, contexts))):
# contexts have been removed or changed, so start over
callback(*args, **kwargs)
else:
callback(*args, **kwargs)
- if getattr(fn, 'stack_context_wrapped', False):
+ if isinstance(fn, (_StackContextWrapper, NoneType)):
return fn
- contexts = _state.contexts
- result = functools.partial(wrapped, fn, contexts)
- result.stack_context_wrapped = True
- return result
+ return _StackContextWrapper(wrapped, fn, _state.contexts)