From: Mike Bayer Date: Thu, 15 Jun 2023 13:18:38 +0000 (-0400) Subject: improve support for declared_attr returning ORMDescriptor X-Git-Tag: rel_2_0_17~13^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=108f5ec3feed145d371cbd1c54d55d6601bbd0f7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git improve support for declared_attr returning ORMDescriptor Fixed issue in ORM Annotated Declarative which prevented a :class:`_orm.declared_attr` with or without :attr:`_orm.declared_attr.directive` from being used on a mixin which did not return a :class:`.Mapped` datatype, and instead returned a supplemental ORM datatype such as :class:`.AssociationProxy`. The Declarative runtime would erroneously try to interpret this annotation as needing to be :class:`.Mapped` and raise an error. Fixed typing issue where using the :class:`.AssociationProxy` return type from a :class:`_orm.declared_attr` function was disallowed. Fixes: #9957 Change-Id: I797c5bbdb3d1e81a04ed21c6558ec349b970476f --- diff --git a/doc/build/changelog/unreleased_20/9957.rst b/doc/build/changelog/unreleased_20/9957.rst new file mode 100644 index 0000000000..4a04cba503 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9957.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, orm + :tickets: 9957 + + Fixed issue in ORM Annotated Declarative which prevented a + :class:`_orm.declared_attr` from being used on a mixin which did not return + a :class:`.Mapped` datatype, and instead returned a supplemental ORM + datatype such as :class:`.AssociationProxy`. The Declarative runtime would + erroneously try to interpret this annotation as needing to be + :class:`.Mapped` and raise an error. + + +.. change:: + :tags: bug, orm, typing + :tickets: 9957 + + Fixed typing issue where using the :class:`.AssociationProxy` return type + from a :class:`_orm.declared_attr` function was disallowed. diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 76605e26e3..0b74e647d4 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -399,7 +399,7 @@ class AssociationProxy( self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS @overload - def __get__(self, instance: Any, owner: Literal[None]) -> Self: + def __get__(self, instance: Literal[None], owner: Literal[None]) -> Self: ... @overload diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 01b78aa99c..6439561d82 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -46,6 +46,7 @@ from .attributes import InstrumentedAttribute from .base import _inspect_mapped_class from .base import _is_mapped_class from .base import Mapped +from .base import ORMDescriptor from .decl_base import _add_attribute from .decl_base import _as_declarative from .decl_base import _ClassScanMapperConfig @@ -98,7 +99,7 @@ _TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"] _MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"] _DeclaredAttrDecorated = Callable[ - ..., Union[Mapped[_T], SQLCoreOperations[_T]] + ..., Union[Mapped[_T], ORMDescriptor[_T], SQLCoreOperations[_T]] ] diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index f87168c089..f50dba52bc 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -51,6 +51,7 @@ from .base import Mapped from .base import object_mapper as object_mapper from .base import object_state as object_state # noqa: F401 from .base import opt_manager_of_class +from .base import ORMDescriptor from .base import state_attribute_str as state_attribute_str # noqa: F401 from .base import state_class_str as state_class_str # noqa: F401 from .base import state_str as state_str # noqa: F401 @@ -2347,10 +2348,18 @@ def _extract_mapped_subtype( annotated, _MappedAnnotationBase ): if expect_mapped: - if getattr(annotated, "__origin__", None) is typing.ClassVar: + if not raiseerr: return None - if not raiseerr: + origin = getattr(annotated, "__origin__", None) + if origin is typing.ClassVar: + return None + + # check for other kind of ORM descriptor like AssociationProxy, + # don't raise for that (issue #9957) + elif isinstance(origin, type) and issubclass( + origin, ORMDescriptor + ): return None raise sa_exc.ArgumentError( diff --git a/test/ext/mypy/plain_files/association_proxy_three.py b/test/ext/mypy/plain_files/association_proxy_three.py new file mode 100644 index 0000000000..f338681f7c --- /dev/null +++ b/test/ext/mypy/plain_files/association_proxy_three.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import List + +from sqlalchemy import ForeignKey +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.ext.associationproxy import AssociationProxy +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +class Milestone: + id: Mapped[int] = mapped_column(primary_key=True) + + @declared_attr + def users(self) -> Mapped[List["User"]]: + return relationship("User") + + @declared_attr + def user_ids(self) -> AssociationProxy[List[int]]: + return association_proxy("users", "id") + + +class BranchMilestone(Milestone, Base): + __tablename__ = "branch_milestones" + + +class User(Base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + branch_id: Mapped[int] = mapped_column(ForeignKey("branch_milestones.id")) + + +bm = BranchMilestone() + +x1 = bm.user_ids + +# EXPECTED_TYPE: list[int] +reveal_type(x1) diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index f0a74b26c3..b41021003a 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -49,6 +49,8 @@ from sqlalchemy import testing from sqlalchemy import types from sqlalchemy import VARCHAR from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.ext.associationproxy import AssociationProxy from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import as_declarative from sqlalchemy.orm import composite @@ -1890,6 +1892,58 @@ class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): "JOIN related ON related.id = b.related_id", ) + @testing.variation("use_directive", [True, False]) + @testing.variation("use_annotation", [True, False]) + def test_supplemental_declared_attr( + self, decl_base, use_directive, use_annotation + ): + """test #9957""" + + class User(decl_base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + branch_id: Mapped[int] = mapped_column(ForeignKey("thing.id")) + + class Mixin: + id: Mapped[int] = mapped_column(primary_key=True) + + @declared_attr + def users(self) -> Mapped[List[User]]: + return relationship(User) + + if use_directive: + if use_annotation: + + @declared_attr.directive + def user_ids(self) -> AssociationProxy[List[int]]: + return association_proxy("users", "id") + + else: + + @declared_attr.directive + def user_ids(self): + return association_proxy("users", "id") + + else: + if use_annotation: + + @declared_attr + def user_ids(self) -> AssociationProxy[List[int]]: + return association_proxy("users", "id") + + else: + + @declared_attr + def user_ids(self): + return association_proxy("users", "id") + + class Thing(Mixin, decl_base): + __tablename__ = "thing" + + t1 = Thing() + t1.users.extend([User(id=1), User(id=2)]) + eq_(t1.user_ids, [1, 2]) + class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 564db64722..e29cb5ff87 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -40,6 +40,8 @@ from sqlalchemy import testing from sqlalchemy import types from sqlalchemy import VARCHAR from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.ext.associationproxy import AssociationProxy from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import as_declarative from sqlalchemy.orm import composite @@ -1881,6 +1883,58 @@ class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): "JOIN related ON related.id = b.related_id", ) + @testing.variation("use_directive", [True, False]) + @testing.variation("use_annotation", [True, False]) + def test_supplemental_declared_attr( + self, decl_base, use_directive, use_annotation + ): + """test #9957""" + + class User(decl_base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + branch_id: Mapped[int] = mapped_column(ForeignKey("thing.id")) + + class Mixin: + id: Mapped[int] = mapped_column(primary_key=True) + + @declared_attr + def users(self) -> Mapped[List[User]]: + return relationship(User) + + if use_directive: + if use_annotation: + + @declared_attr.directive + def user_ids(self) -> AssociationProxy[List[int]]: + return association_proxy("users", "id") + + else: + + @declared_attr.directive + def user_ids(self): + return association_proxy("users", "id") + + else: + if use_annotation: + + @declared_attr + def user_ids(self) -> AssociationProxy[List[int]]: + return association_proxy("users", "id") + + else: + + @declared_attr + def user_ids(self): + return association_proxy("users", "id") + + class Thing(Mixin, decl_base): + __tablename__ = "thing" + + t1 = Thing() + t1.users.extend([User(id=1), User(id=2)]) + eq_(t1.user_ids, [1, 2]) + class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default"