]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Make instrumented attribute covariant as well
authorEthan Langevin <ethan.langevin@anthropic.com>
Mon, 11 Mar 2024 11:41:58 +0000 (07:41 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 12 Mar 2024 22:14:56 +0000 (23:14 +0100)
<!-- Provide a general summary of your proposed changes in the Title field above -->

Allows mapped relationships to use covariant types which makes it possible to define methods that operate on relationships in a typesafe way

### Description

See: https://github.com/sqlalchemy/sqlalchemy/issues/11112 for a more in depth explanation.

Just changed the type parameter in `InstrumentedAttribute` from `_T` to `_T_co`.

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #11113
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11113
Pull-request-sha: 3c100f28661f3440769175a17c2763ed25f4b83a

Change-Id: Iff715c24f1556d5604dcd33661a0ee7232b9404b

lib/sqlalchemy/orm/attributes.py
test/typing/plain_files/orm/mapped_covariant.py

index d9b2d8213d1c03b6890d62abd5c11244de0b5592..5b16ce3d6b32bbcd53901a55859098e28f49a481 100644 (file)
@@ -503,7 +503,7 @@ def _queryable_attribute_unreduce(
         return getattr(entity, key)
 
 
-class InstrumentedAttribute(QueryableAttribute[_T]):
+class InstrumentedAttribute(QueryableAttribute[_T_co]):
     """Class bound instrumented attribute which adds basic
     :term:`descriptor` methods.
 
@@ -544,14 +544,14 @@ class InstrumentedAttribute(QueryableAttribute[_T]):
     @overload
     def __get__(
         self, instance: None, owner: Any
-    ) -> InstrumentedAttribute[_T]: ...
+    ) -> InstrumentedAttribute[_T_co]: ...
 
     @overload
-    def __get__(self, instance: object, owner: Any) -> _T: ...
+    def __get__(self, instance: object, owner: Any) -> _T_co: ...
 
     def __get__(
         self, instance: Optional[object], owner: Any
-    ) -> Union[InstrumentedAttribute[_T], _T]:
+    ) -> Union[InstrumentedAttribute[_T_co], _T_co]:
         if instance is None:
             return self
 
index 9f964021b313ebc437e4f89df40d528be59ab2f0..680e925de3641e3508f27ff71f5280e8bc38c326 100644 (file)
@@ -2,12 +2,15 @@
 \r
 from datetime import datetime\r
 from typing import Protocol\r
+from typing import Sequence\r
+from typing import TypeVar\r
 from typing import Union\r
 \r
 from sqlalchemy import ForeignKey\r
 from sqlalchemy import func\r
 from sqlalchemy import Nullable\r
 from sqlalchemy.orm import DeclarativeBase\r
+from sqlalchemy.orm import InstrumentedAttribute\r
 from sqlalchemy.orm import Mapped\r
 from sqlalchemy.orm import mapped_column\r
 from sqlalchemy.orm import relationship\r
@@ -43,6 +46,8 @@ class Parent(Base):
 \r
     name: Mapped[str] = mapped_column(primary_key=True)\r
 \r
+    children: Mapped[Sequence["Child"]] = relationship("Child")\r
+\r
 \r
 class Child(Base):\r
     __tablename__ = "child"\r
@@ -55,6 +60,23 @@ class Child(Base):
 \r
 assert get_parent_name(Child(parent=Parent(name="foo"))) == "foo"\r
 \r
+# Make sure that relationships are covariant as well\r
+_BaseT = TypeVar("_BaseT", bound=Base, covariant=True)\r
+RelationshipType = (\r
+    InstrumentedAttribute[_BaseT]\r
+    | InstrumentedAttribute[Sequence[_BaseT]]\r
+    | InstrumentedAttribute[_BaseT | None]\r
+)\r
+\r
+\r
+def operate_on_relationships(\r
+    relationships: list[RelationshipType[_BaseT]],\r
+) -> int:\r
+    return len(relationships)\r
+\r
+\r
+assert operate_on_relationships([Parent.children, Child.parent]) == 2\r
+\r
 # other test\r
 \r
 \r