implicit_returning = True
full_returning = True
+ insert_returning = True
+ delete_returning = True
colspecs = {
sqltypes.DateTime: _MSDateTime,
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
+from ...sql import expression
from ...sql import functions
from ...sql import operators
from ...sql import roles
return tmp
+ def returning_clause(self, stmt, returning_cols):
+ columns = [
+ self._label_returning_column(stmt, c)
+ for c in expression._select_iterables(returning_cols)
+ ]
+
+ return "RETURNING " + ", ".join(columns)
+
def limit_clause(self, select, **kw):
# MySQL supports:
# LIMIT <limit>
server_version_info = tuple(version)
- self._set_mariadb(server_version_info and is_mariadb, val)
+ self._set_mariadb(server_version_info and is_mariadb,
+ server_version_info)
if not is_mariadb:
self._mariadb_normalized_version_info = server_version_info
if not is_mariadb and self.is_mariadb:
raise exc.InvalidRequestError(
"MySQL version %s is not a MariaDB variant."
- % (server_version_info,)
+ % ('.'.join(map(str, server_version_info)),)
)
self.is_mariadb = is_mariadb
+ if server_version_info is not None:
+ if server_version_info >= (10, 5):
+ self.insert_returning = True
+ if server_version_info >= (10, 0, 5):
+ self.delete_returning = True
def do_begin_twophase(self, connection, xid):
connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
not self.is_mariadb and self.server_version_info >= (8,)
)
+ self.delete_returning = (
+ self.is_mariadb and self.server_version_info >= (10, 0, 5)
+ )
+
+ self.insert_returning = (
+ self.is_mariadb and self.server_version_info >= (10, 5)
+ )
+
self._warn_for_known_db_issues()
def _warn_for_known_db_issues(self):
implicit_returning = True
full_returning = True
+ delete_returning = True
+ insert_returning = True
connection_characteristics = (
default.DefaultDialect.connection_characteristics
if self.server_version_info <= (8, 2):
self.full_returning = self.implicit_returning = False
+ self.delete_returning = self.insert_returning = False
self.supports_native_enum = self.server_version_info >= (8, 3)
if not self.supports_native_enum:
postfetch_lastrowid = True
implicit_returning = False
full_returning = False
+ delete_returning = False
+ insert_returning = False
insert_executemany_returning = False
cte_follows_insert = False
)
select_stmt._where_criteria = statement._where_criteria
- def skip_for_full_returning(orm_context):
+ def skip_for_returning(orm_context):
bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- if bind.dialect.full_returning:
+ if (
+ (cls == BulkORMDelete and bind.dialect.delete_returning) or
+ bind.dialect.full_returning
+ ):
return _result.null_result()
else:
return None
params,
execution_options,
bind_arguments,
- _add_event=skip_for_full_returning,
+ _add_event=skip_for_returning,
)
matched_rows = result.fetchall()
statement = statement.where(*new_crit)
if (
- mapper
- and compiler._annotations.get("synchronize_session", None)
- == "fetch"
- and compiler.dialect.full_returning
+ mapper and compiler.dialect.delete_returning and
+ compiler._annotations.get("synchronize_session", None) == "fetch"
):
statement = statement.returning(*mapper.primary_key)
return exclusions.open()
+ @property
+ def insert_returning(self):
+ """target platform supports INSERT ... RETURNING."""
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.insert_returning,
+ "%(database)s %(does_support)s 'INSERT ... RETURNING'",
+ )
+
@property
def full_returning(self):
"""target platform supports RETURNING completely, including
synchronize_session="fetch"
)
- if testing.db.dialect.full_returning:
+ if testing.db.dialect.delete_returning:
asserter.assert_(
CompiledSQL(
"DELETE FROM users WHERE users.age_int > %(age_int_1)s "
stmt, execution_options={"synchronize_session": "fetch"}
)
- if testing.db.dialect.full_returning:
+ if testing.db.dialect.delete_returning:
asserter.assert_(
CompiledSQL(
"DELETE FROM users WHERE users.age_int > %(age_int_1)s "
class LoadFromReturningTest(fixtures.MappedTest):
__backend__ = True
- __requires__ = ("full_returning",)
+ __requires__ = ("insert_returning",)
@classmethod
def define_tables(cls, metadata):
},
)
+ @testing.requires.full_returning
def test_load_from_update(self, connection):
User = self.classes.User