]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure synchronize_session works with lambda statements
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Jul 2020 17:44:58 +0000 (13:44 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Jul 2020 18:45:35 +0000 (14:45 -0400)
A few places have logic that assumes the top-level statement
is the actual UPDATE or DELETE which is not the case with a
lambda.  Ensure the correct object is used.  This
fixes issues specific to both "fetch" strategy
as well as "evaluate" strategy.

Fixes: #5442
Change-Id: Ic9cc01c696c3c338d5bc79688507e6717c4c169b

lib/sqlalchemy/orm/persistence.py
test/orm/test_update_delete.py

index cbe7bde33230e1ea07d32ecbeaed6b09b9ca90d9..45ac2442a3bc7c6505c748de5890bfbd1ec063c0 100644 (file)
@@ -1888,8 +1888,18 @@ class BulkUDCompileState(CompileState):
                 from_=err,
             )
 
-        if statement.__visit_name__ == "update":
-            resolved_values = cls._get_resolved_values(mapper, statement)
+        if statement.__visit_name__ == "lambda_element":
+            # ._resolved is called on every LambdaElement in order to
+            # generate the cache key, so this access does not add
+            # additional expense
+            effective_statement = statement._resolved
+        else:
+            effective_statement = statement
+
+        if effective_statement.__visit_name__ == "update":
+            resolved_values = cls._get_resolved_values(
+                mapper, effective_statement
+            )
             value_evaluators = {}
             resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
                 mapper, resolved_values
@@ -2012,10 +2022,20 @@ class BulkUDCompileState(CompileState):
 
         value_evaluators = _EMPTY_DICT
 
-        if statement.__visit_name__ == "update":
+        if statement.__visit_name__ == "lambda_element":
+            # ._resolved is called on every LambdaElement in order to
+            # generate the cache key, so this access does not add
+            # additional expense
+            effective_statement = statement._resolved
+        else:
+            effective_statement = statement
+
+        if effective_statement.__visit_name__ == "update":
             target_cls = mapper.class_
             evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
-            resolved_values = cls._get_resolved_values(mapper, statement)
+            resolved_values = cls._get_resolved_values(
+                mapper, effective_statement
+            )
             resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
                 mapper, resolved_values
             )
@@ -2073,8 +2093,12 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
         elif statement._values:
             new_stmt._values = self._resolved_values
 
+        # if we are against a lambda statement we might not be the
+        # topmost object that received per-execute annotations
+        top_level_stmt = compiler.statement
         if (
-            statement._annotations.get("synchronize_session", None) == "fetch"
+            top_level_stmt._annotations.get("synchronize_session", None)
+            == "fetch"
             and compiler.dialect.full_returning
         ):
             new_stmt = new_stmt.returning(*mapper.primary_key)
@@ -2187,9 +2211,10 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
             "parentmapper", None
         )
 
+        top_level_stmt = compiler.statement
         if (
             mapper
-            and statement._annotations.get("synchronize_session", None)
+            and top_level_stmt._annotations.get("synchronize_session", None)
             == "fetch"
             and compiler.dialect.full_returning
         ):
index b0d7183154bdf6b7b80c2957be6e806f2cc1457a..8ec64c586dd0ec00e8ac184b920adc959813c173 100644 (file)
@@ -1,11 +1,13 @@
 from sqlalchemy import Boolean
 from sqlalchemy import case
 from sqlalchemy import column
+from sqlalchemy import delete
 from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import lambda_stmt
 from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import String
@@ -485,6 +487,62 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([15, 27, 19, 27])),
         )
 
+    def test_update_future_lambda(self):
+        User, users = self.classes.User, self.tables.users
+
+        sess = Session()
+
+        john, jack, jill, jane = (
+            sess.execute(future_select(User).order_by(User.id)).scalars().all()
+        )
+
+        sess.execute(
+            lambda_stmt(
+                lambda: update(User)
+                .where(User.age > 29)
+                .values({"age": User.age - 10})
+                .execution_options(synchronize_session="evaluate")
+            ),
+        )
+
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+        eq_(
+            sess.execute(future_select(User.age).order_by(User.id)).all(),
+            list(zip([25, 37, 29, 27])),
+        )
+
+        sess.execute(
+            lambda_stmt(
+                lambda: update(User)
+                .where(User.age > 29)
+                .values({User.age: User.age - 10})
+                .execution_options(synchronize_session="evaluate")
+            )
+        )
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27])
+        eq_(
+            sess.query(User.age).order_by(User.id).all(),
+            list(zip([25, 27, 29, 27])),
+        )
+
+        sess.query(User).filter(User.age > 27).update(
+            {users.c.age_int: User.age - 10}, synchronize_session="evaluate"
+        )
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27])
+        eq_(
+            sess.query(User.age).order_by(User.id).all(),
+            list(zip([25, 27, 19, 27])),
+        )
+
+        sess.query(User).filter(User.age == 25).update(
+            {User.age: User.age - 10}, synchronize_session="fetch"
+        )
+        eq_([john.age, jack.age, jill.age, jane.age], [15, 27, 19, 27])
+        eq_(
+            sess.query(User.age).order_by(User.id).all(),
+            list(zip([15, 27, 19, 27])),
+        )
+
     def test_update_against_table_col(self):
         User, users = self.classes.User, self.tables.users
 
@@ -565,6 +623,52 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 ),
             )
 
+    def test_update_fetch_returning_lambda(self):
+        User = self.classes.User
+
+        sess = Session()
+
+        john, jack, jill, jane = (
+            sess.execute(future_select(User).order_by(User.id)).scalars().all()
+        )
+
+        with self.sql_execution_asserter() as asserter:
+            stmt = lambda_stmt(
+                lambda: update(User)
+                .where(User.age > 29)
+                .values({"age": User.age - 10})
+            )
+            sess.execute(
+                stmt, execution_options={"synchronize_session": "fetch"}
+            )
+
+            # these are simple values, these are now evaluated even with
+            # the "fetch" strategy, new in 1.4, so there is no expiry
+            eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+
+        if testing.db.dialect.full_returning:
+            asserter.assert_(
+                CompiledSQL(
+                    "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) "
+                    "WHERE users.age_int > %(age_int_2)s RETURNING users.id",
+                    [{"age_int_1": 10, "age_int_2": 29}],
+                    dialect="postgresql",
+                ),
+            )
+        else:
+            asserter.assert_(
+                CompiledSQL(
+                    "SELECT users.id FROM users "
+                    "WHERE users.age_int > :age_int_1",
+                    [{"age_int_1": 29}],
+                ),
+                CompiledSQL(
+                    "UPDATE users SET age_int=(users.age_int - :age_int_1) "
+                    "WHERE users.age_int > :age_int_2",
+                    [{"age_int_1": 10, "age_int_2": 29}],
+                ),
+            )
+
     def test_delete_fetch_returning(self):
         User = self.classes.User
 
@@ -607,6 +711,51 @@ class UpdateDeleteTest(fixtures.MappedTest):
         in_(jill, sess)
         not_in_(jane, sess)
 
+    def test_delete_fetch_returning_lambda(self):
+        User = self.classes.User
+
+        sess = Session()
+
+        john, jack, jill, jane = (
+            sess.execute(future_select(User).order_by(User.id)).scalars().all()
+        )
+
+        in_(john, sess)
+        in_(jack, sess)
+
+        with self.sql_execution_asserter() as asserter:
+            stmt = lambda_stmt(lambda: delete(User).where(User.age > 29))
+            sess.execute(
+                stmt, execution_options={"synchronize_session": "fetch"}
+            )
+
+        if testing.db.dialect.full_returning:
+            asserter.assert_(
+                CompiledSQL(
+                    "DELETE FROM users WHERE users.age_int > %(age_int_1)s "
+                    "RETURNING users.id",
+                    [{"age_int_1": 29}],
+                    dialect="postgresql",
+                ),
+            )
+        else:
+            asserter.assert_(
+                CompiledSQL(
+                    "SELECT users.id FROM users "
+                    "WHERE users.age_int > :age_int_1",
+                    [{"age_int_1": 29}],
+                ),
+                CompiledSQL(
+                    "DELETE FROM users WHERE users.age_int > :age_int_1",
+                    [{"age_int_1": 29}],
+                ),
+            )
+
+        in_(john, sess)
+        not_in_(jack, sess)
+        in_(jill, sess)
+        not_in_(jane, sess)
+
     def test_update_with_filter_statement(self):
         """test for [ticket:4556] """