]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement upsert for SQLite
authorRamonWill <ramonwilliams@hotmail.co.uk>
Mon, 14 Sep 2020 22:22:34 +0000 (18:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Nov 2020 18:34:24 +0000 (13:34 -0500)
Implemented INSERT... ON CONFLICT clause for SQLite. Pull request courtesy
Ramon Williams.

Fixes: #4010
Closes: #5580
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5580
Pull-request-sha: fb422e0749fac442a455cbce539ef662d9512bc0

Change-Id: Ibeea44f4c2cee8dab5dc22b7ec3ae1ab95c12b65

doc/build/changelog/unreleased_14/4010.rst [new file with mode: 0644]
doc/build/dialects/sqlite.rst
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/__init__.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/dialects/sqlite/dml.py [new file with mode: 0644]
lib/sqlalchemy/sql/dml.py
test/dialect/test_sqlite.py

diff --git a/doc/build/changelog/unreleased_14/4010.rst b/doc/build/changelog/unreleased_14/4010.rst
new file mode 100644 (file)
index 0000000..a2e5ce6
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: sqlite, usecase
+    :tickets: 4010
+
+    Implemented INSERT... ON CONFLICT clause for SQLite. Pull request courtesy
+    Ramon Williams.
+
+    .. seealso::
+
+        :ref:`sqlite_on_conflict_insert`
index 85a4bab4c9bdadda6be7f3b2cf1dd0df68339ede..0c04ce3f5138c64dc699de857e7aea48b1ce2b37 100644 (file)
@@ -27,6 +27,14 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect::
 
 .. autoclass:: TIME
 
+SQLite DML Constructs
+-------------------------
+
+.. autofunction:: sqlalchemy.dialects.sqlite.insert
+
+.. autoclass:: sqlalchemy.dialects.sqlite.Insert
+  :members:
+
 Pysqlite
 --------
 
index 3ad0e3813263d0bdf5c4f0ea408384bcf1d3e92d..27c7b2239bedbb4f685a2587f3b0eb18d28216b6 100644 (file)
@@ -482,7 +482,7 @@ an error or to skip performing an UPDATE.
 existing row, using any combination of new values as well as values
 from the proposed insertion.   These values are normally specified using
 keyword arguments passed to the
-:meth:`~.mysql.Insert.on_duplicate_key_update`
+:meth:`_mysql.Insert.on_duplicate_key_update`
 given column key values (usually the name of the column, unless it
 specifies :paramref:`_schema.Column.key`
 ) as keys and literal or SQL expressions
@@ -537,7 +537,7 @@ this context is unambiguous:
 
 .. warning::
 
-    The :meth:`_expression.Insert.on_duplicate_key_update`
+    The :meth:`_mysql.Insert.on_duplicate_key_update`
     method does **not** take into
     account Python-side default UPDATE values or generation functions, e.g.
     e.g. those specified using :paramref:`_schema.Column.onupdate`.
@@ -547,8 +547,8 @@ this context is unambiguous:
 
 
 In order to refer to the proposed insertion row, the special alias
-:attr:`~.mysql.Insert.inserted` is available as an attribute on
-the :class:`.mysql.Insert` object; this object is a
+:attr:`_mysql.Insert.inserted` is available as an attribute on
+the :class:`_mysql.Insert` object; this object is a
 :class:`_expression.ColumnCollection` which contains all columns of the target
 table:
 
index c7467f5ba540b5813da9d9e101ec64ad8def4474..bf257ab3fd422747bc5c4ff2c5f1e5a8d49b5c53 100644 (file)
@@ -409,12 +409,12 @@ this row.
 
 Conflicts are determined using existing unique constraints and indexes.  These
 constraints may be identified either using their name as stated in DDL,
-or they may be *inferred* by stating the columns and conditions that comprise
+or they may be inferred by stating the columns and conditions that comprise
 the indexes.
 
 SQLAlchemy provides ``ON CONFLICT`` support via the PostgreSQL-specific
 :func:`_postgresql.insert()` function, which provides
-the generative methods :meth:`~.postgresql.Insert.on_conflict_do_update`
+the generative methods :meth:`_postgresql.Insert.on_conflict_do_update`
 and :meth:`~.postgresql.Insert.on_conflict_do_nothing`:
 
 .. sourcecode:: pycon+sql
@@ -439,10 +439,21 @@ and :meth:`~.postgresql.Insert.on_conflict_do_nothing`:
     {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
     ON CONFLICT ON CONSTRAINT pk_my_table DO UPDATE SET data = %(param_1)s
 
+.. versionadded:: 1.1
+
+.. seealso::
+
+    `INSERT .. ON CONFLICT
+    <http://www.postgresql.org/docs/current/static/sql-insert.html#SQL-ON-CONFLICT>`_
+    - in the PostgreSQL documentation.
+
+Specifying the Target
+^^^^^^^^^^^^^^^^^^^^^
+
 Both methods supply the "target" of the conflict using either the
 named constraint or by column inference:
 
-* The :paramref:`.Insert.on_conflict_do_update.index_elements` argument
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` argument
   specifies a sequence containing string column names, :class:`_schema.Column`
   objects, and/or SQL expression elements, which would identify a unique
   index:
@@ -466,9 +477,9 @@ named constraint or by column inference:
     {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
     ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
 
-* When using :paramref:`.Insert.on_conflict_do_update.index_elements` to
+* When using :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` to
   infer an index, a partial index can be inferred by also specifying the
-  use the :paramref:`.Insert.on_conflict_do_update.index_where` parameter:
+  use the :paramref:`_postgresql.Insert.on_conflict_do_update.index_where` parameter:
 
   .. sourcecode:: pycon+sql
 
@@ -483,7 +494,7 @@ named constraint or by column inference:
     VALUES (%(data)s, %(user_email)s) ON CONFLICT (user_email)
     WHERE user_email LIKE %(user_email_1)s DO UPDATE SET data = excluded.data
 
-* The :paramref:`.Insert.on_conflict_do_update.constraint` argument is
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument is
   used to specify an index directly rather than inferring it.  This can be
   the name of a UNIQUE constraint, a PRIMARY KEY constraint, or an INDEX:
 
@@ -507,7 +518,7 @@ named constraint or by column inference:
     ON CONFLICT ON CONSTRAINT my_table_pk DO UPDATE SET data = %(param_1)s
     {stop}
 
-* The :paramref:`.Insert.on_conflict_do_update.constraint` argument may
+* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument may
   also refer to a SQLAlchemy construct representing a constraint,
   e.g. :class:`.UniqueConstraint`, :class:`.PrimaryKeyConstraint`,
   :class:`.Index`, or :class:`.ExcludeConstraint`.   In this use,
@@ -529,10 +540,13 @@ named constraint or by column inference:
     {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
     ON CONFLICT (id) DO UPDATE SET data = %(param_1)s
 
+The SET Clause
+^^^^^^^^^^^^^^^
+
 ``ON CONFLICT...DO UPDATE`` is used to perform an update of the already
 existing row, using any combination of new values as well as values
 from the proposed insertion.   These values are specified using the
-:paramref:`.Insert.on_conflict_do_update.set_` parameter.  This
+:paramref:`_postgresql.Insert.on_conflict_do_update.set_` parameter.  This
 parameter accepts a dictionary which consists of direct values
 for UPDATE:
 
@@ -555,7 +569,10 @@ for UPDATE:
     those specified using :paramref:`_schema.Column.onupdate`.
     These values will not be exercised for an ON CONFLICT style of UPDATE,
     unless they are manually specified in the
-    :paramref:`.Insert.on_conflict_do_update.set_` dictionary.
+    :paramref:`_postgresql.Insert.on_conflict_do_update.set_` dictionary.
+
+Updating using the Excluded INSERT Values
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 In order to refer to the proposed insertion row, the special alias
 :attr:`~.postgresql.Insert.excluded` is available as an attribute on
@@ -580,8 +597,11 @@ table:
     VALUES (%(id)s, %(data)s, %(author)s)
     ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author
 
+Additional WHERE Criteria
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
 The :meth:`_expression.Insert.on_conflict_do_update` method also accepts
-a WHERE clause using the :paramref:`.Insert.on_conflict_do_update.where`
+a WHERE clause using the :paramref:`_postgresql.Insert.on_conflict_do_update.where`
 parameter, which will limit those rows which receive an UPDATE:
 
 .. sourcecode:: pycon+sql
@@ -602,7 +622,10 @@ parameter, which will limit those rows which receive an UPDATE:
     ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author
     WHERE my_table.status = %(status_1)s
 
-``ON CONFLICT`` may also be used to skip inserting a row entirely
+Skipping Rows with DO NOTHING
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``ON CONFLICT`` may be used to skip inserting a row entirely
 if any conflict with a unique or exclusion constraint occurs; below
 this is illustrated using the
 :meth:`~.postgresql.Insert.on_conflict_do_nothing` method:
@@ -627,14 +650,6 @@ constraint violation which occurs:
     {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s)
     ON CONFLICT DO NOTHING
 
-.. versionadded:: 1.1 Added support for PostgreSQL ON CONFLICT clauses
-
-.. seealso::
-
-    `INSERT .. ON CONFLICT
-    <http://www.postgresql.org/docs/current/static/sql-insert.html#SQL-ON-CONFLICT>`_
-    - in the PostgreSQL documentation.
-
 .. _postgresql_match:
 
 Full Text Search
index 142131f631bddeed30951f5caca9a58d763414ff..72402dd92bab59931899bafa1ab85c07e0e6ef62 100644 (file)
@@ -24,7 +24,8 @@ from .base import TEXT
 from .base import TIME
 from .base import TIMESTAMP
 from .base import VARCHAR
-
+from .dml import Insert
+from .dml import insert
 
 # default dialect
 base.dialect = dialect = pysqlite.dialect
@@ -47,5 +48,7 @@ __all__ = (
     "TIMESTAMP",
     "VARCHAR",
     "REAL",
+    "Insert",
+    "insert",
     "dialect",
 )
index 5efd0d9c991edc3fd7dc32afd1ab24f65da2fd4f..fc08b4b5ef87224a50e2643405ab06505431a1f5 100644 (file)
@@ -301,7 +301,11 @@ new connections through the usage of events::
 ON CONFLICT support for constraints
 -----------------------------------
 
-SQLite supports a non-standard clause known as ON CONFLICT which can be applied
+.. seealso:: This section describes the :term:`DDL` version of "ON CONFLICT" for
+   SQLite, which occurs within a CREATE TABLE statement.  For "ON CONFLICT" as
+   applied to an INSERT statement, see :ref:`sqlite_on_conflict_insert`.
+
+SQLite supports a non-standard DDL clause known as ON CONFLICT which can be applied
 to primary key, unique, check, and not null constraints.   In DDL, it is
 rendered either within the "CONSTRAINT" clause or within the column definition
 itself depending on the location of the target constraint.    To render this
@@ -402,6 +406,208 @@ resolution algorithm is applied to the constraint itself::
         PRIMARY KEY (id) ON CONFLICT FAIL
     )
 
+.. _sqlite_on_conflict_insert:
+
+INSERT...ON CONFLICT (Upsert)
+-----------------------------------
+
+.. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for
+   SQLite, which occurs within an INSERT statement.  For "ON CONFLICT" as
+   applied to a CREATE TABLE statement, see :ref:`sqlite_on_conflict_ddl`.
+
+From version 3.24.0 onwards, SQLite supports "upserts" (update or insert)
+of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT``
+statement. A candidate row will only be inserted if that row does not violate
+any unique or primary key constraints. In the case of a unique constraint violation, a
+secondary action can occur which can be either “DO UPDATE”, indicating that
+the data in the target row should be updated, or “DO NOTHING”, which indicates
+to silently skip this row.
+
+Conflicts are determined using columns that are part of existing unique
+constraints and indexes.  These constraints are identified by stating the
+columns and conditions that comprise the indexes.
+
+SQLAlchemy provides ``ON CONFLICT`` support via the SQLite-specific
+:func:`_sqlite.insert()` function, which provides
+the generative methods :meth:`_sqlite.Insert.on_conflict_do_update`
+and :meth:`_sqlite.Insert.on_conflict_do_nothing`:
+
+.. sourcecode:: pycon+sql
+
+    >>> from sqlalchemy.dialects.sqlite import insert
+
+    >>> insert_stmt = insert(my_table).values(
+    ...     id='some_existing_id',
+    ...     data='inserted value')
+
+    >>> do_update_stmt = insert_stmt.on_conflict_do_update(
+    ...     index_elements=['id'],
+    ...     set_=dict(data='updated value')
+    ... )
+
+    >>> print(do_update_stmt)
+    {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+    ON CONFLICT (id) DO UPDATE SET data = ?{stop}
+
+    >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(
+    ...     index_elements=['id']
+    ... )
+
+    >>> print(do_nothing_stmt)
+    {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+    ON CONFLICT (id) DO NOTHING
+
+.. versionadded:: 1.4
+
+.. seealso::
+
+    `Upsert
+    <https://sqlite.org/lang_UPSERT.html>`_
+    - in the SQLite documentation.
+
+
+Specifying the Target
+^^^^^^^^^^^^^^^^^^^^^
+
+Both methods supply the “target” of the conflict using column inference:
+
+* The :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` argument
+  specifies a sequence containing string column names, :class:`_schema.Column`
+  objects, and/or SQL expression elements, which would identify a unique index
+  or unique constraint.
+
+* When using :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements`
+  to infer an index, a partial index can be inferred by also specifying the
+  :paramref:`_sqlite.Insert.on_conflict_do_update.index_where` parameter:
+
+  .. sourcecode:: pycon+sql
+
+        >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data')
+
+        >>> do_update_stmt = stmt.on_conflict_do_update(
+        ...     index_elements=[my_table.c.user_email],
+        ...     index_where=my_table.c.user_email.like('%@gmail.com'),
+        ...     set_=dict(data=stmt.excluded.data)
+        ...     )
+
+        >>> print(do_update_stmt)
+        {opensql}INSERT INTO my_table (data, user_email) VALUES (?, ?)
+        ON CONFLICT (user_email)
+        WHERE user_email LIKE '%@gmail.com'
+        DO UPDATE SET data = excluded.data
+        >>>
+
+The SET Clause
+^^^^^^^^^^^^^^^
+
+``ON CONFLICT...DO UPDATE`` is used to perform an update of the already
+existing row, using any combination of new values as well as values
+from the proposed insertion. These values are specified using the
+:paramref:`_sqlite.Insert.on_conflict_do_update.set_` parameter.  This
+parameter accepts a dictionary which consists of direct values
+for UPDATE:
+
+.. sourcecode:: pycon+sql
+
+    >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+
+    >>> do_update_stmt = stmt.on_conflict_do_update(
+    ...     index_elements=['id'],
+    ...     set_=dict(data='updated value')
+    ... )
+
+    >>> print(do_update_stmt)
+
+    {opensql}INSERT INTO my_table (id, data) VALUES (?, ?)
+    ON CONFLICT (id) DO UPDATE SET data = ?
+
+.. warning::
+
+    The :meth:`_sqlite.Insert.on_conflict_do_update` method does **not** take
+    into account Python-side default UPDATE values or generation functions,
+    e.g. those specified using :paramref:`_schema.Column.onupdate`. These
+    values will not be exercised for an ON CONFLICT style of UPDATE, unless
+    they are manually specified in the
+    :paramref:`_sqlite.Insert.on_conflict_do_update.set_` dictionary.
+
+Updating using the Excluded INSERT Values
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to refer to the proposed insertion row, the special alias
+:attr:`~.sqlite.Insert.excluded` is available as an attribute on
+the :class:`_sqlite.Insert` object; this object creates an "excluded." prefix
+on a column, that informs the DO UPDATE to update the row with the value that
+would have been inserted had the constraint not failed:
+
+.. sourcecode:: pycon+sql
+
+    >>> stmt = insert(my_table).values(
+    ...     id='some_id',
+    ...     data='inserted value',
+    ...     author='jlh'
+    ... )
+
+    >>> do_update_stmt = stmt.on_conflict_do_update(
+    ...     index_elements=['id'],
+    ...     set_=dict(data='updated value', author=stmt.excluded.author)
+    ... )
+
+    >>> print(do_update_stmt)
+    {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?)
+    ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author
+
+Additional WHERE Criteria
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The :meth:`_sqlite.Insert.on_conflict_do_update` method also accepts
+a WHERE clause using the :paramref:`_sqlite.Insert.on_conflict_do_update.where`
+parameter, which will limit those rows which receive an UPDATE:
+
+.. sourcecode:: pycon+sql
+
+    >>> stmt = insert(my_table).values(
+    ...     id='some_id',
+    ...     data='inserted value',
+    ...     author='jlh'
+    ... )
+
+    >>> on_update_stmt = stmt.on_conflict_do_update(
+    ...     index_elements=['id'],
+    ...     set_=dict(data='updated value', author=stmt.excluded.author),
+    ...     where=(my_table.c.status == 2)
+    ... )
+    >>> print(on_update_stmt)
+    {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?)
+    ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author
+    WHERE my_table.status = ?
+
+
+Skipping Rows with DO NOTHING
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``ON CONFLICT`` may be used to skip inserting a row entirely
+if any conflict with a unique constraint occurs; below this is illustrated
+using the :meth:`_sqlite.Insert.on_conflict_do_nothing` method:
+
+.. sourcecode:: pycon+sql
+
+    >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+    >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id'])
+    >>> print(stmt)
+    {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING
+
+
+If ``DO NOTHING`` is used without specifying any columns or constraint,
+it has the effect of skipping the INSERT for any unique violation which
+occurs:
+
+.. sourcecode:: pycon+sql
+
+    >>> stmt = insert(my_table).values(id='some_id', data='inserted value')
+    >>> stmt = stmt.on_conflict_do_nothing()
+    >>> print(stmt)
+    {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING
+
 .. _sqlite_type_reflection:
 
 Type Reflection
@@ -600,8 +806,11 @@ from ... import types as sqltypes
 from ... import util
 from ...engine import default
 from ...engine import reflection
+from ...sql import coercions
 from ...sql import ColumnElement
 from ...sql import compiler
+from ...sql import elements
+from ...sql import roles
 from ...types import BLOB  # noqa
 from ...types import BOOLEAN  # noqa
 from ...types import CHAR  # noqa
@@ -1083,6 +1292,101 @@ class SQLiteCompiler(compiler.SQLCompiler):
     def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
         return self._generate_generic_binary(binary, " NOT REGEXP ", **kw)
 
+    def _on_conflict_target(self, clause, **kw):
+        if clause.constraint_target is not None:
+            target_text = "(%s)" % clause.constraint_target
+        elif clause.inferred_target_elements is not None:
+            target_text = "(%s)" % ", ".join(
+                (
+                    self.preparer.quote(c)
+                    if isinstance(c, util.string_types)
+                    else self.process(c, include_table=False, use_schema=False)
+                )
+                for c in clause.inferred_target_elements
+            )
+            if clause.inferred_target_whereclause is not None:
+                target_text += " WHERE %s" % self.process(
+                    clause.inferred_target_whereclause,
+                    include_table=False,
+                    use_schema=False,
+                    literal_binds=True,
+                )
+
+        else:
+            target_text = ""
+
+        return target_text
+
+    def visit_on_conflict_do_nothing(self, on_conflict, **kw):
+
+        target_text = self._on_conflict_target(on_conflict, **kw)
+
+        if target_text:
+            return "ON CONFLICT %s DO NOTHING" % target_text
+        else:
+            return "ON CONFLICT DO NOTHING"
+
+    def visit_on_conflict_do_update(self, on_conflict, **kw):
+        clause = on_conflict
+
+        target_text = self._on_conflict_target(on_conflict, **kw)
+
+        action_set_ops = []
+
+        set_parameters = dict(clause.update_values_to_set)
+        # create a list of column assignment clauses as tuples
+
+        insert_statement = self.stack[-1]["selectable"]
+        cols = insert_statement.table.c
+        for c in cols:
+            col_key = c.key
+            if col_key in set_parameters:
+                value = set_parameters.pop(col_key)
+                if coercions._is_literal(value):
+                    value = elements.BindParameter(None, value, type_=c.type)
+
+                else:
+                    if (
+                        isinstance(value, elements.BindParameter)
+                        and value.type._isnull
+                    ):
+                        value = value._clone()
+                        value.type = c.type
+                value_text = self.process(value.self_group(), use_schema=False)
+
+                key_text = self.preparer.quote(col_key)
+                action_set_ops.append("%s = %s" % (key_text, value_text))
+
+        # check for names that don't match columns
+        if set_parameters:
+            util.warn(
+                "Additional column names not matching "
+                "any column keys in table '%s': %s"
+                % (
+                    self.current_executable.table.name,
+                    (", ".join("'%s'" % c for c in set_parameters)),
+                )
+            )
+            for k, v in set_parameters.items():
+                key_text = (
+                    self.preparer.quote(k)
+                    if isinstance(k, util.string_types)
+                    else self.process(k, use_schema=False)
+                )
+                value_text = self.process(
+                    coercions.expect(roles.ExpressionElementRole, v),
+                    use_schema=False,
+                )
+                action_set_ops.append("%s = %s" % (key_text, value_text))
+
+        action_text = ", ".join(action_set_ops)
+        if clause.update_whereclause is not None:
+            action_text += " WHERE %s" % self.process(
+                clause.update_whereclause, include_table=True, use_schema=False
+            )
+
+        return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text)
+
 
 class SQLiteDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py
new file mode 100644 (file)
index 0000000..a4d4d56
--- /dev/null
@@ -0,0 +1,160 @@
+# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from ... import util
+from ...sql.base import _generative
+from ...sql.dml import Insert as StandardInsert
+from ...sql.elements import ClauseElement
+from ...sql.expression import alias
+from ...util.langhelpers import public_factory
+
+
+__all__ = ("Insert", "insert")
+
+
+class Insert(StandardInsert):
+    """SQLite-specific implementation of INSERT.
+
+    Adds methods for SQLite-specific syntaxes such as ON CONFLICT.
+
+    The :class:`_sqlite.Insert` object is created using the
+    :func:`sqlalchemy.dialects.sqlite.insert` function.
+
+    .. versionadded:: 1.4
+
+    .. seealso::
+
+        :ref:`sqlite_on_conflict_insert`
+
+    """
+
+    stringify_dialect = "sqlite"
+
+    @util.memoized_property
+    def excluded(self):
+        """Provide the ``excluded`` namespace for an ON CONFLICT statement
+
+        SQLite's ON CONFLICT clause allows reference to the row that would
+        be inserted, known as ``excluded``.  This attribute provides
+        all columns in this row to be referenceable.
+
+        """
+        return alias(self.table, name="excluded").columns
+
+    @_generative
+    def on_conflict_do_update(
+        self,
+        index_elements=None,
+        index_where=None,
+        set_=None,
+        where=None,
+    ):
+        r"""
+        Specifies a DO UPDATE SET action for ON CONFLICT clause.
+
+        :param index_elements:
+         A sequence consisting of string column names, :class:`_schema.Column`
+         objects, or other column expression objects that will be used
+         to infer a target index or unique constraint.
+
+        :param index_where:
+         Additional WHERE criterion that can be used to infer a
+         conditional target index.
+
+        :param set\_:
+         Required argument. A dictionary or other mapping object
+         with column names as keys and expressions or literals as values,
+         specifying the ``SET`` actions to take.
+         If the target :class:`_schema.Column` specifies a ".
+         key" attribute distinct
+         from the column name, that key should be used.
+
+         .. warning:: This dictionary does **not** take into account
+            Python-specified default UPDATE values or generation functions,
+            e.g. those specified using :paramref:`_schema.Column.onupdate`.
+            These values will not be exercised for an ON CONFLICT style of
+            UPDATE, unless they are manually specified in the
+            :paramref:`.Insert.on_conflict_do_update.set_` dictionary.
+
+        :param where:
+         Optional argument. If present, can be a literal SQL
+         string or an acceptable expression for a ``WHERE`` clause
+         that restricts the rows affected by ``DO UPDATE SET``. Rows
+         not meeting the ``WHERE`` condition will not be updated
+         (effectively a ``DO NOTHING`` for those rows).
+
+        """
+
+        self._post_values_clause = OnConflictDoUpdate(
+            index_elements, index_where, set_, where
+        )
+
+    @_generative
+    def on_conflict_do_nothing(self, index_elements=None, index_where=None):
+        """
+        Specifies a DO NOTHING action for ON CONFLICT clause.
+
+        :param index_elements:
+         A sequence consisting of string column names, :class:`_schema.Column`
+         objects, or other column expression objects that will be used
+         to infer a target index or unique constraint.
+
+        :param index_where:
+         Additional WHERE criterion that can be used to infer a
+         conditional target index.
+
+        """
+
+        self._post_values_clause = OnConflictDoNothing(
+            index_elements, index_where
+        )
+
+
+insert = public_factory(
+    Insert, ".dialects.sqlite.insert", ".dialects.sqlite.Insert"
+)
+
+
+class OnConflictClause(ClauseElement):
+    stringify_dialect = "sqlite"
+
+    def __init__(self, index_elements=None, index_where=None):
+
+        if index_elements is not None:
+            self.constraint_target = None
+            self.inferred_target_elements = index_elements
+            self.inferred_target_whereclause = index_where
+        else:
+            self.constraint_target = (
+                self.inferred_target_elements
+            ) = self.inferred_target_whereclause = None
+
+
+class OnConflictDoNothing(OnConflictClause):
+    __visit_name__ = "on_conflict_do_nothing"
+
+
+class OnConflictDoUpdate(OnConflictClause):
+    __visit_name__ = "on_conflict_do_update"
+
+    def __init__(
+        self,
+        index_elements=None,
+        index_where=None,
+        set_=None,
+        where=None,
+    ):
+        super(OnConflictDoUpdate, self).__init__(
+            index_elements=index_elements,
+            index_where=index_where,
+        )
+
+        if not isinstance(set_, dict) or not set_:
+            raise ValueError("set parameter must be a non-empty dictionary")
+        self.update_values_to_set = [
+            (key, value) for key, value in set_.items()
+        ]
+        self.update_whereclause = where
index dc2aacbeadc4b27ad7ff9bab1ed1e4f150a4a4a7..9b9a6153b7a7a41a87729cf456d28fd711c8153f 100644 (file)
@@ -863,11 +863,12 @@ class Insert(ValuesBase):
          backends that support "returning", this turns off the "implicit
          returning" feature for the statement.
 
-        If both `values` and compile-time bind parameters are present, the
-        compile-time bind parameters override the information specified
-        within `values` on a per-key basis.
+        If both :paramref:`_expression.Insert.values` and compile-time bind
+        parameters are present, the compile-time bind parameters override the
+        information specified within :paramref:`_expression.Insert.values` on a
+        per-key basis.
 
-        The keys within `values` can be either
+        The keys within :paramref:`_expression.Insert.values` can be either
         :class:`~sqlalchemy.schema.Column` objects or their string
         identifiers. Each key may reference one of:
 
index d06cd48f5fe63e4bd5c1cdef1ad005cd09a79a7d..10e43b2218acfa379bdd9274b412640388acae2f 100644 (file)
@@ -36,6 +36,7 @@ from sqlalchemy import types as sqltypes
 from sqlalchemy import UniqueConstraint
 from sqlalchemy import util
 from sqlalchemy.dialects.sqlite import base as sqlite
+from sqlalchemy.dialects.sqlite import insert
 from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect
 from sqlalchemy.engine.url import make_url
 from sqlalchemy.schema import CreateTable
@@ -2680,3 +2681,540 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             self.table.c.myid.regexp_replace("pattern", "rep").compile,
             dialect=sqlite.dialect(),
         )
+
+
+class OnConflictTest(fixtures.TablesTest):
+
+    __only_on__ = ("sqlite >= 3.24.0",)
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "users",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+        )
+
+        class SpecialType(sqltypes.TypeDecorator):
+            impl = String
+
+            def process_bind_param(self, value, dialect):
+                return value + " processed"
+
+        Table(
+            "bind_targets",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", SpecialType()),
+        )
+
+        users_xtra = Table(
+            "users_xtra",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+            Column("login_email", String(50)),
+            Column("lets_index_this", String(50)),
+        )
+        cls.unique_partial_index = schema.Index(
+            "idx_unique_partial_name",
+            users_xtra.c.name,
+            users_xtra.c.lets_index_this,
+            unique=True,
+            sqlite_where=users_xtra.c.lets_index_this == "unique_name",
+        )
+
+        cls.unique_constraint = schema.UniqueConstraint(
+            users_xtra.c.login_email, name="uq_login_email"
+        )
+        cls.bogus_index = schema.Index(
+            "idx_special_ops",
+            users_xtra.c.lets_index_this,
+            sqlite_where=users_xtra.c.lets_index_this > "m",
+        )
+
+    def test_bad_args(self):
+        assert_raises(
+            ValueError, insert(self.tables.users).on_conflict_do_update
+        )
+
+    def test_on_conflict_do_nothing(self, connection):
+        users = self.tables.users
+
+        conn = connection
+        result = conn.execute(
+            insert(users).on_conflict_do_nothing(),
+            dict(id=1, name="name1"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+
+        result = conn.execute(
+            insert(users).on_conflict_do_nothing(),
+            dict(id=1, name="name2"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name1")],
+        )
+
+    def test_on_conflict_do_nothing_connectionless(self, connection):
+        users = self.tables.users_xtra
+
+        result = connection.execute(
+            insert(users).on_conflict_do_nothing(
+                index_elements=["login_email"]
+            ),
+            dict(name="name1", login_email="email1"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+
+        result = connection.execute(
+            insert(users).on_conflict_do_nothing(
+                index_elements=["login_email"]
+            ),
+            dict(name="name2", login_email="email1"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1", "email1", None)],
+        )
+
+    @testing.provide_metadata
+    def test_on_conflict_do_nothing_target(self, connection):
+        users = self.tables.users
+
+        conn = connection
+
+        result = conn.execute(
+            insert(users).on_conflict_do_nothing(
+                index_elements=users.primary_key.columns
+            ),
+            dict(id=1, name="name1"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+
+        result = conn.execute(
+            insert(users).on_conflict_do_nothing(
+                index_elements=users.primary_key.columns
+            ),
+            dict(id=1, name="name2"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name1")],
+        )
+
+    def test_on_conflict_do_update_one(self, connection):
+        users = self.tables.users
+
+        conn = connection
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+        )
+        result = conn.execute(i, dict(id=1, name="name1"))
+
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name1")],
+        )
+
+    def test_on_conflict_do_update_two(self, connection):
+        users = self.tables.users
+
+        conn = connection
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id],
+            set_=dict(id=i.excluded.id, name=i.excluded.name),
+        )
+
+        result = conn.execute(i, dict(id=1, name="name2"))
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name2")],
+        )
+
+    def test_on_conflict_do_update_three(self, connection):
+        users = self.tables.users
+
+        conn = connection
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(name=i.excluded.name),
+        )
+        result = conn.execute(i, dict(id=1, name="name3"))
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name3")],
+        )
+
+    def test_on_conflict_do_update_four(self, connection):
+        users = self.tables.users
+
+        conn = connection
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(id=i.excluded.id, name=i.excluded.name),
+        ).values(id=1, name="name4")
+
+        result = conn.execute(i)
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name4")],
+        )
+
+    def test_on_conflict_do_update_five(self, connection):
+        users = self.tables.users
+
+        conn = connection
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(id=10, name="I'm a name"),
+        ).values(id=1, name="name4")
+
+        result = conn.execute(i)
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 10)).fetchall(),
+            [(10, "I'm a name")],
+        )
+
+    def test_on_conflict_do_update_multivalues(self, connection):
+        users = self.tables.users
+
+        conn = connection
+
+        conn.execute(users.insert(), dict(id=1, name="name1"))
+        conn.execute(users.insert(), dict(id=2, name="name2"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(name="updated"),
+            where=(i.excluded.name != "name12"),
+        ).values(
+            [
+                dict(id=1, name="name11"),
+                dict(id=2, name="name12"),
+                dict(id=3, name="name13"),
+                dict(id=4, name="name14"),
+            ]
+        )
+
+        result = conn.execute(i)
+        eq_(result.inserted_primary_key, (None,))
+
+        eq_(
+            conn.execute(users.select().order_by(users.c.id)).fetchall(),
+            [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")],
+        )
+
+    def _exotic_targets_fixture(self, conn):
+        users = self.tables.users_xtra
+
+        conn.execute(
+            insert(users),
+            dict(
+                id=1,
+                name="name1",
+                login_email="name1@gmail.com",
+                lets_index_this="not",
+            ),
+        )
+        conn.execute(
+            users.insert(),
+            dict(
+                id=2,
+                name="name2",
+                login_email="name2@gmail.com",
+                lets_index_this="not",
+            ),
+        )
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name1", "name1@gmail.com", "not")],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_two(self, connection):
+        users = self.tables.users_xtra
+
+        conn = connection
+        self._exotic_targets_fixture(conn)
+        # try primary key constraint: cause an upsert on unique id column
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(
+                name=i.excluded.name, login_email=i.excluded.login_email
+            ),
+        )
+        result = conn.execute(
+            i,
+            dict(
+                id=1,
+                name="name2",
+                login_email="name1@gmail.com",
+                lets_index_this="not",
+            ),
+        )
+        eq_(result.inserted_primary_key, (1,))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [(1, "name2", "name1@gmail.com", "not")],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_three(self, connection):
+        users = self.tables.users_xtra
+
+        conn = connection
+        self._exotic_targets_fixture(conn)
+        # try unique constraint: cause an upsert on target
+        # login_email, not id
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=["login_email"],
+            set_=dict(
+                id=i.excluded.id,
+                name=i.excluded.name,
+                login_email=i.excluded.login_email,
+            ),
+        )
+        # note: lets_index_this value totally ignored in SET clause.
+        result = conn.execute(
+            i,
+            dict(
+                id=42,
+                name="nameunique",
+                login_email="name2@gmail.com",
+                lets_index_this="unique",
+            ),
+        )
+        eq_(result.inserted_primary_key, (42,))
+
+        eq_(
+            conn.execute(
+                users.select().where(users.c.login_email == "name2@gmail.com")
+            ).fetchall(),
+            [(42, "nameunique", "name2@gmail.com", "not")],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_four(self, connection):
+        users = self.tables.users_xtra
+
+        conn = connection
+        self._exotic_targets_fixture(conn)
+        # try unique constraint by name: cause an
+        # upsert on target login_email, not id
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=["login_email"],
+            set_=dict(
+                id=i.excluded.id,
+                name=i.excluded.name,
+                login_email=i.excluded.login_email,
+            ),
+        )
+        # note: lets_index_this value totally ignored in SET clause.
+
+        result = conn.execute(
+            i,
+            dict(
+                id=43,
+                name="nameunique2",
+                login_email="name2@gmail.com",
+                lets_index_this="unique",
+            ),
+        )
+        eq_(result.inserted_primary_key, (43,))
+
+        eq_(
+            conn.execute(
+                users.select().where(users.c.login_email == "name2@gmail.com")
+            ).fetchall(),
+            [(43, "nameunique2", "name2@gmail.com", "not")],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection):
+        users = self.tables.users_xtra
+
+        conn = connection
+        self._exotic_targets_fixture(conn)
+        # try unique constraint by name: cause an
+        # upsert on target login_email, not id
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.login_email],
+            set_=dict(
+                id=i.excluded.id,
+                name=i.excluded.name,
+                login_email=i.excluded.login_email,
+            ),
+        )
+
+        conn.execute(i, dict(name="name3", login_email="name1@gmail.com"))
+
+        eq_(
+            conn.execute(users.select().where(users.c.id == 1)).fetchall(),
+            [],
+        )
+
+        eq_(
+            conn.execute(users.select().order_by(users.c.id)).fetchall(),
+            [
+                (2, "name2", "name2@gmail.com", "not"),
+                (3, "name3", "name1@gmail.com", "not"),
+            ],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_five(self, connection):
+        users = self.tables.users_xtra
+
+        conn = connection
+        self._exotic_targets_fixture(conn)
+        # try bogus index
+        i = insert(users)
+
+        i = i.on_conflict_do_update(
+            index_elements=self.bogus_index.columns,
+            index_where=self.bogus_index.dialect_options["sqlite"]["where"],
+            set_=dict(
+                name=i.excluded.name, login_email=i.excluded.login_email
+            ),
+        )
+
+        assert_raises(
+            exc.OperationalError,
+            conn.execute,
+            i,
+            dict(
+                id=1,
+                name="namebogus",
+                login_email="bogus@gmail.com",
+                lets_index_this="bogus",
+            ),
+        )
+
+    def test_on_conflict_do_update_exotic_targets_six(self, connection):
+        users = self.tables.users_xtra
+
+        conn = connection
+        conn.execute(
+            insert(users),
+            dict(
+                id=1,
+                name="name1",
+                login_email="mail1@gmail.com",
+                lets_index_this="unique_name",
+            ),
+        )
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=self.unique_partial_index.columns,
+            index_where=self.unique_partial_index.dialect_options["sqlite"][
+                "where"
+            ],
+            set_=dict(
+                name=i.excluded.name, login_email=i.excluded.login_email
+            ),
+        )
+
+        conn.execute(
+            i,
+            [
+                dict(
+                    name="name1",
+                    login_email="mail2@gmail.com",
+                    lets_index_this="unique_name",
+                )
+            ],
+        )
+
+        eq_(
+            conn.execute(users.select()).fetchall(),
+            [(1, "name1", "mail2@gmail.com", "unique_name")],
+        )
+
+    def test_on_conflict_do_update_no_row_actually_affected(self, connection):
+        users = self.tables.users_xtra
+
+        conn = connection
+        self._exotic_targets_fixture(conn)
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.login_email],
+            set_=dict(name="new_name"),
+            where=(i.excluded.name == "other_name"),
+        )
+        result = conn.execute(
+            i, dict(name="name2", login_email="name1@gmail.com")
+        )
+
+        # The last inserted primary key should be 2 here
+        # it is taking the result from the the exotic fixture
+        eq_(result.inserted_primary_key, (2,))
+
+        eq_(
+            conn.execute(users.select()).fetchall(),
+            [
+                (1, "name1", "name1@gmail.com", "not"),
+                (2, "name2", "name2@gmail.com", "not"),
+            ],
+        )
+
+    def test_on_conflict_do_update_special_types_in_set(self, connection):
+        bind_targets = self.tables.bind_targets
+
+        conn = connection
+        i = insert(bind_targets)
+        conn.execute(i, {"id": 1, "data": "initial data"})
+
+        eq_(
+            conn.scalar(sql.select(bind_targets.c.data)),
+            "initial data processed",
+        )
+
+        i = insert(bind_targets)
+        i = i.on_conflict_do_update(
+            index_elements=[bind_targets.c.id],
+            set_=dict(data="new updated data"),
+        )
+        conn.execute(i, {"id": 1, "data": "new inserted data"})
+
+        eq_(
+            conn.scalar(sql.select(bind_targets.c.data)),
+            "new updated data processed",
+        )