]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
restate all upsert in terms of statement extensions (patch 3)
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Feb 2025 22:53:40 +0000 (17:53 -0500)
committerMichael Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Mar 2025 18:28:18 +0000 (18:28 +0000)
Change-Id: I0595ba8e2bd930e22f4c06d7a813bcd23060cb7a

lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/dml.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/dml.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sqlite/dml.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_on_duplicate.py
test/dialect/postgresql/test_compiler.py
test/dialect/test_sqlite.py
test/sql/test_compare.py

index 7838b455b92a35b059b4c3a7252c6d1e264ed5ee..df4d93c4811d9291c0fc56d959e6bdcf3ca53ea0 100644 (file)
@@ -1444,41 +1444,32 @@ class MySQLCompiler(compiler.SQLCompiler):
         for column in (col for col in cols if col.key in on_duplicate_update):
             val = on_duplicate_update[column.key]
 
-            # TODO: this coercion should be up front.  we can't cache
-            # SQL constructs with non-bound literals buried in them
-            if coercions._is_literal(val):
-                val = elements.BindParameter(None, val, type_=column.type)
-                value_text = self.process(val.self_group(), use_schema=False)
-            else:
-
-                def replace(obj):
-                    if (
-                        isinstance(obj, elements.BindParameter)
-                        and obj.type._isnull
-                    ):
-                        obj = obj._clone()
-                        obj.type = column.type
-                        return obj
-                    elif (
-                        isinstance(obj, elements.ColumnClause)
-                        and obj.table is on_duplicate.inserted_alias
-                    ):
-                        if requires_mysql8_alias:
-                            column_literal_clause = (
-                                f"{_on_dup_alias_name}."
-                                f"{self.preparer.quote(obj.name)}"
-                            )
-                        else:
-                            column_literal_clause = (
-                                f"VALUES({self.preparer.quote(obj.name)})"
-                            )
-                        return literal_column(column_literal_clause)
+            def replace(obj):
+                if (
+                    isinstance(obj, elements.BindParameter)
+                    and obj.type._isnull
+                ):
+                    return obj._with_binary_element_type(column.type)
+                elif (
+                    isinstance(obj, elements.ColumnClause)
+                    and obj.table is on_duplicate.inserted_alias
+                ):
+                    if requires_mysql8_alias:
+                        column_literal_clause = (
+                            f"{_on_dup_alias_name}."
+                            f"{self.preparer.quote(obj.name)}"
+                        )
                     else:
-                        # element is not replaced
-                        return None
+                        column_literal_clause = (
+                            f"VALUES({self.preparer.quote(obj.name)})"
+                        )
+                    return literal_column(column_literal_clause)
+                else:
+                    # element is not replaced
+                    return None
 
-                val = visitors.replacement_traverse(val, {}, replace)
-                value_text = self.process(val.self_group(), use_schema=False)
+            val = visitors.replacement_traverse(val, {}, replace)
+            value_text = self.process(val.self_group(), use_schema=False)
 
             name_text = self.preparer.quote(column.name)
             clauses.append("%s = %s" % (name_text, value_text))
index f3be3c395d28f8ceb339bffe8f1c8dfce3063d4f..61476af0229e7be1ed4f930e18767bfad6c93f6c 100644 (file)
@@ -21,7 +21,6 @@ from ...sql import coercions
 from ...sql import roles
 from ...sql._typing import _DMLTableArgument
 from ...sql.base import _exclusive_against
-from ...sql.base import _generative
 from ...sql.base import ColumnCollection
 from ...sql.base import ReadOnlyColumnCollection
 from ...sql.base import SyntaxExtension
@@ -30,6 +29,7 @@ from ...sql.elements import ClauseElement
 from ...sql.elements import KeyedColumnElement
 from ...sql.expression import alias
 from ...sql.selectable import NamedFromClause
+from ...sql.sqltypes import NULLTYPE
 from ...sql.visitors import InternalTraversal
 from ...util.typing import Self
 
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
     from ...sql._typing import _LimitOffsetType
     from ...sql.dml import Delete
     from ...sql.dml import Update
+    from ...sql.elements import ColumnElement
     from ...sql.visitors import _TraverseInternalsType
 
 __all__ = ("Insert", "insert")
@@ -114,7 +115,7 @@ class Insert(StandardInsert):
     """
 
     stringify_dialect = "mysql"
-    inherit_cache = False
+    inherit_cache = True
 
     @property
     def inserted(
@@ -154,7 +155,6 @@ class Insert(StandardInsert):
     def inserted_alias(self) -> NamedFromClause:
         return alias(self.table, name="inserted")
 
-    @_generative
     @_exclusive_against(
         "_post_values_clause",
         msgs={
@@ -225,20 +225,22 @@ class Insert(StandardInsert):
         else:
             values = kw
 
-        self._post_values_clause = OnDuplicateClause(
-            self.inserted_alias, values
-        )
-        return self
+        return self.ext(OnDuplicateClause(self.inserted_alias, values))
 
 
-class OnDuplicateClause(ClauseElement):
+class OnDuplicateClause(SyntaxExtension, ClauseElement):
     __visit_name__ = "on_duplicate_key_update"
 
     _parameter_ordering: Optional[List[str]] = None
 
-    update: Dict[str, Any]
+    update: Dict[str, ColumnElement[Any]]
     stringify_dialect = "mysql"
 
+    _traverse_internals = [
+        ("_parameter_ordering", InternalTraversal.dp_string_list),
+        ("update", InternalTraversal.dp_dml_values),
+    ]
+
     def __init__(
         self, inserted_alias: NamedFromClause, update: _UpdateArg
     ) -> None:
@@ -267,7 +269,18 @@ class OnDuplicateClause(ClauseElement):
                 "or a ColumnCollection such as the `.c.` collection "
                 "of a Table object"
             )
-        self.update = update
+
+        self.update = {
+            k: coercions.expect(
+                roles.ExpressionElementRole, v, type_=NULLTYPE, is_crud=True
+            )
+            for k, v in update.items()
+        }
+
+    def apply_to_insert(self, insert_stmt: StandardInsert) -> None:
+        insert_stmt.apply_syntax_extension_point(
+            self.append_replacing_same_type, "post_values"
+        )
 
 
 _UpdateArg = Union[
index 83bd99d7f0a0ad04fa786e16a848738cb072855a..38e834cf27e7ad9aa4869cc7faed7287a5991f68 100644 (file)
@@ -2085,18 +2085,12 @@ class PGCompiler(compiler.SQLCompiler):
             else:
                 continue
 
-            # TODO: this coercion should be up front.  we can't cache
-            # SQL constructs with non-bound literals buried in them
-            if coercions._is_literal(value):
-                value = elements.BindParameter(None, value, type_=c.type)
-
-            else:
-                if (
-                    isinstance(value, elements.BindParameter)
-                    and value.type._isnull
-                ):
-                    value = value._clone()
-                    value.type = c.type
+            assert not coercions._is_literal(value)
+            if (
+                isinstance(value, elements.BindParameter)
+                and value.type._isnull
+            ):
+                value = value._with_binary_element_type(c.type)
             value_text = self.process(value.self_group(), use_schema=False)
 
             key_text = self.preparer.quote(c.name)
index 1187b6bf5f03a71b92c5b7cbe870b691aff5ca46..69647546610dce54748ef95f9369a83a8216fbf9 100644 (file)
@@ -7,9 +7,9 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Dict
 from typing import List
 from typing import Optional
-from typing import Tuple
 from typing import Union
 
 from . import ext
@@ -24,18 +24,20 @@ from ...sql import roles
 from ...sql import schema
 from ...sql._typing import _DMLTableArgument
 from ...sql.base import _exclusive_against
-from ...sql.base import _generative
 from ...sql.base import ColumnCollection
 from ...sql.base import ReadOnlyColumnCollection
+from ...sql.base import SyntaxExtension
+from ...sql.dml import _DMLColumnElement
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
 from ...sql.elements import ColumnElement
 from ...sql.elements import KeyedColumnElement
 from ...sql.elements import TextClause
 from ...sql.expression import alias
+from ...sql.type_api import NULLTYPE
+from ...sql.visitors import InternalTraversal
 from ...util.typing import Self
 
-
 __all__ = ("Insert", "insert")
 
 
@@ -70,7 +72,7 @@ class Insert(StandardInsert):
     """
 
     stringify_dialect = "postgresql"
-    inherit_cache = False
+    inherit_cache = True
 
     @util.memoized_property
     def excluded(
@@ -109,7 +111,6 @@ class Insert(StandardInsert):
         },
     )
 
-    @_generative
     @_on_conflict_exclusive
     def on_conflict_do_update(
         self,
@@ -169,12 +170,12 @@ class Insert(StandardInsert):
             :ref:`postgresql_insert_on_conflict`
 
         """
-        self._post_values_clause = OnConflictDoUpdate(
-            constraint, index_elements, index_where, set_, where
+        return self.ext(
+            OnConflictDoUpdate(
+                constraint, index_elements, index_where, set_, where
+            )
         )
-        return self
 
-    @_generative
     @_on_conflict_exclusive
     def on_conflict_do_nothing(
         self,
@@ -206,13 +207,12 @@ class Insert(StandardInsert):
             :ref:`postgresql_insert_on_conflict`
 
         """
-        self._post_values_clause = OnConflictDoNothing(
-            constraint, index_elements, index_where
+        return self.ext(
+            OnConflictDoNothing(constraint, index_elements, index_where)
         )
-        return self
 
 
-class OnConflictClause(ClauseElement):
+class OnConflictClause(SyntaxExtension, ClauseElement):
     stringify_dialect = "postgresql"
 
     constraint_target: Optional[str]
@@ -221,6 +221,12 @@ class OnConflictClause(ClauseElement):
         Union[ColumnElement[Any], TextClause]
     ]
 
+    _traverse_internals = [
+        ("constraint_target", InternalTraversal.dp_string),
+        ("inferred_target_elements", InternalTraversal.dp_multi_list),
+        ("inferred_target_whereclause", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(
         self,
         constraint: _OnConflictConstraintT = None,
@@ -283,17 +289,29 @@ class OnConflictClause(ClauseElement):
                 self.inferred_target_whereclause
             ) = None
 
+    def apply_to_insert(self, insert_stmt: StandardInsert) -> None:
+        insert_stmt.apply_syntax_extension_point(
+            self.append_replacing_same_type, "post_values"
+        )
+
 
 class OnConflictDoNothing(OnConflictClause):
     __visit_name__ = "on_conflict_do_nothing"
 
+    inherit_cache = True
+
 
 class OnConflictDoUpdate(OnConflictClause):
     __visit_name__ = "on_conflict_do_update"
 
-    update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
+    update_values_to_set: Dict[_DMLColumnElement, ColumnElement[Any]]
     update_whereclause: Optional[ColumnElement[Any]]
 
+    _traverse_internals = OnConflictClause._traverse_internals + [
+        ("update_values_to_set", InternalTraversal.dp_dml_values),
+        ("update_whereclause", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(
         self,
         constraint: _OnConflictConstraintT = None,
@@ -328,10 +346,13 @@ class OnConflictDoUpdate(OnConflictClause):
                 "or a ColumnCollection such as the `.c.` collection "
                 "of a Table object"
             )
-        self.update_values_to_set = [
-            (coercions.expect(roles.DMLColumnRole, key), value)
-            for key, value in set_.items()
-        ]
+
+        self.update_values_to_set = {
+            coercions.expect(roles.DMLColumnRole, k): coercions.expect(
+                roles.ExpressionElementRole, v, type_=NULLTYPE, is_crud=True
+            )
+            for k, v in set_.items()
+        }
         self.update_whereclause = (
             coercions.expect(roles.WhereHavingRole, where)
             if where is not None
index 96b2414ccec4a20ca90ea732850054b608032067..7b8e42a2854ba23484fb88a4b80bc18656b511a9 100644 (file)
@@ -1533,16 +1533,11 @@ class SQLiteCompiler(compiler.SQLCompiler):
             else:
                 continue
 
-            if coercions._is_literal(value):
-                value = elements.BindParameter(None, value, type_=c.type)
-
-            else:
-                if (
-                    isinstance(value, elements.BindParameter)
-                    and value.type._isnull
-                ):
-                    value = value._clone()
-                    value.type = c.type
+            if (
+                isinstance(value, elements.BindParameter)
+                and value.type._isnull
+            ):
+                value = value._with_binary_element_type(c.type)
             value_text = self.process(value.self_group(), use_schema=False)
 
             key_text = self.preparer.quote(c.name)
index 84cdb8bec234b3e981255db0582fd4a9c49bdf7c..fc16f1eaa43964b9af9ff917ad4c5ab871cfcc9f 100644 (file)
@@ -7,9 +7,9 @@
 from __future__ import annotations
 
 from typing import Any
+from typing import Dict
 from typing import List
 from typing import Optional
-from typing import Tuple
 from typing import Union
 
 from .._typing import _OnConflictIndexElementsT
@@ -22,15 +22,18 @@ from ...sql import roles
 from ...sql import schema
 from ...sql._typing import _DMLTableArgument
 from ...sql.base import _exclusive_against
-from ...sql.base import _generative
 from ...sql.base import ColumnCollection
 from ...sql.base import ReadOnlyColumnCollection
+from ...sql.base import SyntaxExtension
+from ...sql.dml import _DMLColumnElement
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
 from ...sql.elements import ColumnElement
 from ...sql.elements import KeyedColumnElement
 from ...sql.elements import TextClause
 from ...sql.expression import alias
+from ...sql.sqltypes import NULLTYPE
+from ...sql.visitors import InternalTraversal
 from ...util.typing import Self
 
 __all__ = ("Insert", "insert")
@@ -73,7 +76,7 @@ class Insert(StandardInsert):
     """
 
     stringify_dialect = "sqlite"
-    inherit_cache = False
+    inherit_cache = True
 
     @util.memoized_property
     def excluded(
@@ -107,7 +110,6 @@ class Insert(StandardInsert):
         },
     )
 
-    @_generative
     @_on_conflict_exclusive
     def on_conflict_do_update(
         self,
@@ -155,12 +157,10 @@ class Insert(StandardInsert):
 
         """
 
-        self._post_values_clause = OnConflictDoUpdate(
-            index_elements, index_where, set_, where
+        return self.ext(
+            OnConflictDoUpdate(index_elements, index_where, set_, where)
         )
-        return self
 
-    @_generative
     @_on_conflict_exclusive
     def on_conflict_do_nothing(
         self,
@@ -181,13 +181,10 @@ class Insert(StandardInsert):
 
         """
 
-        self._post_values_clause = OnConflictDoNothing(
-            index_elements, index_where
-        )
-        return self
+        return self.ext(OnConflictDoNothing(index_elements, index_where))
 
 
-class OnConflictClause(ClauseElement):
+class OnConflictClause(SyntaxExtension, ClauseElement):
     stringify_dialect = "sqlite"
 
     inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
@@ -195,6 +192,11 @@ class OnConflictClause(ClauseElement):
         Union[ColumnElement[Any], TextClause]
     ]
 
+    _traverse_internals = [
+        ("inferred_target_elements", InternalTraversal.dp_multi_list),
+        ("inferred_target_whereclause", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(
         self,
         index_elements: _OnConflictIndexElementsT = None,
@@ -218,17 +220,29 @@ class OnConflictClause(ClauseElement):
                 self.inferred_target_whereclause
             ) = None
 
+    def apply_to_insert(self, insert_stmt: StandardInsert) -> None:
+        insert_stmt.apply_syntax_extension_point(
+            self.append_replacing_same_type, "post_values"
+        )
+
 
 class OnConflictDoNothing(OnConflictClause):
     __visit_name__ = "on_conflict_do_nothing"
 
+    inherit_cache = True
+
 
 class OnConflictDoUpdate(OnConflictClause):
     __visit_name__ = "on_conflict_do_update"
 
-    update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
+    update_values_to_set: Dict[_DMLColumnElement, ColumnElement[Any]]
     update_whereclause: Optional[ColumnElement[Any]]
 
+    _traverse_internals = OnConflictClause._traverse_internals + [
+        ("update_values_to_set", InternalTraversal.dp_dml_values),
+        ("update_whereclause", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(
         self,
         index_elements: _OnConflictIndexElementsT = None,
@@ -252,10 +266,12 @@ class OnConflictDoUpdate(OnConflictClause):
                 "or a ColumnCollection such as the `.c.` collection "
                 "of a Table object"
             )
-        self.update_values_to_set = [
-            (coercions.expect(roles.DMLColumnRole, key), value)
-            for key, value in set_.items()
-        ]
+        self.update_values_to_set = {
+            coercions.expect(roles.DMLColumnRole, k): coercions.expect(
+                roles.ExpressionElementRole, v, type_=NULLTYPE, is_crud=True
+            )
+            for k, v in set_.items()
+        }
         self.update_whereclause = (
             coercions.expect(roles.WhereHavingRole, where)
             if where is not None
index 5c98be3f6ae1affab0fb1e138837b411fca637dd..553298c549bc9575b2968df1bc42250eb21e658d 100644 (file)
@@ -1,3 +1,5 @@
+import random
+
 from sqlalchemy import BLOB
 from sqlalchemy import BOOLEAN
 from sqlalchemy import Boolean
@@ -630,6 +632,51 @@ class CustomExtensionTest(
 ):
     __dialect__ = "mysql"
 
+    @fixtures.CacheKeySuite.run_suite_tests
+    def test_insert_on_duplicate_key_cache_key(self):
+        table = Table(
+            "foos",
+            MetaData(),
+            Column("id", Integer, primary_key=True),
+            Column("bar", String(10)),
+            Column("baz", String(10)),
+        )
+
+        def stmt0():
+            # note a multivalues INSERT is not cacheable; use just one
+            # set of values
+            return insert(table).values(
+                {"id": 1, "bar": "ab"},
+            )
+
+        def stmt1():
+            stmt = stmt0()
+            return stmt.on_duplicate_key_update(
+                bar=stmt.inserted.bar, baz=stmt.inserted.baz
+            )
+
+        def stmt15():
+            stmt = insert(table).values(
+                {"id": 1},
+            )
+            return stmt.on_duplicate_key_update(
+                bar=stmt.inserted.bar, baz=stmt.inserted.baz
+            )
+
+        def stmt2():
+            stmt = stmt0()
+            return stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
+        def stmt3():
+            stmt = stmt0()
+            # use different literal values; ensure each cache key is
+            # identical
+            return stmt.on_duplicate_key_update(
+                bar=random.choice(["a", "b", "c"])
+            )
+
+        return lambda: [stmt0(), stmt1(), stmt15(), stmt2(), stmt3()]
+
     @fixtures.CacheKeySuite.run_suite_tests
     def test_dml_limit_cache_key(self):
         t = sql.table("t", sql.column("col1"), sql.column("col2"))
index 35aebb470c30bf0976986a057f9fa5b3c2ac09b3..307057c8e359a0ca6da629c745dba2b8316bde42 100644 (file)
@@ -1,3 +1,5 @@
+import random
+
 from sqlalchemy import Boolean
 from sqlalchemy import Column
 from sqlalchemy import exc
@@ -211,3 +213,25 @@ class OnDuplicateTest(fixtures.TablesTest):
             stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz")
         )
         eq_(result.inserted_primary_key, (1,))
+
+    def test_bound_caching(self, connection):
+        foos = self.tables.foos
+        connection.execute(insert(foos).values(dict(id=1, bar="b", baz="bz")))
+
+        for scenario in [
+            (random.choice(["c", "d", "e"]), random.choice(["f", "g", "h"]))
+            for i in range(10)
+        ]:
+            stmt = insert(foos).values(dict(id=1, bar="q"))
+            stmt = stmt.on_duplicate_key_update(
+                bar=scenario[0], baz=scenario[1]
+            )
+
+            connection.execute(stmt)
+
+            eq_(
+                connection.execute(
+                    foos.select().where(foos.c.id == 1)
+                ).fetchall(),
+                [(1, scenario[0], scenario[1], False)],
+            )
index f02b42c0b21bbdad276c9ab601a3492721b92c6c..b6bd6257088dd3420de4e6e4be2fbbaed4d1929c 100644 (file)
@@ -1,3 +1,5 @@
+import random
+
 from sqlalchemy import and_
 from sqlalchemy import BigInteger
 from sqlalchemy import bindparam
@@ -2667,7 +2669,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
-class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
+class InsertOnConflictTest(
+    fixtures.TablesTest, AssertsCompiledSQL, fixtures.CacheKeySuite
+):
     __dialect__ = postgresql.dialect()
 
     run_create_tables = None
@@ -2786,6 +2790,111 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
                 f"{expected}",
             )
 
+    @fixtures.CacheKeySuite.run_suite_tests
+    def test_insert_on_conflict_cache_key(self):
+        table = Table(
+            "foos",
+            MetaData(),
+            Column("id", Integer, primary_key=True),
+            Column("bar", String(10)),
+            Column("baz", String(10)),
+        )
+        Index("foo_idx", table.c.id)
+
+        def stmt0():
+            # note a multivalues INSERT is not cacheable; use just one
+            # set of values
+            return insert(table).values(
+                {"id": 1, "bar": "ab"},
+            )
+
+        def stmt1():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing()
+
+        def stmt2():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(index_elements=["id"])
+
+        def stmt21():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(index_elements=[table.c.id])
+
+        def stmt22():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(
+                index_elements=["id", table.c.bar]
+            )
+
+        def stmt23():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(index_elements=["id", "bar"])
+
+        def stmt24():
+            stmt = insert(table).values(
+                {"id": 1, "bar": "ab", "baz": "xy"},
+            )
+            return stmt.on_conflict_do_nothing(index_elements=["id", "bar"])
+
+        def stmt3():
+            stmt = stmt0()
+            return stmt.on_conflict_do_update(
+                index_elements=["id"],
+                set_={
+                    "bar": random.choice(["a", "b", "c"]),
+                    "baz": random.choice(["d", "e", "f"]),
+                },
+            )
+
+        def stmt31():
+            stmt = stmt0()
+            return stmt.on_conflict_do_update(
+                index_elements=["id"],
+                set_={
+                    "baz": random.choice(["d", "e", "f"]),
+                },
+            )
+
+        def stmt4():
+            stmt = stmt0()
+
+            return stmt.on_conflict_do_update(
+                constraint=table.primary_key, set_=stmt.excluded
+            )
+
+        def stmt41():
+            stmt = stmt0()
+
+            return stmt.on_conflict_do_update(
+                constraint=table.primary_key,
+                set_=stmt.excluded,
+                where=table.c.bar != random.choice(["q", "p", "r", "z"]),
+            )
+
+        def stmt42():
+            stmt = stmt0()
+
+            return stmt.on_conflict_do_update(
+                constraint=table.primary_key,
+                set_=stmt.excluded,
+                where=table.c.baz != random.choice(["q", "p", "r", "z"]),
+            )
+
+        return lambda: [
+            stmt0(),
+            stmt1(),
+            stmt2(),
+            stmt21(),
+            stmt22(),
+            stmt23(),
+            stmt24(),
+            stmt3(),
+            stmt31(),
+            stmt4(),
+            stmt41(),
+            stmt42(),
+        ]
+
     @testing.combinations("control", "excluded", "dict")
     def test_set_excluded(self, scenario):
         """test #8014, sending all of .excluded to set"""
@@ -2832,6 +2941,34 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
                     "SET id = excluded.id, name = excluded.name",
                 )
 
+    def test_dont_consume_set_collection(self):
+        users = self.tables.users
+        stmt = insert(users).values(
+            [
+                {
+                    "name": "spongebob",
+                },
+                {
+                    "name": "sandy",
+                },
+            ]
+        )
+        stmt = stmt.on_conflict_do_update(
+            index_elements=[users.c.name], set_=dict(name=stmt.excluded.name)
+        )
+        self.assert_compile(
+            stmt,
+            "INSERT INTO users (name) VALUES (%(name_m0)s), (%(name_m1)s) "
+            "ON CONFLICT (name) DO UPDATE SET name = excluded.name",
+        )
+        stmt = stmt.returning(users)
+        self.assert_compile(
+            stmt,
+            "INSERT INTO users (name) VALUES (%(name_m0)s), (%(name_m1)s) "
+            "ON CONFLICT (name) DO UPDATE SET name = excluded.name "
+            "RETURNING users.id, users.name",
+        )
+
     def test_on_conflict_do_no_call_twice(self):
         users = self.table1
 
index ecb9510c937078ec8b8f18654c60e574f7a54f49..c5b4f62e2969e67d8c90904ba080eb457685e799 100644 (file)
@@ -3,6 +3,7 @@
 import datetime
 import json
 import os
+import random
 
 from sqlalchemy import and_
 from sqlalchemy import bindparam
@@ -2952,7 +2953,9 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
 
-class OnConflictCompileTest(AssertsCompiledSQL, fixtures.TestBase):
+class OnConflictCompileTest(
+    AssertsCompiledSQL, fixtures.CacheKeySuite, fixtures.TestBase
+):
     __dialect__ = "sqlite"
 
     @testing.combinations(
@@ -3012,6 +3015,83 @@ class OnConflictCompileTest(AssertsCompiledSQL, fixtures.TestBase):
                 f"INSERT INTO users (id, name) VALUES (?, ?) {expected}",
             )
 
+    @fixtures.CacheKeySuite.run_suite_tests
+    def test_insert_on_conflict_cache_key(self):
+        table = Table(
+            "foos",
+            MetaData(),
+            Column("id", Integer, primary_key=True),
+            Column("bar", String(10)),
+            Column("baz", String(10)),
+        )
+        Index("foo_idx", table.c.id)
+
+        def stmt0():
+            # note a multivalues INSERT is not cacheable; use just one
+            # set of values
+            return insert(table).values(
+                {"id": 1, "bar": "ab"},
+            )
+
+        def stmt1():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing()
+
+        def stmt2():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(index_elements=["id"])
+
+        def stmt21():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(index_elements=[table.c.id])
+
+        def stmt22():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(
+                index_elements=["id", table.c.bar]
+            )
+
+        def stmt23():
+            stmt = stmt0()
+            return stmt.on_conflict_do_nothing(index_elements=["id", "bar"])
+
+        def stmt24():
+            stmt = insert(table).values(
+                {"id": 1, "bar": "ab", "baz": "xy"},
+            )
+            return stmt.on_conflict_do_nothing(index_elements=["id", "bar"])
+
+        def stmt3():
+            stmt = stmt0()
+            return stmt.on_conflict_do_update(
+                index_elements=["id"],
+                set_={
+                    "bar": random.choice(["a", "b", "c"]),
+                    "baz": random.choice(["d", "e", "f"]),
+                },
+            )
+
+        def stmt31():
+            stmt = stmt0()
+            return stmt.on_conflict_do_update(
+                index_elements=["id"],
+                set_={
+                    "baz": random.choice(["d", "e", "f"]),
+                },
+            )
+
+        return lambda: [
+            stmt0(),
+            stmt1(),
+            stmt2(),
+            stmt21(),
+            stmt22(),
+            stmt23(),
+            stmt24(),
+            stmt3(),
+            stmt31(),
+        ]
+
     @testing.combinations("control", "excluded", "dict", argnames="scenario")
     def test_set_excluded(self, scenario, users, users_w_key):
         """test #8014, sending all of .excluded to set"""
@@ -3048,6 +3128,33 @@ class OnConflictCompileTest(AssertsCompiledSQL, fixtures.TestBase):
                     "DO UPDATE SET id = excluded.id, name = excluded.name",
                 )
 
+    def test_dont_consume_set_collection(self, users):
+        stmt = insert(users).values(
+            [
+                {
+                    "name": "spongebob",
+                },
+                {
+                    "name": "sandy",
+                },
+            ]
+        )
+        stmt = stmt.on_conflict_do_update(
+            index_elements=[users.c.name], set_=dict(name=stmt.excluded.name)
+        )
+        self.assert_compile(
+            stmt,
+            "INSERT INTO users (name) VALUES (?), (?) "
+            "ON CONFLICT (name) DO UPDATE SET name = excluded.name",
+        )
+        stmt = stmt.returning(users)
+        self.assert_compile(
+            stmt,
+            "INSERT INTO users (name) VALUES (?), (?) "
+            "ON CONFLICT (name) DO UPDATE SET name = excluded.name "
+            "RETURNING id, name",
+        )
+
     def test_on_conflict_do_update_exotic_targets_six(self, users_xtra):
         users = users_xtra
 
index d499609b49595ab764b9a1aaefae15aee0649b88..8b1869e8d0d97aa57b9daecb5b85e7b4a24170bc 100644 (file)
@@ -31,8 +31,6 @@ from sqlalchemy import TypeDecorator
 from sqlalchemy import union
 from sqlalchemy import union_all
 from sqlalchemy import values
-from sqlalchemy.dialects import mysql
-from sqlalchemy.dialects import postgresql
 from sqlalchemy.schema import Sequence
 from sqlalchemy.sql import bindparam
 from sqlalchemy.sql import ColumnElement
@@ -1226,17 +1224,7 @@ class CoreFixtures:
 
 
 class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
-    # we are slightly breaking the policy of not having external dialect
-    # stuff in here, but use pg/mysql as test cases to ensure that these
-    # objects don't report an inaccurate cache key, which is dependent
-    # on the base insert sending out _post_values_clause and the caching
-    # system properly recognizing these constructs as not cacheable
-
     @testing.combinations(
-        postgresql.insert(table_a).on_conflict_do_update(
-            index_elements=[table_a.c.a], set_={"name": "foo"}
-        ),
-        mysql.insert(table_a).on_duplicate_key_update(updated_once=None),
         table_a.insert().values(  # multivalues doesn't cache
             [
                 {"name": "some name"},