From: Mike Bayer Date: Mon, 24 Jan 2022 23:13:05 +0000 (-0500) Subject: restore set-as-superclass for OrderedSet X-Git-Tag: rel_2_0_0b1~510^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5aee5fe12afdeb4569e588344f00aa56c9250215;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git restore set-as-superclass for OrderedSet OrderedSet again subclasses set, spent some time with the stubs at https://github.com/python/typeshed/blob/master/stdlib/builtins.pyi#L887 to more deeply understand what they are doing here so that we can type check fully. Change-Id: Iec9b5ab43befd30e1f2c5cc40e59ab852dd28e75 --- diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index 5a344da432..c33a6e4a50 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -22,53 +22,52 @@ cdef list cunique_list(seq, hashfunc=None): def unique_list(seq, hashfunc=None): return cunique_list(seq, hashfunc) -cdef class OrderedSet: +cdef class OrderedSet(set): cdef list _list - cdef set _set def __init__(self, d=None): + set.__init__(self) if d is not None: self._list = cunique_list(d) - self._set = set(self._list) + set.update(self, self._list) else: self._list = [] - self._set = set() cdef OrderedSet _copy(self): cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) cp._list = list(self._list) - cp._set = set(cp._list) + set.update(cp, cp._list) return cp cdef OrderedSet _from_list(self, list new_list): cdef OrderedSet new = OrderedSet.__new__(OrderedSet) new._list = new_list - new._set = set(new_list) + set.update(new, new_list) return new def add(self, element): if element not in self: self._list.append(element) - PySet_Add(self._set, element) + PySet_Add(self, element) def remove(self, element): # set.remove will raise if element is not in self - self._set.remove(element) + set.remove(self, element) self._list.remove(element) def insert(self, Py_ssize_t pos, element): if element not in self: self._list.insert(pos, element) - PySet_Add(self._set, element) + PySet_Add(self, element) def discard(self, element): if element in self: - self._set.remove(element) + set.remove(self, element) self._list.remove(element) def clear(self): - self._set.clear() + set.clear(self) self._list = [] def __getitem__(self, key): @@ -85,34 +84,22 @@ cdef class OrderedSet: __str__ = __repr__ - def update(self, *iterables): - for iterable in iterables: - for e in iterable: - if e not in self: - self._list.append(e) - self._set.add(e) + def update(self, iterable): + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + return self def __ior__(self, iterable): - self.update(iterable) - return self + return self.update(iterable) - def union(self, other): + def union(self, *other): result = self._copy() - result.update(other) + for o in other: + result.update(o) return result - def __len__(self) -> int: - return len(self._set) - - def __eq__(self, other): - return self._set == other - - def __ne__(self, other): - return self._set != other - - def __contains__(self, element): - return element in self._set - def __or__(self, other): return self.union(other) @@ -124,8 +111,8 @@ cdef class OrderedSet: other_set = set(other) return other_set - def intersection(self, other): - cdef set other_set = self._to_set(other) + def intersection(self, *other): + cdef other_set = set.intersection(self, *other) return self._from_list([a for a in self._list if a in other_set]) def __and__(self, other): @@ -141,17 +128,16 @@ cdef class OrderedSet: def __xor__(self, other): return self.symmetric_difference(other) - def difference(self, other): - cdef set other_set = self._to_set(other) - return self._from_list([a for a in self._list if a not in other_set]) + def difference(self, *other): + cdef other_set = set.difference(self, *other) + return self._from_list([a for a in self._list if a in other_set]) def __sub__(self, other): return self.difference(other) - def intersection_update(self, other): - cdef set other_set = self._to_set(other) - set.intersection_update(self, other_set) - self._list = [a for a in self._list if a in other_set] + def intersection_update(self, *other): + set.intersection_update(self, *other) + self._list = [a for a in self._list if a in self] def __iand__(self, other): self.intersection_update(other) @@ -166,15 +152,14 @@ cdef class OrderedSet: self.symmetric_difference_update(other) return self - def difference_update(self, other): - set.difference_update(self, other) + def difference_update(self, *other): + set.difference_update(self, *other) self._list = [a for a in self._list if a in self] def __isub__(self, other): self.difference_update(other) return self - cdef object cy_id(object item): return PyLong_FromLong( (item)) diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index a4e4b8b5db..7914507cdd 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -1,7 +1,8 @@ from itertools import filterfalse +from typing import AbstractSet from typing import Any +from typing import cast from typing import Dict -from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -9,6 +10,7 @@ from typing import NoReturn from typing import Optional from typing import Set from typing import TypeVar +from typing import Union _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT", bound=Any) @@ -96,19 +98,20 @@ class immutabledict(ImmutableDictBase[_KT, _VT]): return "immutabledict(%s)" % dict.__repr__(self) -class OrderedSet(Generic[_T]): - __slots__ = ("_list", "_set", "__weakref__") +_S = TypeVar("_S", bound=Any) + + +class OrderedSet(Set[_T]): + __slots__ = ("_list",) _list: List[_T] - _set: Set[_T] def __init__(self, d=None): if d is not None: self._list = unique_list(d) - self._set = set(self._list) + super().update(self._list) else: self._list = [] - self._set = set() def __reduce__(self): return (OrderedSet, (self._list,)) @@ -116,44 +119,26 @@ class OrderedSet(Generic[_T]): def add(self, element: _T) -> None: if element not in self: self._list.append(element) - self._set.add(element) + super().add(element) def remove(self, element: _T) -> None: - self._set.remove(element) + super().remove(element) self._list.remove(element) def insert(self, pos: int, element: _T) -> None: if element not in self: self._list.insert(pos, element) - self._set.add(element) + super().add(element) def discard(self, element: _T) -> None: if element in self: self._list.remove(element) - self._set.remove(element) + super().remove(element) def clear(self) -> None: - self._set.clear() + super().clear() self._list = [] - def __len__(self) -> int: - return len(self._set) - - def __eq__(self, other): - if not isinstance(other, OrderedSet): - return self._set == other - else: - return self._set == other._set - - def __ne__(self, other): - if not isinstance(other, OrderedSet): - return self._set != other - else: - return self._set != other._set - - def __contains__(self, element: Any) -> bool: - return element in self._set - def __getitem__(self, key: int) -> _T: return self._list[key] @@ -173,25 +158,27 @@ class OrderedSet(Generic[_T]): for e in iterable: if e not in self: self._list.append(e) - self._set.add(e) + super().add(e) - def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]": - self.update(other) - return self + def __ior__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": + self.update(other) # type: ignore + return self # type: ignore - def union(self, other: Iterable[_T]) -> "OrderedSet[_T]": - result = self.__class__(self) - result.update(other) + def union(self, *other: Iterable[_S]) -> "OrderedSet[Union[_T, _S]]": + result: "OrderedSet[Union[_T, _S]]" = self.__class__(self) # type: ignore # noqa E501 + for o in other: + result.update(o) return result - def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __or__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": return self.union(other) - def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]": - other = other if isinstance(other, set) else set(other) - return self.__class__(a for a in self if a in other) + def intersection(self, *other: Iterable[Any]) -> "OrderedSet[_T]": + other_set: Set[Any] = set() + other_set.update(*other) + return self.__class__(a for a in self if a in other_set) - def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __and__(self, other: AbstractSet[object]) -> "OrderedSet[_T]": return self.intersection(other) def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": @@ -200,39 +187,40 @@ class OrderedSet(Generic[_T]): result.update(a for a in other if a not in self) return result - def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": - return self.symmetric_difference(other) + def __xor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": + return cast("OrderedSet[Union[_T, _S]]", self).symmetric_difference( + other + ) - def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]": - other = other if isinstance(other, set) else set(other) - return self.__class__(a for a in self if a not in other) + def difference(self, *other: Iterable[Any]) -> "OrderedSet[_T]": + other_set = super().difference(*other) + return self.__class__(a for a in self._list if a in other_set) - def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __sub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]": return self.difference(other) - def intersection_update(self, other: Iterable[_T]) -> None: - other = other if isinstance(other, set) else set(other) - self._set.intersection_update(other) - self._list = [a for a in self._list if a in other] + def intersection_update(self, *other: Iterable[Any]) -> None: + super().intersection_update(*other) + self._list = [a for a in self._list if a in self] - def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __iand__(self, other: AbstractSet[object]) -> "OrderedSet[_T]": self.intersection_update(other) return self - def symmetric_difference_update(self, other: Iterable[_T]) -> None: - self._set.symmetric_difference_update(other) + def symmetric_difference_update(self, other: Iterable[Any]) -> None: + super().symmetric_difference_update(other) self._list = [a for a in self._list if a in self] self._list += [a for a in other if a in self] - def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __ixor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]": self.symmetric_difference_update(other) - return self + return cast("OrderedSet[Union[_T, _S]]", self) - def difference_update(self, other: Iterable[_T]) -> None: - self._set.difference_update(other) + def difference_update(self, *other: Iterable[Any]) -> None: + super().difference_update(*other) self._list = [a for a in self._list if a in self] - def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]": + def __isub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]": self.difference_update(other) return self