]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.12] gh-112281: Allow `Union` with unhashable `Annotated` metadata (GH-112283)...
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Fri, 1 Mar 2024 18:01:27 +0000 (19:01 +0100)
committerGitHub <noreply@github.com>
Fri, 1 Mar 2024 18:01:27 +0000 (18:01 +0000)
Co-authored-by: Nikita Sobolev <mail@sobolevn.me>
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Lib/test/test_types.py
Lib/test/test_typing.py
Lib/typing.py
Misc/NEWS.d/next/Library/2023-11-20-16-15-44.gh-issue-112281.gH4EVk.rst [new file with mode: 0644]

index b86392f43cc5bee5c441380af45a617d575082a9..5ffe4085f095481369b6a6ed0edf71c4e0688d5a 100644 (file)
@@ -709,6 +709,26 @@ class UnionTests(unittest.TestCase):
         self.assertEqual(hash(int | str), hash(str | int))
         self.assertEqual(hash(int | str), hash(typing.Union[int, str]))
 
+    def test_union_of_unhashable(self):
+        class UnhashableMeta(type):
+            __hash__ = None
+
+        class A(metaclass=UnhashableMeta): ...
+        class B(metaclass=UnhashableMeta): ...
+
+        self.assertEqual((A | B).__args__, (A, B))
+        union1 = A | B
+        with self.assertRaises(TypeError):
+            hash(union1)
+
+        union2 = int | B
+        with self.assertRaises(TypeError):
+            hash(union2)
+
+        union3 = A | int
+        with self.assertRaises(TypeError):
+            hash(union3)
+
     def test_instancecheck_and_subclasscheck(self):
         for x in (int | str, typing.Union[int, str]):
             with self.subTest(x=x):
index 7f9c10dd2a54e6bd4411d7007bb17ad41715beb5..e0f71464796bcaa8a6b1cc68d7e74635b68d7253 100644 (file)
@@ -2,10 +2,11 @@ import contextlib
 import collections
 import collections.abc
 from collections import defaultdict
-from functools import lru_cache, wraps
+from functools import lru_cache, wraps, reduce
 import gc
 import inspect
 import itertools
+import operator
 import pickle
 import re
 import sys
@@ -1770,6 +1771,26 @@ class UnionTests(BaseTestCase):
         v = Union[u, Employee]
         self.assertEqual(v, Union[int, float, Employee])
 
+    def test_union_of_unhashable(self):
+        class UnhashableMeta(type):
+            __hash__ = None
+
+        class A(metaclass=UnhashableMeta): ...
+        class B(metaclass=UnhashableMeta): ...
+
+        self.assertEqual(Union[A, B].__args__, (A, B))
+        union1 = Union[A, B]
+        with self.assertRaises(TypeError):
+            hash(union1)
+
+        union2 = Union[int, B]
+        with self.assertRaises(TypeError):
+            hash(union2)
+
+        union3 = Union[A, int]
+        with self.assertRaises(TypeError):
+            hash(union3)
+
     def test_repr(self):
         self.assertEqual(repr(Union), 'typing.Union')
         u = Union[Employee, int]
@@ -5295,10 +5316,8 @@ class OverrideDecoratorTests(BaseTestCase):
         self.assertFalse(hasattr(WithOverride.some, "__override__"))
 
     def test_multiple_decorators(self):
-        import functools
-
         def with_wraps(f):  # similar to `lru_cache` definition
-            @functools.wraps(f)
+            @wraps(f)
             def wrapper(*args, **kwargs):
                 return f(*args, **kwargs)
             return wrapper
@@ -8183,6 +8202,76 @@ class AnnotatedTests(BaseTestCase):
         self.assertEqual(A.__metadata__, (4, 5))
         self.assertEqual(A.__origin__, int)
 
+    def test_deduplicate_from_union(self):
+        # Regular:
+        self.assertEqual(get_args(Annotated[int, 1] | int),
+                         (Annotated[int, 1], int))
+        self.assertEqual(get_args(Union[Annotated[int, 1], int]),
+                         (Annotated[int, 1], int))
+        self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
+                         (Annotated[int, 1], Annotated[int, 2], int))
+        self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
+                         (Annotated[int, 1], Annotated[int, 2], int))
+        self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
+                         (Annotated[int, 1], Annotated[str, 1], int))
+        self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
+                         (Annotated[int, 1], Annotated[str, 1], int))
+
+        # Duplicates:
+        self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
+                         Annotated[int, 1] | int)
+        self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
+                         Union[Annotated[int, 1], int])
+
+        # Unhashable metadata:
+        self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
+                         (str, Annotated[int, {}], Annotated[int, set()], int))
+        self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
+                         (str, Annotated[int, {}], Annotated[int, set()], int))
+        self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
+                         (str, Annotated[int, {}], Annotated[str, {}], int))
+        self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
+                         (str, Annotated[int, {}], Annotated[str, {}], int))
+
+        self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
+                         (Annotated[int, 1], str, Annotated[str, {}], int))
+        self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
+                         (Annotated[int, 1], str, Annotated[str, {}], int))
+
+        import dataclasses
+        @dataclasses.dataclass
+        class ValueRange:
+            lo: int
+            hi: int
+        v = ValueRange(1, 2)
+        self.assertEqual(get_args(Annotated[int, v] | None),
+                         (Annotated[int, v], types.NoneType))
+        self.assertEqual(get_args(Union[Annotated[int, v], None]),
+                         (Annotated[int, v], types.NoneType))
+        self.assertEqual(get_args(Optional[Annotated[int, v]]),
+                         (Annotated[int, v], types.NoneType))
+
+        # Unhashable metadata duplicated:
+        self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
+                         Annotated[int, {}] | int)
+        self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
+                         int | Annotated[int, {}])
+        self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
+                         Union[Annotated[int, {}], int])
+        self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
+                         Union[int, Annotated[int, {}]])
+
+    def test_order_in_union(self):
+        expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
+        for args in itertools.permutations(get_args(expr1)):
+            with self.subTest(args=args):
+                self.assertEqual(expr1, reduce(operator.or_, args))
+
+        expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
+        for args in itertools.permutations(get_args(expr2)):
+            with self.subTest(args=args):
+                self.assertEqual(expr2, Union[args])
+
     def test_specialize(self):
         L = Annotated[List[T], "my decoration"]
         LI = Annotated[List[int], "my decoration"]
@@ -8203,6 +8292,16 @@ class AnnotatedTests(BaseTestCase):
             {Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
             {Annotated[int, 4, 5], Annotated[T, 4, 5]}
         )
+        # Unhashable `metadata` raises `TypeError`:
+        a1 = Annotated[int, []]
+        with self.assertRaises(TypeError):
+            hash(a1)
+
+        class A:
+            __hash__ = None
+        a2 = Annotated[int, A()]
+        with self.assertRaises(TypeError):
+            hash(a2)
 
     def test_instantiate(self):
         class C:
index 1e4c725be473b74eb247e1017eb1b3648d2b3f44..7581c16119d851782532a62f70a79f4356471c01 100644 (file)
@@ -314,19 +314,33 @@ def _unpack_args(args):
             newargs.append(arg)
     return newargs
 
-def _deduplicate(params):
+def _deduplicate(params, *, unhashable_fallback=False):
     # Weed out strict duplicates, preserving the first of each occurrence.
-    all_params = set(params)
-    if len(all_params) < len(params):
-        new_params = []
-        for t in params:
-            if t in all_params:
-                new_params.append(t)
-                all_params.remove(t)
-        params = new_params
-        assert not all_params, all_params
-    return params
-
+    try:
+        return dict.fromkeys(params)
+    except TypeError:
+        if not unhashable_fallback:
+            raise
+        # Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
+        return _deduplicate_unhashable(params)
+
+def _deduplicate_unhashable(unhashable_params):
+    new_unhashable = []
+    for t in unhashable_params:
+        if t not in new_unhashable:
+            new_unhashable.append(t)
+    return new_unhashable
+
+def _compare_args_orderless(first_args, second_args):
+    first_unhashable = _deduplicate_unhashable(first_args)
+    second_unhashable = _deduplicate_unhashable(second_args)
+    t = list(second_unhashable)
+    try:
+        for elem in first_unhashable:
+            t.remove(elem)
+    except ValueError:
+        return False
+    return not t
 
 def _remove_dups_flatten(parameters):
     """Internal helper for Union creation and substitution.
@@ -341,7 +355,7 @@ def _remove_dups_flatten(parameters):
         else:
             params.append(p)
 
-    return tuple(_deduplicate(params))
+    return tuple(_deduplicate(params, unhashable_fallback=True))
 
 
 def _flatten_literal_params(parameters):
@@ -1548,7 +1562,10 @@ class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
     def __eq__(self, other):
         if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
             return NotImplemented
-        return set(self.__args__) == set(other.__args__)
+        try:  # fast path
+            return set(self.__args__) == set(other.__args__)
+        except TypeError:  # not hashable, slow path
+            return _compare_args_orderless(self.__args__, other.__args__)
 
     def __hash__(self):
         return hash(frozenset(self.__args__))
diff --git a/Misc/NEWS.d/next/Library/2023-11-20-16-15-44.gh-issue-112281.gH4EVk.rst b/Misc/NEWS.d/next/Library/2023-11-20-16-15-44.gh-issue-112281.gH4EVk.rst
new file mode 100644 (file)
index 0000000..01f6689
--- /dev/null
@@ -0,0 +1,2 @@
+Allow creating :ref:`union of types<types-union>` for
+:class:`typing.Annotated` with unhashable metadata.