import contextlib
import collections
from collections import defaultdict
-from functools import lru_cache, wraps
+from functools import lru_cache, wraps, reduce
import inspect
import itertools
import gc
+import operator
import pickle
import re
import sys
v = Union[u, Employee]
self.assertEqual(v, Union[int, float, Employee])
+ def test_union_of_unhashable(self):
+ class UnhashableMeta(type):
+ __hash__ = None
+
+ class A(metaclass=UnhashableMeta): ...
+ class B(metaclass=UnhashableMeta): ...
+
+ self.assertEqual(Union[A, B].__args__, (A, B))
+ union1 = Union[A, B]
+ with self.assertRaises(TypeError):
+ hash(union1)
+
+ union2 = Union[int, B]
+ with self.assertRaises(TypeError):
+ hash(union2)
+
+ union3 = Union[A, int]
+ with self.assertRaises(TypeError):
+ hash(union3)
+
def test_repr(self):
self.assertEqual(repr(Union), 'typing.Union')
u = Union[Employee, int]
self.assertEqual(A.__metadata__, (4, 5))
self.assertEqual(A.__origin__, int)
+ def test_deduplicate_from_union(self):
+ # Regular:
+ self.assertEqual(get_args(Annotated[int, 1] | int),
+ (Annotated[int, 1], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], int]),
+ (Annotated[int, 1], int))
+ self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
+ (Annotated[int, 1], Annotated[int, 2], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
+ (Annotated[int, 1], Annotated[int, 2], int))
+ self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
+ (Annotated[int, 1], Annotated[str, 1], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
+ (Annotated[int, 1], Annotated[str, 1], int))
+
+ # Duplicates:
+ self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
+ Annotated[int, 1] | int)
+ self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
+ Union[Annotated[int, 1], int])
+
+ # Unhashable metadata:
+ self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
+ (str, Annotated[int, {}], Annotated[int, set()], int))
+ self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
+ (str, Annotated[int, {}], Annotated[int, set()], int))
+ self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
+ (str, Annotated[int, {}], Annotated[str, {}], int))
+ self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
+ (str, Annotated[int, {}], Annotated[str, {}], int))
+
+ self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
+ (Annotated[int, 1], str, Annotated[str, {}], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
+ (Annotated[int, 1], str, Annotated[str, {}], int))
+
+ import dataclasses
+ @dataclasses.dataclass
+ class ValueRange:
+ lo: int
+ hi: int
+ v = ValueRange(1, 2)
+ self.assertEqual(get_args(Annotated[int, v] | None),
+ (Annotated[int, v], types.NoneType))
+ self.assertEqual(get_args(Union[Annotated[int, v], None]),
+ (Annotated[int, v], types.NoneType))
+ self.assertEqual(get_args(Optional[Annotated[int, v]]),
+ (Annotated[int, v], types.NoneType))
+
+ # Unhashable metadata duplicated:
+ self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
+ Annotated[int, {}] | int)
+ self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
+ int | Annotated[int, {}])
+ self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
+ Union[Annotated[int, {}], int])
+ self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
+ Union[int, Annotated[int, {}]])
+
+ def test_order_in_union(self):
+ expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
+ for args in itertools.permutations(get_args(expr1)):
+ with self.subTest(args=args):
+ self.assertEqual(expr1, reduce(operator.or_, args))
+
+ expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
+ for args in itertools.permutations(get_args(expr2)):
+ with self.subTest(args=args):
+ self.assertEqual(expr2, Union[args])
+
def test_specialize(self):
L = Annotated[List[T], "my decoration"]
LI = Annotated[List[int], "my decoration"]
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
)
+ # Unhashable `metadata` raises `TypeError`:
+ a1 = Annotated[int, []]
+ with self.assertRaises(TypeError):
+ hash(a1)
+
+ class A:
+ __hash__ = None
+ a2 = Annotated[int, A()]
+ with self.assertRaises(TypeError):
+ hash(a2)
def test_instantiate(self):
class C:
newargs.append(arg)
return newargs
-def _deduplicate(params):
+def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
- all_params = set(params)
- if len(all_params) < len(params):
- new_params = []
- for t in params:
- if t in all_params:
- new_params.append(t)
- all_params.remove(t)
- params = new_params
- assert not all_params, all_params
- return params
-
+ try:
+ return dict.fromkeys(params)
+ except TypeError:
+ if not unhashable_fallback:
+ raise
+ # Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
+ return _deduplicate_unhashable(params)
+
+def _deduplicate_unhashable(unhashable_params):
+ new_unhashable = []
+ for t in unhashable_params:
+ if t not in new_unhashable:
+ new_unhashable.append(t)
+ return new_unhashable
+
+def _compare_args_orderless(first_args, second_args):
+ first_unhashable = _deduplicate_unhashable(first_args)
+ second_unhashable = _deduplicate_unhashable(second_args)
+ t = list(second_unhashable)
+ try:
+ for elem in first_unhashable:
+ t.remove(elem)
+ except ValueError:
+ return False
+ return not t
def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
else:
params.append(p)
- return tuple(_deduplicate(params))
+ return tuple(_deduplicate(params, unhashable_fallback=True))
def _flatten_literal_params(parameters):
def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
return NotImplemented
- return set(self.__args__) == set(other.__args__)
+ try: # fast path
+ return set(self.__args__) == set(other.__args__)
+ except TypeError: # not hashable, slow path
+ return _compare_args_orderless(self.__args__, other.__args__)
def __hash__(self):
return hash(frozenset(self.__args__))