]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Minor improvements in collections.pyx 9478/head
authorMatus Valo <matusvalo@gmail.com>
Tue, 7 Mar 2023 16:01:23 +0000 (17:01 +0100)
committerMatus Valo <matusvalo@gmail.com>
Mon, 13 Mar 2023 21:58:51 +0000 (22:58 +0100)
Fixes #9477

lib/sqlalchemy/cyextension/collections.pyx

index 07bc85e23d4af79539ad43a3af80b1d55288a480..e6667dddd027499d3ed318234727c51254c7d4de 100644 (file)
@@ -1,3 +1,4 @@
+cimport cython
 from cpython.dict cimport PyDict_Merge, PyDict_Update
 from cpython.long cimport PyLong_FromLong
 from cpython.set cimport PySet_Add
@@ -38,12 +39,14 @@ cdef class OrderedSet(set):
         else:
             self._list = []
 
+    @cython.final
     cdef OrderedSet _copy(self):
         cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
         cp._list = list(self._list)
         set.update(cp, cp._list)
         return cp
 
+    @cython.final
     cdef OrderedSet _from_list(self, list new_list):
         cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
         new._list = new_list
@@ -88,7 +91,7 @@ cdef class OrderedSet(set):
 
     __str__ = __repr__
 
-    def update(self, iterable):
+    cpdef OrderedSet update(self, iterable):
         for e in iterable:
             if e not in self:
                 self._list.append(e)
@@ -107,6 +110,7 @@ cdef class OrderedSet(set):
     def __or__(self, other):
         return self.union(other)
 
+    @cython.final
     cdef set _to_set(self, other):
         cdef set other_set
         if isinstance(other, set):
@@ -116,7 +120,7 @@ cdef class OrderedSet(set):
         return other_set
 
     def intersection(self, *other):
-        cdef other_set = set.intersection(self, *other)
+        cdef set other_set = set.intersection(self, *other)
         return self._from_list([a for a in self._list if a in other_set])
 
     def __and__(self, other):
@@ -133,7 +137,7 @@ cdef class OrderedSet(set):
         return self.symmetric_difference(other)
 
     def difference(self, *other):
-        cdef other_set = set.difference(self, *other)
+        cdef set other_set = set.difference(self, *other)
         return self._from_list([a for a in self._list if a in other_set])
 
     def __sub__(self, other):
@@ -147,7 +151,7 @@ cdef class OrderedSet(set):
         self.intersection_update(other)
         return self
 
-    def symmetric_difference_update(self, other):
+    cpdef symmetric_difference_update(self, other):
         set.symmetric_difference_update(self, other)
         self._list = [a for a in self._list if a in self]
         self._list += [a for a in other if a in self]
@@ -296,7 +300,7 @@ cdef class IdentitySet:
         self.update(other)
         return self
 
-    cpdef difference(self, iterable):
+    cpdef IdentitySet difference(self, iterable):
         cdef IdentitySet result = self.__new__(self.__class__)
         if isinstance(iterable, self.__class__):
             other = (<IdentitySet>iterable)._members
@@ -320,7 +324,7 @@ cdef class IdentitySet:
         self.difference_update(other)
         return self
 
-    cpdef intersection(self, iterable):
+    cpdef IdentitySet intersection(self, iterable):
         cdef IdentitySet result = self.__new__(self.__class__)
         if isinstance(iterable, self.__class__):
             other = (<IdentitySet>iterable)._members
@@ -344,7 +348,7 @@ cdef class IdentitySet:
         self.intersection_update(other)
         return self
 
-    cpdef symmetric_difference(self, iterable):
+    cpdef IdentitySet symmetric_difference(self, iterable):
         cdef IdentitySet result = self.__new__(self.__class__)
         cdef dict other
         if isinstance(iterable, self.__class__):
@@ -372,7 +376,7 @@ cdef class IdentitySet:
         self.symmetric_difference(other)
         return self
 
-    cpdef copy(self):
+    cpdef IdentitySet copy(self):
         cdef IdentitySet cp = self.__new__(self.__class__)
         cp._members = self._members.copy()
         return cp