]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
SF patch 936813: fast modular exponentiation
authorTim Peters <tim.peters@gmail.com>
Sun, 29 Aug 2004 22:16:50 +0000 (22:16 +0000)
committerTim Peters <tim.peters@gmail.com>
Sun, 29 Aug 2004 22:16:50 +0000 (22:16 +0000)
This checkin is adapted from part 1 (of 3) of Trevor Perrin's patch set.

x_mul()
  - sped a little by optimizing the C
  - sped a lot (~2X) if it's doing a square; note that long_pow() squares
    often
k_mul()
  - more cache-friendly now if it's doing a square
KARATSUBA_CUTOFF
  - boosted; gradeschool mult is quicker now, and it may have been too low
    for many platforms anyway
KARATSUBA_SQUARE_CUTOFF
  - new
  - since x_mul is a lot faster at squaring now, the point at which
    Karatsuba pays for squaring is much higher than for general mult

Include/longintrepr.h
Misc/ACKS
Misc/NEWS
Objects/longobject.c

index 5755adb3066655da8273a073d2080b00ed7dfa10..9ed1fe737b7e7ad32ca6438e302a0b40a560c2c4 100644 (file)
@@ -12,7 +12,7 @@ extern "C" {
    contains at least 16 bits, but it's made changeable anyway.
    Note: 'digit' should be able to hold 2*MASK+1, and 'twodigits'
    should be able to hold the intermediate results in 'mul'
-   (at most MASK << SHIFT).
+   (at most (BASE-1)*(2*BASE+1) == MASK*(2*MASK+3)).
    Also, x_sub assumes that 'digit' is an unsigned type, and overflow
    is handled by taking the result mod 2**N for some N > SHIFT.
    And, at some places it is assumed that MASK fits in an int, as well. */
index 6eb0f648202fc5b2151bb926b7c7a6f7c8b0fe9b..dfdf005ea8f97986abb690bbc2d3b3dcabc14582 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -442,6 +442,7 @@ Steven Pemberton
 Eduardo Pérez
 Fernando Pérez
 Mark Perrego
+Trevor Perrin
 Tim Peters
 Chris Petrilli
 Bjorn Pettersen
index 4656fa2e03cc8b2689995f80be4d698edcf59b24..431b343aa06c7bf8664eed331ee16df709c4f93d 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,16 @@ What's New in Python 2.4 alpha 3?
 Core and builtins
 -----------------
 
+- Some speedups for long arithmetic, thanks to Trevor Perrin.  Gradeschool
+  multiplication was sped a little by optimizing the C code.  Gradeschool
+  squaring was sped by about a factor of 2, by exploiting that about half
+  the digit products are duplicates in a square.  Because exponentiation
+  uses squaring often, this also speeds long power.  For example, the time
+  to compute 17**1000000 dropped from about 14 seconds to 9 on my box due
+  to this much.  The cutoff for Karatsuba multiplication was raised,
+  since gradeschool multiplication got quicker, and the cutoff was
+  aggressively small regardless.
+
 - OverflowWarning is no longer generated.  PEP 237 scheduled this to
   occur in Python 2.3, but since OverflowWarning was disabled by default,
   nobody realized it was still being generated.  On the chance that user
index f246bd2320f2183648e899f637aaab3150940ca0..2f6d103bfec5fc74797669f97e4af4c197cdec9f 100644 (file)
@@ -12,7 +12,8 @@
  * both operands contain more than KARATSUBA_CUTOFF digits (this
  * being an internal Python long digit, in base BASE).
  */
-#define KARATSUBA_CUTOFF 35
+#define KARATSUBA_CUTOFF 70
+#define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
 
 #define ABS(x) ((x) < 0 ? -(x) : (x))
 
@@ -1717,26 +1718,72 @@ x_mul(PyLongObject *a, PyLongObject *b)
                return NULL;
 
        memset(z->ob_digit, 0, z->ob_size * sizeof(digit));
-       for (i = 0; i < size_a; ++i) {
-               twodigits carry = 0;
-               twodigits f = a->ob_digit[i];
-               int j;
-               digit *pz = z->ob_digit + i;
+       if (a == b) {
+               /* Efficient squaring per HAC, Algorithm 14.16:
+                * http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+                * Gives slightly less than a 2x speedup when a == b,
+                * via exploiting that each entry in the multiplication
+                * pyramid appears twice (except for the size_a squares).
+                */
+               for (i = 0; i < size_a; ++i) {
+                       twodigits carry;
+                       twodigits f = a->ob_digit[i];
+                       digit *pz = z->ob_digit + (i << 1);
+                       digit *pa = a->ob_digit + i + 1;
+                       digit *paend = a->ob_digit + size_a;
 
-               SIGCHECK({
-                       Py_DECREF(z);
-                       return NULL;
-               })
-               for (j = 0; j < size_b; ++j) {
-                       carry += *pz + b->ob_digit[j] * f;
-                       *pz++ = (digit) (carry & MASK);
+                       SIGCHECK({
+                               Py_DECREF(z);
+                               return NULL;
+                       })
+
+                       carry = *pz + f * f;
+                       *pz++ = (digit)(carry & MASK);
                        carry >>= SHIFT;
+                       assert(carry <= MASK);
+
+                       /* Now f is added in twice in each column of the
+                        * pyramid it appears.  Same as adding f<<1 once.
+                        */
+                       f <<= 1;
+                       while (pa < paend) {
+                               carry += *pz + *pa++ * f;
+                               *pz++ = (digit)(carry & MASK);
+                               carry >>= SHIFT;
+                               assert(carry <= (MASK << 1));
+                       }
+                       if (carry) {
+                               carry += *pz;
+                               *pz++ = (digit)(carry & MASK);
+                               carry >>= SHIFT;
+                       }
+                       if (carry)
+                               *pz += (digit)(carry & MASK);
+                       assert((carry >> SHIFT) == 0);
                }
-               for (; carry != 0; ++j) {
-                       assert(i+j < z->ob_size);
-                       carry += *pz;
-                       *pz++ = (digit) (carry & MASK);
-                       carry >>= SHIFT;
+       }
+       else {  /* a is not the same as b -- gradeschool long mult */
+               for (i = 0; i < size_a; ++i) {
+                       twodigits carry = 0;
+                       twodigits f = a->ob_digit[i];
+                       digit *pz = z->ob_digit + i;
+                       digit *pb = b->ob_digit;
+                       digit *pbend = b->ob_digit + size_b;
+
+                       SIGCHECK({
+                               Py_DECREF(z);
+                               return NULL;
+                       })
+
+                       while (pb < pbend) {
+                               carry += *pz + *pb++ * f;
+                               *pz++ = (digit)(carry & MASK);
+                               carry >>= SHIFT;
+                               assert(carry <= MASK);
+                       }
+                       if (carry)
+                               *pz += (digit)(carry & MASK);
+                       assert((carry >> SHIFT) == 0);
                }
        }
        return long_normalize(z);
@@ -1816,7 +1863,8 @@ k_mul(PyLongObject *a, PyLongObject *b)
        }
 
        /* Use gradeschool math when either number is too small. */
-       if (asize <= KARATSUBA_CUTOFF) {
+       i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
+       if (asize <= i) {
                if (asize == 0)
                        return _PyLong_New(0);
                else
@@ -1837,7 +1885,13 @@ k_mul(PyLongObject *a, PyLongObject *b)
        if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
        assert(ah->ob_size > 0);        /* the split isn't degenerate */
 
-       if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
+       if (a == b) {
+               bh = ah;
+               bl = al;
+               Py_INCREF(bh);
+               Py_INCREF(bl);
+       }
+       else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
 
        /* The plan:
         * 1. Allocate result space (asize + bsize digits:  that's always
@@ -1906,7 +1960,11 @@ k_mul(PyLongObject *a, PyLongObject *b)
        Py_DECREF(al);
        ah = al = NULL;
 
-       if ((t2 = x_add(bh, bl)) == NULL) {
+       if (a == b) {
+               t2 = t1;
+               Py_INCREF(t2);
+       }
+       else if ((t2 = x_add(bh, bl)) == NULL) {
                Py_DECREF(t1);
                goto fail;
        }