From 466ed5b53a3af83f337c93be95715e4b3ab1255e Mon Sep 17 00:00:00 2001 From: Daniel Black Date: Tue, 28 Sep 2021 14:20:06 -0400 Subject: [PATCH] Generalize RETURNING and suppor for MariaDB / SQLite As almost every dialect supports RETURNING now, RETURNING is also made more of a default assumption. * the default compiler generates a RETURNING clause now when specified; CompileError is no longer raised. * The dialect-level implicit_returning parameter now has no effect. It's not fully clear if there are real world cases relying on the dialect-level parameter, so we will see once 2.0 is released. ORM-level RETURNING can be disabled at the table level, and perhaps "implicit returning" should become an ORM-level option at some point as that's where it applies. * Altered ORM update() / delete() to respect table-level implicit returning for fetch. * Since MariaDB doesnt support UPDATE returning, "full_returning" is now split into insert_returning, update_returning, delete_returning * Crazy new thing. Dialects that have *both* cursor.lastrowid *and* returning. so now we can pick between them for SQLite and mariadb. so, we are trying to keep it on .lastrowid for simple inserts with an autoincrement column, this helps with some edge case test scenarios and i bet .lastrowid is faster anyway. any return_defaults() / multiparams etc then we use returning * SQLite decided they dont want to return rows that match in ON CONFLICT. this is flat out wrong, but for now we need to work with it. Fixes: #6195 Fixes: #7011 Closes: #7047 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7047 Pull-request-sha: d25d5ea3abe094f282c53c7dd87f5f53a9e85248 Co-authored-by: Mike Bayer Change-Id: I9908ce0ff7bdc50bd5b27722081767c31c19a950 --- doc/build/glossary.rst | 41 +++ doc/build/orm/persistence_techniques.rst | 18 +- doc/build/orm/versioning.rst | 3 +- lib/sqlalchemy/dialects/mssql/base.py | 5 +- lib/sqlalchemy/dialects/mssql/pyodbc.py | 2 + lib/sqlalchemy/dialects/mysql/base.py | 45 ++- lib/sqlalchemy/dialects/oracle/base.py | 43 +-- lib/sqlalchemy/dialects/postgresql/base.py | 21 +- lib/sqlalchemy/dialects/postgresql/psycopg.py | 2 +- .../dialects/postgresql/psycopg2.py | 2 +- lib/sqlalchemy/dialects/sqlite/base.py | 66 ++++ lib/sqlalchemy/engine/create.py | 20 +- lib/sqlalchemy/engine/cursor.py | 2 +- lib/sqlalchemy/engine/default.py | 33 +- lib/sqlalchemy/engine/interfaces.py | 30 +- lib/sqlalchemy/ext/horizontal_shard.py | 1 - lib/sqlalchemy/orm/persistence.py | 45 ++- lib/sqlalchemy/sql/compiler.py | 15 +- lib/sqlalchemy/sql/crud.py | 66 +++- lib/sqlalchemy/sql/dml.py | 32 +- lib/sqlalchemy/sql/schema.py | 11 +- lib/sqlalchemy/sql/selectable.py | 4 + lib/sqlalchemy/testing/assertsql.py | 22 +- lib/sqlalchemy/testing/fixtures.py | 14 + lib/sqlalchemy/testing/requirements.py | 42 +-- lib/sqlalchemy/testing/suite/test_insert.py | 10 +- test/dialect/oracle/test_dialect.py | 2 +- test/dialect/oracle/test_types.py | 4 +- test/engine/test_deprecations.py | 58 +--- test/orm/test_defaults.py | 6 +- test/orm/test_events.py | 2 +- test/orm/test_naturalpks.py | 4 +- test/orm/test_unitofwork.py | 7 +- test/orm/test_unitofworkv2.py | 12 +- test/orm/test_update_delete.py | 60 +++- test/orm/test_versioning.py | 231 ++++++++----- test/requirements.py | 2 +- test/sql/test_defaults.py | 21 +- test/sql/test_insert.py | 10 +- test/sql/test_insert_exec.py | 30 +- test/sql/test_returning.py | 313 ++++++++++-------- test/sql/test_sequences.py | 26 +- test/sql/test_type_expressions.py | 2 +- 43 files changed, 868 insertions(+), 517 deletions(-) diff --git a/doc/build/glossary.rst b/doc/build/glossary.rst index a54d7715e5..ccaa27e9c1 100644 --- a/doc/build/glossary.rst +++ b/doc/build/glossary.rst @@ -169,6 +169,47 @@ Glossary also known as :term:`DML`, and typically refers to the ``INSERT``, ``UPDATE``, and ``DELETE`` statements. + executemany + This term refers to a part of the :pep:`249` DBAPI specification + indicating a single SQL statement that may be invoked against a + database connection with multiple parameter sets. The specific + method is known as ``cursor.executemany()``, and it has many + behavioral differences in comparison to the ``cursor.execute()`` + method which is used for single-statement invocation. The "executemany" + method executes the given SQL statement multiple times, once for + each set of parameters passed. As such, DBAPIs generally cannot + return result sets when ``cursor.executemany()`` is used. An additional + limitation of ``cursor.executemany()`` is that database drivers which + support the ``cursor.lastrowid`` attribute, returning the most recently + inserted integer primary key value, also don't support this attribute + when using ``cursor.executemany()``. + + SQLAlchemy makes use of ``cursor.executemany()`` when the + :meth:`_engine.Connection.execute` method is used, passing a list of + parameter dictionaries, instead of just a single parameter dictionary. + When using this form, the returned :class:`_result.Result` object will + not return any rows, even if the given SQL statement uses a form such + as RETURNING. + + Since "executemany" makes it generally impossible to receive results + back that indicate the newly generated values of server-generated + identifiers, the SQLAlchemy ORM can use "executemany" style + statement invocations only in certain circumstances when INSERTing + rows; while "executemany" is generally + associated with faster performance for running many INSERT statements + at once, the SQLAlchemy ORM can only make use of it in those + circumstances where it does not need to fetch newly generated primary + key values or server side default values. Newer versions of SQLAlchemy + make use of an alternate form of INSERT which is to pass a single + VALUES clause with many parameter sets at once, which does support + RETURNING. This form is available + in SQLAlchemy Core using the :meth:`.Insert.values` method. + + .. seealso:: + + :ref:`tutorial_multiple_parameters` - tutorial introduction to + "executemany" + marshalling data marshalling The process of transforming the memory representation of an object to diff --git a/doc/build/orm/persistence_techniques.rst b/doc/build/orm/persistence_techniques.rst index 8d18ac7ceb..7ad4b307a7 100644 --- a/doc/build/orm/persistence_techniques.rst +++ b/doc/build/orm/persistence_techniques.rst @@ -35,13 +35,12 @@ expired, so that when next accessed the newly generated value will be loaded from the database. The feature also has conditional support to work in conjunction with -primary key columns. A database that supports RETURNING, e.g. PostgreSQL, -Oracle, or SQL Server, or as a special case when using SQLite with the pysqlite -driver and a single auto-increment column, a SQL expression may be assigned -to a primary key column as well. This allows both the SQL expression to -be evaluated, as well as allows any server side triggers that modify the -primary key value on INSERT, to be successfully retrieved by the ORM as -part of the object's primary key:: +primary key columns. For backends that have RETURNING support +(including Oracle, SQL Server, MariaDB 10.5, SQLite 3.35) a +SQL expression may be assigned to a primary key column as well. This allows +both the SQL expression to be evaluated, as well as allows any server side +triggers that modify the primary key value on INSERT, to be successfully +retrieved by the ORM as part of the object's primary key:: class Foo(Base): @@ -271,9 +270,8 @@ so care must be taken to use the appropriate method. The two questions to be answered are, 1. is this column part of the primary key or not, and 2. does the database support RETURNING or an equivalent, such as "OUTPUT inserted"; these are SQL phrases which return a server-generated value at the same time as the -INSERT or UPDATE statement is invoked. Databases that support RETURNING or -equivalent include PostgreSQL, Oracle, and SQL Server. Databases that do not -include SQLite and MySQL. +INSERT or UPDATE statement is invoked. RETURNING is currently supported +by PostgreSQL, Oracle, MariaDB 10.5, SQLite 3.35, and SQL Server. Case 1: non primary key, RETURNING or equivalent is supported ------------------------------------------------------------- diff --git a/doc/build/orm/versioning.rst b/doc/build/orm/versioning.rst index 7aeca08738..ffc22fae68 100644 --- a/doc/build/orm/versioning.rst +++ b/doc/build/orm/versioning.rst @@ -204,8 +204,7 @@ missed version counters:: It is *strongly recommended* that server side version counters only be used when absolutely necessary and only on backends that support :term:`RETURNING`, -e.g. PostgreSQL, Oracle, SQL Server (though SQL Server has -`major caveats `_ when triggers are used), Firebird. +currently PostgreSQL, Oracle, MariaDB 10.5, SQLite 3.35, and SQL Server. .. versionadded:: 0.9.0 diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 4295e0ed06..12f495d6e2 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2807,8 +2807,9 @@ class MSDialect(default.DefaultDialect): max_identifier_length = 128 schema_name = "dbo" - implicit_returning = True - full_returning = True + insert_returning = True + update_returning = True + delete_returning = True colspecs = { sqltypes.DateTime: _MSDateTime, diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 28cca56f7f..6d64fdc3ed 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -522,6 +522,8 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): # mssql still has problems with this on Linux supports_sane_rowcount_returning = False + favor_returning_over_lastrowid = True + execution_ctx_cls = MSExecutionContext_pyodbc colspecs = util.update_copy( diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index b585ea992c..68653d9765 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -488,6 +488,37 @@ available. :class:`_mysql.match` +INSERT/DELETE...RETURNING +------------------------- + +The MariaDB dialect supports 10.5+'s ``INSERT..RETURNING`` and +``DELETE..RETURNING`` (10.0+) syntaxes. ``INSERT..RETURNING`` may be used +automatically in some cases in order to fetch newly generated identifiers in +place of the traditional approach of using ``cursor.lastrowid``, however +``cursor.lastrowid`` is currently still preferred for simple single-statement +cases for its better performance. + +To specify an explicit ``RETURNING`` clause, use the +:meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = connection.execute( + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # DELETE..RETURNING + result = connection.execute( + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + +.. versionadded:: 2.0 Added support for MariaDB RETURNING + .. _mysql_insert_on_duplicate_key_update: INSERT...ON DUPLICATE KEY UPDATE (Upsert) @@ -2500,7 +2531,9 @@ class MySQLDialect(default.DefaultDialect): server_version_info = tuple(version) - self._set_mariadb(server_version_info and is_mariadb, val) + self._set_mariadb( + server_version_info and is_mariadb, server_version_info + ) if not is_mariadb: self._mariadb_normalized_version_info = server_version_info @@ -2522,7 +2555,7 @@ class MySQLDialect(default.DefaultDialect): if not is_mariadb and self.is_mariadb: raise exc.InvalidRequestError( "MySQL version %s is not a MariaDB variant." - % (server_version_info,) + % (".".join(map(str, server_version_info)),) ) if is_mariadb: self.preparer = MariaDBIdentifierPreparer @@ -2717,6 +2750,14 @@ class MySQLDialect(default.DefaultDialect): not self.is_mariadb and self.server_version_info >= (8,) ) + self.delete_returning = ( + self.is_mariadb and self.server_version_info >= (10, 0, 5) + ) + + self.insert_returning = ( + self.is_mariadb and self.server_version_info >= (10, 5) + ) + self._warn_for_known_db_issues() def _warn_for_known_db_issues(self): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 37b81e1dd1..faac0deb74 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -293,40 +293,16 @@ added in a future release. RETURNING Support ----------------- -The Oracle database supports a limited form of RETURNING, in order to retrieve -result sets of matched rows from INSERT, UPDATE and DELETE statements. -Oracle's RETURNING..INTO syntax only supports one row being returned, as it -relies upon OUT parameters in order to function. In addition, supported -DBAPIs have further limitations (see :ref:`cx_oracle_returning`). +The Oracle database supports RETURNING fully for INSERT, UPDATE and DELETE +statements that are invoked with a single collection of bound parameters +(that is, a ``cursor.execute()`` style statement; SQLAlchemy does not generally +support RETURNING with :term:`executemany` statements). Multiple rows may be +returned as well. -SQLAlchemy's "implicit returning" feature, which employs RETURNING within an -INSERT and sometimes an UPDATE statement in order to fetch newly generated -primary key values and other SQL defaults and expressions, is normally enabled -on the Oracle backend. By default, "implicit returning" typically only -fetches the value of a single ``nextval(some_seq)`` expression embedded into -an INSERT in order to increment a sequence within an INSERT statement and get -the value back at the same time. To disable this feature across the board, -specify ``implicit_returning=False`` to :func:`_sa.create_engine`:: +.. versionchanged:: 2.0 the Oracle backend has full support for RETURNING + on parity with other backends. - engine = create_engine("oracle+cx_oracle://scott:tiger@dsn", - implicit_returning=False) -Implicit returning can also be disabled on a table-by-table basis as a table -option:: - - # Core Table - my_table = Table("my_table", metadata, ..., implicit_returning=False) - - - # declarative - class MyClass(Base): - __tablename__ = 'my_table' - __table_args__ = {"implicit_returning": False} - -.. seealso:: - - :ref:`cx_oracle_returning` - additional cx_oracle-specific restrictions on - implicit returning. ON UPDATE CASCADE ----------------- @@ -1572,8 +1548,9 @@ class OracleDialect(default.DefaultDialect): supports_alter = True max_identifier_length = 128 - implicit_returning = True - full_returning = True + insert_returning = True + update_returning = True + delete_returning = True div_is_floordiv = False diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 146e59c4d1..83e46151f0 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -44,8 +44,6 @@ subsequent insert. Note that when an apply; no RETURNING clause is emitted nor is the sequence pre-executed in this case. -To force the usage of RETURNING by default off, specify the flag -``implicit_returning=False`` to :func:`_sa.create_engine`. PostgreSQL 10 and above IDENTITY columns ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -2351,16 +2349,6 @@ class PGCompiler(compiler.SQLCompiler): return tmp - def returning_clause( - self, stmt, returning_cols, *, populate_result_map, **kw - ): - columns = [ - self._label_returning_column(stmt, c, populate_result_map) - for c in expression._select_iterables(returning_cols) - ] - - return "RETURNING " + ", ".join(columns) - def visit_substring_func(self, func, **kw): s = self.process(func.clauses.clauses[0], **kw) start = self.process(func.clauses.clauses[1], **kw) @@ -3207,8 +3195,9 @@ class PGDialect(default.DefaultDialect): execution_ctx_cls = PGExecutionContext inspector = PGInspector - implicit_returning = True - full_returning = True + update_returning = True + delete_returning = True + insert_returning = True connection_characteristics = ( default.DefaultDialect.connection_characteristics @@ -3274,7 +3263,9 @@ class PGDialect(default.DefaultDialect): super(PGDialect, self).initialize(connection) if self.server_version_info <= (8, 2): - self.full_returning = self.implicit_returning = False + self.delete_returning = ( + self.update_returning + ) = self.insert_returning = False self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 7ec26cb4ec..90bae61e1a 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -256,7 +256,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): # PGDialect.initialize() checks server version for <= 8.2 and sets # this flag to False if so - if not self.full_returning: + if not self.insert_returning: self.insert_executemany_returning = False # HSTORE can't be registered until we have a connection so that diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index f5d84a5a35..3f4ee2a20a 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -613,7 +613,7 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): # PGDialect.initialize() checks server version for <= 8.2 and sets # this flag to False if so - if not self.full_returning: + if not self.insert_returning: self.insert_executemany_returning = False self.executemany_mode = EXECUTEMANY_PLAIN diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 2ce2984368..fdcd1340bb 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -221,6 +221,46 @@ by *not even emitting BEGIN* until the first write operation. :ref:`dbapi_autocommit` +INSERT/UPDATE/DELETE...RETURNING +--------------------------------- + +The SQLite dialect supports SQLite 3.35's ``INSERT|UPDATE|DELETE..RETURNING`` +syntax. ``INSERT..RETURNING`` may be used +automatically in some cases in order to fetch newly generated identifiers in +place of the traditional approach of using ``cursor.lastrowid``, however +``cursor.lastrowid`` is currently still preferred for simple single-statement +cases for its better performance. + +To specify an explicit ``RETURNING`` clause, use the +:meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = connection.execute( + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # UPDATE..RETURNING + result = connection.execute( + table.update(). + where(table.c.name=='foo'). + values(name='bar'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # DELETE..RETURNING + result = connection.execute( + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + +.. versionadded:: 2.0 Added support for SQLite RETURNING + SAVEPOINT Support ---------------------------- @@ -1280,6 +1320,19 @@ class SQLiteCompiler(compiler.SQLCompiler): "%s is not a valid extract argument." % extract.field ) from err + def returning_clause( + self, + stmt, + returning_cols, + *, + populate_result_map, + **kw, + ): + kw["include_table"] = False + return super().returning_clause( + stmt, returning_cols, populate_result_map=populate_result_map, **kw + ) + def limit_clause(self, select, **kw): text = "" if select._limit_clause is not None: @@ -1372,6 +1425,11 @@ class SQLiteCompiler(compiler.SQLCompiler): return target_text + def visit_insert(self, insert_stmt, **kw): + if insert_stmt._post_values_clause is not None: + kw["disable_implicit_returning"] = True + return super().visit_insert(insert_stmt, **kw) + def visit_on_conflict_do_nothing(self, on_conflict, **kw): target_text = self._on_conflict_target(on_conflict, **kw) @@ -1831,6 +1889,9 @@ class SQLiteDialect(default.DefaultDialect): supports_default_values = True supports_default_metavalue = False + # https://github.com/python/cpython/issues/93421 + supports_sane_rowcount_returning = False + supports_empty_insert = False supports_cast = True supports_multivalues_insert = True @@ -1944,6 +2005,11 @@ class SQLiteDialect(default.DefaultDialect): 14, ) + if self.dbapi.sqlite_version_info >= (3, 35): + self.update_returning = ( + self.delete_returning + ) = self.insert_returning = True + _isolation_lookup = util.immutabledict( {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0} ) diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 68a6b81e2c..36119ab242 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -57,7 +57,7 @@ def create_engine( execution_options: _ExecuteOptions = ..., future: Literal[True], hide_parameters: bool = ..., - implicit_returning: bool = ..., + implicit_returning: Literal[True] = ..., isolation_level: _IsolationLevel = ..., json_deserializer: Callable[..., Any] = ..., json_serializer: Callable[..., Any] = ..., @@ -266,18 +266,12 @@ def create_engine(url: Union[str, "_url.URL"], **kwargs: Any) -> Engine: :ref:`dbengine_logging` - further detail on how to configure logging. - :param implicit_returning=True: Legacy flag that when set to ``False`` - will disable the use of ``RETURNING`` on supporting backends where it - would normally be used to fetch newly generated primary key values for - single-row INSERT statements that do not otherwise specify a RETURNING - clause. This behavior applies primarily to the PostgreSQL, Oracle, - SQL Server backends. - - .. warning:: this flag originally allowed the "implicit returning" - feature to be *enabled* back when it was very new and there was not - well-established database support. In modern SQLAlchemy, this flag - should **always be set to True**. Some SQLAlchemy features will - fail to function properly if this flag is set to ``False``. + :param implicit_returning=True: Legacy parameter that may only be set + to True. In SQLAlchemy 2.0, this parameter does nothing. In order to + disable "implicit returning" for statements invoked by the ORM, + configure this on a per-table basis using the + :paramref:`.Table.implicit_returning` parameter. + :param isolation_level: optional string name of an isolation level which will be set on all new connections unconditionally. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index ec1e1abe18..7947456afe 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1817,7 +1817,7 @@ class CursorResult(Result[_T]): def merge(self, *others: Result[Any]) -> MergedResult[Any]: merged_result = super().merge(*others) - setup_rowcounts = not self._metadata.returns_rows + setup_rowcounts = self.context._has_rowcount if setup_rowcounts: merged_result.rowcount = sum( cast("CursorResult[Any]", result).rowcount diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index bcbe83f3fd..6b76601ffe 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -57,6 +57,7 @@ from ..sql.compiler import DDLCompiler from ..sql.compiler import SQLCompiler from ..sql.elements import quoted_name from ..sql.schema import default_is_scalar +from ..util.typing import Literal if typing.TYPE_CHECKING: from types import ModuleType @@ -135,9 +136,11 @@ class DefaultDialect(Dialect): preexecute_autoincrement_sequences = False supports_identity_columns = False postfetch_lastrowid = True + favor_returning_over_lastrowid = False insert_null_pk_still_autoincrements = False - implicit_returning = False - full_returning = False + update_returning = False + delete_returning = False + insert_returning = False insert_executemany_returning = False cte_follows_insert = False @@ -258,7 +261,7 @@ class DefaultDialect(Dialect): paramstyle: Optional[_ParamStyle] = None, isolation_level: Optional[_IsolationLevel] = None, dbapi: Optional[ModuleType] = None, - implicit_returning: Optional[bool] = None, + implicit_returning: Literal[True] = True, supports_native_boolean: Optional[bool] = None, max_identifier_length: Optional[int] = None, label_length: Optional[int] = None, @@ -296,8 +299,6 @@ class DefaultDialect(Dialect): self.paramstyle = self.dbapi.paramstyle else: self.paramstyle = self.default_paramstyle - if implicit_returning is not None: - self.implicit_returning = implicit_returning self.positional = self.paramstyle in ("qmark", "format", "numeric") self.identifier_preparer = self.preparer(self) self._on_connect_isolation_level = isolation_level @@ -324,6 +325,18 @@ class DefaultDialect(Dialect): self.label_length = label_length self.compiler_linting = compiler_linting + @util.deprecated_property( + "2.0", + "full_returning is deprecated, please use insert_returning, " + "update_returning, delete_returning", + ) + def full_returning(self): + return ( + self.insert_returning + and self.update_returning + and self.delete_returning + ) + @util.memoized_property def loaded_dbapi(self) -> ModuleType: if self.dbapi is None: @@ -771,7 +784,6 @@ class StrCompileDialect(DefaultDialect): supports_sequences = True sequences_optional = True preexecute_autoincrement_sequences = False - implicit_returning = False supports_native_boolean = True @@ -806,6 +818,8 @@ class DefaultExecutionContext(ExecutionContext): _soft_closed = False + _has_rowcount = False + # a hook for SQLite's translation of # result column names # NOTE: pyhive is using this hook, can't remove it :( @@ -1450,6 +1464,7 @@ class DefaultExecutionContext(ExecutionContext): # is testing this, and psycopg will no longer return # rowcount after cursor is closed. result.rowcount + self._has_rowcount = True row = result.fetchone() if row is not None: @@ -1465,7 +1480,12 @@ class DefaultExecutionContext(ExecutionContext): # no results, get rowcount # (which requires open cursor on some drivers) result.rowcount + self._has_rowcount = True result._soft_close() + elif self.isupdate or self.isdelete: + result.rowcount + self._has_rowcount = True + return result @util.memoized_property @@ -1479,7 +1499,6 @@ class DefaultExecutionContext(ExecutionContext): getter = cast( SQLCompiler, self.compiled )._inserted_primary_key_from_lastrowid_getter - lastrowid = self.get_lastrowid() return [getter(lastrowid, self.compiled_parameters[0])] diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 4020af354b..cd6efb904d 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -737,14 +737,32 @@ class Dialect(EventTarget): PostgreSQL. """ - implicit_returning: bool - """For dialects that support RETURNING, indicate RETURNING may be used - to fetch newly generated primary key values and other defaults from - an INSERT statement automatically. + insert_returning: bool + """if the dialect supports RETURNING with INSERT - .. seealso:: + .. versionadded:: 2.0 + + """ + + update_returning: bool + """if the dialect supports RETURNING with UPDATE + + .. versionadded:: 2.0 + + """ + + delete_returning: bool + """if the dialect supports RETURNING with DELETE + + .. versionadded:: 2.0 + + """ + + favor_returning_over_lastrowid: bool + """for backends that support both a lastrowid and a RETURNING insert + strategy, favor RETURNING for simple single-int pk inserts. - :paramref:`_schema.Table.implicit_returning` + cursor.lastrowid tends to be more performant on most backends. """ diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 5588fd5870..7afe2343d2 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -253,5 +253,4 @@ def execute_and_instances(orm_context): for shard_id in session.execute_chooser(orm_context): result_ = iter_for_shard(shard_id, load_options, update_options) partial.append(result_) - return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 56e7cca1ad..0c035e7cfa 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -39,6 +39,7 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..engine import Dialect from ..engine import result as _result from ..sql import coercions from ..sql import expression @@ -57,6 +58,7 @@ from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL if TYPE_CHECKING: from .mapper import Mapper + from .session import ORMExecuteState from .session import SessionTransaction from .state import InstanceState @@ -1103,7 +1105,8 @@ def _emit_insert_statements( or ( has_all_defaults or not base_mapper.eager_defaults - or not connection.dialect.implicit_returning + or not base_mapper.local_table.implicit_returning + or not connection.dialect.insert_returning ) and has_all_pks and not hasvalue @@ -1118,7 +1121,6 @@ def _emit_insert_statements( c = connection.execute( statement, multiparams, execution_options=execution_options ) - if bookkeeping: for ( ( @@ -1802,6 +1804,10 @@ class BulkUDCompileState(CompileState): _matched_rows = None _refresh_identity_token = None + @classmethod + def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: + raise NotImplementedError() + @classmethod def orm_pre_session_exec( cls, @@ -2093,9 +2099,10 @@ class BulkUDCompileState(CompileState): ) select_stmt._where_criteria = statement._where_criteria - def skip_for_full_returning(orm_context): + def skip_for_returning(orm_context: ORMExecuteState) -> Any: bind = orm_context.session.get_bind(**orm_context.bind_arguments) - if bind.dialect.full_returning: + + if cls.can_use_returning(bind.dialect, mapper): return _result.null_result() else: return None @@ -2105,7 +2112,7 @@ class BulkUDCompileState(CompileState): params, execution_options=execution_options, bind_arguments=bind_arguments, - _add_event=skip_for_full_returning, + _add_event=skip_for_returning, ) matched_rows = result.fetchall() @@ -2283,10 +2290,9 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState): # if we are against a lambda statement we might not be the # topmost object that received per-execute annotations - if ( - compiler._annotations.get("synchronize_session", None) == "fetch" - and compiler.dialect.full_returning - ): + if compiler._annotations.get( + "synchronize_session", None + ) == "fetch" and self.can_use_returning(compiler.dialect, mapper): if new_stmt._returning: raise sa_exc.InvalidRequestError( "Can't use synchronize_session='fetch' " @@ -2298,6 +2304,12 @@ class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState): return self + @classmethod + def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: + return ( + dialect.update_returning and mapper.local_table.implicit_returning + ) + @classmethod def _get_crud_kv_pairs(cls, statement, kv_iterator): plugin_subject = statement._propagate_attrs["plugin_subject"] @@ -2478,18 +2490,21 @@ class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState): if new_crit: statement = statement.where(*new_crit) - if ( - mapper - and compiler._annotations.get("synchronize_session", None) - == "fetch" - and compiler.dialect.full_returning - ): + if compiler._annotations.get( + "synchronize_session", None + ) == "fetch" and self.can_use_returning(compiler.dialect, mapper): statement = statement.returning(*mapper.primary_key) DeleteDMLState.__init__(self, statement, compiler, **kw) return self + @classmethod + def can_use_returning(cls, dialect: Dialect, mapper: Mapper[Any]) -> bool: + return ( + dialect.delete_returning and mapper.local_table.implicit_returning + ) + @classmethod def _do_post_synchronize_evaluate(cls, session, result, update_options): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3685751b0d..78c6af38ba 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3482,7 +3482,7 @@ class SQLCompiler(Compiled): ) def _label_returning_column( - self, stmt, column, populate_result_map, column_clause_args=None + self, stmt, column, populate_result_map, column_clause_args=None, **kw ): """Render a column with necessary labels inside of a RETURNING clause. @@ -3499,6 +3499,7 @@ class SQLCompiler(Compiled): populate_result_map, False, {} if column_clause_args is None else column_clause_args, + **kw, ) def _label_select_column( @@ -3514,6 +3515,7 @@ class SQLCompiler(Compiled): within_columns_clause=True, column_is_repeated=False, need_column_expressions=False, + include_table=True, ): """produce labeled columns present in a select().""" impl = column.type.dialect_impl(self.dialect) @@ -3661,6 +3663,7 @@ class SQLCompiler(Compiled): column_clause_args.update( within_columns_clause=within_columns_clause, add_to_result_map=add_to_result_map, + include_table=include_table, ) return result_expr._compiler_dispatch(self, **column_clause_args) @@ -4218,10 +4221,12 @@ class SQLCompiler(Compiled): populate_result_map: bool, **kw: Any, ) -> str: - raise exc.CompileError( - "RETURNING is not supported by this " - "dialect's statement compiler." - ) + columns = [ + self._label_returning_column(stmt, c, populate_result_map, **kw) + for c in base._select_iterables(returning_cols) + ] + + return "RETURNING " + ", ".join(columns) def limit_clause(self, select, **kw): text = "" diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 913e4d4333..81151a26b7 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -568,6 +568,7 @@ def _scan_cols( _col_bind_name, implicit_returning, implicit_return_defaults, + postfetch_lastrowid, values, autoincrement_col, insert_null_pk_still_autoincrements, @@ -649,6 +650,7 @@ def _append_param_parameter( _col_bind_name, implicit_returning, implicit_return_defaults, + postfetch_lastrowid, values, autoincrement_col, insert_null_pk_still_autoincrements, @@ -668,11 +670,12 @@ def _append_param_parameter( and c is autoincrement_col ): # support use case for #7998, fetch autoincrement cols - # even if value was given - if implicit_returning: - compiler.implicit_returning.append(c) - elif compiler.dialect.postfetch_lastrowid: + # even if value was given. + + if postfetch_lastrowid: compiler.postfetch_lastrowid = True + elif implicit_returning: + compiler.implicit_returning.append(c) value = _create_bind_param( compiler, @@ -1281,7 +1284,12 @@ def _get_stmt_parameter_tuples_params( def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): + """determines RETURNING strategy, if any, for the statement. + + This is where it's determined what we need to fetch from the + INSERT or UPDATE statement after it's invoked. + """ need_pks = ( toplevel and compile_state.isinsert @@ -1296,19 +1304,58 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): and not stmt._returning and not compile_state._has_multi_parameters ) + + # check if we have access to simple cursor.lastrowid. we can use that + # after the INSERT if that's all we need. + postfetch_lastrowid = ( + need_pks + and compiler.dialect.postfetch_lastrowid + and stmt.table._autoincrement_column is not None + ) + + # see if we want to add RETURNING to an INSERT in order to get + # primary key columns back. This would be instead of postfetch_lastrowid + # if that's set. implicit_returning = ( + # statement itself can veto it need_pks - and compiler.dialect.implicit_returning - and stmt.table.implicit_returning + # the dialect can veto it if it just doesnt support RETURNING + # with INSERT + and compiler.dialect.insert_returning + # user-defined implicit_returning on Table can veto it + and compile_state._primary_table.implicit_returning + # the compile_state can veto it (SQlite uses this to disable + # RETURNING for an ON CONFLICT insert, as SQLite does not return + # for rows that were updated, which is wrong) + and compile_state._supports_implicit_returning + and ( + # since we support MariaDB and SQLite which also support lastrowid, + # decide if we should use lastrowid or RETURNING. for insert + # that didnt call return_defaults() and has just one set of + # parameters, we can use lastrowid. this is more "traditional" + # and a lot of weird use cases are supported by it. + # SQLite lastrowid times 3x faster than returning, + # Mariadb lastrowid 2x faster than returning + ( + not postfetch_lastrowid + or compiler.dialect.favor_returning_over_lastrowid + ) + or compile_state._has_multi_parameters + or stmt._return_defaults + ) ) + if implicit_returning: + postfetch_lastrowid = False + if compile_state.isinsert: implicit_return_defaults = implicit_returning and stmt._return_defaults elif compile_state.isupdate: implicit_return_defaults = ( - compiler.dialect.implicit_returning - and stmt.table.implicit_returning - and stmt._return_defaults + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and compiler.dialect.update_returning ) else: # this line is unused, currently we are always @@ -1321,7 +1368,6 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): else: implicit_return_defaults = set(stmt._return_defaults_columns) - postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid return ( need_pks, implicit_returning, diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index e63a34454d..28ea512a77 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -119,6 +119,8 @@ class DMLState(CompileState): _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None _parameter_ordering: Optional[List[_DMLColumnElement]] = None _has_multi_parameters = False + _primary_table: FromClause + _supports_implicit_returning = True isupdate = False isdelete = False @@ -182,11 +184,14 @@ class DMLState(CompileState): for k, v in kv_iterator ] - def _make_extra_froms(self, statement: DMLWhereBase) -> List[FromClause]: + def _make_extra_froms( + self, statement: DMLWhereBase + ) -> Tuple[FromClause, List[FromClause]]: froms: List[FromClause] = [] all_tables = list(sql_util.tables_from_leftmost(statement.table)) - seen = {all_tables[0]} + primary_table = all_tables[0] + seen = {primary_table} for crit in statement._where_criteria: for item in _from_objects(crit): @@ -195,7 +200,7 @@ class DMLState(CompileState): seen.update(item._cloned_set) froms.extend(all_tables[1:]) - return froms + return primary_table, froms def _process_multi_values(self, statement: ValuesBase) -> None: if not statement._supports_multi_parameters: @@ -286,8 +291,18 @@ class InsertDMLState(DMLState): include_table_with_column_exprs = False - def __init__(self, statement: Insert, compiler: SQLCompiler, **kw: Any): + def __init__( + self, + statement: Insert, + compiler: SQLCompiler, + disable_implicit_returning: bool = False, + **kw: Any, + ): self.statement = statement + self._primary_table = statement.table + + if disable_implicit_returning: + self._supports_implicit_returning = False self.isinsert = True if statement._select_names: @@ -306,6 +321,7 @@ class UpdateDMLState(DMLState): def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any): self.statement = statement + self.isupdate = True if statement._ordered_values is not None: self._process_ordered_values(statement) @@ -313,7 +329,9 @@ class UpdateDMLState(DMLState): self._process_values(statement) elif statement._multi_values: self._process_multi_values(statement) - self._extra_froms = ef = self._make_extra_froms(statement) + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef self.is_multitable = mt = ef @@ -330,7 +348,9 @@ class DeleteDMLState(DMLState): self.statement = statement self.isdelete = True - self._extra_froms = self._make_extra_froms(statement) + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef SelfUpdateBase = typing.TypeVar("SelfUpdateBase", bound="UpdateBase") diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 598bacc593..447e102ed1 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -639,10 +639,13 @@ class Table( :param implicit_returning: True by default - indicates that - RETURNING can be used by default to fetch newly inserted primary key - values, for backends which support this. Note that - :func:`_sa.create_engine` also provides an ``implicit_returning`` - flag. + RETURNING can be used, typically by the ORM, in order to fetch + server-generated values such as primary key values and + server side defaults, on those backends which support RETURNING. + + In modern SQLAlchemy there is generally no reason to alter this + setting, except in the case of some backends such as SQL Server + when INSERT triggers are used for that table. :param include_columns: A list of strings indicating a subset of columns to be loaded via the ``autoload`` operation; table columns who diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 53dcf51c77..eebefb8776 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1635,6 +1635,10 @@ class AliasedReturnsRows(NoInit, NamedFromClause): return name + @util.ro_non_memoized_property + def implicit_returning(self): + return self.element.implicit_returning # type: ignore + @property def original(self): """Legacy for dialects that are referring to Alias.original.""" diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index a6e3c87644..4416fe630e 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -67,10 +67,13 @@ class CursorSQL(SQLMatchRule): class CompiledSQL(SQLMatchRule): - def __init__(self, statement, params=None, dialect="default"): + def __init__( + self, statement, params=None, dialect="default", enable_returning=False + ): self.statement = statement self.params = params self.dialect = dialect + self.enable_returning = enable_returning def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r"[\n\t]", "", self.statement) @@ -82,14 +85,14 @@ class CompiledSQL(SQLMatchRule): # this is currently what tests are expecting # dialect.supports_default_values = True dialect.supports_default_metavalue = True + + if self.enable_returning: + dialect.insert_returning = ( + dialect.update_returning + ) = dialect.delete_returning = True return dialect else: - # ugh - if self.dialect == "postgresql": - params = {"implicit_returning": True} - else: - params = {} - return url.URL.create(self.dialect).get_dialect()(**params) + return url.URL.create(self.dialect).get_dialect()() def _received_statement(self, execute_observed): """reconstruct the statement and params in terms @@ -221,12 +224,15 @@ class CompiledSQL(SQLMatchRule): class RegexSQL(CompiledSQL): - def __init__(self, regex, params=None, dialect="default"): + def __init__( + self, regex, params=None, dialect="default", enable_returning=False + ): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params self.dialect = dialect + self.enable_returning = enable_returning def _failure_message(self, execute_observed, expected_params): return ( diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ae7a42488e..d0e7d8f3cf 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -89,6 +89,20 @@ class TestBase: # run a close all connections. conn.close() + @config.fixture() + def close_result_when_finished(self): + to_close = [] + + def go(result): + to_close.append(result) + + yield go + for r in to_close: + try: + r.close() + except: + pass + @config.fixture() def registry(self, metadata): reg = registry( diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 4fff6546ec..4f9c73cf6e 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -365,15 +365,30 @@ class SuiteRequirements(Requirements): return exclusions.open() @property - def full_returning(self): - """target platform supports RETURNING completely, including - multiple rows returned. + def delete_returning(self): + """target platform supports DELETE ... RETURNING.""" - """ + return exclusions.only_if( + lambda config: config.db.dialect.delete_returning, + "%(database)s %(does_support)s 'DELETE ... RETURNING'", + ) + + @property + def insert_returning(self): + """target platform supports INSERT ... RETURNING.""" + + return exclusions.only_if( + lambda config: config.db.dialect.insert_returning, + "%(database)s %(does_support)s 'INSERT ... RETURNING'", + ) + + @property + def update_returning(self): + """target platform supports UPDATE ... RETURNING.""" return exclusions.only_if( - lambda config: config.db.dialect.full_returning, - "%(database)s %(does_support)s 'RETURNING of multiple rows'", + lambda config: config.db.dialect.update_returning, + "%(database)s %(does_support)s 'UPDATE ... RETURNING'", ) @property @@ -390,21 +405,6 @@ class SuiteRequirements(Requirements): "multiple rows with INSERT executemany'", ) - @property - def returning(self): - """target platform supports RETURNING for at least one row. - - .. seealso:: - - :attr:`.Requirements.full_returning` - - """ - - return exclusions.only_if( - lambda config: config.db.dialect.implicit_returning, - "%(database)s %(does_support)s 'RETURNING of a single row'", - ) - @property def tuple_in(self): """Target platform supports the syntax diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index f0e4bfcc6d..2307d3b3f6 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -125,10 +125,14 @@ class InsertBehaviorTest(fixtures.TablesTest): # case, the row had to have been consumed at least. assert not r.returns_rows or r.fetchone() is None - @requirements.returning + @requirements.insert_returning def test_autoclose_on_insert_implicit_returning(self, connection): r = connection.execute( - self.tables.autoinc_pk.insert(), dict(data="some data") + # return_defaults() ensures RETURNING will be used, + # new in 2.0 as sqlite/mariadb offer both RETURNING and + # cursor.lastrowid + self.tables.autoinc_pk.insert().return_defaults(), + dict(data="some data"), ) assert r._soft_closed assert not r.closed @@ -295,7 +299,7 @@ class InsertBehaviorTest(fixtures.TablesTest): class ReturningTest(fixtures.TablesTest): run_create_tables = "each" - __requires__ = "returning", "autoincrement_insert" + __requires__ = "insert_returning", "autoincrement_insert" __backend__ = True def _assert_round_trip(self, table, conn): diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 8d74c1f489..eda0fc9867 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -626,7 +626,7 @@ class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL): dialect.initialize(Mock()) # oracle 8 / 8i support returning - assert dialect.implicit_returning + assert dialect.insert_returning assert not dialect._supports_char_length assert not dialect.use_ansi diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 799a5e7b65..23df01a0bc 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -229,7 +229,7 @@ class TypesTest(fixtures.TestBase): [(2, "value 2 ")], ) - @testing.requires.returning + @testing.requires.insert_returning def test_int_not_float(self, metadata, connection): m = metadata t1 = Table("t1", m, Column("foo", Integer)) @@ -243,7 +243,7 @@ class TypesTest(fixtures.TestBase): assert x == 5 assert isinstance(x, int) - @testing.requires.returning + @testing.requires.insert_returning def test_int_not_float_no_coerce_decimal(self, metadata): engine = testing_engine(options=dict(coerce_to_decimal=False)) diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index 5b723b8718..f7602f98a7 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -4,13 +4,8 @@ from unittest.mock import Mock import sqlalchemy as tsa from sqlalchemy import create_engine from sqlalchemy import event -from sqlalchemy import exc -from sqlalchemy import insert -from sqlalchemy import Integer -from sqlalchemy import MetaData from sqlalchemy import pool from sqlalchemy import select -from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.engine import BindTyping from sqlalchemy.engine import reflection @@ -29,10 +24,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_instance_of from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import expect_deprecated -from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.engines import testing_engine -from sqlalchemy.testing.schema import Column -from sqlalchemy.testing.schema import Table def _string_deprecation_expect(): @@ -442,55 +434,11 @@ class ImplicitReturningFlagTest(fixtures.TestBase): @testing.combinations(True, False, None, argnames="implicit_returning") def test_implicit_returning_engine_parameter(self, implicit_returning): if implicit_returning is None: - e = engines.testing_engine() + engines.testing_engine() else: with assertions.expect_deprecated(ce_implicit_returning): - e = engines.testing_engine( + engines.testing_engine( options={"implicit_returning": implicit_returning} ) - if implicit_returning is None: - eq_( - e.dialect.implicit_returning, - testing.db.dialect.implicit_returning, - ) - else: - eq_(e.dialect.implicit_returning, implicit_returning) - - t = Table( - "t", - MetaData(), - Column("id", Integer, primary_key=True), - Column("data", String(50)), - ) - - t2 = Table( - "t", - MetaData(), - Column("id", Integer, primary_key=True), - Column("data", String(50)), - implicit_returning=False, - ) - - with e.connect() as conn: - stmt = insert(t).values(data="data") - - if implicit_returning: - if not testing.requires.returning.enabled: - with expect_raises_message( - exc.CompileError, "RETURNING is not supported" - ): - stmt.compile(conn) - else: - eq_(stmt.compile(conn).implicit_returning, [t.c.id]) - elif ( - implicit_returning is None - and testing.db.dialect.implicit_returning - ): - eq_(stmt.compile(conn).implicit_returning, [t.c.id]) - else: - eq_(stmt.compile(conn).implicit_returning, []) - - # table setting it to False disables it - stmt2 = insert(t2).values(data="data") - eq_(stmt2.compile(conn).implicit_returning, []) + # parameter has no effect diff --git a/test/orm/test_defaults.py b/test/orm/test_defaults.py index fc8e455ea5..7860f5eb1d 100644 --- a/test/orm/test_defaults.py +++ b/test/orm/test_defaults.py @@ -278,7 +278,7 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): asserter.assert_( Conditional( - eager and testing.db.dialect.implicit_returning, + eager and testing.db.dialect.insert_returning, [ Conditional( testing.db.dialect.insert_executemany_returning, @@ -361,7 +361,7 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest): eq_(t1.bar, 5 + 42) eq_(t2.bar, 6 + 42) - if eager and testing.db.dialect.implicit_returning: + if eager and testing.db.dialect.update_returning: asserter.assert_( CompiledSQL( "UPDATE test SET foo=%(foo)s " @@ -462,7 +462,7 @@ class IdentityDefaultsOnUpdateTest(fixtures.MappedTest): asserter.assert_( Conditional( - testing.db.dialect.implicit_returning, + testing.db.dialect.insert_returning, [ Conditional( testing.db.dialect.insert_executemany_returning, diff --git a/test/orm/test_events.py b/test/orm/test_events.py index be1919614d..7e1b29cb1b 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -3411,7 +3411,7 @@ class RefreshFlushInReturningTest(fixtures.MappedTest): s.add(t1) s.flush() - if testing.requires.returning.enabled: + if testing.requires.insert_returning.enabled: # ordering is deterministic in this test b.c. the routine # appends the "returning" params before the "prefetch" # ones. if there were more than one attribute in each category, diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 0dc71f8b3a..64c033ec47 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -157,7 +157,7 @@ class NaturalPKTest(fixtures.MappedTest): assert sess.get(User, "jack") is None assert sess.get(User, "ed").fullname == "jack" - @testing.requires.returning + @testing.requires.update_returning def test_update_to_sql_expr(self): users, User = self.tables.users, self.classes.User @@ -169,6 +169,8 @@ class NaturalPKTest(fixtures.MappedTest): sess.add(u1) sess.flush() + # note this is the primary key, so you need UPDATE..RETURNING + # to catch this u1.username = User.username + " jones" sess.flush() diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 39223a3550..881eee4dde 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -1214,7 +1214,7 @@ class DefaultTest(fixtures.MappedTest): session = fixture_session() session.add(h1) - if testing.db.dialect.implicit_returning: + if testing.db.dialect.insert_returning: self.sql_count_(1, session.flush) else: self.sql_count_(2, session.flush) @@ -3502,7 +3502,10 @@ class NoRowInsertedTest(fixtures.TestBase): """ __backend__ = True - __requires__ = ("returning",) + + # the test manipulates INSERTS to become UPDATES to simulate + # "INSERT that returns no row" so both are needed + __requires__ = ("insert_returning", "update_returning") @testing.fixture def null_server_default_fixture(self, registry, connection): diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 7f4c046521..68099a7a0e 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -2409,7 +2409,7 @@ class EagerDefaultsTest(fixtures.MappedTest): s.add_all([t1, t2]) - if testing.db.dialect.implicit_returning: + if testing.db.dialect.insert_returning: self.assert_sql_execution( testing.db, s.flush, @@ -2469,7 +2469,7 @@ class EagerDefaultsTest(fixtures.MappedTest): testing.db, s.commit, Conditional( - testing.db.dialect.implicit_returning, + testing.db.dialect.insert_returning, [ Conditional( testing.db.dialect.insert_executemany_returning, @@ -2541,7 +2541,7 @@ class EagerDefaultsTest(fixtures.MappedTest): testing.db, s.flush, Conditional( - testing.db.dialect.implicit_returning, + testing.db.dialect.update_returning, [ CompiledSQL( "UPDATE test2 SET foo=%(foo)s " @@ -2633,7 +2633,7 @@ class EagerDefaultsTest(fixtures.MappedTest): t4.foo = 8 t4.bar = text("5 + 7") - if testing.db.dialect.implicit_returning: + if testing.db.dialect.update_returning: self.assert_sql_execution( testing.db, s.flush, @@ -3211,7 +3211,7 @@ class EnsureCacheTest(UOWTest): class ORMOnlyPrimaryKeyTest(fixtures.TestBase): @testing.requires.identity_columns - @testing.requires.returning + @testing.requires.insert_returning def test_a(self, base, run_test): class A(base): __tablename__ = "a" @@ -3224,7 +3224,7 @@ class ORMOnlyPrimaryKeyTest(fixtures.TestBase): run_test(A, A()) @testing.requires.sequences_as_server_defaults - @testing.requires.returning + @testing.requires.insert_returning def test_b(self, base, run_test): seq = Sequence("x_seq") diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 427e49e5e6..22d827be97 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -9,6 +9,7 @@ from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import lambda_stmt +from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String @@ -56,6 +57,19 @@ class UpdateDeleteTest(fixtures.MappedTest): Column("user_id", ForeignKey("users.id")), ) + m = MetaData() + users_no_returning = Table( + "users", + m, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(32)), + Column("age_int", Integer), + implicit_returning=False, + ) + cls.tables.users_no_returning = users_no_returning + @classmethod def setup_classes(cls): class User(cls.Comparable): @@ -64,6 +78,9 @@ class UpdateDeleteTest(fixtures.MappedTest): class Address(cls.Comparable): pass + class UserNoReturning(cls.Comparable): + pass + @classmethod def insert_data(cls, connection): users = cls.tables.users @@ -96,6 +113,16 @@ class UpdateDeleteTest(fixtures.MappedTest): ) cls.mapper_registry.map_imperatively(Address, addresses) + UserNoReturning = cls.classes.UserNoReturning + users_no_returning = cls.tables.users_no_returning + cls.mapper_registry.map_imperatively( + UserNoReturning, + users_no_returning, + properties={ + "age": users_no_returning.c.age_int, + }, + ) + @testing.combinations("table", "mapper", "both", argnames="bind_type") @testing.combinations( "update", "insert", "delete", argnames="statement_type" @@ -445,7 +472,7 @@ class UpdateDeleteTest(fixtures.MappedTest): {"age": User.age + 10}, synchronize_session="fetch" ) - if testing.db.dialect.full_returning: + if testing.db.dialect.update_returning: asserter.assert_( CompiledSQL( "UPDATE users SET age_int=(users.age_int + %(age_int_1)s) " @@ -857,8 +884,12 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([25, 37, 29, 27])), ) - def test_update_fetch_returning(self): - User = self.classes.User + @testing.combinations(True, False, argnames="implicit_returning") + def test_update_fetch_returning(self, implicit_returning): + if implicit_returning: + User = self.classes.User + else: + User = self.classes.UserNoReturning sess = fixture_session() @@ -873,7 +904,7 @@ class UpdateDeleteTest(fixtures.MappedTest): # the "fetch" strategy, new in 1.4, so there is no expiry eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) - if testing.db.dialect.full_returning: + if implicit_returning and testing.db.dialect.update_returning: asserter.assert_( CompiledSQL( "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) " @@ -919,7 +950,7 @@ class UpdateDeleteTest(fixtures.MappedTest): # the "fetch" strategy, new in 1.4, so there is no expiry eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) - if testing.db.dialect.full_returning: + if testing.db.dialect.update_returning: asserter.assert_( CompiledSQL( "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) " @@ -942,7 +973,7 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ) - @testing.requires.full_returning + @testing.requires.update_returning def test_update_explicit_returning(self): User = self.classes.User @@ -974,7 +1005,7 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ) - @testing.requires.full_returning + @testing.requires.update_returning def test_no_fetch_w_explicit_returning(self): User = self.classes.User @@ -994,8 +1025,12 @@ class UpdateDeleteTest(fixtures.MappedTest): ): sess.execute(stmt) - def test_delete_fetch_returning(self): - User = self.classes.User + @testing.combinations(True, False, argnames="implicit_returning") + def test_delete_fetch_returning(self, implicit_returning): + if implicit_returning: + User = self.classes.User + else: + User = self.classes.UserNoReturning sess = fixture_session() @@ -1009,7 +1044,7 @@ class UpdateDeleteTest(fixtures.MappedTest): synchronize_session="fetch" ) - if testing.db.dialect.full_returning: + if implicit_returning and testing.db.dialect.delete_returning: asserter.assert_( CompiledSQL( "DELETE FROM users WHERE users.age_int > %(age_int_1)s " @@ -1054,7 +1089,7 @@ class UpdateDeleteTest(fixtures.MappedTest): stmt, execution_options={"synchronize_session": "fetch"} ) - if testing.db.dialect.full_returning: + if testing.db.dialect.delete_returning: asserter.assert_( CompiledSQL( "DELETE FROM users WHERE users.age_int > %(age_int_1)s " @@ -2148,7 +2183,7 @@ class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): class LoadFromReturningTest(fixtures.MappedTest): __backend__ = True - __requires__ = ("full_returning",) + __requires__ = ("insert_returning",) @classmethod def define_tables(cls, metadata): @@ -2197,6 +2232,7 @@ class LoadFromReturningTest(fixtures.MappedTest): }, ) + @testing.requires.update_returning def test_load_from_update(self, connection): User = self.classes.User diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 4898cb1228..abd5833bee 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -1347,6 +1347,7 @@ class InheritanceTwoVersionIdsTest(fixtures.MappedTest): class ServerVersioningTest(fixtures.MappedTest): run_define_tables = "each" + __backend__ = True @classmethod @@ -1432,7 +1433,7 @@ class ServerVersioningTest(fixtures.MappedTest): lambda ctx: [{"value": "f1"}], ) ] - if not testing.db.dialect.implicit_returning: + if not testing.db.dialect.insert_returning: # DBs without implicit returning, we must immediately # SELECT for the new version id statements.append( @@ -1460,34 +1461,46 @@ class ServerVersioningTest(fixtures.MappedTest): f1.value = "f2" - statements = [ - # note that the assertsql tests the rule against - # "default" - on a "returning" backend, the statement - # includes "RETURNING" - CompiledSQL( - "UPDATE version_table SET version_id=2, value=:value " - "WHERE version_table.id = :version_table_id AND " - "version_table.version_id = :version_table_version_id", - lambda ctx: [ - { - "version_table_id": 1, - "version_table_version_id": 1, - "value": "f2", - } - ], - ) - ] - if not testing.db.dialect.implicit_returning: + if testing.db.dialect.update_returning: + statements = [ + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id " + "RETURNING version_table.version_id", + lambda ctx: [ + { + "version_table_id": 1, + "version_table_version_id": 1, + "value": "f2", + } + ], + enable_returning=True, + ) + ] + else: # DBs without implicit returning, we must immediately # SELECT for the new version id - statements.append( + statements = [ + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id", + lambda ctx: [ + { + "version_table_id": 1, + "version_table_version_id": 1, + "value": "f2", + } + ], + ), CompiledSQL( "SELECT version_table.version_id " "AS version_table_version_id " "FROM version_table WHERE version_table.id = :pk_1", lambda ctx: [{"pk_1": 1}], - ) - ) + ), + ] with conditional_sane_rowcount_warnings( update=True, only_returning=True ): @@ -1512,8 +1525,9 @@ class ServerVersioningTest(fixtures.MappedTest): eq_(f1.version_id, 2) + @testing.requires.sane_rowcount_w_returning @testing.requires.updateable_autoincrement_pks - @testing.requires.returning + @testing.requires.update_returning def test_sql_expr_w_mods_bump(self): sess = self._fixture() @@ -1544,72 +1558,111 @@ class ServerVersioningTest(fixtures.MappedTest): f2.value = "f2a" f3.value = "f3a" - statements = [ - # note that the assertsql tests the rule against - # "default" - on a "returning" backend, the statement - # includes "RETURNING" - CompiledSQL( - "UPDATE version_table SET version_id=2, value=:value " - "WHERE version_table.id = :version_table_id AND " - "version_table.version_id = :version_table_version_id", - lambda ctx: [ - { - "version_table_id": 1, - "version_table_version_id": 1, - "value": "f1a", - } - ], - ), - CompiledSQL( - "UPDATE version_table SET version_id=2, value=:value " - "WHERE version_table.id = :version_table_id AND " - "version_table.version_id = :version_table_version_id", - lambda ctx: [ - { - "version_table_id": 2, - "version_table_version_id": 1, - "value": "f2a", - } - ], - ), - CompiledSQL( - "UPDATE version_table SET version_id=2, value=:value " - "WHERE version_table.id = :version_table_id AND " - "version_table.version_id = :version_table_version_id", - lambda ctx: [ - { - "version_table_id": 3, - "version_table_version_id": 1, - "value": "f3a", - } - ], - ), - ] - if not testing.db.dialect.implicit_returning: - # DBs without implicit returning, we must immediately + if testing.db.dialect.update_returning: + statements = [ + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id " + "RETURNING version_table.version_id", + lambda ctx: [ + { + "version_table_id": 1, + "version_table_version_id": 1, + "value": "f1a", + } + ], + enable_returning=True, + ), + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id " + "RETURNING version_table.version_id", + lambda ctx: [ + { + "version_table_id": 2, + "version_table_version_id": 1, + "value": "f2a", + } + ], + enable_returning=True, + ), + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id " + "RETURNING version_table.version_id", + lambda ctx: [ + { + "version_table_id": 3, + "version_table_version_id": 1, + "value": "f3a", + } + ], + enable_returning=True, + ), + ] + else: + # DBs without update returning, we must immediately # SELECT for the new version id - statements.extend( - [ - CompiledSQL( - "SELECT version_table.version_id " - "AS version_table_version_id " - "FROM version_table WHERE version_table.id = :pk_1", - lambda ctx: [{"pk_1": 1}], - ), - CompiledSQL( - "SELECT version_table.version_id " - "AS version_table_version_id " - "FROM version_table WHERE version_table.id = :pk_1", - lambda ctx: [{"pk_1": 2}], - ), - CompiledSQL( - "SELECT version_table.version_id " - "AS version_table_version_id " - "FROM version_table WHERE version_table.id = :pk_1", - lambda ctx: [{"pk_1": 3}], - ), - ] - ) + statements = [ + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id", + lambda ctx: [ + { + "version_table_id": 1, + "version_table_version_id": 1, + "value": "f1a", + } + ], + ), + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id", + lambda ctx: [ + { + "version_table_id": 2, + "version_table_version_id": 1, + "value": "f2a", + } + ], + ), + CompiledSQL( + "UPDATE version_table SET version_id=2, value=:value " + "WHERE version_table.id = :version_table_id AND " + "version_table.version_id = :version_table_version_id", + lambda ctx: [ + { + "version_table_id": 3, + "version_table_version_id": 1, + "value": "f3a", + } + ], + ), + CompiledSQL( + "SELECT version_table.version_id " + "AS version_table_version_id " + "FROM version_table WHERE version_table.id = :pk_1", + lambda ctx: [{"pk_1": 1}], + ), + CompiledSQL( + "SELECT version_table.version_id " + "AS version_table_version_id " + "FROM version_table WHERE version_table.id = :pk_1", + lambda ctx: [{"pk_1": 2}], + ), + CompiledSQL( + "SELECT version_table.version_id " + "AS version_table_version_id " + "FROM version_table WHERE version_table.id = :pk_1", + lambda ctx: [{"pk_1": 3}], + ), + ] + with conditional_sane_rowcount_warnings( update=True, only_returning=True ): @@ -1638,6 +1691,7 @@ class ServerVersioningTest(fixtures.MappedTest): with conditional_sane_rowcount_warnings(delete=True): self.assert_sql_execution(testing.db, sess.flush, *statements) + @testing.requires.independent_connections @testing.requires.sane_rowcount_w_returning def test_concurrent_mod_err_expire_on_commit(self): sess = self._fixture() @@ -1662,6 +1716,7 @@ class ServerVersioningTest(fixtures.MappedTest): sess.commit, ) + @testing.requires.independent_connections @testing.requires.sane_rowcount_w_returning def test_concurrent_mod_err_noexpire_on_commit(self): sess = self._fixture(expire_on_commit=False) diff --git a/test/requirements.py b/test/requirements.py index f5cbbbf8de..6870fba2f7 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -416,7 +416,7 @@ class DefaultRequirements(SuiteRequirements): @property def sql_expressions_inserted_as_primary_key(self): - return only_if([self.returning, self.sqlite]) + return only_if([self.insert_returning, self.sqlite]) @property def computed_columns_on_update_returning(self): diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 0fa51e04c1..08911a6c56 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -870,7 +870,7 @@ class DefaultRoundTripTest(fixtures.TablesTest): class CTEDefaultTest(fixtures.TablesTest): - __requires__ = ("ctes", "returning", "ctes_on_dml") + __requires__ = ("ctes", "insert_returning", "ctes_on_dml") __backend__ = True @classmethod @@ -993,8 +993,11 @@ class PKDefaultTest(fixtures.TestBase): return go + @testing.crashes( + "+mariadbconnector", "https://jira.mariadb.org/browse/CONPY-206" + ) @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -1278,7 +1281,7 @@ class SpecialTypePKTest(fixtures.TestBase): # we don't pre-fetch 'server_default'. if "server_default" in kw and ( - not testing.db.dialect.implicit_returning + not testing.db.dialect.insert_returning or not implicit_returning ): eq_(r.inserted_primary_key, (None,)) @@ -1321,15 +1324,18 @@ class SpecialTypePKTest(fixtures.TestBase): def test_server_default_no_autoincrement(self): self._run_test(server_default="1", autoincrement=False) + @testing.crashes( + "+mariadbconnector", "https://jira.mariadb.org/browse/CONPY-206" + ) def test_clause(self): stmt = select(cast("INT_1", type_=self.MyInteger)).scalar_subquery() self._run_test(default=stmt) - @testing.requires.returning + @testing.requires.insert_returning def test_no_implicit_returning(self): self._run_test(implicit_returning=False) - @testing.requires.returning + @testing.requires.insert_returning def test_server_default_no_implicit_returning(self): self._run_test(server_default="1", autoincrement=False) @@ -1363,7 +1369,7 @@ class ServerDefaultsOnPKTest(fixtures.TestBase): eq_(r.inserted_primary_key, (None,)) eq_(list(connection.execute(t.select())), [("key_one", "data")]) - @testing.requires.returning + @testing.requires.insert_returning @testing.provide_metadata def test_string_default_on_insert_with_returning(self, connection): """With implicit_returning, we get a string PK default back no @@ -1441,8 +1447,9 @@ class ServerDefaultsOnPKTest(fixtures.TestBase): else: eq_(list(connection.execute(t2.select())), [(5, "data")]) - @testing.requires.returning + @testing.requires.insert_returning @testing.provide_metadata + @testing.fails_on("sqlite", "sqlite doesn't like our default trick here") def test_int_default_on_insert_with_returning(self, connection): metadata = self.metadata t = Table( diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 3a6217f671..808b047a28 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -332,10 +332,10 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): table1 = self.tables.mytable stmt = table1.insert().returning(table1.c.myid) - assert_raises_message( - exc.CompileError, - "RETURNING is not supported by this dialect's statement compiler.", - stmt.compile, + self.assert_compile( + stmt, + "INSERT INTO mytable (myid, name, description) " + "VALUES (:myid, :name, :description) RETURNING mytable.myid", dialect=default.DefaultDialect(), ) @@ -1028,7 +1028,7 @@ class InsertImplicitReturningTest( Column("q", Integer), ) - dialect = postgresql.dialect(implicit_returning=True) + dialect = postgresql.dialect() dialect.insert_null_pk_still_autoincrements = ( insert_null_still_autoincrements ) diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 45f4098b22..3e51e9450c 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -100,8 +100,9 @@ class InsertExecTest(fixtures.TablesTest): # verify implicit_returning is working if ( - connection.dialect.implicit_returning + connection.dialect.insert_returning and table_.implicit_returning + and not connection.dialect.postfetch_lastrowid ): ins = table_.insert() comp = ins.compile(connection, column_keys=list(values)) @@ -146,7 +147,7 @@ class InsertExecTest(fixtures.TablesTest): @testing.requires.supports_autoincrement_w_composite_pk @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -173,7 +174,7 @@ class InsertExecTest(fixtures.TablesTest): @testing.requires.supports_autoincrement_w_composite_pk @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -200,7 +201,7 @@ class InsertExecTest(fixtures.TablesTest): ) @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -223,7 +224,7 @@ class InsertExecTest(fixtures.TablesTest): @testing.requires.sequences @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -251,7 +252,7 @@ class InsertExecTest(fixtures.TablesTest): @testing.requires.sequences @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -277,7 +278,7 @@ class InsertExecTest(fixtures.TablesTest): ) @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -299,7 +300,7 @@ class InsertExecTest(fixtures.TablesTest): @testing.requires.supports_autoincrement_w_composite_pk @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -338,6 +339,7 @@ class InsertExecTest(fixtures.TablesTest): self.metadata, Column("x", Integer, primary_key=True), Column("y", Integer), + implicit_returning=False, ) t.create(connection) with mock.patch.object( @@ -403,12 +405,16 @@ class InsertExecTest(fixtures.TablesTest): eq_(r.inserted_primary_key, (None,)) @testing.requires.empty_inserts - @testing.requires.returning - def test_no_inserted_pk_on_returning(self, connection): + @testing.requires.insert_returning + def test_no_inserted_pk_on_returning( + self, connection, close_result_when_finished + ): users = self.tables.users result = connection.execute( users.insert().returning(users.c.user_id, users.c.user_name) ) + close_result_when_finished(result) + assert_raises_message( exc.InvalidRequestError, r"Can't call inserted_primary_key when returning\(\) is used.", @@ -566,7 +572,7 @@ class TableInsertTest(fixtures.TablesTest): inserted_primary_key=(1,), ) - @testing.requires.returning + @testing.requires.insert_returning def test_uppercase_direct_params_returning(self, connection): t = self.tables.foo self._test( @@ -599,7 +605,7 @@ class TableInsertTest(fixtures.TablesTest): inserted_primary_key=(), ) - @testing.requires.returning + @testing.requires.insert_returning def test_direct_params_returning(self, connection): t = self._fixture() self._test( diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index bacdbaf3fb..c458e3262e 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -199,8 +199,8 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): ) -class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): - __requires__ = ("returning",) +class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): + __requires__ = ("insert_returning",) __backend__ = True run_create_tables = "each" @@ -286,26 +286,6 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): row = result.first() eq_(row[0], 30) - def test_update_returning(self, connection): - table = self.tables.tables - connection.execute( - table.insert(), - [{"persons": 5, "full": False}, {"persons": 3, "full": False}], - ) - - result = connection.execute( - table.update() - .values(dict(full=True)) - .where(table.c.persons > 4) - .returning(table.c.id) - ) - eq_(result.fetchall(), [(1,)]) - - result2 = connection.execute( - select(table.c.id, table.c.full).order_by(table.c.id) - ) - eq_(result2.fetchall(), [(1, True), (2, False)]) - @testing.fails_on( "mssql", "driver has unknown issue with string concatenation " @@ -339,6 +319,94 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) eq_(result2.fetchall(), [(1, "FOOsomegoofyBAR")]) + def test_no_ipk_on_returning(self, connection, close_result_when_finished): + table = self.tables.tables + result = connection.execute( + table.insert().returning(table.c.id), {"persons": 1, "full": False} + ) + close_result_when_finished(result) + assert_raises_message( + sa_exc.InvalidRequestError, + r"Can't call inserted_primary_key when returning\(\) is used.", + getattr, + result, + "inserted_primary_key", + ) + + def test_insert_returning(self, connection): + table = self.tables.tables + result = connection.execute( + table.insert().returning(table.c.id), {"persons": 1, "full": False} + ) + + eq_(result.fetchall(), [(1,)]) + + @testing.requires.multivalues_inserts + def test_multirow_returning(self, connection): + table = self.tables.tables + ins = ( + table.insert() + .returning(table.c.id, table.c.persons) + .values( + [ + {"persons": 1, "full": False}, + {"persons": 2, "full": True}, + {"persons": 3, "full": False}, + ] + ) + ) + result = connection.execute(ins) + eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) + + @testing.fails_on_everything_except( + "postgresql", "mariadb>=10.5", "sqlite>=3.34" + ) + def test_literal_returning(self, connection): + if testing.against("mariadb"): + quote = "`" + else: + quote = '"' + if testing.against("postgresql"): + literal_true = "true" + else: + literal_true = "1" + + result4 = connection.exec_driver_sql( + "insert into tables (id, persons, %sfull%s) " + "values (5, 10, %s) returning persons" + % (quote, quote, literal_true) + ) + eq_([dict(row._mapping) for row in result4], [{"persons": 10}]) + + +class UpdateReturningTest(fixtures.TablesTest, AssertsExecutionResults): + __requires__ = ("update_returning",) + __backend__ = True + + run_create_tables = "each" + + define_tables = InsertReturningTest.define_tables + + def test_update_returning(self, connection): + table = self.tables.tables + connection.execute( + table.insert(), + [{"persons": 5, "full": False}, {"persons": 3, "full": False}], + ) + + result = connection.execute( + table.update() + .values(dict(full=True)) + .where(table.c.persons > 4) + .returning(table.c.id) + ) + eq_(result.fetchall(), [(1,)]) + + result2 = connection.execute( + select(table.c.id, table.c.full).order_by(table.c.id) + ) + eq_(result2.fetchall(), [(1, True), (2, False)]) + def test_update_returning_w_expression_one(self, connection): table = self.tables.tables connection.execute( @@ -388,7 +456,6 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): [(1, "FOOnewgoofyBAR"), (2, "FOOsomegoofy2BAR")], ) - @testing.requires.full_returning def test_update_full_returning(self, connection): table = self.tables.tables connection.execute( @@ -404,69 +471,14 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) eq_(result.fetchall(), [(1, True), (2, True)]) - @testing.requires.full_returning - def test_delete_full_returning(self, connection): - table = self.tables.tables - connection.execute( - table.insert(), - [{"persons": 5, "full": False}, {"persons": 3, "full": False}], - ) - - result = connection.execute( - table.delete().returning(table.c.id, table.c.full) - ) - eq_(result.fetchall(), [(1, False), (2, False)]) - - def test_insert_returning(self, connection): - table = self.tables.tables - result = connection.execute( - table.insert().returning(table.c.id), {"persons": 1, "full": False} - ) - eq_(result.fetchall(), [(1,)]) - - @testing.requires.multivalues_inserts - def test_multirow_returning(self, connection): - table = self.tables.tables - ins = ( - table.insert() - .returning(table.c.id, table.c.persons) - .values( - [ - {"persons": 1, "full": False}, - {"persons": 2, "full": True}, - {"persons": 3, "full": False}, - ] - ) - ) - result = connection.execute(ins) - eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) - - def test_no_ipk_on_returning(self, connection): - table = self.tables.tables - result = connection.execute( - table.insert().returning(table.c.id), {"persons": 1, "full": False} - ) - assert_raises_message( - sa_exc.InvalidRequestError, - r"Can't call inserted_primary_key when returning\(\) is used.", - getattr, - result, - "inserted_primary_key", - ) +class DeleteReturningTest(fixtures.TablesTest, AssertsExecutionResults): + __requires__ = ("delete_returning",) + __backend__ = True - @testing.fails_on_everything_except("postgresql") - def test_literal_returning(self, connection): - if testing.against("postgresql"): - literal_true = "true" - else: - literal_true = "1" + run_create_tables = "each" - result4 = connection.exec_driver_sql( - 'insert into tables (id, persons, "full") ' - "values (5, 10, %s) returning persons" % literal_true - ) - eq_([dict(row._mapping) for row in result4], [{"persons": 10}]) + define_tables = InsertReturningTest.define_tables def test_delete_returning(self, connection): table = self.tables.tables @@ -487,7 +499,7 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): class CompositeStatementTest(fixtures.TestBase): - __requires__ = ("returning",) + __requires__ = ("insert_returning",) __backend__ = True @testing.provide_metadata @@ -517,7 +529,7 @@ class CompositeStatementTest(fixtures.TestBase): class SequenceReturningTest(fixtures.TablesTest): - __requires__ = "returning", "sequences" + __requires__ = "insert_returning", "sequences" __backend__ = True @classmethod @@ -552,7 +564,7 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): """test returning() works with columns that define 'key'.""" - __requires__ = ("returning",) + __requires__ = ("insert_returning",) __backend__ = True @classmethod @@ -583,8 +595,8 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): assert row[table.c.foo_id] == row["id"] == 1 -class ReturnDefaultsTest(fixtures.TablesTest): - __requires__ = ("returning",) +class InsertReturnDefaultsTest(fixtures.TablesTest): + __requires__ = ("insert_returning",) run_define_tables = "each" __backend__ = True @@ -639,67 +651,99 @@ class ReturnDefaultsTest(fixtures.TablesTest): [1, 0], ) - def test_chained_update_pk(self, connection): + def test_insert_non_default(self, connection): + """test that a column not marked at all as a + default works with this feature.""" + t1 = self.tables.t1 - connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.update().values(data="d1").return_defaults(t1.c.upddef) + t1.insert().values(upddef=1).return_defaults(t1.c.data) ) eq_( - [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] + [ + result.returned_defaults._mapping[k] + for k in (t1.c.id, t1.c.data) + ], + [1, None], ) - def test_arg_update_pk(self, connection): + def test_insert_sql_expr(self, connection): + from sqlalchemy import literal + t1 = self.tables.t1 - connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.update().return_defaults(t1.c.upddef).values(data="d1") + t1.insert().return_defaults().values(insdef=literal(10) + 5) ) + eq_( - [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] + result.returned_defaults._mapping, + {"id": 1, "data": None, "insdef": 15, "upddef": None}, ) - def test_insert_non_default(self, connection): - """test that a column not marked at all as a - default works with this feature.""" + def test_insert_non_default_plus_default(self, connection): + t1 = self.tables.t1 + result = connection.execute( + t1.insert() + .values(upddef=1) + .return_defaults(t1.c.data, t1.c.insdef) + ) + eq_( + dict(result.returned_defaults._mapping), + {"id": 1, "data": None, "insdef": 0}, + ) + eq_(result.inserted_primary_key, (1,)) + def test_insert_all(self, connection): t1 = self.tables.t1 result = connection.execute( - t1.insert().values(upddef=1).return_defaults(t1.c.data) + t1.insert().values(upddef=1).return_defaults() ) eq_( - [ - result.returned_defaults._mapping[k] - for k in (t1.c.id, t1.c.data) - ], - [1, None], + dict(result.returned_defaults._mapping), + {"id": 1, "data": None, "insdef": 0}, ) + eq_(result.inserted_primary_key, (1,)) - def test_update_non_default(self, connection): - """test that a column not marked at all as a - default works with this feature.""" +class UpdatedReturnDefaultsTest(fixtures.TablesTest): + __requires__ = ("update_returning",) + run_define_tables = "each" + __backend__ = True + + define_tables = InsertReturnDefaultsTest.define_tables + + def test_chained_update_pk(self, connection): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.update().values(upddef=2).return_defaults(t1.c.data) + t1.update().values(data="d1").return_defaults(t1.c.upddef) ) eq_( - [result.returned_defaults._mapping[k] for k in (t1.c.data,)], - [None], + [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] ) - def test_insert_sql_expr(self, connection): - from sqlalchemy import literal - + def test_arg_update_pk(self, connection): t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.insert().return_defaults().values(insdef=literal(10) + 5) + t1.update().return_defaults(t1.c.upddef).values(data="d1") ) + eq_( + [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] + ) + + def test_update_non_default(self, connection): + """test that a column not marked at all as a + default works with this feature.""" + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.update().values(upddef=2).return_defaults(t1.c.data) + ) eq_( - result.returned_defaults._mapping, - {"id": 1, "data": None, "insdef": 15, "upddef": None}, + [result.returned_defaults._mapping[k] for k in (t1.c.data,)], + [None], ) def test_update_sql_expr(self, connection): @@ -713,19 +757,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): eq_(result.returned_defaults._mapping, {"upddef": 15}) - def test_insert_non_default_plus_default(self, connection): - t1 = self.tables.t1 - result = connection.execute( - t1.insert() - .values(upddef=1) - .return_defaults(t1.c.data, t1.c.insdef) - ) - eq_( - dict(result.returned_defaults._mapping), - {"id": 1, "data": None, "insdef": 0}, - ) - eq_(result.inserted_primary_key, (1,)) - def test_update_non_default_plus_default(self, connection): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) @@ -739,17 +770,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): {"data": None, "upddef": 1}, ) - def test_insert_all(self, connection): - t1 = self.tables.t1 - result = connection.execute( - t1.insert().values(upddef=1).return_defaults() - ) - eq_( - dict(result.returned_defaults._mapping), - {"id": 1, "data": None, "insdef": 0}, - ) - eq_(result.inserted_primary_key, (1,)) - def test_update_all(self, connection): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) @@ -758,7 +778,14 @@ class ReturnDefaultsTest(fixtures.TablesTest): ) eq_(dict(result.returned_defaults._mapping), {"upddef": 1}) - @testing.requires.insert_executemany_returning + +class InsertManyReturnDefaultsTest(fixtures.TablesTest): + __requires__ = ("insert_executemany_returning",) + run_define_tables = "each" + __backend__ = True + + define_tables = InsertReturnDefaultsTest.define_tables + def test_insert_executemany_no_defaults_passed(self, connection): t1 = self.tables.t1 result = connection.execute( @@ -802,7 +829,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): lambda: result.inserted_primary_key, ) - @testing.requires.insert_executemany_returning def test_insert_executemany_insdefault_passed(self, connection): t1 = self.tables.t1 result = connection.execute( @@ -846,7 +872,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): lambda: result.inserted_primary_key, ) - @testing.requires.insert_executemany_returning def test_insert_executemany_only_pk_passed(self, connection): t1 = self.tables.t1 result = connection.execute( diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index be74153cec..19f95c6619 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -218,9 +218,15 @@ class SequenceExecTest(fixtures.TestBase): @testing.combinations( ("implicit_returning",), ("no_implicit_returning",), - ("explicit_returning", testing.requires.returning), - ("return_defaults_no_implicit_returning", testing.requires.returning), - ("return_defaults_implicit_returning", testing.requires.returning), + ("explicit_returning", testing.requires.insert_returning), + ( + "return_defaults_no_implicit_returning", + testing.requires.insert_returning, + ), + ( + "return_defaults_implicit_returning", + testing.requires.insert_returning, + ), argnames="returning", ) @testing.requires.multivalues_inserts @@ -264,17 +270,17 @@ class SequenceExecTest(fixtures.TestBase): ("no_implicit_returning",), ( "explicit_returning", - testing.requires.returning + testing.requires.insert_returning + testing.requires.insert_executemany_returning, ), ( "return_defaults_no_implicit_returning", - testing.requires.returning + testing.requires.insert_returning + testing.requires.insert_executemany_returning, ), ( "return_defaults_implicit_returning", - testing.requires.returning + testing.requires.insert_returning + testing.requires.insert_executemany_returning, ), argnames="returning", @@ -318,7 +324,7 @@ class SequenceExecTest(fixtures.TestBase): [(1, "d1"), (2, "d2"), (3, "d3")], ) - @testing.requires.returning + @testing.requires.insert_returning def test_inserted_pk_implicit_returning(self, connection, metadata): """test inserted_primary_key contains the result when pk_col=next_value(), when implicit returning is used.""" @@ -435,7 +441,7 @@ class SequenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): assert not self._has_sequence(connection, "s1") assert not self._has_sequence(connection, "s2") - @testing.requires.returning + @testing.requires.insert_returning @testing.requires.supports_sequence_for_autoincrement_column @testing.provide_metadata def test_freestanding_sequence_via_autoinc(self, connection): @@ -545,7 +551,7 @@ class TableBoundSequenceTest(fixtures.TablesTest): return go @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) @@ -571,7 +577,7 @@ class TableBoundSequenceTest(fixtures.TablesTest): ) @testing.combinations( - (True, testing.requires.returning), + (True, testing.requires.insert_returning), (False,), argnames="implicit_returning", ) diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index 70c8839e30..901be7132a 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -496,7 +496,7 @@ class TypeDecRoundTripTest(fixtures.TablesTest, RoundTripTestBase): class ReturningTest(fixtures.TablesTest): - __requires__ = ("returning",) + __requires__ = ("insert_returning",) @classmethod def define_tables(cls, metadata): -- 2.47.2