--- /dev/null
+.. change::
+ :tags: usecase, typing
+ :tickets: 10288
+
+ Made the contained type for :class:`.Mapped` covariant; this is to allow
+ greater flexibility for end-user typing scenarios, such as the use of
+ protocols to represent particular mapped class structures that are passed
+ to other functions. As part of this change, the contained type was also
+ made covariant for dependent and related types such as
+ :class:`_orm.base.SQLORMOperations`, :class:`_orm.WriteOnlyMapped`, and
+ :class:`_sql.SQLColumnExpression`. Pull request courtesy Roméo Després.
+
...
-class _HybridExprCallableType(Protocol[_T]):
+class _HybridExprCallableType(Protocol[_T_co]):
def __call__(
s, cls: Any
- ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]:
+ ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]:
...
_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_AllPendingType = Sequence[
@inspection._self_inspects
class QueryableAttribute(
- _DeclarativeMapped[_T],
- SQLORMExpression[_T],
+ _DeclarativeMapped[_T_co],
+ SQLORMExpression[_T_co],
interfaces.InspectionAttr,
- interfaces.PropComparator[_T],
+ interfaces.PropComparator[_T_co],
roles.JoinTargetRole,
roles.OnClauseRole,
sql_base.Immutable,
is_attribute = True
- dispatch: dispatcher[QueryableAttribute[_T]]
+ dispatch: dispatcher[QueryableAttribute[_T_co]]
class_: _ExternalEntityType[Any]
key: str
parententity: _InternalEntityType[Any]
impl: AttributeImpl
- comparator: interfaces.PropComparator[_T]
+ comparator: interfaces.PropComparator[_T_co]
_of_type: Optional[_InternalEntityType[Any]]
_extra_criteria: Tuple[ColumnElement[bool], ...]
_doc: Optional[str]
class_: _ExternalEntityType[_O],
key: str,
parententity: _InternalEntityType[_O],
- comparator: interfaces.PropComparator[_T],
+ comparator: interfaces.PropComparator[_T_co],
impl: Optional[AttributeImpl] = None,
of_type: Optional[_InternalEntityType[Any]] = None,
extra_criteria: Tuple[ColumnElement[bool], ...] = (),
"""
- expression: ColumnElement[_T]
+ expression: ColumnElement[_T_co]
"""The SQL expression object represented by this
:class:`.QueryableAttribute`.
def _annotations(self) -> _AnnotationDict:
return self.__clause_element__()._annotations
- def __clause_element__(self) -> ColumnElement[_T]:
+ def __clause_element__(self) -> ColumnElement[_T_co]:
return self.expression
@property
extra_criteria=self._extra_criteria,
)
- def label(self, name: Optional[str]) -> Label[_T]:
+ def label(self, name: Optional[str]) -> Label[_T_co]:
return self.__clause_element__().label(name)
def operate(
from ..sql.operators import OperatorType
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_O = TypeVar("_O", bound=object)
return {}
-class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
+class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly):
__slots__ = ()
if typing.TYPE_CHECKING:
- def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
+ def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]:
...
def and_(
...
-class ORMDescriptor(Generic[_T], TypingOnly):
+class ORMDescriptor(Generic[_T_co], TypingOnly):
"""Represent any Python descriptor that provides a SQL expression
construct at the class level."""
@overload
def __get__(
self, instance: Any, owner: Literal[None]
- ) -> ORMDescriptor[_T]:
+ ) -> ORMDescriptor[_T_co]:
...
@overload
def __get__(
self, instance: Literal[None], owner: Any
- ) -> SQLCoreOperations[_T]:
+ ) -> SQLCoreOperations[_T_co]:
...
@overload
- def __get__(self, instance: object, owner: Any) -> _T:
+ def __get__(self, instance: object, owner: Any) -> _T_co:
...
def __get__(
self, instance: object, owner: Any
- ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]:
+ ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]:
...
-class _MappedAnnotationBase(Generic[_T], TypingOnly):
+class _MappedAnnotationBase(Generic[_T_co], TypingOnly):
"""common class for Mapped and similar ORM container classes.
these are classes that can appear on the left side of an ORM declarative
class SQLORMExpression(
- SQLORMOperations[_T], SQLColumnExpression[_T], TypingOnly
+ SQLORMOperations[_T_co], SQLColumnExpression[_T_co], TypingOnly
):
"""A type that may be used to indicate any ORM-level attribute or
object that acts in place of one, in the context of SQL expression
class Mapped(
- SQLORMExpression[_T],
- ORMDescriptor[_T],
- _MappedAnnotationBase[_T],
+ SQLORMExpression[_T_co],
+ ORMDescriptor[_T_co],
+ _MappedAnnotationBase[_T_co],
roles.DDLConstraintColumnRole,
):
"""Represent an ORM mapped attribute on a mapped class.
@overload
def __get__(
self, instance: None, owner: Any
- ) -> InstrumentedAttribute[_T]:
+ ) -> InstrumentedAttribute[_T_co]:
...
@overload
- def __get__(self, instance: object, owner: Any) -> _T:
+ def __get__(self, instance: object, owner: Any) -> _T_co:
...
def __get__(
self, instance: Optional[object], owner: Any
- ) -> Union[InstrumentedAttribute[_T], _T]:
+ ) -> Union[InstrumentedAttribute[_T_co], _T_co]:
...
@classmethod
- def _empty_constructor(cls, arg1: Any) -> Mapped[_T]:
+ def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]:
...
def __set__(
- self, instance: Any, value: Union[SQLCoreOperations[_T], _T]
+ self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co]
) -> None:
...
...
-class _MappedAttribute(Generic[_T], TypingOnly):
+class _MappedAttribute(Generic[_T_co], TypingOnly):
"""Mixin for attributes which should be replaced by mapper-assigned
attributes.
__slots__ = ()
-class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]):
+class _DeclarativeMapped(Mapped[_T_co], _MappedAttribute[_T_co]):
"""Mixin for :class:`.MapperProperty` subclasses that allows them to
be compatible with ORM-annotated declarative mappings.
return NotImplemented
-class DynamicMapped(_MappedAnnotationBase[_T]):
+class DynamicMapped(_MappedAnnotationBase[_T_co]):
"""Represent the ORM mapped attribute type for a "dynamic" relationship.
The :class:`_orm.DynamicMapped` type annotation may be used in an
@overload
def __get__(
self, instance: None, owner: Any
- ) -> InstrumentedAttribute[_T]:
+ ) -> InstrumentedAttribute[_T_co]:
...
@overload
- def __get__(self, instance: object, owner: Any) -> AppenderQuery[_T]:
+ def __get__(
+ self, instance: object, owner: Any
+ ) -> AppenderQuery[_T_co]:
...
def __get__(
self, instance: Optional[object], owner: Any
- ) -> Union[InstrumentedAttribute[_T], AppenderQuery[_T]]:
+ ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]:
...
- def __set__(self, instance: Any, value: typing.Collection[_T]) -> None:
+ def __set__(
+ self, instance: Any, value: typing.Collection[_T_co]
+ ) -> None:
...
-class WriteOnlyMapped(_MappedAnnotationBase[_T]):
+class WriteOnlyMapped(_MappedAnnotationBase[_T_co]):
"""Represent the ORM mapped attribute type for a "write only" relationship.
The :class:`_orm.WriteOnlyMapped` type annotation may be used in an
@overload
def __get__(
self, instance: None, owner: Any
- ) -> InstrumentedAttribute[_T]:
+ ) -> InstrumentedAttribute[_T_co]:
...
@overload
def __get__(
self, instance: object, owner: Any
- ) -> WriteOnlyCollection[_T]:
+ ) -> WriteOnlyCollection[_T_co]:
...
def __get__(
self, instance: Optional[object], owner: Any
- ) -> Union[InstrumentedAttribute[_T], WriteOnlyCollection[_T]]:
+ ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]:
...
- def __set__(self, instance: Any, value: typing.Collection[_T]) -> None:
+ def __set__(
+ self, instance: Any, value: typing.Collection[_T_co]
+ ) -> None:
...
_StrategyKey = Tuple[Any, ...]
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]")
@inspection._self_inspects
-class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators):
+class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators):
r"""Defines SQL operations for ORM mapped attributes.
SQLAlchemy allows for operators to
_parententity: _InternalEntityType[Any]
_adapt_to_entity: Optional[AliasedInsp[Any]]
- prop: RODescriptorReference[MapperProperty[_T]]
+ prop: RODescriptorReference[MapperProperty[_T_co]]
def __init__(
self,
_NUMBER = Union[float, int, Decimal]
_T = TypeVar("_T", bound="Any")
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
_OPT = TypeVar("_OPT", bound="Any")
_NT = TypeVar("_NT", bound="_NUMERIC")
# SQLCoreOperations should be suiting the ExpressionElementRole
# and ColumnsClauseRole. however the MRO issues become too elaborate
# at the moment.
-class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
+class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
__slots__ = ()
# annotations for comparison methods
def __or__(self, other: Any) -> BooleanClauseList:
...
- def __invert__(self) -> ColumnElement[_T]:
+ def __invert__(self) -> ColumnElement[_T_co]:
...
def __lt__(self, other: Any) -> ColumnElement[bool]:
def __ge__(self, other: Any) -> ColumnElement[bool]:
...
- def __neg__(self) -> UnaryExpression[_T]:
+ def __neg__(self) -> UnaryExpression[_T_co]:
...
def __contains__(self, other: Any) -> ColumnElement[bool]:
def bitwise_and(self, other: Any) -> BinaryExpression[Any]:
...
- def bitwise_not(self) -> UnaryExpression[_T]:
+ def bitwise_not(self) -> UnaryExpression[_T_co]:
...
def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]:
) -> ColumnElement[str]:
...
- def desc(self) -> UnaryExpression[_T]:
+ def desc(self) -> UnaryExpression[_T_co]:
...
- def asc(self) -> UnaryExpression[_T]:
+ def asc(self) -> UnaryExpression[_T_co]:
...
- def nulls_first(self) -> UnaryExpression[_T]:
+ def nulls_first(self) -> UnaryExpression[_T_co]:
...
- def nullsfirst(self) -> UnaryExpression[_T]:
+ def nullsfirst(self) -> UnaryExpression[_T_co]:
...
- def nulls_last(self) -> UnaryExpression[_T]:
+ def nulls_last(self) -> UnaryExpression[_T_co]:
...
- def nullslast(self) -> UnaryExpression[_T]:
+ def nullslast(self) -> UnaryExpression[_T_co]:
...
def collate(self, collation: str) -> CollationClause:
) -> BinaryExpression[bool]:
...
- def distinct(self: _SQO[_T]) -> UnaryExpression[_T]:
+ def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]:
...
def any_(self) -> CollectionAggregate[Any]:
) -> ColumnElement[str]:
...
- @overload
- def __add__(self, other: Any) -> ColumnElement[Any]:
- ...
-
def __add__(self, other: Any) -> ColumnElement[Any]:
...
@overload
- def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
- ...
-
- @overload
- def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]:
+ def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]:
...
@overload
class SQLColumnExpression(
- SQLCoreOperations[_T], roles.ExpressionElementRole[_T], TypingOnly
+ SQLCoreOperations[_T_co], roles.ExpressionElementRole[_T_co], TypingOnly
):
"""A type that may be used to indicate any SQL column element or object
that acts in place of one.
from .selectable import Subquery
_T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
class SQLRole:
raise NotImplementedError()
-class TypedColumnsClauseRole(Generic[_T], SQLRole):
+class TypedColumnsClauseRole(Generic[_T_co], SQLRole):
"""element-typed form of ColumnsClauseRole"""
__slots__ = ()
_role_name = "SQL expression for WHERE/HAVING role"
-class ExpressionElementRole(TypedColumnsClauseRole[_T]):
+class ExpressionElementRole(TypedColumnsClauseRole[_T_co]):
# note when using generics for ExpressionElementRole,
# the generic type needs to be in
# sqlalchemy.sql.coercions._impl_lookup mapping also.
--- /dev/null
+"""Tests Mapped covariance."""\r
+\r
+from typing import Protocol\r
+\r
+from sqlalchemy import ForeignKey\r
+from sqlalchemy.orm import DeclarativeBase\r
+from sqlalchemy.orm import Mapped\r
+from sqlalchemy.orm import mapped_column\r
+from sqlalchemy.orm import relationship\r
+\r
+\r
+# Protocols\r
+\r
+\r
+class ParentProtocol(Protocol):\r
+ name: Mapped[str]\r
+\r
+\r
+class ChildProtocol(Protocol):\r
+ # Read-only for simplicity, mutable protocol members are complicated,\r
+ # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected\r
+ @property\r
+ def parent(self) -> Mapped[ParentProtocol]:\r
+ ...\r
+\r
+\r
+def get_parent_name(child: ChildProtocol) -> str:\r
+ return child.parent.name\r
+\r
+\r
+# Implementations\r
+\r
+\r
+class Base(DeclarativeBase):\r
+ pass\r
+\r
+\r
+class Parent(Base):\r
+ __tablename__ = "parent"\r
+\r
+ name: Mapped[str] = mapped_column(primary_key=True)\r
+\r
+\r
+class Child(Base):\r
+ __tablename__ = "child"\r
+\r
+ name: Mapped[str] = mapped_column(primary_key=True)\r
+ parent_name: Mapped[str] = mapped_column(ForeignKey(Parent.name))\r
+\r
+ parent: Mapped[Parent] = relationship()\r
+\r
+\r
+assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo"\r