From 267b868f23a85c6a63a06452c85487355cf9ab8a Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Sun, 27 Mar 2005 10:47:39 +0000 Subject: [PATCH] * Fix decimal's handling of foreign types. Now returns NotImplemented instead of raising a TypeError. Allows other types to successfully implement __radd__() style methods. * Remove future division import from test suite. * Remove test suite's shadowing of __builtin__.dir(). --- Lib/decimal.py | 43 +++++++++++++++++++++++++++--- Lib/test/test_decimal.py | 57 +++++++++++++++++++++++++++++++++++----- Misc/NEWS | 5 ++++ 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/Lib/decimal.py b/Lib/decimal.py index fb11e8f3e4f..e3e7fd5c11c 100644 --- a/Lib/decimal.py +++ b/Lib/decimal.py @@ -645,6 +645,8 @@ def __nonzero__(self): def __cmp__(self, other, context=None): other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context) @@ -696,12 +698,12 @@ def __cmp__(self, other, context=None): def __eq__(self, other): if not isinstance(other, (Decimal, int, long)): - return False + return NotImplemented return self.__cmp__(other) == 0 def __ne__(self, other): if not isinstance(other, (Decimal, int, long)): - return True + return NotImplemented return self.__cmp__(other) != 0 def compare(self, other, context=None): @@ -714,6 +716,8 @@ def compare(self, other, context=None): Like __cmp__, but returns Decimal instances. """ other = _convert_other(other) + if other is NotImplemented: + return other #compare(NaN, NaN) = NaN if (self._is_special or other and other._is_special): @@ -919,6 +923,8 @@ def __add__(self, other, context=None): -INF + INF (or the reverse) cause InvalidOperation errors. """ other = _convert_other(other) + if other is NotImplemented: + return other if context is None: context = getcontext() @@ -1006,6 +1012,8 @@ def __add__(self, other, context=None): def __sub__(self, other, context=None): """Return self + (-other)""" other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context=context) @@ -1023,6 +1031,8 @@ def __sub__(self, other, context=None): def __rsub__(self, other, context=None): """Return other + (-self)""" other = _convert_other(other) + if other is NotImplemented: + return other tmp = Decimal(self) tmp._sign = 1 - tmp._sign @@ -1068,6 +1078,8 @@ def __mul__(self, other, context=None): (+-) INF * 0 (or its reverse) raise InvalidOperation. """ other = _convert_other(other) + if other is NotImplemented: + return other if context is None: context = getcontext() @@ -1140,6 +1152,10 @@ def _divide(self, other, divmod = 0, context=None): computing the other value are not raised. """ other = _convert_other(other) + if other is NotImplemented: + if divmod in (0, 1): + return NotImplemented + return (NotImplemented, NotImplemented) if context is None: context = getcontext() @@ -1292,6 +1308,8 @@ def _divide(self, other, divmod = 0, context=None): def __rdiv__(self, other, context=None): """Swaps self/other and returns __div__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__div__(self, context=context) __rtruediv__ = __rdiv__ @@ -1304,6 +1322,8 @@ def __divmod__(self, other, context=None): def __rdivmod__(self, other, context=None): """Swaps self/other and returns __divmod__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__divmod__(self, context=context) def __mod__(self, other, context=None): @@ -1311,6 +1331,8 @@ def __mod__(self, other, context=None): self % other """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context) @@ -1325,6 +1347,8 @@ def __mod__(self, other, context=None): def __rmod__(self, other, context=None): """Swaps self/other and returns __mod__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__mod__(self, context=context) def remainder_near(self, other, context=None): @@ -1332,6 +1356,8 @@ def remainder_near(self, other, context=None): Remainder nearest to 0- abs(remainder-near) <= other/2 """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context) @@ -1411,6 +1437,8 @@ def __floordiv__(self, other, context=None): def __rfloordiv__(self, other, context=None): """Swaps self/other and returns __floordiv__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__floordiv__(self, context=context) def __float__(self): @@ -1661,6 +1689,8 @@ def __pow__(self, n, modulo = None, context=None): If modulo is None (default), don't take it mod modulo. """ n = _convert_other(n) + if n is NotImplemented: + return n if context is None: context = getcontext() @@ -1747,6 +1777,8 @@ def __pow__(self, n, modulo = None, context=None): def __rpow__(self, other, context=None): """Swaps self/other and returns __pow__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__pow__(self, context=context) def normalize(self, context=None): @@ -2001,6 +2033,8 @@ def max(self, other, context=None): NaN (and signals if one is sNaN). Also rounds. """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: # if one operand is a quiet NaN and the other is number, then the @@ -2048,6 +2082,8 @@ def min(self, other, context=None): NaN (and signals if one is sNaN). Also rounds. """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: # if one operand is a quiet NaN and the other is number, then the @@ -2874,8 +2910,7 @@ def _convert_other(other): return other if isinstance(other, (int, long)): return Decimal(other) - - raise TypeError, "You can interact Decimal only with int, long or Decimal data types." + return NotImplemented _infinity_map = { 'inf' : 1, diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index fc1e0482846..34f034b850b 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -24,8 +24,6 @@ with the corresponding argument. """ -from __future__ import division - import unittest import glob import os, sys @@ -54,9 +52,9 @@ else: file = __file__ testdir = os.path.dirname(file) or os.curdir -dir = testdir + os.sep + TESTDATADIR + os.sep +directory = testdir + os.sep + TESTDATADIR + os.sep -skip_expected = not os.path.isdir(dir) +skip_expected = not os.path.isdir(directory) # Make sure it actually raises errors when not expected and caught in flags # Slower, since it runs some things several times. @@ -109,7 +107,6 @@ class DecimalTest(unittest.TestCase): Changed for unittest. """ def setUp(self): - global dir self.context = Context() for key in DefaultContext.traps.keys(): DefaultContext.traps[key] = 1 @@ -302,11 +299,11 @@ def change_clamp(self, clamp): # Dynamically build custom test definition for each file in the test # directory and add the definitions to the DecimalTest class. This # procedure insures that new files do not get skipped. -for filename in os.listdir(dir): +for filename in os.listdir(directory): if '.decTest' not in filename: continue head, tail = filename.split('.') - tester = lambda self, f=filename: self.eval_file(dir + f) + tester = lambda self, f=filename: self.eval_file(directory + f) setattr(DecimalTest, 'test_' + head, tester) del filename, head, tail, tester @@ -476,6 +473,52 @@ def test_implicit_from_float(self): def test_implicit_from_Decimal(self): self.assertEqual(Decimal(5) + Decimal(45), Decimal(50)) + def test_rop(self): + # Allow other classes to be trained to interact with Decimals + class E: + def __divmod__(self, other): + return 'divmod ' + str(other) + def __rdivmod__(self, other): + return str(other) + ' rdivmod' + def __lt__(self, other): + return 'lt ' + str(other) + def __gt__(self, other): + return 'gt ' + str(other) + def __le__(self, other): + return 'le ' + str(other) + def __ge__(self, other): + return 'ge ' + str(other) + def __eq__(self, other): + return 'eq ' + str(other) + def __ne__(self, other): + return 'ne ' + str(other) + + self.assertEqual(divmod(E(), Decimal(10)), 'divmod 10') + self.assertEqual(divmod(Decimal(10), E()), '10 rdivmod') + self.assertEqual(eval('Decimal(10) < E()'), 'gt 10') + self.assertEqual(eval('Decimal(10) > E()'), 'lt 10') + self.assertEqual(eval('Decimal(10) <= E()'), 'ge 10') + self.assertEqual(eval('Decimal(10) >= E()'), 'le 10') + self.assertEqual(eval('Decimal(10) == E()'), 'eq 10') + self.assertEqual(eval('Decimal(10) != E()'), 'ne 10') + + # insert operator methods and then exercise them + for sym, lop, rop in ( + ('+', '__add__', '__radd__'), + ('-', '__sub__', '__rsub__'), + ('*', '__mul__', '__rmul__'), + ('/', '__div__', '__rdiv__'), + ('%', '__mod__', '__rmod__'), + ('//', '__floordiv__', '__rfloordiv__'), + ('**', '__pow__', '__rpow__'), + ): + + setattr(E, lop, lambda self, other: 'str' + lop + str(other)) + setattr(E, rop, lambda self, other: str(other) + rop + 'str') + self.assertEqual(eval('E()' + sym + 'Decimal(10)'), + 'str' + lop + '10') + self.assertEqual(eval('Decimal(10)' + sym + 'E()'), + '10' + rop + 'str') class DecimalArithmeticOperatorsTest(unittest.TestCase): '''Unit tests for all arithmetic operators, binary and unary.''' diff --git a/Misc/NEWS b/Misc/NEWS index 1706874c24d..9a63f210381 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -91,6 +91,11 @@ Library - distutils.commands.upload was added to support uploading distribution files to PyPI. +- decimal operator and comparison methods now return NotImplemented + instead of raising a TypeError when interacting with other types. This + allows other classes to implement __radd__ style methods and have them + work as expected. + - Bug #1163325: Decimal infinities failed to hash. Attempting to hash a NaN raised an InvalidOperation instead of a TypeError.