SF patch 936813: fast modular exponentiation

This checkin is adapted from part 2 (of 3) of Trevor Perrin's patch set.

BACKWARD INCOMPATIBILITY:  SHIFT must now be divisible by 5.  AFAIK,
nobody will care.  long_pow() could be complicated to worm around that,
if necessary.

long_pow():
  - BUGFIX:  This leaked the base and power when the power was negative
    (and so the computation delegated to float pow).
  - Instead of doing right-to-left exponentiation, do left-to-right.  This
    is more efficient for small bases, which is the common case.
  - In addition, if the exponent is large (more than FIVEARY_CUTOFF
    digits), precompute [a**i % c for i in range(32)], and go left to
    right 5 bits at a time.
l_divmod():
  - The signature changed so that callers who don't want the quotient,
    or don't want the remainder, can pass NULL in the slot they don't
    want.  This saves them from having to declare a vrbl for unwanted
    stuff, and remembering to decref it.
long_mod(), long_div(), long_classic_div():
  - Adjust to new l_divmod() signature, and simplified as a result.
This commit is contained in:
Tim Peters 2004-08-30 02:44:38 +00:00
parent 48bd7f3a71
commit 47e52ee0c5
3 changed files with 211 additions and 116 deletions

View file

@ -15,7 +15,8 @@ extern "C" {
(at most (BASE-1)*(2*BASE+1) == MASK*(2*MASK+3)).
Also, x_sub assumes that 'digit' is an unsigned type, and overflow
is handled by taking the result mod 2**N for some N > SHIFT.
And, at some places it is assumed that MASK fits in an int, as well. */
And, at some places it is assumed that MASK fits in an int, as well.
long_pow() requires that SHIFT be divisible by 5. */
typedef unsigned short digit;
typedef unsigned int wdigit; /* digit widened to parameter size */
@ -27,6 +28,10 @@ typedef BASE_TWODIGITS_TYPE stwodigits; /* signed variant of twodigits */
#define BASE ((digit)1 << SHIFT)
#define MASK ((int)(BASE - 1))
#if SHIFT % 5 != 0
#error "longobject.c requires that SHIFT be divisible by 5"
#endif
/* Long integer representation.
The absolute value of a number is equal to
SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i)

View file

@ -20,7 +20,11 @@ Core and builtins
to compute 17**1000000 dropped from about 14 seconds to 9 on my box due
to this much. The cutoff for Karatsuba multiplication was raised,
since gradeschool multiplication got quicker, and the cutoff was
aggressively small regardless.
aggressively small regardless. The exponentiation algorithm was switched
from right-to-left to left-to-right, which is more efficient for small
bases. In addition, if the exponent is large, the algorithm now does
5 bits (instead of 1 bit) at a time. That cut the time to compute
17**1000000 on my box in half again, down to about 4.5 seconds.
- OverflowWarning is no longer generated. PEP 237 scheduled this to
occur in Python 2.3, but since OverflowWarning was disabled by default,
@ -156,6 +160,14 @@ Tools/Demos
Build
-----
- Backward incompatibility: longintrepr.h now triggers a compile-time
error if SHIFT (the number of bits in a Python long "digit") isn't
divisible by 5. This new requirement allows simple code for the new
5-bits-at-a-time long_pow() implementation. If necessary, the
restriction could be removed (by complicating long_pow(), or by
falling back to the 1-bit-at-a-time algorithm), but there are no
plans to do so.
- bug #991962: When building with --disable-toolbox-glue on Darwin no
attempt to build Mac-specific modules occurs.

View file

@ -15,6 +15,13 @@
#define KARATSUBA_CUTOFF 70
#define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
/* For exponentiation, use the binary left-to-right algorithm
* unless the exponent contains more than FIVEARY_CUTOFF digits.
* In that case, do 5 bits at a time. The potential drawback is that
* a table of 2**5 intermediate results is computed.
*/
#define FIVEARY_CUTOFF 8
#define ABS(x) ((x) < 0 ? -(x) : (x))
#undef MIN
@ -2136,6 +2143,12 @@ long_mul(PyLongObject *v, PyLongObject *w)
have different signs. We then subtract one from the 'div'
part of the outcome to keep the invariant intact. */
/* Compute
* *pdiv, *pmod = divmod(v, w)
* NULL can be passed for pdiv or pmod, in which case that part of
* the result is simply thrown away. The caller owns a reference to
* each of these it requests (does not pass NULL for).
*/
static int
l_divmod(PyLongObject *v, PyLongObject *w,
PyLongObject **pdiv, PyLongObject **pmod)
@ -2167,44 +2180,43 @@ l_divmod(PyLongObject *v, PyLongObject *w,
Py_DECREF(div);
div = temp;
}
*pdiv = div;
*pmod = mod;
if (pdiv != NULL)
*pdiv = div;
else
Py_DECREF(div);
if (pmod != NULL)
*pmod = mod;
else
Py_DECREF(mod);
return 0;
}
static PyObject *
long_div(PyObject *v, PyObject *w)
{
PyLongObject *a, *b, *div, *mod;
PyLongObject *a, *b, *div;
CONVERT_BINOP(v, w, &a, &b);
if (l_divmod(a, b, &div, &mod) < 0) {
Py_DECREF(a);
Py_DECREF(b);
return NULL;
}
if (l_divmod(a, b, &div, NULL) < 0)
div = NULL;
Py_DECREF(a);
Py_DECREF(b);
Py_DECREF(mod);
return (PyObject *)div;
}
static PyObject *
long_classic_div(PyObject *v, PyObject *w)
{
PyLongObject *a, *b, *div, *mod;
PyLongObject *a, *b, *div;
CONVERT_BINOP(v, w, &a, &b);
if (Py_DivisionWarningFlag &&
PyErr_Warn(PyExc_DeprecationWarning, "classic long division") < 0)
div = NULL;
else if (l_divmod(a, b, &div, &mod) < 0)
else if (l_divmod(a, b, &div, NULL) < 0)
div = NULL;
else
Py_DECREF(mod);
Py_DECREF(a);
Py_DECREF(b);
return (PyObject *)div;
@ -2255,18 +2267,14 @@ long_true_divide(PyObject *v, PyObject *w)
static PyObject *
long_mod(PyObject *v, PyObject *w)
{
PyLongObject *a, *b, *div, *mod;
PyLongObject *a, *b, *mod;
CONVERT_BINOP(v, w, &a, &b);
if (l_divmod(a, b, &div, &mod) < 0) {
Py_DECREF(a);
Py_DECREF(b);
return NULL;
}
if (l_divmod(a, b, NULL, &mod) < 0)
mod = NULL;
Py_DECREF(a);
Py_DECREF(b);
Py_DECREF(div);
return (PyObject *)mod;
}
@ -2297,22 +2305,33 @@ long_divmod(PyObject *v, PyObject *w)
return z;
}
/* pow(v, w, x) */
static PyObject *
long_pow(PyObject *v, PyObject *w, PyObject *x)
{
PyLongObject *a, *b;
PyObject *c;
PyLongObject *z, *div, *mod;
int size_b, i;
PyLongObject *a, *b, *c; /* a,b,c = v,w,x */
int negativeOutput = 0; /* if x<0 return negative output */
PyLongObject *z = NULL; /* accumulated result */
int i, j, k; /* counters */
PyLongObject *temp = NULL;
/* 5-ary values. If the exponent is large enough, table is
* precomputed so that table[i] == a**i % c for i in range(32).
*/
PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
/* a, b, c = v, w, x */
CONVERT_BINOP(v, w, &a, &b);
if (PyLong_Check(x) || Py_None == x) {
c = x;
if (PyLong_Check(x)) {
c = (PyLongObject *)x;
Py_INCREF(x);
}
else if (PyInt_Check(x)) {
c = PyLong_FromLong(PyInt_AS_LONG(x));
}
else if (PyInt_Check(x))
c = (PyLongObject *)PyLong_FromLong(PyInt_AS_LONG(x));
else if (x == Py_None)
c = NULL;
else {
Py_DECREF(a);
Py_DECREF(b);
@ -2320,95 +2339,154 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
return Py_NotImplemented;
}
if (c != Py_None && ((PyLongObject *)c)->ob_size == 0) {
PyErr_SetString(PyExc_ValueError,
"pow() 3rd argument cannot be 0");
z = NULL;
goto error;
}
size_b = b->ob_size;
if (size_b < 0) {
Py_DECREF(a);
Py_DECREF(b);
Py_DECREF(c);
if (x != Py_None) {
if (b->ob_size < 0) { /* if exponent is negative */
if (c) {
PyErr_SetString(PyExc_TypeError, "pow() 2nd argument "
"cannot be negative when 3rd argument specified");
"cannot be negative when 3rd argument specified");
return NULL;
}
/* Return a float. This works because we know that
this calls float_pow() which converts its
arguments to double. */
return PyFloat_Type.tp_as_number->nb_power(v, w, x);
}
z = (PyLongObject *)PyLong_FromLong(1L);
for (i = 0; i < size_b; ++i) {
digit bi = b->ob_digit[i];
int j;
for (j = 0; j < SHIFT; ++j) {
PyLongObject *temp;
if (bi & 1) {
temp = (PyLongObject *)long_mul(z, a);
Py_DECREF(z);
if (c!=Py_None && temp!=NULL) {
if (l_divmod(temp,(PyLongObject *)c,
&div,&mod) < 0) {
Py_DECREF(temp);
z = NULL;
goto error;
}
Py_XDECREF(div);
Py_DECREF(temp);
temp = mod;
}
z = temp;
if (z == NULL)
break;
}
bi >>= 1;
if (bi == 0 && i+1 == size_b)
break;
temp = (PyLongObject *)long_mul(a, a);
Py_DECREF(a);
if (c!=Py_None && temp!=NULL) {
if (l_divmod(temp, (PyLongObject *)c, &div,
&mod) < 0) {
Py_DECREF(temp);
z = NULL;
goto error;
}
Py_XDECREF(div);
Py_DECREF(temp);
temp = mod;
}
a = temp;
if (a == NULL) {
Py_DECREF(z);
z = NULL;
break;
}
}
if (a == NULL || z == NULL)
break;
}
if (c!=Py_None && z!=NULL) {
if (l_divmod(z, (PyLongObject *)c, &div, &mod) < 0) {
Py_DECREF(z);
z = NULL;
}
else {
Py_XDECREF(div);
Py_DECREF(z);
z = mod;
/* else return a float. This works because we know
that this calls float_pow() which converts its
arguments to double. */
Py_DECREF(a);
Py_DECREF(b);
return PyFloat_Type.tp_as_number->nb_power(v, w, x);
}
}
error:
if (c) {
/* if modulus == 0:
raise ValueError() */
if (c->ob_size == 0) {
PyErr_SetString(PyExc_ValueError,
"pow() 3rd argument cannot be 0");
goto Done;
}
/* if modulus < 0:
negativeOutput = True
modulus = -modulus */
if (c->ob_size < 0) {
negativeOutput = 1;
temp = (PyLongObject *)_PyLong_Copy(c);
if (temp == NULL)
goto Error;
Py_DECREF(c);
c = temp;
temp = NULL;
c->ob_size = - c->ob_size;
}
/* if modulus == 1:
return 0 */
if ((c->ob_size == 1) && (c->ob_digit[0] == 1)) {
z = (PyLongObject *)PyLong_FromLong(0L);
goto Done;
}
/* if base < 0:
base = base % modulus
Having the base positive just makes things easier. */
if (a->ob_size < 0) {
if (l_divmod(a, c, NULL, &temp) < 0)
goto Done;
Py_DECREF(a);
a = temp;
temp = NULL;
}
}
/* At this point a, b, and c are guaranteed non-negative UNLESS
c is NULL, in which case a may be negative. */
z = (PyLongObject *)PyLong_FromLong(1L);
if (z == NULL)
goto Error;
/* Perform a modular reduction, X = X % c, but leave X alone if c
* is NULL.
*/
#define REDUCE(X) \
if (c != NULL) { \
if (l_divmod(X, c, NULL, &temp) < 0) \
goto Error; \
Py_XDECREF(X); \
X = temp; \
temp = NULL; \
}
/* Multiply two values, then reduce the result:
result = X*Y % c. If c is NULL, skip the mod. */
#define MULT(X, Y, result) \
{ \
temp = (PyLongObject *)long_mul(X, Y); \
if (temp == NULL) \
goto Error; \
Py_XDECREF(result); \
result = temp; \
temp = NULL; \
REDUCE(result) \
}
if (b->ob_size <= FIVEARY_CUTOFF) {
/* Left-to-right binary exponentiation (HAC Algorithm 14.79) */
/* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf */
for (i = b->ob_size - 1; i >= 0; --i) {
digit bi = b->ob_digit[i];
for (j = 1 << (SHIFT-1); j != 0; j >>= 1) {
MULT(z, z, z)
if (bi & j)
MULT(z, a, z)
}
}
}
else {
/* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */
Py_INCREF(z); /* still holds 1L */
table[0] = z;
for (i = 1; i < 32; ++i)
MULT(table[i-1], a, table[i])
for (i = b->ob_size - 1; i >= 0; --i) {
const digit bi = b->ob_digit[i];
for (j = SHIFT - 5; j >= 0; j -= 5) {
const int index = (bi >> j) & 0x1f;
for (k = 0; k < 5; ++k)
MULT(z, z, z)
if (index)
MULT(z, table[index], z)
}
}
}
if (negativeOutput && (z->ob_size != 0)) {
temp = (PyLongObject *)long_sub(z, c);
if (temp == NULL)
goto Error;
Py_DECREF(z);
z = temp;
temp = NULL;
}
goto Done;
Error:
if (z != NULL) {
Py_DECREF(z);
z = NULL;
}
/* fall through */
Done:
Py_XDECREF(a);
Py_DECREF(b);
Py_DECREF(c);
Py_XDECREF(b);
Py_XDECREF(c);
Py_XDECREF(temp);
if (b->ob_size > FIVEARY_CUTOFF) {
for (i = 0; i < 32; ++i)
Py_XDECREF(table[i]);
}
return (PyObject *)z;
}