]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add insert/delete returning for MariaDB
authorDaniel Black <daniel@mariadb.org>
Thu, 9 Sep 2021 08:55:01 +0000 (18:55 +1000)
committerDaniel Black <daniel@mariadb.org>
Sat, 25 Sep 2021 06:23:19 +0000 (16:23 +1000)
As MariaDB doesn't support update inserting the
full_returning is complimented with insert_returning
and delete_returning.

Fixes: #7011
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/testing/requirements.py
test/orm/test_update_delete.py

index 7946633eb5861a20e275a53d4f2d0f62498192d3..4a517d7b0966c1a1e3f0cc99a7333859957bd382 100644 (file)
@@ -2601,6 +2601,8 @@ class MSDialect(default.DefaultDialect):
 
     implicit_returning = True
     full_returning = True
+    insert_returning = True
+    delete_returning = True
 
     colspecs = {
         sqltypes.DateTime: _MSDateTime,
index 2bba2f81a7249980748bccf376d1a90900ce35f0..c8e204895c348aa43754e0f8a15b1eb3f0f9a852 100644 (file)
@@ -991,6 +991,7 @@ from ...engine import reflection
 from ...sql import coercions
 from ...sql import compiler
 from ...sql import elements
+from ...sql import expression
 from ...sql import functions
 from ...sql import operators
 from ...sql import roles
@@ -1797,6 +1798,14 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return tmp
 
+    def returning_clause(self, stmt, returning_cols):
+        columns = [
+            self._label_returning_column(stmt, c)
+            for c in expression._select_iterables(returning_cols)
+        ]
+
+        return "RETURNING " + ", ".join(columns)
+
     def limit_clause(self, select, **kw):
         # MySQL supports:
         #   LIMIT <limit>
@@ -2776,7 +2785,8 @@ class MySQLDialect(default.DefaultDialect):
 
         server_version_info = tuple(version)
 
-        self._set_mariadb(server_version_info and is_mariadb, val)
+        self._set_mariadb(server_version_info and is_mariadb,
+                          server_version_info)
 
         if not is_mariadb:
             self._mariadb_normalized_version_info = server_version_info
@@ -2798,9 +2808,14 @@ class MySQLDialect(default.DefaultDialect):
         if not is_mariadb and self.is_mariadb:
             raise exc.InvalidRequestError(
                 "MySQL version %s is not a MariaDB variant."
-                % (server_version_info,)
+                % ('.'.join(map(str, server_version_info)),)
             )
         self.is_mariadb = is_mariadb
+        if server_version_info is not None:
+            if server_version_info >= (10, 5):
+                self.insert_returning = True
+            if server_version_info >= (10, 0, 5):
+                self.delete_returning = True
 
     def do_begin_twophase(self, connection, xid):
         connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
@@ -2975,6 +2990,14 @@ class MySQLDialect(default.DefaultDialect):
             not self.is_mariadb and self.server_version_info >= (8,)
         )
 
+        self.delete_returning = (
+            self.is_mariadb and self.server_version_info >= (10, 0, 5)
+        )
+
+        self.insert_returning = (
+            self.is_mariadb and self.server_version_info >= (10, 5)
+        )
+
         self._warn_for_known_db_issues()
 
     def _warn_for_known_db_issues(self):
index f33542ee80f304263131a14de6cd20bb1fcedebc..da550d34e2d3c43e65aba8aa1d9f08077e5a304d 100644 (file)
@@ -3127,6 +3127,8 @@ class PGDialect(default.DefaultDialect):
 
     implicit_returning = True
     full_returning = True
+    delete_returning = True
+    insert_returning = True
 
     connection_characteristics = (
         default.DefaultDialect.connection_characteristics
@@ -3191,6 +3193,7 @@ class PGDialect(default.DefaultDialect):
 
         if self.server_version_info <= (8, 2):
             self.full_returning = self.implicit_returning = False
+            self.delete_returning = self.insert_returning = False
 
         self.supports_native_enum = self.server_version_info >= (8, 3)
         if not self.supports_native_enum:
index eff28e34008a4e89e58f3d894d324f4ef3385115..dab2dfd6b4a4f72408611c4c3b0f0130b6359583 100644 (file)
@@ -78,6 +78,8 @@ class DefaultDialect(interfaces.Dialect):
     postfetch_lastrowid = True
     implicit_returning = False
     full_returning = False
+    delete_returning = False
+    insert_returning = False
     insert_executemany_returning = False
 
     cte_follows_insert = False
index fd484b52b30df3d690df97aefb0f5480bba94037..2133f767bcf9d75e5e669b4d8b92f3a8eb6b029b 100644 (file)
@@ -2069,9 +2069,12 @@ class BulkUDCompileState(CompileState):
         )
         select_stmt._where_criteria = statement._where_criteria
 
-        def skip_for_full_returning(orm_context):
+        def skip_for_returning(orm_context):
             bind = orm_context.session.get_bind(**orm_context.bind_arguments)
-            if bind.dialect.full_returning:
+            if (
+                (cls == BulkORMDelete and bind.dialect.delete_returning) or
+                bind.dialect.full_returning
+            ):
                 return _result.null_result()
             else:
                 return None
@@ -2081,7 +2084,7 @@ class BulkUDCompileState(CompileState):
             params,
             execution_options,
             bind_arguments,
-            _add_event=skip_for_full_returning,
+            _add_event=skip_for_returning,
         )
         matched_rows = result.fetchall()
 
@@ -2311,10 +2314,8 @@ class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
             statement = statement.where(*new_crit)
 
         if (
-            mapper
-            and compiler._annotations.get("synchronize_session", None)
-            == "fetch"
-            and compiler.dialect.full_returning
+            mapper and compiler.dialect.delete_returning and
+            compiler._annotations.get("synchronize_session", None) == "fetch"
         ):
             statement = statement.returning(*mapper.primary_key)
 
index f6e79042c99617a133b294b67be0867c2e871536..ad866c851a4c79fb64726c40e3569f7250d495a4 100644 (file)
@@ -359,6 +359,15 @@ class SuiteRequirements(Requirements):
 
         return exclusions.open()
 
+    @property
+    def insert_returning(self):
+        """target platform supports INSERT ... RETURNING."""
+
+        return exclusions.only_if(
+            lambda config: config.db.dialect.insert_returning,
+            "%(database)s %(does_support)s 'INSERT ... RETURNING'",
+        )
+
     @property
     def full_returning(self):
         """target platform supports RETURNING completely, including
index 54a9d163dddc8b57d0510ae1e7a55d83083274cf..3c25d3043f82043b1ac99910a0709cba9d7bb177 100644 (file)
@@ -973,7 +973,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 synchronize_session="fetch"
             )
 
-        if testing.db.dialect.full_returning:
+        if testing.db.dialect.delete_returning:
             asserter.assert_(
                 CompiledSQL(
                     "DELETE FROM users WHERE users.age_int > %(age_int_1)s "
@@ -1018,7 +1018,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 stmt, execution_options={"synchronize_session": "fetch"}
             )
 
-        if testing.db.dialect.full_returning:
+        if testing.db.dialect.delete_returning:
             asserter.assert_(
                 CompiledSQL(
                     "DELETE FROM users WHERE users.age_int > %(age_int_1)s "
@@ -2084,7 +2084,7 @@ class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest):
 
 class LoadFromReturningTest(fixtures.MappedTest):
     __backend__ = True
-    __requires__ = ("full_returning",)
+    __requires__ = ("insert_returning",)
 
     @classmethod
     def define_tables(cls, metadata):
@@ -2133,6 +2133,7 @@ class LoadFromReturningTest(fixtures.MappedTest):
             },
         )
 
+    @testing.requires.full_returning
     def test_load_from_update(self, connection):
         User = self.classes.User