]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add missing methods to OrderedSet.
authorFederico Caselli <cfederico87@gmail.com>
Tue, 14 Mar 2023 22:17:07 +0000 (23:17 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 30 Mar 2023 20:18:11 +0000 (22:18 +0200)
Implemented missing method ``copy`` and ``pop`` in OrderedSet class.

Fixes: #9487
Change-Id: I1d2278b64939b44422e9d5857ec7d345fff53997

doc/build/changelog/unreleased_20/9487.rst [new file with mode: 0644]
lib/sqlalchemy/cyextension/collections.pyx
lib/sqlalchemy/util/_py_collections.py
test/base/test_utils.py

diff --git a/doc/build/changelog/unreleased_20/9487.rst b/doc/build/changelog/unreleased_20/9487.rst
new file mode 100644 (file)
index 0000000..627be0e
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, util
+    :tickets: 9487
+
+    Implemented missing methods ``copy`` and ``pop`` in
+    OrderedSet class.
index e6667dddd027499d3ed318234727c51254c7d4de..d08fa3aab24911a48210e18d3a06a6e7f124e261 100644 (file)
@@ -1,8 +1,9 @@
 cimport cython
 from cpython.dict cimport PyDict_Merge, PyDict_Update
-from cpython.long cimport PyLong_FromLong
+from cpython.long cimport PyLong_FromLongLong
 from cpython.set cimport PySet_Add
 
+from collections.abc import Collection
 from itertools import filterfalse
 
 cdef bint add_not_present(set seen, object item, hashfunc):
@@ -39,8 +40,7 @@ cdef class OrderedSet(set):
         else:
             self._list = []
 
-    @cython.final
-    cdef OrderedSet _copy(self):
+    cpdef OrderedSet copy(self):
         cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
         cp._list = list(self._list)
         set.update(cp, cp._list)
@@ -63,6 +63,14 @@ cdef class OrderedSet(set):
         set.remove(self, element)
         self._list.remove(element)
 
+    def pop(self):
+        try:
+            value = self._list.pop()
+        except IndexError:
+            raise KeyError("pop from an empty set") from None
+        set.remove(self, value)
+        return value
+
     def insert(self, Py_ssize_t pos, element):
         if element not in self:
             self._list.insert(pos, element)
@@ -91,34 +99,25 @@ cdef class OrderedSet(set):
 
     __str__ = __repr__
 
-    cpdef OrderedSet update(self, iterable):
-        for e in iterable:
-            if e not in self:
-                self._list.append(e)
-                set.add(self, e)
-        return self
+    def update(self, *iterables):
+        for iterable in iterables:
+            for e in iterable:
+                if e not in self:
+                    self._list.append(e)
+                    set.add(self, e)
 
     def __ior__(self, iterable):
-        return self.update(iterable)
+        self.update(iterable)
+        return self
 
     def union(self, *other):
-        result = self._copy()
-        for o in other:
-            result.update(o)
+        result = self.copy()
+        result.update(*other)
         return result
 
     def __or__(self, other):
         return self.union(other)
 
-    @cython.final
-    cdef set _to_set(self, other):
-        cdef set other_set
-        if isinstance(other, set):
-            other_set = <set> other
-        else:
-            other_set = set(other)
-        return other_set
-
     def intersection(self, *other):
         cdef set other_set = set.intersection(self, *other)
         return self._from_list([a for a in self._list if a in other_set])
@@ -127,10 +126,18 @@ cdef class OrderedSet(set):
         return self.intersection(other)
 
     def symmetric_difference(self, other):
-        cdef set other_set = self._to_set(other)
+        cdef set other_set
+        if isinstance(other, set):
+            other_set = <set> other
+            collection = other_set
+        elif isinstance(other, Collection):
+            collection = other
+            other_set = set(other)
+        else:
+            collection = list(other)
+            other_set = set(collection)
         result = self._from_list([a for a in self._list if a not in other_set])
-        # use other here to keep the order
-        result.update(a for a in other if a not in self)
+        result.update(a for a in collection if a not in self)
         return result
 
     def __xor__(self, other):
@@ -152,9 +159,10 @@ cdef class OrderedSet(set):
         return self
 
     cpdef symmetric_difference_update(self, other):
-        set.symmetric_difference_update(self, other)
+        collection = other if isinstance(other, Collection) else list(other)
+        set.symmetric_difference_update(self, collection)
         self._list = [a for a in self._list if a in self]
-        self._list += [a for a in other if a in self]
+        self._list += [a for a in collection if a in self]
 
     def __ixor__(self, other):
         self.symmetric_difference_update(other)
@@ -169,13 +177,12 @@ cdef class OrderedSet(set):
         return self
 
 cdef object cy_id(object item):
-    return PyLong_FromLong(<long> (<void *>item))
+    return PyLong_FromLongLong(<long long> (<void *>item))
 
 # NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped
 # instead of the __rmeth__, so they need to check that also self is of the
 # correct type. This is fixed in cython 3.x. See:
 # https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods
-
 cdef class IdentitySet:
     """A set that considers only object id() for uniqueness.
 
index 8810800c42ab876c1237b794c4bb23cc0a48c309..9962493b5cb375083ee0c6f0e0894cd63105b391 100644 (file)
@@ -168,8 +168,11 @@ class OrderedSet(Set[_T]):
         else:
             self._list = []
 
-    def __reduce__(self):
-        return (OrderedSet, (self._list,))
+    def copy(self) -> OrderedSet[_T]:
+        cp = self.__class__()
+        cp._list = self._list.copy()
+        set.update(cp, cp._list)
+        return cp
 
     def add(self, element: _T) -> None:
         if element not in self:
@@ -180,6 +183,14 @@ class OrderedSet(Set[_T]):
         super().remove(element)
         self._list.remove(element)
 
+    def pop(self) -> _T:
+        try:
+            value = self._list.pop()
+        except IndexError:
+            raise KeyError("pop from an empty set") from None
+        super().remove(value)
+        return value
+
     def insert(self, pos: int, element: _T) -> None:
         if element not in self:
             self._list.insert(pos, element)
@@ -220,9 +231,8 @@ class OrderedSet(Set[_T]):
         return self  # type: ignore
 
     def union(self, *other: Iterable[_S]) -> OrderedSet[Union[_T, _S]]:
-        result: OrderedSet[Union[_T, _S]] = self.__class__(self)  # type: ignore  # noqa: E501
-        for o in other:
-            result.update(o)
+        result: OrderedSet[Union[_T, _S]] = self.copy()  # type: ignore
+        result.update(*other)
         return result
 
     def __or__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
@@ -237,9 +247,17 @@ class OrderedSet(Set[_T]):
         return self.intersection(other)
 
     def symmetric_difference(self, other: Iterable[_T]) -> OrderedSet[_T]:
-        other_set = other if isinstance(other, set) else set(other)
+        collection: Collection[_T]
+        if isinstance(other, set):
+            collection = other_set = other
+        elif isinstance(other, Collection):
+            collection = other
+            other_set = set(other)
+        else:
+            collection = list(other)
+            other_set = set(collection)
         result = self.__class__(a for a in self if a not in other_set)
-        result.update(a for a in other if a not in self)
+        result.update(a for a in collection if a not in self)
         return result
 
     def __xor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
@@ -263,9 +281,10 @@ class OrderedSet(Set[_T]):
         return self
 
     def symmetric_difference_update(self, other: Iterable[Any]) -> None:
-        super().symmetric_difference_update(other)
+        collection = other if isinstance(other, Collection) else list(other)
+        super().symmetric_difference_update(collection)
         self._list = [a for a in self._list if a in self]
-        self._list += [a for a in other if a in self]
+        self._list += [a for a in collection if a in self]
 
     def __ixor__(self, other: AbstractSet[_S]) -> OrderedSet[Union[_T, _S]]:
         self.symmetric_difference_update(other)
index 01877f776633891db9431fd35e8be0cd3e558c37..d77e1b0ae887854d7b3ad1ecf9083af795817d28 100644 (file)
@@ -19,9 +19,12 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_instance_of
+from sqlalchemy.testing import is_none
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing import not_in
 from sqlalchemy.testing.util import gc_collect
 from sqlalchemy.testing.util import picklers
 from sqlalchemy.util import classproperty
@@ -209,6 +212,27 @@ class OrderedSetTest(fixtures.TestBase):
         eq_(o.difference(iter([3, 4])), util.OrderedSet([2, 5]))
         eq_(o.intersection(iter([3, 4, 6])), util.OrderedSet([3, 4]))
         eq_(o.union(iter([3, 4, 6])), util.OrderedSet([3, 2, 4, 5, 6]))
+        eq_(
+            o.symmetric_difference(iter([3, 4, 6])), util.OrderedSet([2, 5, 6])
+        )
+
+    def test_mutators_against_iter_update(self):
+        # testing a set modified against an iterator
+        o = util.OrderedSet([3, 2, 4, 5])
+        o.difference_update(iter([3, 4]))
+        eq_(list(o), [2, 5])
+
+        o = util.OrderedSet([3, 2, 4, 5])
+        o.intersection_update(iter([3, 4]))
+        eq_(list(o), [3, 4])
+
+        o = util.OrderedSet([3, 2, 4, 5])
+        o.update(iter([3, 4, 6]))
+        eq_(list(o), [3, 2, 4, 5, 6])
+
+        o = util.OrderedSet([3, 2, 4, 5])
+        o.symmetric_difference_update(iter([3, 4, 6]))
+        eq_(list(o), [2, 5, 6])
 
     def test_len(self):
         eq_(len(util.OrderedSet([1, 2, 3])), 3)
@@ -229,6 +253,110 @@ class OrderedSetTest(fixtures.TestBase):
         o = util.OrderedSet([3, 2, 4, 5])
         eq_(str(o), "OrderedSet([3, 2, 4, 5])")
 
+    def test_modify(self):
+        o = util.OrderedSet([3, 9, 11])
+        is_none(o.add(42))
+        in_(42, o)
+        in_(3, o)
+
+        is_none(o.remove(9))
+        not_in(9, o)
+        in_(3, o)
+
+        is_none(o.discard(11))
+        in_(3, o)
+
+        o.add(99)
+        is_none(o.insert(1, 13))
+        eq_(list(o), [3, 13, 42, 99])
+        eq_(o[2], 42)
+
+        val = o.pop()
+        eq_(val, 99)
+        not_in(99, o)
+        eq_(list(o), [3, 13, 42])
+
+        is_none(o.clear())
+        not_in(3, o)
+        is_false(bool(o))
+
+    def test_empty_pop(self):
+        with expect_raises_message(KeyError, "pop from an empty set"):
+            util.OrderedSet().pop()
+
+    @testing.combinations(
+        lambda o: o + util.OrderedSet([11, 22]),
+        lambda o: o | util.OrderedSet([11, 22]),
+        lambda o: o.union(util.OrderedSet([11, 22])),
+        lambda o: o.union([11, 2], [22, 8]),
+    )
+    def test_op(self, fn):
+        o = util.OrderedSet(range(10))
+        x = fn(o)
+        is_instance_of(x, util.OrderedSet)
+        in_(9, x)
+        in_(11, x)
+        not_in(11, o)
+
+    def test_update(self):
+        o = util.OrderedSet(range(10))
+        is_none(o.update([22, 2], [33, 11]))
+        in_(11, o)
+        in_(22, o)
+
+    def test_set_ops(self):
+        o1, o2 = util.OrderedSet([1, 3, 5, 7]), {2, 3, 4, 5}
+        eq_(o1 & o2, {3, 5})
+        eq_(o1.intersection(o2), {3, 5})
+        o3 = o1.copy()
+        o3 &= o2
+        eq_(o3, {3, 5})
+        o3 = o1.copy()
+        is_none(o3.intersection_update(o2))
+        eq_(o3, {3, 5})
+
+        eq_(o1 | o2, {1, 2, 3, 4, 5, 7})
+        eq_(o1.union(o2), {1, 2, 3, 4, 5, 7})
+        o3 = o1.copy()
+        o3 |= o2
+        eq_(o3, {1, 2, 3, 4, 5, 7})
+        o3 = o1.copy()
+        is_none(o3.update(o2))
+        eq_(o3, {1, 2, 3, 4, 5, 7})
+
+        eq_(o1 - o2, {1, 7})
+        eq_(o1.difference(o2), {1, 7})
+        o3 = o1.copy()
+        o3 -= o2
+        eq_(o3, {1, 7})
+        o3 = o1.copy()
+        is_none(o3.difference_update(o2))
+        eq_(o3, {1, 7})
+
+        eq_(o1 ^ o2, {1, 2, 4, 7})
+        eq_(o1.symmetric_difference(o2), {1, 2, 4, 7})
+        o3 = o1.copy()
+        o3 ^= o2
+        eq_(o3, {1, 2, 4, 7})
+        o3 = o1.copy()
+        is_none(o3.symmetric_difference_update(o2))
+        eq_(o3, {1, 2, 4, 7})
+
+    def test_copy(self):
+        o = util.OrderedSet([3, 2, 4, 5])
+        cp = o.copy()
+        is_instance_of(cp, util.OrderedSet)
+        eq_(o, cp)
+        o.add(42)
+        is_false(42 in cp)
+
+    def test_pickle(self):
+        o = util.OrderedSet([2, 4, 9, 42])
+        for loads, dumps in picklers():
+            l = loads(dumps(o))
+            is_instance_of(l, util.OrderedSet)
+            eq_(list(l), [2, 4, 9, 42])
+
 
 class ImmutableDictTest(fixtures.TestBase):
     def test_union_no_change(self):