]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Minor improvements in collections.pyx
authorMatus Valo <matusvalo@gmail.com>
Tue, 14 Mar 2023 08:19:30 +0000 (04:19 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Tue, 14 Mar 2023 08:19:30 +0000 (04:19 -0400)
### Description

This PR introduces minor improvements to collections.pyx:

* Adds missed type annotations yielding slightly more optimised code
* Adds missed `cpdef` methods used internally
* Marks private methods with `@cython.final`

Fixes #9477

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [X] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #9478
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9478
Pull-request-sha: c006c76c2c50491ea1be9c723c278da16c151397

Change-Id: I74b3df2bc790db49e331b8f8085c797249364b07

lib/sqlalchemy/cyextension/collections.pyx

index 07bc85e23d4af79539ad43a3af80b1d55288a480..e6667dddd027499d3ed318234727c51254c7d4de 100644 (file)
@@ -1,3 +1,4 @@
+cimport cython
 from cpython.dict cimport PyDict_Merge, PyDict_Update
 from cpython.long cimport PyLong_FromLong
 from cpython.set cimport PySet_Add
@@ -38,12 +39,14 @@ cdef class OrderedSet(set):
         else:
             self._list = []
 
+    @cython.final
     cdef OrderedSet _copy(self):
         cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
         cp._list = list(self._list)
         set.update(cp, cp._list)
         return cp
 
+    @cython.final
     cdef OrderedSet _from_list(self, list new_list):
         cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
         new._list = new_list
@@ -88,7 +91,7 @@ cdef class OrderedSet(set):
 
     __str__ = __repr__
 
-    def update(self, iterable):
+    cpdef OrderedSet update(self, iterable):
         for e in iterable:
             if e not in self:
                 self._list.append(e)
@@ -107,6 +110,7 @@ cdef class OrderedSet(set):
     def __or__(self, other):
         return self.union(other)
 
+    @cython.final
     cdef set _to_set(self, other):
         cdef set other_set
         if isinstance(other, set):
@@ -116,7 +120,7 @@ cdef class OrderedSet(set):
         return other_set
 
     def intersection(self, *other):
-        cdef other_set = set.intersection(self, *other)
+        cdef set other_set = set.intersection(self, *other)
         return self._from_list([a for a in self._list if a in other_set])
 
     def __and__(self, other):
@@ -133,7 +137,7 @@ cdef class OrderedSet(set):
         return self.symmetric_difference(other)
 
     def difference(self, *other):
-        cdef other_set = set.difference(self, *other)
+        cdef set other_set = set.difference(self, *other)
         return self._from_list([a for a in self._list if a in other_set])
 
     def __sub__(self, other):
@@ -147,7 +151,7 @@ cdef class OrderedSet(set):
         self.intersection_update(other)
         return self
 
-    def symmetric_difference_update(self, other):
+    cpdef symmetric_difference_update(self, other):
         set.symmetric_difference_update(self, other)
         self._list = [a for a in self._list if a in self]
         self._list += [a for a in other if a in self]
@@ -296,7 +300,7 @@ cdef class IdentitySet:
         self.update(other)
         return self
 
-    cpdef difference(self, iterable):
+    cpdef IdentitySet difference(self, iterable):
         cdef IdentitySet result = self.__new__(self.__class__)
         if isinstance(iterable, self.__class__):
             other = (<IdentitySet>iterable)._members
@@ -320,7 +324,7 @@ cdef class IdentitySet:
         self.difference_update(other)
         return self
 
-    cpdef intersection(self, iterable):
+    cpdef IdentitySet intersection(self, iterable):
         cdef IdentitySet result = self.__new__(self.__class__)
         if isinstance(iterable, self.__class__):
             other = (<IdentitySet>iterable)._members
@@ -344,7 +348,7 @@ cdef class IdentitySet:
         self.intersection_update(other)
         return self
 
-    cpdef symmetric_difference(self, iterable):
+    cpdef IdentitySet symmetric_difference(self, iterable):
         cdef IdentitySet result = self.__new__(self.__class__)
         cdef dict other
         if isinstance(iterable, self.__class__):
@@ -372,7 +376,7 @@ cdef class IdentitySet:
         self.symmetric_difference(other)
         return self
 
-    cpdef copy(self):
+    cpdef IdentitySet copy(self):
         cdef IdentitySet cp = self.__new__(self.__class__)
         cp._members = self._members.copy()
         return cp