bpo-34659: Adds initial kwarg to itertools.accumulate() (GH-9345)

This commit is contained in:
Lisa Roach 2018-09-23 17:34:59 -07:00 committed by GitHub
parent c87d9f406b
commit 9718b59ee5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 19 deletions

View file

@ -86,29 +86,38 @@ The following module functions all construct and return iterators. Some provide
streams of infinite length, so they should only be accessed by functions or
loops that truncate the stream.
.. function:: accumulate(iterable[, func])
.. function:: accumulate(iterable[, func, *, initial=None])
Make an iterator that returns accumulated sums, or accumulated
results of other binary functions (specified via the optional
*func* argument). If *func* is supplied, it should be a function
*func* argument).
If *func* is supplied, it should be a function
of two arguments. Elements of the input *iterable* may be any type
that can be accepted as arguments to *func*. (For example, with
the default operation of addition, elements may be any addable
type including :class:`~decimal.Decimal` or
:class:`~fractions.Fraction`.) If the input iterable is empty, the
output iterable will also be empty.
:class:`~fractions.Fraction`.)
Usually, the number of elements output matches the input iterable.
However, if the keyword argument *initial* is provided, the
accumulation leads off with the *initial* value so that the output
has one more element than the input iterable.
Roughly equivalent to::
def accumulate(iterable, func=operator.add):
def accumulate(iterable, func=operator.add, *, initial=None):
'Return running totals'
# accumulate([1,2,3,4,5]) --> 1 3 6 10 15
# accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
# accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
total = initial
if initial is None:
try:
total = next(it)
except StopIteration:
return
yield total
for element in it:
total = func(total, element)
@ -152,6 +161,9 @@ loops that truncate the stream.
.. versionchanged:: 3.3
Added the optional *func* parameter.
.. versionchanged:: 3.8
Added the optional *initial* parameter.
.. function:: chain(*iterables)
Make an iterator that returns elements from the first iterable until it is

View file

@ -147,6 +147,12 @@ def test_accumulate(self):
list(accumulate(s, chr)) # unary-operation
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, accumulate(range(10))) # test pickling
self.pickletest(proto, accumulate(range(10), initial=7))
self.assertEqual(list(accumulate([10, 5, 1], initial=None)), [10, 15, 16])
self.assertEqual(list(accumulate([10, 5, 1], initial=100)), [100, 110, 115, 116])
self.assertEqual(list(accumulate([], initial=100)), [100])
with self.assertRaises(TypeError):
list(accumulate([10, 20], 100))
def test_chain(self):

View file

@ -0,0 +1 @@
Add an optional *initial* argument to itertools.accumulate().

View file

@ -382,29 +382,30 @@ exit:
}
PyDoc_STRVAR(itertools_accumulate__doc__,
"accumulate(iterable, func=None)\n"
"accumulate(iterable, func=None, *, initial=None)\n"
"--\n"
"\n"
"Return series of accumulated sums (or other binary function results).");
static PyObject *
itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
PyObject *binop);
PyObject *binop, PyObject *initial);
static PyObject *
itertools_accumulate(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
PyObject *return_value = NULL;
static const char * const _keywords[] = {"iterable", "func", NULL};
static _PyArg_Parser _parser = {"O|O:accumulate", _keywords, 0};
static const char * const _keywords[] = {"iterable", "func", "initial", NULL};
static _PyArg_Parser _parser = {"O|O$O:accumulate", _keywords, 0};
PyObject *iterable;
PyObject *binop = Py_None;
PyObject *initial = Py_None;
if (!_PyArg_ParseTupleAndKeywordsFast(args, kwargs, &_parser,
&iterable, &binop)) {
&iterable, &binop, &initial)) {
goto exit;
}
return_value = itertools_accumulate_impl(type, iterable, binop);
return_value = itertools_accumulate_impl(type, iterable, binop, initial);
exit:
return return_value;
@ -509,4 +510,4 @@ itertools_count(PyTypeObject *type, PyObject *args, PyObject *kwargs)
exit:
return return_value;
}
/*[clinic end generated code: output=d9eb9601bd3296ef input=a9049054013a1b77]*/
/*[clinic end generated code: output=c8c47b766deeffc3 input=a9049054013a1b77]*/

View file

@ -3475,6 +3475,7 @@ typedef struct {
PyObject *total;
PyObject *it;
PyObject *binop;
PyObject *initial;
} accumulateobject;
static PyTypeObject accumulate_type;
@ -3484,18 +3485,19 @@ static PyTypeObject accumulate_type;
itertools.accumulate.__new__
iterable: object
func as binop: object = None
*
initial: object = None
Return series of accumulated sums (or other binary function results).
[clinic start generated code]*/
static PyObject *
itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
PyObject *binop)
/*[clinic end generated code: output=514d0fb30ba14d55 input=6d9d16aaa1d3cbfc]*/
PyObject *binop, PyObject *initial)
/*[clinic end generated code: output=66da2650627128f8 input=c4ce20ac59bf7ffd]*/
{
PyObject *it;
accumulateobject *lz;
/* Get iterator. */
it = PyObject_GetIter(iterable);
if (it == NULL)
@ -3514,6 +3516,8 @@ itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
}
lz->total = NULL;
lz->it = it;
Py_XINCREF(initial);
lz->initial = initial;
return (PyObject *)lz;
}
@ -3524,6 +3528,7 @@ accumulate_dealloc(accumulateobject *lz)
Py_XDECREF(lz->binop);
Py_XDECREF(lz->total);
Py_XDECREF(lz->it);
Py_XDECREF(lz->initial);
Py_TYPE(lz)->tp_free(lz);
}
@ -3533,6 +3538,7 @@ accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
Py_VISIT(lz->binop);
Py_VISIT(lz->it);
Py_VISIT(lz->total);
Py_VISIT(lz->initial);
return 0;
}
@ -3541,6 +3547,13 @@ accumulate_next(accumulateobject *lz)
{
PyObject *val, *newtotal;
if (lz->initial != Py_None) {
lz->total = lz->initial;
Py_INCREF(Py_None);
lz->initial = Py_None;
Py_INCREF(lz->total);
return lz->total;
}
val = (*Py_TYPE(lz->it)->tp_iternext)(lz->it);
if (val == NULL)
return NULL;
@ -3567,6 +3580,19 @@ accumulate_next(accumulateobject *lz)
static PyObject *
accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored))
{
if (lz->initial != Py_None) {
PyObject *it;
assert(lz->total == NULL);
if (PyType_Ready(&chain_type) < 0)
return NULL;
it = PyObject_CallFunction((PyObject *)&chain_type, "(O)O",
lz->initial, lz->it);
if (it == NULL)
return NULL;
return Py_BuildValue("O(NO)O", Py_TYPE(lz),
it, lz->binop?lz->binop:Py_None, Py_None);
}
if (lz->total == Py_None) {
PyObject *it;