]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-123344: Add missing ast optimizations for PEP 696 (#123377)
authorBogdan Romanyuk <65823030+wrongnull@users.noreply.github.com>
Wed, 28 Aug 2024 13:38:56 +0000 (16:38 +0300)
committerGitHub <noreply@github.com>
Wed, 28 Aug 2024 13:38:56 +0000 (06:38 -0700)
Co-authored-by: Kirill Podoprigora <kirill.bast9@mail.ru>
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Lib/test/test_ast/test_ast.py
Misc/NEWS.d/next/Core and Builtins/2024-08-27-13-16-40.gh-issue-123344.56Or78.rst [new file with mode: 0644]
Python/ast_opt.c

index c37b24adcc71550d9b1d72faaa445f32a8b3bf09..77596eca5d8b74798beae5fb778fa85fbc7ae55b 100644 (file)
@@ -3062,8 +3062,8 @@ class ASTOptimiziationTests(unittest.TestCase):
     def wrap_expr(self, expr):
         return ast.Module(body=[ast.Expr(value=expr)])
 
-    def wrap_for(self, for_statement):
-        return ast.Module(body=[for_statement])
+    def wrap_statement(self, statement):
+        return ast.Module(body=[statement])
 
     def assert_ast(self, code, non_optimized_target, optimized_target):
         non_optimized_tree = ast.parse(code, optimize=-1)
@@ -3090,16 +3090,16 @@ class ASTOptimiziationTests(unittest.TestCase):
             f"{ast.dump(optimized_tree)}",
         )
 
+    def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
+            return ast.BinOp(left=left, op=self.binop[operand], right=right)
+
     def test_folding_binop(self):
         code = "1 %s 1"
         operators = self.binop.keys()
 
-        def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
-            return ast.BinOp(left=left, op=self.binop[operand], right=right)
-
         for op in operators:
             result_code = code % op
-            non_optimized_target = self.wrap_expr(create_binop(op))
+            non_optimized_target = self.wrap_expr(self.create_binop(op))
             optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))
 
             with self.subTest(
@@ -3111,7 +3111,7 @@ class ASTOptimiziationTests(unittest.TestCase):
 
         # Multiplication of constant tuples must be folded
         code = "(1,) * 3"
-        non_optimized_target = self.wrap_expr(create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
+        non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
         optimized_target = self.wrap_expr(ast.Constant(eval(code)))
 
         self.assert_ast(code, non_optimized_target, optimized_target)
@@ -3222,12 +3222,12 @@ class ASTOptimiziationTests(unittest.TestCase):
         ]
 
         for left, right, ast_cls, optimized_iter in braces:
-            non_optimized_target = self.wrap_for(ast.For(
+            non_optimized_target = self.wrap_statement(ast.For(
                 target=ast.Name(id="_", ctx=ast.Store()),
                 iter=ast_cls(elts=[ast.Constant(1)]),
                 body=[ast.Pass()]
             ))
-            optimized_target = self.wrap_for(ast.For(
+            optimized_target = self.wrap_statement(ast.For(
                 target=ast.Name(id="_", ctx=ast.Store()),
                 iter=ast.Constant(value=optimized_iter),
                 body=[ast.Pass()]
@@ -3245,6 +3245,92 @@ class ASTOptimiziationTests(unittest.TestCase):
 
         self.assert_ast(code, non_optimized_target, optimized_target)
 
+    def test_folding_type_param_in_function_def(self):
+        code = "def foo[%s = 1 + 1](): pass"
+
+        unoptimized_binop = self.create_binop("+")
+        unoptimized_type_params = [
+            ("T", "T", ast.TypeVar),
+            ("**P", "P", ast.ParamSpec),
+            ("*Ts", "Ts", ast.TypeVarTuple),
+        ]
+
+        for type, name, type_param in unoptimized_type_params:
+            result_code = code % type
+            optimized_target = self.wrap_statement(
+                ast.FunctionDef(
+                    name='foo',
+                    args=ast.arguments(),
+                    body=[ast.Pass()],
+                    type_params=[type_param(name=name, default_value=ast.Constant(2))]
+                )
+            )
+            non_optimized_target = self.wrap_statement(
+                ast.FunctionDef(
+                    name='foo',
+                    args=ast.arguments(),
+                    body=[ast.Pass()],
+                    type_params=[type_param(name=name, default_value=unoptimized_binop)]
+                )
+            )
+            self.assert_ast(result_code, non_optimized_target, optimized_target)
+
+    def test_folding_type_param_in_class_def(self):
+        code = "class foo[%s = 1 + 1]: pass"
+
+        unoptimized_binop = self.create_binop("+")
+        unoptimized_type_params = [
+            ("T", "T", ast.TypeVar),
+            ("**P", "P", ast.ParamSpec),
+            ("*Ts", "Ts", ast.TypeVarTuple),
+        ]
+
+        for type, name, type_param in unoptimized_type_params:
+            result_code = code % type
+            optimized_target = self.wrap_statement(
+                ast.ClassDef(
+                    name='foo',
+                    body=[ast.Pass()],
+                    type_params=[type_param(name=name, default_value=ast.Constant(2))]
+                )
+            )
+            non_optimized_target = self.wrap_statement(
+                ast.ClassDef(
+                    name='foo',
+                    body=[ast.Pass()],
+                    type_params=[type_param(name=name, default_value=unoptimized_binop)]
+                )
+            )
+            self.assert_ast(result_code, non_optimized_target, optimized_target)
+
+    def test_folding_type_param_in_type_alias(self):
+        code = "type foo[%s = 1 + 1] = 1"
+
+        unoptimized_binop = self.create_binop("+")
+        unoptimized_type_params = [
+            ("T", "T", ast.TypeVar),
+            ("**P", "P", ast.ParamSpec),
+            ("*Ts", "Ts", ast.TypeVarTuple),
+        ]
+
+        for type, name, type_param in unoptimized_type_params:
+            result_code = code % type
+            optimized_target = self.wrap_statement(
+                ast.TypeAlias(
+                    name=ast.Name(id='foo', ctx=ast.Store()),
+                    type_params=[type_param(name=name, default_value=ast.Constant(2))],
+                    value=ast.Constant(value=1),
+                )
+            )
+            non_optimized_target = self.wrap_statement(
+                ast.TypeAlias(
+                    name=ast.Name(id='foo', ctx=ast.Store()),
+                    type_params=[type_param(name=name, default_value=unoptimized_binop)],
+                    value=ast.Constant(value=1),
+                )
+            )
+            self.assert_ast(result_code, non_optimized_target, optimized_target)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/NEWS.d/next/Core and Builtins/2024-08-27-13-16-40.gh-issue-123344.56Or78.rst b/Misc/NEWS.d/next/Core and Builtins/2024-08-27-13-16-40.gh-issue-123344.56Or78.rst
new file mode 100644 (file)
index 0000000..b8b373d
--- /dev/null
@@ -0,0 +1 @@
+Add AST optimizations for type parameter defaults.
index d7a26e64150e55e5e2781c6a0f54feb956a23c69..503715e7405aefef41a674b4eb6c69e631c4c408 100644 (file)
@@ -1087,10 +1087,13 @@ astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat
     switch (node_->kind) {
         case TypeVar_kind:
             CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
+            CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.default_value);
             break;
         case ParamSpec_kind:
+            CALL_OPT(astfold_expr, expr_ty, node_->v.ParamSpec.default_value);
             break;
         case TypeVarTuple_kind:
+            CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVarTuple.default_value);
             break;
     }
     return 1;