mirror of
https://github.com/python/cpython
synced 2024-10-14 13:08:09 +00:00
bpo-46032: Check types in singledispatch's register() at declaration time (GH-30050)
The registry() method of functools.singledispatch() functions checks now the first argument or the first parameter annotation and raises a TypeError if it is not supported. Previously unsupported "types" were ignored (e.g. typing.List[int]) or caused an error at calling time (e.g. list[int]).
This commit is contained in:
parent
1b30660c3b
commit
078abb676c
|
@ -740,6 +740,7 @@ def _compose_mro(cls, types):
|
|||
# Remove entries which are already present in the __mro__ or unrelated.
|
||||
def is_related(typ):
|
||||
return (typ not in bases and hasattr(typ, '__mro__')
|
||||
and not isinstance(typ, GenericAlias)
|
||||
and issubclass(cls, typ))
|
||||
types = [n for n in types if is_related(n)]
|
||||
# Remove entries which are strict bases of other entries (they will end up
|
||||
|
@ -841,9 +842,13 @@ def _is_union_type(cls):
|
|||
from typing import get_origin, Union
|
||||
return get_origin(cls) in {Union, types.UnionType}
|
||||
|
||||
def _is_valid_union_type(cls):
|
||||
def _is_valid_dispatch_type(cls):
|
||||
if isinstance(cls, type) and not isinstance(cls, GenericAlias):
|
||||
return True
|
||||
from typing import get_args
|
||||
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
|
||||
return (_is_union_type(cls) and
|
||||
all(isinstance(arg, type) and not isinstance(arg, GenericAlias)
|
||||
for arg in get_args(cls)))
|
||||
|
||||
def register(cls, func=None):
|
||||
"""generic_func.register(cls, func) -> func
|
||||
|
@ -852,9 +857,15 @@ def register(cls, func=None):
|
|||
|
||||
"""
|
||||
nonlocal cache_token
|
||||
if func is None:
|
||||
if isinstance(cls, type) or _is_valid_union_type(cls):
|
||||
if _is_valid_dispatch_type(cls):
|
||||
if func is None:
|
||||
return lambda f: register(cls, f)
|
||||
else:
|
||||
if func is not None:
|
||||
raise TypeError(
|
||||
f"Invalid first argument to `register()`. "
|
||||
f"{cls!r} is not a class or union type."
|
||||
)
|
||||
ann = getattr(cls, '__annotations__', {})
|
||||
if not ann:
|
||||
raise TypeError(
|
||||
|
@ -867,7 +878,7 @@ def register(cls, func=None):
|
|||
# only import typing if annotation parsing is necessary
|
||||
from typing import get_type_hints
|
||||
argname, cls = next(iter(get_type_hints(func).items()))
|
||||
if not isinstance(cls, type) and not _is_valid_union_type(cls):
|
||||
if not _is_valid_dispatch_type(cls):
|
||||
if _is_union_type(cls):
|
||||
raise TypeError(
|
||||
f"Invalid annotation for {argname!r}. "
|
||||
|
|
|
@ -2722,6 +2722,74 @@ def _(arg: int | float):
|
|||
self.assertEqual(f(1), "types.UnionType")
|
||||
self.assertEqual(f(1.0), "types.UnionType")
|
||||
|
||||
def test_register_genericalias(self):
|
||||
@functools.singledispatch
|
||||
def f(arg):
|
||||
return "default"
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(list[int], lambda arg: "types.GenericAlias")
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(typing.List[int], lambda arg: "typing.GenericAlias")
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(typing.Any, lambda arg: "typing.Any")
|
||||
|
||||
self.assertEqual(f([1]), "default")
|
||||
self.assertEqual(f([1.0]), "default")
|
||||
self.assertEqual(f(""), "default")
|
||||
self.assertEqual(f(b""), "default")
|
||||
|
||||
def test_register_genericalias_decorator(self):
|
||||
@functools.singledispatch
|
||||
def f(arg):
|
||||
return "default"
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(list[int])
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(typing.List[int])
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(list[int] | str)
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(typing.List[int] | str)
|
||||
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
|
||||
f.register(typing.Any)
|
||||
|
||||
def test_register_genericalias_annotation(self):
|
||||
@functools.singledispatch
|
||||
def f(arg):
|
||||
return "default"
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
|
||||
@f.register
|
||||
def _(arg: list[int]):
|
||||
return "types.GenericAlias"
|
||||
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
|
||||
@f.register
|
||||
def _(arg: typing.List[float]):
|
||||
return "typing.GenericAlias"
|
||||
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
|
||||
@f.register
|
||||
def _(arg: list[int] | str):
|
||||
return "types.UnionType(types.GenericAlias)"
|
||||
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
|
||||
@f.register
|
||||
def _(arg: typing.List[float] | bytes):
|
||||
return "typing.Union[typing.GenericAlias]"
|
||||
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
|
||||
@f.register
|
||||
def _(arg: typing.Any):
|
||||
return "typing.Any"
|
||||
|
||||
self.assertEqual(f([1]), "default")
|
||||
self.assertEqual(f([1.0]), "default")
|
||||
self.assertEqual(f(""), "default")
|
||||
self.assertEqual(f(b""), "default")
|
||||
|
||||
|
||||
class CachedCostItem:
|
||||
_cost = 1
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
The ``registry()`` method of :func:`functools.singledispatch` functions
|
||||
checks now the first argument or the first parameter annotation and raises a
|
||||
TypeError if it is not supported. Previously unsupported "types" were
|
||||
ignored (e.g. ``typing.List[int]``) or caused an error at calling time (e.g.
|
||||
``list[int]``).
|
Loading…
Reference in a new issue