]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-46233: Minor speedup for bigint squaring (GH-30345)
authorTim Peters <tim.peters@gmail.com>
Tue, 4 Jan 2022 02:41:16 +0000 (20:41 -0600)
committerGitHub <noreply@github.com>
Tue, 4 Jan 2022 02:41:16 +0000 (20:41 -0600)
x_mul()'s squaring code can do some redundant and/or useless
work at the end of each digit pass. A more careful analysis
of worst-case carries at various digit positions allows
making that code leaner.

Lib/test/test_long.py
Objects/longobject.c

index 3c8e9e22e17a196d472ce62c0b77ec68418104eb..f2a622b5868f0851956469f1b5ee053a6753e335 100644 (file)
@@ -1502,6 +1502,17 @@ class LongTest(unittest.TestCase):
             self.assertEqual(type(numerator), int)
             self.assertEqual(type(denominator), int)
 
+    def test_square(self):
+        # Multiplication makes a special case of multiplying an int with
+        # itself, using a special, faster algorithm. This test is mostly
+        # to ensure that no asserts in the implementation trigger, in
+        # cases with a maximal amount of carries.
+        for bitlen in range(1, 400):
+            n = (1 << bitlen) - 1 # solid string of 1 bits
+            with self.subTest(bitlen=bitlen, n=n):
+                # (2**i - 1)**2 = 2**(2*i) - 2*2**i + 1
+                self.assertEqual(n**2,
+                    (1 << (2 * bitlen)) - (1 << (bitlen + 1)) + 1)
 
 if __name__ == "__main__":
     unittest.main()
index b5648fca7dc5cc2846a8992046c3864cf198ca78..2db8701a841a9447256a2fe36a216a72e809ce1f 100644 (file)
@@ -3237,12 +3237,12 @@ x_mul(PyLongObject *a, PyLongObject *b)
          * via exploiting that each entry in the multiplication
          * pyramid appears twice (except for the size_a squares).
          */
+        digit *paend = a->ob_digit + size_a;
         for (i = 0; i < size_a; ++i) {
             twodigits carry;
             twodigits f = a->ob_digit[i];
             digit *pz = z->ob_digit + (i << 1);
             digit *pa = a->ob_digit + i + 1;
-            digit *paend = a->ob_digit + size_a;
 
             SIGCHECK({
                     Py_DECREF(z);
@@ -3265,13 +3265,27 @@ x_mul(PyLongObject *a, PyLongObject *b)
                 assert(carry <= (PyLong_MASK << 1));
             }
             if (carry) {
+                /* See comment below. pz points at the highest possible
+                 * carry position from the last outer loop iteration, so
+                 * *pz is at most 1.
+                 */
+                assert(*pz <= 1);
                 carry += *pz;
-                *pz++ = (digit)(carry & PyLong_MASK);
+                *pz = (digit)(carry & PyLong_MASK);
                 carry >>= PyLong_SHIFT;
+                if (carry) {
+                    /* If there's still a carry, it must be into a position
+                     * that still holds a 0. Where the base
+                     ^ B is 1 << PyLong_SHIFT, the last add was of a carry no
+                     * more than 2*B - 2 to a stored digit no more than 1.
+                     * So the sum was no more than 2*B - 1, so the current
+                     * carry no more than floor((2*B - 1)/B) = 1.
+                     */
+                    assert(carry == 1);
+                    assert(pz[1] == 0);
+                    pz[1] = (digit)carry;
+                }
             }
-            if (carry)
-                *pz += (digit)(carry & PyLong_MASK);
-            assert((carry >> PyLong_SHIFT) == 0);
         }
     }
     else {      /* a is not the same as b -- gradeschool int mult */