]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dont add non-server-side cols to returning for versioning
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 Feb 2023 22:22:22 +0000 (17:22 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Feb 2023 14:45:22 +0000 (09:45 -0500)
Fixed regression where using the :paramref:`_orm.Mapper.version_id_col`
feature with a regular Python-side incrementing column would fail to work
for SQLite and other databases that don't support "rowcount" with
"RETURNING", as "RETURNING" would be assumed for such columns even though
that's not what actually takes place.

Fixes: #9228
Change-Id: I6a1a7fa4d63e183fe4ef0fbfd3cb5cac03b26d78

doc/build/changelog/unreleased_20/9228.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
test/orm/inheritance/test_basic.py
test/orm/test_versioning.py

diff --git a/doc/build/changelog/unreleased_20/9228.rst b/doc/build/changelog/unreleased_20/9228.rst
new file mode 100644 (file)
index 0000000..7e96c24
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: orm, bug, regression
+    :tickets: 9228
+
+    Fixed regression where using the :paramref:`_orm.Mapper.version_id_col`
+    feature with a regular Python-side incrementing column would fail to work
+    for SQLite and other databases that don't support "rowcount" with
+    "RETURNING", as "RETURNING" would be assumed for such columns even though
+    that's not what actually takes place.
index a3b209e4a6507042d40de7394df0bc892c4f6375..bb7e470ff654ef19a704e536c17e1853390022cc 100644 (file)
@@ -83,6 +83,7 @@ from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql.cache_key import MemoizedHasCacheKey
 from ..sql.elements import KeyedColumnElement
+from ..sql.schema import Column
 from ..sql.schema import Table
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..util import HasMemoized
@@ -112,7 +113,6 @@ if TYPE_CHECKING:
     from ..sql.base import ReadOnlyColumnCollection
     from ..sql.elements import ColumnClause
     from ..sql.elements import ColumnElement
-    from ..sql.schema import Column
     from ..sql.selectable import FromClause
     from ..util import OrderedSet
 
@@ -2522,6 +2522,24 @@ class Mapper(
 
         return from_obj
 
+    @HasMemoized.memoized_attribute
+    def _version_id_has_server_side_value(self) -> bool:
+        vid_col = self.version_id_col
+
+        if vid_col is None:
+            return False
+
+        elif not isinstance(vid_col, Column):
+            return True
+        else:
+            return vid_col.server_default is not None or (
+                vid_col.default is not None
+                and (
+                    not vid_col.default.is_scalar
+                    and not vid_col.default.is_callable
+                )
+            )
+
     @HasMemoized.memoized_attribute
     def _single_table_criterion(self):
         if self.single and self.inherits and self.polymorphic_on is not None:
index cc7e321b43f39e9bd86cac9ab163cbfb59cfb964..b8368001b2ba678605f12c0942f562e7841d3a94 100644 (file)
@@ -801,7 +801,7 @@ def _emit_update_statements(
             )
             return_defaults = True
 
-        if mapper.version_id_col is not None:
+        if mapper._version_id_has_server_side_value:
             statement = statement.return_defaults(mapper.version_id_col)
             return_defaults = True
 
@@ -1268,13 +1268,16 @@ def _emit_post_update_statements(
 
         stmt = table.update().where(clauses)
 
-        if mapper.version_id_col is not None:
-            stmt = stmt.return_defaults(mapper.version_id_col)
-
         return stmt
 
     statement = base_mapper._memo(("post_update", table), update_stmt)
 
+    if mapper._version_id_has_server_side_value:
+        statement = statement.return_defaults(mapper.version_id_col)
+        return_defaults = True
+    else:
+        return_defaults = False
+
     # execute each UPDATE in the order according to the original
     # list of states to guarantee row access order, but
     # also group them into common (connection, cols) sets
@@ -1290,7 +1293,7 @@ def _emit_post_update_statements(
 
         assert_singlerow = (
             connection.dialect.supports_sane_rowcount
-            if mapper.version_id_col is None
+            if not return_defaults
             else connection.dialect.supports_sane_rowcount_returning
         )
         assert_multirow = (
index 37368f3ad63ca2d56faf12a516e9369539b5dcae..02f3527864f12df031796d4e63659e84e711b804 100644 (file)
@@ -2101,8 +2101,7 @@ class VersioningTest(fixtures.MappedTest):
             Column("parent", Integer, ForeignKey("base.id")),
         )
 
-    @testing.emits_warning(r".*updated rowcount")
-    @testing.requires.sane_rowcount_w_returning
+    @testing.requires.sane_rowcount
     def test_save_update(self):
         subtable, base, stuff = (
             self.tables.subtable,
@@ -2170,8 +2169,7 @@ class VersioningTest(fixtures.MappedTest):
         s2.subdata = "sess2 subdata"
         sess2.flush()
 
-    @testing.emits_warning(r".*(update|delete)d rowcount")
-    @testing.requires.sane_rowcount_w_returning
+    @testing.requires.sane_rowcount
     def test_delete(self):
         subtable, base = self.tables.subtable, self.tables.base
 
index 1a1801311b3841bc062523d482aaa46d0fbe04f9..f6b9f18fc4245fd9b21750a232e754e9438579cc 100644 (file)
@@ -140,9 +140,7 @@ class NullVersionIdTest(fixtures.MappedTest):
 
         f1.value = "f1rev2"
 
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             f1.version_id = None
             assert_raises_message(
                 sa.orm.exc.FlushError,
@@ -209,24 +207,20 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
 
         f1.value = "f1rev2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
         s2 = fixture_session()
         f1_s = s2.get(Foo, f1.id)
         f1_s.value = "f1rev3"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s2.commit()
 
         f1.value = "f1rev3mine"
 
         # Only dialects with a sane rowcount can detect the
         # StaleDataError
-        if testing.db.dialect.supports_sane_rowcount_returning:
+        if testing.db.dialect.supports_sane_rowcount:
             assert_raises_message(
                 sa.orm.exc.StaleDataError,
                 r"UPDATE statement on table 'version_table' expected "
@@ -235,9 +229,7 @@ class VersioningTest(fixtures.MappedTest):
             ),
             s1.rollback()
         else:
-            with conditional_sane_rowcount_warnings(
-                update=True, only_returning=True
-            ):
+            with conditional_sane_rowcount_warnings(update=True):
                 s1.commit()
 
         # new in 0.5 !  don't need to close the session
@@ -245,9 +237,7 @@ class VersioningTest(fixtures.MappedTest):
         f2 = s1.get(Foo, f2.id)
 
         f1_s.value = "f1rev4"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s2.commit()
 
         s1.delete(f1)
@@ -275,9 +265,7 @@ class VersioningTest(fixtures.MappedTest):
 
         f1.value = "f1rev2"
         f2.value = "f2rev2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
         eq_(
@@ -306,9 +294,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.add_all((f1, f2))
         s1.commit()
 
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.bulk_update_mappings(
                 Foo,
                 [
@@ -340,9 +326,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
         eq_(f1.version_id, 1)
         f1.version_id = 2
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
         eq_(f1.version_id, 2)
 
@@ -350,9 +334,7 @@ class VersioningTest(fixtures.MappedTest):
         # is honored
         f1.version_id = 4
         f1.value = "something new"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
         eq_(f1.version_id, 4)
 
@@ -377,9 +359,7 @@ class VersioningTest(fixtures.MappedTest):
         s2 = fixture_session()
         f1s2 = s2.get(Foo, f1s1.id)
         f1s2.value = "f1 new value"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s2.commit()
 
         # load, version is wrong
@@ -465,9 +445,7 @@ class VersioningTest(fixtures.MappedTest):
             s1.commit()
 
             f1s1.value = "f2 value"
-            with conditional_sane_rowcount_warnings(
-                update=True, only_returning=True
-            ):
+            with conditional_sane_rowcount_warnings(update=True):
                 s1.flush()
             eq_(f1s1.version_id, 2)
 
@@ -532,17 +510,13 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
 
         f1.value = "f2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
         f2 = Foo(id=f1.id, value="f3")
         f3 = s1.merge(f2)
         assert f3 is f1
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
         eq_(f3.version_id, 3)
 
@@ -555,17 +529,13 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
 
         f1.value = "f2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
         f2 = Foo(id=f1.id, value="f3", version_id=2)
         f3 = s1.merge(f2)
         assert f3 is f1
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
         eq_(f3.version_id, 3)
 
@@ -578,9 +548,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
 
         f1.value = "f2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
         f2 = Foo(id=f1.id, value="f3", version_id=1)
@@ -603,9 +571,7 @@ class VersioningTest(fixtures.MappedTest):
         s1.commit()
 
         f1.value = "f2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
         f2 = Foo(id=f1.id, value="f3", version_id=1)
@@ -670,9 +636,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         s, n1, n2 = self._fixture(o2m=True, post_update=False)
 
         n1.related.append(n2)
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s.flush()
 
         eq_(n1.version_id, 1)
@@ -682,9 +646,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         s, n1, n2 = self._fixture(o2m=False, post_update=False)
 
         n1.related = n2
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s.flush()
 
         eq_(n1.version_id, 2)
@@ -694,9 +656,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         s, n1, n2 = self._fixture(o2m=True, post_update=True)
 
         n1.related.append(n2)
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s.flush()
 
         eq_(n1.version_id, 1)
@@ -706,9 +666,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
         s, n1, n2 = self._fixture(o2m=False, post_update=True)
 
         n1.related = n2
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s.flush()
 
         eq_(n1.version_id, 2)
@@ -719,9 +677,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
 
         n1.related.append(n2)
         s.add_all([n1, n2])
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s.flush()
 
         eq_(n1.version_id, 1)
@@ -732,15 +688,13 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
 
         n1.related = n2
         s.add_all([n1, n2])
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s.flush()
 
         eq_(n1.version_id, 1)
         eq_(n2.version_id, 1)
 
-    @testing.requires.sane_rowcount_w_returning
+    @testing.requires.sane_rowcount
     def test_o2m_post_update_version_assert(self):
         Node = self.classes.Node
         s, n1, n2 = self._fixture(o2m=True, post_update=True)
@@ -782,7 +736,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest):
             ):
                 s.flush()
 
-    @testing.requires.sane_rowcount_w_returning
+    @testing.requires.sane_rowcount
     def test_m2o_post_update_version_assert(self):
         Node = self.classes.Node
 
@@ -944,9 +898,7 @@ class ColumnTypeTest(fixtures.MappedTest):
         s1.commit()
 
         f1.value = "f1rev2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
 
@@ -1007,9 +959,7 @@ class RowSwitchTest(fixtures.MappedTest):
         p = session.query(P).first()
         session.delete(p)
         session.add(P(id="P1", data="really a row-switch"))
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             session.commit()
 
     def test_child_row_switch(self):
@@ -1028,9 +978,7 @@ class RowSwitchTest(fixtures.MappedTest):
 
         p = session.query(P).first()
         p.c = C(data="child row-switch")
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             session.commit()
 
 
@@ -1096,9 +1044,7 @@ class AlternateGeneratorTest(fixtures.MappedTest):
         p = session.query(P).first()
         session.delete(p)
         session.add(P(id="P1", data="really a row-switch"))
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             session.commit()
 
     def test_child_row_switch_one(self):
@@ -1117,12 +1063,10 @@ class AlternateGeneratorTest(fixtures.MappedTest):
 
         p = session.query(P).first()
         p.c = C(data="child row-switch")
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             session.commit()
 
-    @testing.requires.sane_rowcount_w_returning
+    @testing.requires.sane_rowcount
     def test_child_row_switch_two(self):
         P = self.classes.P
 
@@ -1206,9 +1150,7 @@ class PlainInheritanceTest(fixtures.MappedTest):
         s.commit()
 
         s1.sub_data = "s2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s.commit()
 
         eq_(s1.version_id, 2)
@@ -1799,14 +1741,12 @@ class ManualVersionTest(fixtures.MappedTest):
         a1.vid = 2
         a1.data = "d2"
 
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             sess.commit()
 
         eq_(a1.vid, 2)
 
-    @testing.requires.sane_rowcount_w_returning
+    @testing.requires.sane_rowcount
     def test_update_concurrent_check(self):
         sess = fixture_session()
         a1 = self.classes.A()
@@ -1833,18 +1773,14 @@ class ManualVersionTest(fixtures.MappedTest):
         # change the data and UPDATE without
         # incrementing version id
         a1.data = "d2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             sess.commit()
 
         eq_(a1.vid, 1)
 
         a1.data = "d3"
         a1.vid = 2
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             sess.commit()
 
         eq_(a1.vid, 2)
@@ -1907,18 +1843,14 @@ class ManualInheritanceVersionTest(fixtures.MappedTest):
         # change col on subtable only without
         # incrementing version id
         b1.b_data = "bd2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             sess.commit()
 
         eq_(b1.vid, 1)
 
         b1.b_data = "d3"
         b1.vid = 2
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             sess.commit()
 
         eq_(b1.vid, 2)
@@ -1990,9 +1922,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
 
         f1.value = "f1rev2"
         f2.value = "f2rev2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.commit()
 
         eq_(
@@ -2015,9 +1945,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
         f1.version_id = 2
         f2.value = "f2rev2"
         f2.version_id = 2
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             s1.flush()
 
         eq_(
@@ -2057,9 +1985,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
 
         s1.expire_all()
 
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             f1.value = "f2"
             f1.version_id = 2
             s1.flush()
@@ -2109,9 +2035,7 @@ class QuotedBindVersioningTest(fixtures.MappedTest):
         fixture_session.commit()
 
         f1.value = "v2"
-        with conditional_sane_rowcount_warnings(
-            update=True, only_returning=True
-        ):
+        with conditional_sane_rowcount_warnings(update=True):
             fixture_session.commit()
 
         eq_(f1.version, 2)