]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations to indexable extension code 12763/head
authorDenis Laxalde <denis@laxalde.org>
Mon, 28 Jul 2025 08:02:25 +0000 (10:02 +0200)
committerDenis Laxalde <denis@laxalde.org>
Mon, 28 Jul 2025 12:48:14 +0000 (14:48 +0200)
A typing test case (plain_files/ext/indexable.py) is also added.

In order to make the methods of index_property conform with type
definitions of `fget`, `fset` and `fdel` arguments of hybrid_property,
we need to make the signature of protocols (e.g. `_HybridGetterType`)
`__call__`) method positional only.

Related to #6810.

lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/ext/indexable.py
test/typing/plain_files/ext/indexable.py [new file with mode: 0644]

index cbf5e591c1b175115d27c1b4337295fde6342881..c28d5faf200efd1bd328e3af3b9e0200e38e12c3 100644 (file)
@@ -923,11 +923,11 @@ class HybridExtensionType(InspectionAttrExtensionType):
 
 
 class _HybridGetterType(Protocol[_T_co]):
-    def __call__(s, self: Any) -> _T_co: ...
+    def __call__(s, self: Any, /) -> _T_co: ...
 
 
 class _HybridSetterType(Protocol[_T_con]):
-    def __call__(s, self: Any, value: _T_con) -> None: ...
+    def __call__(s, self: Any, value: _T_con, /) -> None: ...
 
 
 class _HybridUpdaterType(Protocol[_T_con]):
@@ -939,12 +939,12 @@ class _HybridUpdaterType(Protocol[_T_con]):
 
 
 class _HybridDeleterType(Protocol[_T_co]):
-    def __call__(s, self: Any) -> None: ...
+    def __call__(s, self: Any, /) -> None: ...
 
 
 class _HybridExprCallableType(Protocol[_T_co]):
     def __call__(
-        s, cls: Any
+        s, cls: Any, /
     ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ...
 
 
index 883d9742078cac9314154d742877ca40c41b7b5c..01fcfd6a890884b1126a41bbd6f9eb7e1bcc02f4 100644 (file)
@@ -4,7 +4,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 """Define attributes on ORM-mapped classes that have "index" attributes for
 columns with :class:`_types.Indexable` types.
@@ -224,15 +223,32 @@ The above query will render:
     WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
 
 """  # noqa
+
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
 from .. import inspect
 from ..ext.hybrid import hybrid_property
 from ..orm.attributes import flag_modified
 
+if TYPE_CHECKING:
+    from ..sql import SQLColumnExpression
+    from ..sql._typing import _HasClauseElement
+
 
 __all__ = ["index_property"]
 
+_T = TypeVar("_T")
+
 
-class index_property(hybrid_property):  # noqa
+class index_property(hybrid_property[_T]):
     """A property generator. The generated property describes an object
     attribute that corresponds to an :class:`_types.Indexable`
     column.
@@ -243,16 +259,16 @@ class index_property(hybrid_property):  # noqa
 
     """
 
-    _NO_DEFAULT_ARGUMENT = object()
+    _NO_DEFAULT_ARGUMENT = cast(_T, object())
 
     def __init__(
         self,
-        attr_name,
-        index,
-        default=_NO_DEFAULT_ARGUMENT,
-        datatype=None,
-        mutable=True,
-        onebased=True,
+        attr_name: str,
+        index: Union[int, str],
+        default: _T = _NO_DEFAULT_ARGUMENT,
+        datatype: Optional[Callable[[], Any]] = None,
+        mutable: bool = True,
+        onebased: bool = True,
     ):
         """Create a new :class:`.index_property`.
 
@@ -291,18 +307,18 @@ class index_property(hybrid_property):  # noqa
             self.datatype = datatype
         else:
             if is_numeric:
-                self.datatype = lambda: [None for x in range(index + 1)]
+                self.datatype = lambda: [None for x in range(index + 1)]  # type: ignore[operator]  # noqa: E501
             else:
                 self.datatype = dict
         self.onebased = onebased
 
-    def _fget_default(self, err=None):
+    def _fget_default(self, err: Optional[BaseException] = None) -> _T:
         if self.default == self._NO_DEFAULT_ARGUMENT:
             raise AttributeError(self.attr_name) from err
         else:
             return self.default
 
-    def fget(self, instance):
+    def fget(self, instance: Any, /) -> _T:
         attr_name = self.attr_name
         column_value = getattr(instance, attr_name)
         if column_value is None:
@@ -312,9 +328,9 @@ class index_property(hybrid_property):  # noqa
         except (KeyError, IndexError) as err:
             return self._fget_default(err)
         else:
-            return value
+            return value  # type: ignore[no-any-return]
 
-    def fset(self, instance, value):
+    def fset(self, instance: Any, value: _T) -> None:
         attr_name = self.attr_name
         column_value = getattr(instance, attr_name, None)
         if column_value is None:
@@ -325,7 +341,7 @@ class index_property(hybrid_property):  # noqa
         if attr_name in inspect(instance).mapper.attrs:
             flag_modified(instance, attr_name)
 
-    def fdel(self, instance):
+    def fdel(self, instance: Any) -> None:
         attr_name = self.attr_name
         column_value = getattr(instance, attr_name)
         if column_value is None:
@@ -338,9 +354,13 @@ class index_property(hybrid_property):  # noqa
             setattr(instance, attr_name, column_value)
             flag_modified(instance, attr_name)
 
-    def expr(self, model):
+    def expr(
+        self, model: Any
+    ) -> Union[_HasClauseElement[_T], SQLColumnExpression[_T]]:
         column = getattr(model, self.attr_name)
         index = self.index
         if self.onebased:
+            if TYPE_CHECKING:
+                assert isinstance(index, int)
             index += 1
-        return column[index]
+        return column[index]  # type: ignore[no-any-return]
diff --git a/test/typing/plain_files/ext/indexable.py b/test/typing/plain_files/ext/indexable.py
new file mode 100644 (file)
index 0000000..c6c1c35
--- /dev/null
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+from datetime import date
+from typing import Dict
+from typing import List
+
+from sqlalchemy import ARRAY
+from sqlalchemy import JSON
+from sqlalchemy import select
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy.ext.indexable import index_property
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Article(Base):
+    __tablename__ = "articles"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+
+    tags: Mapped[Dict[str, str]] = mapped_column(JSON)
+    topic: hybrid_property[str] = index_property("tags", "topic")
+
+    updates: Mapped[List[date]] = mapped_column(ARRAY[date])
+    created_at = index_property(
+        "updates", 0, mutable=True, default=date.today()
+    )
+    updated_at: hybrid_property[date] = index_property("updates", -1)
+
+
+a = Article(
+    tags={"topic": "database", "subject": "programming"},
+    updates=[date(2025, 7, 28), date(2025, 7, 29)],
+)
+
+# EXPECTED_TYPE: str
+reveal_type(a.topic)
+
+# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.str\*?\]
+reveal_type(Article.topic)
+
+# EXPECTED_TYPE: date
+reveal_type(a.created_at)
+
+# EXPECTED_TYPE: date
+reveal_type(a.updated_at)
+
+a.created_at = date(2025, 7, 30)
+
+# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\]
+reveal_type(Article.created_at)
+
+# EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[datetime.date\*?\]
+reveal_type(Article.updated_at)
+
+stmt = select(Article.id, Article.topic, Article.created_at).where(
+    Article.id == 1
+)
+
+# EXPECTED_RE_TYPE: .*Select\[.*int, .*str, datetime\.date\]
+reveal_type(stmt)