]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement cache key for return_defaults token
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Jul 2021 15:18:01 +0000 (11:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Jul 2021 17:57:22 +0000 (13:57 -0400)
Fixed critical caching issue where the ORM's persistence feature using
INSERT..RETURNING would cache an incorrect query when mixing the "bulk
save" and standard "flush" forms of INSERT.

Fixes: #6793
Change-Id: Ifeb61c1226d3fa6d5e1c2e29b6f5ff77a27d6a2d

doc/build/changelog/unreleased_14/6793.rst [new file with mode: 0644]
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
test/orm/test_bulk.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_14/6793.rst b/doc/build/changelog/unreleased_14/6793.rst
new file mode 100644 (file)
index 0000000..059bdac
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 6793
+
+    Fixed critical caching issue where the ORM's persistence feature using
+    INSERT..RETURNING would cache an incorrect query when mixing the "bulk
+    save" and standard "flush" forms of INSERT.
index 74f5a1d05b3d8e86bd60b3d98e20d753367bf5e3..b8f8cb4cef7638c3a9a86850dc4849b1f491727f 100644 (file)
@@ -760,7 +760,7 @@ def _append_param_update(
             compiler.postfetch.append(c)
     elif (
         implicit_return_defaults
-        and stmt._return_defaults is not True
+        and (stmt._return_defaults_columns or not stmt._return_defaults)
         and c in implicit_return_defaults
     ):
         compiler.returning.append(c)
@@ -1024,10 +1024,10 @@ def _get_returning_modifiers(compiler, stmt, compile_state):
         implicit_return_defaults = False  # pragma: no cover
 
     if implicit_return_defaults:
-        if stmt._return_defaults is True:
+        if not stmt._return_defaults_columns:
             implicit_return_defaults = set(stmt.table.c)
         else:
-            implicit_return_defaults = set(stmt._return_defaults)
+            implicit_return_defaults = set(stmt._return_defaults_columns)
 
     postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
 
index 048475040f6361c006b1818593036babce20da55..158cb40f2773e5a14510803895a204255f470cfc 100644 (file)
@@ -214,7 +214,8 @@ class UpdateBase(
     _hints = util.immutabledict()
     named_with_column = False
 
-    _return_defaults = None
+    _return_defaults = False
+    _return_defaults_columns = None
     _returning = ()
 
     is_dml = True
@@ -794,7 +795,8 @@ class ValuesBase(UpdateBase):
             :attr:`_engine.CursorResult.inserted_primary_key_rows`
 
         """
-        self._return_defaults = cols or True
+        self._return_defaults = True
+        self._return_defaults_columns = cols
 
 
 class Insert(ValuesBase):
@@ -825,6 +827,11 @@ class Insert(ValuesBase):
             ("_post_values_clause", InternalTraversal.dp_clauseelement),
             ("_returning", InternalTraversal.dp_clauseelement_list),
             ("_hints", InternalTraversal.dp_table_hint_list),
+            ("_return_defaults", InternalTraversal.dp_boolean),
+            (
+                "_return_defaults_columns",
+                InternalTraversal.dp_clauseelement_list,
+            ),
         ]
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
@@ -929,7 +936,10 @@ class Insert(ValuesBase):
         if dialect_kw:
             self._validate_dialect_kwargs_deprecated(dialect_kw)
 
-        self._return_defaults = return_defaults
+        if return_defaults:
+            self._return_defaults = True
+            if not isinstance(return_defaults, bool):
+                self._return_defaults_columns = return_defaults
 
     @_generative
     def inline(self):
@@ -1116,6 +1126,11 @@ class Update(DMLWhereBase, ValuesBase):
             ("_values", InternalTraversal.dp_dml_values),
             ("_returning", InternalTraversal.dp_clauseelement_list),
             ("_hints", InternalTraversal.dp_table_hint_list),
+            ("_return_defaults", InternalTraversal.dp_boolean),
+            (
+                "_return_defaults_columns",
+                InternalTraversal.dp_clauseelement_list,
+            ),
         ]
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
index 32ee8070835505bb2480c2e24811ee12ce328f62..7e47507c002e8eb121547904f451db8ab3a298c4 100644 (file)
@@ -866,3 +866,57 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
                 ],
             ),
         )
+
+
+class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class User(Base):
+            __tablename__ = "users"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(255), nullable=False)
+
+    def test_issue_6793(self):
+        User = self.classes.User
+
+        session = fixture_session()
+
+        with self.sql_execution_asserter() as asserter:
+
+            session.bulk_save_objects([User(name="A"), User(name="B")])
+
+            session.add(User(name="C"))
+            session.add(User(name="D"))
+            session.flush()
+
+        asserter.assert_(
+            Conditional(
+                testing.db.dialect.insert_executemany_returning,
+                [
+                    CompiledSQL(
+                        "INSERT INTO users (name) VALUES (:name)",
+                        [{"name": "A"}, {"name": "B"}],
+                    ),
+                    CompiledSQL(
+                        "INSERT INTO users (name) VALUES (:name)",
+                        [{"name": "C"}, {"name": "D"}],
+                    ),
+                ],
+                [
+                    CompiledSQL(
+                        "INSERT INTO users (name) VALUES (:name)",
+                        [{"name": "A"}, {"name": "B"}],
+                    ),
+                    CompiledSQL(
+                        "INSERT INTO users (name) VALUES (:name)",
+                        [{"name": "C"}],
+                    ),
+                    CompiledSQL(
+                        "INSERT INTO users (name) VALUES (:name)",
+                        [{"name": "D"}],
+                    ),
+                ],
+            )
+        )
index 188d9337ee90ba814fcafa854fdf022a73c94677..371e68a8ada883c56a57f054cbfb64724d9112d9 100644 (file)
@@ -534,6 +534,9 @@ class CoreFixtures(object):
         ),
         lambda: (
             table_a.insert(),
+            table_a.insert().return_defaults(),
+            table_a.insert().return_defaults(table_a.c.a),
+            table_a.insert().return_defaults(table_a.c.b),
             table_a.insert().values({})._annotate({"nocache": True}),
             table_b.insert(),
             table_b.insert().with_dialect_options(sqlite_foo="some value"),
@@ -570,6 +573,9 @@ class CoreFixtures(object):
         ),
         lambda: (
             table_b.update(),
+            table_b.update().return_defaults(),
+            table_b.update().return_defaults(table_b.c.a),
+            table_b.update().return_defaults(table_b.c.b),
             table_b.update().where(table_b.c.a == 5),
             table_b.update().where(table_b.c.b == 5),
             table_b.update()