]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- further fixes and even better tests for this block
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Jan 2015 16:47:28 +0000 (11:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 19 Jan 2015 16:47:28 +0000 (11:47 -0500)
lib/sqlalchemy/orm/persistence.py
test/orm/test_versioning.py

index 7f81a5c99abf8664b856e36611406a0b55aad405..e553f399de6ee73561069fb1b114b774ebc624bc 100644 (file)
@@ -642,8 +642,10 @@ def _emit_update_statements(base_mapper, uowtransaction,
                         c.context.compiled_parameters[0],
                         value_params)
                 rows += c.rowcount
+                check_rowcount = True
         else:
             if not allow_multirow:
+                check_rowcount = assert_singlerow
                 for state, state_dict, params, mapper, \
                         connection, value_params in records:
                     c = cached_connections[connection].\
@@ -661,6 +663,11 @@ def _emit_update_statements(base_mapper, uowtransaction,
             else:
                 multiparams = [rec[2] for rec in records]
 
+                check_rowcount = assert_multirow or (
+                    assert_singlerow and
+                    len(multiparams) == 1
+                )
+
                 c = cached_connections[connection].\
                     execute(statement, multiparams)
 
@@ -677,9 +684,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
                         c.context.compiled_parameters[0],
                         value_params)
 
-        if hasvalue or assert_multirow or (
-                assert_singlerow and
-                len(multiparams)) == 1:
+        if check_rowcount:
             if rows != len(records):
                 raise orm_exc.StaleDataError(
                     "UPDATE statement on table '%s' expected to "
index 55ce586b5f5529c2aa6f98a121b501ce5523d529..8348cb58832645953c4ff2addf6577aacbe2fb52 100644 (file)
@@ -1,7 +1,8 @@
 import datetime
 import sqlalchemy as sa
-from sqlalchemy.testing import engines
+from sqlalchemy.testing import engines, config
 from sqlalchemy import testing
+from sqlalchemy.testing.mock import patch
 from sqlalchemy import (
     Integer, String, Date, ForeignKey, orm, exc, select, TypeDecorator)
 from sqlalchemy.testing.schema import Table, Column
@@ -12,6 +13,7 @@ from sqlalchemy.testing import (
     eq_, assert_raises, assert_raises_message, fixtures)
 from sqlalchemy.testing.assertsql import CompiledSQL
 import uuid
+from sqlalchemy import util
 
 
 def make_uuid():
@@ -223,6 +225,30 @@ class VersioningTest(fixtures.MappedTest):
         s1.refresh(f1s1, lockmode='update_nowait')
         assert f1s1.version_id == f1s2.version_id
 
+    def test_update_multi_missing_broken_multi_rowcount(self):
+        @util.memoized_property
+        def rowcount(self):
+            if len(self.context.compiled_parameters) > 1:
+                return -1
+            else:
+                return self.context.rowcount
+
+        with patch.object(
+                config.db.dialect, "supports_sane_multi_rowcount", False):
+            with patch(
+                    "sqlalchemy.engine.result.ResultProxy.rowcount",
+                    rowcount):
+
+                Foo = self.classes.Foo
+                s1 = self._fixture()
+                f1s1 = Foo(value='f1 value')
+                s1.add(f1s1)
+                s1.commit()
+
+                f1s1.value = 'f2 value'
+                s1.flush()
+                eq_(f1s1.version_id, 2)
+
     @testing.emits_warning(r'.*does not support updated rowcount')
     @engines.close_open_connections
     def test_noversioncheck(self):