]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Rework integer overflow path in math.prod and add more tests (GH-11809)
authorPablo Galindo <Pablogsal@gmail.com>
Sat, 9 Mar 2019 19:18:08 +0000 (19:18 +0000)
committerGitHub <noreply@github.com>
Sat, 9 Mar 2019 19:18:08 +0000 (19:18 +0000)
The overflow check was relying on undefined behaviour as it was using the result of the multiplication to do the check, and once the overflow has already happened, any operation on the result is undefined behaviour.

Some extra checks that exercise code paths related to this are also added.

Lib/test/test_math.py
Modules/mathmodule.c

index 856b1e8ac11e3bf0dbf6c0d4eb015cf68b1ccaf3..cb05dee0e0fd3c1949aff02846a0c29bc558bf0a 100644 (file)
@@ -1595,6 +1595,92 @@ class MathTests(unittest.TestCase):
             self.fail('Failures in test_mtestfile:\n  ' +
                       '\n  '.join(failures))
 
+    def test_prod(self):
+        prod = math.prod
+        self.assertEqual(prod([]), 1)
+        self.assertEqual(prod([], start=5), 5)
+        self.assertEqual(prod(list(range(2,8))), 5040)
+        self.assertEqual(prod(iter(list(range(2,8)))), 5040)
+        self.assertEqual(prod(range(1, 10), start=10), 3628800)
+
+        self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
+        self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
+        self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
+        self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
+
+        # Test overflow in fast-path for integers
+        self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
+        # Test overflow in fast-path for floats
+        self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
+
+        self.assertRaises(TypeError, prod)
+        self.assertRaises(TypeError, prod, 42)
+        self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
+        self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
+        self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
+        values = [bytearray(b'a'), bytearray(b'b')]
+        self.assertRaises(TypeError, prod, values, bytearray(b''))
+        self.assertRaises(TypeError, prod, [[1], [2], [3]])
+        self.assertRaises(TypeError, prod, [{2:3}])
+        self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
+        self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
+        with self.assertRaises(TypeError):
+            prod([10, 20], [30, 40])     # start is a keyword-only argument
+
+        self.assertEqual(prod([0, 1, 2, 3]), 0)
+        self.assertEqual(prod([1, 0, 2, 3]), 0)
+        self.assertEqual(prod([1, 2, 3, 0]), 0)
+
+        def _naive_prod(iterable, start=1):
+            for elem in iterable:
+                start *= elem
+            return start
+
+        # Big integers
+
+        iterable = range(1, 10000)
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = range(-10000, -1)
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = range(-1000, 1000)
+        self.assertEqual(prod(iterable), 0)
+
+        # Big floats
+
+        iterable = [float(x) for x in range(1, 1000)]
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = [float(x) for x in range(-1000, -1)]
+        self.assertEqual(prod(iterable), _naive_prod(iterable))
+        iterable = [float(x) for x in range(-1000, 1000)]
+        self.assertIsNaN(prod(iterable))
+
+        # Float tests
+
+        self.assertIsNaN(prod([1, 2, 3, float("nan"), 2, 3]))
+        self.assertIsNaN(prod([1, 0, float("nan"), 2, 3]))
+        self.assertIsNaN(prod([1, float("nan"), 0, 3]))
+        self.assertIsNaN(prod([1, float("inf"), float("nan"),3]))
+        self.assertIsNaN(prod([1, float("-inf"), float("nan"),3]))
+        self.assertIsNaN(prod([1, float("nan"), float("inf"),3]))
+        self.assertIsNaN(prod([1, float("nan"), float("-inf"),3]))
+
+        self.assertEqual(prod([1, 2, 3, float('inf'),-3,4]), float('-inf'))
+        self.assertEqual(prod([1, 2, 3, float('-inf'),-3,4]), float('inf'))
+
+        self.assertIsNaN(prod([1,2,0,float('inf'), -3, 4]))
+        self.assertIsNaN(prod([1,2,0,float('-inf'), -3, 4]))
+        self.assertIsNaN(prod([1, 2, 3, float('inf'), -3, 0, 3]))
+        self.assertIsNaN(prod([1, 2, 3, float('-inf'), -3, 0, 2]))
+
+        # Type preservation
+
+        self.assertEqual(type(prod([1, 2, 3, 4, 5, 6])), int)
+        self.assertEqual(type(prod([1, 2.0, 3, 4, 5, 6])), float)
+        self.assertEqual(type(prod(range(1, 10000))), int)
+        self.assertEqual(type(prod(range(1, 10000), start=1.0)), float)
+        self.assertEqual(type(prod([1, decimal.Decimal(2.0), 3, 4, 5, 6])),
+                         decimal.Decimal)
+
     # Custom assertions.
 
     def assertIsNaN(self, value):
@@ -1724,41 +1810,6 @@ class IsCloseTests(unittest.TestCase):
         self.assertAllClose(fraction_examples, rel_tol=1e-8)
         self.assertAllNotClose(fraction_examples, rel_tol=1e-9)
 
-    def test_prod(self):
-        prod = math.prod
-        self.assertEqual(prod([]), 1)
-        self.assertEqual(prod([], start=5), 5)
-        self.assertEqual(prod(list(range(2,8))), 5040)
-        self.assertEqual(prod(iter(list(range(2,8)))), 5040)
-        self.assertEqual(prod(range(1, 10), start=10), 3628800)
-
-        self.assertEqual(prod([1, 2, 3, 4, 5]), 120)
-        self.assertEqual(prod([1.0, 2.0, 3.0, 4.0, 5.0]), 120.0)
-        self.assertEqual(prod([1, 2, 3, 4.0, 5.0]), 120.0)
-        self.assertEqual(prod([1.0, 2.0, 3.0, 4, 5]), 120.0)
-
-        # Test overflow in fast-path for integers
-        self.assertEqual(prod([1, 1, 2**32, 1, 1]), 2**32)
-        # Test overflow in fast-path for floats
-        self.assertEqual(prod([1.0, 1.0, 2**32, 1, 1]), float(2**32))
-
-        self.assertRaises(TypeError, prod)
-        self.assertRaises(TypeError, prod, 42)
-        self.assertRaises(TypeError, prod, ['a', 'b', 'c'])
-        self.assertRaises(TypeError, prod, ['a', 'b', 'c'], '')
-        self.assertRaises(TypeError, prod, [b'a', b'c'], b'')
-        values = [bytearray(b'a'), bytearray(b'b')]
-        self.assertRaises(TypeError, prod, values, bytearray(b''))
-        self.assertRaises(TypeError, prod, [[1], [2], [3]])
-        self.assertRaises(TypeError, prod, [{2:3}])
-        self.assertRaises(TypeError, prod, [{2:3}]*2, {2:3})
-        self.assertRaises(TypeError, prod, [[1], [2], [3]], [])
-        with self.assertRaises(TypeError):
-            prod([10, 20], [30, 40])     # start is a keyword-only argument
-
-        self.assertEqual(prod([0, 1, 2, 3]), 0)
-        self.assertEqual(prod([1, 0, 2, 3]), 0)
-        self.assertEqual(prod(range(10)), 0)
 
 def test_main():
     from doctest import DocFileSuite
index fd0eb327c74329031b3ba3464fcd863e0be6abb1..ba8423211c2b53f04e6b65dbf824d89a7faf1f0c 100644 (file)
@@ -2493,6 +2493,55 @@ math_isclose_impl(PyObject *module, double a, double b, double rel_tol,
             (diff <= abs_tol));
 }
 
+static inline int
+_check_long_mult_overflow(long a, long b) {
+
+    /* From Python2's int_mul code:
+
+    Integer overflow checking for * is painful:  Python tried a couple ways, but
+    they didn't work on all platforms, or failed in endcases (a product of
+    -sys.maxint-1 has been a particular pain).
+
+    Here's another way:
+
+    The native long product x*y is either exactly right or *way* off, being
+    just the last n bits of the true product, where n is the number of bits
+    in a long (the delivered product is the true product plus i*2**n for
+    some integer i).
+
+    The native double product (double)x * (double)y is subject to three
+    rounding errors:  on a sizeof(long)==8 box, each cast to double can lose
+    info, and even on a sizeof(long)==4 box, the multiplication can lose info.
+    But, unlike the native long product, it's not in *range* trouble:  even
+    if sizeof(long)==32 (256-bit longs), the product easily fits in the
+    dynamic range of a double.  So the leading 50 (or so) bits of the double
+    product are correct.
+
+    We check these two ways against each other, and declare victory if they're
+    approximately the same.  Else, because the native long product is the only
+    one that can lose catastrophic amounts of information, it's the native long
+    product that must have overflowed.
+
+    */
+
+    long longprod = (long)((unsigned long)a * b);
+    double doubleprod = (double)a * (double)b;
+    double doubled_longprod = (double)longprod;
+
+    if (doubled_longprod == doubleprod) {
+        return 0;
+    }
+
+    const double diff = doubled_longprod - doubleprod;
+    const double absdiff = diff >= 0.0 ? diff : -diff;
+    const double absprod = doubleprod >= 0.0 ? doubleprod : -doubleprod;
+
+    if (32.0 * absdiff <= absprod) {
+        return 0;
+    }
+
+    return 1;
+}
 
 /*[clinic input]
 math.prod
@@ -2558,11 +2607,8 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
             }
             if (PyLong_CheckExact(item)) {
                 long b = PyLong_AsLongAndOverflow(item, &overflow);
-                long x = i_result * b;
-                /* Continue if there is no overflow */
-                if (overflow == 0
-                    && x < LONG_MAX && x > LONG_MIN
-                    && !(b != 0 && x / b != i_result)) {
+                if (overflow == 0 && !_check_long_mult_overflow(i_result, b)) {
+                    long x = i_result * b;
                     i_result = x;
                     Py_DECREF(item);
                     continue;