From: Mike Bayer Date: Mon, 24 Feb 2025 22:53:40 +0000 (-0500) Subject: restate all upsert in terms of statement extensions (patch 3) X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=d5d4189ef63e7a623894ca7a148a92c716935960;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git restate all upsert in terms of statement extensions (patch 3) Change-Id: I0595ba8e2bd930e22f4c06d7a813bcd23060cb7a --- diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7838b455b9..df4d93c481 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1444,41 +1444,32 @@ class MySQLCompiler(compiler.SQLCompiler): for column in (col for col in cols if col.key in on_duplicate_update): val = on_duplicate_update[column.key] - # TODO: this coercion should be up front. we can't cache - # SQL constructs with non-bound literals buried in them - if coercions._is_literal(val): - val = elements.BindParameter(None, val, type_=column.type) - value_text = self.process(val.self_group(), use_schema=False) - else: - - def replace(obj): - if ( - isinstance(obj, elements.BindParameter) - and obj.type._isnull - ): - obj = obj._clone() - obj.type = column.type - return obj - elif ( - isinstance(obj, elements.ColumnClause) - and obj.table is on_duplicate.inserted_alias - ): - if requires_mysql8_alias: - column_literal_clause = ( - f"{_on_dup_alias_name}." - f"{self.preparer.quote(obj.name)}" - ) - else: - column_literal_clause = ( - f"VALUES({self.preparer.quote(obj.name)})" - ) - return literal_column(column_literal_clause) + def replace(obj): + if ( + isinstance(obj, elements.BindParameter) + and obj.type._isnull + ): + return obj._with_binary_element_type(column.type) + elif ( + isinstance(obj, elements.ColumnClause) + and obj.table is on_duplicate.inserted_alias + ): + if requires_mysql8_alias: + column_literal_clause = ( + f"{_on_dup_alias_name}." + f"{self.preparer.quote(obj.name)}" + ) else: - # element is not replaced - return None + column_literal_clause = ( + f"VALUES({self.preparer.quote(obj.name)})" + ) + return literal_column(column_literal_clause) + else: + # element is not replaced + return None - val = visitors.replacement_traverse(val, {}, replace) - value_text = self.process(val.self_group(), use_schema=False) + val = visitors.replacement_traverse(val, {}, replace) + value_text = self.process(val.self_group(), use_schema=False) name_text = self.preparer.quote(column.name) clauses.append("%s = %s" % (name_text, value_text)) diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index f3be3c395d..61476af022 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -21,7 +21,6 @@ from ...sql import coercions from ...sql import roles from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against -from ...sql.base import _generative from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection from ...sql.base import SyntaxExtension @@ -30,6 +29,7 @@ from ...sql.elements import ClauseElement from ...sql.elements import KeyedColumnElement from ...sql.expression import alias from ...sql.selectable import NamedFromClause +from ...sql.sqltypes import NULLTYPE from ...sql.visitors import InternalTraversal from ...util.typing import Self @@ -37,6 +37,7 @@ if TYPE_CHECKING: from ...sql._typing import _LimitOffsetType from ...sql.dml import Delete from ...sql.dml import Update + from ...sql.elements import ColumnElement from ...sql.visitors import _TraverseInternalsType __all__ = ("Insert", "insert") @@ -114,7 +115,7 @@ class Insert(StandardInsert): """ stringify_dialect = "mysql" - inherit_cache = False + inherit_cache = True @property def inserted( @@ -154,7 +155,6 @@ class Insert(StandardInsert): def inserted_alias(self) -> NamedFromClause: return alias(self.table, name="inserted") - @_generative @_exclusive_against( "_post_values_clause", msgs={ @@ -225,20 +225,22 @@ class Insert(StandardInsert): else: values = kw - self._post_values_clause = OnDuplicateClause( - self.inserted_alias, values - ) - return self + return self.ext(OnDuplicateClause(self.inserted_alias, values)) -class OnDuplicateClause(ClauseElement): +class OnDuplicateClause(SyntaxExtension, ClauseElement): __visit_name__ = "on_duplicate_key_update" _parameter_ordering: Optional[List[str]] = None - update: Dict[str, Any] + update: Dict[str, ColumnElement[Any]] stringify_dialect = "mysql" + _traverse_internals = [ + ("_parameter_ordering", InternalTraversal.dp_string_list), + ("update", InternalTraversal.dp_dml_values), + ] + def __init__( self, inserted_alias: NamedFromClause, update: _UpdateArg ) -> None: @@ -267,7 +269,18 @@ class OnDuplicateClause(ClauseElement): "or a ColumnCollection such as the `.c.` collection " "of a Table object" ) - self.update = update + + self.update = { + k: coercions.expect( + roles.ExpressionElementRole, v, type_=NULLTYPE, is_crud=True + ) + for k, v in update.items() + } + + def apply_to_insert(self, insert_stmt: StandardInsert) -> None: + insert_stmt.apply_syntax_extension_point( + self.append_replacing_same_type, "post_values" + ) _UpdateArg = Union[ diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 83bd99d7f0..38e834cf27 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2085,18 +2085,12 @@ class PGCompiler(compiler.SQLCompiler): else: continue - # TODO: this coercion should be up front. we can't cache - # SQL constructs with non-bound literals buried in them - if coercions._is_literal(value): - value = elements.BindParameter(None, value, type_=c.type) - - else: - if ( - isinstance(value, elements.BindParameter) - and value.type._isnull - ): - value = value._clone() - value.type = c.type + assert not coercions._is_literal(value) + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._with_binary_element_type(c.type) value_text = self.process(value.self_group(), use_schema=False) key_text = self.preparer.quote(c.name) diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 1187b6bf5f..6964754661 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -7,9 +7,9 @@ from __future__ import annotations from typing import Any +from typing import Dict from typing import List from typing import Optional -from typing import Tuple from typing import Union from . import ext @@ -24,18 +24,20 @@ from ...sql import roles from ...sql import schema from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against -from ...sql.base import _generative from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection +from ...sql.base import SyntaxExtension +from ...sql.dml import _DMLColumnElement from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement from ...sql.elements import TextClause from ...sql.expression import alias +from ...sql.type_api import NULLTYPE +from ...sql.visitors import InternalTraversal from ...util.typing import Self - __all__ = ("Insert", "insert") @@ -70,7 +72,7 @@ class Insert(StandardInsert): """ stringify_dialect = "postgresql" - inherit_cache = False + inherit_cache = True @util.memoized_property def excluded( @@ -109,7 +111,6 @@ class Insert(StandardInsert): }, ) - @_generative @_on_conflict_exclusive def on_conflict_do_update( self, @@ -169,12 +170,12 @@ class Insert(StandardInsert): :ref:`postgresql_insert_on_conflict` """ - self._post_values_clause = OnConflictDoUpdate( - constraint, index_elements, index_where, set_, where + return self.ext( + OnConflictDoUpdate( + constraint, index_elements, index_where, set_, where + ) ) - return self - @_generative @_on_conflict_exclusive def on_conflict_do_nothing( self, @@ -206,13 +207,12 @@ class Insert(StandardInsert): :ref:`postgresql_insert_on_conflict` """ - self._post_values_clause = OnConflictDoNothing( - constraint, index_elements, index_where + return self.ext( + OnConflictDoNothing(constraint, index_elements, index_where) ) - return self -class OnConflictClause(ClauseElement): +class OnConflictClause(SyntaxExtension, ClauseElement): stringify_dialect = "postgresql" constraint_target: Optional[str] @@ -221,6 +221,12 @@ class OnConflictClause(ClauseElement): Union[ColumnElement[Any], TextClause] ] + _traverse_internals = [ + ("constraint_target", InternalTraversal.dp_string), + ("inferred_target_elements", InternalTraversal.dp_multi_list), + ("inferred_target_whereclause", InternalTraversal.dp_clauseelement), + ] + def __init__( self, constraint: _OnConflictConstraintT = None, @@ -283,17 +289,29 @@ class OnConflictClause(ClauseElement): self.inferred_target_whereclause ) = None + def apply_to_insert(self, insert_stmt: StandardInsert) -> None: + insert_stmt.apply_syntax_extension_point( + self.append_replacing_same_type, "post_values" + ) + class OnConflictDoNothing(OnConflictClause): __visit_name__ = "on_conflict_do_nothing" + inherit_cache = True + class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" - update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] + update_values_to_set: Dict[_DMLColumnElement, ColumnElement[Any]] update_whereclause: Optional[ColumnElement[Any]] + _traverse_internals = OnConflictClause._traverse_internals + [ + ("update_values_to_set", InternalTraversal.dp_dml_values), + ("update_whereclause", InternalTraversal.dp_clauseelement), + ] + def __init__( self, constraint: _OnConflictConstraintT = None, @@ -328,10 +346,13 @@ class OnConflictDoUpdate(OnConflictClause): "or a ColumnCollection such as the `.c.` collection " "of a Table object" ) - self.update_values_to_set = [ - (coercions.expect(roles.DMLColumnRole, key), value) - for key, value in set_.items() - ] + + self.update_values_to_set = { + coercions.expect(roles.DMLColumnRole, k): coercions.expect( + roles.ExpressionElementRole, v, type_=NULLTYPE, is_crud=True + ) + for k, v in set_.items() + } self.update_whereclause = ( coercions.expect(roles.WhereHavingRole, where) if where is not None diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 96b2414cce..7b8e42a285 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1533,16 +1533,11 @@ class SQLiteCompiler(compiler.SQLCompiler): else: continue - if coercions._is_literal(value): - value = elements.BindParameter(None, value, type_=c.type) - - else: - if ( - isinstance(value, elements.BindParameter) - and value.type._isnull - ): - value = value._clone() - value.type = c.type + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._with_binary_element_type(c.type) value_text = self.process(value.self_group(), use_schema=False) key_text = self.preparer.quote(c.name) diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index 84cdb8bec2..fc16f1eaa4 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -7,9 +7,9 @@ from __future__ import annotations from typing import Any +from typing import Dict from typing import List from typing import Optional -from typing import Tuple from typing import Union from .._typing import _OnConflictIndexElementsT @@ -22,15 +22,18 @@ from ...sql import roles from ...sql import schema from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against -from ...sql.base import _generative from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection +from ...sql.base import SyntaxExtension +from ...sql.dml import _DMLColumnElement from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement from ...sql.elements import TextClause from ...sql.expression import alias +from ...sql.sqltypes import NULLTYPE +from ...sql.visitors import InternalTraversal from ...util.typing import Self __all__ = ("Insert", "insert") @@ -73,7 +76,7 @@ class Insert(StandardInsert): """ stringify_dialect = "sqlite" - inherit_cache = False + inherit_cache = True @util.memoized_property def excluded( @@ -107,7 +110,6 @@ class Insert(StandardInsert): }, ) - @_generative @_on_conflict_exclusive def on_conflict_do_update( self, @@ -155,12 +157,10 @@ class Insert(StandardInsert): """ - self._post_values_clause = OnConflictDoUpdate( - index_elements, index_where, set_, where + return self.ext( + OnConflictDoUpdate(index_elements, index_where, set_, where) ) - return self - @_generative @_on_conflict_exclusive def on_conflict_do_nothing( self, @@ -181,13 +181,10 @@ class Insert(StandardInsert): """ - self._post_values_clause = OnConflictDoNothing( - index_elements, index_where - ) - return self + return self.ext(OnConflictDoNothing(index_elements, index_where)) -class OnConflictClause(ClauseElement): +class OnConflictClause(SyntaxExtension, ClauseElement): stringify_dialect = "sqlite" inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]] @@ -195,6 +192,11 @@ class OnConflictClause(ClauseElement): Union[ColumnElement[Any], TextClause] ] + _traverse_internals = [ + ("inferred_target_elements", InternalTraversal.dp_multi_list), + ("inferred_target_whereclause", InternalTraversal.dp_clauseelement), + ] + def __init__( self, index_elements: _OnConflictIndexElementsT = None, @@ -218,17 +220,29 @@ class OnConflictClause(ClauseElement): self.inferred_target_whereclause ) = None + def apply_to_insert(self, insert_stmt: StandardInsert) -> None: + insert_stmt.apply_syntax_extension_point( + self.append_replacing_same_type, "post_values" + ) + class OnConflictDoNothing(OnConflictClause): __visit_name__ = "on_conflict_do_nothing" + inherit_cache = True + class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" - update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] + update_values_to_set: Dict[_DMLColumnElement, ColumnElement[Any]] update_whereclause: Optional[ColumnElement[Any]] + _traverse_internals = OnConflictClause._traverse_internals + [ + ("update_values_to_set", InternalTraversal.dp_dml_values), + ("update_whereclause", InternalTraversal.dp_clauseelement), + ] + def __init__( self, index_elements: _OnConflictIndexElementsT = None, @@ -252,10 +266,12 @@ class OnConflictDoUpdate(OnConflictClause): "or a ColumnCollection such as the `.c.` collection " "of a Table object" ) - self.update_values_to_set = [ - (coercions.expect(roles.DMLColumnRole, key), value) - for key, value in set_.items() - ] + self.update_values_to_set = { + coercions.expect(roles.DMLColumnRole, k): coercions.expect( + roles.ExpressionElementRole, v, type_=NULLTYPE, is_crud=True + ) + for k, v in set_.items() + } self.update_whereclause = ( coercions.expect(roles.WhereHavingRole, where) if where is not None diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 5c98be3f6a..553298c549 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -1,3 +1,5 @@ +import random + from sqlalchemy import BLOB from sqlalchemy import BOOLEAN from sqlalchemy import Boolean @@ -630,6 +632,51 @@ class CustomExtensionTest( ): __dialect__ = "mysql" + @fixtures.CacheKeySuite.run_suite_tests + def test_insert_on_duplicate_key_cache_key(self): + table = Table( + "foos", + MetaData(), + Column("id", Integer, primary_key=True), + Column("bar", String(10)), + Column("baz", String(10)), + ) + + def stmt0(): + # note a multivalues INSERT is not cacheable; use just one + # set of values + return insert(table).values( + {"id": 1, "bar": "ab"}, + ) + + def stmt1(): + stmt = stmt0() + return stmt.on_duplicate_key_update( + bar=stmt.inserted.bar, baz=stmt.inserted.baz + ) + + def stmt15(): + stmt = insert(table).values( + {"id": 1}, + ) + return stmt.on_duplicate_key_update( + bar=stmt.inserted.bar, baz=stmt.inserted.baz + ) + + def stmt2(): + stmt = stmt0() + return stmt.on_duplicate_key_update(bar=stmt.inserted.bar) + + def stmt3(): + stmt = stmt0() + # use different literal values; ensure each cache key is + # identical + return stmt.on_duplicate_key_update( + bar=random.choice(["a", "b", "c"]) + ) + + return lambda: [stmt0(), stmt1(), stmt15(), stmt2(), stmt3()] + @fixtures.CacheKeySuite.run_suite_tests def test_dml_limit_cache_key(self): t = sql.table("t", sql.column("col1"), sql.column("col2")) diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py index 35aebb470c..307057c8e3 100644 --- a/test/dialect/mysql/test_on_duplicate.py +++ b/test/dialect/mysql/test_on_duplicate.py @@ -1,3 +1,5 @@ +import random + from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy import exc @@ -211,3 +213,25 @@ class OnDuplicateTest(fixtures.TablesTest): stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz") ) eq_(result.inserted_primary_key, (1,)) + + def test_bound_caching(self, connection): + foos = self.tables.foos + connection.execute(insert(foos).values(dict(id=1, bar="b", baz="bz"))) + + for scenario in [ + (random.choice(["c", "d", "e"]), random.choice(["f", "g", "h"])) + for i in range(10) + ]: + stmt = insert(foos).values(dict(id=1, bar="q")) + stmt = stmt.on_duplicate_key_update( + bar=scenario[0], baz=scenario[1] + ) + + connection.execute(stmt) + + eq_( + connection.execute( + foos.select().where(foos.c.id == 1) + ).fetchall(), + [(1, scenario[0], scenario[1], False)], + ) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index f02b42c0b2..b6bd625708 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1,3 +1,5 @@ +import random + from sqlalchemy import and_ from sqlalchemy import BigInteger from sqlalchemy import bindparam @@ -2667,7 +2669,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) -class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): +class InsertOnConflictTest( + fixtures.TablesTest, AssertsCompiledSQL, fixtures.CacheKeySuite +): __dialect__ = postgresql.dialect() run_create_tables = None @@ -2786,6 +2790,111 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): f"{expected}", ) + @fixtures.CacheKeySuite.run_suite_tests + def test_insert_on_conflict_cache_key(self): + table = Table( + "foos", + MetaData(), + Column("id", Integer, primary_key=True), + Column("bar", String(10)), + Column("baz", String(10)), + ) + Index("foo_idx", table.c.id) + + def stmt0(): + # note a multivalues INSERT is not cacheable; use just one + # set of values + return insert(table).values( + {"id": 1, "bar": "ab"}, + ) + + def stmt1(): + stmt = stmt0() + return stmt.on_conflict_do_nothing() + + def stmt2(): + stmt = stmt0() + return stmt.on_conflict_do_nothing(index_elements=["id"]) + + def stmt21(): + stmt = stmt0() + return stmt.on_conflict_do_nothing(index_elements=[table.c.id]) + + def stmt22(): + stmt = stmt0() + return stmt.on_conflict_do_nothing( + index_elements=["id", table.c.bar] + ) + + def stmt23(): + stmt = stmt0() + return stmt.on_conflict_do_nothing(index_elements=["id", "bar"]) + + def stmt24(): + stmt = insert(table).values( + {"id": 1, "bar": "ab", "baz": "xy"}, + ) + return stmt.on_conflict_do_nothing(index_elements=["id", "bar"]) + + def stmt3(): + stmt = stmt0() + return stmt.on_conflict_do_update( + index_elements=["id"], + set_={ + "bar": random.choice(["a", "b", "c"]), + "baz": random.choice(["d", "e", "f"]), + }, + ) + + def stmt31(): + stmt = stmt0() + return stmt.on_conflict_do_update( + index_elements=["id"], + set_={ + "baz": random.choice(["d", "e", "f"]), + }, + ) + + def stmt4(): + stmt = stmt0() + + return stmt.on_conflict_do_update( + constraint=table.primary_key, set_=stmt.excluded + ) + + def stmt41(): + stmt = stmt0() + + return stmt.on_conflict_do_update( + constraint=table.primary_key, + set_=stmt.excluded, + where=table.c.bar != random.choice(["q", "p", "r", "z"]), + ) + + def stmt42(): + stmt = stmt0() + + return stmt.on_conflict_do_update( + constraint=table.primary_key, + set_=stmt.excluded, + where=table.c.baz != random.choice(["q", "p", "r", "z"]), + ) + + return lambda: [ + stmt0(), + stmt1(), + stmt2(), + stmt21(), + stmt22(), + stmt23(), + stmt24(), + stmt3(), + stmt31(), + stmt4(), + stmt41(), + stmt42(), + ] + @testing.combinations("control", "excluded", "dict") def test_set_excluded(self, scenario): """test #8014, sending all of .excluded to set""" @@ -2832,6 +2941,34 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): "SET id = excluded.id, name = excluded.name", ) + def test_dont_consume_set_collection(self): + users = self.tables.users + stmt = insert(users).values( + [ + { + "name": "spongebob", + }, + { + "name": "sandy", + }, + ] + ) + stmt = stmt.on_conflict_do_update( + index_elements=[users.c.name], set_=dict(name=stmt.excluded.name) + ) + self.assert_compile( + stmt, + "INSERT INTO users (name) VALUES (%(name_m0)s), (%(name_m1)s) " + "ON CONFLICT (name) DO UPDATE SET name = excluded.name", + ) + stmt = stmt.returning(users) + self.assert_compile( + stmt, + "INSERT INTO users (name) VALUES (%(name_m0)s), (%(name_m1)s) " + "ON CONFLICT (name) DO UPDATE SET name = excluded.name " + "RETURNING users.id, users.name", + ) + def test_on_conflict_do_no_call_twice(self): users = self.table1 diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index ecb9510c93..c5b4f62e29 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -3,6 +3,7 @@ import datetime import json import os +import random from sqlalchemy import and_ from sqlalchemy import bindparam @@ -2952,7 +2953,9 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) -class OnConflictCompileTest(AssertsCompiledSQL, fixtures.TestBase): +class OnConflictCompileTest( + AssertsCompiledSQL, fixtures.CacheKeySuite, fixtures.TestBase +): __dialect__ = "sqlite" @testing.combinations( @@ -3012,6 +3015,83 @@ class OnConflictCompileTest(AssertsCompiledSQL, fixtures.TestBase): f"INSERT INTO users (id, name) VALUES (?, ?) {expected}", ) + @fixtures.CacheKeySuite.run_suite_tests + def test_insert_on_conflict_cache_key(self): + table = Table( + "foos", + MetaData(), + Column("id", Integer, primary_key=True), + Column("bar", String(10)), + Column("baz", String(10)), + ) + Index("foo_idx", table.c.id) + + def stmt0(): + # note a multivalues INSERT is not cacheable; use just one + # set of values + return insert(table).values( + {"id": 1, "bar": "ab"}, + ) + + def stmt1(): + stmt = stmt0() + return stmt.on_conflict_do_nothing() + + def stmt2(): + stmt = stmt0() + return stmt.on_conflict_do_nothing(index_elements=["id"]) + + def stmt21(): + stmt = stmt0() + return stmt.on_conflict_do_nothing(index_elements=[table.c.id]) + + def stmt22(): + stmt = stmt0() + return stmt.on_conflict_do_nothing( + index_elements=["id", table.c.bar] + ) + + def stmt23(): + stmt = stmt0() + return stmt.on_conflict_do_nothing(index_elements=["id", "bar"]) + + def stmt24(): + stmt = insert(table).values( + {"id": 1, "bar": "ab", "baz": "xy"}, + ) + return stmt.on_conflict_do_nothing(index_elements=["id", "bar"]) + + def stmt3(): + stmt = stmt0() + return stmt.on_conflict_do_update( + index_elements=["id"], + set_={ + "bar": random.choice(["a", "b", "c"]), + "baz": random.choice(["d", "e", "f"]), + }, + ) + + def stmt31(): + stmt = stmt0() + return stmt.on_conflict_do_update( + index_elements=["id"], + set_={ + "baz": random.choice(["d", "e", "f"]), + }, + ) + + return lambda: [ + stmt0(), + stmt1(), + stmt2(), + stmt21(), + stmt22(), + stmt23(), + stmt24(), + stmt3(), + stmt31(), + ] + @testing.combinations("control", "excluded", "dict", argnames="scenario") def test_set_excluded(self, scenario, users, users_w_key): """test #8014, sending all of .excluded to set""" @@ -3048,6 +3128,33 @@ class OnConflictCompileTest(AssertsCompiledSQL, fixtures.TestBase): "DO UPDATE SET id = excluded.id, name = excluded.name", ) + def test_dont_consume_set_collection(self, users): + stmt = insert(users).values( + [ + { + "name": "spongebob", + }, + { + "name": "sandy", + }, + ] + ) + stmt = stmt.on_conflict_do_update( + index_elements=[users.c.name], set_=dict(name=stmt.excluded.name) + ) + self.assert_compile( + stmt, + "INSERT INTO users (name) VALUES (?), (?) " + "ON CONFLICT (name) DO UPDATE SET name = excluded.name", + ) + stmt = stmt.returning(users) + self.assert_compile( + stmt, + "INSERT INTO users (name) VALUES (?), (?) " + "ON CONFLICT (name) DO UPDATE SET name = excluded.name " + "RETURNING id, name", + ) + def test_on_conflict_do_update_exotic_targets_six(self, users_xtra): users = users_xtra diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index d499609b49..8b1869e8d0 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -31,8 +31,6 @@ from sqlalchemy import TypeDecorator from sqlalchemy import union from sqlalchemy import union_all from sqlalchemy import values -from sqlalchemy.dialects import mysql -from sqlalchemy.dialects import postgresql from sqlalchemy.schema import Sequence from sqlalchemy.sql import bindparam from sqlalchemy.sql import ColumnElement @@ -1226,17 +1224,7 @@ class CoreFixtures: class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase): - # we are slightly breaking the policy of not having external dialect - # stuff in here, but use pg/mysql as test cases to ensure that these - # objects don't report an inaccurate cache key, which is dependent - # on the base insert sending out _post_values_clause and the caching - # system properly recognizing these constructs as not cacheable - @testing.combinations( - postgresql.insert(table_a).on_conflict_do_update( - index_elements=[table_a.c.a], set_={"name": "foo"} - ), - mysql.insert(table_a).on_duplicate_key_update(updated_once=None), table_a.insert().values( # multivalues doesn't cache [ {"name": "some name"},