]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add orderinglist type annotations
authorMartijn Pieters <mj@zopatista.com>
Mon, 25 Nov 2024 19:38:48 +0000 (14:38 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 26 Jul 2025 16:42:19 +0000 (18:42 +0200)
Closes: #10889
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10889
Pull-request-sha: 2ddeeb190630a8965a4fae567e2649ed16722c99

Change-Id: I9a0d6e2776d8b6756af4a3c54668bdcd1a1f40f8

lib/sqlalchemy/ext/orderinglist.py
lib/sqlalchemy/orm/collections.py
test/ext/test_orderinglist.py
test/typing/plain_files/ext/orderinglist/orderinglist_one.py [new file with mode: 0644]

index 3cc67b189649eae8a98454f8f894717dd967e9f6..80bf688eaf11dccd91e070d5814fa6856342078b 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """A custom list that manages index/position information for contained
 elements.
@@ -129,17 +128,24 @@ start numbering at 1 or some other integer, provide ``count_from=1``.
 """
 from __future__ import annotations
 
+from typing import Any
 from typing import Callable
+from typing import Dict
+from typing import Iterable
 from typing import List
 from typing import Optional
+from typing import overload
 from typing import Sequence
+from typing import SupportsIndex
+from typing import Type
 from typing import TypeVar
+from typing import Union
 
 from ..orm.collections import collection
 from ..orm.collections import collection_adapter
 
 _T = TypeVar("_T")
-OrderingFunc = Callable[[int, Sequence[_T]], int]
+OrderingFunc = Callable[[int, Sequence[_T]], object]
 
 
 __all__ = ["ordering_list"]
@@ -148,9 +154,9 @@ __all__ = ["ordering_list"]
 def ordering_list(
     attr: str,
     count_from: Optional[int] = None,
-    ordering_func: Optional[OrderingFunc] = None,
+    ordering_func: Optional[OrderingFunc[_T]] = None,
     reorder_on_append: bool = False,
-) -> Callable[[], OrderingList]:
+) -> Callable[[], OrderingList[_T]]:
     """Prepares an :class:`OrderingList` factory for use in mapper definitions.
 
     Returns an object suitable for use as an argument to a Mapper
@@ -196,22 +202,22 @@ def ordering_list(
 # Ordering utility functions
 
 
-def count_from_0(index, collection):
+def count_from_0(index: int, collection: object) -> int:
     """Numbering function: consecutive integers starting at 0."""
 
     return index
 
 
-def count_from_1(index, collection):
+def count_from_1(index: int, collection: object) -> int:
     """Numbering function: consecutive integers starting at 1."""
 
     return index + 1
 
 
-def count_from_n_factory(start):
+def count_from_n_factory(start: int) -> OrderingFunc[Any]:
     """Numbering function: consecutive integers starting at arbitrary start."""
 
-    def f(index, collection):
+    def f(index: int, collection: object) -> int:
         return index + start
 
     try:
@@ -221,7 +227,7 @@ def count_from_n_factory(start):
     return f
 
 
-def _unsugar_count_from(**kw):
+def _unsugar_count_from(**kw: Any) -> Dict[str, Any]:
     """Builds counting functions from keyword arguments.
 
     Keyword argument filter, prepares a simple ``ordering_func`` from a
@@ -249,13 +255,13 @@ class OrderingList(List[_T]):
     """
 
     ordering_attr: str
-    ordering_func: OrderingFunc
+    ordering_func: OrderingFunc[_T]
     reorder_on_append: bool
 
     def __init__(
         self,
-        ordering_attr: Optional[str] = None,
-        ordering_func: Optional[OrderingFunc] = None,
+        ordering_attr: str,
+        ordering_func: Optional[OrderingFunc[_T]] = None,
         reorder_on_append: bool = False,
     ):
         """A custom list that manages position information for its children.
@@ -315,10 +321,10 @@ class OrderingList(List[_T]):
 
     # More complex serialization schemes (multi column, e.g.) are possible by
     # subclassing and reimplementing these two methods.
-    def _get_order_value(self, entity):
+    def _get_order_value(self, entity: _T) -> Any:
         return getattr(entity, self.ordering_attr)
 
-    def _set_order_value(self, entity, value):
+    def _set_order_value(self, entity: _T, value: Any) -> None:
         setattr(entity, self.ordering_attr, value)
 
     def reorder(self) -> None:
@@ -334,7 +340,9 @@ class OrderingList(List[_T]):
     # As of 0.5, _reorder is no longer semi-private
     _reorder = reorder
 
-    def _order_entity(self, index, entity, reorder=True):
+    def _order_entity(
+        self, index: int, entity: _T, reorder: bool = True
+    ) -> None:
         have = self._get_order_value(entity)
 
         # Don't disturb existing ordering if reorder is False
@@ -345,34 +353,44 @@ class OrderingList(List[_T]):
         if have != should_be:
             self._set_order_value(entity, should_be)
 
-    def append(self, entity):
+    def append(self, entity: _T) -> None:
         super().append(entity)
         self._order_entity(len(self) - 1, entity, self.reorder_on_append)
 
-    def _raw_append(self, entity):
+    def _raw_append(self, entity: _T) -> None:
         """Append without any ordering behavior."""
 
         super().append(entity)
 
     _raw_append = collection.adds(1)(_raw_append)
 
-    def insert(self, index, entity):
+    def insert(self, index: SupportsIndex, entity: _T) -> None:
         super().insert(index, entity)
         self._reorder()
 
-    def remove(self, entity):
+    def remove(self, entity: _T) -> None:
         super().remove(entity)
 
         adapter = collection_adapter(self)
         if adapter and adapter._referenced_by_owner:
             self._reorder()
 
-    def pop(self, index=-1):
+    def pop(self, index: SupportsIndex = -1) -> _T:
         entity = super().pop(index)
         self._reorder()
         return entity
 
-    def __setitem__(self, index, entity):
+    @overload
+    def __setitem__(self, index: SupportsIndex, entity: _T) -> None: ...
+
+    @overload
+    def __setitem__(self, index: slice, entity: Iterable[_T]) -> None: ...
+
+    def __setitem__(
+        self,
+        index: Union[SupportsIndex, slice],
+        entity: Union[_T, Iterable[_T]],
+    ) -> None:
         if isinstance(index, slice):
             step = index.step or 1
             start = index.start or 0
@@ -381,26 +399,18 @@ class OrderingList(List[_T]):
             stop = index.stop or len(self)
             if stop < 0:
                 stop += len(self)
-
+            entities = list(entity)  # type: ignore[arg-type]
             for i in range(start, stop, step):
-                self.__setitem__(i, entity[i])
+                self.__setitem__(i, entities[i])
         else:
-            self._order_entity(index, entity, True)
-            super().__setitem__(index, entity)
+            self._order_entity(int(index), entity, True)  # type: ignore[arg-type] # noqa: E501
+            super().__setitem__(index, entity)  # type: ignore[assignment]
 
-    def __delitem__(self, index):
+    def __delitem__(self, index: Union[SupportsIndex, slice]) -> None:
         super().__delitem__(index)
         self._reorder()
 
-    def __setslice__(self, start, end, values):
-        super().__setslice__(start, end, values)
-        self._reorder()
-
-    def __delslice__(self, start, end):
-        super().__delslice__(start, end)
-        self._reorder()
-
-    def __reduce__(self):
+    def __reduce__(self) -> Any:
         return _reconstitute, (self.__class__, self.__dict__, list(self))
 
     for func_name, func in list(locals().items()):
@@ -414,7 +424,9 @@ class OrderingList(List[_T]):
     del func_name, func
 
 
-def _reconstitute(cls, dict_, items):
+def _reconstitute(
+    cls: Type[OrderingList[_T]], dict_: Dict[str, Any], items: List[_T]
+) -> OrderingList[_T]:
     """Reconstitute an :class:`.OrderingList`.
 
     This is the adjoint to :meth:`.OrderingList.__reduce__`.  It is used for
index 1b6cfbc087da315228009ce87dc9aab7cbd30b2a..1670e1cebc6bb992627459f37e89cc38b14ab42d 100644 (file)
@@ -318,7 +318,7 @@ class collection:
         return fn
 
     @staticmethod
-    def adds(arg):
+    def adds(arg: int) -> Callable[[_FN], _FN]:
         """Mark the method as adding an entity to the collection.
 
         Adds "add to collection" handling to the method.  The decorator
index 98e2a8207f9cb9bb3a07fc7346045d164da4ce5c..d3fd2a9a045e79a931b2fcb4c6ccc6dc417969bf 100644 (file)
@@ -322,7 +322,7 @@ class OrderingListTest(fixtures.MappedTest):
         s1 = Slide("Slide #1")
 
         # 1, 2, 3
-        s1.bullets[0:3] = b[0:3]
+        s1.bullets[0:3] = iter(b[0:3])
         for i in 0, 1, 2:
             self.assert_(s1.bullets[i].position == i)
             self.assert_(s1.bullets[i] == b[i])
@@ -490,3 +490,11 @@ class DummyItem:
 
     def __ne__(self, other):
         return not (self == other)
+
+
+class MockIndex:
+    def __init__(self, value):
+        self.value = value
+
+    def __index__(self):
+        return self.value
diff --git a/test/typing/plain_files/ext/orderinglist/orderinglist_one.py b/test/typing/plain_files/ext/orderinglist/orderinglist_one.py
new file mode 100644 (file)
index 0000000..d2b7c5e
--- /dev/null
@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+import re
+from typing import Sequence
+from typing import TYPE_CHECKING
+
+from sqlalchemy import ForeignKey
+from sqlalchemy.ext.orderinglist import ordering_list
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+def text_to_pos(index: int, items: Sequence[Bullet]) -> int:
+    match = re.search(r"(\d+)", items[index].text)
+    return int(match[1]) if match else index
+
+
+pos_from_text = ordering_list("position", ordering_func=text_to_pos)
+
+
+class Slide(Base):
+    __tablename__ = "slide"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str]
+
+    bullets: Mapped[list[Bullet]] = relationship(
+        "Bullet", order_by="Bullet.position", collection_class=pos_from_text
+    )
+
+
+class Bullet(Base):
+    __tablename__ = "bullet"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    slide_id: Mapped[int] = mapped_column(ForeignKey("slide.id"))
+    position: Mapped[int]
+    text: Mapped[str]
+
+
+slide = Slide()
+
+
+if TYPE_CHECKING:
+    # EXPECTED_RE_TYPE: def \(\) -> sqlalchemy.*.orderinglist.OrderingList\[orderinglist_one.Bullet\]
+    reveal_type(pos_from_text)
+
+    # EXPECTED_TYPE: builtins.list[orderinglist_one.Bullet]
+    reveal_type(slide.bullets)