]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Make `Mapped` covariant
authorRomeoDespres <Romeo.Despres@warnermusic.com>
Tue, 29 Aug 2023 10:11:11 +0000 (06:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 30 Aug 2023 14:57:08 +0000 (10:57 -0400)
Made the contained type for :class:`.Mapped` covariant; this is to allow
greater flexibility for end-user typing scenarios, such as the use of
protocols to represent particular mapped class structures that are passed
to other functions. As part of this change, the contained type was also
made covariant for dependent and related types such as
:class:`_orm.base.SQLORMOperations`, :class:`_orm.WriteOnlyMapped`, and
:class:`_sql.SQLColumnExpression`. Pull request courtesy Roméo Després.

within the change, there is a bit of adjustment to ``__radd__()`` to
match the typing of ``__add__()``, which previously was slightly
different for some reason and not passing on mypy with this change.

Fixes: #10288
Closes: #10289
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10289
Pull-request-sha: 50eee7021cd29d59f52d8ff10c69d2970e1c1534

Change-Id: Ic55723a78b0b3b47dfff927d9ee0b94301272a6a

doc/build/changelog/unreleased_20/10288.rst [new file with mode: 0644]
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/roles.py
test/typing/plain_files/orm/mapped_covariant.py [new file with mode: 0644]

diff --git a/doc/build/changelog/unreleased_20/10288.rst b/doc/build/changelog/unreleased_20/10288.rst
new file mode 100644 (file)
index 0000000..18b0bb0
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, typing
+    :tickets: 10288
+
+    Made the contained type for :class:`.Mapped` covariant; this is to allow
+    greater flexibility for end-user typing scenarios, such as the use of
+    protocols to represent particular mapped class structures that are passed
+    to other functions. As part of this change, the contained type was also
+    made covariant for dependent and related types such as
+    :class:`_orm.base.SQLORMOperations`, :class:`_orm.WriteOnlyMapped`, and
+    :class:`_sql.SQLColumnExpression`. Pull request courtesy Roméo Després.
+
index 83dfb5033791d7187ee696710ea576b9e321a98c..1ac6fafc11a461f1b72902dac7488e9aa6a3304f 100644 (file)
@@ -927,10 +927,10 @@ class _HybridDeleterType(Protocol[_T_co]):
         ...
 
 
-class _HybridExprCallableType(Protocol[_T]):
+class _HybridExprCallableType(Protocol[_T_co]):
     def __call__(
         s, cls: Any
-    ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]:
+    ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]:
         ...
 
 
index 6a9766c6f70b791d9e99cc8994eff6d2109354ce..b1bda2281945f4e0a63da3d833c87b842c2446a9 100644 (file)
@@ -120,6 +120,7 @@ if TYPE_CHECKING:
 
 
 _T = TypeVar("_T")
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 
 _AllPendingType = Sequence[
@@ -132,10 +133,10 @@ _UNKNOWN_ATTR_KEY = object()
 
 @inspection._self_inspects
 class QueryableAttribute(
-    _DeclarativeMapped[_T],
-    SQLORMExpression[_T],
+    _DeclarativeMapped[_T_co],
+    SQLORMExpression[_T_co],
     interfaces.InspectionAttr,
-    interfaces.PropComparator[_T],
+    interfaces.PropComparator[_T_co],
     roles.JoinTargetRole,
     roles.OnClauseRole,
     sql_base.Immutable,
@@ -178,13 +179,13 @@ class QueryableAttribute(
 
     is_attribute = True
 
-    dispatch: dispatcher[QueryableAttribute[_T]]
+    dispatch: dispatcher[QueryableAttribute[_T_co]]
 
     class_: _ExternalEntityType[Any]
     key: str
     parententity: _InternalEntityType[Any]
     impl: AttributeImpl
-    comparator: interfaces.PropComparator[_T]
+    comparator: interfaces.PropComparator[_T_co]
     _of_type: Optional[_InternalEntityType[Any]]
     _extra_criteria: Tuple[ColumnElement[bool], ...]
     _doc: Optional[str]
@@ -198,7 +199,7 @@ class QueryableAttribute(
         class_: _ExternalEntityType[_O],
         key: str,
         parententity: _InternalEntityType[_O],
-        comparator: interfaces.PropComparator[_T],
+        comparator: interfaces.PropComparator[_T_co],
         impl: Optional[AttributeImpl] = None,
         of_type: Optional[_InternalEntityType[Any]] = None,
         extra_criteria: Tuple[ColumnElement[bool], ...] = (),
@@ -314,7 +315,7 @@ class QueryableAttribute(
 
     """
 
-    expression: ColumnElement[_T]
+    expression: ColumnElement[_T_co]
     """The SQL expression object represented by this
     :class:`.QueryableAttribute`.
 
@@ -376,7 +377,7 @@ class QueryableAttribute(
     def _annotations(self) -> _AnnotationDict:
         return self.__clause_element__()._annotations
 
-    def __clause_element__(self) -> ColumnElement[_T]:
+    def __clause_element__(self) -> ColumnElement[_T_co]:
         return self.expression
 
     @property
@@ -443,7 +444,7 @@ class QueryableAttribute(
             extra_criteria=self._extra_criteria,
         )
 
-    def label(self, name: Optional[str]) -> Label[_T]:
+    def label(self, name: Optional[str]) -> Label[_T_co]:
         return self.__clause_element__().label(name)
 
     def operate(
index ecb10591a3a68ef21bdb4033683e9571c9a3b4bc..362346cc2a8a784ba2ea309cf26f4f014eea7f33 100644 (file)
@@ -56,6 +56,7 @@ if typing.TYPE_CHECKING:
     from ..sql.operators import OperatorType
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 _O = TypeVar("_O", bound=object)
 
@@ -678,12 +679,12 @@ class InspectionAttrInfo(InspectionAttr):
         return {}
 
 
-class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
+class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly):
     __slots__ = ()
 
     if typing.TYPE_CHECKING:
 
-        def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
+        def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]:
             ...
 
         def and_(
@@ -706,7 +707,7 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
             ...
 
 
-class ORMDescriptor(Generic[_T], TypingOnly):
+class ORMDescriptor(Generic[_T_co], TypingOnly):
     """Represent any Python descriptor that provides a SQL expression
     construct at the class level."""
 
@@ -717,26 +718,26 @@ class ORMDescriptor(Generic[_T], TypingOnly):
         @overload
         def __get__(
             self, instance: Any, owner: Literal[None]
-        ) -> ORMDescriptor[_T]:
+        ) -> ORMDescriptor[_T_co]:
             ...
 
         @overload
         def __get__(
             self, instance: Literal[None], owner: Any
-        ) -> SQLCoreOperations[_T]:
+        ) -> SQLCoreOperations[_T_co]:
             ...
 
         @overload
-        def __get__(self, instance: object, owner: Any) -> _T:
+        def __get__(self, instance: object, owner: Any) -> _T_co:
             ...
 
         def __get__(
             self, instance: object, owner: Any
-        ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]:
+        ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]:
             ...
 
 
-class _MappedAnnotationBase(Generic[_T], TypingOnly):
+class _MappedAnnotationBase(Generic[_T_co], TypingOnly):
     """common class for Mapped and similar ORM container classes.
 
     these are classes that can appear on the left side of an ORM declarative
@@ -749,7 +750,7 @@ class _MappedAnnotationBase(Generic[_T], TypingOnly):
 
 
 class SQLORMExpression(
-    SQLORMOperations[_T], SQLColumnExpression[_T], TypingOnly
+    SQLORMOperations[_T_co], SQLColumnExpression[_T_co], TypingOnly
 ):
     """A type that may be used to indicate any ORM-level attribute or
     object that acts in place of one, in the context of SQL expression
@@ -771,9 +772,9 @@ class SQLORMExpression(
 
 
 class Mapped(
-    SQLORMExpression[_T],
-    ORMDescriptor[_T],
-    _MappedAnnotationBase[_T],
+    SQLORMExpression[_T_co],
+    ORMDescriptor[_T_co],
+    _MappedAnnotationBase[_T_co],
     roles.DDLConstraintColumnRole,
 ):
     """Represent an ORM mapped attribute on a mapped class.
@@ -819,24 +820,24 @@ class Mapped(
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> InstrumentedAttribute[_T]:
+        ) -> InstrumentedAttribute[_T_co]:
             ...
 
         @overload
-        def __get__(self, instance: object, owner: Any) -> _T:
+        def __get__(self, instance: object, owner: Any) -> _T_co:
             ...
 
         def __get__(
             self, instance: Optional[object], owner: Any
-        ) -> Union[InstrumentedAttribute[_T], _T]:
+        ) -> Union[InstrumentedAttribute[_T_co], _T_co]:
             ...
 
         @classmethod
-        def _empty_constructor(cls, arg1: Any) -> Mapped[_T]:
+        def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]:
             ...
 
         def __set__(
-            self, instance: Any, value: Union[SQLCoreOperations[_T], _T]
+            self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co]
         ) -> None:
             ...
 
@@ -844,7 +845,7 @@ class Mapped(
             ...
 
 
-class _MappedAttribute(Generic[_T], TypingOnly):
+class _MappedAttribute(Generic[_T_co], TypingOnly):
     """Mixin for attributes which should be replaced by mapper-assigned
     attributes.
 
@@ -853,7 +854,7 @@ class _MappedAttribute(Generic[_T], TypingOnly):
     __slots__ = ()
 
 
-class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]):
+class _DeclarativeMapped(Mapped[_T_co], _MappedAttribute[_T_co]):
     """Mixin for :class:`.MapperProperty` subclasses that allows them to
     be compatible with ORM-annotated declarative mappings.
 
@@ -878,7 +879,7 @@ class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]):
         return NotImplemented
 
 
-class DynamicMapped(_MappedAnnotationBase[_T]):
+class DynamicMapped(_MappedAnnotationBase[_T_co]):
     """Represent the ORM mapped attribute type for a "dynamic" relationship.
 
     The :class:`_orm.DynamicMapped` type annotation may be used in an
@@ -918,23 +919,27 @@ class DynamicMapped(_MappedAnnotationBase[_T]):
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> InstrumentedAttribute[_T]:
+        ) -> InstrumentedAttribute[_T_co]:
             ...
 
         @overload
-        def __get__(self, instance: object, owner: Any) -> AppenderQuery[_T]:
+        def __get__(
+            self, instance: object, owner: Any
+        ) -> AppenderQuery[_T_co]:
             ...
 
         def __get__(
             self, instance: Optional[object], owner: Any
-        ) -> Union[InstrumentedAttribute[_T], AppenderQuery[_T]]:
+        ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]:
             ...
 
-        def __set__(self, instance: Any, value: typing.Collection[_T]) -> None:
+        def __set__(
+            self, instance: Any, value: typing.Collection[_T_co]
+        ) -> None:
             ...
 
 
-class WriteOnlyMapped(_MappedAnnotationBase[_T]):
+class WriteOnlyMapped(_MappedAnnotationBase[_T_co]):
     """Represent the ORM mapped attribute type for a "write only" relationship.
 
     The :class:`_orm.WriteOnlyMapped` type annotation may be used in an
@@ -970,19 +975,21 @@ class WriteOnlyMapped(_MappedAnnotationBase[_T]):
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> InstrumentedAttribute[_T]:
+        ) -> InstrumentedAttribute[_T_co]:
             ...
 
         @overload
         def __get__(
             self, instance: object, owner: Any
-        ) -> WriteOnlyCollection[_T]:
+        ) -> WriteOnlyCollection[_T_co]:
             ...
 
         def __get__(
             self, instance: Optional[object], owner: Any
-        ) -> Union[InstrumentedAttribute[_T], WriteOnlyCollection[_T]]:
+        ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]:
             ...
 
-        def __set__(self, instance: Any, value: typing.Collection[_T]) -> None:
+        def __set__(
+            self, instance: Any, value: typing.Collection[_T_co]
+        ) -> None:
             ...
index 88cafcb64524081e761d43ef710e6fd6507bb896..daba973cb3a4e4d18624a50bd0d4524f4a3c0ef9 100644 (file)
@@ -107,6 +107,7 @@ if typing.TYPE_CHECKING:
 _StrategyKey = Tuple[Any, ...]
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 _TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]")
 
@@ -653,7 +654,7 @@ class MapperProperty(
 
 
 @inspection._self_inspects
-class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators):
+class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators):
     r"""Defines SQL operations for ORM mapped attributes.
 
     SQLAlchemy allows for operators to
@@ -740,7 +741,7 @@ class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators):
 
     _parententity: _InternalEntityType[Any]
     _adapt_to_entity: Optional[AliasedInsp[Any]]
-    prop: RODescriptorReference[MapperProperty[_T]]
+    prop: RODescriptorReference[MapperProperty[_T_co]]
 
     def __init__(
         self,
index a6b4c8ab16f1e9afd020fb26df3efe8a60fddc24..3917b5f02397d90403c49091b0fa10c9ec58e95c 100644 (file)
@@ -117,6 +117,7 @@ _NUMERIC = Union[float, Decimal]
 _NUMBER = Union[float, int, Decimal]
 
 _T = TypeVar("_T", bound="Any")
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 _OPT = TypeVar("_OPT", bound="Any")
 _NT = TypeVar("_NT", bound="_NUMERIC")
 
@@ -804,7 +805,7 @@ class CompilerColumnElement(
 # SQLCoreOperations should be suiting the ExpressionElementRole
 # and ColumnsClauseRole.   however the MRO issues become too elaborate
 # at the moment.
-class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
+class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
     __slots__ = ()
 
     # annotations for comparison methods
@@ -873,7 +874,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def __or__(self, other: Any) -> BooleanClauseList:
             ...
 
-        def __invert__(self) -> ColumnElement[_T]:
+        def __invert__(self) -> ColumnElement[_T_co]:
             ...
 
         def __lt__(self, other: Any) -> ColumnElement[bool]:
@@ -900,7 +901,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def __ge__(self, other: Any) -> ColumnElement[bool]:
             ...
 
-        def __neg__(self) -> UnaryExpression[_T]:
+        def __neg__(self) -> UnaryExpression[_T_co]:
             ...
 
         def __contains__(self, other: Any) -> ColumnElement[bool]:
@@ -961,7 +962,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def bitwise_and(self, other: Any) -> BinaryExpression[Any]:
             ...
 
-        def bitwise_not(self) -> UnaryExpression[_T]:
+        def bitwise_not(self) -> UnaryExpression[_T_co]:
             ...
 
         def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]:
@@ -1074,22 +1075,22 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         ) -> ColumnElement[str]:
             ...
 
-        def desc(self) -> UnaryExpression[_T]:
+        def desc(self) -> UnaryExpression[_T_co]:
             ...
 
-        def asc(self) -> UnaryExpression[_T]:
+        def asc(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nulls_first(self) -> UnaryExpression[_T]:
+        def nulls_first(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nullsfirst(self) -> UnaryExpression[_T]:
+        def nullsfirst(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nulls_last(self) -> UnaryExpression[_T]:
+        def nulls_last(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nullslast(self) -> UnaryExpression[_T]:
+        def nullslast(self) -> UnaryExpression[_T_co]:
             ...
 
         def collate(self, collation: str) -> CollationClause:
@@ -1100,7 +1101,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         ) -> BinaryExpression[bool]:
             ...
 
-        def distinct(self: _SQO[_T]) -> UnaryExpression[_T]:
+        def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]:
             ...
 
         def any_(self) -> CollectionAggregate[Any]:
@@ -1128,19 +1129,11 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         ) -> ColumnElement[str]:
             ...
 
-        @overload
-        def __add__(self, other: Any) -> ColumnElement[Any]:
-            ...
-
         def __add__(self, other: Any) -> ColumnElement[Any]:
             ...
 
         @overload
-        def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
-            ...
-
-        @overload
-        def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]:
+        def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]:
             ...
 
         @overload
@@ -1282,7 +1275,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
 
 
 class SQLColumnExpression(
-    SQLCoreOperations[_T], roles.ExpressionElementRole[_T], TypingOnly
+    SQLCoreOperations[_T_co], roles.ExpressionElementRole[_T_co], TypingOnly
 ):
     """A type that may be used to indicate any SQL column element or object
     that acts in place of one.
index f8aac70b998dc3aad81fc12f85dc1505d2444ff1..6f299224328a7a1137bfa1670347c54ecd79a49d 100644 (file)
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
     from .selectable import Subquery
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 
 class SQLRole:
@@ -110,7 +111,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
         raise NotImplementedError()
 
 
-class TypedColumnsClauseRole(Generic[_T], SQLRole):
+class TypedColumnsClauseRole(Generic[_T_co], SQLRole):
     """element-typed form of ColumnsClauseRole"""
 
     __slots__ = ()
@@ -162,7 +163,7 @@ class WhereHavingRole(OnClauseRole):
     _role_name = "SQL expression for WHERE/HAVING role"
 
 
-class ExpressionElementRole(TypedColumnsClauseRole[_T]):
+class ExpressionElementRole(TypedColumnsClauseRole[_T_co]):
     # note when using generics for ExpressionElementRole,
     # the generic type needs to be in
     # sqlalchemy.sql.coercions._impl_lookup mapping also.
diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py
new file mode 100644 (file)
index 0000000..58cc1ab
--- /dev/null
@@ -0,0 +1,53 @@
+"""Tests Mapped covariance."""\r
+\r
+from typing import Protocol\r
+\r
+from sqlalchemy import ForeignKey\r
+from sqlalchemy.orm import DeclarativeBase\r
+from sqlalchemy.orm import Mapped\r
+from sqlalchemy.orm import mapped_column\r
+from sqlalchemy.orm import relationship\r
+\r
+\r
+# Protocols\r
+\r
+\r
+class ParentProtocol(Protocol):\r
+    name: Mapped[str]\r
+\r
+\r
+class ChildProtocol(Protocol):\r
+    # Read-only for simplicity, mutable protocol members are complicated,\r
+    # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected\r
+    @property\r
+    def parent(self) -> Mapped[ParentProtocol]:\r
+        ...\r
+\r
+\r
+def get_parent_name(child: ChildProtocol) -> str:\r
+    return child.parent.name\r
+\r
+\r
+# Implementations\r
+\r
+\r
+class Base(DeclarativeBase):\r
+    pass\r
+\r
+\r
+class Parent(Base):\r
+    __tablename__ = "parent"\r
+\r
+    name: Mapped[str] = mapped_column(primary_key=True)\r
+\r
+\r
+class Child(Base):\r
+    __tablename__ = "child"\r
+\r
+    name: Mapped[str] = mapped_column(primary_key=True)\r
+    parent_name: Mapped[str] = mapped_column(ForeignKey(Parent.name))\r
+\r
+    parent: Mapped[Parent] = relationship()\r
+\r
+\r
+assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo"\r