From: Mike Bayer Date: Sat, 16 Jan 2021 17:39:51 +0000 (-0500) Subject: introduce generalized decorator to prevent invalid method calls X-Git-Tag: rel_1_4_0b2~43 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8860117c9655a4bdeafeba;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git introduce generalized decorator to prevent invalid method calls This introduces the ``_exclusive_against()`` utility decorator that can be used to prevent repeated invocations of methods that typically should only be called once. An informative error message is now raised for a selected set of DML methods (currently all part of :class:`_dml.Insert` constructs) if they are called a second time, which would implicitly cancel out the previous setting. The methods altered include: :class:`_sqlite.Insert.on_conflict_do_update`, :class:`_sqlite.Insert.on_conflict_do_nothing` (SQLite), :class:`_postgresql.Insert.on_conflict_do_update`, :class:`_postgresql.Insert.on_conflict_do_nothing` (PostgreSQL), :class:`_mysql.Insert.on_duplicate_key_update` (MySQL) Fixes: #5169 Change-Id: I9278fa87cd3470dcf296ff96bb0fb17a3236d49d --- diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 9f8177c598..6c50dcca91 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -1,5 +1,6 @@ from ... import exc from ... import util +from ...sql.base import _exclusive_against from ...sql.base import _generative from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement @@ -49,6 +50,13 @@ class Insert(StandardInsert): return alias(self.table, name="inserted") @_generative + @_exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already " + "has an ON DUPLICATE KEY clause present" + }, + ) def on_duplicate_key_update(self, *args, **kw): r""" Specifies the ON DUPLICATE KEY UPDATE clause. diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 76dfafd049..bff61e1736 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -10,6 +10,7 @@ from ... import util from ...sql import coercions from ...sql import roles from ...sql import schema +from ...sql.base import _exclusive_against from ...sql.base import _generative from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement @@ -50,7 +51,16 @@ class Insert(StandardInsert): """ return alias(self.table, name="excluded").columns + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + @_generative + @_on_conflict_exclusive def on_conflict_do_update( self, constraint=None, @@ -117,6 +127,7 @@ class Insert(StandardInsert): ) @_generative + @_on_conflict_exclusive def on_conflict_do_nothing( self, constraint=None, index_elements=None, index_where=None ): diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index 9c8f10f7bc..be32781c7a 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -7,6 +7,7 @@ from ... import util from ...sql import coercions from ...sql import roles +from ...sql.base import _exclusive_against from ...sql.base import _generative from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement @@ -46,7 +47,16 @@ class Insert(StandardInsert): """ return alias(self.table, name="excluded").columns + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + @_generative + @_on_conflict_exclusive def on_conflict_do_update( self, index_elements=None, @@ -99,6 +109,7 @@ class Insert(StandardInsert): ) @_generative + @_on_conflict_exclusive def on_conflict_do_nothing(self, index_elements=None, index_where=None): """ Specifies a DO NOTHING action for ON CONFLICT clause. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 550111020e..220bbb115b 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -102,6 +102,31 @@ def _generative(fn): return decorated +def _exclusive_against(*names, **kw): + msgs = kw.pop("msgs", {}) + + defaults = kw.pop("defaults", {}) + + getters = [ + (name, operator.attrgetter(name), defaults.get(name, None)) + for name in names + ] + + @util.decorator + def check(fn, self, *args, **kw): + for name, getter, default_ in getters: + if getter(self) is not default_: + msg = msgs.get( + name, + "Method %s() has already been invoked on this %s construct" + % (fn.__name__, self.__class__), + ) + raise exc.InvalidRequestError(msg) + return fn(self, *args, **kw) + + return check + + def _clone(element, **kw): return element._clone() diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index c402de1216..3f492a490e 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -14,6 +14,7 @@ from . import coercions from . import roles from . import util as sql_util from .base import _entity_namespace_key +from .base import _exclusive_against from .base import _from_objects from .base import _generative from .base import ColumnCollection @@ -495,6 +496,15 @@ class ValuesBase(UpdateBase): self._setup_prefixes(prefixes) @_generative + @_exclusive_against( + "_select_names", + "_ordered_values", + msgs={ + "_select_names": "This construct already inserts from a SELECT", + "_ordered_values": "This statement already has ordered " + "values present", + }, + ) def values(self, *args, **kwargs): r"""Specify a fixed VALUES clause for an INSERT statement, or the SET clause for an UPDATE. @@ -607,15 +617,6 @@ class ValuesBase(UpdateBase): """ - if self._select_names: - raise exc.InvalidRequestError( - "This construct already inserts from a SELECT" - ) - elif self._ordered_values: - raise exc.ArgumentError( - "This statement already has ordered values present" - ) - if args: # positional case. this is currently expensive. we don't # yet have positional-only args so we have to check the length. @@ -699,6 +700,13 @@ class ValuesBase(UpdateBase): self._values = util.immutabledict(arg) @_generative + @_exclusive_against( + "_returning", + msgs={ + "_returning": "RETURNING is already configured on this statement" + }, + defaults={"_returning": _returning}, + ) def return_defaults(self, *cols): """Make use of a :term:`RETURNING` clause for the purpose of fetching server-side expressions and defaults. @@ -783,10 +791,6 @@ class ValuesBase(UpdateBase): :attr:`_engine.CursorResult.inserted_primary_key_rows` """ - if self._returning: - raise exc.InvalidRequestError( - "RETURNING is already configured on this statement" - ) self._return_defaults = cols or True diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 7fd24e8b51..84646d3802 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -1000,6 +1000,22 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): Column("baz", String(10)), ) + def test_no_call_twice(self): + stmt = insert(self.table).values( + [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] + ) + stmt = stmt.on_duplicate_key_update( + bar=stmt.inserted.bar, baz=stmt.inserted.baz + ) + with testing.expect_raises_message( + exc.InvalidRequestError, + "This Insert construct already has an " + "ON DUPLICATE KEY clause present", + ): + stmt = stmt.on_duplicate_key_update( + bar=stmt.inserted.bar, baz=stmt.inserted.baz + ) + def test_from_values(self): stmt = insert(self.table).values( [{"id": 1, "bar": "ab"}, {"id": 2, "bar": "b"}] diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index b3a0b9bbde..eb39091ae4 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1842,6 +1842,27 @@ class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): "goofy_index", table1.c.name, postgresql_where=table1.c.name > "m" ) + def test_on_conflict_do_no_call_twice(self): + users = self.table1 + + for stmt in ( + insert(users).on_conflict_do_nothing(), + insert(users).on_conflict_do_update( + index_elements=[users.c.myid], set_=dict(name="foo") + ), + ): + for meth in ( + stmt.on_conflict_do_nothing, + stmt.on_conflict_do_update, + ): + + with testing.expect_raises_message( + exc.InvalidRequestError, + "This Insert construct already has an " + "ON CONFLICT clause established", + ): + meth() + def test_do_nothing_no_target(self): i = insert( diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 0500a20bd6..23ceb88b3c 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -2717,6 +2717,27 @@ class OnConflictTest(fixtures.TablesTest): ValueError, insert(self.tables.users).on_conflict_do_update ) + def test_on_conflict_do_no_call_twice(self): + users = self.tables.users + + for stmt in ( + insert(users).on_conflict_do_nothing(), + insert(users).on_conflict_do_update( + index_elements=[users.c.id], set_=dict(name="foo") + ), + ): + for meth in ( + stmt.on_conflict_do_nothing, + stmt.on_conflict_do_update, + ): + + with testing.expect_raises_message( + exc.InvalidRequestError, + "This Insert construct already has an " + "ON CONFLICT clause established", + ): + meth() + def test_on_conflict_do_nothing(self, connection): users = self.tables.users diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 946a01651a..26b0f6217a 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -672,7 +672,7 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): stmt = table1.update().ordered_values(("myid", 1), ("name", "d1")) assert_raises_message( - exc.ArgumentError, + exc.InvalidRequestError, "This statement already has ordered values present", stmt.values, {"myid": 2, "name": "d2"},