]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-108346: Fix failed benchmark in decimal (#108353)
authorCharlie Zhao <zhaoyu_hit@qq.com>
Wed, 13 Sep 2023 04:17:55 +0000 (12:17 +0800)
committerGitHub <noreply@github.com>
Wed, 13 Sep 2023 04:17:55 +0000 (21:17 -0700)
Fix benchmark in decimal to work again after the int str conversion limits.

Modules/_decimal/tests/bench.py

index 24e091b6887ccdf4b71250b7b72bb2043c1ca14a..640290f2ec7962cd761c5f5f9b8cb4494d721778 100644 (file)
@@ -7,6 +7,8 @@
 
 
 import time
+import sys
+from functools import wraps
 from test.support.import_helper import import_fresh_module
 
 C = import_fresh_module('decimal', fresh=['_decimal'])
@@ -64,66 +66,85 @@ def factorial(n, m):
     else:
         return factorial(n, (n+m)//2) * factorial((n+m)//2 + 1, m)
 
+# Fix failed test cases caused by CVE-2020-10735 patch.
+# See gh-95778 for details.
+def increase_int_max_str_digits(maxdigits):
+    def _increase_int_max_str_digits(func, maxdigits=maxdigits):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            previous_int_limit = sys.get_int_max_str_digits()
+            sys.set_int_max_str_digits(maxdigits)
+            ans = func(*args, **kwargs)
+            sys.set_int_max_str_digits(previous_int_limit)
+            return ans
+        return wrapper
+    return _increase_int_max_str_digits
+
+def test_calc_pi():
+    print("\n# ======================================================================")
+    print("#                   Calculating pi, 10000 iterations")
+    print("# ======================================================================\n")
+
+    to_benchmark = [pi_float, pi_decimal]
+    if C is not None:
+        to_benchmark.insert(1, pi_cdecimal)
+
+    for prec in [9, 19]:
+        print("\nPrecision: %d decimal digits\n" % prec)
+        for func in to_benchmark:
+            start = time.time()
+            if C is not None:
+                C.getcontext().prec = prec
+            P.getcontext().prec = prec
+            for i in range(10000):
+                x = func()
+            print("%s:" % func.__name__.replace("pi_", ""))
+            print("result: %s" % str(x))
+            print("time: %fs\n" % (time.time()-start))
+
+@increase_int_max_str_digits(maxdigits=10000000)
+def test_factorial():
+    print("\n# ======================================================================")
+    print("#                               Factorial")
+    print("# ======================================================================\n")
 
-print("\n# ======================================================================")
-print("#                   Calculating pi, 10000 iterations")
-print("# ======================================================================\n")
-
-to_benchmark = [pi_float, pi_decimal]
-if C is not None:
-    to_benchmark.insert(1, pi_cdecimal)
-
-for prec in [9, 19]:
-    print("\nPrecision: %d decimal digits\n" % prec)
-    for func in to_benchmark:
-        start = time.time()
-        if C is not None:
-            C.getcontext().prec = prec
-        P.getcontext().prec = prec
-        for i in range(10000):
-            x = func()
-        print("%s:" % func.__name__.replace("pi_", ""))
-        print("result: %s" % str(x))
-        print("time: %fs\n" % (time.time()-start))
-
-
-print("\n# ======================================================================")
-print("#                               Factorial")
-print("# ======================================================================\n")
-
-if C is not None:
-    c = C.getcontext()
-    c.prec = C.MAX_PREC
-    c.Emax = C.MAX_EMAX
-    c.Emin = C.MIN_EMIN
+    if C is not None:
+        c = C.getcontext()
+        c.prec = C.MAX_PREC
+        c.Emax = C.MAX_EMAX
+        c.Emin = C.MIN_EMIN
 
-for n in [100000, 1000000]:
+    for n in [100000, 1000000]:
 
-    print("n = %d\n" % n)
+        print("n = %d\n" % n)
 
-    if C is not None:
-        # C version of decimal
+        if C is not None:
+            # C version of decimal
+            start_calc = time.time()
+            x = factorial(C.Decimal(n), 0)
+            end_calc = time.time()
+            start_conv = time.time()
+            sx = str(x)
+            end_conv = time.time()
+            print("cdecimal:")
+            print("calculation time: %fs" % (end_calc-start_calc))
+            print("conversion time: %fs\n" % (end_conv-start_conv))
+
+        # Python integers
         start_calc = time.time()
-        x = factorial(C.Decimal(n), 0)
+        y = factorial(n, 0)
         end_calc = time.time()
         start_conv = time.time()
-        sx = str(x)
-        end_conv = time.time()
-        print("cdecimal:")
-        print("calculation time: %fs" % (end_calc-start_calc))
-        print("conversion time: %fs\n" % (end_conv-start_conv))
+        sy = str(y)
+        end_conv =  time.time()
 
-    # Python integers
-    start_calc = time.time()
-    y = factorial(n, 0)
-    end_calc = time.time()
-    start_conv = time.time()
-    sy = str(y)
-    end_conv =  time.time()
+        print("int:")
+        print("calculation time: %fs" % (end_calc-start_calc))
+        print("conversion time: %fs\n\n" % (end_conv-start_conv))
 
-    print("int:")
-    print("calculation time: %fs" % (end_calc-start_calc))
-    print("conversion time: %fs\n\n" % (end_conv-start_conv))
+        if C is not None:
+            assert(sx == sy)
 
-    if C is not None:
-        assert(sx == sy)
+if __name__ == "__main__":
+    test_calc_pi()
+    test_factorial()