]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix a regression from ref #3178, where dialects that don't actually support
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Jan 2015 02:36:52 +0000 (21:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Jan 2015 02:58:49 +0000 (21:58 -0500)
sane multi rowcount (e.g. pyodbc) would fail on multirow update.  add
a test that mocks this breakage into plain dialects

lib/sqlalchemy/orm/persistence.py
test/orm/test_unitofworkv2.py

index f477e1dd7becb87b329115dbfd04640bcaab07e2..dbf1d3eb41866546518987c0d1f77d539a2c716e 100644 (file)
@@ -617,6 +617,14 @@ def _emit_update_statements(base_mapper, uowtransaction,
         rows = 0
         records = list(records)
 
+        # TODO: would be super-nice to not have to determine this boolean
+        # inside the loop here, in the 99.9999% of the time there's only
+        # one connection in use
+        assert_singlerow = connection.dialect.supports_sane_rowcount
+        assert_multirow = assert_singlerow and \
+            connection.dialect.supports_sane_multi_rowcount
+        allow_multirow = not needs_version_id or assert_multirow
+
         if hasvalue:
             for state, state_dict, params, mapper, \
                     connection, value_params in records:
@@ -635,9 +643,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
                         value_params)
                 rows += c.rowcount
         else:
-            if needs_version_id and \
-                not connection.dialect.supports_sane_multi_rowcount and \
-                    connection.dialect.supports_sane_rowcount:
+            if not allow_multirow:
                 for state, state_dict, params, mapper, \
                         connection, value_params in records:
                     c = cached_connections[connection].\
@@ -654,6 +660,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
                     rows += c.rowcount
             else:
                 multiparams = [rec[2] for rec in records]
+
                 c = cached_connections[connection].\
                     execute(statement, multiparams)
 
@@ -670,7 +677,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
                         c.context.compiled_parameters[0],
                         value_params)
 
-        if connection.dialect.supports_sane_rowcount:
+        if assert_multirow or assert_singlerow and \
+                len(multiparams) == 1:
             if rows != len(records):
                 raise orm_exc.StaleDataError(
                     "UPDATE statement on table '%s' expected to "
index 374a77237cc72e29c52824bf83ac2a84e28c4d41..681b104cff675881ae4371a1f4a899caac5afeab 100644 (file)
@@ -3,13 +3,13 @@ from sqlalchemy import testing
 from sqlalchemy.testing import engines
 from sqlalchemy.testing.schema import Table, Column
 from test.orm import _fixtures
-from sqlalchemy import exc
-from sqlalchemy.testing import fixtures
+from sqlalchemy import exc, util
+from sqlalchemy.testing import fixtures, config
 from sqlalchemy import Integer, String, ForeignKey, func
 from sqlalchemy.orm import mapper, relationship, backref, \
     create_session, unitofwork, attributes,\
     Session, exc as orm_exc
-from sqlalchemy.testing.mock import Mock
+from sqlalchemy.testing.mock import Mock, patch
 from sqlalchemy.testing.assertsql import AllOf, CompiledSQL
 from sqlalchemy import event
 
@@ -1473,6 +1473,67 @@ class BasicStaleChecksTest(fixtures.MappedTest):
             sess.flush
         )
 
+    def test_update_single_missing_broken_multi_rowcount(self):
+        @util.memoized_property
+        def rowcount(self):
+            if len(self.context.compiled_parameters) > 1:
+                return -1
+            else:
+                return self.context.rowcount
+
+        with patch.object(
+                config.db.dialect, "supports_sane_multi_rowcount", False):
+            with patch(
+                    "sqlalchemy.engine.result.ResultProxy.rowcount",
+                    rowcount):
+                Parent, Child = self._fixture()
+                sess = Session()
+                p1 = Parent(id=1, data=2)
+                sess.add(p1)
+                sess.flush()
+
+                sess.execute(self.tables.parent.delete())
+
+                p1.data = 3
+                assert_raises_message(
+                    orm_exc.StaleDataError,
+                    "UPDATE statement on table 'parent' expected to "
+                    "update 1 row\(s\); 0 were matched.",
+                    sess.flush
+                )
+
+    def test_update_multi_missing_broken_multi_rowcount(self):
+        @util.memoized_property
+        def rowcount(self):
+            if len(self.context.compiled_parameters) > 1:
+                return -1
+            else:
+                return self.context.rowcount
+
+        with patch.object(
+                config.db.dialect, "supports_sane_multi_rowcount", False):
+            with patch(
+                    "sqlalchemy.engine.result.ResultProxy.rowcount",
+                    rowcount):
+                Parent, Child = self._fixture()
+                sess = Session()
+                p1 = Parent(id=1, data=2)
+                p2 = Parent(id=2, data=3)
+                sess.add_all([p1, p2])
+                sess.flush()
+
+                sess.execute(self.tables.parent.delete().where(Parent.id == 1))
+
+                p1.data = 3
+                p2.data = 4
+                sess.flush()  # no exception
+
+                # update occurred for remaining row
+                eq_(
+                    sess.query(Parent.id, Parent.data).all(),
+                    [(2, 4)]
+                )
+
     @testing.requires.sane_multi_rowcount
     def test_delete_multi_missing_warning(self):
         Parent, Child = self._fixture()
@@ -1544,6 +1605,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
             T(id=10, data='t10', def_='def3'),
             T(id=11, data='t11'),
         ])
+
         self.assert_sql_execution(
             testing.db,
             sess.flush,