]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
introduce generalized decorator to prevent invalid method calls
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 16 Jan 2021 17:39:51 +0000 (12:39 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 16 Jan 2021 23:44:21 +0000 (18:44 -0500)
This introduces the ``_exclusive_against()`` utility decorator
that can be used to prevent repeated invocations of methods that
typically should only be called once.

An informative error message is now raised for a selected set of DML
methods (currently all part of :class:`_dml.Insert` constructs) if they are
called a second time, which would implicitly cancel out the previous
setting.  The methods altered include:
:class:`_sqlite.Insert.on_conflict_do_update`,
:class:`_sqlite.Insert.on_conflict_do_nothing` (SQLite),
:class:`_postgresql.Insert.on_conflict_do_update`,
:class:`_postgresql.Insert.on_conflict_do_nothing` (PostgreSQL),
:class:`_mysql.Insert.on_duplicate_key_update` (MySQL)

Fixes: #5169
Change-Id: I9278fa87cd3470dcf296ff96bb0fb17a3236d49d

lib/sqlalchemy/dialects/mysql/dml.py
lib/sqlalchemy/dialects/postgresql/dml.py
lib/sqlalchemy/dialects/sqlite/dml.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/dml.py
test/dialect/mysql/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/dialect/test_sqlite.py
test/sql/test_update.py

index 9f8177c598a37d2251e3d2fc5958d2422ad9e6f5..6c50dcca9166ba8bf0d24435c5a5e77f3036d1b4 100644 (file)
@@ -1,5 +1,6 @@
 from ... import exc
 from ... import util
+from ...sql.base import _exclusive_against
 from ...sql.base import _generative
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
@@ -49,6 +50,13 @@ class Insert(StandardInsert):
         return alias(self.table, name="inserted")
 
     @_generative
+    @_exclusive_against(
+        "_post_values_clause",
+        msgs={
+            "_post_values_clause": "This Insert construct already "
+            "has an ON DUPLICATE KEY clause present"
+        },
+    )
     def on_duplicate_key_update(self, *args, **kw):
         r"""
         Specifies the ON DUPLICATE KEY UPDATE clause.
index 76dfafd0497017b2a1c3710bb5f7d4e3cd256135..bff61e173674f4e388ac743ba9c7b74527bb99c1 100644 (file)
@@ -10,6 +10,7 @@ from ... import util
 from ...sql import coercions
 from ...sql import roles
 from ...sql import schema
+from ...sql.base import _exclusive_against
 from ...sql.base import _generative
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
@@ -50,7 +51,16 @@ class Insert(StandardInsert):
         """
         return alias(self.table, name="excluded").columns
 
+    _on_conflict_exclusive = _exclusive_against(
+        "_post_values_clause",
+        msgs={
+            "_post_values_clause": "This Insert construct already has "
+            "an ON CONFLICT clause established"
+        },
+    )
+
     @_generative
+    @_on_conflict_exclusive
     def on_conflict_do_update(
         self,
         constraint=None,
@@ -117,6 +127,7 @@ class Insert(StandardInsert):
         )
 
     @_generative
+    @_on_conflict_exclusive
     def on_conflict_do_nothing(
         self, constraint=None, index_elements=None, index_where=None
     ):
index 9c8f10f7bc18284a1c0ee72353e14ddad7450bea..be32781c7a643f983a5f9e470805ce522cd6838c 100644 (file)
@@ -7,6 +7,7 @@
 from ... import util
 from ...sql import coercions
 from ...sql import roles
+from ...sql.base import _exclusive_against
 from ...sql.base import _generative
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
@@ -46,7 +47,16 @@ class Insert(StandardInsert):
         """
         return alias(self.table, name="excluded").columns
 
+    _on_conflict_exclusive = _exclusive_against(
+        "_post_values_clause",
+        msgs={
+            "_post_values_clause": "This Insert construct already has "
+            "an ON CONFLICT clause established"
+        },
+    )
+
     @_generative
+    @_on_conflict_exclusive
     def on_conflict_do_update(
         self,
         index_elements=None,
@@ -99,6 +109,7 @@ class Insert(StandardInsert):
         )
 
     @_generative
+    @_on_conflict_exclusive
     def on_conflict_do_nothing(self, index_elements=None, index_where=None):
         """
         Specifies a DO NOTHING action for ON CONFLICT clause.
index 550111020eb05b048f8bde3d556a6fbd3e6420e7..220bbb115b7df4f3f532e42f7c4fa697d4ed54c4 100644 (file)
@@ -102,6 +102,31 @@ def _generative(fn):
     return decorated
 
 
+def _exclusive_against(*names, **kw):
+    msgs = kw.pop("msgs", {})
+
+    defaults = kw.pop("defaults", {})
+
+    getters = [
+        (name, operator.attrgetter(name), defaults.get(name, None))
+        for name in names
+    ]
+
+    @util.decorator
+    def check(fn, self, *args, **kw):
+        for name, getter, default_ in getters:
+            if getter(self) is not default_:
+                msg = msgs.get(
+                    name,
+                    "Method %s() has already been invoked on this %s construct"
+                    % (fn.__name__, self.__class__),
+                )
+                raise exc.InvalidRequestError(msg)
+        return fn(self, *args, **kw)
+
+    return check
+
+
 def _clone(element, **kw):
     return element._clone()
 
index c402de12160ce1f352bf7d46e90dea8bdd5d97d5..3f492a490eb38ff74f59d0afc81dc98c8116b29e 100644 (file)
@@ -14,6 +14,7 @@ from . import coercions
 from . import roles
 from . import util as sql_util
 from .base import _entity_namespace_key
+from .base import _exclusive_against
 from .base import _from_objects
 from .base import _generative
 from .base import ColumnCollection
@@ -495,6 +496,15 @@ class ValuesBase(UpdateBase):
             self._setup_prefixes(prefixes)
 
     @_generative
+    @_exclusive_against(
+        "_select_names",
+        "_ordered_values",
+        msgs={
+            "_select_names": "This construct already inserts from a SELECT",
+            "_ordered_values": "This statement already has ordered "
+            "values present",
+        },
+    )
     def values(self, *args, **kwargs):
         r"""Specify a fixed VALUES clause for an INSERT statement, or the SET
         clause for an UPDATE.
@@ -607,15 +617,6 @@ class ValuesBase(UpdateBase):
 
 
         """
-        if self._select_names:
-            raise exc.InvalidRequestError(
-                "This construct already inserts from a SELECT"
-            )
-        elif self._ordered_values:
-            raise exc.ArgumentError(
-                "This statement already has ordered values present"
-            )
-
         if args:
             # positional case.  this is currently expensive.   we don't
             # yet have positional-only args so we have to check the length.
@@ -699,6 +700,13 @@ class ValuesBase(UpdateBase):
                 self._values = util.immutabledict(arg)
 
     @_generative
+    @_exclusive_against(
+        "_returning",
+        msgs={
+            "_returning": "RETURNING is already configured on this statement"
+        },
+        defaults={"_returning": _returning},
+    )
     def return_defaults(self, *cols):
         """Make use of a :term:`RETURNING` clause for the purpose
         of fetching server-side expressions and defaults.
@@ -783,10 +791,6 @@ class ValuesBase(UpdateBase):
             :attr:`_engine.CursorResult.inserted_primary_key_rows`
 
         """
-        if self._returning:
-            raise exc.InvalidRequestError(
-                "RETURNING is already configured on this statement"
-            )
         self._return_defaults = cols or True
 
 
index 7fd24e8b51c2c069912c2c1c540dd741a39fe0ee..84646d3802a96e58af154c154d2fb4e8da10d37f 100644 (file)
@@ -1000,6 +1000,22 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
             Column("baz", String(10)),
         )
 
+    def test_no_call_twice(self):
+        stmt = insert(self.table).values(
+            [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
+        )
+        stmt = stmt.on_duplicate_key_update(
+            bar=stmt.inserted.bar, baz=stmt.inserted.baz
+        )
+        with testing.expect_raises_message(
+            exc.InvalidRequestError,
+            "This Insert construct already has an "
+            "ON DUPLICATE KEY clause present",
+        ):
+            stmt = stmt.on_duplicate_key_update(
+                bar=stmt.inserted.bar, baz=stmt.inserted.baz
+            )
+
     def test_from_values(self):
         stmt = insert(self.table).values(
             [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}]
index b3a0b9bbded4adbd8c692e0c079ffbb8fa8517b4..eb39091ae4b1a461da8c0a1474b7a564d35f2bb3 100644 (file)
@@ -1842,6 +1842,27 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
             "goofy_index", table1.c.name, postgresql_where=table1.c.name > "m"
         )
 
+    def test_on_conflict_do_no_call_twice(self):
+        users = self.table1
+
+        for stmt in (
+            insert(users).on_conflict_do_nothing(),
+            insert(users).on_conflict_do_update(
+                index_elements=[users.c.myid], set_=dict(name="foo")
+            ),
+        ):
+            for meth in (
+                stmt.on_conflict_do_nothing,
+                stmt.on_conflict_do_update,
+            ):
+
+                with testing.expect_raises_message(
+                    exc.InvalidRequestError,
+                    "This Insert construct already has an "
+                    "ON CONFLICT clause established",
+                ):
+                    meth()
+
     def test_do_nothing_no_target(self):
 
         i = insert(
index 0500a20bd63f857d759dc0c60b4a1609545e8a3b..23ceb88b3cf61c2eea3cbb1d29fb85d8dd3b26fd 100644 (file)
@@ -2717,6 +2717,27 @@ class OnConflictTest(fixtures.TablesTest):
             ValueError, insert(self.tables.users).on_conflict_do_update
         )
 
+    def test_on_conflict_do_no_call_twice(self):
+        users = self.tables.users
+
+        for stmt in (
+            insert(users).on_conflict_do_nothing(),
+            insert(users).on_conflict_do_update(
+                index_elements=[users.c.id], set_=dict(name="foo")
+            ),
+        ):
+            for meth in (
+                stmt.on_conflict_do_nothing,
+                stmt.on_conflict_do_update,
+            ):
+
+                with testing.expect_raises_message(
+                    exc.InvalidRequestError,
+                    "This Insert construct already has an "
+                    "ON CONFLICT clause established",
+                ):
+                    meth()
+
     def test_on_conflict_do_nothing(self, connection):
         users = self.tables.users
 
index 946a01651a554a6a1ded4b443f7f58db4e1cf633..26b0f6217ad9336471dc4e6c7bb80329be177976 100644 (file)
@@ -672,7 +672,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
         stmt = table1.update().ordered_values(("myid", 1), ("name", "d1"))
 
         assert_raises_message(
-            exc.ArgumentError,
+            exc.InvalidRequestError,
             "This statement already has ordered values present",
             stmt.values,
             {"myid": 2, "name": "d2"},