]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
authorYurii Karabas <1998uriyyo@gmail.com>
Fri, 10 Dec 2021 23:27:55 +0000 (01:27 +0200)
committerGitHub <noreply@github.com>
Fri, 10 Dec 2021 23:27:55 +0000 (00:27 +0100)
Lib/functools.py
Lib/test/test_functools.py
Misc/NEWS.d/next/Library/2021-12-10-03-13-57.bpo-46014.3xYdST.rst [new file with mode: 0644]

index 77ec852805c10456cdb5141942c5bbdc81457a04..ccac6f89996b61776b9dcc998a5a2f20d23599aa 100644 (file)
@@ -837,6 +837,14 @@ def singledispatch(func):
             dispatch_cache[cls] = impl
         return impl
 
+    def _is_union_type(cls):
+        from typing import get_origin, Union
+        return get_origin(cls) in {Union, types.UnionType}
+
+    def _is_valid_union_type(cls):
+        from typing import get_args
+        return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
+
     def register(cls, func=None):
         """generic_func.register(cls, func) -> func
 
@@ -845,7 +853,7 @@ def singledispatch(func):
         """
         nonlocal cache_token
         if func is None:
-            if isinstance(cls, type):
+            if isinstance(cls, type) or _is_valid_union_type(cls):
                 return lambda f: register(cls, f)
             ann = getattr(cls, '__annotations__', {})
             if not ann:
@@ -859,12 +867,25 @@ def singledispatch(func):
             # only import typing if annotation parsing is necessary
             from typing import get_type_hints
             argname, cls = next(iter(get_type_hints(func).items()))
-            if not isinstance(cls, type):
-                raise TypeError(
-                    f"Invalid annotation for {argname!r}. "
-                    f"{cls!r} is not a class."
-                )
-        registry[cls] = func
+            if not isinstance(cls, type) and not _is_valid_union_type(cls):
+                if _is_union_type(cls):
+                    raise TypeError(
+                        f"Invalid annotation for {argname!r}. "
+                        f"{cls!r} not all arguments are classes."
+                    )
+                else:
+                    raise TypeError(
+                        f"Invalid annotation for {argname!r}. "
+                        f"{cls!r} is not a class."
+                    )
+
+        if _is_union_type(cls):
+            from typing import get_args
+
+            for arg in get_args(cls):
+                registry[arg] = func
+        else:
+            registry[cls] = func
         if cache_token is None and hasattr(cls, '__abstractmethods__'):
             cache_token = get_cache_token()
         dispatch_cache.clear()
index 08cf457cc17db549043f6e58edcf0c0b9035e2aa..755ac038792b7959797e865b6142fda9224aee21 100644 (file)
@@ -2684,6 +2684,17 @@ class TestSingleDispatch(unittest.TestCase):
             'typing.Iterable[str] is not a class.'
         ))
 
+        with self.assertRaises(TypeError) as exc:
+            @i.register
+            def _(arg: typing.Union[int, typing.Iterable[str]]):
+                return "Invalid Union"
+        self.assertTrue(str(exc.exception).startswith(
+            "Invalid annotation for 'arg'."
+        ))
+        self.assertTrue(str(exc.exception).endswith(
+            'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
+        ))
+
     def test_invalid_positional_argument(self):
         @functools.singledispatch
         def f(*args):
@@ -2692,6 +2703,25 @@ class TestSingleDispatch(unittest.TestCase):
         with self.assertRaisesRegex(TypeError, msg):
             f()
 
+    def test_union(self):
+        @functools.singledispatch
+        def f(arg):
+            return "default"
+
+        @f.register
+        def _(arg: typing.Union[str, bytes]):
+            return "typing.Union"
+
+        @f.register
+        def _(arg: int | float):
+            return "types.UnionType"
+
+        self.assertEqual(f([]), "default")
+        self.assertEqual(f(""), "typing.Union")
+        self.assertEqual(f(b""), "typing.Union")
+        self.assertEqual(f(1), "types.UnionType")
+        self.assertEqual(f(1.0), "types.UnionType")
+
 
 class CachedCostItem:
     _cost = 1
diff --git a/Misc/NEWS.d/next/Library/2021-12-10-03-13-57.bpo-46014.3xYdST.rst b/Misc/NEWS.d/next/Library/2021-12-10-03-13-57.bpo-46014.3xYdST.rst
new file mode 100644 (file)
index 0000000..90aacaf
--- /dev/null
@@ -0,0 +1,2 @@
+Add ability to use ``typing.Union`` and ``types.UnionType`` as dispatch
+argument to ``functools.singledispatch``. Patch provided by Yurii Karabas.