]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #11112 - Make instrumented attribute covariant as well 11113/head
authorEthan Langevin <ethan.langevin@anthropic.com>
Thu, 7 Mar 2024 01:18:30 +0000 (20:18 -0500)
committerEthan Langevin <ethan.langevin@anthropic.com>
Thu, 7 Mar 2024 01:18:30 +0000 (20:18 -0500)
lib/sqlalchemy/orm/attributes.py
test/typing/plain_files/orm/mapped_covariant.py

index d9b2d8213d1c03b6890d62abd5c11244de0b5592..d3d773055bb27204e6745a4d946c6d999667f552 100644 (file)
@@ -503,7 +503,7 @@ def _queryable_attribute_unreduce(
         return getattr(entity, key)
 
 
-class InstrumentedAttribute(QueryableAttribute[_T]):
+class InstrumentedAttribute(QueryableAttribute[_T_co]):
     """Class bound instrumented attribute which adds basic
     :term:`descriptor` methods.
 
index 9f964021b313ebc437e4f89df40d528be59ab2f0..398db309fbab5af82393d7abbc9eec6db847b2d1 100644 (file)
@@ -2,12 +2,15 @@
 \r
 from datetime import datetime\r
 from typing import Protocol\r
+from typing import Sequence\r
+from typing import TypeVar\r
 from typing import Union\r
 \r
 from sqlalchemy import ForeignKey\r
 from sqlalchemy import func\r
 from sqlalchemy import Nullable\r
 from sqlalchemy.orm import DeclarativeBase\r
+from sqlalchemy.orm import InstrumentedAttribute\r
 from sqlalchemy.orm import Mapped\r
 from sqlalchemy.orm import mapped_column\r
 from sqlalchemy.orm import relationship\r
@@ -24,7 +27,8 @@ class ChildProtocol(Protocol):
     # Read-only for simplicity, mutable protocol members are complicated,\r
     # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected\r
     @property\r
-    def parent(self) -> Mapped[ParentProtocol]: ...\r
+    def parent(self) -> Mapped[ParentProtocol]:\r
+        ...\r
 \r
 \r
 def get_parent_name(child: ChildProtocol) -> str:\r
@@ -43,6 +47,8 @@ class Parent(Base):
 \r
     name: Mapped[str] = mapped_column(primary_key=True)\r
 \r
+    children: Mapped[Sequence["Child"]] = relationship("Child")\r
+\r
 \r
 class Child(Base):\r
     __tablename__ = "child"\r
@@ -55,6 +61,23 @@ class Child(Base):
 \r
 assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo"\r
 \r
+# Make sure that relationships are covariant as well\r
+_BaseT = TypeVar("_BaseT", bound=Base, covariant=True)\r
+RelationshipType = (\r
+    InstrumentedAttribute[_BaseT]\r
+    | InstrumentedAttribute[Sequence[_BaseT]]\r
+    | InstrumentedAttribute[_BaseT | None]\r
+)\r
+\r
+\r
+def operate_on_relationships(\r
+    relationships: list[RelationshipType[_BaseT]],\r
+) -> int:\r
+    return len(relationships)\r
+\r
+\r
+assert operate_on_relationships([Parent.children, Child.parent]) == 2\r
+\r
 # other test\r
 \r
 \r