]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Additional fixes to sane rowcount
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Aug 2017 19:27:26 +0000 (15:27 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Aug 2017 19:27:26 +0000 (15:27 -0400)
Implement rowcount assertions and single row check
for post_update as well as deletes.

Change-Id: I4e5ba7e8747bf0e0b41f569089eb8cdbf064b7a9
Fixes: #4062
lib/sqlalchemy/orm/persistence.py
test/orm/inheritance/test_basic.py
test/orm/test_bulk.py
test/orm/test_session.py
test/orm/test_versioning.py

index 24c9743d4adb9cb5e6941034df8c2423e7035db8..b8fd0c79f2810c14d378f8ab7063b6a6f3fa705a 100644 (file)
@@ -934,11 +934,16 @@ def _emit_post_update_statements(base_mapper, uowtransaction,
         records = list(records)
         connection = key[0]
 
-        assert_singlerow = connection.dialect.supports_sane_rowcount
+        assert_singlerow = (
+            connection.dialect.supports_sane_rowcount
+            if mapper.version_id_col is None
+            else connection.dialect.supports_sane_rowcount_returning
+        )
         assert_multirow = assert_singlerow and \
             connection.dialect.supports_sane_multi_rowcount
         allow_multirow = not needs_version_id or assert_multirow
 
+
         if not allow_multirow:
             check_rowcount = assert_singlerow
             for state, state_dict, mapper_rec, \
@@ -1043,7 +1048,12 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections,
                     stacklevel=12)
                 connection.execute(statement, del_objects)
         else:
-            connection.execute(statement, del_objects)
+            c = connection.execute(statement, del_objects)
+
+            if not need_version_id:
+                only_warn = True
+
+            rows_matched = c.rowcount
 
         if base_mapper.confirm_deleted_rows and \
                 rows_matched > -1 and expected != rows_matched:
index 65c4b030916b1117a9c3c982355a7dfa2057bbbc..007061d60cf0c57078ab145c68812738bba2c8a2 100644 (file)
@@ -1619,6 +1619,7 @@ class VersioningTest(fixtures.MappedTest):
               Column('parent', Integer, ForeignKey('base.id')))
 
     @testing.emits_warning(r".*updated rowcount")
+    @testing.requires.sane_rowcount_w_returning
     @engines.close_open_connections
     def test_save_update(self):
         subtable, base, stuff = (self.tables.subtable,
@@ -1675,6 +1676,7 @@ class VersioningTest(fixtures.MappedTest):
         sess2.flush()
 
     @testing.emits_warning(r".*(update|delete)d rowcount")
+    @testing.requires.sane_rowcount_w_returning
     def test_delete(self):
         subtable, base = self.tables.subtable, self.tables.base
 
index 59be0c88fc3e55404002d29c8742a6e99afcb702..0763fe70cd905cfbe3d68f132877b943f04a2eee 100644 (file)
@@ -33,6 +33,7 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
 
         mapper(Foo, version_table, version_id_col=version_table.c.version_id)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_bulk_insert_via_save(self):
         Foo = self.classes.Foo
 
@@ -45,6 +46,7 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
             [Foo(version_id=1, value='value')]
         )
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_bulk_update_via_save(self):
         Foo = self.classes.Foo
 
@@ -349,9 +351,11 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
     def test_update_keys(self):
         self._test_update(self.classes.PersonKeys)
 
+    @testing.requires.updateable_autoincrement_pks
     def test_update_attrs(self):
         self._test_update(self.classes.PersonAttrs)
 
+    @testing.requires.updateable_autoincrement_pks
     def test_update_both(self):
         # want to make sure that before [ticket:3849], this did not have
         # a successful behavior or workaround
index 4fb90d603415875266fca916944bef3575ac96a4..363d5f7821bd4995ac5f57655b351d03bdedb772 100644 (file)
@@ -355,6 +355,7 @@ class SessionStateTest(_fixtures.FixtureTest):
 
         eq_(sess.query(User).count(), 1)
 
+    @testing.requires.sane_rowcount
     def test_deleted_adds_to_imap_unconditionally(self):
         users, User = self.tables.users, self.classes.User
 
index 4d9d6883acb1e056d887972cbfcc00b86b7ef8af..ed5f78465f23a9fa2dbf4f3ed04231c6bf9c072b 100644 (file)
@@ -71,6 +71,7 @@ class NullVersionIdTest(fixtures.MappedTest):
             "Instance does not contain a non-NULL version value",
             s1.commit)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_null_version_id_update(self):
         Foo = self.classes.Foo
 
@@ -134,8 +135,8 @@ class VersioningTest(fixtures.MappedTest):
         finally:
             testing.db.dialect.supports_sane_rowcount = save
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
+    @testing.requires.sane_rowcount_w_returning
     def test_basic(self):
         Foo = self.classes.Foo
 
@@ -185,6 +186,7 @@ class VersioningTest(fixtures.MappedTest):
         else:
             s1.commit()
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_multiple_updates(self):
         Foo = self.classes.Foo
 
@@ -203,6 +205,7 @@ class VersioningTest(fixtures.MappedTest):
             [(f1.id, 'f1rev2', 2), (f2.id, 'f2rev2', 2)]
         )
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_bulk_insert(self):
         Foo = self.classes.Foo
 
@@ -216,6 +219,7 @@ class VersioningTest(fixtures.MappedTest):
             [(1, 'f1', 1), (2, 'f2', 1)]
         )
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_bulk_update(self):
         Foo = self.classes.Foo
 
@@ -240,8 +244,7 @@ class VersioningTest(fixtures.MappedTest):
             [(f1.id, 'f1rev2', 2), (f2.id, 'f2rev2', 2)]
         )
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_bump_version(self):
         """test that version number can be bumped.
 
@@ -274,7 +277,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
         eq_(s1.query(Foo).count(), 0)
 
-    @testing.emits_warning(r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     @engines.close_open_connections
     def test_versioncheck(self):
         """query.with_lockmode performs a 'version check' on an already loaded
@@ -310,7 +313,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.close()
         s1.query(Foo).with_for_update(read=True).get(f1s1.id)
 
-    @testing.emits_warning(r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     @engines.close_open_connections
     def test_versioncheck_legacy(self):
         """query.with_lockmode performs a 'version check' on an already loaded
@@ -360,7 +363,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
         s1.query(Foo).with_lockmode('read').get(f1s1.id)
 
-    @testing.emits_warning(r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     @engines.close_open_connections
     @testing.requires.update_nowait
     def test_versioncheck_for_update(self):
@@ -390,7 +393,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.refresh(f1s1, with_for_update={"nowait": True})
         assert f1s1.version_id == f1s2.version_id
 
-    @testing.emits_warning(r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     @engines.close_open_connections
     @testing.requires.update_nowait
     def test_versioncheck_for_update_legacy(self):
@@ -419,6 +422,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.refresh(f1s1, lockmode='update_nowait')
         assert f1s1.version_id == f1s2.version_id
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_update_multi_missing_broken_multi_rowcount(self):
         @util.memoized_property
         def rowcount(self):
@@ -462,8 +466,7 @@ class VersioningTest(fixtures.MappedTest):
         assert f1s2.id == f1s1.id
         assert f1s2.value == f1s1.value
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_merge_no_version(self):
         Foo = self.classes.Foo
 
@@ -481,8 +484,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
         eq_(f3.version_id, 3)
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_merge_correct_version(self):
         Foo = self.classes.Foo
 
@@ -500,8 +502,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
         eq_(f3.version_id, 3)
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_merge_incorrect_version(self):
         Foo = self.classes.Foo
 
@@ -523,8 +524,7 @@ class VersioningTest(fixtures.MappedTest):
             s1.merge, f2
         )
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_merge_incorrect_version_not_in_session(self):
         Foo = self.classes.Foo
 
@@ -587,6 +587,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
             s.flush()
         return s, n1, n2
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_o2m_plain(self):
         s, n1, n2 = self._fixture(o2m=True, post_update=False)
 
@@ -596,6 +597,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         eq_(n1.version_id, 1)
         eq_(n2.version_id, 2)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_m2o_plain(self):
         s, n1, n2 = self._fixture(o2m=False, post_update=False)
 
@@ -605,6 +607,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         eq_(n1.version_id, 2)
         eq_(n2.version_id, 1)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_o2m_post_update(self):
         s, n1, n2 = self._fixture(o2m=True, post_update=True)
 
@@ -614,6 +617,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         eq_(n1.version_id, 1)
         eq_(n2.version_id, 2)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_m2o_post_update(self):
         s, n1, n2 = self._fixture(o2m=False, post_update=True)
 
@@ -623,6 +627,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         eq_(n1.version_id, 2)
         eq_(n2.version_id, 1)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_o2m_post_update_not_assoc_w_insert(self):
         s, n1, n2 = self._fixture(o2m=True, post_update=True, insert=False)
 
@@ -633,6 +638,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         eq_(n1.version_id, 1)
         eq_(n2.version_id, 1)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_m2o_post_update_not_assoc_w_insert(self):
         s, n1, n2 = self._fixture(o2m=False, post_update=True, insert=False)
 
@@ -643,6 +649,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         eq_(n1.version_id, 1)
         eq_(n2.version_id, 1)
 
+    @testing.requires.sane_rowcount_w_returning
     def test_o2m_post_update_version_assert(self):
         Node = self.classes.Node
         s, n1, n2 = self._fixture(o2m=True, post_update=True)
@@ -663,6 +670,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
             s.flush
         )
 
+    @testing.requires.sane_rowcount_w_returning
     def test_m2o_post_update_version_assert(self):
         Node = self.classes.Node
 
@@ -805,6 +813,7 @@ class ColumnTypeTest(fixtures.MappedTest):
         s1 = Session()
         return s1
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     @engines.close_open_connections
     def test_update(self):
         Foo = self.classes.Foo
@@ -855,8 +864,7 @@ class RowSwitchTest(fixtures.MappedTest):
                     C, uselist=False, cascade='all, delete-orphan')})
         mapper(C, c, version_id_col=c.c.version_id)
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_row_switch(self):
         P = self.classes.P
 
@@ -870,8 +878,7 @@ class RowSwitchTest(fixtures.MappedTest):
         session.add(P(id='P1', data="really a row-switch"))
         session.commit()
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_child_row_switch(self):
         P, C = self.classes.P, self.classes.C
 
@@ -934,8 +941,7 @@ class AlternateGeneratorTest(fixtures.MappedTest):
             version_id_generator=lambda x: make_uuid(),
         )
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support updated rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_row_switch(self):
         P = self.classes.P
 
@@ -949,8 +955,7 @@ class AlternateGeneratorTest(fixtures.MappedTest):
         session.add(P(id='P1', data="really a row-switch"))
         session.commit()
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_child_row_switch_one(self):
         P, C = self.classes.P, self.classes.C
 
@@ -969,8 +974,8 @@ class AlternateGeneratorTest(fixtures.MappedTest):
         p.c = C(data='child row-switch')
         session.commit()
 
-    @testing.emits_warning_on(
-        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
+    @testing.emits_warning(r".*versioning cannot be verified")
+    @testing.requires.sane_rowcount_w_returning
     def test_child_row_switch_two(self):
         P = self.classes.P
 
@@ -1036,6 +1041,7 @@ class PlainInheritanceTest(fixtures.MappedTest):
         class Sub(Base):
             pass
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_update_child_table_only(self):
         Base, sub, base, Sub = (
             self.classes.Base, self.tables.sub, self.tables.base,
@@ -1259,6 +1265,7 @@ class ServerVersioningTest(fixtures.MappedTest):
     def test_update_col_eager_defaults(self):
         self._test_update_col(eager_defaults=True)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def _test_update_col(self, **kw):
         sess = self._fixture(**kw)
 
@@ -1295,6 +1302,8 @@ class ServerVersioningTest(fixtures.MappedTest):
             )
         self.assert_sql_execution(testing.db, sess.flush, *statements)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
+    @testing.requires.updateable_autoincrement_pks
     def test_sql_expr_bump(self):
         sess = self._fixture()
 
@@ -1310,6 +1319,8 @@ class ServerVersioningTest(fixtures.MappedTest):
 
         eq_(f1.version_id, 2)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
+    @testing.requires.updateable_autoincrement_pks
     @testing.requires.returning
     def test_sql_expr_w_mods_bump(self):
         sess = self._fixture()
@@ -1327,6 +1338,7 @@ class ServerVersioningTest(fixtures.MappedTest):
         eq_(f1.id, 5)
         eq_(f1.version_id, 2)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_multi_update(self):
         sess = self._fixture()
 
@@ -1419,6 +1431,7 @@ class ServerVersioningTest(fixtures.MappedTest):
         ]
         self.assert_sql_execution(testing.db, sess.flush, *statements)
 
+    @testing.requires.sane_rowcount_w_returning
     def test_concurrent_mod_err_expire_on_commit(self):
         sess = self._fixture()
 
@@ -1442,6 +1455,7 @@ class ServerVersioningTest(fixtures.MappedTest):
             sess.commit
         )
 
+    @testing.requires.sane_rowcount_w_returning
     def test_concurrent_mod_err_noexpire_on_commit(self):
         sess = self._fixture(expire_on_commit=False)
 
@@ -1505,6 +1519,7 @@ class ManualVersionTest(fixtures.MappedTest):
 
         eq_(a1.vid, 1)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_update(self):
         sess = Session()
         a1 = self.classes.A()
@@ -1521,6 +1536,7 @@ class ManualVersionTest(fixtures.MappedTest):
 
         eq_(a1.vid, 2)
 
+    @testing.requires.sane_rowcount_w_returning
     def test_update_concurrent_check(self):
         sess = Session()
         a1 = self.classes.A()
@@ -1538,6 +1554,7 @@ class ManualVersionTest(fixtures.MappedTest):
             sess.commit
         )
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_update_version_conditional(self):
         sess = Session()
         a1 = self.classes.A()
@@ -1600,6 +1617,7 @@ class ManualInheritanceVersionTest(fixtures.MappedTest):
         mapper(
             cls.classes.B, cls.tables.b, inherits=cls.classes.A)
 
+    @testing.emits_warning(r".*versioning cannot be verified")
     def test_no_increment(self):
         sess = Session()
         b1 = self.classes.B()