]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update orderinglist annotations, list compatibity
authorMartijn Pieters <mj@zopatista.com>
Mon, 15 Jan 2024 16:10:30 +0000 (16:10 +0000)
committerMartijn Pieters <mj@zopatista.com>
Sun, 24 Nov 2024 20:24:37 +0000 (20:24 +0000)
- Don't omit the `_T` typevar in `ordering_list` and
  `OrderingList.ordering_func``; type checkers need to understand the
  relationship between the `OrderingList` instance and the ordering
  function connected to it.
- The ordering function can return _any_ value, not just integers
- The `ordering_attr` argument to `OrderingList` is not optional
- Update list methods to accept the same signature as the overridden
  methods, including `SupportsIndex` instead of `int`, an iterable of
  `_T` when using `__setitem__` with a slice (and not just sequences)
  and converting the index value to an integer before passing it to the
  `ordering_func` callable.
- Update `__setitem__` to _not_ attempt to handle slice objects as
  handling all edge cases of slice length and iterable length is very
  tricky and most use of `OrderingList` and slices is handled by
  the SQLAlchemy collections instrumentation anyway.
- Remove the `__setslice__` and `__delslice__` methods, which were
  deprecated in Python 2.6 and removed in Python 3.0.

Fixes #10888

lib/sqlalchemy/ext/orderinglist.py
test/ext/test_orderinglist.py

index 1a12cf38c6981da4c92b667ff7064de75041cef1..dc6ca0ee6ab888117d08b894b9dfc8a7430c66eb 100644 (file)
@@ -122,17 +122,23 @@ start numbering at 1 or some other integer, provide ``count_from=1``.
 """
 from __future__ import annotations
 
+from operator import index as index_to_int
+from typing import Any
 from typing import Callable
+from typing import Iterable
 from typing import List
 from typing import Optional
+from typing import overload
 from typing import Sequence
+from typing import SupportsIndex
 from typing import TypeVar
+from typing import Union
 
 from ..orm.collections import collection
 from ..orm.collections import collection_adapter
 
 _T = TypeVar("_T")
-OrderingFunc = Callable[[int, Sequence[_T]], int]
+OrderingFunc = Callable[[int, Sequence[_T]], object]
 
 
 __all__ = ["ordering_list"]
@@ -141,9 +147,9 @@ __all__ = ["ordering_list"]
 def ordering_list(
     attr: str,
     count_from: Optional[int] = None,
-    ordering_func: Optional[OrderingFunc] = None,
+    ordering_func: Optional[OrderingFunc[_T]] = None,
     reorder_on_append: bool = False,
-) -> Callable[[], OrderingList]:
+) -> Callable[[], OrderingList[_T]]:
     """Prepares an :class:`OrderingList` factory for use in mapper definitions.
 
     Returns an object suitable for use as an argument to a Mapper
@@ -185,22 +191,22 @@ def ordering_list(
 # Ordering utility functions
 
 
-def count_from_0(index, collection):
+def count_from_0(index: int, collection: object) -> int:
     """Numbering function: consecutive integers starting at 0."""
 
     return index
 
 
-def count_from_1(index, collection):
+def count_from_1(index: int, collection: object) -> int:
     """Numbering function: consecutive integers starting at 1."""
 
     return index + 1
 
 
-def count_from_n_factory(start):
+def count_from_n_factory(start: int) -> OrderingFunc[Any]:
     """Numbering function: consecutive integers starting at arbitrary start."""
 
-    def f(index, collection):
+    def f(index: int, collection: object) -> int:
         return index + start
 
     try:
@@ -238,13 +244,13 @@ class OrderingList(List[_T]):
     """
 
     ordering_attr: str
-    ordering_func: OrderingFunc
+    ordering_func: OrderingFunc[_T]
     reorder_on_append: bool
 
     def __init__(
         self,
-        ordering_attr: Optional[str] = None,
-        ordering_func: Optional[OrderingFunc] = None,
+        ordering_attr: str,
+        ordering_func: Optional[OrderingFunc[_T]] = None,
         reorder_on_append: bool = False,
     ):
         """A custom list that manages position information for its children.
@@ -330,13 +336,13 @@ class OrderingList(List[_T]):
         if have is not None and not reorder:
             return
 
-        should_be = self.ordering_func(index, self)
+        should_be = self.ordering_func(index_to_int(index), self)
         if have != should_be:
             self._set_order_value(entity, should_be)
 
-    def append(self, entity):
-        super().append(entity)
-        self._order_entity(len(self) - 1, entity, self.reorder_on_append)
+    def append(self, __entity: _T) -> None:
+        super().append(__entity)
+        self._order_entity(len(self) - 1, __entity, self.reorder_on_append)
 
     def _raw_append(self, entity):
         """Append without any ordering behavior."""
@@ -345,48 +351,47 @@ class OrderingList(List[_T]):
 
     _raw_append = collection.adds(1)(_raw_append)
 
-    def insert(self, index, entity):
-        super().insert(index, entity)
+    def insert(self, __index: SupportsIndex, __entity: _T) -> None:
+        super().insert(__index, __entity)
         self._reorder()
 
-    def remove(self, entity):
-        super().remove(entity)
+    def remove(self, __entity: _T) -> None:
+        super().remove(__entity)
 
         adapter = collection_adapter(self)
         if adapter and adapter._referenced_by_owner:
             self._reorder()
 
-    def pop(self, index=-1):
-        entity = super().pop(index)
+    def pop(self, __index: SupportsIndex = -1) -> _T:
+        entity = super().pop(__index)
         self._reorder()
         return entity
 
-    def __setitem__(self, index, entity):
-        if isinstance(index, slice):
-            step = index.step or 1
-            start = index.start or 0
-            if start < 0:
-                start += len(self)
-            stop = index.stop or len(self)
-            if stop < 0:
-                stop += len(self)
-
-            for i in range(start, stop, step):
-                self.__setitem__(i, entity[i])
-        else:
-            self._order_entity(index, entity, True)
-            super().__setitem__(index, entity)
+    @overload
+    def __setitem__(self, __index: SupportsIndex, __entity: _T) -> None: ...
 
-    def __delitem__(self, index):
-        super().__delitem__(index)
-        self._reorder()
+    @overload
+    def __setitem__(self, __index: slice, __entity: Iterable[_T]) -> None: ...
 
-    def __setslice__(self, start, end, values):
-        super().__setslice__(start, end, values)
-        self._reorder()
+    def __setitem__(
+        self,
+        __index: Union[SupportsIndex, slice],
+        __entity: Union[_T, Iterable[_T]],
+    ) -> None:
+        if isinstance(__index, slice):
+            # there are enough edge cases with slice lengths longer or shorter
+            # than the length of the assigned items, that full-on reordering
+            # is far simpler.
+            # SQLAlchemy collection instrumentation otherwise takes care of
+            # slices in all normal use of OrderingList in ORM classes.
+            super().__setitem__(__index, __entity)
+            self._reorder()
+        else:
+            self._order_entity(__index, __entity, True)
+            super().__setitem__(__index, __entity)
 
-    def __delslice__(self, start, end):
-        super().__delslice__(start, end)
+    def __delitem__(self, __index: Union[SupportsIndex, slice]) -> None:
+        super().__delitem__(__index)
         self._reorder()
 
     def __reduce__(self):
index 90c7f385789c3f3b617801c190fa51e8d2d0c2c0..dee945bc4c5a704c9287c084e8af4a1d4de6ad83 100644 (file)
@@ -322,7 +322,7 @@ class OrderingListTest(fixtures.MappedTest):
         s1 = Slide("Slide #1")
 
         # 1, 2, 3
-        s1.bullets[0:3] = b[0:3]
+        s1.bullets[0:3] = iter(b[0:3])
         for i in 0, 1, 2:
             self.assert_(s1.bullets[i].position == i)
             self.assert_(s1.bullets[i] == b[i])
@@ -480,6 +480,20 @@ class OrderingListTest(fixtures.MappedTest):
             self.assert_(copy == olist)
             self.assert_(copy.__dict__ == olist.__dict__)
 
+    def test_index_to_int(self):
+        self._setup(ordering_list("position"))
+
+        s1 = Slide("Slide #1")
+        s1.bullets.append(Bullet("1"))
+        s1.bullets.append(Bullet("2"))
+
+        b3 = Bullet("3")
+        b3.position = 2
+        index_value = MockIndex(0)
+        s1.bullets[index_value] = b3
+        assert s1.bullets[0].text == "3"
+        assert s1.bullets[0].position == 0
+
 
 class DummyItem:
     def __init__(self, order=None):
@@ -490,3 +504,11 @@ class DummyItem:
 
     def __ne__(self, other):
         return not (self == other)
+
+
+class MockIndex:
+    def __init__(self, value):
+        self.value = value
+
+    def __index__(self):
+        return self.value