From 299676fd32c0d0bf8a6700f8bc4029219c13abca Mon Sep 17 00:00:00 2001 From: Daniel Black Date: Thu, 9 Sep 2021 18:55:01 +1000 Subject: [PATCH] Add insert/delete returning for MariaDB As MariaDB doesn't support update inserting the full_returning is complimented with insert_returning and delete_returning. Fixes: #7011 --- lib/sqlalchemy/dialects/mssql/base.py | 2 ++ lib/sqlalchemy/dialects/mysql/base.py | 27 ++++++++++++++++++++-- lib/sqlalchemy/dialects/postgresql/base.py | 3 +++ lib/sqlalchemy/engine/default.py | 2 ++ lib/sqlalchemy/orm/persistence.py | 15 ++++++------ lib/sqlalchemy/testing/requirements.py | 9 ++++++++ test/orm/test_update_delete.py | 7 +++--- 7 files changed, 53 insertions(+), 12 deletions(-) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 7946633eb5..4a517d7b09 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2601,6 +2601,8 @@ class MSDialect(default.DefaultDialect): implicit_returning = True full_returning = True + insert_returning = True + delete_returning = True colspecs = { sqltypes.DateTime: _MSDateTime, diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 2bba2f81a7..c8e204895c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -991,6 +991,7 @@ from ...engine import reflection from ...sql import coercions from ...sql import compiler from ...sql import elements +from ...sql import expression from ...sql import functions from ...sql import operators from ...sql import roles @@ -1797,6 +1798,14 @@ class MySQLCompiler(compiler.SQLCompiler): return tmp + def returning_clause(self, stmt, returning_cols): + columns = [ + self._label_returning_column(stmt, c) + for c in expression._select_iterables(returning_cols) + ] + + return "RETURNING " + ", ".join(columns) + def limit_clause(self, select, **kw): # MySQL supports: # LIMIT @@ -2776,7 +2785,8 @@ class MySQLDialect(default.DefaultDialect): server_version_info = tuple(version) - self._set_mariadb(server_version_info and is_mariadb, val) + self._set_mariadb(server_version_info and is_mariadb, + server_version_info) if not is_mariadb: self._mariadb_normalized_version_info = server_version_info @@ -2798,9 +2808,14 @@ class MySQLDialect(default.DefaultDialect): if not is_mariadb and self.is_mariadb: raise exc.InvalidRequestError( "MySQL version %s is not a MariaDB variant." - % (server_version_info,) + % ('.'.join(map(str, server_version_info)),) ) self.is_mariadb = is_mariadb + if server_version_info is not None: + if server_version_info >= (10, 5): + self.insert_returning = True + if server_version_info >= (10, 0, 5): + self.delete_returning = True def do_begin_twophase(self, connection, xid): connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid)) @@ -2975,6 +2990,14 @@ class MySQLDialect(default.DefaultDialect): not self.is_mariadb and self.server_version_info >= (8,) ) + self.delete_returning = ( + self.is_mariadb and self.server_version_info >= (10, 0, 5) + ) + + self.insert_returning = ( + self.is_mariadb and self.server_version_info >= (10, 5) + ) + self._warn_for_known_db_issues() def _warn_for_known_db_issues(self): diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index f33542ee80..da550d34e2 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3127,6 +3127,8 @@ class PGDialect(default.DefaultDialect): implicit_returning = True full_returning = True + delete_returning = True + insert_returning = True connection_characteristics = ( default.DefaultDialect.connection_characteristics @@ -3191,6 +3193,7 @@ class PGDialect(default.DefaultDialect): if self.server_version_info <= (8, 2): self.full_returning = self.implicit_returning = False + self.delete_returning = self.insert_returning = False self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index eff28e3400..dab2dfd6b4 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -78,6 +78,8 @@ class DefaultDialect(interfaces.Dialect): postfetch_lastrowid = True implicit_returning = False full_returning = False + delete_returning = False + insert_returning = False insert_executemany_returning = False cte_follows_insert = False diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index fd484b52b3..2133f767bc 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -2069,9 +2069,12 @@ class BulkUDCompileState(CompileState): ) select_stmt._where_criteria = statement._where_criteria - def skip_for_full_returning(orm_context): + def skip_for_returning(orm_context): bind = orm_context.session.get_bind(**orm_context.bind_arguments) - if bind.dialect.full_returning: + if ( + (cls == BulkORMDelete and bind.dialect.delete_returning) or + bind.dialect.full_returning + ): return _result.null_result() else: return None @@ -2081,7 +2084,7 @@ class BulkUDCompileState(CompileState): params, execution_options, bind_arguments, - _add_event=skip_for_full_returning, + _add_event=skip_for_returning, ) matched_rows = result.fetchall() @@ -2311,10 +2314,8 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState): statement = statement.where(*new_crit) if ( - mapper - and compiler._annotations.get("synchronize_session", None) - == "fetch" - and compiler.dialect.full_returning + mapper and compiler.dialect.delete_returning and + compiler._annotations.get("synchronize_session", None) == "fetch" ): statement = statement.returning(*mapper.primary_key) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index f6e79042c9..ad866c851a 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -359,6 +359,15 @@ class SuiteRequirements(Requirements): return exclusions.open() + @property + def insert_returning(self): + """target platform supports INSERT ... RETURNING.""" + + return exclusions.only_if( + lambda config: config.db.dialect.insert_returning, + "%(database)s %(does_support)s 'INSERT ... RETURNING'", + ) + @property def full_returning(self): """target platform supports RETURNING completely, including diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 54a9d163dd..3c25d3043f 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -973,7 +973,7 @@ class UpdateDeleteTest(fixtures.MappedTest): synchronize_session="fetch" ) - if testing.db.dialect.full_returning: + if testing.db.dialect.delete_returning: asserter.assert_( CompiledSQL( "DELETE FROM users WHERE users.age_int > %(age_int_1)s " @@ -1018,7 +1018,7 @@ class UpdateDeleteTest(fixtures.MappedTest): stmt, execution_options={"synchronize_session": "fetch"} ) - if testing.db.dialect.full_returning: + if testing.db.dialect.delete_returning: asserter.assert_( CompiledSQL( "DELETE FROM users WHERE users.age_int > %(age_int_1)s " @@ -2084,7 +2084,7 @@ class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): class LoadFromReturningTest(fixtures.MappedTest): __backend__ = True - __requires__ = ("full_returning",) + __requires__ = ("insert_returning",) @classmethod def define_tables(cls, metadata): @@ -2133,6 +2133,7 @@ class LoadFromReturningTest(fixtures.MappedTest): }, ) + @testing.requires.full_returning def test_load_from_update(self, connection): User = self.classes.User -- 2.47.3