]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Close #13062: Add inspect.getclosurevars to simplify testing stateful closures
authorNick Coghlan <ncoghlan@gmail.com>
Sat, 23 Jun 2012 09:39:55 +0000 (19:39 +1000)
committerNick Coghlan <ncoghlan@gmail.com>
Sat, 23 Jun 2012 09:39:55 +0000 (19:39 +1000)
Doc/library/inspect.rst
Doc/whatsnew/3.3.rst
Lib/inspect.py
Lib/test/test_inspect.py
Misc/NEWS

index 611f780254fc89b4d6996445be602d0bdfccd4ef..04b724b737834e7445d7b8ba938308d6b37c4a38 100644 (file)
@@ -497,6 +497,22 @@ Classes and functions
    .. versionadded:: 3.2
 
 
+.. function:: getclosurevars(func)
+
+   Get the mapping of external name references in a Python function or
+   method *func* to their current values. A
+   :term:`named tuple` ``ClosureVars(nonlocals, globals, builtins, unbound)``
+   is returned. *nonlocals* maps referenced names to lexical closure
+   variables, *globals* to the function's module globals and *builtins* to
+   the builtins visible from the function body. *unbound* is the set of names
+   referenced in the function that could not be resolved at all given the
+   current module globals and builtins.
+
+   :exc:`TypeError` is raised if *func* is not a Python function or method.
+
+   .. versionadded:: 3.3
+
+
 .. _inspect-stack:
 
 The interpreter stack
index 974cc7130289343da34752e9db5f4370c157be0c..592f9c99253567ae7a0e0381967a3fb104b2a40a 100644 (file)
@@ -1027,6 +1027,16 @@ parameter to control parameters of the secure channel.
 (Contributed by Sijin Joseph in :issue:`8808`)
 
 
+inspect
+-------
+
+A new :func:`~inspect.getclosurevars` function has been added. This function
+reports the current binding of all names referenced from the function body and
+where those names were resolved, making it easier to verify correct internal
+state when testing code that relies on stateful closures.
+
+(Contributed by Meador Inge and Nick Coghlan in :issue:`13062`)
+
 io
 --
 
index 484c9c358d30a236f248e3a1bcd68bf9d14d112d..dd2de6422123594eaba4c0db3bae11a0dc3b66dc 100644 (file)
@@ -42,6 +42,7 @@ import tokenize
 import types
 import warnings
 import functools
+import builtins
 from operator import attrgetter
 from collections import namedtuple, OrderedDict
 
@@ -1036,6 +1037,59 @@ def getcallargs(func, *positional, **named):
         _missing_arguments(f_name, kwonlyargs, False, arg2value)
     return arg2value
 
+ClosureVars = namedtuple('ClosureVars', 'nonlocals globals builtins unbound')
+
+def getclosurevars(func):
+    """
+    Get the mapping of free variables to their current values.
+
+    Returns a named tuple of dics mapping the current nonlocal, global
+    and builtin references as seen by the body of the function. A final
+    set of unbound names that could not be resolved is also provided.
+    """
+
+    if ismethod(func):
+        func = func.__func__
+
+    if not isfunction(func):
+        raise TypeError("'{!r}' is not a Python function".format(func))
+
+    code = func.__code__
+    # Nonlocal references are named in co_freevars and resolved
+    # by looking them up in __closure__ by positional index
+    if func.__closure__ is None:
+        nonlocal_vars = {}
+    else:
+        nonlocal_vars = {
+            var : cell.cell_contents
+            for var, cell in zip(code.co_freevars, func.__closure__)
+       }
+
+    # Global and builtin references are named in co_names and resolved
+    # by looking them up in __globals__ or __builtins__
+    global_ns = func.__globals__
+    builtin_ns = global_ns.get("__builtins__", builtins.__dict__)
+    if ismodule(builtin_ns):
+        builtin_ns = builtin_ns.__dict__
+    global_vars = {}
+    builtin_vars = {}
+    unbound_names = set()
+    for name in code.co_names:
+        if name in ("None", "True", "False"):
+            # Because these used to be builtins instead of keywords, they
+            # may still show up as name references. We ignore them.
+            continue
+        try:
+            global_vars[name] = global_ns[name]
+        except KeyError:
+            try:
+                builtin_vars[name] = builtin_ns[name]
+            except KeyError:
+                unbound_names.add(name)
+
+    return ClosureVars(nonlocal_vars, global_vars,
+                       builtin_vars, unbound_names)
+
 # -------------------------------------------------- stack frame extraction
 
 Traceback = namedtuple('Traceback', 'filename lineno function code_context index')
index 53c947fc9de2d2a0d276366f279f90d6ee7e4140..83277215a8bdd9c5c1b1abaadb3c0677bb0672db 100644 (file)
@@ -665,6 +665,105 @@ class TestClassesAndFunctions(unittest.TestCase):
         self.assertIn(('f', b.f), inspect.getmembers(b, inspect.ismethod))
 
 
+_global_ref = object()
+class TestGetClosureVars(unittest.TestCase):
+
+    def test_name_resolution(self):
+        # Basic test of the 4 different resolution mechanisms
+        def f(nonlocal_ref):
+            def g(local_ref):
+                print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+            return g
+        _arg = object()
+        nonlocal_vars = {"nonlocal_ref": _arg}
+        global_vars = {"_global_ref": _global_ref}
+        builtin_vars = {"print": print}
+        unbound_names = {"unbound_ref"}
+        expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+                                       builtin_vars, unbound_names)
+        self.assertEqual(inspect.getclosurevars(f(_arg)), expected)
+
+    def test_generator_closure(self):
+        def f(nonlocal_ref):
+            def g(local_ref):
+                print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+                yield
+            return g
+        _arg = object()
+        nonlocal_vars = {"nonlocal_ref": _arg}
+        global_vars = {"_global_ref": _global_ref}
+        builtin_vars = {"print": print}
+        unbound_names = {"unbound_ref"}
+        expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+                                       builtin_vars, unbound_names)
+        self.assertEqual(inspect.getclosurevars(f(_arg)), expected)
+
+    def test_method_closure(self):
+        class C:
+            def f(self, nonlocal_ref):
+                def g(local_ref):
+                    print(local_ref, nonlocal_ref, _global_ref, unbound_ref)
+                return g
+        _arg = object()
+        nonlocal_vars = {"nonlocal_ref": _arg}
+        global_vars = {"_global_ref": _global_ref}
+        builtin_vars = {"print": print}
+        unbound_names = {"unbound_ref"}
+        expected = inspect.ClosureVars(nonlocal_vars, global_vars,
+                                       builtin_vars, unbound_names)
+        self.assertEqual(inspect.getclosurevars(C().f(_arg)), expected)
+
+    def test_nonlocal_vars(self):
+        # More complex tests of nonlocal resolution
+        def _nonlocal_vars(f):
+            return inspect.getclosurevars(f).nonlocals
+
+        def make_adder(x):
+            def add(y):
+                return x + y
+            return add
+
+        def curry(func, arg1):
+            return lambda arg2: func(arg1, arg2)
+
+        def less_than(a, b):
+            return a < b
+
+        # The infamous Y combinator.
+        def Y(le):
+            def g(f):
+                return le(lambda x: f(f)(x))
+            Y.g_ref = g
+            return g(g)
+
+        def check_y_combinator(func):
+            self.assertEqual(_nonlocal_vars(func), {'f': Y.g_ref})
+
+        inc = make_adder(1)
+        add_two = make_adder(2)
+        greater_than_five = curry(less_than, 5)
+
+        self.assertEqual(_nonlocal_vars(inc), {'x': 1})
+        self.assertEqual(_nonlocal_vars(add_two), {'x': 2})
+        self.assertEqual(_nonlocal_vars(greater_than_five),
+                         {'arg1': 5, 'func': less_than})
+        self.assertEqual(_nonlocal_vars((lambda x: lambda y: x + y)(3)),
+                         {'x': 3})
+        Y(check_y_combinator)
+
+    def test_getclosurevars_empty(self):
+        def foo(): pass
+        _empty = inspect.ClosureVars({}, {}, {}, set())
+        self.assertEqual(inspect.getclosurevars(lambda: True), _empty)
+        self.assertEqual(inspect.getclosurevars(foo), _empty)
+
+    def test_getclosurevars_error(self):
+        class T: pass
+        self.assertRaises(TypeError, inspect.getclosurevars, 1)
+        self.assertRaises(TypeError, inspect.getclosurevars, list)
+        self.assertRaises(TypeError, inspect.getclosurevars, {})
+
+
 class TestGetcallargsFunctions(unittest.TestCase):
 
     def assertEqualCallArgs(self, func, call_params_string, locs=None):
@@ -2100,7 +2199,7 @@ def test_main():
         TestGetcallargsFunctions, TestGetcallargsMethods,
         TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState,
         TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject,
-        TestBoundArguments
+        TestBoundArguments, TestGetClosureVars
     )
 
 if __name__ == "__main__":
index b6f735fc7a36600850d2d24e39bfe5149be17536..0c69f3506ed3c292d9d413dd2f6d71fa77303af0 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -40,7 +40,10 @@ Core and Builtins
 Library
 -------
 
-- Issues #11024: Fixes and additional tests for Time2Internaldate.
+- Issue #13062: Added inspect.getclosurevars to simplify testing stateful
+  closures
+
+- Issue #11024: Fixes and additional tests for Time2Internaldate.
 
 - Issue #14626: Large refactoring of functions / parameters in the os module.
   Many functions now support "dir_fd" and "follow_symlinks" parameters;