From: Martijn Pieters Date: Mon, 25 Nov 2024 19:38:48 +0000 (-0500) Subject: Add orderinglist type annotations X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=2907dededf5ca923add1106f8b71adb373d3088a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add orderinglist type annotations Closes: #10889 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10889 Pull-request-sha: 2ddeeb190630a8965a4fae567e2649ed16722c99 Change-Id: I9a0d6e2776d8b6756af4a3c54668bdcd1a1f40f8 --- diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index 3cc67b1896..80bf688eaf 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """A custom list that manages index/position information for contained elements. @@ -129,17 +128,24 @@ start numbering at 1 or some other integer, provide ``count_from=1``. """ from __future__ import annotations +from typing import Any from typing import Callable +from typing import Dict +from typing import Iterable from typing import List from typing import Optional +from typing import overload from typing import Sequence +from typing import SupportsIndex +from typing import Type from typing import TypeVar +from typing import Union from ..orm.collections import collection from ..orm.collections import collection_adapter _T = TypeVar("_T") -OrderingFunc = Callable[[int, Sequence[_T]], int] +OrderingFunc = Callable[[int, Sequence[_T]], object] __all__ = ["ordering_list"] @@ -148,9 +154,9 @@ __all__ = ["ordering_list"] def ordering_list( attr: str, count_from: Optional[int] = None, - ordering_func: Optional[OrderingFunc] = None, + ordering_func: Optional[OrderingFunc[_T]] = None, reorder_on_append: bool = False, -) -> Callable[[], OrderingList]: +) -> Callable[[], OrderingList[_T]]: """Prepares an :class:`OrderingList` factory for use in mapper definitions. Returns an object suitable for use as an argument to a Mapper @@ -196,22 +202,22 @@ def ordering_list( # Ordering utility functions -def count_from_0(index, collection): +def count_from_0(index: int, collection: object) -> int: """Numbering function: consecutive integers starting at 0.""" return index -def count_from_1(index, collection): +def count_from_1(index: int, collection: object) -> int: """Numbering function: consecutive integers starting at 1.""" return index + 1 -def count_from_n_factory(start): +def count_from_n_factory(start: int) -> OrderingFunc[Any]: """Numbering function: consecutive integers starting at arbitrary start.""" - def f(index, collection): + def f(index: int, collection: object) -> int: return index + start try: @@ -221,7 +227,7 @@ def count_from_n_factory(start): return f -def _unsugar_count_from(**kw): +def _unsugar_count_from(**kw: Any) -> Dict[str, Any]: """Builds counting functions from keyword arguments. Keyword argument filter, prepares a simple ``ordering_func`` from a @@ -249,13 +255,13 @@ class OrderingList(List[_T]): """ ordering_attr: str - ordering_func: OrderingFunc + ordering_func: OrderingFunc[_T] reorder_on_append: bool def __init__( self, - ordering_attr: Optional[str] = None, - ordering_func: Optional[OrderingFunc] = None, + ordering_attr: str, + ordering_func: Optional[OrderingFunc[_T]] = None, reorder_on_append: bool = False, ): """A custom list that manages position information for its children. @@ -315,10 +321,10 @@ class OrderingList(List[_T]): # More complex serialization schemes (multi column, e.g.) are possible by # subclassing and reimplementing these two methods. - def _get_order_value(self, entity): + def _get_order_value(self, entity: _T) -> Any: return getattr(entity, self.ordering_attr) - def _set_order_value(self, entity, value): + def _set_order_value(self, entity: _T, value: Any) -> None: setattr(entity, self.ordering_attr, value) def reorder(self) -> None: @@ -334,7 +340,9 @@ class OrderingList(List[_T]): # As of 0.5, _reorder is no longer semi-private _reorder = reorder - def _order_entity(self, index, entity, reorder=True): + def _order_entity( + self, index: int, entity: _T, reorder: bool = True + ) -> None: have = self._get_order_value(entity) # Don't disturb existing ordering if reorder is False @@ -345,34 +353,44 @@ class OrderingList(List[_T]): if have != should_be: self._set_order_value(entity, should_be) - def append(self, entity): + def append(self, entity: _T) -> None: super().append(entity) self._order_entity(len(self) - 1, entity, self.reorder_on_append) - def _raw_append(self, entity): + def _raw_append(self, entity: _T) -> None: """Append without any ordering behavior.""" super().append(entity) _raw_append = collection.adds(1)(_raw_append) - def insert(self, index, entity): + def insert(self, index: SupportsIndex, entity: _T) -> None: super().insert(index, entity) self._reorder() - def remove(self, entity): + def remove(self, entity: _T) -> None: super().remove(entity) adapter = collection_adapter(self) if adapter and adapter._referenced_by_owner: self._reorder() - def pop(self, index=-1): + def pop(self, index: SupportsIndex = -1) -> _T: entity = super().pop(index) self._reorder() return entity - def __setitem__(self, index, entity): + @overload + def __setitem__(self, index: SupportsIndex, entity: _T) -> None: ... + + @overload + def __setitem__(self, index: slice, entity: Iterable[_T]) -> None: ... + + def __setitem__( + self, + index: Union[SupportsIndex, slice], + entity: Union[_T, Iterable[_T]], + ) -> None: if isinstance(index, slice): step = index.step or 1 start = index.start or 0 @@ -381,26 +399,18 @@ class OrderingList(List[_T]): stop = index.stop or len(self) if stop < 0: stop += len(self) - + entities = list(entity) # type: ignore[arg-type] for i in range(start, stop, step): - self.__setitem__(i, entity[i]) + self.__setitem__(i, entities[i]) else: - self._order_entity(index, entity, True) - super().__setitem__(index, entity) + self._order_entity(int(index), entity, True) # type: ignore[arg-type] # noqa: E501 + super().__setitem__(index, entity) # type: ignore[assignment] - def __delitem__(self, index): + def __delitem__(self, index: Union[SupportsIndex, slice]) -> None: super().__delitem__(index) self._reorder() - def __setslice__(self, start, end, values): - super().__setslice__(start, end, values) - self._reorder() - - def __delslice__(self, start, end): - super().__delslice__(start, end) - self._reorder() - - def __reduce__(self): + def __reduce__(self) -> Any: return _reconstitute, (self.__class__, self.__dict__, list(self)) for func_name, func in list(locals().items()): @@ -414,7 +424,9 @@ class OrderingList(List[_T]): del func_name, func -def _reconstitute(cls, dict_, items): +def _reconstitute( + cls: Type[OrderingList[_T]], dict_: Dict[str, Any], items: List[_T] +) -> OrderingList[_T]: """Reconstitute an :class:`.OrderingList`. This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 1b6cfbc087..1670e1cebc 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -318,7 +318,7 @@ class collection: return fn @staticmethod - def adds(arg): + def adds(arg: int) -> Callable[[_FN], _FN]: """Mark the method as adding an entity to the collection. Adds "add to collection" handling to the method. The decorator diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py index 98e2a8207f..d3fd2a9a04 100644 --- a/test/ext/test_orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -322,7 +322,7 @@ class OrderingListTest(fixtures.MappedTest): s1 = Slide("Slide #1") # 1, 2, 3 - s1.bullets[0:3] = b[0:3] + s1.bullets[0:3] = iter(b[0:3]) for i in 0, 1, 2: self.assert_(s1.bullets[i].position == i) self.assert_(s1.bullets[i] == b[i]) @@ -490,3 +490,11 @@ class DummyItem: def __ne__(self, other): return not (self == other) + + +class MockIndex: + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value diff --git a/test/typing/plain_files/ext/orderinglist/orderinglist_one.py b/test/typing/plain_files/ext/orderinglist/orderinglist_one.py new file mode 100644 index 0000000000..d2b7c5ece0 --- /dev/null +++ b/test/typing/plain_files/ext/orderinglist/orderinglist_one.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import re +from typing import Sequence +from typing import TYPE_CHECKING + +from sqlalchemy import ForeignKey +from sqlalchemy.ext.orderinglist import ordering_list +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +def text_to_pos(index: int, items: Sequence[Bullet]) -> int: + match = re.search(r"(\d+)", items[index].text) + return int(match[1]) if match else index + + +pos_from_text = ordering_list("position", ordering_func=text_to_pos) + + +class Slide(Base): + __tablename__ = "slide" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + bullets: Mapped[list[Bullet]] = relationship( + "Bullet", order_by="Bullet.position", collection_class=pos_from_text + ) + + +class Bullet(Base): + __tablename__ = "bullet" + id: Mapped[int] = mapped_column(primary_key=True) + slide_id: Mapped[int] = mapped_column(ForeignKey("slide.id")) + position: Mapped[int] + text: Mapped[str] + + +slide = Slide() + + +if TYPE_CHECKING: + # EXPECTED_RE_TYPE: def \(\) -> sqlalchemy.*.orderinglist.OrderingList\[orderinglist_one.Bullet\] + reveal_type(pos_from_text) + + # EXPECTED_TYPE: builtins.list[orderinglist_one.Bullet] + reveal_type(slide.bullets)