]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-143006: Fix and optimize mixed comparison of float and int (GH-143084)
authorSerhiy Storchaka <storchaka@gmail.com>
Fri, 9 Jan 2026 17:06:45 +0000 (19:06 +0200)
committerGitHub <noreply@github.com>
Fri, 9 Jan 2026 17:06:45 +0000 (19:06 +0200)
When comparing negative non-integer float and int with the same number
of bits in the integer part, __neg__() in the int subclass returning
not an int caused an assertion error.

Now the integer is no longer negated. Also, reduced the number of
temporary created Python objects.

Lib/test/test_float.py
Misc/NEWS.d/next/Core_and_Builtins/2025-12-22-22-37-53.gh-issue-143006.ZBQwbN.rst [new file with mode: 0644]
Objects/floatobject.c

index 00518abcb11b46b31c4fa1d3ca2991a8d7f9430c..c03b0a09f718899b3b51f78eb1a8a09b7d2d1966 100644 (file)
@@ -651,6 +651,24 @@ class GeneralFloatCases(unittest.TestCase):
         value = F('nan')
         self.assertEqual(hash(value), object.__hash__(value))
 
+    def test_issue_gh143006(self):
+        # When comparing negative non-integer float and int with the
+        # same number of bits in the integer part, __neg__() in the
+        # int subclass returning not an int caused an assertion error.
+        class EvilInt(int):
+            def __neg__(self):
+                return ""
+
+        i = -1 << 50
+        f = float(i) - 0.5
+        i = EvilInt(i)
+        self.assertFalse(f == i)
+        self.assertTrue(f != i)
+        self.assertTrue(f < i)
+        self.assertTrue(f <= i)
+        self.assertFalse(f > i)
+        self.assertFalse(f >= i)
+
 
 @unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
 class FormatFunctionsTestCase(unittest.TestCase):
diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-12-22-22-37-53.gh-issue-143006.ZBQwbN.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-12-22-22-37-53.gh-issue-143006.ZBQwbN.rst
new file mode 100644 (file)
index 0000000..f256203
--- /dev/null
@@ -0,0 +1,2 @@
+Fix a possible assertion error when comparing negative non-integer ``float``
+and ``int`` with the same number of bits in the integer part.
index 2cb690748d9de43b5ba85b363246aa9061ad3560..579765281ca484c89899d47beace9361235d07e6 100644 (file)
@@ -435,82 +435,67 @@ float_richcompare(PyObject *v, PyObject *w, int op)
         assert(vsign != 0); /* if vsign were 0, then since wsign is
                              * not 0, we would have taken the
                              * vsign != wsign branch at the start */
-        /* We want to work with non-negative numbers. */
-        if (vsign < 0) {
-            /* "Multiply both sides" by -1; this also swaps the
-             * comparator.
-             */
-            i = -i;
-            op = _Py_SwappedOp[op];
-        }
-        assert(i > 0.0);
         (void) frexp(i, &exponent);
         /* exponent is the # of bits in v before the radix point;
          * we know that nbits (the # of bits in w) > 48 at this point
          */
         if (exponent < nbits) {
-            i = 1.0;
-            j = 2.0;
+            j = i;
+            i = 0.0;
             goto Compare;
         }
         if (exponent > nbits) {
-            i = 2.0;
-            j = 1.0;
+            j = 0.0;
             goto Compare;
         }
         /* v and w have the same number of bits before the radix
-         * point.  Construct two ints that have the same comparison
-         * outcome.
+         * point.  Construct an int from the integer part of v and
+         * update op if necessary, so comparing two ints has the same outcome.
          */
         {
             double fracpart;
             double intpart;
             PyObject *result = NULL;
             PyObject *vv = NULL;
-            PyObject *ww = w;
 
-            if (wsign < 0) {
-                ww = PyNumber_Negative(w);
-                if (ww == NULL)
-                    goto Error;
+            fracpart = modf(i, &intpart);
+            if (fracpart != 0.0) {
+                switch (op) {
+                    /* Non-integer float never equals to an int. */
+                    case Py_EQ:
+                        Py_RETURN_FALSE;
+                    case Py_NE:
+                        Py_RETURN_TRUE;
+                    /* For non-integer float, v <= w <=> v < w.
+                     * If v > 0: trunc(v) < v < trunc(v) + 1
+                     *   v < w => trunc(v) < w
+                     *   trunc(v) < w => trunc(v) + 1 <= w => v < w
+                     * If v < 0: trunc(v) - 1 < v < trunc(v)
+                     *   v < w => trunc(v) - 1 < w => trunc(v) <= w
+                     *   trunc(v) <= w => v < w
+                     */
+                    case Py_LT:
+                    case Py_LE:
+                        op = vsign > 0 ? Py_LT : Py_LE;
+                        break;
+                    /* The same as above, but with opposite directions. */
+                    case Py_GT:
+                    case Py_GE:
+                        op = vsign > 0 ? Py_GE : Py_GT;
+                        break;
+                }
             }
-            else
-                Py_INCREF(ww);
 
-            fracpart = modf(i, &intpart);
             vv = PyLong_FromDouble(intpart);
             if (vv == NULL)
                 goto Error;
 
-            if (fracpart != 0.0) {
-                /* Shift left, and or a 1 bit into vv
-                 * to represent the lost fraction.
-                 */
-                PyObject *temp;
-
-                temp = _PyLong_Lshift(ww, 1);
-                if (temp == NULL)
-                    goto Error;
-                Py_SETREF(ww, temp);
-
-                temp = _PyLong_Lshift(vv, 1);
-                if (temp == NULL)
-                    goto Error;
-                Py_SETREF(vv, temp);
-
-                temp = PyNumber_Or(vv, _PyLong_GetOne());
-                if (temp == NULL)
-                    goto Error;
-                Py_SETREF(vv, temp);
-            }
-
-            r = PyObject_RichCompareBool(vv, ww, op);
+            r = PyObject_RichCompareBool(vv, w, op);
             if (r < 0)
                 goto Error;
             result = PyBool_FromLong(r);
          Error:
             Py_XDECREF(vv);
-            Py_XDECREF(ww);
             return result;
         }
     } /* else if (PyLong_Check(w)) */