from typing import Optional
from typing import Union
-from ..sql._typing import _DDLColumnArgument
-from ..sql.elements import DQLDMLClauseElement
+from ..sql import roles
+from ..sql.schema import Column
from ..sql.schema import ColumnCollectionConstraint
from ..sql.schema import Index
_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None]
-_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]]
-_OnConflictIndexWhereT = Optional[DQLDMLClauseElement]
+_OnConflictIndexElementsT = Optional[
+ Iterable[Union[Column[Any], str, roles.DDLConstraintColumnRole]]
+]
+_OnConflictIndexWhereT = Optional[roles.WhereHavingRole]
_OnConflictSetT = Optional[Mapping[Any, Any]]
-_OnConflictWhereT = Union[DQLDMLClauseElement, str, None]
+_OnConflictWhereT = Optional[roles.WhereHavingRole]
for column in (col for col in cols if col.key in on_duplicate.update):
val = on_duplicate.update[column.key]
+ # TODO: this coercion should be up front. we can't cache
+ # SQL constructs with non-bound literals buried in them
if coercions._is_literal(val):
val = elements.BindParameter(None, val, type_=column.type)
value_text = self.process(val.self_group(), use_schema=False)
from __future__ import annotations
from typing import Any
+from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
_parameter_ordering: Optional[List[str]] = None
+ update: Dict[str, Any]
stringify_dialect = "mysql"
def __init__(
else:
continue
+ # TODO: this coercion should be up front. we can't cache
+ # SQL constructs with non-bound literals buried in them
if coercions._is_literal(value):
value = elements.BindParameter(None, value, type_=c.type)
from __future__ import annotations
from typing import Any
+from typing import List
from typing import Optional
+from typing import Tuple
+from typing import Union
from . import ext
from .._typing import _OnConflictConstraintT
from ...sql.base import ReadOnlyColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
+from ...sql.elements import ColumnElement
from ...sql.elements import KeyedColumnElement
+from ...sql.elements import TextClause
from ...sql.expression import alias
from ...util.typing import Self
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
- Optional argument. If present, can be a literal SQL
- string or an acceptable expression for a ``WHERE`` clause
- that restricts the rows affected by ``DO UPDATE SET``. Rows
- not meeting the ``WHERE`` condition will not be updated
- (effectively a ``DO NOTHING`` for those rows).
+ Optional argument. An expression object representing a ``WHERE``
+ clause that restricts the rows affected by ``DO UPDATE SET``. Rows not
+ meeting the ``WHERE`` condition will not be updated (effectively a
+ ``DO NOTHING`` for those rows).
.. seealso::
stringify_dialect = "postgresql"
constraint_target: Optional[str]
- inferred_target_elements: _OnConflictIndexElementsT
- inferred_target_whereclause: _OnConflictIndexWhereT
+ inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
+ inferred_target_whereclause: Optional[
+ Union[ColumnElement[Any], TextClause]
+ ]
def __init__(
self,
if index_elements is not None:
self.constraint_target = None
- self.inferred_target_elements = index_elements
- self.inferred_target_whereclause = index_where
+ self.inferred_target_elements = [
+ coercions.expect(roles.DDLConstraintColumnRole, column)
+ for column in index_elements
+ ]
+
+ self.inferred_target_whereclause = (
+ coercions.expect(
+ (
+ roles.StatementOptionRole
+ if isinstance(constraint, ext.ExcludeConstraint)
+ else roles.WhereHavingRole
+ ),
+ index_where,
+ )
+ if index_where is not None
+ else None
+ )
+
elif constraint is None:
self.constraint_target = self.inferred_target_elements = (
self.inferred_target_whereclause
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
+ update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
+ update_whereclause: Optional[ColumnElement[Any]]
+
def __init__(
self,
constraint: _OnConflictConstraintT = None,
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
- self.update_whereclause = where
+ self.update_whereclause = (
+ coercions.expect(roles.WhereHavingRole, where)
+ if where is not None
+ else None
+ )
return self._generate_generic_binary(binary, " NOT REGEXP ", **kw)
def _on_conflict_target(self, clause, **kw):
- if clause.constraint_target is not None:
- target_text = "(%s)" % clause.constraint_target
- elif clause.inferred_target_elements is not None:
+ if clause.inferred_target_elements is not None:
target_text = "(%s)" % ", ".join(
(
self.preparer.quote(c)
clause.inferred_target_whereclause,
include_table=False,
use_schema=False,
- literal_binds=True,
+ literal_execute=True,
)
else:
from __future__ import annotations
from typing import Any
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
from .._typing import _OnConflictIndexElementsT
from .._typing import _OnConflictIndexWhereT
from ... import util
from ...sql import coercions
from ...sql import roles
+from ...sql import schema
from ...sql._typing import _DMLTableArgument
from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.base import ReadOnlyColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
+from ...sql.elements import ColumnElement
from ...sql.elements import KeyedColumnElement
+from ...sql.elements import TextClause
from ...sql.expression import alias
from ...util.typing import Self
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
- Optional argument. If present, can be a literal SQL
- string or an acceptable expression for a ``WHERE`` clause
- that restricts the rows affected by ``DO UPDATE SET``. Rows
- not meeting the ``WHERE`` condition will not be updated
- (effectively a ``DO NOTHING`` for those rows).
+ Optional argument. An expression object representing a ``WHERE``
+ clause that restricts the rows affected by ``DO UPDATE SET``. Rows not
+ meeting the ``WHERE`` condition will not be updated (effectively a
+ ``DO NOTHING`` for those rows).
"""
class OnConflictClause(ClauseElement):
stringify_dialect = "sqlite"
- constraint_target: None
- inferred_target_elements: _OnConflictIndexElementsT
- inferred_target_whereclause: _OnConflictIndexWhereT
+ inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]]
+ inferred_target_whereclause: Optional[
+ Union[ColumnElement[Any], TextClause]
+ ]
def __init__(
self,
index_where: _OnConflictIndexWhereT = None,
):
if index_elements is not None:
- self.constraint_target = None
- self.inferred_target_elements = index_elements
- self.inferred_target_whereclause = index_where
+ self.inferred_target_elements = [
+ coercions.expect(roles.DDLConstraintColumnRole, column)
+ for column in index_elements
+ ]
+ self.inferred_target_whereclause = (
+ coercions.expect(
+ roles.WhereHavingRole,
+ index_where,
+ )
+ if index_where is not None
+ else None
+ )
else:
- self.constraint_target = self.inferred_target_elements = (
+ self.inferred_target_elements = (
self.inferred_target_whereclause
) = None
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
+ update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]]
+ update_whereclause: Optional[ColumnElement[Any]]
+
def __init__(
self,
index_elements: _OnConflictIndexElementsT = None,
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
- self.update_whereclause = where
+ self.update_whereclause = (
+ coercions.expect(roles.WhereHavingRole, where)
+ if where is not None
+ else None
+ )
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
- from .elements import DQLDMLClauseElement
from .elements import NamedColumn
from .elements import SQLCoreOperations
+ from .elements import TextClause
from .schema import Column
from .selectable import _ColumnsClauseElement
from .selectable import _JoinTargetProtocol
role: Type[roles.DDLReferredColumnRole],
element: Any,
**kw: Any,
-) -> Column[Any]: ...
+) -> Union[Column[Any], str]: ...
@overload
role: Type[roles.StatementOptionRole],
element: Any,
**kw: Any,
-) -> DQLDMLClauseElement: ...
+) -> Union[ColumnElement[Any], TextClause]: ...
@overload
] = _gather_expressions
if processed_expressions is not None:
+
+ # this is expected to be an empty list
+ assert not processed_expressions
+
self._pending_colargs = []
for (
expr,
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
(cls.table_with_metadata.c.description, "&&"),
where=cls.table_with_metadata.c.description != "foo",
)
+ cls.excl_constr_anon_str = ExcludeConstraint(
+ (cls.table_with_metadata.c.name, "="),
+ (cls.table_with_metadata.c.description, "&&"),
+ where="description != 'foo'",
+ )
cls.goofy_index = Index(
"goofy_index", table1.c.name, postgresql_where=table1.c.name > "m"
)
Column("name", String(50), key="name_keyed"),
)
+ @testing.combinations(
+ (
+ lambda users, stmt: stmt.on_conflict_do_nothing(
+ index_elements=["id"], index_where=text("name = 'hi'")
+ ),
+ "ON CONFLICT (id) WHERE name = 'hi' DO NOTHING",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_nothing(
+ index_elements=[users.c.id], index_where=users.c.name == "hi"
+ ),
+ "ON CONFLICT (id) WHERE name = %(name_1)s DO NOTHING",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_nothing(
+ index_elements=["id"], index_where="name = 'hi'"
+ ),
+ exc.ArgumentError,
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: "there"},
+ where=users.c.name == "hi",
+ ),
+ "ON CONFLICT (id) DO UPDATE SET name = %(param_1)s "
+ "WHERE users.name = %(name_1)s",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: "there"},
+ where=text("name = 'hi'"),
+ ),
+ "ON CONFLICT (id) DO UPDATE SET name = %(param_1)s "
+ "WHERE name = 'hi'",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: "there"},
+ where="name = 'hi'",
+ ),
+ exc.ArgumentError,
+ ),
+ )
+ def test_assorted_arg_coercion(self, case, expected):
+ stmt = insert(self.tables.users)
+
+ if isinstance(expected, type) and issubclass(expected, Exception):
+ with expect_raises(expected):
+ testing.resolve_lambda(
+ case, stmt=stmt, users=self.tables.users
+ ),
+ else:
+ self.assert_compile(
+ testing.resolve_lambda(
+ case, stmt=stmt, users=self.tables.users
+ ),
+ f"INSERT INTO users (id, name) VALUES (%(id)s, %(name)s) "
+ f"{expected}",
+ )
+
@testing.combinations("control", "excluded", "dict")
def test_set_excluded(self, scenario):
"""test #8014, sending all of .excluded to set"""
"DO UPDATE SET name = excluded.name",
)
+ def test_do_update_unnamed_exclude_constraint_string_target(self):
+ i = insert(self.table1).values(dict(name="foo"))
+ i = i.on_conflict_do_update(
+ constraint=self.excl_constr_anon_str,
+ set_=dict(name=i.excluded.name),
+ )
+ self.assert_compile(
+ i,
+ "INSERT INTO mytable (name) VALUES "
+ "(%(name)s) ON CONFLICT (name, description) "
+ "WHERE description != 'foo' "
+ "DO UPDATE SET name = excluded.name",
+ )
+
def test_do_update_add_whereclause(self):
i = insert(self.table1).values(dict(name="foo"))
i = i.on_conflict_do_update(
"AND mytable.description != %(description_2)s",
)
+ def test_do_update_str_index_where(self):
+ i = insert(self.table1).values(dict(name="foo"))
+ i = i.on_conflict_do_update(
+ constraint=self.excl_constr_anon_str,
+ set_=dict(name=i.excluded.name),
+ where=(
+ (self.table1.c.name != "brah")
+ & (self.table1.c.description != "brah")
+ ),
+ )
+ self.assert_compile(
+ i,
+ "INSERT INTO mytable (name) VALUES "
+ "(%(name)s) ON CONFLICT (name, description) "
+ "WHERE description != 'foo' "
+ "DO UPDATE SET name = excluded.name "
+ "WHERE mytable.name != %(name_1)s "
+ "AND mytable.description != %(description_1)s",
+ )
+
def test_do_update_add_whereclause_references_excluded(self):
i = insert(self.table1).values(dict(name="foo"))
i = i.on_conflict_do_update(
[(43, "nameunique2", "name2@gmail.com", "not")],
)
- def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection):
+ @testing.variation("string_index_elements", [True, False])
+ def test_on_conflict_do_update_exotic_targets_four_no_pk(
+ self, connection, string_index_elements
+ ):
users = self.tables.users_xtra
self._exotic_targets_fixture(connection)
# upsert on target login_email, not id
i = insert(users)
i = i.on_conflict_do_update(
- index_elements=[users.c.login_email],
+ index_elements=(
+ ["login_email"]
+ if string_index_elements
+ else [users.c.login_email]
+ ),
set_=dict(
id=i.excluded.id,
name=i.excluded.name,
)
-class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest):
+class OnConflictCompileTest(AssertsCompiledSQL):
+ __dialect__ = "sqlite"
+
+ @testing.combinations(
+ (
+ lambda users, stmt: stmt.on_conflict_do_nothing(
+ index_elements=["id"], index_where=text("name = 'hi'")
+ ),
+ "ON CONFLICT (id) WHERE name = 'hi' DO NOTHING",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_nothing(
+ index_elements=["id"], index_where="name = 'hi'"
+ ),
+ exc.ArgumentError,
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_nothing(
+ index_elements=[users.c.id], index_where=users.c.name == "hi"
+ ),
+ "ON CONFLICT (id) WHERE name = __[POSTCOMPILE_name_1] DO NOTHING",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: "there"},
+ where=users.c.name == "hi",
+ ),
+ "ON CONFLICT (id) DO UPDATE SET name = ? " "WHERE users.name = ?",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: "there"},
+ where=text("name = 'hi'"),
+ ),
+ "ON CONFLICT (id) DO UPDATE SET name = ? " "WHERE name = 'hi'",
+ ),
+ (
+ lambda users, stmt: stmt.on_conflict_do_update(
+ index_elements=[users.c.id],
+ set_={users.c.name: "there"},
+ where="name = 'hi'",
+ ),
+ exc.ArgumentError,
+ ),
+ argnames="case,expected",
+ )
+ def test_assorted_arg_coercion(self, users, case, expected):
+ stmt = insert(users)
+
+ if isinstance(expected, type) and issubclass(expected, Exception):
+ with expect_raises(expected):
+ testing.resolve_lambda(case, stmt=stmt, users=users),
+ else:
+ self.assert_compile(
+ testing.resolve_lambda(case, stmt=stmt, users=users),
+ f"INSERT INTO users (id, name) VALUES (?, ?) {expected}",
+ )
+
+ @testing.combinations("control", "excluded", "dict")
+ def test_set_excluded(self, scenario, users):
+ """test #8014, sending all of .excluded to set"""
+
+ if scenario == "control":
+
+ stmt = insert(users)
+ self.assert_compile(
+ stmt.on_conflict_do_update(set_=stmt.excluded),
+ "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT "
+ "DO UPDATE SET id = excluded.id, name = excluded.name",
+ )
+ else:
+ users_w_key = self.tables.users_w_key
+
+ stmt = insert(users_w_key)
+
+ if scenario == "excluded":
+ self.assert_compile(
+ stmt.on_conflict_do_update(set_=stmt.excluded),
+ "INSERT INTO users_w_key (id, name) VALUES (?, ?) "
+ "ON CONFLICT "
+ "DO UPDATE SET id = excluded.id, name = excluded.name",
+ )
+ else:
+ self.assert_compile(
+ stmt.on_conflict_do_update(
+ set_={
+ "id": stmt.excluded.id,
+ "name_keyed": stmt.excluded.name_keyed,
+ }
+ ),
+ "INSERT INTO users_w_key (id, name) VALUES (?, ?) "
+ "ON CONFLICT "
+ "DO UPDATE SET id = excluded.id, name = excluded.name",
+ )
+
+ def test_on_conflict_do_update_exotic_targets_six(
+ self, connection, users_xtra
+ ):
+ users = users_xtra
+
+ unique_partial_index = schema.Index(
+ "idx_unique_partial_name",
+ users_xtra.c.name,
+ users_xtra.c.lets_index_this,
+ unique=True,
+ sqlite_where=users_xtra.c.lets_index_this == "unique_name",
+ )
+
+ conn = connection
+ conn.execute(
+ insert(users),
+ dict(
+ id=1,
+ name="name1",
+ login_email="mail1@gmail.com",
+ lets_index_this="unique_name",
+ ),
+ )
+ i = insert(users)
+ i = i.on_conflict_do_update(
+ index_elements=unique_partial_index.columns,
+ index_where=unique_partial_index.dialect_options["sqlite"][
+ "where"
+ ],
+ set_=dict(
+ name=i.excluded.name, login_email=i.excluded.login_email
+ ),
+ )
+
+ # this test illustrates that the index_where clause can't use
+ # bound parameters, where we see below a literal_execute parameter is
+ # used (will be sent as literal to the DBAPI). SQLite otherwise
+ # fails here with "(sqlite3.OperationalError) ON CONFLICT clause does
+ # not match any PRIMARY KEY or UNIQUE constraint" if sent as a real
+ # bind parameter.
+ self.assert_compile(
+ i,
+ "INSERT INTO users_xtra (id, name, login_email, lets_index_this) "
+ "VALUES (?, ?, ?, ?) ON CONFLICT (name, lets_index_this) "
+ "WHERE lets_index_this = __[POSTCOMPILE_lets_index_this_1] "
+ "DO UPDATE "
+ "SET name = excluded.name, login_email = excluded.login_email",
+ )
+
+ @testing.fixture
+ def users(self):
+ metadata = MetaData()
+ return Table(
+ "users",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50)),
+ )
+
+ @testing.fixture
+ def users_xtra(self):
+ metadata = MetaData()
+ return Table(
+ "users_xtra",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50)),
+ Column("login_email", String(50)),
+ Column("lets_index_this", String(50)),
+ )
+
+
+class OnConflictTest(fixtures.TablesTest):
__only_on__ = ("sqlite >= 3.24.0",)
__backend__ = True
)
def test_bad_args(self):
- assert_raises(
- ValueError, insert(self.tables.users).on_conflict_do_update
- )
-
- @testing.combinations("control", "excluded", "dict")
- @testing.skip_if("+pysqlite_numeric")
- @testing.skip_if("+pysqlite_dollar")
- def test_set_excluded(self, scenario):
- """test #8014, sending all of .excluded to set"""
-
- if scenario == "control":
- users = self.tables.users
-
- stmt = insert(users)
- self.assert_compile(
- stmt.on_conflict_do_update(set_=stmt.excluded),
- "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT "
- "DO UPDATE SET id = excluded.id, name = excluded.name",
- )
- else:
- users_w_key = self.tables.users_w_key
-
- stmt = insert(users_w_key)
-
- if scenario == "excluded":
- self.assert_compile(
- stmt.on_conflict_do_update(set_=stmt.excluded),
- "INSERT INTO users_w_key (id, name) VALUES (?, ?) "
- "ON CONFLICT "
- "DO UPDATE SET id = excluded.id, name = excluded.name",
- )
- else:
- self.assert_compile(
- stmt.on_conflict_do_update(
- set_={
- "id": stmt.excluded.id,
- "name_keyed": stmt.excluded.name_keyed,
- }
- ),
- "INSERT INTO users_w_key (id, name) VALUES (?, ?) "
- "ON CONFLICT "
- "DO UPDATE SET id = excluded.id, name = excluded.name",
- )
+ with expect_raises(ValueError):
+ insert(self.tables.users).on_conflict_do_update()
def test_on_conflict_do_no_call_twice(self):
users = self.tables.users