]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for preserve_rowcount execution_option
authorFederico Caselli <cfederico87@gmail.com>
Wed, 7 Feb 2024 21:11:25 +0000 (22:11 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 3 Mar 2024 22:58:12 +0000 (17:58 -0500)
Added new core execution option
paramref:`_engine.Connection.execution_options.preserve_rowcount`
to unconditionally save the ``rowcount`` attribute from the cursor in the
class:`_engine.Result` returned from an execution, regardless of the
statement being executed.
When this option is provided the correct value is also set when
an INSERT makes use of the "insertmanyvalues" mode, that may use
more than one actualy cursor execution.

Fixes: #10974
Change-Id: Icecef6b7539be9f0a1a02b9539864f5f163dcfbc

19 files changed:
doc/build/changelog/unreleased_20/10974.rst [new file with mode: 0644]
doc/build/tutorial/data_update.rst
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/mariadbconnector.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/fixtures/sql.py
test/requirements.py
test/sql/test_insert_exec.py
test/sql/test_resultset.py
test/typing/test_overloads.py

diff --git a/doc/build/changelog/unreleased_20/10974.rst b/doc/build/changelog/unreleased_20/10974.rst
new file mode 100644 (file)
index 0000000..a5da624
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: engine, usecase
+    :tickets: 10974
+
+    Added new core execution option
+    :paramref:`_engine.Connection.execution_options.preserve_rowcount`. When
+    set, the ``cursor.rowcount`` attribute from the DBAPI cursor will be
+    unconditionally memoized at statement execution time, so that whatever
+    value the DBAPI offers for any kind of statement will be available using
+    the :attr:`_engine.CursorResult.rowcount` attribute from the
+    :class:`_engine.CursorResult`.  This allows the rowcount to be accessed for
+    statments such as INSERT and SELECT, to the degree supported by the DBAPI
+    in use. The :ref:`engine_insertmanyvalues` also supports this option and
+    will ensure :attr:`_engine.CursorResult.rowcount` is correctly set for a
+    bulk INSERT of rows when set.
index a82f070a3f65169d8fbab8a1487518431181fc01..48cf5c058aa23bdd9014da9e37e4d47c50abc7dc 100644 (file)
@@ -279,17 +279,24 @@ Facts about :attr:`_engine.CursorResult.rowcount`:
   the statement.   It does not matter if the row were actually modified or not.
 
 * :attr:`_engine.CursorResult.rowcount` is not necessarily available for an UPDATE
-  or DELETE statement that uses RETURNING.
+  or DELETE statement that uses RETURNING, or for one that uses an
+  :ref:`executemany <tutorial_multiple_parameters>` execution.   The availablility
+  depends on the DBAPI module in use.
 
-* For an :ref:`executemany <tutorial_multiple_parameters>` execution,
-  :attr:`_engine.CursorResult.rowcount` may not be available either, which depends
-  highly on the DBAPI module in use as well as configured options.  The
-  attribute :attr:`_engine.CursorResult.supports_sane_multi_rowcount` indicates
-  if this value will be available for the current backend in use.
+* In any case where the DBAPI does not determine the rowcount for some type
+  of statement, the returned value will be ``-1``.
+
+* SQLAlchemy pre-memoizes the DBAPIs ``cursor.rowcount`` value before the cursor
+  is closed, as some DBAPIs don't support accessing this attribute after the
+  fact.  In order to pre-memoize ``cursor.rowcount`` for a statement that is
+  not UPDATE or DELETE, such as INSERT or SELECT, the
+  :paramref:`_engine.Connection.execution_options.preserve_rowcount` execution
+  option may be used.
 
 * Some drivers, particularly third party dialects for non-relational databases,
   may not support :attr:`_engine.CursorResult.rowcount` at all.   The
-  :attr:`_engine.CursorResult.supports_sane_rowcount` will indicate this.
+  :attr:`_engine.CursorResult.supports_sane_rowcount` cursor attribute will
+  indicate this.
 
 * "rowcount" is used by the ORM :term:`unit of work` process to validate that
   an UPDATE or DELETE statement matched the expected number of rows, and is
index 98f7f6dce6e91864db8b25ebed52f78e6f2b74a6..ff69d6aa147e5109c80fb2a828c28f68848744f3 100644 (file)
@@ -1841,7 +1841,6 @@ class MSExecutionContext(default.DefaultExecutionContext):
     _enable_identity_insert = False
     _select_lastrowid = False
     _lastrowid = None
-    _rowcount = None
 
     dialect: MSDialect
 
@@ -1961,13 +1960,6 @@ class MSExecutionContext(default.DefaultExecutionContext):
     def get_lastrowid(self):
         return self._lastrowid
 
-    @property
-    def rowcount(self):
-        if self._rowcount is not None:
-            return self._rowcount
-        else:
-            return self.cursor.rowcount
-
     def handle_dbapi_exception(self, e):
         if self._enable_identity_insert:
             try:
index 86bc59d45a39392393ffa8289fa3f5c8318b54a0..c33ccd3b9332e9e7e7176e81580dddad313c4c11 100644 (file)
@@ -88,13 +88,6 @@ class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
         if self.isinsert and self.compiled.postfetch_lastrowid:
             self._lastrowid = self.cursor.lastrowid
 
-    @property
-    def rowcount(self):
-        if self._rowcount is not None:
-            return self._rowcount
-        else:
-            return self.cursor.rowcount
-
     def get_lastrowid(self):
         return self._lastrowid
 
index d46d159d4cd7f0da70fac1c04b43a2600409ae63..0c632b66f3eff87cea6ae1b36ad9d8dff2f84ddf 100644 (file)
@@ -97,12 +97,7 @@ from ... import util
 
 
 class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
-    @property
-    def rowcount(self):
-        if hasattr(self, "_rowcount"):
-            return self._rowcount
-        else:
-            return self.cursor.rowcount
+    pass
 
 
 class MySQLCompiler_mysqldb(MySQLCompiler):
index 17b14f4d05b64523d3c0036b6aa5eb6ff64c0f9e..325ea8869905175e8e5423230dc50690febab0b6 100644 (file)
@@ -8,7 +8,7 @@
 
 from types import ModuleType
 
-from . import array as arraylib  # noqa # must be above base and other dialects
+from . import array as arraylib  # noqa # keep above base and other dialects
 from . import asyncpg  # noqa
 from . import base
 from . import pg8000  # noqa
index b3577ecca26bdef5b8fb23434c0e63896416c723..63631bdbd739cfe68918c5ab158432a24a179ea2 100644 (file)
@@ -254,6 +254,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         yield_per: int = ...,
         insertmanyvalues_page_size: int = ...,
         schema_translate_map: Optional[SchemaTranslateMapType] = ...,
+        preserve_rowcount: bool = False,
         **opt: Any,
     ) -> Connection: ...
 
@@ -494,6 +495,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
 
             :ref:`schema_translating`
 
+        :param preserve_rowcount: Boolean; when True, the ``cursor.rowcount``
+          attribute will be unconditionally memoized within the result and
+          made available via the :attr:`.CursorResult.rowcount` attribute.
+          Normally, this attribute is only preserved for UPDATE and DELETE
+          statements.  Using this option, the DBAPIs rowcount value can
+          be accessed for other kinds of statements such as INSERT and SELECT,
+          to the degree that the DBAPI supports these statements.  See
+          :attr:`.CursorResult.rowcount` for notes regarding the behavior
+          of this attribute.
+
+          .. versionadded:: 2.0.28
+
         .. seealso::
 
             :meth:`_engine.Engine.execution_options`
@@ -1835,10 +1848,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         context.pre_exec()
 
         if context.execute_style is ExecuteStyle.INSERTMANYVALUES:
-            return self._exec_insertmany_context(
-                dialect,
-                context,
-            )
+            return self._exec_insertmany_context(dialect, context)
         else:
             return self._exec_single_context(
                 dialect, context, statement, parameters
@@ -2022,6 +2032,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
         if self._echo:
             stats = context._get_cache_stats() + " (insertmanyvalues)"
 
+        preserve_rowcount = context.execution_options.get(
+            "preserve_rowcount", False
+        )
+        rowcount = 0
+
         for imv_batch in dialect._deliver_insertmanyvalues_batches(
             cursor,
             str_statement,
@@ -2132,9 +2147,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
                     context.executemany,
                 )
 
+            if preserve_rowcount:
+                rowcount += imv_batch.current_batch_size
+
         try:
             context.post_exec()
 
+            if preserve_rowcount:
+                context._rowcount = rowcount  # type: ignore[attr-defined]
+
             result = context._setup_result_proxy()
 
         except BaseException as e:
index 89a443bc0b7adb0d93b4c43cf8f71c899ed3e389..004274ec5aabad2cf52e235b8110cb3bac60dab3 100644 (file)
@@ -1981,8 +1981,28 @@ class CursorResult(Result[Unpack[_Ts]]):
     def rowcount(self) -> int:
         """Return the 'rowcount' for this result.
 
-        The 'rowcount' reports the number of rows *matched*
-        by the WHERE criterion of an UPDATE or DELETE statement.
+        The primary purpose of 'rowcount' is to report the number of rows
+        matched by the WHERE criterion of an UPDATE or DELETE statement
+        executed once (i.e. for a single parameter set), which may then be
+        compared to the number of rows expected to be updated or deleted as a
+        means of asserting data integrity.
+
+        This attribute is transferred from the ``cursor.rowcount`` attribute
+        of the DBAPI before the cursor is closed, to support DBAPIs that
+        don't make this value available after cursor close.   Some DBAPIs may
+        offer meaningful values for other kinds of statements, such as INSERT
+        and SELECT statements as well.  In order to retrieve ``cursor.rowcount``
+        for these statements, set the
+        :paramref:`.Connection.execution_options.preserve_rowcount`
+        execution option to True, which will cause the ``cursor.rowcount``
+        value to be unconditionally memoized before any results are returned
+        or the cursor is closed, regardless of statement type.
+
+        For cases where the DBAPI does not support rowcount for a particular
+        kind of statement and/or execution, the returned value will be ``-1``,
+        which is delivered directly from the DBAPI and is part of :pep:`249`.
+        All DBAPIs should support rowcount for single-parameter-set
+        UPDATE and DELETE statements, however.
 
         .. note::
 
@@ -1991,38 +2011,47 @@ class CursorResult(Result[Unpack[_Ts]]):
 
            * This attribute returns the number of rows *matched*,
              which is not necessarily the same as the number of rows
-             that were actually *modified* - an UPDATE statement, for example,
+             that were actually *modified*. For example, an UPDATE statement
              may have no net change on a given row if the SET values
              given are the same as those present in the row already.
              Such a row would be matched but not modified.
              On backends that feature both styles, such as MySQL,
-             rowcount is configured by default to return the match
+             rowcount is configured to return the match
              count in all cases.
 
-           * :attr:`_engine.CursorResult.rowcount`
-             is *only* useful in conjunction
-             with an UPDATE or DELETE statement.  Contrary to what the Python
-             DBAPI says, it does *not* reliably return the
-             number of rows available from the results of a SELECT statement
-             as DBAPIs cannot support this functionality when rows are
-             unbuffered.
-
-           * :attr:`_engine.CursorResult.rowcount`
-             may not be fully implemented by
-             all dialects.  In particular, most DBAPIs do not support an
-             aggregate rowcount result from an executemany call.
-             The :meth:`_engine.CursorResult.supports_sane_rowcount` and
-             :meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods
-             will report from the dialect if each usage is known to be
-             supported.
-
-           * Statements that use RETURNING may not return a correct
-             rowcount.
+           * :attr:`_engine.CursorResult.rowcount` in the default case is
+             *only* useful in conjunction with an UPDATE or DELETE statement,
+             and only with a single set of parameters. For other kinds of
+             statements, SQLAlchemy will not attempt to pre-memoize the value
+             unless the
+             :paramref:`.Connection.execution_options.preserve_rowcount`
+             execution option is used.  Note that contrary to :pep:`249`, many
+             DBAPIs do not support rowcount values for statements that are not
+             UPDATE or DELETE, particularly when rows are being returned which
+             are not fully pre-buffered.   DBAPIs that dont support rowcount
+             for a particular kind of statement should return the value ``-1``
+             for such statements.
+
+           * :attr:`_engine.CursorResult.rowcount` may not be meaningful
+             when executing a single statement with multiple parameter sets
+             (i.e. an :term:`executemany`). Most DBAPIs do not sum "rowcount"
+             values across multiple parameter sets and will return ``-1``
+             when accessed.
+
+           * SQLAlchemy's :ref:`engine_insertmanyvalues` feature does support
+             a correct population of :attr:`_engine.CursorResult.rowcount`
+             when the :paramref:`.Connection.execution_options.preserve_rowcount`
+             execution option is set to True.
+
+           * Statements that use RETURNING may not support rowcount, returning
+             a ``-1`` value instead.
 
         .. seealso::
 
             :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial`
 
+            :paramref:`.Connection.execution_options.preserve_rowcount`
+
         """  # noqa: E501
         try:
             return self.context.rowcount
@@ -2118,8 +2147,7 @@ class CursorResult(Result[Unpack[_Ts]]):
         self, *others: Result[Unpack[TupleAny]]
     ) -> MergedResult[Unpack[TupleAny]]:
         merged_result = super().merge(*others)
-        setup_rowcounts = self.context._has_rowcount
-        if setup_rowcounts:
+        if self.context._has_rowcount:
             merged_result.rowcount = sum(
                 cast("CursorResult[Any]", result).rowcount
                 for result in (self,) + others
index 7eb7d0eb8b2657798ecddd377c300b21b8ea8a77..b6782ff32ebe9050793592bef0d69cc9ac3836f3 100644 (file)
@@ -1207,7 +1207,7 @@ class DefaultExecutionContext(ExecutionContext):
 
     _soft_closed = False
 
-    _has_rowcount = False
+    _rowcount: Optional[int] = None
 
     # a hook for SQLite's translation of
     # result column names
@@ -1797,7 +1797,14 @@ class DefaultExecutionContext(ExecutionContext):
 
     @util.non_memoized_property
     def rowcount(self) -> int:
-        return self.cursor.rowcount
+        if self._rowcount is not None:
+            return self._rowcount
+        else:
+            return self.cursor.rowcount
+
+    @property
+    def _has_rowcount(self):
+        return self._rowcount is not None
 
     def supports_sane_rowcount(self):
         return self.dialect.supports_sane_rowcount
@@ -1808,6 +1815,9 @@ class DefaultExecutionContext(ExecutionContext):
     def _setup_result_proxy(self):
         exec_opt = self.execution_options
 
+        if self._rowcount is None and exec_opt.get("preserve_rowcount", False):
+            self._rowcount = self.cursor.rowcount
+
         if self.is_crud or self.is_text:
             result = self._setup_dml_or_text_result()
             yp = sr = False
@@ -1964,8 +1974,7 @@ class DefaultExecutionContext(ExecutionContext):
 
             if rows:
                 self.returned_default_rows = rows
-            result.rowcount = len(rows)
-            self._has_rowcount = True
+            self._rowcount = len(rows)
 
             if self._is_supplemental_returning:
                 result._rewind(rows)
@@ -1979,12 +1988,12 @@ class DefaultExecutionContext(ExecutionContext):
         elif not result._metadata.returns_rows:
             # no results, get rowcount
             # (which requires open cursor on some drivers)
-            result.rowcount
-            self._has_rowcount = True
+            if self._rowcount is None:
+                self._rowcount = self.cursor.rowcount
             result._soft_close()
         elif self.isupdate or self.isdelete:
-            result.rowcount
-            self._has_rowcount = True
+            if self._rowcount is None:
+                self._rowcount = self.cursor.rowcount
         return result
 
     @util.memoized_property
index 62476696e86a0a1788f3748ab4ac429d15c202c5..d4c5aef7976ab91fc7195082978da7c7b17bb510 100644 (file)
@@ -270,6 +270,7 @@ class _CoreKnownExecutionOptions(TypedDict, total=False):
     yield_per: int
     insertmanyvalues_page_size: int
     schema_translate_map: Optional[SchemaTranslateMapType]
+    preserve_rowcount: bool
 
 
 _ExecuteOptions = immutabledict[str, Any]
@@ -2977,6 +2978,9 @@ class ExecutionContext:
       inline SQL expression value was fired off.  Applies to inserts
       and updates."""
 
+    execution_options: _ExecuteOptions
+    """Execution options associated with the current statement execution"""
+
     @classmethod
     def _init_ddl(
         cls,
index 2b3a85465d3b96062bbe7eeacd64b1ad06cbb4d6..ae04833ad60cb2e4ae701fd20b56518788897807 100644 (file)
@@ -417,6 +417,7 @@ class AsyncConnection(
         yield_per: int = ...,
         insertmanyvalues_page_size: int = ...,
         schema_translate_map: Optional[SchemaTranslateMapType] = ...,
+        preserve_rowcount: bool = False,
         **opt: Any,
     ) -> AsyncConnection: ...
 
index 6a9fd22b658394bedc7b17c2b144e03617537cbd..3a94340052688e7bf6a99421ed149231347153df 100644 (file)
@@ -1732,6 +1732,7 @@ class Query(
         schema_translate_map: Optional[SchemaTranslateMapType] = ...,
         populate_existing: bool = False,
         autoflush: bool = False,
+        preserve_rowcount: bool = False,
         **opt: Any,
     ) -> Self: ...
 
index 798a35eed4c85cfb4fc5ef0bca0aac338739276c..a7bc18c5a4e6bc5efb378828998e4ddbbf2c0734 100644 (file)
@@ -1166,6 +1166,7 @@ class Executable(roles.StatementRole):
         render_nulls: bool = ...,
         is_delete_using: bool = ...,
         is_update_from: bool = ...,
+        preserve_rowcount: bool = False,
         **opt: Any,
     ) -> Self: ...
 
index 4c30b9363829b31d22102f4dd41a40fb38004d0b..9d4becf5a66019e95ed98d1d86ff390bcbee47cd 100644 (file)
@@ -602,7 +602,7 @@ class _InsertManyValuesBatch(NamedTuple):
     replaced_parameters: _DBAPIAnyExecuteParams
     processed_setinputsizes: Optional[_GenericSetInputSizesType]
     batch: Sequence[_DBAPISingleExecuteParams]
-    batch_size: int
+    current_batch_size: int
     batchnum: int
     total_batches: int
     rows_sorted: bool
@@ -5406,7 +5406,7 @@ class SQLCompiler(Compiled):
                     param,
                     generic_setinputsizes,
                     [param],
-                    batch_size,
+                    1,
                     batchnum,
                     lenparams,
                     sort_by_parameter_order,
@@ -5437,7 +5437,7 @@ class SQLCompiler(Compiled):
                 ),
             )
 
-        batches = list(parameters)
+        batches = cast("List[Sequence[Any]]", list(parameters))
 
         processed_setinputsizes: Optional[_GenericSetInputSizesType] = None
         batchnum = 1
@@ -5531,8 +5531,12 @@ class SQLCompiler(Compiled):
                 )
 
         while batches:
-            batch = cast("Sequence[Any]", batches[0:batch_size])
+            batch = batches[0:batch_size]
             batches[0:batch_size] = []
+            if batches:
+                current_batch_size = batch_size
+            else:
+                current_batch_size = len(batch)
 
             if generic_setinputsizes:
                 # if setinputsizes is present, expand this collection to
@@ -5542,7 +5546,7 @@ class SQLCompiler(Compiled):
                     (new_key, len_, typ)
                     for new_key, len_, typ in (
                         (f"{key}_{index}", len_, typ)
-                        for index in range(len(batch))
+                        for index in range(current_batch_size)
                         for key, len_, typ in generic_setinputsizes
                     )
                 ]
@@ -5552,6 +5556,9 @@ class SQLCompiler(Compiled):
                 num_ins_params = imv.num_positional_params_counted
 
                 batch_iterator: Iterable[Sequence[Any]]
+                extra_params_left: Sequence[Any]
+                extra_params_right: Sequence[Any]
+
                 if num_ins_params == len(batch[0]):
                     extra_params_left = extra_params_right = ()
                     batch_iterator = batch
@@ -5574,7 +5581,7 @@ class SQLCompiler(Compiled):
                     )[:-2]
                 else:
                     expanded_values_string = (
-                        (executemany_values_w_comma * len(batch))
+                        (executemany_values_w_comma * current_batch_size)
                     )[:-2]
 
                 if self._numeric_binds and num_ins_params > 0:
@@ -5590,7 +5597,7 @@ class SQLCompiler(Compiled):
                     assert not extra_params_right
 
                     start = expand_pos_lower_index + 1
-                    end = num_ins_params * (len(batch)) + start
+                    end = num_ins_params * (current_batch_size) + start
 
                     # need to format here, since statement may contain
                     # unescaped %, while values_string contains just (%s, %s)
@@ -5640,7 +5647,7 @@ class SQLCompiler(Compiled):
                 replaced_parameters,
                 processed_setinputsizes,
                 batch,
-                batch_size,
+                current_batch_size,
                 batchnum,
                 total_batches,
                 sort_by_parameter_order,
index 1448510625d9ba879f3e72e6b7b9cbedb65fa037..ab532ab0e6d457e35b41a8a6b9a5a804ffd01c14 100644 (file)
@@ -478,10 +478,7 @@ def insertmanyvalues_fixture(
 
             yield batch
 
-    def _exec_insertmany_context(
-        dialect,
-        context,
-    ):
+    def _exec_insertmany_context(dialect, context):
         with mock.patch.object(
             dialect,
             "_deliver_insertmanyvalues_batches",
index a692cd3fee3c3a43e42a30102472b36a02f0b580..2e80884bc17aec686c9d779809c9b1c7692281d9 100644 (file)
@@ -2061,3 +2061,17 @@ class DefaultRequirements(SuiteRequirements):
                 return False
 
         return only_if(go, "json_each is required")
+
+    @property
+    def rowcount_always_cached(self):
+        """Indicates that ``cursor.rowcount`` is always accessed,
+        usually in an ``ExecutionContext.post_exec``.
+        """
+        return only_on(["+mariadbconnector"])
+
+    @property
+    def rowcount_always_cached_on_insert(self):
+        """Indicates that ``cursor.rowcount`` is always accessed in an insert
+        statement.
+        """
+        return only_on(["mssql"])
index ce4caf30e93b703a8fa697b409939580db31a171..16300aad0ffada21e4f1aaf2e604bd542bd3c23f 100644 (file)
@@ -787,7 +787,8 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
 
         eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)])
 
-    def test_insert_returning_values(self, connection):
+    @testing.variation("preserve_rowcount", [True, False])
+    def test_insert_returning_values(self, connection, preserve_rowcount):
         t = self.tables.data
 
         conn = connection
@@ -796,7 +797,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
             {"x": "x%d" % i, "y": "y%d" % i}
             for i in range(1, page_size * 2 + 27)
         ]
-        result = conn.execute(t.insert().returning(t.c.x, t.c.y), data)
+        if preserve_rowcount:
+            eo = {"preserve_rowcount": True}
+        else:
+            eo = {}
+
+        result = conn.execute(
+            t.insert().returning(t.c.x, t.c.y), data, execution_options=eo
+        )
 
         eq_([tup[0] for tup in result.cursor.description], ["x", "y"])
         eq_(result.keys(), ["x", "y"])
@@ -814,6 +822,9 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
         # assert result.closed
         assert result.cursor is None
 
+        if preserve_rowcount:
+            eq_(result.rowcount, len(data))
+
     def test_insert_returning_preexecute_pk(self, metadata, connection):
         counter = itertools.count(1)
 
@@ -1036,10 +1047,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
 
         eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")])
 
-    def test_insert_returning_defaults(self, connection):
+    @testing.variation("preserve_rowcount", [True, False])
+    def test_insert_returning_defaults(self, connection, preserve_rowcount):
         t = self.tables.data
 
-        conn = connection
+        if preserve_rowcount:
+            conn = connection.execution_options(preserve_rowcount=True)
+        else:
+            conn = connection
 
         result = conn.execute(t.insert(), {"x": "x0", "y": "y0"})
         first_pk = result.inserted_primary_key[0]
@@ -1054,6 +1069,9 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest):
             [(pk, 5) for pk in range(1 + first_pk, total_rows + first_pk)],
         )
 
+        if preserve_rowcount:
+            eq_(result.rowcount, total_rows - 1)  # range starts from 1
+
     def test_insert_return_pks_default_values(self, connection):
         """test sending multiple, empty rows into an INSERT and getting primary
         key values back.
index e1b43b7fd182f5e5643f2b19b4b26ef3a45294b3..938df1ac3af828c82ab5cfe53c1834081442ddf4 100644 (file)
@@ -1,3 +1,4 @@
+from collections import defaultdict
 import collections.abc as collections_abc
 from contextlib import contextmanager
 import csv
@@ -1733,6 +1734,29 @@ class CursorResultTest(fixtures.TablesTest):
         eq_(proxy.key, "value")
         eq_(proxy._mapping["key"], "value")
 
+    @contextmanager
+    def cursor_wrapper(self, engine):
+        calls = defaultdict(int)
+
+        class CursorWrapper:
+            def __init__(self, real_cursor):
+                self.real_cursor = real_cursor
+
+            def __getattr__(self, name):
+                calls[name] += 1
+                return getattr(self.real_cursor, name)
+
+        create_cursor = engine.dialect.execution_ctx_cls.create_cursor
+
+        def new_create(context):
+            cursor = create_cursor(context)
+            return CursorWrapper(cursor)
+
+        with patch.object(
+            engine.dialect.execution_ctx_cls, "create_cursor", new_create
+        ):
+            yield calls
+
     def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine):
         """assert that rowcount is only called on deletes and updates.
 
@@ -1744,33 +1768,71 @@ class CursorResultTest(fixtures.TablesTest):
 
         engine = testing_engine()
 
+        req = testing.requires
+
         t = Table("t1", metadata, Column("data", String(10)))
         metadata.create_all(engine)
-
-        with patch.object(
-            engine.dialect.execution_ctx_cls, "rowcount"
-        ) as mock_rowcount:
+        count = 0
+        with self.cursor_wrapper(engine) as call_counts:
             with engine.begin() as conn:
-                mock_rowcount.__get__ = Mock()
                 conn.execute(
                     t.insert(),
                     [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}],
                 )
-
-                eq_(len(mock_rowcount.__get__.mock_calls), 0)
+                if (
+                    req.rowcount_always_cached.enabled
+                    or req.rowcount_always_cached_on_insert.enabled
+                ):
+                    count += 1
+                eq_(call_counts["rowcount"], count)
 
                 eq_(
                     conn.execute(t.select()).fetchall(),
                     [("d1",), ("d2",), ("d3",)],
                 )
-                eq_(len(mock_rowcount.__get__.mock_calls), 0)
+                if req.rowcount_always_cached.enabled:
+                    count += 1
+                eq_(call_counts["rowcount"], count)
+
+                conn.execute(t.update(), {"data": "d4"})
+
+                count += 1
+                eq_(call_counts["rowcount"], count)
+
+                conn.execute(t.delete())
+                count += 1
+                eq_(call_counts["rowcount"], count)
+
+    def test_rowcount_always_called_when_preserve_rowcount(
+        self, metadata, testing_engine
+    ):
+        """assert that rowcount is called on any statement when
+        ``preserve_rowcount=True``.
+
+        """
+
+        engine = testing_engine()
+
+        t = Table("t1", metadata, Column("data", String(10)))
+        metadata.create_all(engine)
+
+        with self.cursor_wrapper(engine) as call_counts:
+            with engine.begin() as conn:
+                conn = conn.execution_options(preserve_rowcount=True)
+                # Do not use insertmanyvalues on any driver
+                conn.execute(t.insert(), {"data": "d1"})
+
+                eq_(call_counts["rowcount"], 1)
+
+                eq_(conn.execute(t.select()).fetchall(), [("d1",)])
+                eq_(call_counts["rowcount"], 2)
 
                 conn.execute(t.update(), {"data": "d4"})
 
-                eq_(len(mock_rowcount.__get__.mock_calls), 1)
+                eq_(call_counts["rowcount"], 3)
 
                 conn.execute(t.delete())
-                eq_(len(mock_rowcount.__get__.mock_calls), 2)
+                eq_(call_counts["rowcount"], 4)
 
     def test_row_is_sequence(self):
         row = Row(object(), [None], {}, ["value"])
index 968b60d926473e9e84896371f4d900ec5bce14f8..66209f50365948caf21f25722fbbcf5a390783c3 100644 (file)
@@ -24,6 +24,7 @@ core_execution_options = {
     "stream_results": "bool",
     "max_row_buffer": "int",
     "yield_per": "int",
+    "preserve_rowcount": "bool",
 }
 
 orm_dql_execution_options = {