]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support __visit_name__ on PropComparator to work in cloning
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 Mar 2021 16:15:53 +0000 (12:15 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 Mar 2021 18:14:04 +0000 (14:14 -0400)
Repaired support so that the :meth:`_sql.Select.params` method can work
correctly with a :class:`_sql.Select` object that includes joins across ORM
relationship structures, which is a new feature in 1.4.

Fixes: #6124
Change-Id: Ia92fc33c3acbe66910e9e3bf00af9100de19b2b8

doc/build/changelog/unreleased_14/6124.rst [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
test/orm/test_core_compilation.py
test/orm/test_mapper.py

diff --git a/doc/build/changelog/unreleased_14/6124.rst b/doc/build/changelog/unreleased_14/6124.rst
new file mode 100644 (file)
index 0000000..ac08eaf
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 6124
+
+    Repaired support so that the :meth:`_sql.Select.params` method can work
+    correctly with a :class:`_sql.Select` object that includes joins across ORM
+    relationship structures, which is a new feature in 1.4.
+
index 2e48695f51046878a979dbf1031ad08d68084a72..610ee2726b6a6cbab558e90f3af1b073c72f3bea 100644 (file)
@@ -85,6 +85,10 @@ class QueryableAttribute(
 
     is_attribute = True
 
+    # PropComparator has a __visit_name__ to participate within
+    # traversals.   Disambiguate the attribute vs. a comparator.
+    __visit_name__ = "orm_instrumented_attribute"
+
     def __init__(
         self,
         class_,
index a660d7e1a588f2ba3238db96cd4075ed71a68e9d..e2cc3699970d1c3607ce3a375467438058ea2070 100644 (file)
@@ -393,6 +393,8 @@ class PropComparator(operators.ColumnOperators):
 
     __slots__ = "prop", "property", "_parententity", "_adapt_to_entity"
 
+    __visit_name__ = "orm_prop_comparator"
+
     def __init__(
         self,
         prop,  # type: MapperProperty
index 1a58356e3513b4f7810dd85f31e31727dce0d722..a53b15bcbf6b6723f85484d7ffac12da7f1ab264 100644 (file)
@@ -1,3 +1,4 @@
+from sqlalchemy import bindparam
 from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import insert
@@ -24,6 +25,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.testing.util import resolve_lambda
 from .inheritance import _poly_fixtures
 from .test_query import QueryTest
 
@@ -257,6 +259,52 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
             },
         )
 
+    @testing.combinations(
+        (
+            lambda User: select(User).where(User.id == bindparam("foo")),
+            "SELECT users.id, users.name FROM users WHERE users.id = :foo",
+            {"foo": "bar"},
+            {"foo": "bar"},
+        ),
+        (
+            lambda User, Address: select(User)
+            .join_from(User, Address)
+            .where(User.id == bindparam("foo")),
+            "SELECT users.id, users.name FROM users JOIN addresses "
+            "ON users.id = addresses.user_id WHERE users.id = :foo",
+            {"foo": "bar"},
+            {"foo": "bar"},
+        ),
+        (
+            lambda User, Address: select(User)
+            .join_from(User, Address, User.addresses)
+            .where(User.id == bindparam("foo")),
+            "SELECT users.id, users.name FROM users JOIN addresses "
+            "ON users.id = addresses.user_id WHERE users.id = :foo",
+            {"foo": "bar"},
+            {"foo": "bar"},
+        ),
+        (
+            lambda User, Address: select(User)
+            .join(User.addresses)
+            .where(User.id == bindparam("foo")),
+            "SELECT users.id, users.name FROM users JOIN addresses "
+            "ON users.id = addresses.user_id WHERE users.id = :foo",
+            {"foo": "bar"},
+            {"foo": "bar"},
+        ),
+    )
+    def test_params_with_join(
+        self, test_case, expected, bindparams, expected_params
+    ):
+        User, Address = self.classes("User", "Address")
+
+        stmt = resolve_lambda(test_case, **locals())
+
+        stmt = stmt.params(**bindparams)
+
+        self.assert_compile(stmt, expected, checkparams=expected_params)
+
 
 class LoadersInSubqueriesTest(QueryTest, AssertsCompiledSQL):
     """The Query object calls eanble_eagerloads(False) when you call
index e4c89c7e8a9c42924e1c9abf20bdc0028550b36a..3c8f83f91996c8c5c37d9d4d42a2f0132092dec5 100644 (file)
@@ -86,7 +86,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         assert um.attrs.addresses.primaryjoin.compare(
             users.c.id == addresses.c.user_id
         )
-        assert um.attrs.addresses.order_by[0].compare(Address.id)
+        assert um.attrs.addresses.order_by[0].compare(Address.id.expression)
 
         configure_mappers()