From: Mike Bayer Date: Wed, 24 Mar 2021 16:15:53 +0000 (-0400) Subject: Support __visit_name__ on PropComparator to work in cloning X-Git-Tag: rel_1_4_3~10^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0bf04029dcdd912a1e5a4cdac1cccf14b60b3ec9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support __visit_name__ on PropComparator to work in cloning Repaired support so that the :meth:`_sql.Select.params` method can work correctly with a :class:`_sql.Select` object that includes joins across ORM relationship structures, which is a new feature in 1.4. Fixes: #6124 Change-Id: Ia92fc33c3acbe66910e9e3bf00af9100de19b2b8 --- diff --git a/doc/build/changelog/unreleased_14/6124.rst b/doc/build/changelog/unreleased_14/6124.rst new file mode 100644 index 0000000000..ac08eaf9fa --- /dev/null +++ b/doc/build/changelog/unreleased_14/6124.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 6124 + + Repaired support so that the :meth:`_sql.Select.params` method can work + correctly with a :class:`_sql.Select` object that includes joins across ORM + relationship structures, which is a new feature in 1.4. + diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 2e48695f51..610ee2726b 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -85,6 +85,10 @@ class QueryableAttribute( is_attribute = True + # PropComparator has a __visit_name__ to participate within + # traversals. Disambiguate the attribute vs. a comparator. + __visit_name__ = "orm_instrumented_attribute" + def __init__( self, class_, diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index a660d7e1a5..e2cc369997 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -393,6 +393,8 @@ class PropComparator(operators.ColumnOperators): __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" + __visit_name__ = "orm_prop_comparator" + def __init__( self, prop, # type: MapperProperty diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 1a58356e35..a53b15bcbf 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -1,3 +1,4 @@ +from sqlalchemy import bindparam from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import insert @@ -24,6 +25,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.testing.util import resolve_lambda from .inheritance import _poly_fixtures from .test_query import QueryTest @@ -257,6 +259,52 @@ class JoinTest(QueryTest, AssertsCompiledSQL): }, ) + @testing.combinations( + ( + lambda User: select(User).where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ( + lambda User, Address: select(User) + .join_from(User, Address) + .where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ( + lambda User, Address: select(User) + .join_from(User, Address, User.addresses) + .where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ( + lambda User, Address: select(User) + .join(User.addresses) + .where(User.id == bindparam("foo")), + "SELECT users.id, users.name FROM users JOIN addresses " + "ON users.id = addresses.user_id WHERE users.id = :foo", + {"foo": "bar"}, + {"foo": "bar"}, + ), + ) + def test_params_with_join( + self, test_case, expected, bindparams, expected_params + ): + User, Address = self.classes("User", "Address") + + stmt = resolve_lambda(test_case, **locals()) + + stmt = stmt.params(**bindparams) + + self.assert_compile(stmt, expected, checkparams=expected_params) + class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL): """The Query object calls eanble_eagerloads(False) when you call diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index e4c89c7e8a..3c8f83f919 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -86,7 +86,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): assert um.attrs.addresses.primaryjoin.compare( users.c.id == addresses.c.user_id ) - assert um.attrs.addresses.order_by[0].compare(Address.id) + assert um.attrs.addresses.order_by[0].compare(Address.id.expression) configure_mappers()