]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve support for declared_attr returning ORMDescriptor
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Jun 2023 13:18:38 +0000 (09:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Jun 2023 14:38:12 +0000 (10:38 -0400)
Fixed issue in ORM Annotated Declarative which prevented a
:class:`_orm.declared_attr` with or without
:attr:`_orm.declared_attr.directive` from being used on a mixin which did
not return a :class:`.Mapped` datatype, and instead returned a supplemental
ORM datatype such as :class:`.AssociationProxy`.  The Declarative runtime
would erroneously try to interpret this annotation as needing to be
:class:`.Mapped` and raise an error.

Fixed typing issue where using the :class:`.AssociationProxy` return type
from a :class:`_orm.declared_attr` function was disallowed.

Fixes: #9957
Change-Id: I797c5bbdb3d1e81a04ed21c6558ec349b970476f

doc/build/changelog/unreleased_20/9957.rst [new file with mode: 0644]
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/util.py
test/ext/mypy/plain_files/association_proxy_three.py [new file with mode: 0644]
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/9957.rst b/doc/build/changelog/unreleased_20/9957.rst
new file mode 100644 (file)
index 0000000..4a04cba
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9957
+
+    Fixed issue in ORM Annotated Declarative which prevented a
+    :class:`_orm.declared_attr` from being used on a mixin which did not return
+    a :class:`.Mapped` datatype, and instead returned a supplemental ORM
+    datatype such as :class:`.AssociationProxy`.  The Declarative runtime would
+    erroneously try to interpret this annotation as needing to be
+    :class:`.Mapped` and raise an error.
+
+
+.. change::
+    :tags: bug, orm, typing
+    :tickets: 9957
+
+    Fixed typing issue where using the :class:`.AssociationProxy` return type
+    from a :class:`_orm.declared_attr` function was disallowed.
index 76605e26e3e7e26567f9032d5f88ffaffafb43fd..0b74e647d4e81a15ad9d1444c850793a5f29121a 100644 (file)
@@ -399,7 +399,7 @@ class AssociationProxy(
             self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS
 
     @overload
-    def __get__(self, instance: Any, owner: Literal[None]) -> Self:
+    def __get__(self, instance: Literal[None], owner: Literal[None]) -> Self:
         ...
 
     @overload
index 01b78aa99c58acad26641b75ade90e299d7250e7..6439561d82f5790bbbc4848eabb3d2e0c7eda57e 100644 (file)
@@ -46,6 +46,7 @@ from .attributes import InstrumentedAttribute
 from .base import _inspect_mapped_class
 from .base import _is_mapped_class
 from .base import Mapped
+from .base import ORMDescriptor
 from .decl_base import _add_attribute
 from .decl_base import _as_declarative
 from .decl_base import _ClassScanMapperConfig
@@ -98,7 +99,7 @@ _TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"]
 _MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"]
 
 _DeclaredAttrDecorated = Callable[
-    ..., Union[Mapped[_T], SQLCoreOperations[_T]]
+    ..., Union[Mapped[_T], ORMDescriptor[_T], SQLCoreOperations[_T]]
 ]
 
 
index f87168c0890b332d582582e6ffe79813a8627b43..f50dba52bc0539170e3ec24756732485cd703655 100644 (file)
@@ -51,6 +51,7 @@ from .base import Mapped
 from .base import object_mapper as object_mapper
 from .base import object_state as object_state  # noqa: F401
 from .base import opt_manager_of_class
+from .base import ORMDescriptor
 from .base import state_attribute_str as state_attribute_str  # noqa: F401
 from .base import state_class_str as state_class_str  # noqa: F401
 from .base import state_str as state_str  # noqa: F401
@@ -2347,10 +2348,18 @@ def _extract_mapped_subtype(
             annotated, _MappedAnnotationBase
         ):
             if expect_mapped:
-                if getattr(annotated, "__origin__", None) is typing.ClassVar:
+                if not raiseerr:
                     return None
 
-                if not raiseerr:
+                origin = getattr(annotated, "__origin__", None)
+                if origin is typing.ClassVar:
+                    return None
+
+                # check for other kind of ORM descriptor like AssociationProxy,
+                # don't raise for that (issue #9957)
+                elif isinstance(origin, type) and issubclass(
+                    origin, ORMDescriptor
+                ):
                     return None
 
                 raise sa_exc.ArgumentError(
diff --git a/test/ext/mypy/plain_files/association_proxy_three.py b/test/ext/mypy/plain_files/association_proxy_three.py
new file mode 100644 (file)
index 0000000..f338681
--- /dev/null
@@ -0,0 +1,46 @@
+from __future__ import annotations
+
+from typing import List
+
+from sqlalchemy import ForeignKey
+from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy.ext.associationproxy import AssociationProxy
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import declared_attr
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Milestone:
+    id: Mapped[int] = mapped_column(primary_key=True)
+
+    @declared_attr
+    def users(self) -> Mapped[List["User"]]:
+        return relationship("User")
+
+    @declared_attr
+    def user_ids(self) -> AssociationProxy[List[int]]:
+        return association_proxy("users", "id")
+
+
+class BranchMilestone(Milestone, Base):
+    __tablename__ = "branch_milestones"
+
+
+class User(Base):
+    __tablename__ = "user"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    branch_id: Mapped[int] = mapped_column(ForeignKey("branch_milestones.id"))
+
+
+bm = BranchMilestone()
+
+x1 = bm.user_ids
+
+# EXPECTED_TYPE: list[int]
+reveal_type(x1)
index f0a74b26c3030f687b50414077f3829d8a42cb5e..b41021003a22d2cb6f65790fd4f544402428c38b 100644 (file)
@@ -49,6 +49,8 @@ from sqlalchemy import testing
 from sqlalchemy import types
 from sqlalchemy import VARCHAR
 from sqlalchemy.exc import ArgumentError
+from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy.ext.associationproxy import AssociationProxy
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.orm import as_declarative
 from sqlalchemy.orm import composite
@@ -1890,6 +1892,58 @@ class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "JOIN related ON related.id = b.related_id",
         )
 
+    @testing.variation("use_directive", [True, False])
+    @testing.variation("use_annotation", [True, False])
+    def test_supplemental_declared_attr(
+        self, decl_base, use_directive, use_annotation
+    ):
+        """test #9957"""
+
+        class User(decl_base):
+            __tablename__ = "user"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            branch_id: Mapped[int] = mapped_column(ForeignKey("thing.id"))
+
+        class Mixin:
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            @declared_attr
+            def users(self) -> Mapped[List[User]]:
+                return relationship(User)
+
+            if use_directive:
+                if use_annotation:
+
+                    @declared_attr.directive
+                    def user_ids(self) -> AssociationProxy[List[int]]:
+                        return association_proxy("users", "id")
+
+                else:
+
+                    @declared_attr.directive
+                    def user_ids(self):
+                        return association_proxy("users", "id")
+
+            else:
+                if use_annotation:
+
+                    @declared_attr
+                    def user_ids(self) -> AssociationProxy[List[int]]:
+                        return association_proxy("users", "id")
+
+                else:
+
+                    @declared_attr
+                    def user_ids(self):
+                        return association_proxy("users", "id")
+
+        class Thing(Mixin, decl_base):
+            __tablename__ = "thing"
+
+        t1 = Thing()
+        t1.users.extend([User(id=1), User(id=2)])
+        eq_(t1.user_ids, [1, 2])
+
 
 class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
index 564db647228a78c34ad99e8235113f3115463b55..e29cb5ff87db73231d6bdf0e87624ccb701d4097 100644 (file)
@@ -40,6 +40,8 @@ from sqlalchemy import testing
 from sqlalchemy import types
 from sqlalchemy import VARCHAR
 from sqlalchemy.exc import ArgumentError
+from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy.ext.associationproxy import AssociationProxy
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.orm import as_declarative
 from sqlalchemy.orm import composite
@@ -1881,6 +1883,58 @@ class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "JOIN related ON related.id = b.related_id",
         )
 
+    @testing.variation("use_directive", [True, False])
+    @testing.variation("use_annotation", [True, False])
+    def test_supplemental_declared_attr(
+        self, decl_base, use_directive, use_annotation
+    ):
+        """test #9957"""
+
+        class User(decl_base):
+            __tablename__ = "user"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            branch_id: Mapped[int] = mapped_column(ForeignKey("thing.id"))
+
+        class Mixin:
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            @declared_attr
+            def users(self) -> Mapped[List[User]]:
+                return relationship(User)
+
+            if use_directive:
+                if use_annotation:
+
+                    @declared_attr.directive
+                    def user_ids(self) -> AssociationProxy[List[int]]:
+                        return association_proxy("users", "id")
+
+                else:
+
+                    @declared_attr.directive
+                    def user_ids(self):
+                        return association_proxy("users", "id")
+
+            else:
+                if use_annotation:
+
+                    @declared_attr
+                    def user_ids(self) -> AssociationProxy[List[int]]:
+                        return association_proxy("users", "id")
+
+                else:
+
+                    @declared_attr
+                    def user_ids(self):
+                        return association_proxy("users", "id")
+
+        class Thing(Mixin, decl_base):
+            __tablename__ = "thing"
+
+        t1 = Thing()
+        t1.users.extend([User(id=1), User(id=2)])
+        eq_(t1.user_ids, [1, 2])
+
 
 class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"