from ..sql.type_api import _MatchedOnType
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_TT = TypeVar("_TT", bound=Any)
_MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"]
_DeclaredAttrDecorated = Callable[
- ..., Union[Mapped[_T], ORMDescriptor[_T], SQLCoreOperations[_T]]
+ ..., Union[Mapped[_T_co], ORMDescriptor[_T_co], SQLCoreOperations[_T_co]]
]
...
-class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common):
+class declared_attr(interfaces._MappedAttribute[_T_co], _declared_attr_common):
"""Mark a class-level method as representing the definition of
a mapped property or Declarative directive.
def __init__(
self,
- fn: _DeclaredAttrDecorated[_T],
+ fn: _DeclaredAttrDecorated[_T_co],
cascading: bool = False,
): ...
@overload
def __get__(
self, instance: None, owner: Any
- ) -> InstrumentedAttribute[_T]: ...
+ ) -> InstrumentedAttribute[_T_co]: ...
@overload
- def __get__(self, instance: object, owner: Any) -> _T: ...
+ def __get__(self, instance: object, owner: Any) -> _T_co: ...
def __get__(
self, instance: Optional[object], owner: Any
- ) -> Union[InstrumentedAttribute[_T], _T]: ...
+ ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ...
@hybridmethod
- def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]:
+ def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T_co]:
return _stateful_declared_attr(**kw)
@hybridproperty
return _declared_directive # type: ignore
@hybridproperty
- def cascading(cls) -> _stateful_declared_attr[_T]:
+ def cascading(cls) -> _stateful_declared_attr[_T_co]:
# see mapping_api.rst for docstring
return cls._stateful(cascading=True)
-class _stateful_declared_attr(declared_attr[_T]):
+class _stateful_declared_attr(declared_attr[_T_co]):
kw: Dict[str, Any]
def __init__(self, **kw: Any):
self.kw = kw
@hybridmethod
- def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]:
+ def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T_co]:
new_kw = self.kw.copy()
new_kw.update(kw)
return _stateful_declared_attr(**new_kw)
- def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]:
+ def __call__(
+ self, fn: _DeclaredAttrDecorated[_T_co]
+ ) -> declared_attr[_T_co]:
return declared_attr(fn, **self.kw)
--- /dev/null
+# Regression tests for the declared_attr typing issue reported in #10673.
+
+import typing
+from typing import assert_type
+from typing import Protocol
+from uuid import UUID
+from uuid import uuid4
+
+import sqlalchemy as sa
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import Mapped
+
+
+class ModelBase(DeclarativeBase):
+ pass
+
+
+class CompareProtocol(Protocol):
+ @property
+ def id(self) -> Mapped[int | UUID]: ...
+
+
+class CompareMixin:
+ def compare(self: CompareProtocol, other: CompareProtocol) -> bool:
+ return self.id == other.id
+
+
+class IntIdMixin:
+ @declared_attr
+ def id(cls) -> Mapped[int]:
+ return sa.orm.mapped_column(sa.Integer, primary_key=True)
+
+
+class UuidIdMixin:
+ @declared_attr
+ def id(cls) -> Mapped[UUID]:
+ return sa.orm.mapped_column(sa.UUID, primary_key=True, default=uuid4)
+
+
+class MyModel(CompareMixin, IntIdMixin, ModelBase):
+ __tablename__ = "my_model"
+
+
+class MyUuidModel(CompareMixin, UuidIdMixin, ModelBase):
+ __tablename__ = "my_uuid_model"
+
+
+m1 = MyModel()
+m2 = MyModel()
+u1 = MyUuidModel()
+
+
+def _int_id(cls: type[object]) -> Mapped[int]:
+ return sa.orm.mapped_column(sa.Integer, primary_key=True)
+
+
+int_id_attr: declared_attr[int] = declared_attr(_int_id)
+union_id_attr: declared_attr[int | UUID] = int_id_attr
+assert union_id_attr
+
+if typing.TYPE_CHECKING:
+ assert_type(m1.compare(m2), bool)
+ assert_type(m1.compare(u1), bool)