]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Warn / raise for returning() / return_defaults() combinations
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Nov 2020 16:13:27 +0000 (11:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Nov 2020 16:36:06 +0000 (11:36 -0500)
A warning is emmitted if a returning() method such as
:meth:`_sql.Insert.returning` is called multiple times, as this does not
yet support additive operation.  Version 1.4 will support additive
operation for this.  Additionally, any combination of the
:meth:`_sql.Insert.returning` and :meth:`_sql.Insert.return_defaults`
methods now raises an error as these methods are mutually exclusive;
previously the operation would fail silently.

Fixes: #5691
Change-Id: Id95e0f9da48bba0b59439cb26564f0daa684c8e3

doc/build/changelog/unreleased_13/5691.rst [new file with mode: 0644]
doc/build/tutorial/data.rst
lib/sqlalchemy/sql/dml.py
test/sql/test_returning.py

diff --git a/doc/build/changelog/unreleased_13/5691.rst b/doc/build/changelog/unreleased_13/5691.rst
new file mode 100644 (file)
index 0000000..6180e77
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 5691
+
+    A warning is emmitted if a returning() method such as
+    :meth:`_sql.Insert.returning` is called multiple times, as this does not
+    yet support additive operation.  Version 1.4 will support additive
+    operation for this.  Additionally, any combination of the
+    :meth:`_sql.Insert.returning` and :meth:`_sql.Insert.return_defaults`
+    methods now raises an error as these methods are mutually exclusive;
+    previously the operation would fail silently.
+
index 27a21b0978f978847c9b561702c715acd5b3772e..849b706cc12e99c0b6f24a5d6a95b07f601b0d02 100644 (file)
@@ -1502,7 +1502,7 @@ be iterated::
     ...     delete(user_table).where(user_table.c.name == 'patrick').
     ...     returning(user_table.c.id, user_table.c.name)
     ... )
-    >>> print(delete_stmt.returning(user_table.c.id, user_table.c.name))
+    >>> print(delete_stmt)
     {opensql}DELETE FROM user_account
     WHERE user_account.name = :name_1
     RETURNING user_account.id, user_account.name
index 5726cddc029986f4227971510e91625672c7c06d..ddb85224aa188f53aa6ea72ce4a896e2fb9f215a 100644 (file)
@@ -207,6 +207,8 @@ class UpdateBase(
     _hints = util.immutabledict()
     named_with_column = False
 
+    _return_defaults = None
+
     is_dml = True
 
     @classmethod
@@ -343,11 +345,10 @@ class UpdateBase(
             for server_flag, updated_timestamp in connection.execute(stmt):
                 print(server_flag, updated_timestamp)
 
-        The given collection of column expressions should be derived from
-        the table that is
-        the target of the INSERT, UPDATE, or DELETE.  While
-        :class:`_schema.Column`
-        objects are typical, the elements can also be expressions::
+        The given collection of column expressions should be derived from the
+        table that is the target of the INSERT, UPDATE, or DELETE.  While
+        :class:`_schema.Column` objects are typical, the elements can also be
+        expressions::
 
             stmt = table.insert().returning(
                 (table.c.first_name + " " + table.c.last_name).
@@ -383,6 +384,16 @@ class UpdateBase(
 
 
         """
+        if self._return_defaults:
+            raise exc.InvalidRequestError(
+                "return_defaults() is already configured on this statement"
+            )
+        if self._returning:
+            util.warn(
+                "The returning() method does not currently support multiple "
+                "additive calls.  The existing RETURNING clause being "
+                "replaced by new columns."
+            )
         self._returning = cols
 
     def _exported_columns_iterator(self):
@@ -760,6 +771,10 @@ class ValuesBase(UpdateBase):
             :attr:`_engine.CursorResult.inserted_primary_key_rows`
 
         """
+        if self._returning:
+            raise exc.InvalidRequestError(
+                "RETURNING is already configured on this statement"
+            )
         self._return_defaults = cols or True
 
 
index 601bd62730cd808c7ed323c28deb80442a08e86f..13a1b025d486647f9c6c840a6c2ad4ae0dfc97d3 100644 (file)
@@ -1,15 +1,19 @@
 import itertools
 
 from sqlalchemy import Boolean
+from sqlalchemy import delete
 from sqlalchemy import exc as sa_exc
 from sqlalchemy import func
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import Sequence
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import update
 from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
@@ -22,6 +26,76 @@ from sqlalchemy.types import TypeDecorator
 table = GoofyType = seq = None
 
 
+class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = "postgresql"
+
+    @testing.fixture
+    def table_fixture(self):
+        return Table(
+            "foo",
+            MetaData(),
+            Column("id", Integer, primary_key=True),
+            Column("q", Integer, server_default="5"),
+            Column("x", Integer),
+            Column("y", Integer),
+        )
+
+    @testing.combinations(
+        (
+            insert,
+            "INSERT INTO foo (id, q, x, y) "
+            "VALUES (%(id)s, %(q)s, %(x)s, %(y)s)",
+        ),
+        (update, "UPDATE foo SET id=%(id)s, q=%(q)s, x=%(x)s, y=%(y)s"),
+        (delete, "DELETE FROM foo"),
+        argnames="dml_fn, sql_frag",
+        id_="na",
+    )
+    def test_return_combinations(self, table_fixture, dml_fn, sql_frag):
+        t = table_fixture
+        stmt = dml_fn(t)
+
+        stmt = stmt.returning(t.c.x)
+
+        with testing.expect_warnings(
+            r"The returning\(\) method does not currently "
+            "support multiple additive calls."
+        ):
+            stmt = stmt.returning(t.c.y)
+
+        self.assert_compile(
+            stmt,
+            "%s RETURNING foo.y" % (sql_frag),
+        )
+
+    def test_return_no_return_defaults(self, table_fixture):
+        t = table_fixture
+
+        stmt = t.insert()
+
+        stmt = stmt.returning(t.c.x)
+
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            "RETURNING is already configured on this statement",
+            stmt.return_defaults,
+        )
+
+    def test_return_defaults_no_returning(self, table_fixture):
+        t = table_fixture
+
+        stmt = t.insert()
+
+        stmt = stmt.return_defaults()
+
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            r"return_defaults\(\) is already configured on this statement",
+            stmt.returning,
+            t.c.x,
+        )
+
+
 class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
     __requires__ = ("returning",)
     __backend__ = True