From: Mike Bayer Date: Thu, 21 Aug 2025 15:20:11 +0000 (-0400) Subject: ensure assoc proxy from aliased() is generated in correct context X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=b229245a0ad6c2381aea66e58facfe2e2b92b320;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure assoc proxy from aliased() is generated in correct context Improved association proxy to behave slightly better when the parent class is used in an :func:`_orm.aliased` construct, so that the proxy as delivered by the :class:`.Aliased` behaves appropriate in terms of that aliased construct, including operators like ``.any()`` and ``.has()`` work correctly. Fixes: #11622 Change-Id: I6220d984d4323a01a38bd89cfbb1bae46d81c24e --- diff --git a/doc/build/changelog/unreleased_20/11622.rst b/doc/build/changelog/unreleased_20/11622.rst new file mode 100644 index 0000000000..25569da160 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11622.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 11622 + + Improved association proxy to behave slightly better when the parent class + is used in an :func:`_orm.aliased` construct, so that the proxy as + delivered by the :class:`.Aliased` behaves appropriate in terms of that + aliased construct, including operators like ``.any()`` and ``.has()`` work + correctly. diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 22d2bb570d..776ad6a44d 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -69,6 +69,7 @@ if typing.TYPE_CHECKING: from ..orm.interfaces import MapperProperty from ..orm.interfaces import PropComparator from ..orm.mapper import Mapper + from ..orm.util import AliasedInsp from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _InfoType @@ -1231,6 +1232,11 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]): _target_is_object: bool = True _is_canonical = True + def adapt_to_entity( + self, aliased_insp: AliasedInsp[Any] + ) -> AliasedAssociationProxyInstance[_T]: + return AliasedAssociationProxyInstance(self, aliased_insp) + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: """Produce a proxied 'contains' expression using EXISTS. @@ -1284,6 +1290,44 @@ class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]): ) +class AliasedAssociationProxyInstance(ObjectAssociationProxyInstance[_T]): + def __init__( + self, + parent_instance: ObjectAssociationProxyInstance[_T], + aliased_insp: AliasedInsp[Any], + ) -> None: + self.parent = parent_instance.parent + self.owning_class = parent_instance.owning_class + self.aliased_insp = aliased_insp + self.target_collection = parent_instance.target_collection + self.collection_class = None + self.target_class = parent_instance.target_class + self.value_attr = parent_instance.value_attr + + @property + def _comparator(self) -> PropComparator[Any]: + return getattr( # type: ignore + self.aliased_insp.entity, self.target_collection + ).comparator + + @property + def local_attr(self) -> SQLORMOperations[Any]: + """The 'local' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return cast( + "SQLORMOperations[Any]", + getattr(self.aliased_insp.entity, self.target_collection), + ) + + class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]): """an :class:`.AssociationProxyInstance` that has a database column as a target. diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 1aca0c97e2..792e6a23ad 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -16,6 +16,7 @@ from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import or_ +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.engine import default @@ -1825,13 +1826,18 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): # TODO: this is not the correct pattern, use session per test cls.session = Session(testing.db) + def _query_equivalent(self, q_proxy, q_direct): + self._equivalent(q_proxy.statement, q_direct.statement) + eq_(q_proxy.all(), q_direct.all()) + def _equivalent(self, q_proxy, q_direct): - proxy_sql = q_proxy.statement.compile(dialect=default.DefaultDialect()) - direct_sql = q_direct.statement.compile( - dialect=default.DefaultDialect() - ) + proxy_sql = q_proxy.compile(dialect=default.DefaultDialect()) + direct_sql = q_direct.compile(dialect=default.DefaultDialect()) eq_(str(proxy_sql), str(direct_sql)) - eq_(q_proxy.all(), q_direct.all()) + + def _statement_equivalent(self, session, q_proxy, q_direct): + self._equivalent(q_proxy, q_direct) + eq_(session.scalars(q_proxy).all(), session.scalars(q_direct).all()) def test_no_straight_expr(self): User = self.classes.User @@ -1871,12 +1877,12 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): q2 = self.session.query(User).filter( User.user_keywords.any(UserKeyword.value == "singular8") ) - self._equivalent(q1, q2) + self._query_equivalent(q1, q2) def test_filter_any_kwarg_ul_nul(self): UserKeyword, User = self.classes.UserKeyword, self.classes.User - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.keywords.any(keyword="jumped") ), @@ -1890,7 +1896,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): def test_filter_has_kwarg_nul_nul(self): UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword - self._equivalent( + self._query_equivalent( self.session.query(Keyword).filter(Keyword.user.has(name="user2")), self.session.query(Keyword).filter( Keyword.user_keyword.has(UserKeyword.user.has(name="user2")) @@ -1900,7 +1906,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): def test_filter_has_kwarg_nul_ul(self): User, Singular = self.classes.User, self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_keywords.any(keyword="jumped") ), @@ -1916,7 +1922,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self.classes.Keyword, ) - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.keywords.any(Keyword.keyword == "jumped") ), @@ -1934,7 +1940,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self.classes.Keyword, ) - self._equivalent( + self._query_equivalent( self.session.query(Keyword).filter( Keyword.user.has(User.name == "user2") ), @@ -1952,7 +1958,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): self.classes.Singular, ) - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_keywords.any(Keyword.keyword == "jumped") ), @@ -1966,7 +1972,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): def test_filter_contains_ul_nul(self): User = self.classes.User - self._equivalent( + self._query_equivalent( self.session.query(User).filter(User.keywords.contains(self.kw)), self.session.query(User).filter( User.user_keywords.any(keyword=self.kw) @@ -1979,7 +1985,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): with expect_warnings( "Got None for value of column keywords.singular_id;" ): - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_keywords.contains(self.kw) ), @@ -1991,7 +1997,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): def test_filter_eq_nul_nul(self): Keyword = self.classes.Keyword - self._equivalent( + self._query_equivalent( self.session.query(Keyword).filter(Keyword.user == self.u), self.session.query(Keyword).filter( Keyword.user_keyword.has(user=self.u) @@ -2002,7 +2008,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Keyword = self.classes.Keyword UserKeyword = self.classes.UserKeyword - self._equivalent( + self._query_equivalent( self.session.query(Keyword).filter(Keyword.user != self.u), self.session.query(Keyword).filter( Keyword.user_keyword.has(UserKeyword.user != self.u) @@ -2012,7 +2018,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): def test_filter_eq_null_nul_nul(self): UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword - self._equivalent( + self._query_equivalent( self.session.query(Keyword).filter(Keyword.user == None), # noqa self.session.query(Keyword).filter( or_( @@ -2025,7 +2031,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): def test_filter_ne_null_nul_nul(self): UserKeyword, Keyword = self.classes.UserKeyword, self.classes.Keyword - self._equivalent( + self._query_equivalent( self.session.query(Keyword).filter(Keyword.user != None), # noqa self.session.query(Keyword).filter( Keyword.user_keyword.has(UserKeyword.user != None) @@ -2036,7 +2042,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): UserKeyword = self.classes.UserKeyword User = self.classes.User - self._equivalent( + self._query_equivalent( self.session.query(UserKeyword).filter( UserKeyword.singular == None ), # noqa @@ -2052,7 +2058,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_value == None ), # noqa @@ -2070,7 +2076,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Singular = self.classes.Singular s4 = self.session.query(Singular).filter_by(value="singular4").one() - self._equivalent( + self._query_equivalent( self.session.query(UserKeyword).filter(UserKeyword.singular != s4), self.session.query(UserKeyword).filter( UserKeyword.user.has(User.singular != s4) @@ -2081,7 +2087,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_value != "singular4" ), @@ -2094,7 +2100,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_value == "singular4" ), @@ -2107,7 +2113,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_value != None ), # noqa @@ -2122,7 +2128,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter(User.singular_value.has()), self.session.query(User).filter(User.singular.has()), ) @@ -2133,7 +2139,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter(~User.singular_value.has()), self.session.query(User).filter(~User.singular.has()), ) @@ -2176,7 +2182,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): ) ) ) - self._equivalent(q1, q2) + self._query_equivalent(q1, q2) def test_filter_has_chained_has_to_any(self): User = self.classes.User @@ -2205,7 +2211,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): Singular.keywords.any(Keyword.keyword == "brown") ) ) - self._equivalent(q1, q2) + self._query_equivalent(q1, q2) def test_filter_has_scalar_raises(self): User = self.classes.User @@ -2241,7 +2247,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): ) ) - self._equivalent(q1, q2) + self._query_equivalent(q1, q2) def test_filter_contains_chained_any_to_has(self): User = self.classes.User @@ -2270,7 +2276,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): UserKeyword.keyword.has(Keyword.keyword == "brown") ) ) - self._equivalent(q1, q2) + self._query_equivalent(q1, q2) def test_filter_contains_chained_any_to_has_to_eq(self): User = self.classes.User @@ -2304,7 +2310,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): UserKeyword.user.has(User.singular == singular) ) ) - self._equivalent(q1, q2) + self._query_equivalent(q1, q2) def test_has_criterion_nul(self): # but we don't allow that with any criterion... @@ -2348,7 +2354,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter(User.singular_value.like("foo")), self.session.query(User).filter( User.singular.has(Singular.value.like("foo")) @@ -2359,7 +2365,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter( User.singular_value.contains("foo") ), @@ -2372,7 +2378,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter(User.singular_value == "foo"), self.session.query(User).filter( User.singular.has(Singular.value == "foo") @@ -2383,7 +2389,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter(User.singular_value != "foo"), self.session.query(User).filter( User.singular.has(Singular.value != "foo") @@ -2394,7 +2400,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): User = self.classes.User Singular = self.classes.Singular - self._equivalent( + self._query_equivalent( self.session.query(User).filter(User.singular_value == None), self.session.query(User).filter( or_( @@ -2451,6 +2457,48 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): "userkeywords.keyword_id", ) + @testing.variation("use_aliased", [True, False]) + def test_aliased_class_one(self, use_aliased): + """test #11622""" + User = self.classes.User + UserKeyword = self.classes.UserKeyword + + if use_aliased: + second_user = aliased( + User, + select(User).where(User.name == "second").cte("second_user"), + ) + else: + second_user = User + + s1 = select(second_user).where(second_user.keywords.any()) + s2 = select(second_user).where( + second_user.user_keywords.any(UserKeyword.keyword.has()) + ) + + self._statement_equivalent(fixture_session(), s1, s2) + + if use_aliased: + self.assert_compile( + s1, + "WITH second_user AS (SELECT users.id AS id, users.name AS " + "name, users.singular_id AS singular_id FROM users " + "WHERE users.name = :name_1) SELECT second_user.id, " + "second_user.name, second_user.singular_id FROM second_user " + "WHERE EXISTS (SELECT 1 FROM userkeywords " + "WHERE second_user.id = userkeywords.user_id AND " + "(EXISTS (SELECT 1 FROM keywords WHERE keywords.id = " + "userkeywords.keyword_id)))", + ) + else: + self.assert_compile( + s1, + "SELECT users.id, users.name, users.singular_id FROM users " + "WHERE EXISTS (SELECT 1 FROM userkeywords WHERE users.id = " + "userkeywords.user_id AND (EXISTS (SELECT 1 FROM keywords " + "WHERE keywords.id = userkeywords.keyword_id)))", + ) + class DictOfTupleUpdateTest(fixtures.MappedTest): run_create_tables = None