From: Denis Laxalde Date: Mon, 28 Jul 2025 08:02:25 +0000 (+0200) Subject: Add type annotations to indexable extension code X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=refs%2Fpull%2F12763%2Fhead;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add type annotations to indexable extension code A typing test case (plain_files/ext/indexable.py) is also added. In order to make the methods of index_property conform with type definitions of `fget`, `fset` and `fdel` arguments of hybrid_property, we need to make the signature of protocols (e.g. `_HybridGetterType`) `__call__`) method positional only. Related to #6810. --- diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index cbf5e591c1..c28d5faf20 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -923,11 +923,11 @@ class HybridExtensionType(InspectionAttrExtensionType): class _HybridGetterType(Protocol[_T_co]): - def __call__(s, self: Any) -> _T_co: ... + def __call__(s, self: Any, /) -> _T_co: ... class _HybridSetterType(Protocol[_T_con]): - def __call__(s, self: Any, value: _T_con) -> None: ... + def __call__(s, self: Any, value: _T_con, /) -> None: ... class _HybridUpdaterType(Protocol[_T_con]): @@ -939,12 +939,12 @@ class _HybridUpdaterType(Protocol[_T_con]): class _HybridDeleterType(Protocol[_T_co]): - def __call__(s, self: Any) -> None: ... + def __call__(s, self: Any, /) -> None: ... class _HybridExprCallableType(Protocol[_T_co]): def __call__( - s, cls: Any + s, cls: Any, / ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ... diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py index 883d974207..01fcfd6a89 100644 --- a/lib/sqlalchemy/ext/indexable.py +++ b/lib/sqlalchemy/ext/indexable.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors """Define attributes on ORM-mapped classes that have "index" attributes for columns with :class:`_types.Indexable` types. @@ -224,15 +223,32 @@ The above query will render: WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s """ # noqa + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + from .. import inspect from ..ext.hybrid import hybrid_property from ..orm.attributes import flag_modified +if TYPE_CHECKING: + from ..sql import SQLColumnExpression + from ..sql._typing import _HasClauseElement + __all__ = ["index_property"] +_T = TypeVar("_T") + -class index_property(hybrid_property): # noqa +class index_property(hybrid_property[_T]): """A property generator. The generated property describes an object attribute that corresponds to an :class:`_types.Indexable` column. @@ -243,16 +259,16 @@ class index_property(hybrid_property): # noqa """ - _NO_DEFAULT_ARGUMENT = object() + _NO_DEFAULT_ARGUMENT = cast(_T, object()) def __init__( self, - attr_name, - index, - default=_NO_DEFAULT_ARGUMENT, - datatype=None, - mutable=True, - onebased=True, + attr_name: str, + index: Union[int, str], + default: _T = _NO_DEFAULT_ARGUMENT, + datatype: Optional[Callable[[], Any]] = None, + mutable: bool = True, + onebased: bool = True, ): """Create a new :class:`.index_property`. @@ -291,18 +307,18 @@ class index_property(hybrid_property): # noqa self.datatype = datatype else: if is_numeric: - self.datatype = lambda: [None for x in range(index + 1)] + self.datatype = lambda: [None for x in range(index + 1)] # type: ignore[operator] # noqa: E501 else: self.datatype = dict self.onebased = onebased - def _fget_default(self, err=None): + def _fget_default(self, err: Optional[BaseException] = None) -> _T: if self.default == self._NO_DEFAULT_ARGUMENT: raise AttributeError(self.attr_name) from err else: return self.default - def fget(self, instance): + def fget(self, instance: Any, /) -> _T: attr_name = self.attr_name column_value = getattr(instance, attr_name) if column_value is None: @@ -312,9 +328,9 @@ class index_property(hybrid_property): # noqa except (KeyError, IndexError) as err: return self._fget_default(err) else: - return value + return value # type: ignore[no-any-return] - def fset(self, instance, value): + def fset(self, instance: Any, value: _T) -> None: attr_name = self.attr_name column_value = getattr(instance, attr_name, None) if column_value is None: @@ -325,7 +341,7 @@ class index_property(hybrid_property): # noqa if attr_name in inspect(instance).mapper.attrs: flag_modified(instance, attr_name) - def fdel(self, instance): + def fdel(self, instance: Any) -> None: attr_name = self.attr_name column_value = getattr(instance, attr_name) if column_value is None: @@ -338,9 +354,13 @@ class index_property(hybrid_property): # noqa setattr(instance, attr_name, column_value) flag_modified(instance, attr_name) - def expr(self, model): + def expr( + self, model: Any + ) -> Union[_HasClauseElement[_T], SQLColumnExpression[_T]]: column = getattr(model, self.attr_name) index = self.index if self.onebased: + if TYPE_CHECKING: + assert isinstance(index, int) index += 1 - return column[index] + return column[index] # type: ignore[no-any-return] diff --git a/test/typing/plain_files/ext/indexable.py b/test/typing/plain_files/ext/indexable.py new file mode 100644 index 0000000000..c6c1c35299 --- /dev/null +++ b/test/typing/plain_files/ext/indexable.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from datetime import date +from typing import Dict +from typing import List + +from sqlalchemy import ARRAY +from sqlalchemy import JSON +from sqlalchemy import select +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.ext.indexable import index_property +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Article(Base): + __tablename__ = "articles" + + id: Mapped[int] = mapped_column(primary_key=True) + + tags: Mapped[Dict[str, str]] = mapped_column(JSON) + topic: hybrid_property[str] = index_property("tags", "topic") + + updates: Mapped[List[date]] = mapped_column(ARRAY[date]) + created_at = index_property( + "updates", 0, mutable=True, default=date.today() + ) + updated_at: hybrid_property[date] = index_property("updates", -1) + + +a = Article( + tags={"topic": "database", "subject": "programming"}, + updates=[date(2025, 7, 28), date(2025, 7, 29)], +) + +# EXPECTED_TYPE: str +reveal_type(a.topic) + +# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.str\*?\] +reveal_type(Article.topic) + +# EXPECTED_TYPE: date +reveal_type(a.created_at) + +# EXPECTED_TYPE: date +reveal_type(a.updated_at) + +a.created_at = date(2025, 7, 30) + +# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\] +reveal_type(Article.created_at) + +# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\] +reveal_type(Article.updated_at) + +stmt = select(Article.id, Article.topic, Article.created_at).where( + Article.id == 1 +) + +# EXPECTED_RE_TYPE: .*Select\[.*int, .*str, datetime\.date\] +reveal_type(stmt)