From: Luiz Felipe Neves Date: Thu, 30 Apr 2026 17:06:25 +0000 (-0400) Subject: Fixes: #10673: make declared_attr covariant X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f1dfe10237d754ae25b9790c9a9e4d9defccb52b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixes: #10673: make declared_attr covariant ### Description I made declared_attr covariant as suggested in #10673. mypy didn't seem to complain. Added a regression test for the use case that was asked for. Unfortunately, it seems like using `Mapped[int | UUID]` directly in the Protocol won't work: ```python class CompareProtocol(Protocol): id: Mapped[int | UUID] ``` Because mypy will see this as a settable variable and not as a SQLAlchemy descriptor. Using `@property` instead seems to work and it's what I used in the test (perhaps it should be documented as the way to achieve this?): ```python class CompareProtocol(Protocol): @property def id(self) -> Mapped[int | UUID]: ... ``` ### Checklist This pull request is: - [ ] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [X] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. Closes: #13266 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13266 Pull-request-sha: 8edd2841f4bbe61f8bb9bc15a7a57e0560698779 Change-Id: I7d63ad43df0ab34ee7c7389a007191be91efa574 --- diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 0df31d236a..505df8cfbe 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -96,6 +96,7 @@ if TYPE_CHECKING: from ..util.typing import _MatchedOnType _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _TT = TypeVar("_TT", bound=Any) @@ -105,7 +106,7 @@ _TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"] _MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"] _DeclaredAttrDecorated = Callable[ - ..., Union[Mapped[_T], ORMDescriptor[_T], SQLCoreOperations[_T]] + ..., Union[Mapped[_T_co], ORMDescriptor[_T_co], SQLCoreOperations[_T_co]] ] @@ -330,7 +331,7 @@ class _declared_directive(_declared_attr_common, Generic[_T]): ... -class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): +class declared_attr(interfaces._MappedAttribute[_T_co], _declared_attr_common): """Mark a class-level method as representing the definition of a mapped property or Declarative directive. @@ -427,7 +428,7 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): def __init__( self, - fn: _DeclaredAttrDecorated[_T], + fn: _DeclaredAttrDecorated[_T_co], cascading: bool = False, ): ... @@ -442,17 +443,17 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: ... + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: ... + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: ... + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... @hybridmethod - def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]: + def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T_co]: return _stateful_declared_attr(**kw) @hybridproperty @@ -461,24 +462,26 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): return _declared_directive # type: ignore @hybridproperty - def cascading(cls) -> _stateful_declared_attr[_T]: + def cascading(cls) -> _stateful_declared_attr[_T_co]: # see mapping_api.rst for docstring return cls._stateful(cascading=True) -class _stateful_declared_attr(declared_attr[_T]): +class _stateful_declared_attr(declared_attr[_T_co]): kw: Dict[str, Any] def __init__(self, **kw: Any): self.kw = kw @hybridmethod - def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]: + def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T_co]: new_kw = self.kw.copy() new_kw.update(kw) return _stateful_declared_attr(**new_kw) - def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]: + def __call__( + self, fn: _DeclaredAttrDecorated[_T_co] + ) -> declared_attr[_T_co]: return declared_attr(fn, **self.kw) diff --git a/test/typing/plain_files/orm/declared_attr_three.py b/test/typing/plain_files/orm/declared_attr_three.py new file mode 100644 index 0000000000..b2b0c36a9b --- /dev/null +++ b/test/typing/plain_files/orm/declared_attr_three.py @@ -0,0 +1,64 @@ +# Regression tests for the declared_attr typing issue reported in #10673. + +import typing +from typing import assert_type +from typing import Protocol +from uuid import UUID +from uuid import uuid4 + +import sqlalchemy as sa +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import Mapped + + +class ModelBase(DeclarativeBase): + pass + + +class CompareProtocol(Protocol): + @property + def id(self) -> Mapped[int | UUID]: ... + + +class CompareMixin: + def compare(self: CompareProtocol, other: CompareProtocol) -> bool: + return self.id == other.id + + +class IntIdMixin: + @declared_attr + def id(cls) -> Mapped[int]: + return sa.orm.mapped_column(sa.Integer, primary_key=True) + + +class UuidIdMixin: + @declared_attr + def id(cls) -> Mapped[UUID]: + return sa.orm.mapped_column(sa.UUID, primary_key=True, default=uuid4) + + +class MyModel(CompareMixin, IntIdMixin, ModelBase): + __tablename__ = "my_model" + + +class MyUuidModel(CompareMixin, UuidIdMixin, ModelBase): + __tablename__ = "my_uuid_model" + + +m1 = MyModel() +m2 = MyModel() +u1 = MyUuidModel() + + +def _int_id(cls: type[object]) -> Mapped[int]: + return sa.orm.mapped_column(sa.Integer, primary_key=True) + + +int_id_attr: declared_attr[int] = declared_attr(_int_id) +union_id_attr: declared_attr[int | UUID] = int_id_attr +assert union_id_attr + +if typing.TYPE_CHECKING: + assert_type(m1.compare(m2), bool) + assert_type(m1.compare(u1), bool)