From a141e02bbd7c8a9ad47a130b28a87eb05204a4fa Mon Sep 17 00:00:00 2001 From: RomeoDespres Date: Mon, 28 Aug 2023 16:38:04 +0200 Subject: [PATCH] Make `Mapped` covariant (Fixes #10288) --- lib/sqlalchemy/ext/hybrid.py | 4 +- lib/sqlalchemy/orm/attributes.py | 19 +++--- lib/sqlalchemy/orm/base.py | 65 ++++++++++--------- lib/sqlalchemy/orm/interfaces.py | 5 +- lib/sqlalchemy/sql/elements.py | 29 ++++----- lib/sqlalchemy/sql/roles.py | 5 +- .../plain_files/orm/mapped_covariant.py | 53 +++++++++++++++ 7 files changed, 120 insertions(+), 60 deletions(-) create mode 100644 test/typing/plain_files/orm/mapped_covariant.py diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 83dfb50337..1ac6fafc11 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -927,10 +927,10 @@ class _HybridDeleterType(Protocol[_T_co]): ... -class _HybridExprCallableType(Protocol[_T]): +class _HybridExprCallableType(Protocol[_T_co]): def __call__( s, cls: Any - ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]: + ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]: ... diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 6a9766c6f7..b1bda22819 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -120,6 +120,7 @@ if TYPE_CHECKING: _T = TypeVar("_T") +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _AllPendingType = Sequence[ @@ -132,10 +133,10 @@ _UNKNOWN_ATTR_KEY = object() @inspection._self_inspects class QueryableAttribute( - _DeclarativeMapped[_T], - SQLORMExpression[_T], + _DeclarativeMapped[_T_co], + SQLORMExpression[_T_co], interfaces.InspectionAttr, - interfaces.PropComparator[_T], + interfaces.PropComparator[_T_co], roles.JoinTargetRole, roles.OnClauseRole, sql_base.Immutable, @@ -178,13 +179,13 @@ class QueryableAttribute( is_attribute = True - dispatch: dispatcher[QueryableAttribute[_T]] + dispatch: dispatcher[QueryableAttribute[_T_co]] class_: _ExternalEntityType[Any] key: str parententity: _InternalEntityType[Any] impl: AttributeImpl - comparator: interfaces.PropComparator[_T] + comparator: interfaces.PropComparator[_T_co] _of_type: Optional[_InternalEntityType[Any]] _extra_criteria: Tuple[ColumnElement[bool], ...] _doc: Optional[str] @@ -198,7 +199,7 @@ class QueryableAttribute( class_: _ExternalEntityType[_O], key: str, parententity: _InternalEntityType[_O], - comparator: interfaces.PropComparator[_T], + comparator: interfaces.PropComparator[_T_co], impl: Optional[AttributeImpl] = None, of_type: Optional[_InternalEntityType[Any]] = None, extra_criteria: Tuple[ColumnElement[bool], ...] = (), @@ -314,7 +315,7 @@ class QueryableAttribute( """ - expression: ColumnElement[_T] + expression: ColumnElement[_T_co] """The SQL expression object represented by this :class:`.QueryableAttribute`. @@ -376,7 +377,7 @@ class QueryableAttribute( def _annotations(self) -> _AnnotationDict: return self.__clause_element__()._annotations - def __clause_element__(self) -> ColumnElement[_T]: + def __clause_element__(self) -> ColumnElement[_T_co]: return self.expression @property @@ -443,7 +444,7 @@ class QueryableAttribute( extra_criteria=self._extra_criteria, ) - def label(self, name: Optional[str]) -> Label[_T]: + def label(self, name: Optional[str]) -> Label[_T_co]: return self.__clause_element__().label(name) def operate( diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index ecb10591a3..362346cc2a 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -56,6 +56,7 @@ if typing.TYPE_CHECKING: from ..sql.operators import OperatorType _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _O = TypeVar("_O", bound=object) @@ -678,12 +679,12 @@ class InspectionAttrInfo(InspectionAttr): return {} -class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): +class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly): __slots__ = () if typing.TYPE_CHECKING: - def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: ... def and_( @@ -706,7 +707,7 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): ... -class ORMDescriptor(Generic[_T], TypingOnly): +class ORMDescriptor(Generic[_T_co], TypingOnly): """Represent any Python descriptor that provides a SQL expression construct at the class level.""" @@ -717,26 +718,26 @@ class ORMDescriptor(Generic[_T], TypingOnly): @overload def __get__( self, instance: Any, owner: Literal[None] - ) -> ORMDescriptor[_T]: + ) -> ORMDescriptor[_T_co]: ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> SQLCoreOperations[_T]: + ) -> SQLCoreOperations[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: object, owner: Any - ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]: + ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: ... -class _MappedAnnotationBase(Generic[_T], TypingOnly): +class _MappedAnnotationBase(Generic[_T_co], TypingOnly): """common class for Mapped and similar ORM container classes. these are classes that can appear on the left side of an ORM declarative @@ -749,7 +750,7 @@ class _MappedAnnotationBase(Generic[_T], TypingOnly): class SQLORMExpression( - SQLORMOperations[_T], SQLColumnExpression[_T], TypingOnly + SQLORMOperations[_T_co], SQLColumnExpression[_T_co], TypingOnly ): """A type that may be used to indicate any ORM-level attribute or object that acts in place of one, in the context of SQL expression @@ -771,9 +772,9 @@ class SQLORMExpression( class Mapped( - SQLORMExpression[_T], - ORMDescriptor[_T], - _MappedAnnotationBase[_T], + SQLORMExpression[_T_co], + ORMDescriptor[_T_co], + _MappedAnnotationBase[_T_co], roles.DDLConstraintColumnRole, ): """Represent an ORM mapped attribute on a mapped class. @@ -819,24 +820,24 @@ class Mapped( @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... @classmethod - def _empty_constructor(cls, arg1: Any) -> Mapped[_T]: + def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: ... def __set__( - self, instance: Any, value: Union[SQLCoreOperations[_T], _T] + self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co] ) -> None: ... @@ -844,7 +845,7 @@ class Mapped( ... -class _MappedAttribute(Generic[_T], TypingOnly): +class _MappedAttribute(Generic[_T_co], TypingOnly): """Mixin for attributes which should be replaced by mapper-assigned attributes. @@ -853,7 +854,7 @@ class _MappedAttribute(Generic[_T], TypingOnly): __slots__ = () -class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]): +class _DeclarativeMapped(Mapped[_T_co], _MappedAttribute[_T_co]): """Mixin for :class:`.MapperProperty` subclasses that allows them to be compatible with ORM-annotated declarative mappings. @@ -878,7 +879,7 @@ class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]): return NotImplemented -class DynamicMapped(_MappedAnnotationBase[_T]): +class DynamicMapped(_MappedAnnotationBase[_T_co]): """Represent the ORM mapped attribute type for a "dynamic" relationship. The :class:`_orm.DynamicMapped` type annotation may be used in an @@ -918,23 +919,27 @@ class DynamicMapped(_MappedAnnotationBase[_T]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> AppenderQuery[_T]: + def __get__( + self, instance: object, owner: Any + ) -> AppenderQuery[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], AppenderQuery[_T]]: + ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: ... - def __set__(self, instance: Any, value: typing.Collection[_T]) -> None: + def __set__( + self, instance: Any, value: typing.Collection[_T_co] + ) -> None: ... -class WriteOnlyMapped(_MappedAnnotationBase[_T]): +class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): """Represent the ORM mapped attribute type for a "write only" relationship. The :class:`_orm.WriteOnlyMapped` type annotation may be used in an @@ -970,19 +975,21 @@ class WriteOnlyMapped(_MappedAnnotationBase[_T]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: + ) -> InstrumentedAttribute[_T_co]: ... @overload def __get__( self, instance: object, owner: Any - ) -> WriteOnlyCollection[_T]: + ) -> WriteOnlyCollection[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], WriteOnlyCollection[_T]]: + ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]: ... - def __set__(self, instance: Any, value: typing.Collection[_T]) -> None: + def __set__( + self, instance: Any, value: typing.Collection[_T_co] + ) -> None: ... diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 88cafcb645..daba973cb3 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -107,6 +107,7 @@ if typing.TYPE_CHECKING: _StrategyKey = Tuple[Any, ...] _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") @@ -653,7 +654,7 @@ class MapperProperty( @inspection._self_inspects -class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators): +class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): r"""Defines SQL operations for ORM mapped attributes. SQLAlchemy allows for operators to @@ -740,7 +741,7 @@ class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators): _parententity: _InternalEntityType[Any] _adapt_to_entity: Optional[AliasedInsp[Any]] - prop: RODescriptorReference[MapperProperty[_T]] + prop: RODescriptorReference[MapperProperty[_T_co]] def __init__( self, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index a6b4c8ab16..cc63b96abc 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -117,6 +117,7 @@ _NUMERIC = Union[float, Decimal] _NUMBER = Union[float, int, Decimal] _T = TypeVar("_T", bound="Any") +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") @@ -804,7 +805,7 @@ class CompilerColumnElement( # SQLCoreOperations should be suiting the ExpressionElementRole # and ColumnsClauseRole. however the MRO issues become too elaborate # at the moment. -class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): +class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): __slots__ = () # annotations for comparison methods @@ -873,7 +874,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def __or__(self, other: Any) -> BooleanClauseList: ... - def __invert__(self) -> ColumnElement[_T]: + def __invert__(self) -> ColumnElement[_T_co]: ... def __lt__(self, other: Any) -> ColumnElement[bool]: @@ -900,7 +901,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def __ge__(self, other: Any) -> ColumnElement[bool]: ... - def __neg__(self) -> UnaryExpression[_T]: + def __neg__(self) -> UnaryExpression[_T_co]: ... def __contains__(self, other: Any) -> ColumnElement[bool]: @@ -961,7 +962,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def bitwise_and(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_not(self) -> UnaryExpression[_T]: + def bitwise_not(self) -> UnaryExpression[_T_co]: ... def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]: @@ -1074,22 +1075,22 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): ) -> ColumnElement[str]: ... - def desc(self) -> UnaryExpression[_T]: + def desc(self) -> UnaryExpression[_T_co]: ... - def asc(self) -> UnaryExpression[_T]: + def asc(self) -> UnaryExpression[_T_co]: ... - def nulls_first(self) -> UnaryExpression[_T]: + def nulls_first(self) -> UnaryExpression[_T_co]: ... - def nullsfirst(self) -> UnaryExpression[_T]: + def nullsfirst(self) -> UnaryExpression[_T_co]: ... - def nulls_last(self) -> UnaryExpression[_T]: + def nulls_last(self) -> UnaryExpression[_T_co]: ... - def nullslast(self) -> UnaryExpression[_T]: + def nullslast(self) -> UnaryExpression[_T_co]: ... def collate(self, collation: str) -> CollationClause: @@ -1100,7 +1101,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): ) -> BinaryExpression[bool]: ... - def distinct(self: _SQO[_T]) -> UnaryExpression[_T]: + def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]: ... def any_(self) -> CollectionAggregate[Any]: @@ -1139,10 +1140,6 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: ... - @overload - def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]: - ... - @overload def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: ... @@ -1282,7 +1279,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): class SQLColumnExpression( - SQLCoreOperations[_T], roles.ExpressionElementRole[_T], TypingOnly + SQLCoreOperations[_T_co], roles.ExpressionElementRole[_T_co], TypingOnly ): """A type that may be used to indicate any SQL column element or object that acts in place of one. diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index f8aac70b99..6f29922432 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from .selectable import Subquery _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) class SQLRole: @@ -110,7 +111,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): raise NotImplementedError() -class TypedColumnsClauseRole(Generic[_T], SQLRole): +class TypedColumnsClauseRole(Generic[_T_co], SQLRole): """element-typed form of ColumnsClauseRole""" __slots__ = () @@ -162,7 +163,7 @@ class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" -class ExpressionElementRole(TypedColumnsClauseRole[_T]): +class ExpressionElementRole(TypedColumnsClauseRole[_T_co]): # note when using generics for ExpressionElementRole, # the generic type needs to be in # sqlalchemy.sql.coercions._impl_lookup mapping also. diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py new file mode 100644 index 0000000000..58cc1ab6c2 --- /dev/null +++ b/test/typing/plain_files/orm/mapped_covariant.py @@ -0,0 +1,53 @@ +"""Tests Mapped covariance.""" + +from typing import Protocol + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +# Protocols + + +class ParentProtocol(Protocol): + name: Mapped[str] + + +class ChildProtocol(Protocol): + # Read-only for simplicity, mutable protocol members are complicated, + # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected + @property + def parent(self) -> Mapped[ParentProtocol]: + ... + + +def get_parent_name(child: ChildProtocol) -> str: + return child.parent.name + + +# Implementations + + +class Base(DeclarativeBase): + pass + + +class Parent(Base): + __tablename__ = "parent" + + name: Mapped[str] = mapped_column(primary_key=True) + + +class Child(Base): + __tablename__ = "child" + + name: Mapped[str] = mapped_column(primary_key=True) + parent_name: Mapped[str] = mapped_column(ForeignKey(Parent.name)) + + parent: Mapped[Parent] = relationship() + + +assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo" -- 2.47.3