]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-46258: Streamline isqrt fast path (#30333)
authorMark Dickinson <mdickinson@enthought.com>
Sat, 15 Jan 2022 09:58:04 +0000 (09:58 +0000)
committerGitHub <noreply@github.com>
Sat, 15 Jan 2022 09:58:04 +0000 (09:58 +0000)
Misc/NEWS.d/next/Library/2022-01-04-18-05-25.bpo-46258.DYgwRo.rst [new file with mode: 0644]
Modules/mathmodule.c

diff --git a/Misc/NEWS.d/next/Library/2022-01-04-18-05-25.bpo-46258.DYgwRo.rst b/Misc/NEWS.d/next/Library/2022-01-04-18-05-25.bpo-46258.DYgwRo.rst
new file mode 100644 (file)
index 0000000..b918ed1
--- /dev/null
@@ -0,0 +1,2 @@
+Speed up :func:`math.isqrt` for small positive integers by replacing two
+division steps with a lookup table.
index 3ab1a0776046dd2b08d8d567d2527b10beb4ae29..0c7d4de0686213a6d3b27c0143350e9297411384 100644 (file)
@@ -1718,20 +1718,49 @@ completes the proof sketch.
 
 */
 
+/*
+    The _approximate_isqrt_tab table provides approximate square roots for
+    16-bit integers. For any n in the range 2**14 <= n < 2**16, the value
+
+        a = _approximate_isqrt_tab[(n >> 8) - 64]
+
+    is an approximate square root of n, satisfying (a - 1)**2 < n < (a + 1)**2.
+
+    The table was computed in Python using the expression:
+
+        [min(round(sqrt(256*n + 128)), 255) for n in range(64, 256)]
+*/
+
+static const uint8_t _approximate_isqrt_tab[192] = {
+    128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
+    140, 141, 142, 143, 144, 144, 145, 146, 147, 148, 149, 150,
+    151, 151, 152, 153, 154, 155, 156, 156, 157, 158, 159, 160,
+    160, 161, 162, 163, 164, 164, 165, 166, 167, 167, 168, 169,
+    170, 170, 171, 172, 173, 173, 174, 175, 176, 176, 177, 178,
+    179, 179, 180, 181, 181, 182, 183, 183, 184, 185, 186, 186,
+    187, 188, 188, 189, 190, 190, 191, 192, 192, 193, 194, 194,
+    195, 196, 196, 197, 198, 198, 199, 200, 200, 201, 201, 202,
+    203, 203, 204, 205, 205, 206, 206, 207, 208, 208, 209, 210,
+    210, 211, 211, 212, 213, 213, 214, 214, 215, 216, 216, 217,
+    217, 218, 219, 219, 220, 220, 221, 221, 222, 223, 223, 224,
+    224, 225, 225, 226, 227, 227, 228, 228, 229, 229, 230, 230,
+    231, 232, 232, 233, 233, 234, 234, 235, 235, 236, 237, 237,
+    238, 238, 239, 239, 240, 240, 241, 241, 242, 242, 243, 243,
+    244, 244, 245, 246, 246, 247, 247, 248, 248, 249, 249, 250,
+    250, 251, 251, 252, 252, 253, 253, 254, 254, 255, 255, 255,
+};
 
 /* Approximate square root of a large 64-bit integer.
 
    Given `n` satisfying `2**62 <= n < 2**64`, return `a`
    satisfying `(a - 1)**2 < n < (a + 1)**2`. */
 
-static uint64_t
+static inline uint32_t
 _approximate_isqrt(uint64_t n)
 {
-    uint32_t u = 1U + (n >> 62);
-    u = (u << 1) + (n >> 59) / u;
-    u = (u << 3) + (n >> 53) / u;
-    u = (u << 7) + (n >> 41) / u;
-    return (u << 15) + (n >> 17) / u;
+    uint32_t u = _approximate_isqrt_tab[(n >> 56) - 64];
+    u = (u << 7) + (uint32_t)(n >> 41) / u;
+    return (u << 15) + (uint32_t)((n >> 17) / u);
 }
 
 /*[clinic input]
@@ -1749,7 +1778,8 @@ math_isqrt(PyObject *module, PyObject *n)
 {
     int a_too_large, c_bit_length;
     size_t c, d;
-    uint64_t m, u;
+    uint64_t m;
+    uint32_t u;
     PyObject *a = NULL, *b;
 
     n = _PyNumber_Index(n);
@@ -1776,18 +1806,17 @@ math_isqrt(PyObject *module, PyObject *n)
     c = (c - 1U) / 2U;
 
     /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a
-       fast, almost branch-free algorithm. In the final correction, we use `u*u
-       - 1 >= m` instead of the simpler `u*u > m` in order to get the correct
-       result in the corner case where `u=2**32`. */
+       fast, almost branch-free algorithm. */
     if (c <= 31U) {
+        int shift = 31 - (int)c;
         m = (uint64_t)PyLong_AsUnsignedLongLong(n);
         Py_DECREF(n);
         if (m == (uint64_t)(-1) && PyErr_Occurred()) {
             return NULL;
         }
-        u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c);
-        u -= u * u - 1U >= m;
-        return PyLong_FromUnsignedLongLong((unsigned long long)u);
+        u = _approximate_isqrt(m << 2*shift) >> shift;
+        u -= (uint64_t)u * u > m;
+        return PyLong_FromUnsignedLong(u);
     }
 
     /* Slow path: n >= 2**64. We perform the first five iterations in C integer
@@ -1811,7 +1840,7 @@ math_isqrt(PyObject *module, PyObject *n)
         goto error;
     }
     u = _approximate_isqrt(m) >> (31U - d);
-    a = PyLong_FromUnsignedLongLong((unsigned long long)u);
+    a = PyLong_FromUnsignedLong(u);
     if (a == NULL) {
         goto error;
     }