]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add new sane_rowcount_w_returning flag
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Aug 2017 17:12:50 +0000 (13:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Aug 2017 17:20:57 +0000 (13:20 -0400)
Added a new class of "rowcount support" for dialects that is specific to
when "RETURNING", which on SQL Server looks like "OUTPUT inserted", is in
use, as the PyODBC backend isn't able to give us rowcount on an UPDATE or
DELETE statement when OUTPUT is in effect.  This primarily affects the ORM
when a flush is updating a row that contains server-calcluated values,
raising an error if the backend does not return the expected row count.
PyODBC now states that it supports rowcount except if OUTPUT.inserted is
present, which is taken into account by the ORM during a flush as to
whether it will look for a rowcount.

ORM tests are implicit in existing tests run against PyODBC

Fixes: #4062
Change-Id: Iff17cbe4c7a5742971ed85a4d58660c18cc569c2

doc/build/changelog/unreleased_12/4062.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/pyodbc.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/testing/requirements.py
test/requirements.py
test/sql/test_rowcount.py

diff --git a/doc/build/changelog/unreleased_12/4062.rst b/doc/build/changelog/unreleased_12/4062.rst
new file mode 100644 (file)
index 0000000..3a89a1a
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, mssql, orm
+    :tickets: 4062
+
+    Added a new class of "rowcount support" for dialects that is specific to
+    when "RETURNING", which on SQL Server looks like "OUTPUT inserted", is in
+    use, as the PyODBC backend isn't able to give us rowcount on an UPDATE or
+    DELETE statement when OUTPUT is in effect.  This primarily affects the ORM
+    when a flush is updating a row that contains server-calcluated values,
+    raising an error if the backend does not return the expected row count.
+    PyODBC now states that it supports rowcount except if OUTPUT.inserted is
+    present, which is taken into account by the ORM during a flush as to
+    whether it will look for a rowcount.
index 65fe37212e0abfedcd7f525625a57dc74220f02e..66acf007252a476f6276c2f5a93d5ed0f9790f64 100644 (file)
@@ -16,6 +16,7 @@ import re
 class PyODBCConnector(Connector):
     driver = 'pyodbc'
 
+    supports_sane_rowcount_returning = False
     supports_sane_multi_rowcount = False
 
     if util.py2k:
index d1b54ab01a0a568e921573fbd852c0b7a535fa4e..8b72c0001fb8c88f2ec708606b3e058a29383b78 100644 (file)
@@ -249,6 +249,10 @@ class DefaultDialect(interfaces.Dialect):
     def dialect_description(self):
         return self.name + "+" + self.driver
 
+    @property
+    def supports_sane_rowcount_returning(self):
+        return self.supports_sane_rowcount
+
     @classmethod
     def get_pool_class(cls, url):
         return getattr(cls, 'poolclass', pool.QueuePool)
index faacd018e1c5852a812b806d33f0178865da205a..24c9743d4adb9cb5e6941034df8c2423e7035db8 100644 (file)
@@ -693,22 +693,28 @@ def _emit_update_statements(base_mapper, uowtransaction,
         records = list(records)
 
         statement = cached_stmt
-
-        # TODO: would be super-nice to not have to determine this boolean
-        # inside the loop here, in the 99.9999% of the time there's only
-        # one connection in use
-        assert_singlerow = connection.dialect.supports_sane_rowcount
-        assert_multirow = assert_singlerow and \
-            connection.dialect.supports_sane_multi_rowcount
-        allow_multirow = has_all_defaults and not needs_version_id
+        return_defaults = False
 
         if not has_all_pks:
             statement = statement.return_defaults()
+            return_defaults = True
         elif bookkeeping and not has_all_defaults and \
                 mapper.base_mapper.eager_defaults:
             statement = statement.return_defaults()
+            return_defaults = True
         elif mapper.version_id_col is not None:
             statement = statement.return_defaults(mapper.version_id_col)
+            return_defaults = True
+
+        assert_singlerow = (
+            connection.dialect.supports_sane_rowcount
+            if not return_defaults
+            else connection.dialect.supports_sane_rowcount_returning
+        )
+
+        assert_multirow = assert_singlerow and \
+            connection.dialect.supports_sane_multi_rowcount
+        allow_multirow = has_all_defaults and not needs_version_id
 
         if hasvalue:
             for state, state_dict, params, mapper, \
@@ -728,7 +734,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
                         c.context.compiled_parameters[0],
                         value_params)
                 rows += c.rowcount
-                check_rowcount = True
+                check_rowcount = assert_singlerow
         else:
             if not allow_multirow:
                 check_rowcount = assert_singlerow
index 08a7b1cedc0ff29f46edf6ce5d2e8c89b2c56007..327362bf6c64d55545f3ec3d3a40a9ff505af498 100644 (file)
@@ -193,6 +193,29 @@ class SuiteRequirements(Requirements):
 
         return exclusions.open()
 
+
+    @property
+    def sane_rowcount(self):
+        return exclusions.skip_if(
+            lambda config: not config.db.dialect.supports_sane_rowcount,
+            "driver doesn't support 'sane' rowcount"
+        )
+
+    @property
+    def sane_multi_rowcount(self):
+        return exclusions.fails_if(
+            lambda config: not config.db.dialect.supports_sane_multi_rowcount,
+            "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
+        )
+
+    @property
+    def sane_rowcount_w_returning(self):
+        return exclusions.fails_if(
+            lambda config:
+                not config.db.dialect.supports_sane_rowcount_returning,
+            "driver doesn't support 'sane' rowcount when returning is on"
+        )
+
     @property
     def empty_inserts(self):
         """target platform supports INSERT with no values, i.e.
index 4f01eac9bcd3e3077ad98f8c4eaa575dc356a866..0362e28d139d2abd96aa5fc904533844c6389c19 100644 (file)
@@ -558,13 +558,6 @@ class DefaultRequirements(SuiteRequirements):
             exclude('mysql', '<', (4, 1, 1), 'no unicode connection support'),
         ])
 
-    @property
-    def sane_rowcount(self):
-        return skip_if(
-            lambda config: not config.db.dialect.supports_sane_rowcount,
-            "driver doesn't support 'sane' rowcount"
-        )
-
     @property
     def emulated_lastrowid(self):
         """"target dialect retrieves cursor.lastrowid or an equivalent
@@ -593,13 +586,6 @@ class DefaultRequirements(SuiteRequirements):
                                        'sqlite+pysqlite',
                                        'sqlite+pysqlcipher')
 
-    @property
-    def sane_multi_rowcount(self):
-        return fails_if(
-            lambda config: not config.db.dialect.supports_sane_multi_rowcount,
-            "driver %(driver)s %(doesnt_support)s 'sane' multi row count"
-        )
-
     @property
     def nullsordering(self):
         """Target backends that support nulls ordering."""
index 16087b94ccd749880c77dadcc41a1e920b746dff..3399ba7ec9f43672e28cdba26f5515d072a927c9 100644 (file)
@@ -65,6 +65,15 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
         r = employees_table.update(department == 'C').execute(department='C')
         assert r.rowcount == 3
 
+    @testing.requires.sane_rowcount_w_returning
+    def test_update_rowcount_return_defaults(self):
+        department = employees_table.c.department
+        stmt = employees_table.update(department == 'C').values(
+            name=employees_table.c.department + 'Z').return_defaults()
+
+        r = stmt.execute()
+        assert r.rowcount == 3
+
     def test_raw_sql_rowcount(self):
         # test issue #3622, make sure eager rowcount is called for text
         with testing.db.connect() as conn:
@@ -117,3 +126,4 @@ class FoundRowsTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(
             r.rowcount, 2
         )
+