From a34a4af8a80f4edd12b022753b69065025818e20 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 21 Jul 2021 11:18:01 -0400 Subject: [PATCH] implement cache key for return_defaults token Fixed critical caching issue where the ORM's persistence feature using INSERT..RETURNING would cache an incorrect query when mixing the "bulk save" and standard "flush" forms of INSERT. Fixes: #6793 Change-Id: Ifeb61c1226d3fa6d5e1c2e29b6f5ff77a27d6a2d --- doc/build/changelog/unreleased_14/6793.rst | 7 +++ lib/sqlalchemy/sql/crud.py | 6 +-- lib/sqlalchemy/sql/dml.py | 21 +++++++-- test/orm/test_bulk.py | 54 ++++++++++++++++++++++ test/sql/test_compare.py | 6 +++ 5 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6793.rst diff --git a/doc/build/changelog/unreleased_14/6793.rst b/doc/build/changelog/unreleased_14/6793.rst new file mode 100644 index 0000000000..059bdac65e --- /dev/null +++ b/doc/build/changelog/unreleased_14/6793.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 6793 + + Fixed critical caching issue where the ORM's persistence feature using + INSERT..RETURNING would cache an incorrect query when mixing the "bulk + save" and standard "flush" forms of INSERT. diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 74f5a1d05b..b8f8cb4cef 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -760,7 +760,7 @@ def _append_param_update( compiler.postfetch.append(c) elif ( implicit_return_defaults - and stmt._return_defaults is not True + and (stmt._return_defaults_columns or not stmt._return_defaults) and c in implicit_return_defaults ): compiler.returning.append(c) @@ -1024,10 +1024,10 @@ def _get_returning_modifiers(compiler, stmt, compile_state): implicit_return_defaults = False # pragma: no cover if implicit_return_defaults: - if stmt._return_defaults is True: + if not stmt._return_defaults_columns: implicit_return_defaults = set(stmt.table.c) else: - implicit_return_defaults = set(stmt._return_defaults) + implicit_return_defaults = set(stmt._return_defaults_columns) postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 048475040f..158cb40f27 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -214,7 +214,8 @@ class UpdateBase( _hints = util.immutabledict() named_with_column = False - _return_defaults = None + _return_defaults = False + _return_defaults_columns = None _returning = () is_dml = True @@ -794,7 +795,8 @@ class ValuesBase(UpdateBase): :attr:`_engine.CursorResult.inserted_primary_key_rows` """ - self._return_defaults = cols or True + self._return_defaults = True + self._return_defaults_columns = cols class Insert(ValuesBase): @@ -825,6 +827,11 @@ class Insert(ValuesBase): ("_post_values_clause", InternalTraversal.dp_clauseelement), ("_returning", InternalTraversal.dp_clauseelement_list), ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_list, + ), ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals @@ -929,7 +936,10 @@ class Insert(ValuesBase): if dialect_kw: self._validate_dialect_kwargs_deprecated(dialect_kw) - self._return_defaults = return_defaults + if return_defaults: + self._return_defaults = True + if not isinstance(return_defaults, bool): + self._return_defaults_columns = return_defaults @_generative def inline(self): @@ -1116,6 +1126,11 @@ class Update(DMLWhereBase, ValuesBase): ("_values", InternalTraversal.dp_dml_values), ("_returning", InternalTraversal.dp_clauseelement_list), ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_list, + ), ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py index 32ee807083..7e47507c00 100644 --- a/test/orm/test_bulk.py +++ b/test/orm/test_bulk.py @@ -866,3 +866,57 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest): ], ), ) + + +class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest): + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class User(Base): + __tablename__ = "users" + id = Column(Integer, primary_key=True) + name = Column(String(255), nullable=False) + + def test_issue_6793(self): + User = self.classes.User + + session = fixture_session() + + with self.sql_execution_asserter() as asserter: + + session.bulk_save_objects([User(name="A"), User(name="B")]) + + session.add(User(name="C")) + session.add(User(name="D")) + session.flush() + + asserter.assert_( + Conditional( + testing.db.dialect.insert_executemany_returning, + [ + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "A"}, {"name": "B"}], + ), + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "C"}, {"name": "D"}], + ), + ], + [ + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "A"}, {"name": "B"}], + ), + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "C"}], + ), + CompiledSQL( + "INSERT INTO users (name) VALUES (:name)", + [{"name": "D"}], + ), + ], + ) + ) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 188d9337ee..371e68a8ad 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -534,6 +534,9 @@ class CoreFixtures(object): ), lambda: ( table_a.insert(), + table_a.insert().return_defaults(), + table_a.insert().return_defaults(table_a.c.a), + table_a.insert().return_defaults(table_a.c.b), table_a.insert().values({})._annotate({"nocache": True}), table_b.insert(), table_b.insert().with_dialect_options(sqlite_foo="some value"), @@ -570,6 +573,9 @@ class CoreFixtures(object): ), lambda: ( table_b.update(), + table_b.update().return_defaults(), + table_b.update().return_defaults(table_b.c.a), + table_b.update().return_defaults(table_b.c.b), table_b.update().where(table_b.c.a == 5), table_b.update().where(table_b.c.b == 5), table_b.update() -- 2.47.2