From: Mike Bayer Date: Wed, 18 Dec 2024 16:24:58 +0000 (-0500) Subject: harden typing / coercion for on conflict/on duplicate key X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=219bcb3a77edd72ef8fc36c8ded921d6fb9a34a5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git harden typing / coercion for on conflict/on duplicate key in 2.1 we want these structures to be cacheable, so start by cleaning up types and adding coercions to enforce those types. these will be more locked down in 2.1 as we will need to move bound parameter coercion outside of compilation, but here do some small starts and introduce in 2.0. in one interest of cachability, a "literal_binds" that found its way into SQLite's compiler is replaced with "literal_execute", the difference being that the latter is cacheable. This literal is apparently necessary to suit SQLite's query planner for the "index criteria" portion of the on conflict clause that otherwise can't work with a real bound parameter. Change-Id: I4d66ec1473321616a1707da324a7dfe7a61ec94e --- diff --git a/lib/sqlalchemy/dialects/_typing.py b/lib/sqlalchemy/dialects/_typing.py index 9ee6e4bca1..811e125fd5 100644 --- a/lib/sqlalchemy/dialects/_typing.py +++ b/lib/sqlalchemy/dialects/_typing.py @@ -12,14 +12,16 @@ from typing import Mapping from typing import Optional from typing import Union -from ..sql._typing import _DDLColumnArgument -from ..sql.elements import DQLDMLClauseElement +from ..sql import roles +from ..sql.schema import Column from ..sql.schema import ColumnCollectionConstraint from ..sql.schema import Index _OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None] -_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]] -_OnConflictIndexWhereT = Optional[DQLDMLClauseElement] +_OnConflictIndexElementsT = Optional[ + Iterable[Union[Column[Any], str, roles.DDLConstraintColumnRole]] +] +_OnConflictIndexWhereT = Optional[roles.WhereHavingRole] _OnConflictSetT = Optional[Mapping[Any, Any]] -_OnConflictWhereT = Union[DQLDMLClauseElement, str, None] +_OnConflictWhereT = Optional[roles.WhereHavingRole] diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 42e80cf273..25d293d533 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1405,6 +1405,8 @@ class MySQLCompiler(compiler.SQLCompiler): for column in (col for col in cols if col.key in on_duplicate.update): val = on_duplicate.update[column.key] + # TODO: this coercion should be up front. we can't cache + # SQL constructs with non-bound literals buried in them if coercions._is_literal(val): val = elements.BindParameter(None, val, type_=column.type) value_text = self.process(val.self_group(), use_schema=False) diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index d9164317b0..731d1943aa 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -7,6 +7,7 @@ from __future__ import annotations from typing import Any +from typing import Dict from typing import List from typing import Mapping from typing import Optional @@ -185,6 +186,7 @@ class OnDuplicateClause(ClauseElement): _parameter_ordering: Optional[List[str]] = None + update: Dict[str, Any] stringify_dialect = "mysql" def __init__( diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 6b14ace174..b917cfcde7 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2085,6 +2085,8 @@ class PGCompiler(compiler.SQLCompiler): else: continue + # TODO: this coercion should be up front. we can't cache + # SQL constructs with non-bound literals buried in them if coercions._is_literal(value): value = elements.BindParameter(None, value, type_=c.type) diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 4404ecd37b..1615506c0b 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -7,7 +7,10 @@ from __future__ import annotations from typing import Any +from typing import List from typing import Optional +from typing import Tuple +from typing import Union from . import ext from .._typing import _OnConflictConstraintT @@ -26,7 +29,9 @@ from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement +from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement +from ...sql.elements import TextClause from ...sql.expression import alias from ...util.typing import Self @@ -153,11 +158,10 @@ class Insert(StandardInsert): :paramref:`.Insert.on_conflict_do_update.set_` dictionary. :param where: - Optional argument. If present, can be a literal SQL - string or an acceptable expression for a ``WHERE`` clause - that restricts the rows affected by ``DO UPDATE SET``. Rows - not meeting the ``WHERE`` condition will not be updated - (effectively a ``DO NOTHING`` for those rows). + Optional argument. An expression object representing a ``WHERE`` + clause that restricts the rows affected by ``DO UPDATE SET``. Rows not + meeting the ``WHERE`` condition will not be updated (effectively a + ``DO NOTHING`` for those rows). .. seealso:: @@ -212,8 +216,10 @@ class OnConflictClause(ClauseElement): stringify_dialect = "postgresql" constraint_target: Optional[str] - inferred_target_elements: _OnConflictIndexElementsT - inferred_target_whereclause: _OnConflictIndexWhereT + inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]] + inferred_target_whereclause: Optional[ + Union[ColumnElement[Any], TextClause] + ] def __init__( self, @@ -254,8 +260,24 @@ class OnConflictClause(ClauseElement): if index_elements is not None: self.constraint_target = None - self.inferred_target_elements = index_elements - self.inferred_target_whereclause = index_where + self.inferred_target_elements = [ + coercions.expect(roles.DDLConstraintColumnRole, column) + for column in index_elements + ] + + self.inferred_target_whereclause = ( + coercions.expect( + ( + roles.StatementOptionRole + if isinstance(constraint, ext.ExcludeConstraint) + else roles.WhereHavingRole + ), + index_where, + ) + if index_where is not None + else None + ) + elif constraint is None: self.constraint_target = self.inferred_target_elements = ( self.inferred_target_whereclause @@ -269,6 +291,9 @@ class OnConflictDoNothing(OnConflictClause): class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" + update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] + update_whereclause: Optional[ColumnElement[Any]] + def __init__( self, constraint: _OnConflictConstraintT = None, @@ -307,4 +332,8 @@ class OnConflictDoUpdate(OnConflictClause): (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() ] - self.update_whereclause = where + self.update_whereclause = ( + coercions.expect(roles.WhereHavingRole, where) + if where is not None + else None + ) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 5ae7ffbf0f..51b957cf9a 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1481,9 +1481,7 @@ class SQLiteCompiler(compiler.SQLCompiler): return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) def _on_conflict_target(self, clause, **kw): - if clause.constraint_target is not None: - target_text = "(%s)" % clause.constraint_target - elif clause.inferred_target_elements is not None: + if clause.inferred_target_elements is not None: target_text = "(%s)" % ", ".join( ( self.preparer.quote(c) @@ -1497,7 +1495,7 @@ class SQLiteCompiler(compiler.SQLCompiler): clause.inferred_target_whereclause, include_table=False, use_schema=False, - literal_binds=True, + literal_execute=True, ) else: diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index dcf5e4482e..163a6ed28b 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -7,6 +7,10 @@ from __future__ import annotations from typing import Any +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union from .._typing import _OnConflictIndexElementsT from .._typing import _OnConflictIndexWhereT @@ -15,6 +19,7 @@ from .._typing import _OnConflictWhereT from ... import util from ...sql import coercions from ...sql import roles +from ...sql import schema from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against from ...sql.base import _generative @@ -22,7 +27,9 @@ from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement +from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement +from ...sql.elements import TextClause from ...sql.expression import alias from ...util.typing import Self @@ -141,11 +148,10 @@ class Insert(StandardInsert): :paramref:`.Insert.on_conflict_do_update.set_` dictionary. :param where: - Optional argument. If present, can be a literal SQL - string or an acceptable expression for a ``WHERE`` clause - that restricts the rows affected by ``DO UPDATE SET``. Rows - not meeting the ``WHERE`` condition will not be updated - (effectively a ``DO NOTHING`` for those rows). + Optional argument. An expression object representing a ``WHERE`` + clause that restricts the rows affected by ``DO UPDATE SET``. Rows not + meeting the ``WHERE`` condition will not be updated (effectively a + ``DO NOTHING`` for those rows). """ @@ -184,9 +190,10 @@ class Insert(StandardInsert): class OnConflictClause(ClauseElement): stringify_dialect = "sqlite" - constraint_target: None - inferred_target_elements: _OnConflictIndexElementsT - inferred_target_whereclause: _OnConflictIndexWhereT + inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]] + inferred_target_whereclause: Optional[ + Union[ColumnElement[Any], TextClause] + ] def __init__( self, @@ -194,11 +201,20 @@ class OnConflictClause(ClauseElement): index_where: _OnConflictIndexWhereT = None, ): if index_elements is not None: - self.constraint_target = None - self.inferred_target_elements = index_elements - self.inferred_target_whereclause = index_where + self.inferred_target_elements = [ + coercions.expect(roles.DDLConstraintColumnRole, column) + for column in index_elements + ] + self.inferred_target_whereclause = ( + coercions.expect( + roles.WhereHavingRole, + index_where, + ) + if index_where is not None + else None + ) else: - self.constraint_target = self.inferred_target_elements = ( + self.inferred_target_elements = ( self.inferred_target_whereclause ) = None @@ -210,6 +226,9 @@ class OnConflictDoNothing(OnConflictClause): class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" + update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] + update_whereclause: Optional[ColumnElement[Any]] + def __init__( self, index_elements: _OnConflictIndexElementsT = None, @@ -237,4 +256,8 @@ class OnConflictDoUpdate(OnConflictClause): (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() ] - self.update_whereclause = where + self.update_whereclause = ( + coercions.expect(roles.WhereHavingRole, where) + if where is not None + else None + ) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 63f9f85529..c30258a890 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -57,9 +57,9 @@ if typing.TYPE_CHECKING: from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement - from .elements import DQLDMLClauseElement from .elements import NamedColumn from .elements import SQLCoreOperations + from .elements import TextClause from .schema import Column from .selectable import _ColumnsClauseElement from .selectable import _JoinTargetProtocol @@ -190,7 +190,7 @@ def expect( role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, -) -> Column[Any]: ... +) -> Union[Column[Any], str]: ... @overload @@ -206,7 +206,7 @@ def expect( role: Type[roles.StatementOptionRole], element: Any, **kw: Any, -) -> DQLDMLClauseElement: ... +) -> Union[ColumnElement[Any], TextClause]: ... @overload diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 6539e303fa..de6d37f439 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -4293,6 +4293,10 @@ class ColumnCollectionMixin: ] = _gather_expressions if processed_expressions is not None: + + # this is expected to be an empty list + assert not processed_expressions + self._pending_colargs = [] for ( expr, diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index bb2dc653f8..f02b42c0b2 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -62,6 +62,7 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.functions import GenericFunction +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises @@ -2699,6 +2700,11 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): (cls.table_with_metadata.c.description, "&&"), where=cls.table_with_metadata.c.description != "foo", ) + cls.excl_constr_anon_str = ExcludeConstraint( + (cls.table_with_metadata.c.name, "="), + (cls.table_with_metadata.c.description, "&&"), + where="description != 'foo'", + ) cls.goofy_index = Index( "goofy_index", table1.c.name, postgresql_where=table1.c.name > "m" ) @@ -2717,6 +2723,69 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): Column("name", String(50), key="name_keyed"), ) + @testing.combinations( + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where=text("name = 'hi'") + ), + "ON CONFLICT (id) WHERE name = 'hi' DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=[users.c.id], index_where=users.c.name == "hi" + ), + "ON CONFLICT (id) WHERE name = %(name_1)s DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where="name = 'hi'" + ), + exc.ArgumentError, + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=users.c.name == "hi", + ), + "ON CONFLICT (id) DO UPDATE SET name = %(param_1)s " + "WHERE users.name = %(name_1)s", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=text("name = 'hi'"), + ), + "ON CONFLICT (id) DO UPDATE SET name = %(param_1)s " + "WHERE name = 'hi'", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where="name = 'hi'", + ), + exc.ArgumentError, + ), + ) + def test_assorted_arg_coercion(self, case, expected): + stmt = insert(self.tables.users) + + if isinstance(expected, type) and issubclass(expected, Exception): + with expect_raises(expected): + testing.resolve_lambda( + case, stmt=stmt, users=self.tables.users + ), + else: + self.assert_compile( + testing.resolve_lambda( + case, stmt=stmt, users=self.tables.users + ), + f"INSERT INTO users (id, name) VALUES (%(id)s, %(name)s) " + f"{expected}", + ) + @testing.combinations("control", "excluded", "dict") def test_set_excluded(self, scenario): """test #8014, sending all of .excluded to set""" @@ -3110,6 +3179,20 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): "DO UPDATE SET name = excluded.name", ) + def test_do_update_unnamed_exclude_constraint_string_target(self): + i = insert(self.table1).values(dict(name="foo")) + i = i.on_conflict_do_update( + constraint=self.excl_constr_anon_str, + set_=dict(name=i.excluded.name), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != 'foo' " + "DO UPDATE SET name = excluded.name", + ) + def test_do_update_add_whereclause(self): i = insert(self.table1).values(dict(name="foo")) i = i.on_conflict_do_update( @@ -3130,6 +3213,26 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): "AND mytable.description != %(description_2)s", ) + def test_do_update_str_index_where(self): + i = insert(self.table1).values(dict(name="foo")) + i = i.on_conflict_do_update( + constraint=self.excl_constr_anon_str, + set_=dict(name=i.excluded.name), + where=( + (self.table1.c.name != "brah") + & (self.table1.c.description != "brah") + ), + ) + self.assert_compile( + i, + "INSERT INTO mytable (name) VALUES " + "(%(name)s) ON CONFLICT (name, description) " + "WHERE description != 'foo' " + "DO UPDATE SET name = excluded.name " + "WHERE mytable.name != %(name_1)s " + "AND mytable.description != %(description_1)s", + ) + def test_do_update_add_whereclause_references_excluded(self): i = insert(self.table1).values(dict(name="foo")) i = i.on_conflict_do_update( diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index a9320f2c50..691f6c3962 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -583,7 +583,10 @@ class OnConflictTest(fixtures.TablesTest): [(43, "nameunique2", "name2@gmail.com", "not")], ) - def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection): + @testing.variation("string_index_elements", [True, False]) + def test_on_conflict_do_update_exotic_targets_four_no_pk( + self, connection, string_index_elements + ): users = self.tables.users_xtra self._exotic_targets_fixture(connection) @@ -591,7 +594,11 @@ class OnConflictTest(fixtures.TablesTest): # upsert on target login_email, not id i = insert(users) i = i.on_conflict_do_update( - index_elements=[users.c.login_email], + index_elements=( + ["login_email"] + if string_index_elements + else [users.c.login_email] + ), set_=dict( id=i.excluded.id, name=i.excluded.name, diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index d24a75f67d..5f483214b6 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -2938,7 +2938,176 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) -class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest): +class OnConflictCompileTest(AssertsCompiledSQL): + __dialect__ = "sqlite" + + @testing.combinations( + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where=text("name = 'hi'") + ), + "ON CONFLICT (id) WHERE name = 'hi' DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=["id"], index_where="name = 'hi'" + ), + exc.ArgumentError, + ), + ( + lambda users, stmt: stmt.on_conflict_do_nothing( + index_elements=[users.c.id], index_where=users.c.name == "hi" + ), + "ON CONFLICT (id) WHERE name = __[POSTCOMPILE_name_1] DO NOTHING", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=users.c.name == "hi", + ), + "ON CONFLICT (id) DO UPDATE SET name = ? " "WHERE users.name = ?", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where=text("name = 'hi'"), + ), + "ON CONFLICT (id) DO UPDATE SET name = ? " "WHERE name = 'hi'", + ), + ( + lambda users, stmt: stmt.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: "there"}, + where="name = 'hi'", + ), + exc.ArgumentError, + ), + argnames="case,expected", + ) + def test_assorted_arg_coercion(self, users, case, expected): + stmt = insert(users) + + if isinstance(expected, type) and issubclass(expected, Exception): + with expect_raises(expected): + testing.resolve_lambda(case, stmt=stmt, users=users), + else: + self.assert_compile( + testing.resolve_lambda(case, stmt=stmt, users=users), + f"INSERT INTO users (id, name) VALUES (?, ?) {expected}", + ) + + @testing.combinations("control", "excluded", "dict") + def test_set_excluded(self, scenario, users): + """test #8014, sending all of .excluded to set""" + + if scenario == "control": + + stmt = insert(users) + self.assert_compile( + stmt.on_conflict_do_update(set_=stmt.excluded), + "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT " + "DO UPDATE SET id = excluded.id, name = excluded.name", + ) + else: + users_w_key = self.tables.users_w_key + + stmt = insert(users_w_key) + + if scenario == "excluded": + self.assert_compile( + stmt.on_conflict_do_update(set_=stmt.excluded), + "INSERT INTO users_w_key (id, name) VALUES (?, ?) " + "ON CONFLICT " + "DO UPDATE SET id = excluded.id, name = excluded.name", + ) + else: + self.assert_compile( + stmt.on_conflict_do_update( + set_={ + "id": stmt.excluded.id, + "name_keyed": stmt.excluded.name_keyed, + } + ), + "INSERT INTO users_w_key (id, name) VALUES (?, ?) " + "ON CONFLICT " + "DO UPDATE SET id = excluded.id, name = excluded.name", + ) + + def test_on_conflict_do_update_exotic_targets_six( + self, connection, users_xtra + ): + users = users_xtra + + unique_partial_index = schema.Index( + "idx_unique_partial_name", + users_xtra.c.name, + users_xtra.c.lets_index_this, + unique=True, + sqlite_where=users_xtra.c.lets_index_this == "unique_name", + ) + + conn = connection + conn.execute( + insert(users), + dict( + id=1, + name="name1", + login_email="mail1@gmail.com", + lets_index_this="unique_name", + ), + ) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=unique_partial_index.columns, + index_where=unique_partial_index.dialect_options["sqlite"][ + "where" + ], + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + + # this test illustrates that the index_where clause can't use + # bound parameters, where we see below a literal_execute parameter is + # used (will be sent as literal to the DBAPI). SQLite otherwise + # fails here with "(sqlite3.OperationalError) ON CONFLICT clause does + # not match any PRIMARY KEY or UNIQUE constraint" if sent as a real + # bind parameter. + self.assert_compile( + i, + "INSERT INTO users_xtra (id, name, login_email, lets_index_this) " + "VALUES (?, ?, ?, ?) ON CONFLICT (name, lets_index_this) " + "WHERE lets_index_this = __[POSTCOMPILE_lets_index_this_1] " + "DO UPDATE " + "SET name = excluded.name, login_email = excluded.login_email", + ) + + @testing.fixture + def users(self): + metadata = MetaData() + return Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + ) + + @testing.fixture + def users_xtra(self): + metadata = MetaData() + return Table( + "users_xtra", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50)), + Column("login_email", String(50)), + Column("lets_index_this", String(50)), + ) + + +class OnConflictTest(fixtures.TablesTest): __only_on__ = ("sqlite >= 3.24.0",) __backend__ = True @@ -2998,49 +3167,8 @@ class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest): ) def test_bad_args(self): - assert_raises( - ValueError, insert(self.tables.users).on_conflict_do_update - ) - - @testing.combinations("control", "excluded", "dict") - @testing.skip_if("+pysqlite_numeric") - @testing.skip_if("+pysqlite_dollar") - def test_set_excluded(self, scenario): - """test #8014, sending all of .excluded to set""" - - if scenario == "control": - users = self.tables.users - - stmt = insert(users) - self.assert_compile( - stmt.on_conflict_do_update(set_=stmt.excluded), - "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT " - "DO UPDATE SET id = excluded.id, name = excluded.name", - ) - else: - users_w_key = self.tables.users_w_key - - stmt = insert(users_w_key) - - if scenario == "excluded": - self.assert_compile( - stmt.on_conflict_do_update(set_=stmt.excluded), - "INSERT INTO users_w_key (id, name) VALUES (?, ?) " - "ON CONFLICT " - "DO UPDATE SET id = excluded.id, name = excluded.name", - ) - else: - self.assert_compile( - stmt.on_conflict_do_update( - set_={ - "id": stmt.excluded.id, - "name_keyed": stmt.excluded.name_keyed, - } - ), - "INSERT INTO users_w_key (id, name) VALUES (?, ?) " - "ON CONFLICT " - "DO UPDATE SET id = excluded.id, name = excluded.name", - ) + with expect_raises(ValueError): + insert(self.tables.users).on_conflict_do_update() def test_on_conflict_do_no_call_twice(self): users = self.tables.users