def unique_list(seq, hashfunc=None):
return cunique_list(seq, hashfunc)
-cdef class OrderedSet:
+cdef class OrderedSet(set):
cdef list _list
- cdef set _set
def __init__(self, d=None):
+ set.__init__(self)
if d is not None:
self._list = cunique_list(d)
- self._set = set(self._list)
+ set.update(self, self._list)
else:
self._list = []
- self._set = set()
cdef OrderedSet _copy(self):
cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
cp._list = list(self._list)
- cp._set = set(cp._list)
+ set.update(cp, cp._list)
return cp
cdef OrderedSet _from_list(self, list new_list):
cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
new._list = new_list
- new._set = set(new_list)
+ set.update(new, new_list)
return new
def add(self, element):
if element not in self:
self._list.append(element)
- PySet_Add(self._set, element)
+ PySet_Add(self, element)
def remove(self, element):
# set.remove will raise if element is not in self
- self._set.remove(element)
+ set.remove(self, element)
self._list.remove(element)
def insert(self, Py_ssize_t pos, element):
if element not in self:
self._list.insert(pos, element)
- PySet_Add(self._set, element)
+ PySet_Add(self, element)
def discard(self, element):
if element in self:
- self._set.remove(element)
+ set.remove(self, element)
self._list.remove(element)
def clear(self):
- self._set.clear()
+ set.clear(self)
self._list = []
def __getitem__(self, key):
__str__ = __repr__
- def update(self, *iterables):
- for iterable in iterables:
- for e in iterable:
- if e not in self:
- self._list.append(e)
- self._set.add(e)
+ def update(self, iterable):
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ set.add(self, e)
+ return self
def __ior__(self, iterable):
- self.update(iterable)
- return self
+ return self.update(iterable)
- def union(self, other):
+ def union(self, *other):
result = self._copy()
- result.update(other)
+ for o in other:
+ result.update(o)
return result
- def __len__(self) -> int:
- return len(self._set)
-
- def __eq__(self, other):
- return self._set == other
-
- def __ne__(self, other):
- return self._set != other
-
- def __contains__(self, element):
- return element in self._set
-
def __or__(self, other):
return self.union(other)
other_set = set(other)
return other_set
- def intersection(self, other):
- cdef set other_set = self._to_set(other)
+ def intersection(self, *other):
+ cdef other_set = set.intersection(self, *other)
return self._from_list([a for a in self._list if a in other_set])
def __and__(self, other):
def __xor__(self, other):
return self.symmetric_difference(other)
- def difference(self, other):
- cdef set other_set = self._to_set(other)
- return self._from_list([a for a in self._list if a not in other_set])
+ def difference(self, *other):
+ cdef other_set = set.difference(self, *other)
+ return self._from_list([a for a in self._list if a in other_set])
def __sub__(self, other):
return self.difference(other)
- def intersection_update(self, other):
- cdef set other_set = self._to_set(other)
- set.intersection_update(self, other_set)
- self._list = [a for a in self._list if a in other_set]
+ def intersection_update(self, *other):
+ set.intersection_update(self, *other)
+ self._list = [a for a in self._list if a in self]
def __iand__(self, other):
self.intersection_update(other)
self.symmetric_difference_update(other)
return self
- def difference_update(self, other):
- set.difference_update(self, other)
+ def difference_update(self, *other):
+ set.difference_update(self, *other)
self._list = [a for a in self._list if a in self]
def __isub__(self, other):
self.difference_update(other)
return self
-
cdef object cy_id(object item):
return PyLong_FromLong(<long> (<void *>item))
from itertools import filterfalse
+from typing import AbstractSet
from typing import Any
+from typing import cast
from typing import Dict
-from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import TypeVar
+from typing import Union
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT", bound=Any)
return "immutabledict(%s)" % dict.__repr__(self)
-class OrderedSet(Generic[_T]):
- __slots__ = ("_list", "_set", "__weakref__")
+_S = TypeVar("_S", bound=Any)
+
+
+class OrderedSet(Set[_T]):
+ __slots__ = ("_list",)
_list: List[_T]
- _set: Set[_T]
def __init__(self, d=None):
if d is not None:
self._list = unique_list(d)
- self._set = set(self._list)
+ super().update(self._list)
else:
self._list = []
- self._set = set()
def __reduce__(self):
return (OrderedSet, (self._list,))
def add(self, element: _T) -> None:
if element not in self:
self._list.append(element)
- self._set.add(element)
+ super().add(element)
def remove(self, element: _T) -> None:
- self._set.remove(element)
+ super().remove(element)
self._list.remove(element)
def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
- self._set.add(element)
+ super().add(element)
def discard(self, element: _T) -> None:
if element in self:
self._list.remove(element)
- self._set.remove(element)
+ super().remove(element)
def clear(self) -> None:
- self._set.clear()
+ super().clear()
self._list = []
- def __len__(self) -> int:
- return len(self._set)
-
- def __eq__(self, other):
- if not isinstance(other, OrderedSet):
- return self._set == other
- else:
- return self._set == other._set
-
- def __ne__(self, other):
- if not isinstance(other, OrderedSet):
- return self._set != other
- else:
- return self._set != other._set
-
- def __contains__(self, element: Any) -> bool:
- return element in self._set
-
def __getitem__(self, key: int) -> _T:
return self._list[key]
for e in iterable:
if e not in self:
self._list.append(e)
- self._set.add(e)
+ super().add(e)
- def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- self.update(other)
- return self
+ def __ior__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
+ self.update(other) # type: ignore
+ return self # type: ignore
- def union(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- result = self.__class__(self)
- result.update(other)
+ def union(self, *other: Iterable[_S]) -> "OrderedSet[Union[_T, _S]]":
+ result: "OrderedSet[Union[_T, _S]]" = self.__class__(self) # type: ignore # noqa E501
+ for o in other:
+ result.update(o)
return result
- def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __or__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
return self.union(other)
- def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- other = other if isinstance(other, set) else set(other)
- return self.__class__(a for a in self if a in other)
+ def intersection(self, *other: Iterable[Any]) -> "OrderedSet[_T]":
+ other_set: Set[Any] = set()
+ other_set.update(*other)
+ return self.__class__(a for a in self if a in other_set)
- def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __and__(self, other: AbstractSet[object]) -> "OrderedSet[_T]":
return self.intersection(other)
def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
result.update(a for a in other if a not in self)
return result
- def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- return self.symmetric_difference(other)
+ def __xor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
+ return cast("OrderedSet[Union[_T, _S]]", self).symmetric_difference(
+ other
+ )
- def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
- other = other if isinstance(other, set) else set(other)
- return self.__class__(a for a in self if a not in other)
+ def difference(self, *other: Iterable[Any]) -> "OrderedSet[_T]":
+ other_set = super().difference(*other)
+ return self.__class__(a for a in self._list if a in other_set)
- def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __sub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]":
return self.difference(other)
- def intersection_update(self, other: Iterable[_T]) -> None:
- other = other if isinstance(other, set) else set(other)
- self._set.intersection_update(other)
- self._list = [a for a in self._list if a in other]
+ def intersection_update(self, *other: Iterable[Any]) -> None:
+ super().intersection_update(*other)
+ self._list = [a for a in self._list if a in self]
- def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __iand__(self, other: AbstractSet[object]) -> "OrderedSet[_T]":
self.intersection_update(other)
return self
- def symmetric_difference_update(self, other: Iterable[_T]) -> None:
- self._set.symmetric_difference_update(other)
+ def symmetric_difference_update(self, other: Iterable[Any]) -> None:
+ super().symmetric_difference_update(other)
self._list = [a for a in self._list if a in self]
self._list += [a for a in other if a in self]
- def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __ixor__(self, other: AbstractSet[_S]) -> "OrderedSet[Union[_T, _S]]":
self.symmetric_difference_update(other)
- return self
+ return cast("OrderedSet[Union[_T, _S]]", self)
- def difference_update(self, other: Iterable[_T]) -> None:
- self._set.difference_update(other)
+ def difference_update(self, *other: Iterable[Any]) -> None:
+ super().difference_update(*other)
self._list = [a for a in self._list if a in self]
- def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ def __isub__(self, other: AbstractSet[_T | None]) -> "OrderedSet[_T]":
self.difference_update(other)
return self