]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-46020: Optimize long_pow for the common case (GH-30555)
authorTim Peters <tim.peters@gmail.com>
Wed, 12 Jan 2022 18:55:02 +0000 (12:55 -0600)
committerGitHub <noreply@github.com>
Wed, 12 Jan 2022 18:55:02 +0000 (12:55 -0600)
This cuts a bit of overhead by not initializing the table of small
odd powers unless it's needed for a large exponent.

Objects/longobject.c

index 2db8701a841a9447256a2fe36a216a72e809ce1f..5d181aa0850aa8bf1c3f9e64c2ed144e33ac522e 100644 (file)
@@ -4215,8 +4215,13 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
     /* k-ary values.  If the exponent is large enough, table is
      * precomputed so that table[i] == a**(2*i+1) % c for i in
      * range(EXP_TABLE_LEN).
+     * Note: this is uninitialzed stack trash: don't pay to set it to known
+     * values unless it's needed. Instead ensure that num_table_entries is
+     * set to the number of entries actually filled whenever a branch to the
+     * Error or Done labels is possible.
      */
-    PyLongObject *table[EXP_TABLE_LEN] = {0};
+    PyLongObject *table[EXP_TABLE_LEN];
+    Py_ssize_t num_table_entries = 0;
 
     /* a, b, c = v, w, x */
     CHECK_BINOP(v, w);
@@ -4408,10 +4413,14 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
          */
         Py_INCREF(a);
         table[0] = a;
+        num_table_entries = 1;
         MULT(a, a, a2);
         /* table[i] == a**(2*i + 1) % c */
-        for (i = 1; i < EXP_TABLE_LEN; ++i)
+        for (i = 1; i < EXP_TABLE_LEN; ++i) {
+            table[i] = NULL; /* must set to known value for MULT */
             MULT(table[i-1], a2, table[i]);
+            ++num_table_entries; /* incremented iff MULT succeeded */
+        }
         Py_CLEAR(a2);
 
         /* Repeatedly extract the next (no more than) EXP_WINDOW_SIZE bits
@@ -4472,10 +4481,8 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
     Py_CLEAR(z);
     /* fall through */
   Done:
-    if (Py_SIZE(b) > HUGE_EXP_CUTOFF / PyLong_SHIFT) {
-        for (i = 0; i < EXP_TABLE_LEN; ++i)
-            Py_XDECREF(table[i]);
-    }
+    for (i = 0; i < num_table_entries; ++i)
+        Py_DECREF(table[i]);
     Py_DECREF(a);
     Py_DECREF(b);
     Py_XDECREF(c);