PEP 465: a dedicated infix operator for matrix multiplication (closes #21176)

This commit is contained in:
Benjamin Peterson 2014-04-09 23:55:56 -04:00
parent 2aad6ef774
commit d51374ed78
42 changed files with 803 additions and 442 deletions

View file

@ -30,6 +30,14 @@ Number Protocol
the equivalent of the Python expression ``o1 * o2``.
.. c:function:: PyObject* PyNumber_MatrixMultiply(PyObject *o1, PyObject *o2)
Returns the result of matrix multiplication on *o1* and *o2*, or *NULL* on
failure. This is the equivalent of the Python expression ``o1 @ o2``.
.. versionadded:: 3.5
.. c:function:: PyObject* PyNumber_FloorDivide(PyObject *o1, PyObject *o2)
Return the floor of *o1* divided by *o2*, or *NULL* on failure. This is
@ -146,6 +154,15 @@ Number Protocol
the Python statement ``o1 *= o2``.
.. c:function:: PyObject* PyNumber_InPlaceMatrixMultiply(PyObject *o1, PyObject *o2)
Returns the result of matrix multiplication on *o1* and *o2*, or *NULL* on
failure. The operation is done *in-place* when *o1* supports it. This is
the equivalent of the Python statement ``o1 @= o2``.
.. versionadded:: 3.5
.. c:function:: PyObject* PyNumber_InPlaceFloorDivide(PyObject *o1, PyObject *o2)
Returns the mathematical floor of dividing *o1* by *o2*, or *NULL* on failure.

View file

@ -1121,6 +1121,9 @@ Number Object Structures
binaryfunc nb_inplace_true_divide;
unaryfunc nb_index;
binaryfunc nb_matrix_multiply;
binaryfunc nb_inplace_matrix_multiply;
} PyNumberMethods;
.. note::

View file

@ -364,6 +364,11 @@ result back on the stack.
Implements ``TOS = TOS1 * TOS``.
.. opcode:: BINARY_MATRIX_MULTIPLY
Implements ``TOS = TOS1 @ TOS``.
.. opcode:: BINARY_FLOOR_DIVIDE
Implements ``TOS = TOS1 // TOS``.
@ -436,6 +441,11 @@ the original TOS1.
Implements in-place ``TOS = TOS1 * TOS``.
.. opcode:: INPLACE_MATRIX_MULTIPLY
Implements in-place ``TOS = TOS1 @ TOS``.
.. opcode:: INPLACE_FLOOR_DIVIDE
Implements in-place ``TOS = TOS1 // TOS``.

View file

@ -138,6 +138,14 @@ The mathematical and bitwise operations are the most numerous:
Return ``a * b``, for *a* and *b* numbers.
.. function:: matmul(a, b)
__matmul__(a, b)
Return ``a @ b``.
.. versionadded:: 3.5
.. function:: neg(obj)
__neg__(obj)
@ -400,6 +408,8 @@ Python syntax and the functions in the :mod:`operator` module.
+-----------------------+-------------------------+---------------------------------------+
| Multiplication | ``a * b`` | ``mul(a, b)`` |
+-----------------------+-------------------------+---------------------------------------+
| Matrix Multiplication | ``a @ b`` | ``matmul(a, b)`` |
+-----------------------+-------------------------+---------------------------------------+
| Negation (Arithmetic) | ``- a`` | ``neg(a)`` |
+-----------------------+-------------------------+---------------------------------------+
| Negation (Logical) | ``not a`` | ``not_(a)`` |
@ -508,6 +518,14 @@ will perform the update, so no subsequent assignment is necessary:
``a = imul(a, b)`` is equivalent to ``a *= b``.
.. function:: imatmul(a, b)
__imatmul__(a, b)
``a = imatmul(a, b)`` is equivalent to ``a @= b``.
.. versionadded:: 3.5
.. function:: ior(a, b)
__ior__(a, b)

View file

@ -93,6 +93,7 @@ The token constants are:
DOUBLESLASH
DOUBLESLASHEQUAL
AT
ATEQUAL
RARROW
ELLIPSIS
OP

View file

@ -1970,6 +1970,7 @@ left undefined.
.. method:: object.__add__(self, other)
object.__sub__(self, other)
object.__mul__(self, other)
object.__matmul__(self, other)
object.__truediv__(self, other)
object.__floordiv__(self, other)
object.__mod__(self, other)
@ -1986,15 +1987,16 @@ left undefined.
builtin: pow
builtin: pow
These methods are called to implement the binary arithmetic operations (``+``,
``-``, ``*``, ``/``, ``//``, ``%``, :func:`divmod`, :func:`pow`, ``**``, ``<<``,
``>>``, ``&``, ``^``, ``|``). For instance, to evaluate the expression
``x + y``, where *x* is an instance of a class that has an :meth:`__add__`
method, ``x.__add__(y)`` is called. The :meth:`__divmod__` method should be the
equivalent to using :meth:`__floordiv__` and :meth:`__mod__`; it should not be
related to :meth:`__truediv__`. Note that :meth:`__pow__` should be defined
to accept an optional third argument if the ternary version of the built-in
:func:`pow` function is to be supported.
These methods are called to implement the binary arithmetic operations
(``+``, ``-``, ``*``, ``@``, ``/``, ``//``, ``%``, :func:`divmod`,
:func:`pow`, ``**``, ``<<``, ``>>``, ``&``, ``^``, ``|``). For instance, to
evaluate the expression ``x + y``, where *x* is an instance of a class that
has an :meth:`__add__` method, ``x.__add__(y)`` is called. The
:meth:`__divmod__` method should be the equivalent to using
:meth:`__floordiv__` and :meth:`__mod__`; it should not be related to
:meth:`__truediv__`. Note that :meth:`__pow__` should be defined to accept
an optional third argument if the ternary version of the built-in :func:`pow`
function is to be supported.
If one of those methods does not support the operation with the supplied
arguments, it should return ``NotImplemented``.
@ -2003,6 +2005,7 @@ left undefined.
.. method:: object.__radd__(self, other)
object.__rsub__(self, other)
object.__rmul__(self, other)
object.__rmatmul__(self, other)
object.__rtruediv__(self, other)
object.__rfloordiv__(self, other)
object.__rmod__(self, other)
@ -2018,14 +2021,14 @@ left undefined.
builtin: divmod
builtin: pow
These methods are called to implement the binary arithmetic operations (``+``,
``-``, ``*``, ``/``, ``//``, ``%``, :func:`divmod`, :func:`pow`, ``**``,
``<<``, ``>>``, ``&``, ``^``, ``|``) with reflected (swapped) operands.
These functions are only called if the left operand does not support the
corresponding operation and the operands are of different types. [#]_ For
instance, to evaluate the expression ``x - y``, where *y* is an instance of
a class that has an :meth:`__rsub__` method, ``y.__rsub__(x)`` is called if
``x.__sub__(y)`` returns *NotImplemented*.
These methods are called to implement the binary arithmetic operations
(``+``, ``-``, ``*``, ``@``, ``/``, ``//``, ``%``, :func:`divmod`,
:func:`pow`, ``**``, ``<<``, ``>>``, ``&``, ``^``, ``|``) with reflected
(swapped) operands. These functions are only called if the left operand does
not support the corresponding operation and the operands are of different
types. [#]_ For instance, to evaluate the expression ``x - y``, where *y* is
an instance of a class that has an :meth:`__rsub__` method, ``y.__rsub__(x)``
is called if ``x.__sub__(y)`` returns *NotImplemented*.
.. index:: builtin: pow
@ -2043,6 +2046,7 @@ left undefined.
.. method:: object.__iadd__(self, other)
object.__isub__(self, other)
object.__imul__(self, other)
object.__imatmul__(self, other)
object.__itruediv__(self, other)
object.__ifloordiv__(self, other)
object.__imod__(self, other)
@ -2054,17 +2058,17 @@ left undefined.
object.__ior__(self, other)
These methods are called to implement the augmented arithmetic assignments
(``+=``, ``-=``, ``*=``, ``/=``, ``//=``, ``%=``, ``**=``, ``<<=``, ``>>=``,
``&=``, ``^=``, ``|=``). These methods should attempt to do the operation
in-place (modifying *self*) and return the result (which could be, but does
not have to be, *self*). If a specific method is not defined, the augmented
assignment falls back to the normal methods. For instance, if *x* is an
instance of a class with an :meth:`__iadd__` method, ``x += y`` is equivalent
to ``x = x.__iadd__(y)`` . Otherwise, ``x.__add__(y)`` and ``y.__radd__(x)``
are considered, as with the evaluation of ``x + y``. In certain situations,
augmented assignment can result in unexpected errors (see
:ref:`faq-augmented-assignment-tuple-error`), but this behavior is in
fact part of the data model.
(``+=``, ``-=``, ``*=``, ``@=``, ``/=``, ``//=``, ``%=``, ``**=``, ``<<=``,
``>>=``, ``&=``, ``^=``, ``|=``). These methods should attempt to do the
operation in-place (modifying *self*) and return the result (which could be,
but does not have to be, *self*). If a specific method is not defined, the
augmented assignment falls back to the normal methods. For instance, if *x*
is an instance of a class with an :meth:`__iadd__` method, ``x += y`` is
equivalent to ``x = x.__iadd__(y)`` . Otherwise, ``x.__add__(y)`` and
``y.__radd__(x)`` are considered, as with the evaluation of ``x + y``. In
certain situations, augmented assignment can result in unexpected errors (see
:ref:`faq-augmented-assignment-tuple-error`), but this behavior is in fact
part of the data model.
.. method:: object.__neg__(self)

View file

@ -892,8 +892,9 @@ from the power operator, there are only two levels, one for multiplicative
operators and one for additive operators:
.. productionlist::
m_expr: `u_expr` | `m_expr` "*" `u_expr` | `m_expr` "//" `u_expr` | `m_expr` "/" `u_expr`
: | `m_expr` "%" `u_expr`
m_expr: `u_expr` | `m_expr` "*" `u_expr` | `m_expr` "@" `m_expr` |
: `m_expr` "//" `u_expr`| `m_expr` "/" `u_expr` |
: `m_expr` "%" `u_expr`
a_expr: `m_expr` | `a_expr` "+" `m_expr` | `a_expr` "-" `m_expr`
.. index:: single: multiplication
@ -904,6 +905,13 @@ the other must be a sequence. In the former case, the numbers are converted to a
common type and then multiplied together. In the latter case, sequence
repetition is performed; a negative repetition factor yields an empty sequence.
.. index:: single: matrix multiplication
The ``@`` (at) operator is intended to be used for matrix multiplication. No
builtin Python types implement this operator.
.. versionadded:: 3.5
.. index::
exception: ZeroDivisionError
single: division
@ -1346,8 +1354,9 @@ groups from right to left).
+-----------------------------------------------+-------------------------------------+
| ``+``, ``-`` | Addition and subtraction |
+-----------------------------------------------+-------------------------------------+
| ``*``, ``/``, ``//``, ``%`` | Multiplication, division, remainder |
| | [#]_ |
| ``*``, ``@``, ``/``, ``//``, ``%`` | Multiplication, matrix |
| | multiplication division, |
| | remainder [#]_ |
+-----------------------------------------------+-------------------------------------+
| ``+x``, ``-x``, ``~x`` | Positive, negative, bitwise NOT |
+-----------------------------------------------+-------------------------------------+

View file

@ -267,7 +267,7 @@ operation and an assignment statement:
.. productionlist::
augmented_assignment_stmt: `augtarget` `augop` (`expression_list` | `yield_expression`)
augtarget: `identifier` | `attributeref` | `subscription` | `slicing`
augop: "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | "**="
augop: "+=" | "-=" | "*=" | "@=" | "/=" | "//=" | "%=" | "**="
: | ">>=" | "<<=" | "&=" | "^=" | "|="
(See section :ref:`primaries` for the syntax definitions for the last three

View file

@ -40,7 +40,7 @@ small_stmt: (expr_stmt | del_stmt | pass_stmt | flow_stmt |
expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) |
('=' (yield_expr|testlist_star_expr))*)
testlist_star_expr: (test|star_expr) (',' (test|star_expr))* [',']
augassign: ('+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^=' |
augassign: ('+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^=' |
'<<=' | '>>=' | '**=' | '//=')
# For normal assignments, additional restrictions enforced by the interpreter
del_stmt: 'del' exprlist
@ -97,7 +97,7 @@ xor_expr: and_expr ('^' and_expr)*
and_expr: shift_expr ('&' shift_expr)*
shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)*
term: factor (('*'|'@'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power
power: atom trailer* ['**' factor]
atom: ('(' [yield_expr|testlist_comp] ')' |

View file

@ -15,9 +15,9 @@ typedef struct _slice *slice_ty;
typedef enum _boolop { And=1, Or=2 } boolop_ty;
typedef enum _operator { Add=1, Sub=2, Mult=3, Div=4, Mod=5, Pow=6, LShift=7,
RShift=8, BitOr=9, BitXor=10, BitAnd=11, FloorDiv=12 }
operator_ty;
typedef enum _operator { Add=1, Sub=2, Mult=3, MatMult=4, Div=5, Mod=6, Pow=7,
LShift=8, RShift=9, BitOr=10, BitXor=11, BitAnd=12,
FloorDiv=13 } operator_ty;
typedef enum _unaryop { Invert=1, Not=2, UAdd=3, USub=4 } unaryop_ty;

View file

@ -658,6 +658,12 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
o1*o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_MatrixMultiply(PyObject *o1, PyObject *o2);
/*
This is the equivalent of the Python expression: o1 @ o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_FloorDivide(PyObject *o1, PyObject *o2);
/*
@ -832,6 +838,12 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
o1 *= o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_InPlaceMatrixMultiply(PyObject *o1, PyObject *o2);
/*
This is the equivalent of the Python expression: o1 @= o2.
*/
PyAPI_FUNC(PyObject *) PyNumber_InPlaceFloorDivide(PyObject *o1,
PyObject *o2);

View file

@ -275,6 +275,9 @@ typedef struct {
binaryfunc nb_inplace_true_divide;
unaryfunc nb_index;
binaryfunc nb_matrix_multiply;
binaryfunc nb_inplace_matrix_multiply;
} PyNumberMethods;
typedef struct {

View file

@ -20,6 +20,9 @@ extern "C" {
#define UNARY_INVERT 15
#define BINARY_MATRIX_MULTIPLY 16
#define INPLACE_MATRIX_MULTIPLY 17
#define BINARY_POWER 19
#define BINARY_MULTIPLY 20

View file

@ -58,13 +58,14 @@ extern "C" {
#define DOUBLESTAREQUAL 46
#define DOUBLESLASH 47
#define DOUBLESLASHEQUAL 48
#define AT 49
#define RARROW 50
#define ELLIPSIS 51
#define AT 49
#define ATEQUAL 50
#define RARROW 51
#define ELLIPSIS 52
/* Don't forget to update the table _PyParser_TokenNames in tokenizer.c! */
#define OP 52
#define ERRORTOKEN 53
#define N_TOKENS 54
#define OP 53
#define ERRORTOKEN 54
#define N_TOKENS 55
/* Special definitions for cooperation with parser */

View file

@ -74,3 +74,5 @@
#define Py_tp_members 72
#define Py_tp_getset 73
#define Py_tp_free 74
#define Py_nb_matrix_multiply 75
#define Py_nb_inplace_matrix_multiply 76

View file

@ -419,12 +419,13 @@ def _call_with_frames_removed(f, *args, **kwds):
# Python 3.4a4 3290 (changes to __qualname__ computation)
# Python 3.4a4 3300 (more changes to __qualname__ computation)
# Python 3.4rc2 3310 (alter __qualname__ computation)
# Python 3.5a0 3320 (matrix multiplication operator)
#
# MAGIC must change whenever the bytecode emitted by the compiler may no
# longer be understood by older implementations of the eval loop (usually
# due to the addition of new opcodes).
MAGIC_NUMBER = (3310).to_bytes(2, 'little') + b'\r\n'
MAGIC_NUMBER = (3320).to_bytes(2, 'little') + b'\r\n'
_RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c
_PYCACHE = '__pycache__'

View file

@ -70,6 +70,9 @@ def jabs_op(name, op):
def_op('UNARY_INVERT', 15)
def_op('BINARY_MATRIX_MULTIPLY', 16)
def_op('INPLACE_MATRIX_MULTIPLY', 17)
def_op('BINARY_POWER', 19)
def_op('BINARY_MULTIPLY', 20)

View file

@ -105,6 +105,10 @@ def mul(a, b):
"Same as a * b."
return a * b
def matmul(a, b):
"Same as a @ b."
return a @ b
def neg(a):
"Same as -a."
return -a
@ -326,6 +330,11 @@ def imul(a, b):
a *= b
return a
def imatmul(a, b):
"Same as a @= b."
a @= b
return a
def ior(a, b):
"Same as a |= b."
a |= b
@ -383,6 +392,7 @@ def ixor(a, b):
__lshift__ = lshift
__mod__ = mod
__mul__ = mul
__matmul__ = matmul
__neg__ = neg
__or__ = or_
__pos__ = pos
@ -403,6 +413,7 @@ def ixor(a, b):
__ilshift__ = ilshift
__imod__ = imod
__imul__ = imul
__imatmul__ = imatmul
__ior__ = ior
__ipow__ = ipow
__irshift__ = irshift

View file

@ -136,6 +136,14 @@ def __imul__(self, val):
output.append("__imul__ called")
return self
def __matmul__(self, val):
output.append("__matmul__ called")
def __rmatmul__(self, val):
output.append("__rmatmul__ called")
def __imatmul__(self, val):
output.append("__imatmul__ called")
return self
def __div__(self, val):
output.append("__div__ called")
def __rdiv__(self, val):
@ -233,6 +241,10 @@ def __ilshift__(self, val):
1 * x
x *= 1
x @ 1
1 @ x
x @= 1
x / 1
1 / x
x /= 1
@ -279,6 +291,9 @@ def __ilshift__(self, val):
__mul__ called
__rmul__ called
__imul__ called
__matmul__ called
__rmatmul__ called
__imatmul__ called
__truediv__ called
__rtruediv__ called
__itruediv__ called

View file

@ -150,6 +150,23 @@ def test_docstring_signature_parsing(self):
self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__text_signature__,
"($module, /, parameter)")
def test_c_type_with_matrix_multiplication(self):
M = _testcapi.matmulType
m1 = M()
m2 = M()
self.assertEqual(m1 @ m2, ("matmul", m1, m2))
self.assertEqual(m1 @ 42, ("matmul", m1, 42))
self.assertEqual(42 @ m1, ("matmul", 42, m1))
o = m1
o @= m2
self.assertEqual(o, ("imatmul", m1, m2))
o = m1
o @= 42
self.assertEqual(o, ("imatmul", m1, 42))
o = 42
o @= m1
self.assertEqual(o, ("matmul", 42, m1))
@unittest.skipUnless(threading, 'Threading required for this test.')
class TestPendingCalls(unittest.TestCase):

View file

@ -4160,6 +4160,7 @@ def check(expr, x, y):
('__add__', 'x + y', 'x += y'),
('__sub__', 'x - y', 'x -= y'),
('__mul__', 'x * y', 'x *= y'),
('__matmul__', 'x @ y', 'x @= y'),
('__truediv__', 'operator.truediv(x, y)', None),
('__floordiv__', 'operator.floordiv(x, y)', None),
('__div__', 'x / y', 'x /= y'),

View file

@ -985,6 +985,20 @@ def test_paren_evaluation(self):
self.assertFalse((False is 2) is 3)
self.assertFalse(False is 2 is 3)
def test_matrix_mul(self):
# This is not intended to be a comprehensive test, rather just to be few
# samples of the @ operator in test_grammar.py.
class M:
def __matmul__(self, o):
return 4
def __imatmul__(self, o):
self.other = o
return self
m = M()
self.assertEqual(m @ m, 4)
m @= 42
self.assertEqual(m.other, 42)
def test_main():
run_unittest(TokenTests, GrammarTests)

View file

@ -203,6 +203,15 @@ def test_mul(self):
self.assertRaises(TypeError, operator.mul, None, None)
self.assertTrue(operator.mul(5, 2) == 10)
def test_matmul(self):
operator = self.module
self.assertRaises(TypeError, operator.matmul)
self.assertRaises(TypeError, operator.matmul, 42, 42)
class M:
def __matmul__(self, other):
return other - 1
self.assertEqual(M() @ 42, 41)
def test_neg(self):
operator = self.module
self.assertRaises(TypeError, operator.neg)
@ -416,6 +425,7 @@ def __ifloordiv__(self, other): return "ifloordiv"
def __ilshift__ (self, other): return "ilshift"
def __imod__ (self, other): return "imod"
def __imul__ (self, other): return "imul"
def __imatmul__ (self, other): return "imatmul"
def __ior__ (self, other): return "ior"
def __ipow__ (self, other): return "ipow"
def __irshift__ (self, other): return "irshift"
@ -430,6 +440,7 @@ def __getitem__(self, other): return 5 # so that C is a sequence
self.assertEqual(operator.ilshift (c, 5), "ilshift")
self.assertEqual(operator.imod (c, 5), "imod")
self.assertEqual(operator.imul (c, 5), "imul")
self.assertEqual(operator.imatmul (c, 5), "imatmul")
self.assertEqual(operator.ior (c, 5), "ior")
self.assertEqual(operator.ipow (c, 5), "ipow")
self.assertEqual(operator.irshift (c, 5), "irshift")

View file

@ -952,7 +952,7 @@ def delx(self): del self.__x
check(int, s)
# (PyTypeObject + PyNumberMethods + PyMappingMethods +
# PySequenceMethods + PyBufferProcs + 4P)
s = vsize('P2n15Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P')
s = vsize('P2n17Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P')
# Separate block for PyDictKeysObject with 4 entries
s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P")
# class

View file

@ -464,7 +464,7 @@
Multiplicative
>>> dump_tokens("x = 1//1*1/5*12%0x12")
>>> dump_tokens("x = 1//1*1/5*12%0x12@42")
ENCODING 'utf-8' (0, 0) (0, 0)
NAME 'x' (1, 0) (1, 1)
OP '=' (1, 2) (1, 3)
@ -479,6 +479,8 @@
NUMBER '12' (1, 13) (1, 15)
OP '%' (1, 15) (1, 16)
NUMBER '0x12' (1, 16) (1, 20)
OP '@' (1, 20) (1, 21)
NUMBER '42' (1, 21) (1, 23)
Unary
@ -1154,6 +1156,7 @@ def test_exact_type(self):
self.assertExactTypeEqual('//', token.DOUBLESLASH)
self.assertExactTypeEqual('//=', token.DOUBLESLASHEQUAL)
self.assertExactTypeEqual('@', token.AT)
self.assertExactTypeEqual('@=', token.ATEQUAL)
self.assertExactTypeEqual('a**2+b**2==c**2',
NAME, token.DOUBLESTAR, NUMBER,

View file

@ -60,11 +60,12 @@ DOUBLESTAREQUAL = 46
DOUBLESLASH = 47
DOUBLESLASHEQUAL = 48
AT = 49
RARROW = 50
ELLIPSIS = 51
OP = 52
ERRORTOKEN = 53
N_TOKENS = 54
ATEQUAL = 50
RARROW = 51
ELLIPSIS = 52
OP = 53
ERRORTOKEN = 54
N_TOKENS = 55
NT_OFFSET = 256
#--end constants--

View file

@ -91,7 +91,8 @@
'**=': DOUBLESTAREQUAL,
'//': DOUBLESLASH,
'//=': DOUBLESLASHEQUAL,
'@': AT
'@': AT,
'@=': ATEQUAL,
}
class TokenInfo(collections.namedtuple('TokenInfo', 'type string start end line')):
@ -150,7 +151,7 @@ def maybe(*choices): return group(*choices) + '?'
# recognized as two instances of =).
Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"!=",
r"//=?", r"->",
r"[+\-*/%&|^=<>]=?",
r"[+\-*/%&@|^=<>]=?",
r"~")
Bracket = '[][(){}]'

View file

@ -10,6 +10,8 @@ Release date: TBA
Core and Builtins
-----------------
- PEP 465 and Issue #21176: Add the '@' operator for matrix multiplication.
- Issue #21134: Fix segfault when str is called on an uninitialized
UnicodeEncodeError, UnicodeDecodeError, or UnicodeTranslateError object.

View file

@ -69,6 +69,7 @@ spami(truth , PyObject_IsTrue)
spam2(op_add , PyNumber_Add)
spam2(op_sub , PyNumber_Subtract)
spam2(op_mul , PyNumber_Multiply)
spam2(op_matmul , PyNumber_MatrixMultiply)
spam2(op_floordiv , PyNumber_FloorDivide)
spam2(op_truediv , PyNumber_TrueDivide)
spam2(op_mod , PyNumber_Remainder)
@ -86,6 +87,7 @@ spam2(op_or_ , PyNumber_Or)
spam2(op_iadd , PyNumber_InPlaceAdd)
spam2(op_isub , PyNumber_InPlaceSubtract)
spam2(op_imul , PyNumber_InPlaceMultiply)
spam2(op_imatmul , PyNumber_InPlaceMatrixMultiply)
spam2(op_ifloordiv , PyNumber_InPlaceFloorDivide)
spam2(op_itruediv , PyNumber_InPlaceTrueDivide)
spam2(op_imod , PyNumber_InPlaceRemainder)
@ -343,6 +345,7 @@ spam2o(index, "index(a) -- Same as a.__index__()")
spam2(add, "add(a, b) -- Same as a + b.")
spam2(sub, "sub(a, b) -- Same as a - b.")
spam2(mul, "mul(a, b) -- Same as a * b.")
spam2(matmul, "matmul(a, b) -- Same as a @ b.")
spam2(floordiv, "floordiv(a, b) -- Same as a // b.")
spam2(truediv, "truediv(a, b) -- Same as a / b.")
spam2(mod, "mod(a, b) -- Same as a % b.")
@ -360,6 +363,7 @@ spam2(or_, "or_(a, b) -- Same as a | b.")
spam2(iadd, "a = iadd(a, b) -- Same as a += b.")
spam2(isub, "a = isub(a, b) -- Same as a -= b.")
spam2(imul, "a = imul(a, b) -- Same as a *= b.")
spam2(imatmul, "a = imatmul(a, b) -- Same as a @= b.")
spam2(ifloordiv, "a = ifloordiv(a, b) -- Same as a //= b.")
spam2(itruediv, "a = itruediv(a, b) -- Same as a /= b")
spam2(imod, "a = imod(a, b) -- Same as a %= b.")

View file

@ -3298,6 +3298,109 @@ static PyTypeObject test_structmembersType = {
};
typedef struct {
PyObject_HEAD
} matmulObject;
static PyObject *
matmulType_matmul(PyObject *self, PyObject *other)
{
return Py_BuildValue("(sOO)", "matmul", self, other);
}
static PyObject *
matmulType_imatmul(PyObject *self, PyObject *other)
{
return Py_BuildValue("(sOO)", "imatmul", self, other);
}
static void
matmulType_dealloc(PyObject *self)
{
return Py_TYPE(self)->tp_free(self);
}
static PyNumberMethods matmulType_as_number = {
0, /* nb_add */
0, /* nb_subtract */
0, /* nb_multiply */
0, /* nb_remainde r*/
0, /* nb_divmod */
0, /* nb_power */
0, /* nb_negative */
0, /* tp_positive */
0, /* tp_absolute */
0, /* tp_bool */
0, /* nb_invert */
0, /* nb_lshift */
0, /* nb_rshift */
0, /* nb_and */
0, /* nb_xor */
0, /* nb_or */
0, /* nb_int */
0, /* nb_reserved */
0, /* nb_float */
0, /* nb_inplace_add */
0, /* nb_inplace_subtract */
0, /* nb_inplace_multiply */
0, /* nb_inplace_remainder */
0, /* nb_inplace_power */
0, /* nb_inplace_lshift */
0, /* nb_inplace_rshift */
0, /* nb_inplace_and */
0, /* nb_inplace_xor */
0, /* nb_inplace_or */
0, /* nb_floor_divide */
0, /* nb_true_divide */
0, /* nb_inplace_floor_divide */
0, /* nb_inplace_true_divide */
0, /* nb_index */
matmulType_matmul, /* nb_matrix_multiply */
matmulType_imatmul /* nb_matrix_inplace_multiply */
};
static PyTypeObject matmulType = {
PyVarObject_HEAD_INIT(NULL, 0)
"matmulType",
sizeof(matmulObject), /* tp_basicsize */
0, /* tp_itemsize */
matmulType_dealloc, /* destructor tp_dealloc */
0, /* tp_print */
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
0, /* tp_repr */
&matmulType_as_number, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
PyObject_GenericGetAttr, /* tp_getattro */
PyObject_GenericSetAttr, /* tp_setattro */
0, /* tp_as_buffer */
0, /* tp_flags */
"C level type with matrix operations defined",
0, /* traverseproc tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0,
0,
0,
0,
0,
0,
0,
0,
PyType_GenericNew, /* tp_new */
PyObject_Del, /* tp_free */
};
static struct PyModuleDef _testcapimodule = {
PyModuleDef_HEAD_INIT,
@ -3327,6 +3430,10 @@ PyInit__testcapi(void)
/* don't use a name starting with "test", since we don't want
test_capi to automatically call this */
PyModule_AddObject(m, "_test_structmembersType", (PyObject *)&test_structmembersType);
if (PyType_Ready(&matmulType) < 0)
return NULL;
Py_INCREF(&matmulType);
PyModule_AddObject(m, "matmulType", (PyObject *)&matmulType);
PyModule_AddObject(m, "CHAR_MAX", PyLong_FromLong(CHAR_MAX));
PyModule_AddObject(m, "CHAR_MIN", PyLong_FromLong(CHAR_MIN));

View file

@ -931,6 +931,12 @@ PyNumber_Multiply(PyObject *v, PyObject *w)
return result;
}
PyObject *
PyNumber_MatrixMultiply(PyObject *v, PyObject *w)
{
return binary_op(v, w, NB_SLOT(nb_matrix_multiply), "@");
}
PyObject *
PyNumber_FloorDivide(PyObject *v, PyObject *w)
{
@ -1012,6 +1018,7 @@ INPLACE_BINOP(PyNumber_InPlaceAnd, nb_inplace_and, nb_and, "&=")
INPLACE_BINOP(PyNumber_InPlaceLshift, nb_inplace_lshift, nb_lshift, "<<=")
INPLACE_BINOP(PyNumber_InPlaceRshift, nb_inplace_rshift, nb_rshift, ">>=")
INPLACE_BINOP(PyNumber_InPlaceSubtract, nb_inplace_subtract, nb_subtract, "-=")
INPLACE_BINOP(PyNumber_InMatrixMultiply, nb_inplace_matrix_multiply, nb_matrix_multiply, "@=")
PyObject *
PyNumber_InPlaceFloorDivide(PyObject *v, PyObject *w)
@ -1077,6 +1084,13 @@ PyNumber_InPlaceMultiply(PyObject *v, PyObject *w)
return result;
}
PyObject *
PyNumber_InPlaceMatrixMultiply(PyObject *v, PyObject *w)
{
return binary_iop(v, w, NB_SLOT(nb_inplace_matrix_multiply),
NB_SLOT(nb_matrix_multiply), "@=");
}
PyObject *
PyNumber_InPlaceRemainder(PyObject *v, PyObject *w)
{

View file

@ -4469,6 +4469,8 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base)
COPYNUM(nb_inplace_true_divide);
COPYNUM(nb_inplace_floor_divide);
COPYNUM(nb_index);
COPYNUM(nb_matrix_multiply);
COPYNUM(nb_inplace_matrix_multiply);
}
if (type->tp_as_sequence != NULL && base->tp_as_sequence != NULL) {
@ -5605,6 +5607,7 @@ slot_mp_ass_subscript(PyObject *self, PyObject *key, PyObject *value)
SLOT1BIN(slot_nb_add, nb_add, "__add__", "__radd__")
SLOT1BIN(slot_nb_subtract, nb_subtract, "__sub__", "__rsub__")
SLOT1BIN(slot_nb_multiply, nb_multiply, "__mul__", "__rmul__")
SLOT1BIN(slot_nb_matrix_multiply, nb_matrix_multiply, "__matmul__", "__rmatmul__")
SLOT1BIN(slot_nb_remainder, nb_remainder, "__mod__", "__rmod__")
SLOT1BIN(slot_nb_divmod, nb_divmod, "__divmod__", "__rdivmod__")
@ -5698,6 +5701,7 @@ SLOT0(slot_nb_float, "__float__")
SLOT1(slot_nb_inplace_add, "__iadd__", PyObject *, "O")
SLOT1(slot_nb_inplace_subtract, "__isub__", PyObject *, "O")
SLOT1(slot_nb_inplace_multiply, "__imul__", PyObject *, "O")
SLOT1(slot_nb_inplace_matrix_multiply, "__imatmul__", PyObject *, "O")
SLOT1(slot_nb_inplace_remainder, "__imod__", PyObject *, "O")
/* Can't use SLOT1 here, because nb_inplace_power is ternary */
static PyObject *
@ -6278,6 +6282,12 @@ static slotdef slotdefs[] = {
"__index__($self, /)\n--\n\n"
"Return self converted to an integer, if self is suitable"
"for use as an index into a list."),
BINSLOT("__matmul__", nb_matrix_multiply, slot_nb_matrix_multiply,
"@"),
RBINSLOT("__rmatmul__", nb_matrix_multiply, slot_nb_matrix_multiply,
"@"),
IBSLOT("__imatmul__", nb_inplace_matrix_multiply, slot_nb_inplace_matrix_multiply,
wrap_binaryfunc, "@="),
MPSLOT("__len__", mp_length, slot_mp_length, wrap_lenfunc,
"__len__($self, /)\n--\n\nReturn len(self)."),
MPSLOT("__getitem__", mp_subscript, slot_mp_subscript,

View file

@ -73,3 +73,5 @@ offsetof(PyHeapTypeObject, ht_type.tp_traverse),
offsetof(PyHeapTypeObject, ht_type.tp_members),
offsetof(PyHeapTypeObject, ht_type.tp_getset),
offsetof(PyHeapTypeObject, ht_type.tp_free),
offsetof(PyHeapTypeObject, as_number.nb_matrix_multiply),
offsetof(PyHeapTypeObject, as_number.nb_inplace_matrix_multiply),

View file

@ -91,7 +91,7 @@ module Python
boolop = And | Or
operator = Add | Sub | Mult | Div | Mod | Pow | LShift
operator = Add | Sub | Mult | MatMult | Div | Mod | Pow | LShift
| RShift | BitOr | BitXor | BitAnd | FloorDiv
unaryop = Invert | Not | UAdd | USub

View file

@ -98,6 +98,7 @@ const char *_PyParser_TokenNames[] = {
"DOUBLESLASH",
"DOUBLESLASHEQUAL",
"AT",
"ATEQUAL",
"RARROW",
"ELLIPSIS",
/* This table must match the #defines in token.h! */
@ -1131,7 +1132,7 @@ PyToken_OneChar(int c)
case '}': return RBRACE;
case '^': return CIRCUMFLEX;
case '~': return TILDE;
case '@': return AT;
case '@': return AT;
default: return OP;
}
}
@ -1207,6 +1208,11 @@ PyToken_TwoChars(int c1, int c2)
case '=': return CIRCUMFLEXEQUAL;
}
break;
case '@':
switch (c2) {
case '=': return ATEQUAL;
}
break;
}
return OP;
}

View file

@ -349,13 +349,14 @@ static PyTypeObject *And_type;
static PyTypeObject *Or_type;
static PyTypeObject *operator_type;
static PyObject *Add_singleton, *Sub_singleton, *Mult_singleton,
*Div_singleton, *Mod_singleton, *Pow_singleton, *LShift_singleton,
*RShift_singleton, *BitOr_singleton, *BitXor_singleton, *BitAnd_singleton,
*FloorDiv_singleton;
*MatMult_singleton, *Div_singleton, *Mod_singleton, *Pow_singleton,
*LShift_singleton, *RShift_singleton, *BitOr_singleton, *BitXor_singleton,
*BitAnd_singleton, *FloorDiv_singleton;
static PyObject* ast2obj_operator(operator_ty);
static PyTypeObject *Add_type;
static PyTypeObject *Sub_type;
static PyTypeObject *Mult_type;
static PyTypeObject *MatMult_type;
static PyTypeObject *Div_type;
static PyTypeObject *Mod_type;
static PyTypeObject *Pow_type;
@ -970,6 +971,10 @@ static int init_types(void)
if (!Mult_type) return 0;
Mult_singleton = PyType_GenericNew(Mult_type, NULL, NULL);
if (!Mult_singleton) return 0;
MatMult_type = make_type("MatMult", operator_type, NULL, 0);
if (!MatMult_type) return 0;
MatMult_singleton = PyType_GenericNew(MatMult_type, NULL, NULL);
if (!MatMult_singleton) return 0;
Div_type = make_type("Div", operator_type, NULL, 0);
if (!Div_type) return 0;
Div_singleton = PyType_GenericNew(Div_type, NULL, NULL);
@ -3232,6 +3237,9 @@ PyObject* ast2obj_operator(operator_ty o)
case Mult:
Py_INCREF(Mult_singleton);
return Mult_singleton;
case MatMult:
Py_INCREF(MatMult_singleton);
return MatMult_singleton;
case Div:
Py_INCREF(Div_singleton);
return Div_singleton;
@ -6175,6 +6183,14 @@ obj2ast_operator(PyObject* obj, operator_ty* out, PyArena* arena)
*out = Mult;
return 0;
}
isinstance = PyObject_IsInstance(obj, (PyObject *)MatMult_type);
if (isinstance == -1) {
return 1;
}
if (isinstance) {
*out = MatMult;
return 0;
}
isinstance = PyObject_IsInstance(obj, (PyObject *)Div_type);
if (isinstance == -1) {
return 1;
@ -6956,6 +6972,8 @@ PyInit__ast(void)
if (PyDict_SetItemString(d, "Add", (PyObject*)Add_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Sub", (PyObject*)Sub_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Mult", (PyObject*)Mult_type) < 0) return NULL;
if (PyDict_SetItemString(d, "MatMult", (PyObject*)MatMult_type) < 0) return
NULL;
if (PyDict_SetItemString(d, "Div", (PyObject*)Div_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Mod", (PyObject*)Mod_type) < 0) return NULL;
if (PyDict_SetItemString(d, "Pow", (PyObject*)Pow_type) < 0) return NULL;

View file

@ -825,6 +825,8 @@ get_operator(const node *n)
return Sub;
case STAR:
return Mult;
case AT:
return MatMult;
case SLASH:
return Div;
case DOUBLESLASH:
@ -1030,6 +1032,8 @@ ast_for_augassign(struct compiling *c, const node *n)
return Pow;
else
return Mult;
case '@':
return MatMult;
default:
PyErr_Format(PyExc_SystemError, "invalid augassign: %s", STR(n));
return (operator_ty)0;
@ -2266,7 +2270,7 @@ ast_for_expr(struct compiling *c, const node *n)
and_expr: shift_expr ('&' shift_expr)*
shift_expr: arith_expr (('<<'|'>>') arith_expr)*
arith_expr: term (('+'|'-') term)*
term: factor (('*'|'/'|'%'|'//') factor)*
term: factor (('*'|'@'|'/'|'%'|'//') factor)*
factor: ('+'|'-'|'~') factor | power
power: atom trailer* ('**' factor)*
*/
@ -2577,7 +2581,7 @@ ast_for_expr_stmt(struct compiling *c, const node *n)
/* expr_stmt: testlist_star_expr (augassign (yield_expr|testlist)
| ('=' (yield_expr|testlist))*)
testlist_star_expr: (test|star_expr) (',' test|star_expr)* [',']
augassign: '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^='
augassign: '+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^='
| '<<=' | '>>=' | '**=' | '//='
test: ... here starts the operator precendence dance
*/

View file

@ -1495,6 +1495,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
DISPATCH();
}
TARGET(BINARY_MATRIX_MULTIPLY) {
PyObject *right = POP();
PyObject *left = TOP();
PyObject *res = PyNumber_MatrixMultiply(left, right);
Py_DECREF(left);
Py_DECREF(right);
SET_TOP(res);
if (res == NULL)
goto error;
DISPATCH();
}
TARGET(BINARY_TRUE_DIVIDE) {
PyObject *divisor = POP();
PyObject *dividend = TOP();
@ -1685,6 +1697,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
DISPATCH();
}
TARGET(INPLACE_MATRIX_MULTIPLY) {
PyObject *right = POP();
PyObject *left = TOP();
PyObject *res = PyNumber_InPlaceMatrixMultiply(left, right);
Py_DECREF(left);
Py_DECREF(right);
SET_TOP(res);
if (res == NULL)
goto error;
DISPATCH();
}
TARGET(INPLACE_TRUE_DIVIDE) {
PyObject *divisor = POP();
PyObject *dividend = TOP();

View file

@ -881,6 +881,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
case BINARY_POWER:
case BINARY_MULTIPLY:
case BINARY_MATRIX_MULTIPLY:
case BINARY_MODULO:
case BINARY_ADD:
case BINARY_SUBTRACT:
@ -895,6 +896,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
case INPLACE_ADD:
case INPLACE_SUBTRACT:
case INPLACE_MULTIPLY:
case INPLACE_MATRIX_MULTIPLY:
case INPLACE_MODULO:
return -1;
case STORE_SUBSCR:
@ -2625,6 +2627,8 @@ binop(struct compiler *c, operator_ty op)
return BINARY_SUBTRACT;
case Mult:
return BINARY_MULTIPLY;
case MatMult:
return BINARY_MATRIX_MULTIPLY;
case Div:
return BINARY_TRUE_DIVIDE;
case Mod:
@ -2689,6 +2693,8 @@ inplace_binop(struct compiler *c, operator_ty op)
return INPLACE_SUBTRACT;
case Mult:
return INPLACE_MULTIPLY;
case MatMult:
return INPLACE_MATRIX_MULTIPLY;
case Div:
return INPLACE_TRUE_DIVIDE;
case Mod:

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -15,8 +15,8 @@ static void *opcode_targets[256] = {
&&_unknown_opcode,
&&_unknown_opcode,
&&TARGET_UNARY_INVERT,
&&_unknown_opcode,
&&_unknown_opcode,
&&TARGET_BINARY_MATRIX_MULTIPLY,
&&TARGET_INPLACE_MATRIX_MULTIPLY,
&&_unknown_opcode,
&&TARGET_BINARY_POWER,
&&TARGET_BINARY_MULTIPLY,