From f0537442eb7d3a3b2e702c8843c3c277fbfda0ac Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Wed, 7 Feb 2024 22:11:25 +0100 Subject: [PATCH] Add support for preserve_rowcount execution_option Added new core execution option paramref:`_engine.Connection.execution_options.preserve_rowcount` to unconditionally save the ``rowcount`` attribute from the cursor in the class:`_engine.Result` returned from an execution, regardless of the statement being executed. When this option is provided the correct value is also set when an INSERT makes use of the "insertmanyvalues" mode, that may use more than one actualy cursor execution. Fixes: #10974 Change-Id: Icecef6b7539be9f0a1a02b9539864f5f163dcfbc --- doc/build/changelog/unreleased_20/10974.rst | 15 ++++ doc/build/tutorial/data_update.rst | 21 +++-- lib/sqlalchemy/dialects/mssql/base.py | 8 -- .../dialects/mysql/mariadbconnector.py | 7 -- lib/sqlalchemy/dialects/mysql/mysqldb.py | 7 +- .../dialects/postgresql/__init__.py | 2 +- lib/sqlalchemy/engine/base.py | 29 ++++++- lib/sqlalchemy/engine/cursor.py | 78 ++++++++++++------ lib/sqlalchemy/engine/default.py | 25 ++++-- lib/sqlalchemy/engine/interfaces.py | 4 + lib/sqlalchemy/ext/asyncio/engine.py | 1 + lib/sqlalchemy/orm/query.py | 1 + lib/sqlalchemy/sql/base.py | 1 + lib/sqlalchemy/sql/compiler.py | 23 ++++-- lib/sqlalchemy/testing/fixtures/sql.py | 5 +- test/requirements.py | 14 ++++ test/sql/test_insert_exec.py | 26 +++++- test/sql/test_resultset.py | 82 ++++++++++++++++--- test/typing/test_overloads.py | 1 + 19 files changed, 258 insertions(+), 92 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/10974.rst diff --git a/doc/build/changelog/unreleased_20/10974.rst b/doc/build/changelog/unreleased_20/10974.rst new file mode 100644 index 0000000000..a5da62475e --- /dev/null +++ b/doc/build/changelog/unreleased_20/10974.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: engine, usecase + :tickets: 10974 + + Added new core execution option + :paramref:`_engine.Connection.execution_options.preserve_rowcount`. When + set, the ``cursor.rowcount`` attribute from the DBAPI cursor will be + unconditionally memoized at statement execution time, so that whatever + value the DBAPI offers for any kind of statement will be available using + the :attr:`_engine.CursorResult.rowcount` attribute from the + :class:`_engine.CursorResult`. This allows the rowcount to be accessed for + statments such as INSERT and SELECT, to the degree supported by the DBAPI + in use. The :ref:`engine_insertmanyvalues` also supports this option and + will ensure :attr:`_engine.CursorResult.rowcount` is correctly set for a + bulk INSERT of rows when set. diff --git a/doc/build/tutorial/data_update.rst b/doc/build/tutorial/data_update.rst index a82f070a3f..48cf5c058a 100644 --- a/doc/build/tutorial/data_update.rst +++ b/doc/build/tutorial/data_update.rst @@ -279,17 +279,24 @@ Facts about :attr:`_engine.CursorResult.rowcount`: the statement. It does not matter if the row were actually modified or not. * :attr:`_engine.CursorResult.rowcount` is not necessarily available for an UPDATE - or DELETE statement that uses RETURNING. + or DELETE statement that uses RETURNING, or for one that uses an + :ref:`executemany ` execution. The availablility + depends on the DBAPI module in use. -* For an :ref:`executemany ` execution, - :attr:`_engine.CursorResult.rowcount` may not be available either, which depends - highly on the DBAPI module in use as well as configured options. The - attribute :attr:`_engine.CursorResult.supports_sane_multi_rowcount` indicates - if this value will be available for the current backend in use. +* In any case where the DBAPI does not determine the rowcount for some type + of statement, the returned value will be ``-1``. + +* SQLAlchemy pre-memoizes the DBAPIs ``cursor.rowcount`` value before the cursor + is closed, as some DBAPIs don't support accessing this attribute after the + fact. In order to pre-memoize ``cursor.rowcount`` for a statement that is + not UPDATE or DELETE, such as INSERT or SELECT, the + :paramref:`_engine.Connection.execution_options.preserve_rowcount` execution + option may be used. * Some drivers, particularly third party dialects for non-relational databases, may not support :attr:`_engine.CursorResult.rowcount` at all. The - :attr:`_engine.CursorResult.supports_sane_rowcount` will indicate this. + :attr:`_engine.CursorResult.supports_sane_rowcount` cursor attribute will + indicate this. * "rowcount" is used by the ORM :term:`unit of work` process to validate that an UPDATE or DELETE statement matched the expected number of rows, and is diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 98f7f6dce6..ff69d6aa14 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1841,7 +1841,6 @@ class MSExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False _select_lastrowid = False _lastrowid = None - _rowcount = None dialect: MSDialect @@ -1961,13 +1960,6 @@ class MSExecutionContext(default.DefaultExecutionContext): def get_lastrowid(self): return self._lastrowid - @property - def rowcount(self): - if self._rowcount is not None: - return self._rowcount - else: - return self.cursor.rowcount - def handle_dbapi_exception(self, e): if self._enable_identity_insert: try: diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index 86bc59d45a..c33ccd3b93 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -88,13 +88,6 @@ class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext): if self.isinsert and self.compiled.postfetch_lastrowid: self._lastrowid = self.cursor.lastrowid - @property - def rowcount(self): - if self._rowcount is not None: - return self._rowcount - else: - return self.cursor.rowcount - def get_lastrowid(self): return self._lastrowid diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index d46d159d4c..0c632b66f3 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -97,12 +97,7 @@ from ... import util class MySQLExecutionContext_mysqldb(MySQLExecutionContext): - @property - def rowcount(self): - if hasattr(self, "_rowcount"): - return self._rowcount - else: - return self.cursor.rowcount + pass class MySQLCompiler_mysqldb(MySQLCompiler): diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 17b14f4d05..325ea88699 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -8,7 +8,7 @@ from types import ModuleType -from . import array as arraylib # noqa # must be above base and other dialects +from . import array as arraylib # noqa # keep above base and other dialects from . import asyncpg # noqa from . import base from . import pg8000 # noqa diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index b3577ecca2..63631bdbd7 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -254,6 +254,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): yield_per: int = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., + preserve_rowcount: bool = False, **opt: Any, ) -> Connection: ... @@ -494,6 +495,18 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): :ref:`schema_translating` + :param preserve_rowcount: Boolean; when True, the ``cursor.rowcount`` + attribute will be unconditionally memoized within the result and + made available via the :attr:`.CursorResult.rowcount` attribute. + Normally, this attribute is only preserved for UPDATE and DELETE + statements. Using this option, the DBAPIs rowcount value can + be accessed for other kinds of statements such as INSERT and SELECT, + to the degree that the DBAPI supports these statements. See + :attr:`.CursorResult.rowcount` for notes regarding the behavior + of this attribute. + + .. versionadded:: 2.0.28 + .. seealso:: :meth:`_engine.Engine.execution_options` @@ -1835,10 +1848,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): context.pre_exec() if context.execute_style is ExecuteStyle.INSERTMANYVALUES: - return self._exec_insertmany_context( - dialect, - context, - ) + return self._exec_insertmany_context(dialect, context) else: return self._exec_single_context( dialect, context, statement, parameters @@ -2022,6 +2032,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._echo: stats = context._get_cache_stats() + " (insertmanyvalues)" + preserve_rowcount = context.execution_options.get( + "preserve_rowcount", False + ) + rowcount = 0 + for imv_batch in dialect._deliver_insertmanyvalues_batches( cursor, str_statement, @@ -2132,9 +2147,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): context.executemany, ) + if preserve_rowcount: + rowcount += imv_batch.current_batch_size + try: context.post_exec() + if preserve_rowcount: + context._rowcount = rowcount # type: ignore[attr-defined] + result = context._setup_result_proxy() except BaseException as e: diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 89a443bc0b..004274ec5a 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1981,8 +1981,28 @@ class CursorResult(Result[Unpack[_Ts]]): def rowcount(self) -> int: """Return the 'rowcount' for this result. - The 'rowcount' reports the number of rows *matched* - by the WHERE criterion of an UPDATE or DELETE statement. + The primary purpose of 'rowcount' is to report the number of rows + matched by the WHERE criterion of an UPDATE or DELETE statement + executed once (i.e. for a single parameter set), which may then be + compared to the number of rows expected to be updated or deleted as a + means of asserting data integrity. + + This attribute is transferred from the ``cursor.rowcount`` attribute + of the DBAPI before the cursor is closed, to support DBAPIs that + don't make this value available after cursor close. Some DBAPIs may + offer meaningful values for other kinds of statements, such as INSERT + and SELECT statements as well. In order to retrieve ``cursor.rowcount`` + for these statements, set the + :paramref:`.Connection.execution_options.preserve_rowcount` + execution option to True, which will cause the ``cursor.rowcount`` + value to be unconditionally memoized before any results are returned + or the cursor is closed, regardless of statement type. + + For cases where the DBAPI does not support rowcount for a particular + kind of statement and/or execution, the returned value will be ``-1``, + which is delivered directly from the DBAPI and is part of :pep:`249`. + All DBAPIs should support rowcount for single-parameter-set + UPDATE and DELETE statements, however. .. note:: @@ -1991,38 +2011,47 @@ class CursorResult(Result[Unpack[_Ts]]): * This attribute returns the number of rows *matched*, which is not necessarily the same as the number of rows - that were actually *modified* - an UPDATE statement, for example, + that were actually *modified*. For example, an UPDATE statement may have no net change on a given row if the SET values given are the same as those present in the row already. Such a row would be matched but not modified. On backends that feature both styles, such as MySQL, - rowcount is configured by default to return the match + rowcount is configured to return the match count in all cases. - * :attr:`_engine.CursorResult.rowcount` - is *only* useful in conjunction - with an UPDATE or DELETE statement. Contrary to what the Python - DBAPI says, it does *not* reliably return the - number of rows available from the results of a SELECT statement - as DBAPIs cannot support this functionality when rows are - unbuffered. - - * :attr:`_engine.CursorResult.rowcount` - may not be fully implemented by - all dialects. In particular, most DBAPIs do not support an - aggregate rowcount result from an executemany call. - The :meth:`_engine.CursorResult.supports_sane_rowcount` and - :meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods - will report from the dialect if each usage is known to be - supported. - - * Statements that use RETURNING may not return a correct - rowcount. + * :attr:`_engine.CursorResult.rowcount` in the default case is + *only* useful in conjunction with an UPDATE or DELETE statement, + and only with a single set of parameters. For other kinds of + statements, SQLAlchemy will not attempt to pre-memoize the value + unless the + :paramref:`.Connection.execution_options.preserve_rowcount` + execution option is used. Note that contrary to :pep:`249`, many + DBAPIs do not support rowcount values for statements that are not + UPDATE or DELETE, particularly when rows are being returned which + are not fully pre-buffered. DBAPIs that dont support rowcount + for a particular kind of statement should return the value ``-1`` + for such statements. + + * :attr:`_engine.CursorResult.rowcount` may not be meaningful + when executing a single statement with multiple parameter sets + (i.e. an :term:`executemany`). Most DBAPIs do not sum "rowcount" + values across multiple parameter sets and will return ``-1`` + when accessed. + + * SQLAlchemy's :ref:`engine_insertmanyvalues` feature does support + a correct population of :attr:`_engine.CursorResult.rowcount` + when the :paramref:`.Connection.execution_options.preserve_rowcount` + execution option is set to True. + + * Statements that use RETURNING may not support rowcount, returning + a ``-1`` value instead. .. seealso:: :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial` + :paramref:`.Connection.execution_options.preserve_rowcount` + """ # noqa: E501 try: return self.context.rowcount @@ -2118,8 +2147,7 @@ class CursorResult(Result[Unpack[_Ts]]): self, *others: Result[Unpack[TupleAny]] ) -> MergedResult[Unpack[TupleAny]]: merged_result = super().merge(*others) - setup_rowcounts = self.context._has_rowcount - if setup_rowcounts: + if self.context._has_rowcount: merged_result.rowcount = sum( cast("CursorResult[Any]", result).rowcount for result in (self,) + others diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 7eb7d0eb8b..b6782ff32e 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1207,7 +1207,7 @@ class DefaultExecutionContext(ExecutionContext): _soft_closed = False - _has_rowcount = False + _rowcount: Optional[int] = None # a hook for SQLite's translation of # result column names @@ -1797,7 +1797,14 @@ class DefaultExecutionContext(ExecutionContext): @util.non_memoized_property def rowcount(self) -> int: - return self.cursor.rowcount + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + + @property + def _has_rowcount(self): + return self._rowcount is not None def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount @@ -1808,6 +1815,9 @@ class DefaultExecutionContext(ExecutionContext): def _setup_result_proxy(self): exec_opt = self.execution_options + if self._rowcount is None and exec_opt.get("preserve_rowcount", False): + self._rowcount = self.cursor.rowcount + if self.is_crud or self.is_text: result = self._setup_dml_or_text_result() yp = sr = False @@ -1964,8 +1974,7 @@ class DefaultExecutionContext(ExecutionContext): if rows: self.returned_default_rows = rows - result.rowcount = len(rows) - self._has_rowcount = True + self._rowcount = len(rows) if self._is_supplemental_returning: result._rewind(rows) @@ -1979,12 +1988,12 @@ class DefaultExecutionContext(ExecutionContext): elif not result._metadata.returns_rows: # no results, get rowcount # (which requires open cursor on some drivers) - result.rowcount - self._has_rowcount = True + if self._rowcount is None: + self._rowcount = self.cursor.rowcount result._soft_close() elif self.isupdate or self.isdelete: - result.rowcount - self._has_rowcount = True + if self._rowcount is None: + self._rowcount = self.cursor.rowcount return result @util.memoized_property diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 62476696e8..d4c5aef797 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -270,6 +270,7 @@ class _CoreKnownExecutionOptions(TypedDict, total=False): yield_per: int insertmanyvalues_page_size: int schema_translate_map: Optional[SchemaTranslateMapType] + preserve_rowcount: bool _ExecuteOptions = immutabledict[str, Any] @@ -2977,6 +2978,9 @@ class ExecutionContext: inline SQL expression value was fired off. Applies to inserts and updates.""" + execution_options: _ExecuteOptions + """Execution options associated with the current statement execution""" + @classmethod def _init_ddl( cls, diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 2b3a85465d..ae04833ad6 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -417,6 +417,7 @@ class AsyncConnection( yield_per: int = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., + preserve_rowcount: bool = False, **opt: Any, ) -> AsyncConnection: ... diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6a9fd22b65..3a94340052 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1732,6 +1732,7 @@ class Query( schema_translate_map: Optional[SchemaTranslateMapType] = ..., populate_existing: bool = False, autoflush: bool = False, + preserve_rowcount: bool = False, **opt: Any, ) -> Self: ... diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 798a35eed4..a7bc18c5a4 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1166,6 +1166,7 @@ class Executable(roles.StatementRole): render_nulls: bool = ..., is_delete_using: bool = ..., is_update_from: bool = ..., + preserve_rowcount: bool = False, **opt: Any, ) -> Self: ... diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4c30b93638..9d4becf5a6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -602,7 +602,7 @@ class _InsertManyValuesBatch(NamedTuple): replaced_parameters: _DBAPIAnyExecuteParams processed_setinputsizes: Optional[_GenericSetInputSizesType] batch: Sequence[_DBAPISingleExecuteParams] - batch_size: int + current_batch_size: int batchnum: int total_batches: int rows_sorted: bool @@ -5406,7 +5406,7 @@ class SQLCompiler(Compiled): param, generic_setinputsizes, [param], - batch_size, + 1, batchnum, lenparams, sort_by_parameter_order, @@ -5437,7 +5437,7 @@ class SQLCompiler(Compiled): ), ) - batches = list(parameters) + batches = cast("List[Sequence[Any]]", list(parameters)) processed_setinputsizes: Optional[_GenericSetInputSizesType] = None batchnum = 1 @@ -5531,8 +5531,12 @@ class SQLCompiler(Compiled): ) while batches: - batch = cast("Sequence[Any]", batches[0:batch_size]) + batch = batches[0:batch_size] batches[0:batch_size] = [] + if batches: + current_batch_size = batch_size + else: + current_batch_size = len(batch) if generic_setinputsizes: # if setinputsizes is present, expand this collection to @@ -5542,7 +5546,7 @@ class SQLCompiler(Compiled): (new_key, len_, typ) for new_key, len_, typ in ( (f"{key}_{index}", len_, typ) - for index in range(len(batch)) + for index in range(current_batch_size) for key, len_, typ in generic_setinputsizes ) ] @@ -5552,6 +5556,9 @@ class SQLCompiler(Compiled): num_ins_params = imv.num_positional_params_counted batch_iterator: Iterable[Sequence[Any]] + extra_params_left: Sequence[Any] + extra_params_right: Sequence[Any] + if num_ins_params == len(batch[0]): extra_params_left = extra_params_right = () batch_iterator = batch @@ -5574,7 +5581,7 @@ class SQLCompiler(Compiled): )[:-2] else: expanded_values_string = ( - (executemany_values_w_comma * len(batch)) + (executemany_values_w_comma * current_batch_size) )[:-2] if self._numeric_binds and num_ins_params > 0: @@ -5590,7 +5597,7 @@ class SQLCompiler(Compiled): assert not extra_params_right start = expand_pos_lower_index + 1 - end = num_ins_params * (len(batch)) + start + end = num_ins_params * (current_batch_size) + start # need to format here, since statement may contain # unescaped %, while values_string contains just (%s, %s) @@ -5640,7 +5647,7 @@ class SQLCompiler(Compiled): replaced_parameters, processed_setinputsizes, batch, - batch_size, + current_batch_size, batchnum, total_batches, sort_by_parameter_order, diff --git a/lib/sqlalchemy/testing/fixtures/sql.py b/lib/sqlalchemy/testing/fixtures/sql.py index 1448510625..ab532ab0e6 100644 --- a/lib/sqlalchemy/testing/fixtures/sql.py +++ b/lib/sqlalchemy/testing/fixtures/sql.py @@ -478,10 +478,7 @@ def insertmanyvalues_fixture( yield batch - def _exec_insertmany_context( - dialect, - context, - ): + def _exec_insertmany_context(dialect, context): with mock.patch.object( dialect, "_deliver_insertmanyvalues_batches", diff --git a/test/requirements.py b/test/requirements.py index a692cd3fee..2e80884bc1 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -2061,3 +2061,17 @@ class DefaultRequirements(SuiteRequirements): return False return only_if(go, "json_each is required") + + @property + def rowcount_always_cached(self): + """Indicates that ``cursor.rowcount`` is always accessed, + usually in an ``ExecutionContext.post_exec``. + """ + return only_on(["+mariadbconnector"]) + + @property + def rowcount_always_cached_on_insert(self): + """Indicates that ``cursor.rowcount`` is always accessed in an insert + statement. + """ + return only_on(["mssql"]) diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index ce4caf30e9..16300aad0f 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -787,7 +787,8 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)]) - def test_insert_returning_values(self, connection): + @testing.variation("preserve_rowcount", [True, False]) + def test_insert_returning_values(self, connection, preserve_rowcount): t = self.tables.data conn = connection @@ -796,7 +797,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): {"x": "x%d" % i, "y": "y%d" % i} for i in range(1, page_size * 2 + 27) ] - result = conn.execute(t.insert().returning(t.c.x, t.c.y), data) + if preserve_rowcount: + eo = {"preserve_rowcount": True} + else: + eo = {} + + result = conn.execute( + t.insert().returning(t.c.x, t.c.y), data, execution_options=eo + ) eq_([tup[0] for tup in result.cursor.description], ["x", "y"]) eq_(result.keys(), ["x", "y"]) @@ -814,6 +822,9 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): # assert result.closed assert result.cursor is None + if preserve_rowcount: + eq_(result.rowcount, len(data)) + def test_insert_returning_preexecute_pk(self, metadata, connection): counter = itertools.count(1) @@ -1036,10 +1047,14 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): eq_(result.all(), [("p1_p1", "y1"), ("p2_p2", "y2")]) - def test_insert_returning_defaults(self, connection): + @testing.variation("preserve_rowcount", [True, False]) + def test_insert_returning_defaults(self, connection, preserve_rowcount): t = self.tables.data - conn = connection + if preserve_rowcount: + conn = connection.execution_options(preserve_rowcount=True) + else: + conn = connection result = conn.execute(t.insert(), {"x": "x0", "y": "y0"}) first_pk = result.inserted_primary_key[0] @@ -1054,6 +1069,9 @@ class InsertManyValuesTest(fixtures.RemovesEvents, fixtures.TablesTest): [(pk, 5) for pk in range(1 + first_pk, total_rows + first_pk)], ) + if preserve_rowcount: + eq_(result.rowcount, total_rows - 1) # range starts from 1 + def test_insert_return_pks_default_values(self, connection): """test sending multiple, empty rows into an INSERT and getting primary key values back. diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index e1b43b7fd1..938df1ac3a 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1,3 +1,4 @@ +from collections import defaultdict import collections.abc as collections_abc from contextlib import contextmanager import csv @@ -1733,6 +1734,29 @@ class CursorResultTest(fixtures.TablesTest): eq_(proxy.key, "value") eq_(proxy._mapping["key"], "value") + @contextmanager + def cursor_wrapper(self, engine): + calls = defaultdict(int) + + class CursorWrapper: + def __init__(self, real_cursor): + self.real_cursor = real_cursor + + def __getattr__(self, name): + calls[name] += 1 + return getattr(self.real_cursor, name) + + create_cursor = engine.dialect.execution_ctx_cls.create_cursor + + def new_create(context): + cursor = create_cursor(context) + return CursorWrapper(cursor) + + with patch.object( + engine.dialect.execution_ctx_cls, "create_cursor", new_create + ): + yield calls + def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine): """assert that rowcount is only called on deletes and updates. @@ -1744,33 +1768,71 @@ class CursorResultTest(fixtures.TablesTest): engine = testing_engine() + req = testing.requires + t = Table("t1", metadata, Column("data", String(10))) metadata.create_all(engine) - - with patch.object( - engine.dialect.execution_ctx_cls, "rowcount" - ) as mock_rowcount: + count = 0 + with self.cursor_wrapper(engine) as call_counts: with engine.begin() as conn: - mock_rowcount.__get__ = Mock() conn.execute( t.insert(), [{"data": "d1"}, {"data": "d2"}, {"data": "d3"}], ) - - eq_(len(mock_rowcount.__get__.mock_calls), 0) + if ( + req.rowcount_always_cached.enabled + or req.rowcount_always_cached_on_insert.enabled + ): + count += 1 + eq_(call_counts["rowcount"], count) eq_( conn.execute(t.select()).fetchall(), [("d1",), ("d2",), ("d3",)], ) - eq_(len(mock_rowcount.__get__.mock_calls), 0) + if req.rowcount_always_cached.enabled: + count += 1 + eq_(call_counts["rowcount"], count) + + conn.execute(t.update(), {"data": "d4"}) + + count += 1 + eq_(call_counts["rowcount"], count) + + conn.execute(t.delete()) + count += 1 + eq_(call_counts["rowcount"], count) + + def test_rowcount_always_called_when_preserve_rowcount( + self, metadata, testing_engine + ): + """assert that rowcount is called on any statement when + ``preserve_rowcount=True``. + + """ + + engine = testing_engine() + + t = Table("t1", metadata, Column("data", String(10))) + metadata.create_all(engine) + + with self.cursor_wrapper(engine) as call_counts: + with engine.begin() as conn: + conn = conn.execution_options(preserve_rowcount=True) + # Do not use insertmanyvalues on any driver + conn.execute(t.insert(), {"data": "d1"}) + + eq_(call_counts["rowcount"], 1) + + eq_(conn.execute(t.select()).fetchall(), [("d1",)]) + eq_(call_counts["rowcount"], 2) conn.execute(t.update(), {"data": "d4"}) - eq_(len(mock_rowcount.__get__.mock_calls), 1) + eq_(call_counts["rowcount"], 3) conn.execute(t.delete()) - eq_(len(mock_rowcount.__get__.mock_calls), 2) + eq_(call_counts["rowcount"], 4) def test_row_is_sequence(self): row = Row(object(), [None], {}, ["value"]) diff --git a/test/typing/test_overloads.py b/test/typing/test_overloads.py index 968b60d926..66209f5036 100644 --- a/test/typing/test_overloads.py +++ b/test/typing/test_overloads.py @@ -24,6 +24,7 @@ core_execution_options = { "stream_results": "bool", "max_row_buffer": "int", "yield_per": "int", + "preserve_rowcount": "bool", } orm_dql_execution_options = { -- 2.47.2