]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-118610: Centralize power caching in `_pylong.py` (#118611)
authorTim Peters <tim.peters@gmail.com>
Wed, 8 May 2024 00:09:09 +0000 (19:09 -0500)
committerGitHub <noreply@github.com>
Wed, 8 May 2024 00:09:09 +0000 (19:09 -0500)
A new `compute_powers()` function computes all and only the powers of the base the various base-conversion functions need, as efficiently as reasonably possible (turns out that invoking `**`is needed at most once). This typically gives a few % speedup, but the primary point is to simplify the base-conversion functions, which no longer need their own, ad hoc, and less efficient power-caching schemes.

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
Lib/_pylong.py
Lib/test/test_int.py

index 30bee6fc9ef54ca6f76c15d09bffd59f59a91e09..4970eb3fa67b2b91e8ecc42fbb8c454a777f39c0 100644 (file)
@@ -19,6 +19,86 @@ try:
 except ImportError:
     _decimal = None
 
+# A number of functions have this form, where `w` is a desired number of
+# digits in base `base`:
+#
+#    def inner(...w...):
+#        if w <= LIMIT:
+#            return something
+#        lo = w >> 1
+#        hi = w - lo
+#        something involving base**lo, inner(...lo...), j, and inner(...hi...)
+#    figure out largest w needed
+#    result = inner(w)
+#
+# They all had some on-the-fly scheme to cache `base**lo` results for reuse.
+# Power is costly.
+#
+# This routine aims to compute all amd only the needed powers in advance, as
+# efficiently as reasonably possible. This isn't trivial, and all the
+# on-the-fly methods did needless work in many cases. The driving code above
+# changes to:
+#
+#    figure out largest w needed
+#    mycache = compute_powers(w, base, LIMIT)
+#    result = inner(w)
+#
+# and `mycache[lo]` replaces `base**lo` in the inner function.
+#
+# While this does give minor speedups (a few percent at best), the primary
+# intent is to simplify the functions using this, by eliminating the need for
+# them to craft their own ad-hoc caching schemes.
+def compute_powers(w, base, more_than, show=False):
+    seen = set()
+    need = set()
+    ws = {w}
+    while ws:
+        w = ws.pop() # any element is fine to use next
+        if w in seen or w <= more_than:
+            continue
+        seen.add(w)
+        lo = w >> 1
+        # only _need_ lo here; some other path may, or may not, need hi
+        need.add(lo)
+        ws.add(lo)
+        if w & 1:
+            ws.add(lo + 1)
+
+    d = {}
+    if not need:
+        return d
+    it = iter(sorted(need))
+    first = next(it)
+    if show:
+        print("pow at", first)
+    d[first] = base ** first
+    for this in it:
+        if this - 1 in d:
+            if show:
+                print("* base at", this)
+            d[this] = d[this - 1] * base # cheap
+        else:
+            lo = this >> 1
+            hi = this - lo
+            assert lo in d
+            if show:
+                print("square at", this)
+            # Multiplying a bigint by itself (same object!) is about twice
+            # as fast in CPython.
+            sq = d[lo] * d[lo]
+            if hi != lo:
+                assert hi == lo + 1
+                if show:
+                    print("    and * base")
+                sq *= base
+            d[this] = sq
+    return d
+
+_unbounded_dec_context = decimal.getcontext().copy()
+_unbounded_dec_context.prec = decimal.MAX_PREC
+_unbounded_dec_context.Emax = decimal.MAX_EMAX
+_unbounded_dec_context.Emin = decimal.MIN_EMIN
+_unbounded_dec_context.traps[decimal.Inexact] = 1 # sanity check
 
 def int_to_decimal(n):
     """Asymptotically fast conversion of an 'int' to Decimal."""
@@ -33,57 +113,32 @@ def int_to_decimal(n):
     # "clever" recursive way.  If we want a string representation, we
     # apply str to _that_.
 
-    D = decimal.Decimal
-    D2 = D(2)
-
-    BITLIM = 128
-
-    mem = {}
-
-    def w2pow(w):
-        """Return D(2)**w and store the result. Also possibly save some
-        intermediate results. In context, these are likely to be reused
-        across various levels of the conversion to Decimal."""
-        if (result := mem.get(w)) is None:
-            if w <= BITLIM:
-                result = D2**w
-            elif w - 1 in mem:
-                result = (t := mem[w - 1]) + t
-            else:
-                w2 = w >> 1
-                # If w happens to be odd, w-w2 is one larger then w2
-                # now. Recurse on the smaller first (w2), so that it's
-                # in the cache and the larger (w-w2) can be handled by
-                # the cheaper `w-1 in mem` branch instead.
-                result = w2pow(w2) * w2pow(w - w2)
-            mem[w] = result
-        return result
+    from decimal import Decimal as D
+    BITLIM = 200
 
+    # Don't bother caching the "lo" mask in this; the time to compute it is
+    # tiny compared to the multiply.
     def inner(n, w):
         if w <= BITLIM:
             return D(n)
         w2 = w >> 1
         hi = n >> w2
-        lo = n - (hi << w2)
-        return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)
-
-    with decimal.localcontext() as ctx:
-        ctx.prec = decimal.MAX_PREC
-        ctx.Emax = decimal.MAX_EMAX
-        ctx.Emin = decimal.MIN_EMIN
-        ctx.traps[decimal.Inexact] = 1
+        lo = n & ((1 << w2) - 1)
+        return inner(lo, w2) + inner(hi, w - w2) * w2pow[w2]
 
+    with decimal.localcontext(_unbounded_dec_context):
+        nbits = n.bit_length()
+        w2pow = compute_powers(nbits, D(2), BITLIM)
         if n < 0:
             negate = True
             n = -n
         else:
             negate = False
-        result = inner(n, n.bit_length())
+        result = inner(n, nbits)
         if negate:
             result = -result
     return result
 
-
 def int_to_decimal_string(n):
     """Asymptotically fast conversion of an 'int' to a decimal string."""
     w = n.bit_length()
@@ -97,14 +152,13 @@ def int_to_decimal_string(n):
     # available.  This algorithm is asymptotically worse than the algorithm
     # using the decimal module, but better than the quadratic time
     # implementation in longobject.c.
+
+    DIGLIM = 1000
     def inner(n, w):
-        if w <= 1000:
+        if w <= DIGLIM:
             return str(n)
         w2 = w >> 1
-        d = pow10_cache.get(w2)
-        if d is None:
-            d = pow10_cache[w2] = 5**w2 << w2 # 10**i = (5*2)**i = 5**i * 2**i
-        hi, lo = divmod(n, d)
+        hi, lo = divmod(n, pow10[w2])
         return inner(hi, w - w2) + inner(lo, w2).zfill(w2)
 
     # The estimation of the number of decimal digits.
@@ -115,7 +169,9 @@ def int_to_decimal_string(n):
     # only if the number has way more than 10**15 digits, that exceeds
     # the 52-bit physical address limit in both Intel64 and AMD64.
     w = int(w * 0.3010299956639812 + 1)  # log10(2)
-    pow10_cache = {}
+    pow10 = compute_powers(w, 5, DIGLIM)
+    for k, v in pow10.items():
+        pow10[k] = v << k # 5**k << k == 5**k * 2**k == 10**k
     if n < 0:
         n = -n
         sign = '-'
@@ -128,7 +184,6 @@ def int_to_decimal_string(n):
         s = s.lstrip('0')
     return sign + s
 
-
 def _str_to_int_inner(s):
     """Asymptotically fast conversion of a 'str' to an 'int'."""
 
@@ -144,35 +199,15 @@ def _str_to_int_inner(s):
 
     DIGLIM = 2048
 
-    mem = {}
-
-    def w5pow(w):
-        """Return 5**w and store the result.
-        Also possibly save some intermediate results. In context, these
-        are likely to be reused across various levels of the conversion
-        to 'int'.
-        """
-        if (result := mem.get(w)) is None:
-            if w <= DIGLIM:
-                result = 5**w
-            elif w - 1 in mem:
-                result = mem[w - 1] * 5
-            else:
-                w2 = w >> 1
-                # If w happens to be odd, w-w2 is one larger then w2
-                # now. Recurse on the smaller first (w2), so that it's
-                # in the cache and the larger (w-w2) can be handled by
-                # the cheaper `w-1 in mem` branch instead.
-                result = w5pow(w2) * w5pow(w - w2)
-            mem[w] = result
-        return result
-
     def inner(a, b):
         if b - a <= DIGLIM:
             return int(s[a:b])
         mid = (a + b + 1) >> 1
-        return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
+        return (inner(mid, b)
+                + ((inner(a, mid) * w5pow[b - mid])
+                    << (b - mid)))
 
+    w5pow = compute_powers(len(s), 5, DIGLIM)
     return inner(0, len(s))
 
 
@@ -186,7 +221,6 @@ def int_from_string(s):
     s = s.rstrip().replace('_', '')
     return _str_to_int_inner(s)
 
-
 def str_to_int(s):
     """Asymptotically fast version of decimal string to 'int' conversion."""
     # FIXME: this doesn't support the full syntax that int() supports.
index c8626398b35b896adfa8740b307e27e75fbbfcd1..8959ffb6dcc236f3e6419afcece6964954a026f2 100644 (file)
@@ -906,6 +906,18 @@ class PyLongModuleTests(unittest.TestCase):
             with self.assertRaises(RuntimeError):
                 int(big_value)
 
+    def test_pylong_roundtrip(self):
+        from random import randrange, getrandbits
+        bits = 5000
+        while bits <= 1_000_000:
+            bits += randrange(-100, 101) # break bitlength patterns
+            hibit = 1 << (bits - 1)
+            n = hibit | getrandbits(bits - 1)
+            assert n.bit_length() == bits
+            sn = str(n)
+            self.assertFalse(sn.startswith('0'))
+            self.assertEqual(n, int(sn))
+            bits <<= 1
 
 if __name__ == "__main__":
     unittest.main()