]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-127809: Fix the JIT's understanding of ** (GH-127844)
authorBrandt Bucher <brandtbucher@microsoft.com>
Wed, 8 Jan 2025 01:25:48 +0000 (17:25 -0800)
committerGitHub <noreply@github.com>
Wed, 8 Jan 2025 01:25:48 +0000 (17:25 -0800)
Lib/test/test_capi/test_opt.py
Misc/NEWS.d/next/Core_and_Builtins/2024-12-11-14-32-22.gh-issue-127809.0W8khe.rst [new file with mode: 0644]
Python/bytecodes.c
Python/executor_cases.c.h
Python/generated_cases.c.h
Python/optimizer_bytecodes.c
Python/optimizer_cases.c.h
Tools/cases_generator/analyzer.py

index 4cf9b66170c055f9b0290f8006bfdd54494ac92b..d84702411afe41d0f0d8f2c30bbf96759a79d10a 100644 (file)
@@ -1,4 +1,5 @@
 import contextlib
+import itertools
 import sys
 import textwrap
 import unittest
@@ -1511,6 +1512,49 @@ class TestUopsOptimization(unittest.TestCase):
         with self.assertRaises(TypeError):
             {item for item in items}
 
+    def test_power_type_depends_on_input_values(self):
+        template = textwrap.dedent("""
+            import _testinternalcapi
+
+            L, R, X, Y = {l}, {r}, {x}, {y}
+
+            def check(actual: complex, expected: complex) -> None:
+                assert actual == expected, (actual, expected)
+                assert type(actual) is type(expected), (actual, expected)
+
+            def f(l: complex, r: complex) -> None:
+                expected_local_local = pow(l, r) + pow(l, r)
+                expected_const_local = pow(L, r) + pow(L, r)
+                expected_local_const = pow(l, R) + pow(l, R)
+                expected_const_const = pow(L, R) + pow(L, R)
+                for _ in range(_testinternalcapi.TIER2_THRESHOLD):
+                    # Narrow types:
+                    l + l, r + r
+                    # The powers produce results, and the addition is unguarded:
+                    check(l ** r + l ** r, expected_local_local)
+                    check(L ** r + L ** r, expected_const_local)
+                    check(l ** R + l ** R, expected_local_const)
+                    check(L ** R + L ** R, expected_const_const)
+
+            # JIT for one pair of values...
+            f(L, R)
+            # ...then run with another:
+            f(X, Y)
+        """)
+        interesting = [
+            (1, 1),  # int ** int -> int
+            (1, -1),  # int ** int -> float
+            (1.0, 1),  # float ** int -> float
+            (1, 1.0),  # int ** float -> float
+            (-1, 0.5),  # int ** float -> complex
+            (1.0, 1.0),  # float ** float -> float
+            (-1.0, 0.5),  # float ** float -> complex
+        ]
+        for (l, r), (x, y) in itertools.product(interesting, repeat=2):
+            s = template.format(l=l, r=r, x=x, y=y)
+            with self.subTest(l=l, r=r, x=x, y=y):
+                script_helper.assert_python_ok("-c", s)
+
 
 def global_identity(x):
     return x
diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2024-12-11-14-32-22.gh-issue-127809.0W8khe.rst b/Misc/NEWS.d/next/Core_and_Builtins/2024-12-11-14-32-22.gh-issue-127809.0W8khe.rst
new file mode 100644 (file)
index 0000000..19c8cc6
--- /dev/null
@@ -0,0 +1,2 @@
+Fix an issue where the experimental JIT may infer an incorrect result type
+for exponentiation (``**`` and ``**=``), leading to bugs or crashes.
index ec1cd00962ac0aee02d8abd6db4ca7771cde2a6a..8bab4ea16b629b6f047814220972feac04c84a86 100644 (file)
@@ -530,6 +530,8 @@ dummy_func(
         pure op(_BINARY_OP_MULTIPLY_INT, (left, right -- res)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyLong_CheckExact(left_o));
+            assert(PyLong_CheckExact(right_o));
 
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
@@ -543,6 +545,8 @@ dummy_func(
         pure op(_BINARY_OP_ADD_INT, (left, right -- res)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyLong_CheckExact(left_o));
+            assert(PyLong_CheckExact(right_o));
 
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
@@ -556,6 +560,8 @@ dummy_func(
         pure op(_BINARY_OP_SUBTRACT_INT, (left, right -- res)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyLong_CheckExact(left_o));
+            assert(PyLong_CheckExact(right_o));
 
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
@@ -593,6 +599,8 @@ dummy_func(
         pure op(_BINARY_OP_MULTIPLY_FLOAT, (left, right -- res)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyFloat_CheckExact(left_o));
+            assert(PyFloat_CheckExact(right_o));
 
             STAT_INC(BINARY_OP, hit);
             double dres =
@@ -607,6 +615,8 @@ dummy_func(
         pure op(_BINARY_OP_ADD_FLOAT, (left, right -- res)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyFloat_CheckExact(left_o));
+            assert(PyFloat_CheckExact(right_o));
 
             STAT_INC(BINARY_OP, hit);
             double dres =
@@ -621,6 +631,8 @@ dummy_func(
         pure op(_BINARY_OP_SUBTRACT_FLOAT, (left, right -- res)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyFloat_CheckExact(left_o));
+            assert(PyFloat_CheckExact(right_o));
 
             STAT_INC(BINARY_OP, hit);
             double dres =
@@ -650,6 +662,8 @@ dummy_func(
         pure op(_BINARY_OP_ADD_UNICODE, (left, right -- res)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyUnicode_CheckExact(left_o));
+            assert(PyUnicode_CheckExact(right_o));
 
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = PyUnicode_Concat(left_o, right_o);
@@ -672,6 +686,8 @@ dummy_func(
         op(_BINARY_OP_INPLACE_ADD_UNICODE, (left, right --)) {
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyUnicode_CheckExact(left_o));
+            assert(PyUnicode_CheckExact(right_o));
 
             int next_oparg;
         #if TIER_ONE
index ac2f69b7e98dc33aecb5ae52592ee274a17e7144..e40fa88be891728daeeada9d5c5f4b8b72fa02ec 100644 (file)
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyLong_CheckExact(left_o));
+            assert(PyLong_CheckExact(right_o));
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
             PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyLong_CheckExact(left_o));
+            assert(PyLong_CheckExact(right_o));
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
             PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyLong_CheckExact(left_o));
+            assert(PyLong_CheckExact(right_o));
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
             PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyFloat_CheckExact(left_o));
+            assert(PyFloat_CheckExact(right_o));
             STAT_INC(BINARY_OP, hit);
             double dres =
             ((PyFloatObject *)left_o)->ob_fval *
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyFloat_CheckExact(left_o));
+            assert(PyFloat_CheckExact(right_o));
             STAT_INC(BINARY_OP, hit);
             double dres =
             ((PyFloatObject *)left_o)->ob_fval +
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyFloat_CheckExact(left_o));
+            assert(PyFloat_CheckExact(right_o));
             STAT_INC(BINARY_OP, hit);
             double dres =
             ((PyFloatObject *)left_o)->ob_fval -
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyUnicode_CheckExact(left_o));
+            assert(PyUnicode_CheckExact(right_o));
             STAT_INC(BINARY_OP, hit);
             PyObject *res_o = PyUnicode_Concat(left_o, right_o);
             PyStackRef_CLOSE_SPECIALIZED(left, _PyUnicode_ExactDealloc);
             left = stack_pointer[-2];
             PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
             PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+            assert(PyUnicode_CheckExact(left_o));
+            assert(PyUnicode_CheckExact(right_o));
             int next_oparg;
             #if TIER_ONE
             assert(next_instr->op.code == STORE_FAST);
index eaa8a5634640686c2ab21cae1bb25acc2884da3d..7028ba52faae96aa7bb08c286e603a3ef1125733 100644 (file)
@@ -80,6 +80,8 @@
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyFloat_CheckExact(left_o));
+                assert(PyFloat_CheckExact(right_o));
                 STAT_INC(BINARY_OP, hit);
                 double dres =
                 ((PyFloatObject *)left_o)->ob_fval +
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyLong_CheckExact(left_o));
+                assert(PyLong_CheckExact(right_o));
                 STAT_INC(BINARY_OP, hit);
                 PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
                 PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyUnicode_CheckExact(left_o));
+                assert(PyUnicode_CheckExact(right_o));
                 STAT_INC(BINARY_OP, hit);
                 PyObject *res_o = PyUnicode_Concat(left_o, right_o);
                 PyStackRef_CLOSE_SPECIALIZED(left, _PyUnicode_ExactDealloc);
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyUnicode_CheckExact(left_o));
+                assert(PyUnicode_CheckExact(right_o));
                 int next_oparg;
                 #if TIER_ONE
                 assert(next_instr->op.code == STORE_FAST);
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyFloat_CheckExact(left_o));
+                assert(PyFloat_CheckExact(right_o));
                 STAT_INC(BINARY_OP, hit);
                 double dres =
                 ((PyFloatObject *)left_o)->ob_fval *
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyLong_CheckExact(left_o));
+                assert(PyLong_CheckExact(right_o));
                 STAT_INC(BINARY_OP, hit);
                 PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
                 PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyFloat_CheckExact(left_o));
+                assert(PyFloat_CheckExact(right_o));
                 STAT_INC(BINARY_OP, hit);
                 double dres =
                 ((PyFloatObject *)left_o)->ob_fval -
             {
                 PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
                 PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
+                assert(PyLong_CheckExact(left_o));
+                assert(PyLong_CheckExact(right_o));
                 STAT_INC(BINARY_OP, hit);
                 PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
                 PyStackRef_CLOSE_SPECIALIZED(right, _PyLong_ExactDealloc);
index a14d119b7a1dece0d352cd382edb36c84c9a1830..86394480f76bb8bb185cc67de1137dda51cd453c 100644 (file)
@@ -167,23 +167,56 @@ dummy_func(void) {
     }
 
     op(_BINARY_OP, (left, right -- res)) {
-        PyTypeObject *ltype = sym_get_type(left);
-        PyTypeObject *rtype = sym_get_type(right);
-        if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) &&
-            rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type))
-        {
-            if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
-                ltype == &PyLong_Type && rtype == &PyLong_Type) {
-                /* If both inputs are ints and the op is not division the result is an int */
-                res = sym_new_type(ctx, &PyLong_Type);
+        bool lhs_int = sym_matches_type(left, &PyLong_Type);
+        bool rhs_int = sym_matches_type(right, &PyLong_Type);
+        bool lhs_float = sym_matches_type(left, &PyFloat_Type);
+        bool rhs_float = sym_matches_type(right, &PyFloat_Type);
+        if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) {
+            // There's something other than an int or float involved:
+            res = sym_new_unknown(ctx);
+        }
+        else if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) {
+            // This one's fun... the *type* of the result depends on the
+            // *values* being exponentiated. However, exponents with one
+            // constant part are reasonably common, so it's probably worth
+            // trying to infer some simple cases:
+            // - A: 1 ** 1 -> 1 (int ** int -> int)
+            // - B: 1 ** -1 -> 1.0 (int ** int -> float)
+            // - C: 1.0 ** 1 -> 1.0 (float ** int -> float)
+            // - D: 1 ** 1.0 -> 1.0 (int ** float -> float)
+            // - E: -1 ** 0.5 ~> 1j (int ** float -> complex)
+            // - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float)
+            // - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex)
+            if (rhs_float) {
+                // Case D, E, F, or G... can't know without the sign of the LHS
+                // or whether the RHS is whole, which isn't worth the effort:
+                res = sym_new_unknown(ctx);
             }
-            else {
-                /* For any other op combining ints/floats the result is a float */
+            else if (lhs_float) {
+                // Case C:
                 res = sym_new_type(ctx, &PyFloat_Type);
             }
+            else if (!sym_is_const(right)) {
+                // Case A or B... can't know without the sign of the RHS:
+                res = sym_new_unknown(ctx);
+            }
+            else if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) {
+                // Case B:
+                res = sym_new_type(ctx, &PyFloat_Type);
+            }
+            else {
+                // Case A:
+                res = sym_new_type(ctx, &PyLong_Type);
+            }
+        }
+        else if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) {
+            res = sym_new_type(ctx, &PyFloat_Type);
+        }
+        else if (lhs_int && rhs_int) {
+            res = sym_new_type(ctx, &PyLong_Type);
         }
         else {
-            res = sym_new_unknown(ctx);
+            res = sym_new_type(ctx, &PyFloat_Type);
         }
     }
 
index be3e06108aec92461b887c1eea707803aaac2e4f..c72ae7b6281e80cc1bbd6677880c7441f2e432ce 100644 (file)
             _Py_UopsSymbol *res;
             right = stack_pointer[-1];
             left = stack_pointer[-2];
-            PyTypeObject *ltype = sym_get_type(left);
-            PyTypeObject *rtype = sym_get_type(right);
-            if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) &&
-                rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type))
-            {
-                if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
-                    ltype == &PyLong_Type && rtype == &PyLong_Type) {
-                    /* If both inputs are ints and the op is not division the result is an int */
-                    res = sym_new_type(ctx, &PyLong_Type);
+            bool lhs_int = sym_matches_type(left, &PyLong_Type);
+            bool rhs_int = sym_matches_type(right, &PyLong_Type);
+            bool lhs_float = sym_matches_type(left, &PyFloat_Type);
+            bool rhs_float = sym_matches_type(right, &PyFloat_Type);
+            if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) {
+                // There's something other than an int or float involved:
+                res = sym_new_unknown(ctx);
+            }
+            else {
+                if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) {
+                    // This one's fun... the *type* of the result depends on the
+                    // *values* being exponentiated. However, exponents with one
+                    // constant part are reasonably common, so it's probably worth
+                    // trying to infer some simple cases:
+                    // - A: 1 ** 1 -> 1 (int ** int -> int)
+                    // - B: 1 ** -1 -> 1.0 (int ** int -> float)
+                    // - C: 1.0 ** 1 -> 1.0 (float ** int -> float)
+                    // - D: 1 ** 1.0 -> 1.0 (int ** float -> float)
+                    // - E: -1 ** 0.5 ~> 1j (int ** float -> complex)
+                    // - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float)
+                    // - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex)
+                    if (rhs_float) {
+                        // Case D, E, F, or G... can't know without the sign of the LHS
+                        // or whether the RHS is whole, which isn't worth the effort:
+                        res = sym_new_unknown(ctx);
+                    }
+                    else {
+                        if (lhs_float) {
+                            // Case C:
+                            res = sym_new_type(ctx, &PyFloat_Type);
+                        }
+                        else {
+                            if (!sym_is_const(right)) {
+                                // Case A or B... can't know without the sign of the RHS:
+                                res = sym_new_unknown(ctx);
+                            }
+                            else {
+                                if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) {
+                                    // Case B:
+                                    res = sym_new_type(ctx, &PyFloat_Type);
+                                }
+                                else {
+                                    // Case A:
+                                    res = sym_new_type(ctx, &PyLong_Type);
+                                }
+                            }
+                        }
+                    }
                 }
                 else {
-                    /* For any other op combining ints/floats the result is a float */
-                    res = sym_new_type(ctx, &PyFloat_Type);
+                    if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) {
+                        res = sym_new_type(ctx, &PyFloat_Type);
+                    }
+                    else {
+                        if (lhs_int && rhs_int) {
+                            res = sym_new_type(ctx, &PyLong_Type);
+                        }
+                        else {
+                            res = sym_new_type(ctx, &PyFloat_Type);
+                        }
+                    }
                 }
             }
-            else {
-                res = sym_new_unknown(ctx);
-            }
             stack_pointer[-2] = res;
             stack_pointer += -1;
             assert(WITHIN_STACK_BOUNDS());
index c0a370a936aa943b443abf6b93e78abfc3f0d110..679beca3ec3a9dabd5596ad21dc795614781c29e 100644 (file)
@@ -599,6 +599,7 @@ NON_ESCAPING_FUNCTIONS = (
     "_PyLong_CompactValue",
     "_PyLong_DigitCount",
     "_PyLong_IsCompact",
+    "_PyLong_IsNegative",
     "_PyLong_IsNonNegativeCompact",
     "_PyLong_IsZero",
     "_PyLong_Multiply",