From 3c100f28661f3440769175a17c2763ed25f4b83a Mon Sep 17 00:00:00 2001 From: Ethan Langevin Date: Wed, 6 Mar 2024 20:18:30 -0500 Subject: [PATCH] Fixes: #11112 - Make instrumented attribute covariant as well --- lib/sqlalchemy/orm/attributes.py | 2 +- .../plain_files/orm/mapped_covariant.py | 25 ++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index d9b2d8213d..d3d773055b 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -503,7 +503,7 @@ def _queryable_attribute_unreduce( return getattr(entity, key) -class InstrumentedAttribute(QueryableAttribute[_T]): +class InstrumentedAttribute(QueryableAttribute[_T_co]): """Class bound instrumented attribute which adds basic :term:`descriptor` methods. diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py index 9f964021b3..398db309fb 100644 --- a/test/typing/plain_files/orm/mapped_covariant.py +++ b/test/typing/plain_files/orm/mapped_covariant.py @@ -2,12 +2,15 @@ from datetime import datetime from typing import Protocol +from typing import Sequence +from typing import TypeVar from typing import Union from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Nullable from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship @@ -24,7 +27,8 @@ class ChildProtocol(Protocol): # Read-only for simplicity, mutable protocol members are complicated, # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected @property - def parent(self) -> Mapped[ParentProtocol]: ... + def parent(self) -> Mapped[ParentProtocol]: + ... def get_parent_name(child: ChildProtocol) -> str: @@ -43,6 +47,8 @@ class Parent(Base): name: Mapped[str] = mapped_column(primary_key=True) + children: Mapped[Sequence["Child"]] = relationship("Child") + class Child(Base): __tablename__ = "child" @@ -55,6 +61,23 @@ class Child(Base): assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo" +# Make sure that relationships are covariant as well +_BaseT = TypeVar("_BaseT", bound=Base, covariant=True) +RelationshipType = ( + InstrumentedAttribute[_BaseT] + | InstrumentedAttribute[Sequence[_BaseT]] + | InstrumentedAttribute[_BaseT | None] +) + + +def operate_on_relationships( + relationships: list[RelationshipType[_BaseT]], +) -> int: + return len(relationships) + + +assert operate_on_relationships([Parent.children, Child.parent]) == 2 + # other test -- 2.47.2