]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-104377: fix cell in comprehension that is free in outer scope (#104394)
authorCarl Meyer <carl@oddbird.net>
Thu, 11 May 2023 23:48:21 +0000 (17:48 -0600)
committerGitHub <noreply@github.com>
Thu, 11 May 2023 23:48:21 +0000 (16:48 -0700)
Lib/test/test_listcomps.py
Python/compile.c

index 1cc202bb599ae66d087d8841cd7ef8490cc60dad..b2a3b7ea3e49b62454c1cfbb5766a1794dc3d5ae 100644 (file)
@@ -117,15 +117,15 @@ class ListComprehensionTest(unittest.TestCase):
                     newcode = code
                     def get_output(moddict, name):
                         return moddict[name]
-                ns = ns or {}
+                newns = ns.copy() if ns else {}
                 try:
-                    exec(newcode, ns)
+                    exec(newcode, newns)
                 except raises as e:
                     # We care about e.g. NameError vs UnboundLocalError
                     self.assertIs(type(e), raises)
                 else:
                     for k, v in (outputs or {}).items():
-                        self.assertEqual(get_output(ns, k), v)
+                        self.assertEqual(get_output(newns, k), v)
 
     def test_lambdas_with_iteration_var_as_default(self):
         code = """
@@ -180,6 +180,26 @@ class ListComprehensionTest(unittest.TestCase):
             z = [x() for x in items]
         """
         outputs = {"z": [2, 2, 2, 2, 2]}
+        self._check_in_scopes(code, outputs, scopes=["module", "function"])
+
+    def test_cell_inner_free_outer(self):
+        code = """
+            def f():
+                return [lambda: x for x in (x, [1])[1]]
+            x = ...
+            y = [fn() for fn in f()]
+        """
+        outputs = {"y": [1]}
+        self._check_in_scopes(code, outputs, scopes=["module", "function"])
+
+    def test_free_inner_cell_outer(self):
+        code = """
+            g = 2
+            def f():
+                return g
+            y = [g for x in [1]]
+        """
+        outputs = {"y": [2]}
         self._check_in_scopes(code, outputs)
 
     def test_inner_cell_shadows_outer_redefined(self):
@@ -203,6 +223,37 @@ class ListComprehensionTest(unittest.TestCase):
         outputs = {"x": -1}
         self._check_in_scopes(code, outputs, ns={"g": -1})
 
+    def test_explicit_global(self):
+        code = """
+            global g
+            x = g
+            g = 2
+            items = [g for g in [1]]
+            y = g
+        """
+        outputs = {"x": 1, "y": 2, "items": [1]}
+        self._check_in_scopes(code, outputs, ns={"g": 1})
+
+    def test_explicit_global_2(self):
+        code = """
+            global g
+            x = g
+            g = 2
+            items = [g for x in [1]]
+            y = g
+        """
+        outputs = {"x": 1, "y": 2, "items": [2]}
+        self._check_in_scopes(code, outputs, ns={"g": 1})
+
+    def test_explicit_global_3(self):
+        code = """
+            global g
+            fns = [lambda: g for g in [2]]
+            items = [fn() for fn in fns]
+        """
+        outputs = {"items": [2]}
+        self._check_in_scopes(code, outputs, ns={"g": 1})
+
     def test_assignment_expression(self):
         code = """
             x = -1
@@ -250,7 +301,7 @@ class ListComprehensionTest(unittest.TestCase):
             g()
         """
         outputs = {"x": 1}
-        self._check_in_scopes(code, outputs)
+        self._check_in_scopes(code, outputs, scopes=["module", "function"])
 
     def test_introspecting_frame_locals(self):
         code = """
index 941c6e9d4fdbb7ba0b17b8e338e1fea3ff446ae8..f8d0197e9f0682e982caa40f432d4837fd998552 100644 (file)
@@ -5028,14 +5028,19 @@ push_inlined_comprehension_state(struct compiler *c, location loc,
             long scope = (symbol >> SCOPE_OFFSET) & SCOPE_MASK;
             PyObject *outv = PyDict_GetItemWithError(c->u->u_ste->ste_symbols, k);
             if (outv == NULL) {
+                assert(PyErr_Occurred());
                 return ERROR;
             }
             assert(PyLong_Check(outv));
             long outsc = (PyLong_AS_LONG(outv) >> SCOPE_OFFSET) & SCOPE_MASK;
-            if (scope != outsc) {
+            if (scope != outsc && !(scope == CELL && outsc == FREE)) {
                 // If a name has different scope inside than outside the
                 // comprehension, we need to temporarily handle it with the
-                // right scope while compiling the comprehension.
+                // right scope while compiling the comprehension. (If it's free
+                // in outer scope and cell in inner scope, we can't treat it as
+                // both cell and free in the same function, but treating it as
+                // free throughout is fine; it's *_DEREF either way.)
+
                 if (state->temp_symbols == NULL) {
                     state->temp_symbols = PyDict_New();
                     if (state->temp_symbols == NULL) {
@@ -5071,7 +5076,11 @@ push_inlined_comprehension_state(struct compiler *c, location loc,
                 // comprehension and restore the original one after
                 ADDOP_NAME(c, loc, LOAD_FAST_AND_CLEAR, k, varnames);
                 if (scope == CELL) {
-                    ADDOP_NAME(c, loc, MAKE_CELL, k, cellvars);
+                    if (outsc == FREE) {
+                        ADDOP_NAME(c, loc, MAKE_CELL, k, freevars);
+                    } else {
+                        ADDOP_NAME(c, loc, MAKE_CELL, k, cellvars);
+                    }
                 }
                 if (PyList_Append(state->pushed_locals, k) < 0) {
                     return ERROR;