From: RomeoDespres Date: Tue, 29 Aug 2023 10:11:11 +0000 (-0400) Subject: Make `Mapped` covariant X-Git-Tag: rel_2_0_21~24^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2ec9b21fa744947319bfc49b6bdddc165487844a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Make `Mapped` covariant Made the contained type for :class:`.Mapped` covariant; this is to allow greater flexibility for end-user typing scenarios, such as the use of protocols to represent particular mapped class structures that are passed to other functions. As part of this change, the contained type was also made covariant for dependent and related types such as :class:`_orm.base.SQLORMOperations`, :class:`_orm.WriteOnlyMapped`, and :class:`_sql.SQLColumnExpression`. Pull request courtesy Roméo Després. within the change, there is a bit of adjustment to ``__radd__()`` to match the typing of ``__add__()``, which previously was slightly different for some reason and not passing on mypy with this change. Fixes: #10288 Closes: #10289 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10289 Pull-request-sha: 50eee7021cd29d59f52d8ff10c69d2970e1c1534 Change-Id: Ic55723a78b0b3b47dfff927d9ee0b94301272a6a --- diff --git a/doc/build/changelog/unreleased_20/10288.rst b/doc/build/changelog/unreleased_20/10288.rst new file mode 100644 index 0000000000..18b0bb0702 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10288.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, typing + :tickets: 10288 + + Made the contained type for :class:`.Mapped` covariant; this is to allow + greater flexibility for end-user typing scenarios, such as the use of + protocols to represent particular mapped class structures that are passed + to other functions. As part of this change, the contained type was also + made covariant for dependent and related types such as + :class:`_orm.base.SQLORMOperations`, :class:`_orm.WriteOnlyMapped`, and + :class:`_sql.SQLColumnExpression`. Pull request courtesy Roméo Després. + diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 83dfb50337..1ac6fafc11 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -927,10 +927,10 @@ class _HybridDeleterType(Protocol[_T_co]): ... -class _HybridExprCallableType(Protocol[_T]): +class _HybridExprCallableType(Protocol[_T_co]): def __call__( s, cls: Any - ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]: + ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]: ... diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 6a9766c6f7..b1bda22819 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -120,6 +120,7 @@ if TYPE_CHECKING: _T = TypeVar("_T") +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _AllPendingType = Sequence[ @@ -132,10 +133,10 @@ _UNKNOWN_ATTR_KEY = object() @inspection._self_inspects class QueryableAttribute( - _DeclarativeMapped[_T], - SQLORMExpression[_T], + _DeclarativeMapped[_T_co], + SQLORMExpression[_T_co], interfaces.InspectionAttr, - interfaces.PropComparator[_T], + interfaces.PropComparator[_T_co], roles.JoinTargetRole, roles.OnClauseRole, sql_base.Immutable, @@ -178,13 +179,13 @@ class QueryableAttribute( is_attribute = True - dispatch: dispatcher[QueryableAttribute[_T]] + dispatch: dispatcher[QueryableAttribute[_T_co]] class_: _ExternalEntityType[Any] key: str parententity: _InternalEntityType[Any] impl: AttributeImpl - comparator: interfaces.PropComparator[_T] + comparator: interfaces.PropComparator[_T_co] _of_type: Optional[_InternalEntityType[Any]] _extra_criteria: Tuple[ColumnElement[bool], ...] _doc: Optional[str] @@ -198,7 +199,7 @@ class QueryableAttribute( class_: _ExternalEntityType[_O], key: str, parententity: _InternalEntityType[_O], - comparator: interfaces.PropComparator[_T], + comparator: interfaces.PropComparator[_T_co], impl: Optional[AttributeImpl] = None, of_type: Optional[_InternalEntityType[Any]] = None, extra_criteria: Tuple[ColumnElement[bool], ...] = (), @@ -314,7 +315,7 @@ class QueryableAttribute( """ - expression: ColumnElement[_T] + expression: ColumnElement[_T_co] """The SQL expression object represented by this :class:`.QueryableAttribute`. @@ -376,7 +377,7 @@ class QueryableAttribute( def _annotations(self) -> _AnnotationDict: return self.__clause_element__()._annotations - def __clause_element__(self) -> ColumnElement[_T]: + def __clause_element__(self) -> ColumnElement[_T_co]: return self.expression @property @@ -443,7 +444,7 @@ class QueryableAttribute( extra_criteria=self._extra_criteria, ) - def label(self, name: Optional[str]) -> Label[_T]: + def label(self, name: Optional[str]) -> Label[_T_co]: return self.__clause_element__().label(name) def operate( diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index ecb10591a3..362346cc2a 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -56,6 +56,7 @@ if typing.TYPE_CHECKING: from ..sql.operators import OperatorType _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _O = TypeVar("_O", bound=object) @@ -678,12 +679,12 @@ class InspectionAttrInfo(InspectionAttr): return {} -class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): +class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly): __slots__ = () if typing.TYPE_CHECKING: - def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: ... def and_( @@ -706,7 +707,7 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly): ... -class ORMDescriptor(Generic[_T], TypingOnly): +class ORMDescriptor(Generic[_T_co], TypingOnly): """Represent any Python descriptor that provides a SQL expression construct at the class level.""" @@ -717,26 +718,26 @@ class ORMDescriptor(Generic[_T], TypingOnly): @overload def __get__( self, instance: Any, owner: Literal[None] - ) -> ORMDescriptor[_T]: + ) -> ORMDescriptor[_T_co]: ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> SQLCoreOperations[_T]: + ) -> SQLCoreOperations[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: object, owner: Any - ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]: + ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: ... -class _MappedAnnotationBase(Generic[_T], TypingOnly): +class _MappedAnnotationBase(Generic[_T_co], TypingOnly): """common class for Mapped and similar ORM container classes. these are classes that can appear on the left side of an ORM declarative @@ -749,7 +750,7 @@ class _MappedAnnotationBase(Generic[_T], TypingOnly): class SQLORMExpression( - SQLORMOperations[_T], SQLColumnExpression[_T], TypingOnly + SQLORMOperations[_T_co], SQLColumnExpression[_T_co], TypingOnly ): """A type that may be used to indicate any ORM-level attribute or object that acts in place of one, in the context of SQL expression @@ -771,9 +772,9 @@ class SQLORMExpression( class Mapped( - SQLORMExpression[_T], - ORMDescriptor[_T], - _MappedAnnotationBase[_T], + SQLORMExpression[_T_co], + ORMDescriptor[_T_co], + _MappedAnnotationBase[_T_co], roles.DDLConstraintColumnRole, ): """Represent an ORM mapped attribute on a mapped class. @@ -819,24 +820,24 @@ class Mapped( @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... @classmethod - def _empty_constructor(cls, arg1: Any) -> Mapped[_T]: + def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: ... def __set__( - self, instance: Any, value: Union[SQLCoreOperations[_T], _T] + self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co] ) -> None: ... @@ -844,7 +845,7 @@ class Mapped( ... -class _MappedAttribute(Generic[_T], TypingOnly): +class _MappedAttribute(Generic[_T_co], TypingOnly): """Mixin for attributes which should be replaced by mapper-assigned attributes. @@ -853,7 +854,7 @@ class _MappedAttribute(Generic[_T], TypingOnly): __slots__ = () -class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]): +class _DeclarativeMapped(Mapped[_T_co], _MappedAttribute[_T_co]): """Mixin for :class:`.MapperProperty` subclasses that allows them to be compatible with ORM-annotated declarative mappings. @@ -878,7 +879,7 @@ class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]): return NotImplemented -class DynamicMapped(_MappedAnnotationBase[_T]): +class DynamicMapped(_MappedAnnotationBase[_T_co]): """Represent the ORM mapped attribute type for a "dynamic" relationship. The :class:`_orm.DynamicMapped` type annotation may be used in an @@ -918,23 +919,27 @@ class DynamicMapped(_MappedAnnotationBase[_T]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> AppenderQuery[_T]: + def __get__( + self, instance: object, owner: Any + ) -> AppenderQuery[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], AppenderQuery[_T]]: + ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: ... - def __set__(self, instance: Any, value: typing.Collection[_T]) -> None: + def __set__( + self, instance: Any, value: typing.Collection[_T_co] + ) -> None: ... -class WriteOnlyMapped(_MappedAnnotationBase[_T]): +class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): """Represent the ORM mapped attribute type for a "write only" relationship. The :class:`_orm.WriteOnlyMapped` type annotation may be used in an @@ -970,19 +975,21 @@ class WriteOnlyMapped(_MappedAnnotationBase[_T]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: + ) -> InstrumentedAttribute[_T_co]: ... @overload def __get__( self, instance: object, owner: Any - ) -> WriteOnlyCollection[_T]: + ) -> WriteOnlyCollection[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], WriteOnlyCollection[_T]]: + ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]: ... - def __set__(self, instance: Any, value: typing.Collection[_T]) -> None: + def __set__( + self, instance: Any, value: typing.Collection[_T_co] + ) -> None: ... diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 88cafcb645..daba973cb3 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -107,6 +107,7 @@ if typing.TYPE_CHECKING: _StrategyKey = Tuple[Any, ...] _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") @@ -653,7 +654,7 @@ class MapperProperty( @inspection._self_inspects -class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators): +class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): r"""Defines SQL operations for ORM mapped attributes. SQLAlchemy allows for operators to @@ -740,7 +741,7 @@ class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators): _parententity: _InternalEntityType[Any] _adapt_to_entity: Optional[AliasedInsp[Any]] - prop: RODescriptorReference[MapperProperty[_T]] + prop: RODescriptorReference[MapperProperty[_T_co]] def __init__( self, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index a6b4c8ab16..3917b5f023 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -117,6 +117,7 @@ _NUMERIC = Union[float, Decimal] _NUMBER = Union[float, int, Decimal] _T = TypeVar("_T", bound="Any") +_T_co = TypeVar("_T_co", bound=Any, covariant=True) _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") @@ -804,7 +805,7 @@ class CompilerColumnElement( # SQLCoreOperations should be suiting the ExpressionElementRole # and ColumnsClauseRole. however the MRO issues become too elaborate # at the moment. -class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): +class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): __slots__ = () # annotations for comparison methods @@ -873,7 +874,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def __or__(self, other: Any) -> BooleanClauseList: ... - def __invert__(self) -> ColumnElement[_T]: + def __invert__(self) -> ColumnElement[_T_co]: ... def __lt__(self, other: Any) -> ColumnElement[bool]: @@ -900,7 +901,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def __ge__(self, other: Any) -> ColumnElement[bool]: ... - def __neg__(self) -> UnaryExpression[_T]: + def __neg__(self) -> UnaryExpression[_T_co]: ... def __contains__(self, other: Any) -> ColumnElement[bool]: @@ -961,7 +962,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): def bitwise_and(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_not(self) -> UnaryExpression[_T]: + def bitwise_not(self) -> UnaryExpression[_T_co]: ... def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]: @@ -1074,22 +1075,22 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): ) -> ColumnElement[str]: ... - def desc(self) -> UnaryExpression[_T]: + def desc(self) -> UnaryExpression[_T_co]: ... - def asc(self) -> UnaryExpression[_T]: + def asc(self) -> UnaryExpression[_T_co]: ... - def nulls_first(self) -> UnaryExpression[_T]: + def nulls_first(self) -> UnaryExpression[_T_co]: ... - def nullsfirst(self) -> UnaryExpression[_T]: + def nullsfirst(self) -> UnaryExpression[_T_co]: ... - def nulls_last(self) -> UnaryExpression[_T]: + def nulls_last(self) -> UnaryExpression[_T_co]: ... - def nullslast(self) -> UnaryExpression[_T]: + def nullslast(self) -> UnaryExpression[_T_co]: ... def collate(self, collation: str) -> CollationClause: @@ -1100,7 +1101,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): ) -> BinaryExpression[bool]: ... - def distinct(self: _SQO[_T]) -> UnaryExpression[_T]: + def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]: ... def any_(self) -> CollectionAggregate[Any]: @@ -1128,19 +1129,11 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): ) -> ColumnElement[str]: ... - @overload - def __add__(self, other: Any) -> ColumnElement[Any]: - ... - def __add__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: - ... - - @overload - def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]: + def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload @@ -1282,7 +1275,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly): class SQLColumnExpression( - SQLCoreOperations[_T], roles.ExpressionElementRole[_T], TypingOnly + SQLCoreOperations[_T_co], roles.ExpressionElementRole[_T_co], TypingOnly ): """A type that may be used to indicate any SQL column element or object that acts in place of one. diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index f8aac70b99..6f29922432 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from .selectable import Subquery _T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) class SQLRole: @@ -110,7 +111,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole): raise NotImplementedError() -class TypedColumnsClauseRole(Generic[_T], SQLRole): +class TypedColumnsClauseRole(Generic[_T_co], SQLRole): """element-typed form of ColumnsClauseRole""" __slots__ = () @@ -162,7 +163,7 @@ class WhereHavingRole(OnClauseRole): _role_name = "SQL expression for WHERE/HAVING role" -class ExpressionElementRole(TypedColumnsClauseRole[_T]): +class ExpressionElementRole(TypedColumnsClauseRole[_T_co]): # note when using generics for ExpressionElementRole, # the generic type needs to be in # sqlalchemy.sql.coercions._impl_lookup mapping also. diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py new file mode 100644 index 0000000000..58cc1ab6c2 --- /dev/null +++ b/test/typing/plain_files/orm/mapped_covariant.py @@ -0,0 +1,53 @@ +"""Tests Mapped covariance.""" + +from typing import Protocol + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +# Protocols + + +class ParentProtocol(Protocol): + name: Mapped[str] + + +class ChildProtocol(Protocol): + # Read-only for simplicity, mutable protocol members are complicated, + # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected + @property + def parent(self) -> Mapped[ParentProtocol]: + ... + + +def get_parent_name(child: ChildProtocol) -> str: + return child.parent.name + + +# Implementations + + +class Base(DeclarativeBase): + pass + + +class Parent(Base): + __tablename__ = "parent" + + name: Mapped[str] = mapped_column(primary_key=True) + + +class Child(Base): + __tablename__ = "child" + + name: Mapped[str] = mapped_column(primary_key=True) + parent_name: Mapped[str] = mapped_column(ForeignKey(Parent.name)) + + parent: Mapped[Parent] = relationship() + + +assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo"