]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-46032: Check types in singledispatch's register() at declaration time (GH-30050)
authorSerhiy Storchaka <storchaka@gmail.com>
Sat, 25 Dec 2021 12:16:14 +0000 (14:16 +0200)
committerGitHub <noreply@github.com>
Sat, 25 Dec 2021 12:16:14 +0000 (14:16 +0200)
The registry() method of functools.singledispatch() functions checks now
the first argument or the first parameter annotation and raises a TypeError if it is
not supported. Previously unsupported "types" were ignored (e.g. typing.List[int])
or caused an error at calling time (e.g. list[int]).

Lib/functools.py
Lib/test/test_functools.py
Misc/NEWS.d/next/Library/2021-12-11-15-45-07.bpo-46032.HmciLT.rst [new file with mode: 0644]

index ccac6f89996b61776b9dcc998a5a2f20d23599aa..91b678c2269662d3bb058c1f90f7af4dd64a7952 100644 (file)
@@ -740,6 +740,7 @@ def _compose_mro(cls, types):
     # Remove entries which are already present in the __mro__ or unrelated.
     def is_related(typ):
         return (typ not in bases and hasattr(typ, '__mro__')
+                                 and not isinstance(typ, GenericAlias)
                                  and issubclass(cls, typ))
     types = [n for n in types if is_related(n)]
     # Remove entries which are strict bases of other entries (they will end up
@@ -841,9 +842,13 @@ def singledispatch(func):
         from typing import get_origin, Union
         return get_origin(cls) in {Union, types.UnionType}
 
-    def _is_valid_union_type(cls):
+    def _is_valid_dispatch_type(cls):
+        if isinstance(cls, type) and not isinstance(cls, GenericAlias):
+            return True
         from typing import get_args
-        return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
+        return (_is_union_type(cls) and
+                all(isinstance(arg, type) and not isinstance(arg, GenericAlias)
+                    for arg in get_args(cls)))
 
     def register(cls, func=None):
         """generic_func.register(cls, func) -> func
@@ -852,9 +857,15 @@ def singledispatch(func):
 
         """
         nonlocal cache_token
-        if func is None:
-            if isinstance(cls, type) or _is_valid_union_type(cls):
+        if _is_valid_dispatch_type(cls):
+            if func is None:
                 return lambda f: register(cls, f)
+        else:
+            if func is not None:
+                raise TypeError(
+                    f"Invalid first argument to `register()`. "
+                    f"{cls!r} is not a class or union type."
+                )
             ann = getattr(cls, '__annotations__', {})
             if not ann:
                 raise TypeError(
@@ -867,7 +878,7 @@ def singledispatch(func):
             # only import typing if annotation parsing is necessary
             from typing import get_type_hints
             argname, cls = next(iter(get_type_hints(func).items()))
-            if not isinstance(cls, type) and not _is_valid_union_type(cls):
+            if not _is_valid_dispatch_type(cls):
                 if _is_union_type(cls):
                     raise TypeError(
                         f"Invalid annotation for {argname!r}. "
index 755ac038792b7959797e865b6142fda9224aee21..70ae8e06bb47584cecc67be3a73afd2eb42d087c 100644 (file)
@@ -2722,6 +2722,74 @@ class TestSingleDispatch(unittest.TestCase):
         self.assertEqual(f(1), "types.UnionType")
         self.assertEqual(f(1.0), "types.UnionType")
 
+    def test_register_genericalias(self):
+        @functools.singledispatch
+        def f(arg):
+            return "default"
+
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(list[int], lambda arg: "types.GenericAlias")
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(typing.List[int], lambda arg: "typing.GenericAlias")
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(typing.Any, lambda arg: "typing.Any")
+
+        self.assertEqual(f([1]), "default")
+        self.assertEqual(f([1.0]), "default")
+        self.assertEqual(f(""), "default")
+        self.assertEqual(f(b""), "default")
+
+    def test_register_genericalias_decorator(self):
+        @functools.singledispatch
+        def f(arg):
+            return "default"
+
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(list[int])
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(typing.List[int])
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(list[int] | str)
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(typing.List[int] | str)
+        with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
+            f.register(typing.Any)
+
+    def test_register_genericalias_annotation(self):
+        @functools.singledispatch
+        def f(arg):
+            return "default"
+
+        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+            @f.register
+            def _(arg: list[int]):
+                return "types.GenericAlias"
+        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+            @f.register
+            def _(arg: typing.List[float]):
+                return "typing.GenericAlias"
+        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+            @f.register
+            def _(arg: list[int] | str):
+                return "types.UnionType(types.GenericAlias)"
+        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+            @f.register
+            def _(arg: typing.List[float] | bytes):
+                return "typing.Union[typing.GenericAlias]"
+        with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
+            @f.register
+            def _(arg: typing.Any):
+                return "typing.Any"
+
+        self.assertEqual(f([1]), "default")
+        self.assertEqual(f([1.0]), "default")
+        self.assertEqual(f(""), "default")
+        self.assertEqual(f(b""), "default")
+
 
 class CachedCostItem:
     _cost = 1
diff --git a/Misc/NEWS.d/next/Library/2021-12-11-15-45-07.bpo-46032.HmciLT.rst b/Misc/NEWS.d/next/Library/2021-12-11-15-45-07.bpo-46032.HmciLT.rst
new file mode 100644 (file)
index 0000000..97a553d
--- /dev/null
@@ -0,0 +1,5 @@
+The ``registry()`` method of :func:`functools.singledispatch` functions
+checks now the first argument or the first parameter annotation and raises a
+TypeError if it is not supported. Previously unsupported "types" were
+ignored (e.g. ``typing.List[int]``) or caused an error at calling time (e.g.
+``list[int]``).