]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #10673: make declared_attr covariant
authorLuiz Felipe Neves <luizfneves@proton.me>
Thu, 30 Apr 2026 17:06:25 +0000 (13:06 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Thu, 30 Apr 2026 17:06:25 +0000 (13:06 -0400)
<!-- Provide a general summary of your proposed changes in the Title field above -->

### Description
<!-- Describe your changes in detail -->
I made declared_attr covariant as suggested in #10673. mypy didn't seem to complain. Added a regression test for the use case that was asked for. Unfortunately, it seems like using `Mapped[int | UUID]` directly in the Protocol won't work:

```python
class CompareProtocol(Protocol):
    id: Mapped[int | UUID]
```

Because mypy will see this as a settable variable and not as a SQLAlchemy descriptor. Using `@property` instead seems to work and it's what I used in the test (perhaps it should be documented as the way to achieve this?):

```python
class CompareProtocol(Protocol):
    @property
    def id(self) -> Mapped[int | UUID]: ...
```

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [X] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Closes: #13266
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13266
Pull-request-sha: 8edd2841f4bbe61f8bb9bc15a7a57e0560698779

Change-Id: I7d63ad43df0ab34ee7c7389a007191be91efa574

lib/sqlalchemy/orm/decl_api.py
test/typing/plain_files/orm/declared_attr_three.py [new file with mode: 0644]

index 0df31d236a0a363920d63f65fef479fedee3b84f..505df8cfbe18cee0b6eb421a076d5b95b43df3ee 100644 (file)
@@ -96,6 +96,7 @@ if TYPE_CHECKING:
     from ..util.typing import _MatchedOnType
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 _TT = TypeVar("_TT", bound=Any)
 
@@ -105,7 +106,7 @@ _TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"]
 _MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"]
 
 _DeclaredAttrDecorated = Callable[
-    ..., Union[Mapped[_T], ORMDescriptor[_T], SQLCoreOperations[_T]]
+    ..., Union[Mapped[_T_co], ORMDescriptor[_T_co], SQLCoreOperations[_T_co]]
 ]
 
 
@@ -330,7 +331,7 @@ class _declared_directive(_declared_attr_common, Generic[_T]):
             ...
 
 
-class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common):
+class declared_attr(interfaces._MappedAttribute[_T_co], _declared_attr_common):
     """Mark a class-level method as representing the definition of
     a mapped property or Declarative directive.
 
@@ -427,7 +428,7 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common):
 
         def __init__(
             self,
-            fn: _DeclaredAttrDecorated[_T],
+            fn: _DeclaredAttrDecorated[_T_co],
             cascading: bool = False,
         ): ...
 
@@ -442,17 +443,17 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common):
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> InstrumentedAttribute[_T]: ...
+        ) -> InstrumentedAttribute[_T_co]: ...
 
         @overload
-        def __get__(self, instance: object, owner: Any) -> _T: ...
+        def __get__(self, instance: object, owner: Any) -> _T_co: ...
 
         def __get__(
             self, instance: Optional[object], owner: Any
-        ) -> Union[InstrumentedAttribute[_T], _T]: ...
+        ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ...
 
     @hybridmethod
-    def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]:
+    def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T_co]:
         return _stateful_declared_attr(**kw)
 
     @hybridproperty
@@ -461,24 +462,26 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common):
         return _declared_directive  # type: ignore
 
     @hybridproperty
-    def cascading(cls) -> _stateful_declared_attr[_T]:
+    def cascading(cls) -> _stateful_declared_attr[_T_co]:
         # see mapping_api.rst for docstring
         return cls._stateful(cascading=True)
 
 
-class _stateful_declared_attr(declared_attr[_T]):
+class _stateful_declared_attr(declared_attr[_T_co]):
     kw: Dict[str, Any]
 
     def __init__(self, **kw: Any):
         self.kw = kw
 
     @hybridmethod
-    def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]:
+    def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T_co]:
         new_kw = self.kw.copy()
         new_kw.update(kw)
         return _stateful_declared_attr(**new_kw)
 
-    def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]:
+    def __call__(
+        self, fn: _DeclaredAttrDecorated[_T_co]
+    ) -> declared_attr[_T_co]:
         return declared_attr(fn, **self.kw)
 
 
diff --git a/test/typing/plain_files/orm/declared_attr_three.py b/test/typing/plain_files/orm/declared_attr_three.py
new file mode 100644 (file)
index 0000000..b2b0c36
--- /dev/null
@@ -0,0 +1,64 @@
+# Regression tests for the declared_attr typing issue reported in #10673.
+
+import typing
+from typing import assert_type
+from typing import Protocol
+from uuid import UUID
+from uuid import uuid4
+
+import sqlalchemy as sa
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import Mapped
+
+
+class ModelBase(DeclarativeBase):
+    pass
+
+
+class CompareProtocol(Protocol):
+    @property
+    def id(self) -> Mapped[int | UUID]: ...
+
+
+class CompareMixin:
+    def compare(self: CompareProtocol, other: CompareProtocol) -> bool:
+        return self.id == other.id
+
+
+class IntIdMixin:
+    @declared_attr
+    def id(cls) -> Mapped[int]:
+        return sa.orm.mapped_column(sa.Integer, primary_key=True)
+
+
+class UuidIdMixin:
+    @declared_attr
+    def id(cls) -> Mapped[UUID]:
+        return sa.orm.mapped_column(sa.UUID, primary_key=True, default=uuid4)
+
+
+class MyModel(CompareMixin, IntIdMixin, ModelBase):
+    __tablename__ = "my_model"
+
+
+class MyUuidModel(CompareMixin, UuidIdMixin, ModelBase):
+    __tablename__ = "my_uuid_model"
+
+
+m1 = MyModel()
+m2 = MyModel()
+u1 = MyUuidModel()
+
+
+def _int_id(cls: type[object]) -> Mapped[int]:
+    return sa.orm.mapped_column(sa.Integer, primary_key=True)
+
+
+int_id_attr: declared_attr[int] = declared_attr(_int_id)
+union_id_attr: declared_attr[int | UUID] = int_id_attr
+assert union_id_attr
+
+if typing.TYPE_CHECKING:
+    assert_type(m1.compare(m2), bool)
+    assert_type(m1.compare(u1), bool)