From: Mike Bayer Date: Sun, 5 Jul 2020 17:44:58 +0000 (-0400) Subject: Ensure synchronize_session works with lambda statements X-Git-Tag: rel_1_4_0b1~240 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c6c9d5f925e4418c10c93c47fef53200dca11f00;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Ensure synchronize_session works with lambda statements A few places have logic that assumes the top-level statement is the actual UPDATE or DELETE which is not the case with a lambda. Ensure the correct object is used. This fixes issues specific to both "fetch" strategy as well as "evaluate" strategy. Fixes: #5442 Change-Id: Ic9cc01c696c3c338d5bc79688507e6717c4c169b --- diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index cbe7bde332..45ac2442a3 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -1888,8 +1888,18 @@ class BulkUDCompileState(CompileState): from_=err, ) - if statement.__visit_name__ == "update": - resolved_values = cls._get_resolved_values(mapper, statement) + if statement.__visit_name__ == "lambda_element": + # ._resolved is called on every LambdaElement in order to + # generate the cache key, so this access does not add + # additional expense + effective_statement = statement._resolved + else: + effective_statement = statement + + if effective_statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values( + mapper, effective_statement + ) value_evaluators = {} resolved_keys_as_propnames = cls._resolved_keys_as_propnames( mapper, resolved_values @@ -2012,10 +2022,20 @@ class BulkUDCompileState(CompileState): value_evaluators = _EMPTY_DICT - if statement.__visit_name__ == "update": + if statement.__visit_name__ == "lambda_element": + # ._resolved is called on every LambdaElement in order to + # generate the cache key, so this access does not add + # additional expense + effective_statement = statement._resolved + else: + effective_statement = statement + + if effective_statement.__visit_name__ == "update": target_cls = mapper.class_ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - resolved_values = cls._get_resolved_values(mapper, statement) + resolved_values = cls._get_resolved_values( + mapper, effective_statement + ) resolved_keys_as_propnames = cls._resolved_keys_as_propnames( mapper, resolved_values ) @@ -2073,8 +2093,12 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): elif statement._values: new_stmt._values = self._resolved_values + # if we are against a lambda statement we might not be the + # topmost object that received per-execute annotations + top_level_stmt = compiler.statement if ( - statement._annotations.get("synchronize_session", None) == "fetch" + top_level_stmt._annotations.get("synchronize_session", None) + == "fetch" and compiler.dialect.full_returning ): new_stmt = new_stmt.returning(*mapper.primary_key) @@ -2187,9 +2211,10 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): "parentmapper", None ) + top_level_stmt = compiler.statement if ( mapper - and statement._annotations.get("synchronize_session", None) + and top_level_stmt._annotations.get("synchronize_session", None) == "fetch" and compiler.dialect.full_returning ): diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index b0d7183154..8ec64c586d 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -1,11 +1,13 @@ from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column +from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import lambda_stmt from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String @@ -485,6 +487,62 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([15, 27, 19, 27])), ) + def test_update_future_lambda(self): + User, users = self.classes.User, self.tables.users + + sess = Session() + + john, jack, jill, jane = ( + sess.execute(future_select(User).order_by(User.id)).scalars().all() + ) + + sess.execute( + lambda_stmt( + lambda: update(User) + .where(User.age > 29) + .values({"age": User.age - 10}) + .execution_options(synchronize_session="evaluate") + ), + ) + + eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + eq_( + sess.execute(future_select(User.age).order_by(User.id)).all(), + list(zip([25, 37, 29, 27])), + ) + + sess.execute( + lambda_stmt( + lambda: update(User) + .where(User.age > 29) + .values({User.age: User.age - 10}) + .execution_options(synchronize_session="evaluate") + ) + ) + eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27]) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 27, 29, 27])), + ) + + sess.query(User).filter(User.age > 27).update( + {users.c.age_int: User.age - 10}, synchronize_session="evaluate" + ) + eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27]) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 27, 19, 27])), + ) + + sess.query(User).filter(User.age == 25).update( + {User.age: User.age - 10}, synchronize_session="fetch" + ) + eq_([john.age, jack.age, jill.age, jane.age], [15, 27, 19, 27]) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([15, 27, 19, 27])), + ) + def test_update_against_table_col(self): User, users = self.classes.User, self.tables.users @@ -565,6 +623,52 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ) + def test_update_fetch_returning_lambda(self): + User = self.classes.User + + sess = Session() + + john, jack, jill, jane = ( + sess.execute(future_select(User).order_by(User.id)).scalars().all() + ) + + with self.sql_execution_asserter() as asserter: + stmt = lambda_stmt( + lambda: update(User) + .where(User.age > 29) + .values({"age": User.age - 10}) + ) + sess.execute( + stmt, execution_options={"synchronize_session": "fetch"} + ) + + # these are simple values, these are now evaluated even with + # the "fetch" strategy, new in 1.4, so there is no expiry + eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + + if testing.db.dialect.full_returning: + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) " + "WHERE users.age_int > %(age_int_2)s RETURNING users.id", + [{"age_int_1": 10, "age_int_2": 29}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "SELECT users.id FROM users " + "WHERE users.age_int > :age_int_1", + [{"age_int_1": 29}], + ), + CompiledSQL( + "UPDATE users SET age_int=(users.age_int - :age_int_1) " + "WHERE users.age_int > :age_int_2", + [{"age_int_1": 10, "age_int_2": 29}], + ), + ) + def test_delete_fetch_returning(self): User = self.classes.User @@ -607,6 +711,51 @@ class UpdateDeleteTest(fixtures.MappedTest): in_(jill, sess) not_in_(jane, sess) + def test_delete_fetch_returning_lambda(self): + User = self.classes.User + + sess = Session() + + john, jack, jill, jane = ( + sess.execute(future_select(User).order_by(User.id)).scalars().all() + ) + + in_(john, sess) + in_(jack, sess) + + with self.sql_execution_asserter() as asserter: + stmt = lambda_stmt(lambda: delete(User).where(User.age > 29)) + sess.execute( + stmt, execution_options={"synchronize_session": "fetch"} + ) + + if testing.db.dialect.full_returning: + asserter.assert_( + CompiledSQL( + "DELETE FROM users WHERE users.age_int > %(age_int_1)s " + "RETURNING users.id", + [{"age_int_1": 29}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "SELECT users.id FROM users " + "WHERE users.age_int > :age_int_1", + [{"age_int_1": 29}], + ), + CompiledSQL( + "DELETE FROM users WHERE users.age_int > :age_int_1", + [{"age_int_1": 29}], + ), + ) + + in_(john, sess) + not_in_(jack, sess) + in_(jill, sess) + not_in_(jane, sess) + def test_update_with_filter_statement(self): """test for [ticket:4556] """