]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Make `Mapped` covariant (Fixes #10288)
authorRomeoDespres <Romeo.Despres@warnermusic.com>
Mon, 28 Aug 2023 14:38:04 +0000 (16:38 +0200)
committerRomeoDespres <Romeo.Despres@warnermusic.com>
Mon, 28 Aug 2023 16:11:41 +0000 (18:11 +0200)
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/roles.py
test/typing/plain_files/orm/mapped_covariant.py [new file with mode: 0644]

index 83dfb5033791d7187ee696710ea576b9e321a98c..1ac6fafc11a461f1b72902dac7488e9aa6a3304f 100644 (file)
@@ -927,10 +927,10 @@ class _HybridDeleterType(Protocol[_T_co]):
         ...
 
 
-class _HybridExprCallableType(Protocol[_T]):
+class _HybridExprCallableType(Protocol[_T_co]):
     def __call__(
         s, cls: Any
-    ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]:
+    ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]:
         ...
 
 
index 6a9766c6f70b791d9e99cc8994eff6d2109354ce..b1bda2281945f4e0a63da3d833c87b842c2446a9 100644 (file)
@@ -120,6 +120,7 @@ if TYPE_CHECKING:
 
 
 _T = TypeVar("_T")
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 
 _AllPendingType = Sequence[
@@ -132,10 +133,10 @@ _UNKNOWN_ATTR_KEY = object()
 
 @inspection._self_inspects
 class QueryableAttribute(
-    _DeclarativeMapped[_T],
-    SQLORMExpression[_T],
+    _DeclarativeMapped[_T_co],
+    SQLORMExpression[_T_co],
     interfaces.InspectionAttr,
-    interfaces.PropComparator[_T],
+    interfaces.PropComparator[_T_co],
     roles.JoinTargetRole,
     roles.OnClauseRole,
     sql_base.Immutable,
@@ -178,13 +179,13 @@ class QueryableAttribute(
 
     is_attribute = True
 
-    dispatch: dispatcher[QueryableAttribute[_T]]
+    dispatch: dispatcher[QueryableAttribute[_T_co]]
 
     class_: _ExternalEntityType[Any]
     key: str
     parententity: _InternalEntityType[Any]
     impl: AttributeImpl
-    comparator: interfaces.PropComparator[_T]
+    comparator: interfaces.PropComparator[_T_co]
     _of_type: Optional[_InternalEntityType[Any]]
     _extra_criteria: Tuple[ColumnElement[bool], ...]
     _doc: Optional[str]
@@ -198,7 +199,7 @@ class QueryableAttribute(
         class_: _ExternalEntityType[_O],
         key: str,
         parententity: _InternalEntityType[_O],
-        comparator: interfaces.PropComparator[_T],
+        comparator: interfaces.PropComparator[_T_co],
         impl: Optional[AttributeImpl] = None,
         of_type: Optional[_InternalEntityType[Any]] = None,
         extra_criteria: Tuple[ColumnElement[bool], ...] = (),
@@ -314,7 +315,7 @@ class QueryableAttribute(
 
     """
 
-    expression: ColumnElement[_T]
+    expression: ColumnElement[_T_co]
     """The SQL expression object represented by this
     :class:`.QueryableAttribute`.
 
@@ -376,7 +377,7 @@ class QueryableAttribute(
     def _annotations(self) -> _AnnotationDict:
         return self.__clause_element__()._annotations
 
-    def __clause_element__(self) -> ColumnElement[_T]:
+    def __clause_element__(self) -> ColumnElement[_T_co]:
         return self.expression
 
     @property
@@ -443,7 +444,7 @@ class QueryableAttribute(
             extra_criteria=self._extra_criteria,
         )
 
-    def label(self, name: Optional[str]) -> Label[_T]:
+    def label(self, name: Optional[str]) -> Label[_T_co]:
         return self.__clause_element__().label(name)
 
     def operate(
index ecb10591a3a68ef21bdb4033683e9571c9a3b4bc..362346cc2a8a784ba2ea309cf26f4f014eea7f33 100644 (file)
@@ -56,6 +56,7 @@ if typing.TYPE_CHECKING:
     from ..sql.operators import OperatorType
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 _O = TypeVar("_O", bound=object)
 
@@ -678,12 +679,12 @@ class InspectionAttrInfo(InspectionAttr):
         return {}
 
 
-class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
+class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly):
     __slots__ = ()
 
     if typing.TYPE_CHECKING:
 
-        def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]:
+        def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]:
             ...
 
         def and_(
@@ -706,7 +707,7 @@ class SQLORMOperations(SQLCoreOperations[_T], TypingOnly):
             ...
 
 
-class ORMDescriptor(Generic[_T], TypingOnly):
+class ORMDescriptor(Generic[_T_co], TypingOnly):
     """Represent any Python descriptor that provides a SQL expression
     construct at the class level."""
 
@@ -717,26 +718,26 @@ class ORMDescriptor(Generic[_T], TypingOnly):
         @overload
         def __get__(
             self, instance: Any, owner: Literal[None]
-        ) -> ORMDescriptor[_T]:
+        ) -> ORMDescriptor[_T_co]:
             ...
 
         @overload
         def __get__(
             self, instance: Literal[None], owner: Any
-        ) -> SQLCoreOperations[_T]:
+        ) -> SQLCoreOperations[_T_co]:
             ...
 
         @overload
-        def __get__(self, instance: object, owner: Any) -> _T:
+        def __get__(self, instance: object, owner: Any) -> _T_co:
             ...
 
         def __get__(
             self, instance: object, owner: Any
-        ) -> Union[ORMDescriptor[_T], SQLCoreOperations[_T], _T]:
+        ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]:
             ...
 
 
-class _MappedAnnotationBase(Generic[_T], TypingOnly):
+class _MappedAnnotationBase(Generic[_T_co], TypingOnly):
     """common class for Mapped and similar ORM container classes.
 
     these are classes that can appear on the left side of an ORM declarative
@@ -749,7 +750,7 @@ class _MappedAnnotationBase(Generic[_T], TypingOnly):
 
 
 class SQLORMExpression(
-    SQLORMOperations[_T], SQLColumnExpression[_T], TypingOnly
+    SQLORMOperations[_T_co], SQLColumnExpression[_T_co], TypingOnly
 ):
     """A type that may be used to indicate any ORM-level attribute or
     object that acts in place of one, in the context of SQL expression
@@ -771,9 +772,9 @@ class SQLORMExpression(
 
 
 class Mapped(
-    SQLORMExpression[_T],
-    ORMDescriptor[_T],
-    _MappedAnnotationBase[_T],
+    SQLORMExpression[_T_co],
+    ORMDescriptor[_T_co],
+    _MappedAnnotationBase[_T_co],
     roles.DDLConstraintColumnRole,
 ):
     """Represent an ORM mapped attribute on a mapped class.
@@ -819,24 +820,24 @@ class Mapped(
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> InstrumentedAttribute[_T]:
+        ) -> InstrumentedAttribute[_T_co]:
             ...
 
         @overload
-        def __get__(self, instance: object, owner: Any) -> _T:
+        def __get__(self, instance: object, owner: Any) -> _T_co:
             ...
 
         def __get__(
             self, instance: Optional[object], owner: Any
-        ) -> Union[InstrumentedAttribute[_T], _T]:
+        ) -> Union[InstrumentedAttribute[_T_co], _T_co]:
             ...
 
         @classmethod
-        def _empty_constructor(cls, arg1: Any) -> Mapped[_T]:
+        def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]:
             ...
 
         def __set__(
-            self, instance: Any, value: Union[SQLCoreOperations[_T], _T]
+            self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co]
         ) -> None:
             ...
 
@@ -844,7 +845,7 @@ class Mapped(
             ...
 
 
-class _MappedAttribute(Generic[_T], TypingOnly):
+class _MappedAttribute(Generic[_T_co], TypingOnly):
     """Mixin for attributes which should be replaced by mapper-assigned
     attributes.
 
@@ -853,7 +854,7 @@ class _MappedAttribute(Generic[_T], TypingOnly):
     __slots__ = ()
 
 
-class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]):
+class _DeclarativeMapped(Mapped[_T_co], _MappedAttribute[_T_co]):
     """Mixin for :class:`.MapperProperty` subclasses that allows them to
     be compatible with ORM-annotated declarative mappings.
 
@@ -878,7 +879,7 @@ class _DeclarativeMapped(Mapped[_T], _MappedAttribute[_T]):
         return NotImplemented
 
 
-class DynamicMapped(_MappedAnnotationBase[_T]):
+class DynamicMapped(_MappedAnnotationBase[_T_co]):
     """Represent the ORM mapped attribute type for a "dynamic" relationship.
 
     The :class:`_orm.DynamicMapped` type annotation may be used in an
@@ -918,23 +919,27 @@ class DynamicMapped(_MappedAnnotationBase[_T]):
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> InstrumentedAttribute[_T]:
+        ) -> InstrumentedAttribute[_T_co]:
             ...
 
         @overload
-        def __get__(self, instance: object, owner: Any) -> AppenderQuery[_T]:
+        def __get__(
+            self, instance: object, owner: Any
+        ) -> AppenderQuery[_T_co]:
             ...
 
         def __get__(
             self, instance: Optional[object], owner: Any
-        ) -> Union[InstrumentedAttribute[_T], AppenderQuery[_T]]:
+        ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]:
             ...
 
-        def __set__(self, instance: Any, value: typing.Collection[_T]) -> None:
+        def __set__(
+            self, instance: Any, value: typing.Collection[_T_co]
+        ) -> None:
             ...
 
 
-class WriteOnlyMapped(_MappedAnnotationBase[_T]):
+class WriteOnlyMapped(_MappedAnnotationBase[_T_co]):
     """Represent the ORM mapped attribute type for a "write only" relationship.
 
     The :class:`_orm.WriteOnlyMapped` type annotation may be used in an
@@ -970,19 +975,21 @@ class WriteOnlyMapped(_MappedAnnotationBase[_T]):
         @overload
         def __get__(
             self, instance: None, owner: Any
-        ) -> InstrumentedAttribute[_T]:
+        ) -> InstrumentedAttribute[_T_co]:
             ...
 
         @overload
         def __get__(
             self, instance: object, owner: Any
-        ) -> WriteOnlyCollection[_T]:
+        ) -> WriteOnlyCollection[_T_co]:
             ...
 
         def __get__(
             self, instance: Optional[object], owner: Any
-        ) -> Union[InstrumentedAttribute[_T], WriteOnlyCollection[_T]]:
+        ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]:
             ...
 
-        def __set__(self, instance: Any, value: typing.Collection[_T]) -> None:
+        def __set__(
+            self, instance: Any, value: typing.Collection[_T_co]
+        ) -> None:
             ...
index 88cafcb64524081e761d43ef710e6fd6507bb896..daba973cb3a4e4d18624a50bd0d4524f4a3c0ef9 100644 (file)
@@ -107,6 +107,7 @@ if typing.TYPE_CHECKING:
 _StrategyKey = Tuple[Any, ...]
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 _TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]")
 
@@ -653,7 +654,7 @@ class MapperProperty(
 
 
 @inspection._self_inspects
-class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators):
+class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators):
     r"""Defines SQL operations for ORM mapped attributes.
 
     SQLAlchemy allows for operators to
@@ -740,7 +741,7 @@ class PropComparator(SQLORMOperations[_T], Generic[_T], ColumnOperators):
 
     _parententity: _InternalEntityType[Any]
     _adapt_to_entity: Optional[AliasedInsp[Any]]
-    prop: RODescriptorReference[MapperProperty[_T]]
+    prop: RODescriptorReference[MapperProperty[_T_co]]
 
     def __init__(
         self,
index a6b4c8ab16f1e9afd020fb26df3efe8a60fddc24..cc63b96abc4ff69a899052fd067a6146f1f1e2ae 100644 (file)
@@ -117,6 +117,7 @@ _NUMERIC = Union[float, Decimal]
 _NUMBER = Union[float, int, Decimal]
 
 _T = TypeVar("_T", bound="Any")
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 _OPT = TypeVar("_OPT", bound="Any")
 _NT = TypeVar("_NT", bound="_NUMERIC")
 
@@ -804,7 +805,7 @@ class CompilerColumnElement(
 # SQLCoreOperations should be suiting the ExpressionElementRole
 # and ColumnsClauseRole.   however the MRO issues become too elaborate
 # at the moment.
-class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
+class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly):
     __slots__ = ()
 
     # annotations for comparison methods
@@ -873,7 +874,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def __or__(self, other: Any) -> BooleanClauseList:
             ...
 
-        def __invert__(self) -> ColumnElement[_T]:
+        def __invert__(self) -> ColumnElement[_T_co]:
             ...
 
         def __lt__(self, other: Any) -> ColumnElement[bool]:
@@ -900,7 +901,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def __ge__(self, other: Any) -> ColumnElement[bool]:
             ...
 
-        def __neg__(self) -> UnaryExpression[_T]:
+        def __neg__(self) -> UnaryExpression[_T_co]:
             ...
 
         def __contains__(self, other: Any) -> ColumnElement[bool]:
@@ -961,7 +962,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def bitwise_and(self, other: Any) -> BinaryExpression[Any]:
             ...
 
-        def bitwise_not(self) -> UnaryExpression[_T]:
+        def bitwise_not(self) -> UnaryExpression[_T_co]:
             ...
 
         def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]:
@@ -1074,22 +1075,22 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         ) -> ColumnElement[str]:
             ...
 
-        def desc(self) -> UnaryExpression[_T]:
+        def desc(self) -> UnaryExpression[_T_co]:
             ...
 
-        def asc(self) -> UnaryExpression[_T]:
+        def asc(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nulls_first(self) -> UnaryExpression[_T]:
+        def nulls_first(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nullsfirst(self) -> UnaryExpression[_T]:
+        def nullsfirst(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nulls_last(self) -> UnaryExpression[_T]:
+        def nulls_last(self) -> UnaryExpression[_T_co]:
             ...
 
-        def nullslast(self) -> UnaryExpression[_T]:
+        def nullslast(self) -> UnaryExpression[_T_co]:
             ...
 
         def collate(self, collation: str) -> CollationClause:
@@ -1100,7 +1101,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         ) -> BinaryExpression[bool]:
             ...
 
-        def distinct(self: _SQO[_T]) -> UnaryExpression[_T]:
+        def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]:
             ...
 
         def any_(self) -> CollectionAggregate[Any]:
@@ -1139,10 +1140,6 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
         def __radd__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]:
             ...
 
-        @overload
-        def __radd__(self: _SQO[int], other: Any) -> ColumnElement[int]:
-            ...
-
         @overload
         def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]:
             ...
@@ -1282,7 +1279,7 @@ class SQLCoreOperations(Generic[_T], ColumnOperators, TypingOnly):
 
 
 class SQLColumnExpression(
-    SQLCoreOperations[_T], roles.ExpressionElementRole[_T], TypingOnly
+    SQLCoreOperations[_T_co], roles.ExpressionElementRole[_T_co], TypingOnly
 ):
     """A type that may be used to indicate any SQL column element or object
     that acts in place of one.
index f8aac70b998dc3aad81fc12f85dc1505d2444ff1..6f299224328a7a1137bfa1670347c54ecd79a49d 100644 (file)
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
     from .selectable import Subquery
 
 _T = TypeVar("_T", bound=Any)
+_T_co = TypeVar("_T_co", bound=Any, covariant=True)
 
 
 class SQLRole:
@@ -110,7 +111,7 @@ class ColumnsClauseRole(AllowsLambdaRole, UsesInspection, ColumnListRole):
         raise NotImplementedError()
 
 
-class TypedColumnsClauseRole(Generic[_T], SQLRole):
+class TypedColumnsClauseRole(Generic[_T_co], SQLRole):
     """element-typed form of ColumnsClauseRole"""
 
     __slots__ = ()
@@ -162,7 +163,7 @@ class WhereHavingRole(OnClauseRole):
     _role_name = "SQL expression for WHERE/HAVING role"
 
 
-class ExpressionElementRole(TypedColumnsClauseRole[_T]):
+class ExpressionElementRole(TypedColumnsClauseRole[_T_co]):
     # note when using generics for ExpressionElementRole,
     # the generic type needs to be in
     # sqlalchemy.sql.coercions._impl_lookup mapping also.
diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py
new file mode 100644 (file)
index 0000000..58cc1ab
--- /dev/null
@@ -0,0 +1,53 @@
+"""Tests Mapped covariance."""\r
+\r
+from typing import Protocol\r
+\r
+from sqlalchemy import ForeignKey\r
+from sqlalchemy.orm import DeclarativeBase\r
+from sqlalchemy.orm import Mapped\r
+from sqlalchemy.orm import mapped_column\r
+from sqlalchemy.orm import relationship\r
+\r
+\r
+# Protocols\r
+\r
+\r
+class ParentProtocol(Protocol):\r
+    name: Mapped[str]\r
+\r
+\r
+class ChildProtocol(Protocol):\r
+    # Read-only for simplicity, mutable protocol members are complicated,\r
+    # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected\r
+    @property\r
+    def parent(self) -> Mapped[ParentProtocol]:\r
+        ...\r
+\r
+\r
+def get_parent_name(child: ChildProtocol) -> str:\r
+    return child.parent.name\r
+\r
+\r
+# Implementations\r
+\r
+\r
+class Base(DeclarativeBase):\r
+    pass\r
+\r
+\r
+class Parent(Base):\r
+    __tablename__ = "parent"\r
+\r
+    name: Mapped[str] = mapped_column(primary_key=True)\r
+\r
+\r
+class Child(Base):\r
+    __tablename__ = "child"\r
+\r
+    name: Mapped[str] = mapped_column(primary_key=True)\r
+    parent_name: Mapped[str] = mapped_column(ForeignKey(Parent.name))\r
+\r
+    parent: Mapped[Parent] = relationship()\r
+\r
+\r
+assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo"\r