]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure assoc proxy from aliased() is generated in correct context
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 21 Aug 2025 15:20:11 +0000 (11:20 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Aug 2025 14:51:09 +0000 (10:51 -0400)
Improved association proxy to behave slightly better when the parent class
is used in an :func:`_orm.aliased` construct, so that the proxy as
delivered by the :class:`.Aliased` behaves appropriate in terms of that
aliased construct, including operators like ``.any()`` and ``.has()`` work
correctly.

Fixes: #11622
Change-Id: I6220d984d4323a01a38bd89cfbb1bae46d81c24e

doc/build/changelog/unreleased_20/11622.rst [new file with mode: 0644]
lib/sqlalchemy/ext/associationproxy.py
test/ext/test_associationproxy.py

diff --git a/doc/build/changelog/unreleased_20/11622.rst b/doc/build/changelog/unreleased_20/11622.rst
new file mode 100644 (file)
index 0000000..25569da
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11622
+
+    Improved association proxy to behave slightly better when the parent class
+    is used in an :func:`_orm.aliased` construct, so that the proxy as
+    delivered by the :class:`.Aliased` behaves appropriate in terms of that
+    aliased construct, including operators like ``.any()`` and ``.has()`` work
+    correctly.
index 22d2bb570d76d72e9e511588a92d09fcb481a891..776ad6a44d4555429d4d649788458f71779647f5 100644 (file)
@@ -69,6 +69,7 @@ if typing.TYPE_CHECKING:
     from ..orm.interfaces import MapperProperty
     from ..orm.interfaces import PropComparator
     from ..orm.mapper import Mapper
+    from ..orm.util import AliasedInsp
     from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _InfoType
 
@@ -1231,6 +1232,11 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]):
     _target_is_object: bool = True
     _is_canonical = True
 
+    def adapt_to_entity(
+        self, aliased_insp: AliasedInsp[Any]
+    ) -> AliasedAssociationProxyInstance[_T]:
+        return AliasedAssociationProxyInstance(self, aliased_insp)
+
     def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]:
         """Produce a proxied 'contains' expression using EXISTS.
 
@@ -1284,6 +1290,44 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]):
         )
 
 
+class AliasedAssociationProxyInstance(ObjectAssociationProxyInstance[_T]):
+    def __init__(
+        self,
+        parent_instance: ObjectAssociationProxyInstance[_T],
+        aliased_insp: AliasedInsp[Any],
+    ) -> None:
+        self.parent = parent_instance.parent
+        self.owning_class = parent_instance.owning_class
+        self.aliased_insp = aliased_insp
+        self.target_collection = parent_instance.target_collection
+        self.collection_class = None
+        self.target_class = parent_instance.target_class
+        self.value_attr = parent_instance.value_attr
+
+    @property
+    def _comparator(self) -> PropComparator[Any]:
+        return getattr(  # type: ignore
+            self.aliased_insp.entity, self.target_collection
+        ).comparator
+
+    @property
+    def local_attr(self) -> SQLORMOperations[Any]:
+        """The 'local' class attribute referenced by this
+        :class:`.AssociationProxyInstance`.
+
+        .. seealso::
+
+            :attr:`.AssociationProxyInstance.attr`
+
+            :attr:`.AssociationProxyInstance.remote_attr`
+
+        """
+        return cast(
+            "SQLORMOperations[Any]",
+            getattr(self.aliased_insp.entity, self.target_collection),
+        )
+
+
 class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]):
     """an :class:`.AssociationProxyInstance` that has a database column as a
     target.
index 1aca0c97e259ad9109abbae84f8369fb7ce41f3e..792e6a23ad0b9b904e8f5868730783e0c018534a 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import or_
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.engine import default
@@ -1825,13 +1826,18 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         # TODO: this is not the correct pattern, use session per test
         cls.session = Session(testing.db)
 
+    def _query_equivalent(self, q_proxy, q_direct):
+        self._equivalent(q_proxy.statement, q_direct.statement)
+        eq_(q_proxy.all(), q_direct.all())
+
     def _equivalent(self, q_proxy, q_direct):
-        proxy_sql = q_proxy.statement.compile(dialect=default.DefaultDialect())
-        direct_sql = q_direct.statement.compile(
-            dialect=default.DefaultDialect()
-        )
+        proxy_sql = q_proxy.compile(dialect=default.DefaultDialect())
+        direct_sql = q_direct.compile(dialect=default.DefaultDialect())
         eq_(str(proxy_sql), str(direct_sql))
-        eq_(q_proxy.all(), q_direct.all())
+
+    def _statement_equivalent(self, session, q_proxy, q_direct):
+        self._equivalent(q_proxy, q_direct)
+        eq_(session.scalars(q_proxy).all(), session.scalars(q_direct).all())
 
     def test_no_straight_expr(self):
         User = self.classes.User
@@ -1871,12 +1877,12 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         q2 = self.session.query(User).filter(
             User.user_keywords.any(UserKeyword.value == "singular8")
         )
-        self._equivalent(q1, q2)
+        self._query_equivalent(q1, q2)
 
     def test_filter_any_kwarg_ul_nul(self):
         UserKeyword, User = self.classes.UserKeyword, self.classes.User
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.keywords.any(keyword="jumped")
             ),
@@ -1890,7 +1896,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
     def test_filter_has_kwarg_nul_nul(self):
         UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(Keyword).filter(Keyword.user.has(name="user2")),
             self.session.query(Keyword).filter(
                 Keyword.user_keyword.has(UserKeyword.user.has(name="user2"))
@@ -1900,7 +1906,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
     def test_filter_has_kwarg_nul_ul(self):
         User, Singular = self.classes.User, self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.singular_keywords.any(keyword="jumped")
             ),
@@ -1916,7 +1922,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
             self.classes.Keyword,
         )
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.keywords.any(Keyword.keyword == "jumped")
             ),
@@ -1934,7 +1940,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
             self.classes.Keyword,
         )
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(Keyword).filter(
                 Keyword.user.has(User.name == "user2")
             ),
@@ -1952,7 +1958,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
             self.classes.Singular,
         )
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.singular_keywords.any(Keyword.keyword == "jumped")
             ),
@@ -1966,7 +1972,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
     def test_filter_contains_ul_nul(self):
         User = self.classes.User
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(User.keywords.contains(self.kw)),
             self.session.query(User).filter(
                 User.user_keywords.any(keyword=self.kw)
@@ -1979,7 +1985,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         with expect_warnings(
             "Got None for value of column keywords.singular_id;"
         ):
-            self._equivalent(
+            self._query_equivalent(
                 self.session.query(User).filter(
                     User.singular_keywords.contains(self.kw)
                 ),
@@ -1991,7 +1997,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
     def test_filter_eq_nul_nul(self):
         Keyword = self.classes.Keyword
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(Keyword).filter(Keyword.user == self.u),
             self.session.query(Keyword).filter(
                 Keyword.user_keyword.has(user=self.u)
@@ -2002,7 +2008,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         Keyword = self.classes.Keyword
         UserKeyword = self.classes.UserKeyword
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(Keyword).filter(Keyword.user != self.u),
             self.session.query(Keyword).filter(
                 Keyword.user_keyword.has(UserKeyword.user != self.u)
@@ -2012,7 +2018,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
     def test_filter_eq_null_nul_nul(self):
         UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(Keyword).filter(Keyword.user == None),  # noqa
             self.session.query(Keyword).filter(
                 or_(
@@ -2025,7 +2031,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
     def test_filter_ne_null_nul_nul(self):
         UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(Keyword).filter(Keyword.user != None),  # noqa
             self.session.query(Keyword).filter(
                 Keyword.user_keyword.has(UserKeyword.user != None)
@@ -2036,7 +2042,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         UserKeyword = self.classes.UserKeyword
         User = self.classes.User
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(UserKeyword).filter(
                 UserKeyword.singular == None
             ),  # noqa
@@ -2052,7 +2058,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.singular_value == None
             ),  # noqa
@@ -2070,7 +2076,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         Singular = self.classes.Singular
 
         s4 = self.session.query(Singular).filter_by(value="singular4").one()
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(UserKeyword).filter(UserKeyword.singular != s4),
             self.session.query(UserKeyword).filter(
                 UserKeyword.user.has(User.singular != s4)
@@ -2081,7 +2087,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.singular_value != "singular4"
             ),
@@ -2094,7 +2100,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.singular_value == "singular4"
             ),
@@ -2107,7 +2113,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.singular_value != None
             ),  # noqa
@@ -2122,7 +2128,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(User.singular_value.has()),
             self.session.query(User).filter(User.singular.has()),
         )
@@ -2133,7 +2139,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(~User.singular_value.has()),
             self.session.query(User).filter(~User.singular.has()),
         )
@@ -2176,7 +2182,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
                 )
             )
         )
-        self._equivalent(q1, q2)
+        self._query_equivalent(q1, q2)
 
     def test_filter_has_chained_has_to_any(self):
         User = self.classes.User
@@ -2205,7 +2211,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
                 Singular.keywords.any(Keyword.keyword == "brown")
             )
         )
-        self._equivalent(q1, q2)
+        self._query_equivalent(q1, q2)
 
     def test_filter_has_scalar_raises(self):
         User = self.classes.User
@@ -2241,7 +2247,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
             )
         )
 
-        self._equivalent(q1, q2)
+        self._query_equivalent(q1, q2)
 
     def test_filter_contains_chained_any_to_has(self):
         User = self.classes.User
@@ -2270,7 +2276,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
                 UserKeyword.keyword.has(Keyword.keyword == "brown")
             )
         )
-        self._equivalent(q1, q2)
+        self._query_equivalent(q1, q2)
 
     def test_filter_contains_chained_any_to_has_to_eq(self):
         User = self.classes.User
@@ -2304,7 +2310,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
                 UserKeyword.user.has(User.singular == singular)
             )
         )
-        self._equivalent(q1, q2)
+        self._query_equivalent(q1, q2)
 
     def test_has_criterion_nul(self):
         # but we don't allow that with any criterion...
@@ -2348,7 +2354,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(User.singular_value.like("foo")),
             self.session.query(User).filter(
                 User.singular.has(Singular.value.like("foo"))
@@ -2359,7 +2365,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(
                 User.singular_value.contains("foo")
             ),
@@ -2372,7 +2378,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(User.singular_value == "foo"),
             self.session.query(User).filter(
                 User.singular.has(Singular.value == "foo")
@@ -2383,7 +2389,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(User.singular_value != "foo"),
             self.session.query(User).filter(
                 User.singular.has(Singular.value != "foo")
@@ -2394,7 +2400,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
         User = self.classes.User
         Singular = self.classes.Singular
 
-        self._equivalent(
+        self._query_equivalent(
             self.session.query(User).filter(User.singular_value == None),
             self.session.query(User).filter(
                 or_(
@@ -2451,6 +2457,48 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
             "userkeywords.keyword_id",
         )
 
+    @testing.variation("use_aliased", [True, False])
+    def test_aliased_class_one(self, use_aliased):
+        """test #11622"""
+        User = self.classes.User
+        UserKeyword = self.classes.UserKeyword
+
+        if use_aliased:
+            second_user = aliased(
+                User,
+                select(User).where(User.name == "second").cte("second_user"),
+            )
+        else:
+            second_user = User
+
+        s1 = select(second_user).where(second_user.keywords.any())
+        s2 = select(second_user).where(
+            second_user.user_keywords.any(UserKeyword.keyword.has())
+        )
+
+        self._statement_equivalent(fixture_session(), s1, s2)
+
+        if use_aliased:
+            self.assert_compile(
+                s1,
+                "WITH second_user AS (SELECT users.id AS id, users.name AS "
+                "name, users.singular_id AS singular_id FROM users "
+                "WHERE users.name = :name_1) SELECT second_user.id, "
+                "second_user.name, second_user.singular_id FROM second_user "
+                "WHERE EXISTS (SELECT 1 FROM userkeywords "
+                "WHERE second_user.id = userkeywords.user_id AND "
+                "(EXISTS (SELECT 1 FROM keywords WHERE keywords.id = "
+                "userkeywords.keyword_id)))",
+            )
+        else:
+            self.assert_compile(
+                s1,
+                "SELECT users.id, users.name, users.singular_id FROM users "
+                "WHERE EXISTS (SELECT 1 FROM userkeywords WHERE users.id = "
+                "userkeywords.user_id AND (EXISTS (SELECT 1 FROM keywords "
+                "WHERE keywords.id = userkeywords.keyword_id)))",
+            )
+
 
 class DictOfTupleUpdateTest(fixtures.MappedTest):
     run_create_tables = None