From: Matus Valo Date: Tue, 14 Mar 2023 08:19:30 +0000 (-0400) Subject: Minor improvements in collections.pyx X-Git-Tag: rel_2_0_8~9^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f3baf6194c3984525e3ce259e1b70c763c0ad824;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Minor improvements in collections.pyx ### Description This PR introduces minor improvements to collections.pyx: * Adds missed type annotations yielding slightly more optimised code * Adds missed `cpdef` methods used internally * Marks private methods with `@cython.final` Fixes #9477 ### Checklist This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [X] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #9478 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9478 Pull-request-sha: c006c76c2c50491ea1be9c723c278da16c151397 Change-Id: I74b3df2bc790db49e331b8f8085c797249364b07 --- diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx index 07bc85e23d..e6667dddd0 100644 --- a/lib/sqlalchemy/cyextension/collections.pyx +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -1,3 +1,4 @@ +cimport cython from cpython.dict cimport PyDict_Merge, PyDict_Update from cpython.long cimport PyLong_FromLong from cpython.set cimport PySet_Add @@ -38,12 +39,14 @@ cdef class OrderedSet(set): else: self._list = [] + @cython.final cdef OrderedSet _copy(self): cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) cp._list = list(self._list) set.update(cp, cp._list) return cp + @cython.final cdef OrderedSet _from_list(self, list new_list): cdef OrderedSet new = OrderedSet.__new__(OrderedSet) new._list = new_list @@ -88,7 +91,7 @@ cdef class OrderedSet(set): __str__ = __repr__ - def update(self, iterable): + cpdef OrderedSet update(self, iterable): for e in iterable: if e not in self: self._list.append(e) @@ -107,6 +110,7 @@ cdef class OrderedSet(set): def __or__(self, other): return self.union(other) + @cython.final cdef set _to_set(self, other): cdef set other_set if isinstance(other, set): @@ -116,7 +120,7 @@ cdef class OrderedSet(set): return other_set def intersection(self, *other): - cdef other_set = set.intersection(self, *other) + cdef set other_set = set.intersection(self, *other) return self._from_list([a for a in self._list if a in other_set]) def __and__(self, other): @@ -133,7 +137,7 @@ cdef class OrderedSet(set): return self.symmetric_difference(other) def difference(self, *other): - cdef other_set = set.difference(self, *other) + cdef set other_set = set.difference(self, *other) return self._from_list([a for a in self._list if a in other_set]) def __sub__(self, other): @@ -147,7 +151,7 @@ cdef class OrderedSet(set): self.intersection_update(other) return self - def symmetric_difference_update(self, other): + cpdef symmetric_difference_update(self, other): set.symmetric_difference_update(self, other) self._list = [a for a in self._list if a in self] self._list += [a for a in other if a in self] @@ -296,7 +300,7 @@ cdef class IdentitySet: self.update(other) return self - cpdef difference(self, iterable): + cpdef IdentitySet difference(self, iterable): cdef IdentitySet result = self.__new__(self.__class__) if isinstance(iterable, self.__class__): other = (iterable)._members @@ -320,7 +324,7 @@ cdef class IdentitySet: self.difference_update(other) return self - cpdef intersection(self, iterable): + cpdef IdentitySet intersection(self, iterable): cdef IdentitySet result = self.__new__(self.__class__) if isinstance(iterable, self.__class__): other = (iterable)._members @@ -344,7 +348,7 @@ cdef class IdentitySet: self.intersection_update(other) return self - cpdef symmetric_difference(self, iterable): + cpdef IdentitySet symmetric_difference(self, iterable): cdef IdentitySet result = self.__new__(self.__class__) cdef dict other if isinstance(iterable, self.__class__): @@ -372,7 +376,7 @@ cdef class IdentitySet: self.symmetric_difference(other) return self - cpdef copy(self): + cpdef IdentitySet copy(self): cdef IdentitySet cp = self.__new__(self.__class__) cp._members = self._members.copy() return cp