]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-126835: Move constant unaryop & binop folding to CFG (#129550)
authorYan Yanchii <yyanchiy@gmail.com>
Fri, 21 Feb 2025 17:54:22 +0000 (18:54 +0100)
committerGitHub <noreply@github.com>
Fri, 21 Feb 2025 17:54:22 +0000 (17:54 +0000)
Lib/test/test_ast/test_ast.py
Lib/test/test_ast/utils.py
Lib/test/test_builtin.py
Lib/test/test_peepholer.py
Python/ast_opt.c
Python/flowgraph.c

index 07a444856375527c925b5e76085ba27ecb061c30..1c202f82e941e83f13710185590a095ce09fa6a2 100644 (file)
@@ -154,18 +154,17 @@ class AST_Tests(unittest.TestCase):
                         self.assertEqual(res.body[0].value.id, expected)
 
     def test_optimization_levels_const_folding(self):
-        folded = ('Expr', (1, 0, 1, 5), ('Constant', (1, 0, 1, 5), 3, None))
-        not_folded = ('Expr', (1, 0, 1, 5),
-                         ('BinOp', (1, 0, 1, 5),
-                             ('Constant', (1, 0, 1, 1), 1, None),
-                             ('Add',),
-                             ('Constant', (1, 4, 1, 5), 2, None)))
+        folded = ('Expr', (1, 0, 1, 6), ('Constant', (1, 0, 1, 6), (1, 2), None))
+        not_folded = ('Expr', (1, 0, 1, 6),
+                         ('Tuple', (1, 0, 1, 6),
+                             [('Constant', (1, 1, 1, 2), 1, None),
+                             ('Constant', (1, 4, 1, 5), 2, None)], ('Load',)))
 
         cases = [(-1, not_folded), (0, not_folded), (1, folded), (2, folded)]
         for (optval, expected) in cases:
             with self.subTest(optval=optval):
-                tree1 = ast.parse("1 + 2", optimize=optval)
-                tree2 = ast.parse(ast.parse("1 + 2"), optimize=optval)
+                tree1 = ast.parse("(1, 2)", optimize=optval)
+                tree2 = ast.parse(ast.parse("(1, 2)"), optimize=optval)
                 for tree in [tree1, tree2]:
                     res = to_tuple(tree.body[0])
                     self.assertEqual(res, expected)
@@ -3089,27 +3088,6 @@ class ASTMainTests(unittest.TestCase):
 
 
 class ASTOptimiziationTests(unittest.TestCase):
-    binop = {
-        "+": ast.Add(),
-        "-": ast.Sub(),
-        "*": ast.Mult(),
-        "/": ast.Div(),
-        "%": ast.Mod(),
-        "<<": ast.LShift(),
-        ">>": ast.RShift(),
-        "|": ast.BitOr(),
-        "^": ast.BitXor(),
-        "&": ast.BitAnd(),
-        "//": ast.FloorDiv(),
-        "**": ast.Pow(),
-    }
-
-    unaryop = {
-        "~": ast.Invert(),
-        "+": ast.UAdd(),
-        "-": ast.USub(),
-    }
-
     def wrap_expr(self, expr):
         return ast.Module(body=[ast.Expr(value=expr)])
 
@@ -3141,83 +3119,6 @@ class ASTOptimiziationTests(unittest.TestCase):
             f"{ast.dump(optimized_tree)}",
         )
 
-    def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
-            return ast.BinOp(left=left, op=self.binop[operand], right=right)
-
-    def test_folding_binop(self):
-        code = "1 %s 1"
-        operators = self.binop.keys()
-
-        for op in operators:
-            result_code = code % op
-            non_optimized_target = self.wrap_expr(self.create_binop(op))
-            optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))
-
-            with self.subTest(
-                result_code=result_code,
-                non_optimized_target=non_optimized_target,
-                optimized_target=optimized_target
-            ):
-                self.assert_ast(result_code, non_optimized_target, optimized_target)
-
-        # Multiplication of constant tuples must be folded
-        code = "(1,) * 3"
-        non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
-        optimized_target = self.wrap_expr(ast.Constant(eval(code)))
-
-        self.assert_ast(code, non_optimized_target, optimized_target)
-
-    def test_folding_unaryop(self):
-        code = "%s1"
-        operators = self.unaryop.keys()
-
-        def create_unaryop(operand):
-            return ast.UnaryOp(op=self.unaryop[operand], operand=ast.Constant(1))
-
-        for op in operators:
-            result_code = code % op
-            non_optimized_target = self.wrap_expr(create_unaryop(op))
-            optimized_target = self.wrap_expr(ast.Constant(eval(result_code)))
-
-            with self.subTest(
-                result_code=result_code,
-                non_optimized_target=non_optimized_target,
-                optimized_target=optimized_target
-            ):
-                self.assert_ast(result_code, non_optimized_target, optimized_target)
-
-    def test_folding_not(self):
-        code = "not (1 %s (1,))"
-        operators = {
-            "in": ast.In(),
-            "is": ast.Is(),
-        }
-        opt_operators = {
-            "is": ast.IsNot(),
-            "in": ast.NotIn(),
-        }
-
-        def create_notop(operand):
-            return ast.UnaryOp(op=ast.Not(), operand=ast.Compare(
-                left=ast.Constant(value=1),
-                ops=[operators[operand]],
-                comparators=[ast.Tuple(elts=[ast.Constant(value=1)])]
-            ))
-
-        for op in operators.keys():
-            result_code = code % op
-            non_optimized_target = self.wrap_expr(create_notop(op))
-            optimized_target = self.wrap_expr(
-                ast.Compare(left=ast.Constant(1), ops=[opt_operators[op]], comparators=[ast.Constant(value=(1,))])
-            )
-
-            with self.subTest(
-                result_code=result_code,
-                non_optimized_target=non_optimized_target,
-                optimized_target=optimized_target
-            ):
-                self.assert_ast(result_code, non_optimized_target, optimized_target)
-
     def test_folding_format(self):
         code = "'%s' % (a,)"
 
@@ -3247,9 +3148,9 @@ class ASTOptimiziationTests(unittest.TestCase):
         self.assert_ast(code, non_optimized_target, optimized_target)
 
     def test_folding_type_param_in_function_def(self):
-        code = "def foo[%s = 1 + 1](): pass"
+        code = "def foo[%s = (1, 2)](): pass"
 
-        unoptimized_binop = self.create_binop("+")
+        unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
         unoptimized_type_params = [
             ("T", "T", ast.TypeVar),
             ("**P", "P", ast.ParamSpec),
@@ -3263,7 +3164,7 @@ class ASTOptimiziationTests(unittest.TestCase):
                     name='foo',
                     args=ast.arguments(),
                     body=[ast.Pass()],
-                    type_params=[type_param(name=name, default_value=ast.Constant(2))]
+                    type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
                 )
             )
             non_optimized_target = self.wrap_statement(
@@ -3271,15 +3172,15 @@ class ASTOptimiziationTests(unittest.TestCase):
                     name='foo',
                     args=ast.arguments(),
                     body=[ast.Pass()],
-                    type_params=[type_param(name=name, default_value=unoptimized_binop)]
+                    type_params=[type_param(name=name, default_value=unoptimized_tuple)]
                 )
             )
             self.assert_ast(result_code, non_optimized_target, optimized_target)
 
     def test_folding_type_param_in_class_def(self):
-        code = "class foo[%s = 1 + 1]: pass"
+        code = "class foo[%s = (1, 2)]: pass"
 
-        unoptimized_binop = self.create_binop("+")
+        unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
         unoptimized_type_params = [
             ("T", "T", ast.TypeVar),
             ("**P", "P", ast.ParamSpec),
@@ -3292,22 +3193,22 @@ class ASTOptimiziationTests(unittest.TestCase):
                 ast.ClassDef(
                     name='foo',
                     body=[ast.Pass()],
-                    type_params=[type_param(name=name, default_value=ast.Constant(2))]
+                    type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))]
                 )
             )
             non_optimized_target = self.wrap_statement(
                 ast.ClassDef(
                     name='foo',
                     body=[ast.Pass()],
-                    type_params=[type_param(name=name, default_value=unoptimized_binop)]
+                    type_params=[type_param(name=name, default_value=unoptimized_tuple)]
                 )
             )
             self.assert_ast(result_code, non_optimized_target, optimized_target)
 
     def test_folding_type_param_in_type_alias(self):
-        code = "type foo[%s = 1 + 1] = 1"
+        code = "type foo[%s = (1, 2)] = 1"
 
-        unoptimized_binop = self.create_binop("+")
+        unoptimized_tuple = ast.Tuple(elts=[ast.Constant(1), ast.Constant(2)])
         unoptimized_type_params = [
             ("T", "T", ast.TypeVar),
             ("**P", "P", ast.ParamSpec),
@@ -3319,19 +3220,80 @@ class ASTOptimiziationTests(unittest.TestCase):
             optimized_target = self.wrap_statement(
                 ast.TypeAlias(
                     name=ast.Name(id='foo', ctx=ast.Store()),
-                    type_params=[type_param(name=name, default_value=ast.Constant(2))],
+                    type_params=[type_param(name=name, default_value=ast.Constant((1, 2)))],
                     value=ast.Constant(value=1),
                 )
             )
             non_optimized_target = self.wrap_statement(
                 ast.TypeAlias(
                     name=ast.Name(id='foo', ctx=ast.Store()),
-                    type_params=[type_param(name=name, default_value=unoptimized_binop)],
+                    type_params=[type_param(name=name, default_value=unoptimized_tuple)],
                     value=ast.Constant(value=1),
                 )
             )
             self.assert_ast(result_code, non_optimized_target, optimized_target)
 
+    def test_folding_match_case_allowed_expressions(self):
+        def get_match_case_values(node):
+            result = []
+            if isinstance(node, ast.Constant):
+                result.append(node.value)
+            elif isinstance(node, ast.MatchValue):
+                result.extend(get_match_case_values(node.value))
+            elif isinstance(node, ast.MatchMapping):
+                for key in node.keys:
+                    result.extend(get_match_case_values(key))
+            elif isinstance(node, ast.MatchSequence):
+                for pat in node.patterns:
+                    result.extend(get_match_case_values(pat))
+            else:
+                self.fail(f"Unexpected node {node}")
+            return result
+
+        tests = [
+            ("-0", [0]),
+            ("-0.1", [-0.1]),
+            ("-0j", [complex(0, 0)]),
+            ("-0.1j", [complex(0, -0.1)]),
+            ("1 + 2j", [complex(1, 2)]),
+            ("1 - 2j", [complex(1, -2)]),
+            ("1.1 + 2.1j", [complex(1.1, 2.1)]),
+            ("1.1 - 2.1j", [complex(1.1, -2.1)]),
+            ("-0 + 1j", [complex(0, 1)]),
+            ("-0 - 1j", [complex(0, -1)]),
+            ("-0.1 + 1.1j", [complex(-0.1, 1.1)]),
+            ("-0.1 - 1.1j", [complex(-0.1, -1.1)]),
+            ("{-0: 0}", [0]),
+            ("{-0.1: 0}", [-0.1]),
+            ("{-0j: 0}", [complex(0, 0)]),
+            ("{-0.1j: 0}", [complex(0, -0.1)]),
+            ("{1 + 2j: 0}", [complex(1, 2)]),
+            ("{1 - 2j: 0}", [complex(1, -2)]),
+            ("{1.1 + 2.1j: 0}", [complex(1.1, 2.1)]),
+            ("{1.1 - 2.1j: 0}", [complex(1.1, -2.1)]),
+            ("{-0 + 1j: 0}", [complex(0, 1)]),
+            ("{-0 - 1j: 0}", [complex(0, -1)]),
+            ("{-0.1 + 1.1j: 0}", [complex(-0.1, 1.1)]),
+            ("{-0.1 - 1.1j: 0}", [complex(-0.1, -1.1)]),
+            ("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}", [0, complex(0, 1), complex(0.1, 1)]),
+            ("[-0, -0.1, -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+            ("[[[[-0, -0.1, -0j, -0.1j]]]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+            ("[[-0, -0.1], -0j, -0.1j]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+            ("[[-0, -0.1], [-0j, -0.1j]]", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+            ("(-0, -0.1, -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+            ("((((-0, -0.1, -0j, -0.1j))))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+            ("((-0, -0.1), -0j, -0.1j)", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+            ("((-0, -0.1), (-0j, -0.1j))", [0, -0.1, complex(0, 0), complex(0, -0.1)]),
+        ]
+        for match_expr, constants in tests:
+            with self.subTest(match_expr):
+                src = f"match 0:\n\t case {match_expr}: pass"
+                tree = ast.parse(src, optimize=1)
+                match_stmt = tree.body[0]
+                case = match_stmt.cases[0]
+                values = get_match_case_values(case.pattern)
+                self.assertListEqual(constants, values)
+
 
 if __name__ == '__main__':
     if len(sys.argv) > 1 and sys.argv[1] == '--snapshot-update':
index 145e89ee94e9352356bd82e886d8b46184bb63f9..e7054f3f710f1f0990a4ecc66c210530ab929cc7 100644 (file)
@@ -1,5 +1,5 @@
 def to_tuple(t):
-    if t is None or isinstance(t, (str, int, complex, float, bytes)) or t is Ellipsis:
+    if t is None or isinstance(t, (str, int, complex, float, bytes, tuple)) or t is Ellipsis:
         return t
     elif isinstance(t, list):
         return [to_tuple(e) for e in t]
index 913d007a126d72beb2c3a3e1068c8856475ebce0..d15964fe9dd88b079c4cfa537bc69633ba05955e 100644 (file)
@@ -555,7 +555,7 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
         self.assertEqual(type(glob['ticker']()), AsyncGeneratorType)
 
     def test_compile_ast(self):
-        args = ("a*(1+2)", "f.py", "exec")
+        args = ("a*(1,2)", "f.py", "exec")
         raw = compile(*args, flags = ast.PyCF_ONLY_AST).body[0]
         opt1 = compile(*args, flags = ast.PyCF_OPTIMIZED_AST).body[0]
         opt2 = compile(ast.parse(args[0]), *args[1:], flags = ast.PyCF_OPTIMIZED_AST).body[0]
@@ -566,17 +566,14 @@ class BuiltinTest(ComplexesAreIdenticalMixin, unittest.TestCase):
             self.assertIsInstance(tree.value.left, ast.Name)
             self.assertEqual(tree.value.left.id, 'a')
 
-        raw_right = raw.value.right  # expect BinOp(1, '+', 2)
-        self.assertIsInstance(raw_right, ast.BinOp)
-        self.assertIsInstance(raw_right.left, ast.Constant)
-        self.assertEqual(raw_right.left.value, 1)
-        self.assertIsInstance(raw_right.right, ast.Constant)
-        self.assertEqual(raw_right.right.value, 2)
+        raw_right = raw.value.right  # expect Tuple((1, 2))
+        self.assertIsInstance(raw_right, ast.Tuple)
+        self.assertListEqual([elt.value for elt in raw_right.elts], [1, 2])
 
         for opt in [opt1, opt2]:
-            opt_right = opt.value.right  # expect Constant(3)
+            opt_right = opt.value.right  # expect Constant((1,2))
             self.assertIsInstance(opt_right, ast.Constant)
-            self.assertEqual(opt_right.value, 3)
+            self.assertEqual(opt_right.value, (1, 2))
 
     def test_delattr(self):
         sys.spam = 1
index 4471c5129b96df374fe793a9a13f0136954c77e9..98f6b29dc7fc5e0522f430ef27aed033c806b6d8 100644 (file)
@@ -483,15 +483,28 @@ class TestTranforms(BytecodeTestCase):
 
     def test_constant_folding_small_int(self):
         tests = [
-            # subscript
             ('(0, )[0]', 0),
             ('(1 + 2, )[0]', 3),
             ('(2 + 2 * 2, )[0]', 6),
             ('(1, (1 + 2 + 3, ))[1][0]', 6),
+            ('1 + 2', 3),
+            ('2 + 2 * 2 // 2 - 2', 2),
             ('(255, )[0]', 255),
             ('(256, )[0]', None),
             ('(1000, )[0]', None),
             ('(1 - 2, )[0]', None),
+            ('255 + 0', 255),
+            ('255 + 1', None),
+            ('-1', None),
+            ('--1', 1),
+            ('--255', 255),
+            ('--256', None),
+            ('~1', None),
+            ('~~1', 1),
+            ('~~255', 255),
+            ('~~256', None),
+            ('++255', 255),
+            ('++256', None),
         ]
         for expr, oparg in tests:
             with self.subTest(expr=expr, oparg=oparg):
@@ -502,37 +515,97 @@ class TestTranforms(BytecodeTestCase):
                     self.assertNotInBytecode(code, 'LOAD_SMALL_INT')
                 self.check_lnotab(code)
 
-    def test_folding_subscript(self):
+    def test_folding_unaryop(self):
+        intrinsic_positive = 5
         tests = [
-            ('(1, )[0]', False),
-            ('(1, )[-1]', False),
-            ('(1 + 2, )[0]', False),
-            ('(1, (1, 2))[1][1]', False),
-            ('(1, 2)[2-1]', False),
-            ('(1, (1, 2))[1][2-1]', False),
-            ('(1, (1, 2))[1:6][0][2-1]', False),
-            ('"a"[0]', False),
-            ('("a" + "b")[1]', False),
-            ('("a" + "b", )[0][1]', False),
-            ('("a" * 10)[9]', False),
-            ('(1, )[1]', True),
-            ('(1, )[-2]', True),
-            ('"a"[1]', True),
-            ('"a"[-2]', True),
-            ('("a" + "b")[2]', True),
-            ('("a" + "b", )[0][2]', True),
-            ('("a" + "b", )[1][0]', True),
-            ('("a" * 10)[10]', True),
-            ('(1, (1, 2))[2:6][0][2-1]', True),
-        ]
-        subscr_argval = get_binop_argval('NB_SUBSCR')
-        for expr, has_error in tests:
+            ('---1', 'UNARY_NEGATIVE', None, True),
+            ('---""', 'UNARY_NEGATIVE', None, False),
+            ('~~~1', 'UNARY_INVERT', None, True),
+            ('~~~""', 'UNARY_INVERT', None, False),
+            ('not not True', 'UNARY_NOT', None, True),
+            ('not not x', 'UNARY_NOT', None, True),  # this should be optimized regardless of constant or not
+            ('+++1', 'CALL_INTRINSIC_1', intrinsic_positive, True),
+            ('---x', 'UNARY_NEGATIVE', None, False),
+            ('~~~x', 'UNARY_INVERT', None, False),
+            ('+++x', 'CALL_INTRINSIC_1', intrinsic_positive, False),
+        ]
+
+        for expr, opcode, oparg, optimized in tests:
+            with self.subTest(expr=expr, optimized=optimized):
+                code = compile(expr, '', 'single')
+                if optimized:
+                    self.assertNotInBytecode(code, opcode, argval=oparg)
+                else:
+                    self.assertInBytecode(code, opcode, argval=oparg)
+                self.check_lnotab(code)
+
+    def test_folding_binop(self):
+        tests = [
+            ('1 + 2', False, 'NB_ADD'),
+            ('1 + 2 + 3', False, 'NB_ADD'),
+            ('1 + ""', True, 'NB_ADD'),
+            ('1 - 2', False, 'NB_SUBTRACT'),
+            ('1 - 2 - 3', False, 'NB_SUBTRACT'),
+            ('1 - ""', True, 'NB_SUBTRACT'),
+            ('2 * 2', False, 'NB_MULTIPLY'),
+            ('2 * 2 * 2', False, 'NB_MULTIPLY'),
+            ('2 / 2', False, 'NB_TRUE_DIVIDE'),
+            ('2 / 2 / 2', False, 'NB_TRUE_DIVIDE'),
+            ('2 / ""', True, 'NB_TRUE_DIVIDE'),
+            ('2 // 2', False, 'NB_FLOOR_DIVIDE'),
+            ('2 // 2 // 2', False, 'NB_FLOOR_DIVIDE'),
+            ('2 // ""', True, 'NB_FLOOR_DIVIDE'),
+            ('2 % 2', False, 'NB_REMAINDER'),
+            ('2 % 2 % 2', False, 'NB_REMAINDER'),
+            ('2 % ()', True, 'NB_REMAINDER'),
+            ('2 ** 2', False, 'NB_POWER'),
+            ('2 ** 2 ** 2', False, 'NB_POWER'),
+            ('2 ** ""', True, 'NB_POWER'),
+            ('2 << 2', False, 'NB_LSHIFT'),
+            ('2 << 2 << 2', False, 'NB_LSHIFT'),
+            ('2 << ""', True, 'NB_LSHIFT'),
+            ('2 >> 2', False, 'NB_RSHIFT'),
+            ('2 >> 2 >> 2', False, 'NB_RSHIFT'),
+            ('2 >> ""', True, 'NB_RSHIFT'),
+            ('2 | 2', False, 'NB_OR'),
+            ('2 | 2 | 2', False, 'NB_OR'),
+            ('2 | ""', True, 'NB_OR'),
+            ('2 & 2', False, 'NB_AND'),
+            ('2 & 2 & 2', False, 'NB_AND'),
+            ('2 & ""', True, 'NB_AND'),
+            ('2 ^ 2', False, 'NB_XOR'),
+            ('2 ^ 2 ^ 2', False, 'NB_XOR'),
+            ('2 ^ ""', True, 'NB_XOR'),
+            ('(1, )[0]', False, 'NB_SUBSCR'),
+            ('(1, )[-1]', False, 'NB_SUBSCR'),
+            ('(1 + 2, )[0]', False, 'NB_SUBSCR'),
+            ('(1, (1, 2))[1][1]', False, 'NB_SUBSCR'),
+            ('(1, 2)[2-1]', False, 'NB_SUBSCR'),
+            ('(1, (1, 2))[1][2-1]', False, 'NB_SUBSCR'),
+            ('(1, (1, 2))[1:6][0][2-1]', False, 'NB_SUBSCR'),
+            ('"a"[0]', False, 'NB_SUBSCR'),
+            ('("a" + "b")[1]', False, 'NB_SUBSCR'),
+            ('("a" + "b", )[0][1]', False, 'NB_SUBSCR'),
+            ('("a" * 10)[9]', False, 'NB_SUBSCR'),
+            ('(1, )[1]', True, 'NB_SUBSCR'),
+            ('(1, )[-2]', True, 'NB_SUBSCR'),
+            ('"a"[1]', True, 'NB_SUBSCR'),
+            ('"a"[-2]', True, 'NB_SUBSCR'),
+            ('("a" + "b")[2]', True, 'NB_SUBSCR'),
+            ('("a" + "b", )[0][2]', True, 'NB_SUBSCR'),
+            ('("a" + "b", )[1][0]', True, 'NB_SUBSCR'),
+            ('("a" * 10)[10]', True, 'NB_SUBSCR'),
+            ('(1, (1, 2))[2:6][0][2-1]', True, 'NB_SUBSCR'),
+
+        ]
+        for expr, has_error, nb_op in tests:
             with self.subTest(expr=expr, has_error=has_error):
                 code = compile(expr, '', 'single')
+                nb_op_val = get_binop_argval(nb_op)
                 if not has_error:
-                    self.assertNotInBytecode(code, 'BINARY_OP', argval=subscr_argval)
+                    self.assertNotInBytecode(code, 'BINARY_OP', argval=nb_op_val)
                 else:
-                    self.assertInBytecode(code, 'BINARY_OP', argval=subscr_argval)
+                    self.assertInBytecode(code, 'BINARY_OP', argval=nb_op_val)
                 self.check_lnotab(code)
 
     def test_constant_folding_remove_nop_location(self):
@@ -1173,21 +1246,59 @@ class DirectCfgOptimizerTests(CfgOptimizationTestCase):
             }
         self.assertEqual(f(), frozenset(range(40)))
 
-    def test_multiple_foldings(self):
+    def test_nested_const_foldings(self):
+        # (1, (--2 + ++2 * 2 // 2 - 2, )[0], ~~3, not not True)  ==>  (1, 2, 3, True)
+        intrinsic_positive = 5
         before = [
             ('LOAD_SMALL_INT', 1, 0),
+            ('NOP', None, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('UNARY_NEGATIVE', None, 0),
+            ('NOP', None, 0),
+            ('UNARY_NEGATIVE', None, 0),
+            ('NOP', None, 0),
+            ('NOP', None, 0),
             ('LOAD_SMALL_INT', 2, 0),
+            ('CALL_INTRINSIC_1', intrinsic_positive, 0),
+            ('NOP', None, 0),
+            ('CALL_INTRINSIC_1', intrinsic_positive, 0),
+            ('BINARY_OP', get_binop_argval('NB_MULTIPLY')),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('NOP', None, 0),
+            ('BINARY_OP', get_binop_argval('NB_FLOOR_DIVIDE')),
+            ('NOP', None, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', get_binop_argval('NB_ADD')),
+            ('NOP', None, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('NOP', None, 0),
+            ('BINARY_OP', get_binop_argval('NB_SUBTRACT')),
+            ('NOP', None, 0),
             ('BUILD_TUPLE', 1, 0),
             ('LOAD_SMALL_INT', 0, 0),
             ('BINARY_OP', get_binop_argval('NB_SUBSCR'), 0),
-            ('BUILD_TUPLE', 2, 0),
+            ('NOP', None, 0),
+            ('LOAD_SMALL_INT', 3, 0),
+            ('NOP', None, 0),
+            ('UNARY_INVERT', None, 0),
+            ('NOP', None, 0),
+            ('UNARY_INVERT', None, 0),
+            ('NOP', None, 0),
+            ('LOAD_SMALL_INT', 3, 0),
+            ('NOP', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('NOP', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('NOP', None, 0),
+            ('BUILD_TUPLE', 4, 0),
+            ('NOP', None, 0),
             ('RETURN_VALUE', None, 0)
         ]
         after = [
             ('LOAD_CONST', 1, 0),
             ('RETURN_VALUE', None, 0)
         ]
-        self.cfg_optimization_test(before, after, consts=[], expected_consts=[(2,), (1, 2)])
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[-2, (1, 2, 3, True)])
 
     def test_build_empty_tuple(self):
         before = [
@@ -1535,6 +1646,502 @@ class DirectCfgOptimizerTests(CfgOptimizationTestCase):
         ]
         self.cfg_optimization_test(same, same, consts=[None], expected_consts=[None])
 
+    def test_optimize_unary_not(self):
+        # test folding
+        before = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('TO_BOOL', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_CONST', 1, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[True, False])
+
+        # test cancel out
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('TO_BOOL', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test eliminate to bool
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('TO_BOOL', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('TO_BOOL', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test folding & cancel out
+        before = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('TO_BOOL', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_CONST', 0, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[True])
+
+        # test folding & eliminate to bool
+        before = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('TO_BOOL', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_CONST', 1, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[True, False])
+
+        # test cancel out & eliminate to bool (to bool stays as we are not iterating to a fixed point)
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('TO_BOOL', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        is_ = in_ = 0
+        isnot = notin = 1
+
+        # test is/isnot
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', is_, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', isnot, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test is/isnot cancel out
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', is_, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', is_, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test is/isnot eliminate to bool
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', is_, 0),
+            ('UNARY_NOT', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', isnot, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test is/isnot cancel out & eliminate to bool
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', is_, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('IS_OP', is_, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test in/notin
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', in_, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', notin, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test in/notin cancel out
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', in_, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', in_, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test is/isnot & eliminate to bool
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', in_, 0),
+            ('UNARY_NOT', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', notin, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test in/notin cancel out & eliminate to bool
+        before = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', in_, 0),
+            ('UNARY_NOT', None, 0),
+            ('UNARY_NOT', None, 0),
+            ('TO_BOOL', None, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        after = [
+            ('LOAD_NAME', 0, 0),
+            ('LOAD_NAME', 1, 0),
+            ('CONTAINS_OP', in_, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+    def test_optimize_if_const_unaryop(self):
+        # test unary negative
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('UNARY_NEGATIVE', None, 0),
+            ('UNARY_NEGATIVE', None, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[-2])
+
+        # test unary invert
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('UNARY_INVERT', None, 0),
+            ('UNARY_INVERT', None, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[-3])
+
+        # test unary positive
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('CALL_INTRINSIC_1', 5, 0),
+            ('CALL_INTRINSIC_1', 5, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('RETURN_VALUE', None, 0),
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+    def test_optimize_if_const_binop(self):
+        add = get_binop_argval('NB_ADD')
+        sub = get_binop_argval('NB_SUBTRACT')
+        mul = get_binop_argval('NB_MULTIPLY')
+        div = get_binop_argval('NB_TRUE_DIVIDE')
+        floor = get_binop_argval('NB_FLOOR_DIVIDE')
+        rem = get_binop_argval('NB_REMAINDER')
+        pow = get_binop_argval('NB_POWER')
+        lshift = get_binop_argval('NB_LSHIFT')
+        rshift = get_binop_argval('NB_RSHIFT')
+        or_ = get_binop_argval('NB_OR')
+        and_ = get_binop_argval('NB_AND')
+        xor = get_binop_argval('NB_XOR')
+        subscr = get_binop_argval('NB_SUBSCR')
+
+        # test add
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', add, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', add, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 6, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test sub
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', sub, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', sub, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_CONST', 0, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[-2])
+
+        # test mul
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', mul, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', mul, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 8, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test div
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', div, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', div, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_CONST', 1, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[1.0, 0.5])
+
+        # test floor
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', floor, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', floor, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 0, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test rem
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', rem, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', rem, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 0, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test pow
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', pow, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', pow, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 16, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test lshift
+        before = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('LOAD_SMALL_INT', 1, 0),
+            ('BINARY_OP', lshift, 0),
+            ('LOAD_SMALL_INT', 1, 0),
+            ('BINARY_OP', lshift, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 4, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test rshift
+        before = [
+            ('LOAD_SMALL_INT', 4, 0),
+            ('LOAD_SMALL_INT', 1, 0),
+            ('BINARY_OP', rshift, 0),
+            ('LOAD_SMALL_INT', 1, 0),
+            ('BINARY_OP', rshift, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test or
+        before = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', or_, 0),
+            ('LOAD_SMALL_INT', 4, 0),
+            ('BINARY_OP', or_, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 7, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test and
+        before = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('LOAD_SMALL_INT', 1, 0),
+            ('BINARY_OP', and_, 0),
+            ('LOAD_SMALL_INT', 1, 0),
+            ('BINARY_OP', and_, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 1, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test xor
+        before = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', xor, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', xor, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 2, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[], expected_consts=[])
+
+        # test subscr
+        before = [
+            ('LOAD_CONST', 0, 0),
+            ('LOAD_SMALL_INT', 1, 0),
+            ('BINARY_OP', subscr, 0),
+            ('LOAD_SMALL_INT', 2, 0),
+            ('BINARY_OP', subscr, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        after = [
+            ('LOAD_SMALL_INT', 3, 0),
+            ('RETURN_VALUE', None, 0)
+        ]
+        self.cfg_optimization_test(before, after, consts=[(1, (1, 2, 3))], expected_consts=[(1, (1, 2, 3))])
+
+
     def test_conditional_jump_forward_const_condition(self):
         # The unreachable branch of the jump is removed, the jump
         # becomes redundant and is replaced by a NOP (for the lineno)
index ab1ee96b045362541d0c04151d6a9c272121d660..2c6e16817f2aad317d9734be5302746b9bdaf337 100644 (file)
@@ -56,199 +56,6 @@ has_starred(asdl_expr_seq *elts)
     return 0;
 }
 
-
-static PyObject*
-unary_not(PyObject *v)
-{
-    int r = PyObject_IsTrue(v);
-    if (r < 0)
-        return NULL;
-    return PyBool_FromLong(!r);
-}
-
-static int
-fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
-{
-    expr_ty arg = node->v.UnaryOp.operand;
-
-    if (arg->kind != Constant_kind) {
-        /* Fold not into comparison */
-        if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
-                asdl_seq_LEN(arg->v.Compare.ops) == 1) {
-            /* Eq and NotEq are often implemented in terms of one another, so
-               folding not (self == other) into self != other breaks implementation
-               of !=. Detecting such cases doesn't seem worthwhile.
-               Python uses </> for 'is subset'/'is superset' operations on sets.
-               They don't satisfy not folding laws. */
-            cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
-            switch (op) {
-            case Is:
-                op = IsNot;
-                break;
-            case IsNot:
-                op = Is;
-                break;
-            case In:
-                op = NotIn;
-                break;
-            case NotIn:
-                op = In;
-                break;
-            // The remaining comparison operators can't be safely inverted
-            case Eq:
-            case NotEq:
-            case Lt:
-            case LtE:
-            case Gt:
-            case GtE:
-                op = 0; // The AST enums leave "0" free as an "unused" marker
-                break;
-            // No default case, so the compiler will emit a warning if new
-            // comparison operators are added without being handled here
-            }
-            if (op) {
-                asdl_seq_SET(arg->v.Compare.ops, 0, op);
-                COPY_NODE(node, arg);
-                return 1;
-            }
-        }
-        return 1;
-    }
-
-    typedef PyObject *(*unary_op)(PyObject*);
-    static const unary_op ops[] = {
-        [Invert] = PyNumber_Invert,
-        [Not] = unary_not,
-        [UAdd] = PyNumber_Positive,
-        [USub] = PyNumber_Negative,
-    };
-    PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
-    return make_const(node, newval, arena);
-}
-
-/* Check whether a collection doesn't containing too much items (including
-   subcollections).  This protects from creating a constant that needs
-   too much time for calculating a hash.
-   "limit" is the maximal number of items.
-   Returns the negative number if the total number of items exceeds the
-   limit.  Otherwise returns the limit minus the total number of items.
-*/
-
-static Py_ssize_t
-check_complexity(PyObject *obj, Py_ssize_t limit)
-{
-    if (PyTuple_Check(obj)) {
-        Py_ssize_t i;
-        limit -= PyTuple_GET_SIZE(obj);
-        for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
-            limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
-        }
-        return limit;
-    }
-    return limit;
-}
-
-#define MAX_INT_SIZE           128  /* bits */
-#define MAX_COLLECTION_SIZE    256  /* items */
-#define MAX_STR_SIZE          4096  /* characters */
-#define MAX_TOTAL_ITEMS       1024  /* including nested collections */
-
-static PyObject *
-safe_multiply(PyObject *v, PyObject *w)
-{
-    if (PyLong_Check(v) && PyLong_Check(w) &&
-        !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
-    ) {
-        int64_t vbits = _PyLong_NumBits(v);
-        int64_t wbits = _PyLong_NumBits(w);
-        assert(vbits >= 0);
-        assert(wbits >= 0);
-        if (vbits + wbits > MAX_INT_SIZE) {
-            return NULL;
-        }
-    }
-    else if (PyLong_Check(v) && PyTuple_Check(w)) {
-        Py_ssize_t size = PyTuple_GET_SIZE(w);
-        if (size) {
-            long n = PyLong_AsLong(v);
-            if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
-                return NULL;
-            }
-            if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
-                return NULL;
-            }
-        }
-    }
-    else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
-        Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
-                                               PyBytes_GET_SIZE(w);
-        if (size) {
-            long n = PyLong_AsLong(v);
-            if (n < 0 || n > MAX_STR_SIZE / size) {
-                return NULL;
-            }
-        }
-    }
-    else if (PyLong_Check(w) &&
-             (PyTuple_Check(v) || PyUnicode_Check(v) || PyBytes_Check(v)))
-    {
-        return safe_multiply(w, v);
-    }
-
-    return PyNumber_Multiply(v, w);
-}
-
-static PyObject *
-safe_power(PyObject *v, PyObject *w)
-{
-    if (PyLong_Check(v) && PyLong_Check(w) &&
-        !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
-    ) {
-        int64_t vbits = _PyLong_NumBits(v);
-        size_t wbits = PyLong_AsSize_t(w);
-        assert(vbits >= 0);
-        if (wbits == (size_t)-1) {
-            return NULL;
-        }
-        if ((uint64_t)vbits > MAX_INT_SIZE / wbits) {
-            return NULL;
-        }
-    }
-
-    return PyNumber_Power(v, w, Py_None);
-}
-
-static PyObject *
-safe_lshift(PyObject *v, PyObject *w)
-{
-    if (PyLong_Check(v) && PyLong_Check(w) &&
-        !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
-    ) {
-        int64_t vbits = _PyLong_NumBits(v);
-        size_t wbits = PyLong_AsSize_t(w);
-        assert(vbits >= 0);
-        if (wbits == (size_t)-1) {
-            return NULL;
-        }
-        if (wbits > MAX_INT_SIZE || (uint64_t)vbits > MAX_INT_SIZE - wbits) {
-            return NULL;
-        }
-    }
-
-    return PyNumber_Lshift(v, w);
-}
-
-static PyObject *
-safe_mod(PyObject *v, PyObject *w)
-{
-    if (PyUnicode_Check(v) || PyBytes_Check(v)) {
-        return NULL;
-    }
-
-    return PyNumber_Remainder(v, w);
-}
-
-
 static expr_ty
 parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
 {
@@ -468,58 +275,7 @@ fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
         return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
     }
 
-    if (rhs->kind != Constant_kind) {
-        return 1;
-    }
-
-    PyObject *rv = rhs->v.Constant.value;
-    PyObject *newval = NULL;
-
-    switch (node->v.BinOp.op) {
-    case Add:
-        newval = PyNumber_Add(lv, rv);
-        break;
-    case Sub:
-        newval = PyNumber_Subtract(lv, rv);
-        break;
-    case Mult:
-        newval = safe_multiply(lv, rv);
-        break;
-    case Div:
-        newval = PyNumber_TrueDivide(lv, rv);
-        break;
-    case FloorDiv:
-        newval = PyNumber_FloorDivide(lv, rv);
-        break;
-    case Mod:
-        newval = safe_mod(lv, rv);
-        break;
-    case Pow:
-        newval = safe_power(lv, rv);
-        break;
-    case LShift:
-        newval = safe_lshift(lv, rv);
-        break;
-    case RShift:
-        newval = PyNumber_Rshift(lv, rv);
-        break;
-    case BitOr:
-        newval = PyNumber_Or(lv, rv);
-        break;
-    case BitXor:
-        newval = PyNumber_Xor(lv, rv);
-        break;
-    case BitAnd:
-        newval = PyNumber_And(lv, rv);
-        break;
-    // No builtin constants implement the following operators
-    case MatMult:
-        return 1;
-    // No default case, so the compiler will emit a warning if new binary
-    // operators are added without being handled here
-    }
-
-    return make_const(node, newval, arena);
+    return 1;
 }
 
 static PyObject*
@@ -670,7 +426,6 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
         break;
     case UnaryOp_kind:
         CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
-        CALL(fold_unaryop, expr_ty, node_);
         break;
     case Lambda_kind:
         CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
@@ -961,6 +716,44 @@ astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     return 1;
 }
 
+static int
+fold_const_match_patterns(expr_ty node, PyArena *ctx_, _PyASTOptimizeState *state)
+{
+    switch (node->kind)
+    {
+        case UnaryOp_kind:
+        {
+            if (node->v.UnaryOp.op == USub &&
+                node->v.UnaryOp.operand->kind == Constant_kind)
+            {
+                PyObject *operand = node->v.UnaryOp.operand->v.Constant.value;
+                PyObject *folded = PyNumber_Negative(operand);
+                return make_const(node, folded, ctx_);
+            }
+            break;
+        }
+        case BinOp_kind:
+        {
+            operator_ty op = node->v.BinOp.op;
+            if ((op == Add || op == Sub) &&
+                node->v.BinOp.right->kind == Constant_kind)
+            {
+                CALL(fold_const_match_patterns, expr_ty, node->v.BinOp.left);
+                if (node->v.BinOp.left->kind == Constant_kind) {
+                    PyObject *left = node->v.BinOp.left->v.Constant.value;
+                    PyObject *right = node->v.BinOp.right->v.Constant.value;
+                    PyObject *folded = op == Add ? PyNumber_Add(left, right) : PyNumber_Subtract(left, right);
+                    return make_const(node, folded, ctx_);
+                }
+            }
+            break;
+        }
+        default:
+            break;
+    }
+    return 1;
+}
+
 static int
 astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
 {
@@ -970,7 +763,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     ENTER_RECURSIVE();
     switch (node_->kind) {
         case MatchValue_kind:
-            CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
+            CALL(fold_const_match_patterns, expr_ty, node_->v.MatchValue.value);
             break;
         case MatchSingleton_kind:
             break;
@@ -978,7 +771,7 @@ astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
             break;
         case MatchMapping_kind:
-            CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
+            CALL_SEQ(fold_const_match_patterns, expr, node_->v.MatchMapping.keys);
             CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
             break;
         case MatchClass_kind:
index 38fb40831f3735f25697bc755fb3c0fc2aafda80..c5bdf105545459c360aede938b489cbf664f972c 100644 (file)
@@ -1406,6 +1406,26 @@ nop_out(basicblock *bb, int start, int count)
     }
 }
 
+/* Steals reference to "newconst" */
+static int
+instr_make_load_const(cfg_instr *instr, PyObject *newconst,
+                      PyObject *consts, PyObject *const_cache)
+{
+    if (PyLong_CheckExact(newconst)) {
+        int overflow;
+        long val = PyLong_AsLongAndOverflow(newconst, &overflow);
+        if (!overflow && _PY_IS_SMALL_INT(val)) {
+            assert(_Py_IsImmortal(newconst));
+            INSTR_SET_OP1(instr, LOAD_SMALL_INT, (int)val);
+            return SUCCESS;
+        }
+    }
+    int oparg = add_const(newconst, consts, const_cache);
+    RETURN_IF_ERROR(oparg);
+    INSTR_SET_OP1(instr, LOAD_CONST, oparg);
+    return SUCCESS;
+}
+
 /* Replace LOAD_CONST c1, LOAD_CONST c2 ... LOAD_CONST cn, BUILD_TUPLE n
    with    LOAD_CONST (c1, c2, ... cn).
    The consts table must still be in list form so that the
@@ -1413,25 +1433,23 @@ nop_out(basicblock *bb, int start, int count)
    Called with codestr pointing to the first LOAD_CONST.
 */
 static int
-fold_tuple_of_constants(basicblock *bb, int n, PyObject *consts, PyObject *const_cache)
+fold_tuple_of_constants(basicblock *bb, int i, PyObject *consts, PyObject *const_cache)
 {
     /* Pre-conditions */
     assert(PyDict_CheckExact(const_cache));
     assert(PyList_CheckExact(consts));
-    cfg_instr *instr = &bb->b_instr[n];
+    cfg_instr *instr = &bb->b_instr[i];
     assert(instr->i_opcode == BUILD_TUPLE);
     int seq_size = instr->i_oparg;
     PyObject *newconst;
-    RETURN_IF_ERROR(get_constant_sequence(bb, n-1, seq_size, consts, &newconst));
+    RETURN_IF_ERROR(get_constant_sequence(bb, i-1, seq_size, consts, &newconst));
     if (newconst == NULL) {
         /* not a const sequence */
         return SUCCESS;
     }
-    assert(PyTuple_CheckExact(newconst) && PyTuple_GET_SIZE(newconst) == seq_size);
-    int index = add_const(newconst, consts, const_cache);
-    RETURN_IF_ERROR(index);
-    nop_out(bb, n-1, seq_size);
-    INSTR_SET_OP1(instr, LOAD_CONST, index);
+    assert(PyTuple_Size(newconst) == seq_size);
+    RETURN_IF_ERROR(instr_make_load_const(instr, newconst, consts, const_cache));
+    nop_out(bb, i-1, seq_size);
     return SUCCESS;
 }
 
@@ -1469,7 +1487,7 @@ optimize_lists_and_sets(basicblock *bb, int i, int nextop,
         }
         return SUCCESS;
     }
-    assert(PyTuple_CheckExact(newconst) && PyTuple_GET_SIZE(newconst) == seq_size);
+    assert(PyTuple_Size(newconst) == seq_size);
     if (instr->i_opcode == BUILD_SET) {
         PyObject *frozenset = PyFrozenSet_New(newconst);
         if (frozenset == NULL) {
@@ -1497,45 +1515,200 @@ optimize_lists_and_sets(basicblock *bb, int i, int nextop,
     return SUCCESS;
 }
 
-/* Determine opcode & oparg for freshly folded constant. */
-static int
-newop_from_folded(PyObject *newconst, PyObject *consts,
-                  PyObject *const_cache, int *newopcode, int *newoparg)
+/* Check whether the total number of items in the (possibly nested) collection obj exceeds
+ * limit. Return a negative number if it does, and a non-negative number otherwise.
+ * Used to avoid creating constants which are slow to hash.
+ */
+static Py_ssize_t
+const_folding_check_complexity(PyObject *obj, Py_ssize_t limit)
 {
-    if (PyLong_CheckExact(newconst)) {
-        int overflow;
-        long val = PyLong_AsLongAndOverflow(newconst, &overflow);
-        if (!overflow && _PY_IS_SMALL_INT(val)) {
-            *newopcode = LOAD_SMALL_INT;
-            *newoparg = val;
-            return SUCCESS;
+    if (PyTuple_Check(obj)) {
+        Py_ssize_t i;
+        limit -= PyTuple_GET_SIZE(obj);
+        for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
+            limit = const_folding_check_complexity(PyTuple_GET_ITEM(obj, i), limit);
+            if (limit < 0) {
+                return limit;
+            }
         }
     }
-    *newopcode = LOAD_CONST;
-    *newoparg = add_const(newconst, consts, const_cache);
-    RETURN_IF_ERROR(*newoparg);
-    return SUCCESS;
+    return limit;
+}
+
+#define MAX_INT_SIZE           128  /* bits */
+#define MAX_COLLECTION_SIZE    256  /* items */
+#define MAX_STR_SIZE          4096  /* characters */
+#define MAX_TOTAL_ITEMS       1024  /* including nested collections */
+
+static PyObject *
+const_folding_safe_multiply(PyObject *v, PyObject *w)
+{
+    if (PyLong_Check(v) && PyLong_Check(w) &&
+        !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
+    ) {
+        int64_t vbits = _PyLong_NumBits(v);
+        int64_t wbits = _PyLong_NumBits(w);
+        assert(vbits >= 0);
+        assert(wbits >= 0);
+        if (vbits + wbits > MAX_INT_SIZE) {
+            return NULL;
+        }
+    }
+    else if (PyLong_Check(v) && PyTuple_Check(w)) {
+        Py_ssize_t size = PyTuple_GET_SIZE(w);
+        if (size) {
+            long n = PyLong_AsLong(v);
+            if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
+                return NULL;
+            }
+            if (n && const_folding_check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
+                return NULL;
+            }
+        }
+    }
+    else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
+        Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
+                                               PyBytes_GET_SIZE(w);
+        if (size) {
+            long n = PyLong_AsLong(v);
+            if (n < 0 || n > MAX_STR_SIZE / size) {
+                return NULL;
+            }
+        }
+    }
+    else if (PyLong_Check(w) &&
+             (PyTuple_Check(v) || PyUnicode_Check(v) || PyBytes_Check(v)))
+    {
+        return const_folding_safe_multiply(w, v);
+    }
+
+    return PyNumber_Multiply(v, w);
+}
+
+static PyObject *
+const_folding_safe_power(PyObject *v, PyObject *w)
+{
+    if (PyLong_Check(v) && PyLong_Check(w) &&
+        !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
+    ) {
+        int64_t vbits = _PyLong_NumBits(v);
+        size_t wbits = PyLong_AsSize_t(w);
+        assert(vbits >= 0);
+        if (wbits == (size_t)-1) {
+            return NULL;
+        }
+        if ((uint64_t)vbits > MAX_INT_SIZE / wbits) {
+            return NULL;
+        }
+    }
+
+    return PyNumber_Power(v, w, Py_None);
+}
+
+static PyObject *
+const_folding_safe_lshift(PyObject *v, PyObject *w)
+{
+    if (PyLong_Check(v) && PyLong_Check(w) &&
+        !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
+    ) {
+        int64_t vbits = _PyLong_NumBits(v);
+        size_t wbits = PyLong_AsSize_t(w);
+        assert(vbits >= 0);
+        if (wbits == (size_t)-1) {
+            return NULL;
+        }
+        if (wbits > MAX_INT_SIZE || (uint64_t)vbits > MAX_INT_SIZE - wbits) {
+            return NULL;
+        }
+    }
+
+    return PyNumber_Lshift(v, w);
+}
+
+static PyObject *
+const_folding_safe_mod(PyObject *v, PyObject *w)
+{
+    if (PyUnicode_Check(v) || PyBytes_Check(v)) {
+        return NULL;
+    }
+
+    return PyNumber_Remainder(v, w);
+}
+
+static PyObject *
+eval_const_binop(PyObject *left, int op, PyObject *right)
+{
+    assert(left != NULL && right != NULL);
+    assert(op >= 0 && op <= NB_OPARG_LAST);
+
+    PyObject *result = NULL;
+    switch (op) {
+        case NB_ADD:
+            result = PyNumber_Add(left, right);
+            break;
+        case NB_SUBTRACT:
+            result = PyNumber_Subtract(left, right);
+            break;
+        case NB_MULTIPLY:
+            result = const_folding_safe_multiply(left, right);
+            break;
+        case NB_TRUE_DIVIDE:
+            result = PyNumber_TrueDivide(left, right);
+            break;
+        case NB_FLOOR_DIVIDE:
+            result = PyNumber_FloorDivide(left, right);
+            break;
+        case NB_REMAINDER:
+            result = const_folding_safe_mod(left, right);
+            break;
+        case NB_POWER:
+            result = const_folding_safe_power(left, right);
+            break;
+        case NB_LSHIFT:
+            result = const_folding_safe_lshift(left, right);
+            break;
+        case NB_RSHIFT:
+            result = PyNumber_Rshift(left, right);
+            break;
+        case NB_OR:
+            result = PyNumber_Or(left, right);
+            break;
+        case NB_XOR:
+            result = PyNumber_Xor(left, right);
+            break;
+        case NB_AND:
+            result = PyNumber_And(left, right);
+            break;
+        case NB_SUBSCR:
+            result = PyObject_GetItem(left, right);
+            break;
+        case NB_MATRIX_MULTIPLY:
+            // No builtin constants implement matrix multiplication
+            break;
+        default:
+            Py_UNREACHABLE();
+    }
+    return result;
 }
 
 static int
-optimize_if_const_binop(basicblock *bb, int i, PyObject *consts, PyObject *const_cache)
+fold_const_binop(basicblock *bb, int i, PyObject *consts, PyObject *const_cache)
 {
+    #define BINOP_OPERAND_COUNT 2
+    assert(PyDict_CheckExact(const_cache));
+    assert(PyList_CheckExact(consts));
     cfg_instr *binop = &bb->b_instr[i];
     assert(binop->i_opcode == BINARY_OP);
-    if (binop->i_oparg != NB_SUBSCR) {
-        /* TODO: support other binary ops */
-        return SUCCESS;
-    }
     PyObject *pair;
-    RETURN_IF_ERROR(get_constant_sequence(bb, i-1, 2, consts, &pair));
+    RETURN_IF_ERROR(get_constant_sequence(bb, i-1, BINOP_OPERAND_COUNT, consts, &pair));
     if (pair == NULL) {
+        /* not a const sequence */
         return SUCCESS;
     }
-    assert(PyTuple_CheckExact(pair) && PyTuple_Size(pair) == 2);
+    assert(PyTuple_Size(pair) == BINOP_OPERAND_COUNT);
     PyObject *left = PyTuple_GET_ITEM(pair, 0);
     PyObject *right = PyTuple_GET_ITEM(pair, 1);
-    assert(left != NULL && right != NULL);
-    PyObject *newconst = PyObject_GetItem(left, right);
+    PyObject *newconst = eval_const_binop(left, binop->i_oparg, right);
     Py_DECREF(pair);
     if (newconst == NULL) {
         if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
@@ -1544,10 +1717,78 @@ optimize_if_const_binop(basicblock *bb, int i, PyObject *consts, PyObject *const
         PyErr_Clear();
         return SUCCESS;
     }
-    int newopcode, newoparg;
-    RETURN_IF_ERROR(newop_from_folded(newconst, consts, const_cache, &newopcode, &newoparg));
-    nop_out(bb, i-1, 2);
-    INSTR_SET_OP1(binop, newopcode, newoparg);
+    RETURN_IF_ERROR(instr_make_load_const(binop, newconst, consts, const_cache));
+    nop_out(bb, i-1, BINOP_OPERAND_COUNT);
+    return SUCCESS;
+}
+
+static PyObject *
+eval_const_unaryop(PyObject *operand, int opcode, int oparg)
+{
+    assert(operand != NULL);
+    assert(
+        opcode == UNARY_NEGATIVE ||
+        opcode == UNARY_INVERT ||
+        opcode == UNARY_NOT ||
+        (opcode == CALL_INTRINSIC_1 && oparg == INTRINSIC_UNARY_POSITIVE)
+    );
+    PyObject *result;
+    switch (opcode) {
+        case UNARY_NEGATIVE:
+            result = PyNumber_Negative(operand);
+            break;
+        case UNARY_INVERT:
+            result = PyNumber_Invert(operand);
+            break;
+        case UNARY_NOT: {
+            int r = PyObject_IsTrue(operand);
+            if (r < 0) {
+                return NULL;
+            }
+            result = PyBool_FromLong(!r);
+            break;
+        }
+        case CALL_INTRINSIC_1:
+            if (oparg != INTRINSIC_UNARY_POSITIVE) {
+                Py_UNREACHABLE();
+            }
+            result = PyNumber_Positive(operand);
+            break;
+        default:
+            Py_UNREACHABLE();
+    }
+    return result;
+}
+
+static int
+fold_const_unaryop(basicblock *bb, int i, PyObject *consts, PyObject *const_cache)
+{
+    #define UNARYOP_OPERAND_COUNT 1
+    assert(PyDict_CheckExact(const_cache));
+    assert(PyList_CheckExact(consts));
+    cfg_instr *instr = &bb->b_instr[i];
+    PyObject *seq;
+    RETURN_IF_ERROR(get_constant_sequence(bb, i-1, UNARYOP_OPERAND_COUNT, consts, &seq));
+    if (seq == NULL) {
+        /* not a const */
+        return SUCCESS;
+    }
+    assert(PyTuple_Size(seq) == UNARYOP_OPERAND_COUNT);
+    PyObject *operand = PyTuple_GET_ITEM(seq, 0);
+    PyObject *newconst = eval_const_unaryop(operand, instr->i_opcode, instr->i_oparg);
+    Py_DECREF(seq);
+    if (newconst == NULL) {
+        if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
+            return ERROR;
+        }
+        PyErr_Clear();
+        return SUCCESS;
+    }
+    if (instr->i_opcode == UNARY_NOT) {
+        assert(PyBool_Check(newconst));
+    }
+    RETURN_IF_ERROR(instr_make_load_const(instr, newconst, consts, const_cache));
+    nop_out(bb, i-1, UNARYOP_OPERAND_COUNT);
     return SUCCESS;
 }
 
@@ -2023,6 +2264,13 @@ optimize_basic_block(PyObject *const_cache, basicblock *bb, PyObject *consts)
                     INSTR_SET_OP1(&bb->b_instr[i + 1], opcode, oparg);
                     continue;
                 }
+                if (nextop == UNARY_NOT) {
+                    INSTR_SET_OP0(inst, NOP);
+                    int inverted = oparg ^ 1;
+                    assert(inverted == 0 || inverted == 1);
+                    INSTR_SET_OP1(&bb->b_instr[i + 1], opcode, inverted);
+                    continue;
+                }
                 break;
             case TO_BOOL:
                 if (nextop == TO_BOOL) {
@@ -2041,15 +2289,22 @@ optimize_basic_block(PyObject *const_cache, basicblock *bb, PyObject *consts)
                     INSTR_SET_OP0(&bb->b_instr[i + 1], NOP);
                     continue;
                 }
+                _Py_FALLTHROUGH;
+            case UNARY_INVERT:
+            case UNARY_NEGATIVE:
+                RETURN_IF_ERROR(fold_const_unaryop(bb, i, consts, const_cache));
                 break;
             case CALL_INTRINSIC_1:
                 // for _ in (*foo, *bar) -> for _ in [*foo, *bar]
                 if (oparg == INTRINSIC_LIST_TO_TUPLE && nextop == GET_ITER) {
                     INSTR_SET_OP0(inst, NOP);
                 }
+                else if (oparg == INTRINSIC_UNARY_POSITIVE) {
+                    RETURN_IF_ERROR(fold_const_unaryop(bb, i, consts, const_cache));
+                }
                 break;
             case BINARY_OP:
-                RETURN_IF_ERROR(optimize_if_const_binop(bb, i, consts, const_cache));
+                RETURN_IF_ERROR(fold_const_binop(bb, i, consts, const_cache));
                 break;
         }
     }