]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
restore set-as-superclass for OrderedSet
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Jan 2022 23:13:05 +0000 (18:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 25 Jan 2022 16:18:55 +0000 (11:18 -0500)
OrderedSet again subclasses set, spent some time
with the stubs at
https://github.com/python/typeshed/blob/master/stdlib/builtins.pyi#L887
to more deeply understand what they are doing here
so that we can type check fully.

Change-Id: Iec9b5ab43befd30e1f2c5cc40e59ab852dd28e75

lib/sqlalchemy/cyextension/collections.pyx
lib/sqlalchemy/util/_py_collections.py

index 5a344da43224bd1348bdadcdad13c2d7d9bc4784..c33a6e4a508621c51ddb6e07838aed2352395f7c 100644 (file)
@@ -22,53 +22,52 @@ cdef list cunique_list(seq, hashfunc=None):
 def unique_list(seq, hashfunc=None):
     return cunique_list(seq, hashfunc)
 
-cdef class OrderedSet:
+cdef class OrderedSet(set):
 
     cdef list _list
-    cdef set _set
 
     def __init__(self, d=None):
+        set.__init__(self)
         if d is not None:
             self._list = cunique_list(d)
-            self._set = set(self._list)
+            set.update(self, self._list)
         else:
             self._list = []
-            self._set = set()
 
     cdef OrderedSet _copy(self):
         cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
         cp._list = list(self._list)
-        cp._set = set(cp._list)
+        set.update(cp, cp._list)
         return cp
 
     cdef OrderedSet _from_list(self, list new_list):
         cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
         new._list = new_list
-        new._set = set(new_list)
+        set.update(new, new_list)
         return new
 
     def add(self, element):
         if element not in self:
             self._list.append(element)
-            PySet_Add(self._set, element)
+            PySet_Add(self, element)
 
     def remove(self, element):
         # set.remove will raise if element is not in self
-        self._set.remove(element)
+        set.remove(self, element)
         self._list.remove(element)
 
     def insert(self, Py_ssize_t pos, element):
         if element not in self:
             self._list.insert(pos, element)
-            PySet_Add(self._set, element)
+            PySet_Add(self, element)
 
     def discard(self, element):
         if element in self:
-            self._set.remove(element)
+            set.remove(self, element)
             self._list.remove(element)
 
     def clear(self):
-        self._set.clear()
+        set.clear(self)
         self._list = []
 
     def __getitem__(self, key):
@@ -85,34 +84,22 @@ cdef class OrderedSet:
 
     __str__ = __repr__
 
-    def update(self, *iterables):
-        for iterable in iterables:
-            for e in iterable:
-                if e not in self:
-                    self._list.append(e)
-                    self._set.add(e)
+    def update(self, iterable):
+        for e in iterable:
+            if e not in self:
+                self._list.append(e)
+                set.add(self, e)
+        return self
 
     def __ior__(self, iterable):
-        self.update(iterable)
-        return self
+        return self.update(iterable)
 
-    def union(self, other):
+    def union(self, *other):
         result = self._copy()
-        result.update(other)
+        for o in other:
+            result.update(o)
         return result
 
-    def __len__(self) -> int:
-        return len(self._set)
-
-    def __eq__(self, other):
-        return self._set == other
-
-    def __ne__(self, other):
-        return self._set != other
-
-    def __contains__(self, element):
-        return element in self._set
-
     def __or__(self, other):
         return self.union(other)
 
@@ -124,8 +111,8 @@ cdef class OrderedSet:
             other_set = set(other)
         return other_set
 
-    def intersection(self, other):
-        cdef set other_set = self._to_set(other)
+    def intersection(self, *other):
+        cdef other_set = set.intersection(self, *other)
         return self._from_list([a for a in self._list if a in other_set])
 
     def __and__(self, other):
@@ -141,17 +128,16 @@ cdef class OrderedSet:
     def __xor__(self, other):
         return self.symmetric_difference(other)
 
-    def difference(self, other):
-        cdef set other_set = self._to_set(other)
-        return self._from_list([a for a in self._list if a not in other_set])
+    def difference(self, *other):
+        cdef other_set = set.difference(self, *other)
+        return self._from_list([a for a in self._list if a in other_set])
 
     def __sub__(self, other):
         return self.difference(other)
 
-    def intersection_update(self, other):
-        cdef set other_set = self._to_set(other)
-        set.intersection_update(self, other_set)
-        self._list = [a for a in self._list if a in other_set]
+    def intersection_update(self, *other):
+        set.intersection_update(self, *other)
+        self._list = [a for a in self._list if a in self]
 
     def __iand__(self, other):
         self.intersection_update(other)
@@ -166,15 +152,14 @@ cdef class OrderedSet:
         self.symmetric_difference_update(other)
         return self
 
-    def difference_update(self, other):
-        set.difference_update(self, other)
+    def difference_update(self, *other):
+        set.difference_update(self, *other)
         self._list = [a for a in self._list if a in self]
 
     def __isub__(self, other):
         self.difference_update(other)
         return self
 
-
 cdef object cy_id(object item):
     return PyLong_FromLong(<long> (<void *>item))
 
index a4e4b8b5db29f9dea2c007e03bd01d3cc9deaff2..7914507cdd145298fc9ceb0f6e5a591960caaf24 100644 (file)
@@ -1,7 +1,8 @@
 from itertools import filterfalse
+from typing import AbstractSet
 from typing import Any
+from typing import cast
 from typing import Dict
-from typing import Generic
 from typing import Iterable
 from typing import Iterator
 from typing import List
@@ -9,6 +10,7 @@ from typing import NoReturn
 from typing import Optional
 from typing import Set
 from typing import TypeVar
+from typing import Union
 
 _T = TypeVar("_T", bound=Any)
 _KT = TypeVar("_KT", bound=Any)
@@ -96,19 +98,20 @@ class immutabledict(ImmutableDictBase[_KT, _VT]):
         return "immutabledict(%s)" % dict.__repr__(self)
 
 
-class OrderedSet(Generic[_T]):
-    __slots__ = ("_list", "_set", "__weakref__")
+_S = TypeVar("_S", bound=Any)
+
+
+class OrderedSet(Set[_T]):
+    __slots__ = ("_list",)
 
     _list: List[_T]
-    _set: Set[_T]
 
     def __init__(self, d=None):
         if d is not None:
             self._list = unique_list(d)
-            self._set = set(self._list)
+            super().update(self._list)
         else:
             self._list = []
-            self._set = set()
 
     def __reduce__(self):
         return (OrderedSet, (self._list,))
@@ -116,44 +119,26 @@ class OrderedSet(Generic[_T]):
     def add(self, element: _T) -> None:
         if element not in self:
             self._list.append(element)
-        self._set.add(element)
+        super().add(element)
 
     def remove(self, element: _T) -> None:
-        self._set.remove(element)
+        super().remove(element)
         self._list.remove(element)
 
     def insert(self, pos: int, element: _T) -> None:
         if element not in self:
             self._list.insert(pos, element)
-        self._set.add(element)
+        super().add(element)
 
     def discard(self, element: _T) -> None:
         if element in self:
             self._list.remove(element)
-            self._set.remove(element)
+            super().remove(element)
 
     def clear(self) -> None:
-        self._set.clear()
+        super().clear()
         self._list = []
 
-    def __len__(self) -> int:
-        return len(self._set)
-
-    def __eq__(self, other):
-        if not isinstance(other, OrderedSet):
-            return self._set == other
-        else:
-            return self._set == other._set
-
-    def __ne__(self, other):
-        if not isinstance(other, OrderedSet):
-            return self._set != other
-        else:
-            return self._set != other._set
-
-    def __contains__(self, element: Any) -> bool:
-        return element in self._set
-
     def __getitem__(self, key: int) -> _T:
         return self._list[key]
 
@@ -173,25 +158,27 @@ class OrderedSet(Generic[_T]):
             for e in iterable:
                 if e not in self:
                     self._list.append(e)
-                    self._set.add(e)
+                    super().add(e)
 
-    def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
-        self.update(other)
-        return self
+    def __ior__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
+        self.update(other)  # type: ignore
+        return self  # type: ignore
 
-    def union(self, other: Iterable[_T]) -> "OrderedSet[_T]":
-        result = self.__class__(self)
-        result.update(other)
+    def union(self, *other: Iterable[_S]) -> "OrderedSet[Union[_T, _S]]":
+        result: "OrderedSet[Union[_T, _S]]" = self.__class__(self)  # type: ignore  # noqa E501
+        for o in other:
+            result.update(o)
         return result
 
-    def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+    def __or__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
         return self.union(other)
 
-    def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]":
-        other = other if isinstance(other, set) else set(other)
-        return self.__class__(a for a in self if a in other)
+    def intersection(self, *other: Iterable[Any]) -> "OrderedSet[_T]":
+        other_set: Set[Any] = set()
+        other_set.update(*other)
+        return self.__class__(a for a in self if a in other_set)
 
-    def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+    def __and__(self, other: AbstractSet[object]) -> "OrderedSet[_T]":
         return self.intersection(other)
 
     def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
@@ -200,39 +187,40 @@ class OrderedSet(Generic[_T]):
         result.update(a for a in other if a not in self)
         return result
 
-    def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
-        return self.symmetric_difference(other)
+    def __xor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
+        return cast("OrderedSet[Union[_T, _S]]", self).symmetric_difference(
+            other
+        )
 
-    def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
-        other = other if isinstance(other, set) else set(other)
-        return self.__class__(a for a in self if a not in other)
+    def difference(self, *other: Iterable[Any]) -> "OrderedSet[_T]":
+        other_set = super().difference(*other)
+        return self.__class__(a for a in self._list if a in other_set)
 
-    def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+    def __sub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]":
         return self.difference(other)
 
-    def intersection_update(self, other: Iterable[_T]) -> None:
-        other = other if isinstance(other, set) else set(other)
-        self._set.intersection_update(other)
-        self._list = [a for a in self._list if a in other]
+    def intersection_update(self, *other: Iterable[Any]) -> None:
+        super().intersection_update(*other)
+        self._list = [a for a in self._list if a in self]
 
-    def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+    def __iand__(self, other: AbstractSet[object]) -> "OrderedSet[_T]":
         self.intersection_update(other)
         return self
 
-    def symmetric_difference_update(self, other: Iterable[_T]) -> None:
-        self._set.symmetric_difference_update(other)
+    def symmetric_difference_update(self, other: Iterable[Any]) -> None:
+        super().symmetric_difference_update(other)
         self._list = [a for a in self._list if a in self]
         self._list += [a for a in other if a in self]
 
-    def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+    def __ixor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
         self.symmetric_difference_update(other)
-        return self
+        return cast("OrderedSet[Union[_T, _S]]", self)
 
-    def difference_update(self, other: Iterable[_T]) -> None:
-        self._set.difference_update(other)
+    def difference_update(self, *other: Iterable[Any]) -> None:
+        super().difference_update(*other)
         self._list = [a for a in self._list if a in self]
 
-    def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+    def __isub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]":
         self.difference_update(other)
         return self