From 058e10f2b7e5686198dc744107b32952e55dc93c Mon Sep 17 00:00:00 2001 From: Ethan Langevin Date: Mon, 11 Mar 2024 07:41:58 -0400 Subject: [PATCH] Make instrumented attribute covariant as well Allows mapped relationships to use covariant types which makes it possible to define methods that operate on relationships in a typesafe way ### Description See: https://github.com/sqlalchemy/sqlalchemy/issues/11112 for a more in depth explanation. Just changed the type parameter in `InstrumentedAttribute` from `_T` to `_T_co`. ### Checklist This pull request is: - [ ] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [x] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #11113 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11113 Pull-request-sha: 3c100f28661f3440769175a17c2763ed25f4b83a Change-Id: Iff715c24f1556d5604dcd33661a0ee7232b9404b --- lib/sqlalchemy/orm/attributes.py | 8 +++---- .../plain_files/orm/mapped_covariant.py | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index d9b2d8213d..5b16ce3d6b 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -503,7 +503,7 @@ def _queryable_attribute_unreduce( return getattr(entity, key) -class InstrumentedAttribute(QueryableAttribute[_T]): +class InstrumentedAttribute(QueryableAttribute[_T_co]): """Class bound instrumented attribute which adds basic :term:`descriptor` methods. @@ -544,14 +544,14 @@ class InstrumentedAttribute(QueryableAttribute[_T]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: ... + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: ... + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: if instance is None: return self diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py index 9f964021b3..680e925de3 100644 --- a/test/typing/plain_files/orm/mapped_covariant.py +++ b/test/typing/plain_files/orm/mapped_covariant.py @@ -2,12 +2,15 @@ from datetime import datetime from typing import Protocol +from typing import Sequence +from typing import TypeVar from typing import Union from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Nullable from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship @@ -43,6 +46,8 @@ class Parent(Base): name: Mapped[str] = mapped_column(primary_key=True) + children: Mapped[Sequence["Child"]] = relationship("Child") + class Child(Base): __tablename__ = "child" @@ -55,6 +60,23 @@ class Child(Base): assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo" +# Make sure that relationships are covariant as well +_BaseT = TypeVar("_BaseT", bound=Base, covariant=True) +RelationshipType = ( + InstrumentedAttribute[_BaseT] + | InstrumentedAttribute[Sequence[_BaseT]] + | InstrumentedAttribute[_BaseT | None] +) + + +def operate_on_relationships( + relationships: list[RelationshipType[_BaseT]], +) -> int: + return len(relationships) + + +assert operate_on_relationships([Parent.children, Child.parent]) == 2 + # other test -- 2.47.2