From: Mike Bayer Date: Thu, 31 Aug 2017 17:12:50 +0000 (-0400) Subject: Add new sane_rowcount_w_returning flag X-Git-Tag: origin~31 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b9b1e374bfbcece8259a4df5372ca68d45aaaf01;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add new sane_rowcount_w_returning flag Added a new class of "rowcount support" for dialects that is specific to when "RETURNING", which on SQL Server looks like "OUTPUT inserted", is in use, as the PyODBC backend isn't able to give us rowcount on an UPDATE or DELETE statement when OUTPUT is in effect. This primarily affects the ORM when a flush is updating a row that contains server-calcluated values, raising an error if the backend does not return the expected row count. PyODBC now states that it supports rowcount except if OUTPUT.inserted is present, which is taken into account by the ORM during a flush as to whether it will look for a rowcount. ORM tests are implicit in existing tests run against PyODBC Fixes: #4062 Change-Id: Iff17cbe4c7a5742971ed85a4d58660c18cc569c2 --- diff --git a/doc/build/changelog/unreleased_12/4062.rst b/doc/build/changelog/unreleased_12/4062.rst new file mode 100644 index 0000000000..3a89a1ad67 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4062.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, mssql, orm + :tickets: 4062 + + Added a new class of "rowcount support" for dialects that is specific to + when "RETURNING", which on SQL Server looks like "OUTPUT inserted", is in + use, as the PyODBC backend isn't able to give us rowcount on an UPDATE or + DELETE statement when OUTPUT is in effect. This primarily affects the ORM + when a flush is updating a row that contains server-calcluated values, + raising an error if the backend does not return the expected row count. + PyODBC now states that it supports rowcount except if OUTPUT.inserted is + present, which is taken into account by the ORM during a flush as to + whether it will look for a rowcount. diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 65fe37212e..66acf00725 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -16,6 +16,7 @@ import re class PyODBCConnector(Connector): driver = 'pyodbc' + supports_sane_rowcount_returning = False supports_sane_multi_rowcount = False if util.py2k: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index d1b54ab01a..8b72c0001f 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -249,6 +249,10 @@ class DefaultDialect(interfaces.Dialect): def dialect_description(self): return self.name + "+" + self.driver + @property + def supports_sane_rowcount_returning(self): + return self.supports_sane_rowcount + @classmethod def get_pool_class(cls, url): return getattr(cls, 'poolclass', pool.QueuePool) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index faacd018e1..24c9743d4a 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -693,22 +693,28 @@ def _emit_update_statements(base_mapper, uowtransaction, records = list(records) statement = cached_stmt - - # TODO: would be super-nice to not have to determine this boolean - # inside the loop here, in the 99.9999% of the time there's only - # one connection in use - assert_singlerow = connection.dialect.supports_sane_rowcount - assert_multirow = assert_singlerow and \ - connection.dialect.supports_sane_multi_rowcount - allow_multirow = has_all_defaults and not needs_version_id + return_defaults = False if not has_all_pks: statement = statement.return_defaults() + return_defaults = True elif bookkeeping and not has_all_defaults and \ mapper.base_mapper.eager_defaults: statement = statement.return_defaults() + return_defaults = True elif mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) + return_defaults = True + + assert_singlerow = ( + connection.dialect.supports_sane_rowcount + if not return_defaults + else connection.dialect.supports_sane_rowcount_returning + ) + + assert_multirow = assert_singlerow and \ + connection.dialect.supports_sane_multi_rowcount + allow_multirow = has_all_defaults and not needs_version_id if hasvalue: for state, state_dict, params, mapper, \ @@ -728,7 +734,7 @@ def _emit_update_statements(base_mapper, uowtransaction, c.context.compiled_parameters[0], value_params) rows += c.rowcount - check_rowcount = True + check_rowcount = assert_singlerow else: if not allow_multirow: check_rowcount = assert_singlerow diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 08a7b1cedc..327362bf6c 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -193,6 +193,29 @@ class SuiteRequirements(Requirements): return exclusions.open() + + @property + def sane_rowcount(self): + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_sane_rowcount, + "driver doesn't support 'sane' rowcount" + ) + + @property + def sane_multi_rowcount(self): + return exclusions.fails_if( + lambda config: not config.db.dialect.supports_sane_multi_rowcount, + "driver %(driver)s %(doesnt_support)s 'sane' multi row count" + ) + + @property + def sane_rowcount_w_returning(self): + return exclusions.fails_if( + lambda config: + not config.db.dialect.supports_sane_rowcount_returning, + "driver doesn't support 'sane' rowcount when returning is on" + ) + @property def empty_inserts(self): """target platform supports INSERT with no values, i.e. diff --git a/test/requirements.py b/test/requirements.py index 4f01eac9bc..0362e28d13 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -558,13 +558,6 @@ class DefaultRequirements(SuiteRequirements): exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'), ]) - @property - def sane_rowcount(self): - return skip_if( - lambda config: not config.db.dialect.supports_sane_rowcount, - "driver doesn't support 'sane' rowcount" - ) - @property def emulated_lastrowid(self): """"target dialect retrieves cursor.lastrowid or an equivalent @@ -593,13 +586,6 @@ class DefaultRequirements(SuiteRequirements): 'sqlite+pysqlite', 'sqlite+pysqlcipher') - @property - def sane_multi_rowcount(self): - return fails_if( - lambda config: not config.db.dialect.supports_sane_multi_rowcount, - "driver %(driver)s %(doesnt_support)s 'sane' multi row count" - ) - @property def nullsordering(self): """Target backends that support nulls ordering.""" diff --git a/test/sql/test_rowcount.py b/test/sql/test_rowcount.py index 16087b94cc..3399ba7ec9 100644 --- a/test/sql/test_rowcount.py +++ b/test/sql/test_rowcount.py @@ -65,6 +65,15 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): r = employees_table.update(department == 'C').execute(department='C') assert r.rowcount == 3 + @testing.requires.sane_rowcount_w_returning + def test_update_rowcount_return_defaults(self): + department = employees_table.c.department + stmt = employees_table.update(department == 'C').values( + name=employees_table.c.department + 'Z').return_defaults() + + r = stmt.execute() + assert r.rowcount == 3 + def test_raw_sql_rowcount(self): # test issue #3622, make sure eager rowcount is called for text with testing.db.connect() as conn: @@ -117,3 +126,4 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults): eq_( r.rowcount, 2 ) +