From: Mike Bayer Date: Thu, 2 Feb 2023 22:22:22 +0000 (-0500) Subject: dont add non-server-side cols to returning for versioning X-Git-Tag: rel_2_0_2~13^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=92e3c21ea9b8192ff3d6ad856389186dfe8b3d3d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git dont add non-server-side cols to returning for versioning Fixed regression where using the :paramref:`_orm.Mapper.version_id_col` feature with a regular Python-side incrementing column would fail to work for SQLite and other databases that don't support "rowcount" with "RETURNING", as "RETURNING" would be assumed for such columns even though that's not what actually takes place. Fixes: #9228 Change-Id: I6a1a7fa4d63e183fe4ef0fbfd3cb5cac03b26d78 --- diff --git a/doc/build/changelog/unreleased_20/9228.rst b/doc/build/changelog/unreleased_20/9228.rst new file mode 100644 index 0000000000..7e96c24611 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9228.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: orm, bug, regression + :tickets: 9228 + + Fixed regression where using the :paramref:`_orm.Mapper.version_id_col` + feature with a regular Python-side incrementing column would fail to work + for SQLite and other databases that don't support "rowcount" with + "RETURNING", as "RETURNING" would be assumed for such columns even though + that's not what actually takes place. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index a3b209e4a6..bb7e470ff6 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -83,6 +83,7 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.cache_key import MemoizedHasCacheKey from ..sql.elements import KeyedColumnElement +from ..sql.schema import Column from ..sql.schema import Table from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..util import HasMemoized @@ -112,7 +113,6 @@ if TYPE_CHECKING: from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement - from ..sql.schema import Column from ..sql.selectable import FromClause from ..util import OrderedSet @@ -2522,6 +2522,24 @@ class Mapper( return from_obj + @HasMemoized.memoized_attribute + def _version_id_has_server_side_value(self) -> bool: + vid_col = self.version_id_col + + if vid_col is None: + return False + + elif not isinstance(vid_col, Column): + return True + else: + return vid_col.server_default is not None or ( + vid_col.default is not None + and ( + not vid_col.default.is_scalar + and not vid_col.default.is_callable + ) + ) + @HasMemoized.memoized_attribute def _single_table_criterion(self): if self.single and self.inherits and self.polymorphic_on is not None: diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index cc7e321b43..b8368001b2 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -801,7 +801,7 @@ def _emit_update_statements( ) return_defaults = True - if mapper.version_id_col is not None: + if mapper._version_id_has_server_side_value: statement = statement.return_defaults(mapper.version_id_col) return_defaults = True @@ -1268,13 +1268,16 @@ def _emit_post_update_statements( stmt = table.update().where(clauses) - if mapper.version_id_col is not None: - stmt = stmt.return_defaults(mapper.version_id_col) - return stmt statement = base_mapper._memo(("post_update", table), update_stmt) + if mapper._version_id_has_server_side_value: + statement = statement.return_defaults(mapper.version_id_col) + return_defaults = True + else: + return_defaults = False + # execute each UPDATE in the order according to the original # list of states to guarantee row access order, but # also group them into common (connection, cols) sets @@ -1290,7 +1293,7 @@ def _emit_post_update_statements( assert_singlerow = ( connection.dialect.supports_sane_rowcount - if mapper.version_id_col is None + if not return_defaults else connection.dialect.supports_sane_rowcount_returning ) assert_multirow = ( diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 37368f3ad6..02f3527864 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -2101,8 +2101,7 @@ class VersioningTest(fixtures.MappedTest): Column("parent", Integer, ForeignKey("base.id")), ) - @testing.emits_warning(r".*updated rowcount") - @testing.requires.sane_rowcount_w_returning + @testing.requires.sane_rowcount def test_save_update(self): subtable, base, stuff = ( self.tables.subtable, @@ -2170,8 +2169,7 @@ class VersioningTest(fixtures.MappedTest): s2.subdata = "sess2 subdata" sess2.flush() - @testing.emits_warning(r".*(update|delete)d rowcount") - @testing.requires.sane_rowcount_w_returning + @testing.requires.sane_rowcount def test_delete(self): subtable, base = self.tables.subtable, self.tables.base diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 1a1801311b..f6b9f18fc4 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -140,9 +140,7 @@ class NullVersionIdTest(fixtures.MappedTest): f1.value = "f1rev2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): f1.version_id = None assert_raises_message( sa.orm.exc.FlushError, @@ -209,24 +207,20 @@ class VersioningTest(fixtures.MappedTest): s1.commit() f1.value = "f1rev2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() s2 = fixture_session() f1_s = s2.get(Foo, f1.id) f1_s.value = "f1rev3" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s2.commit() f1.value = "f1rev3mine" # Only dialects with a sane rowcount can detect the # StaleDataError - if testing.db.dialect.supports_sane_rowcount_returning: + if testing.db.dialect.supports_sane_rowcount: assert_raises_message( sa.orm.exc.StaleDataError, r"UPDATE statement on table 'version_table' expected " @@ -235,9 +229,7 @@ class VersioningTest(fixtures.MappedTest): ), s1.rollback() else: - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() # new in 0.5 ! don't need to close the session @@ -245,9 +237,7 @@ class VersioningTest(fixtures.MappedTest): f2 = s1.get(Foo, f2.id) f1_s.value = "f1rev4" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s2.commit() s1.delete(f1) @@ -275,9 +265,7 @@ class VersioningTest(fixtures.MappedTest): f1.value = "f1rev2" f2.value = "f2rev2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() eq_( @@ -306,9 +294,7 @@ class VersioningTest(fixtures.MappedTest): s1.add_all((f1, f2)) s1.commit() - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.bulk_update_mappings( Foo, [ @@ -340,9 +326,7 @@ class VersioningTest(fixtures.MappedTest): s1.commit() eq_(f1.version_id, 1) f1.version_id = 2 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() eq_(f1.version_id, 2) @@ -350,9 +334,7 @@ class VersioningTest(fixtures.MappedTest): # is honored f1.version_id = 4 f1.value = "something new" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() eq_(f1.version_id, 4) @@ -377,9 +359,7 @@ class VersioningTest(fixtures.MappedTest): s2 = fixture_session() f1s2 = s2.get(Foo, f1s1.id) f1s2.value = "f1 new value" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s2.commit() # load, version is wrong @@ -465,9 +445,7 @@ class VersioningTest(fixtures.MappedTest): s1.commit() f1s1.value = "f2 value" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.flush() eq_(f1s1.version_id, 2) @@ -532,17 +510,13 @@ class VersioningTest(fixtures.MappedTest): s1.commit() f1.value = "f2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() f2 = Foo(id=f1.id, value="f3") f3 = s1.merge(f2) assert f3 is f1 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() eq_(f3.version_id, 3) @@ -555,17 +529,13 @@ class VersioningTest(fixtures.MappedTest): s1.commit() f1.value = "f2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() f2 = Foo(id=f1.id, value="f3", version_id=2) f3 = s1.merge(f2) assert f3 is f1 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() eq_(f3.version_id, 3) @@ -578,9 +548,7 @@ class VersioningTest(fixtures.MappedTest): s1.commit() f1.value = "f2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() f2 = Foo(id=f1.id, value="f3", version_id=1) @@ -603,9 +571,7 @@ class VersioningTest(fixtures.MappedTest): s1.commit() f1.value = "f2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() f2 = Foo(id=f1.id, value="f3", version_id=1) @@ -670,9 +636,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): s, n1, n2 = self._fixture(o2m=True, post_update=False) n1.related.append(n2) - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s.flush() eq_(n1.version_id, 1) @@ -682,9 +646,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): s, n1, n2 = self._fixture(o2m=False, post_update=False) n1.related = n2 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s.flush() eq_(n1.version_id, 2) @@ -694,9 +656,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): s, n1, n2 = self._fixture(o2m=True, post_update=True) n1.related.append(n2) - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s.flush() eq_(n1.version_id, 1) @@ -706,9 +666,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): s, n1, n2 = self._fixture(o2m=False, post_update=True) n1.related = n2 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s.flush() eq_(n1.version_id, 2) @@ -719,9 +677,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): n1.related.append(n2) s.add_all([n1, n2]) - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s.flush() eq_(n1.version_id, 1) @@ -732,15 +688,13 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): n1.related = n2 s.add_all([n1, n2]) - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s.flush() eq_(n1.version_id, 1) eq_(n2.version_id, 1) - @testing.requires.sane_rowcount_w_returning + @testing.requires.sane_rowcount def test_o2m_post_update_version_assert(self): Node = self.classes.Node s, n1, n2 = self._fixture(o2m=True, post_update=True) @@ -782,7 +736,7 @@ class VersionOnPostUpdateTest(fixtures.MappedTest): ): s.flush() - @testing.requires.sane_rowcount_w_returning + @testing.requires.sane_rowcount def test_m2o_post_update_version_assert(self): Node = self.classes.Node @@ -944,9 +898,7 @@ class ColumnTypeTest(fixtures.MappedTest): s1.commit() f1.value = "f1rev2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() @@ -1007,9 +959,7 @@ class RowSwitchTest(fixtures.MappedTest): p = session.query(P).first() session.delete(p) session.add(P(id="P1", data="really a row-switch")) - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): session.commit() def test_child_row_switch(self): @@ -1028,9 +978,7 @@ class RowSwitchTest(fixtures.MappedTest): p = session.query(P).first() p.c = C(data="child row-switch") - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): session.commit() @@ -1096,9 +1044,7 @@ class AlternateGeneratorTest(fixtures.MappedTest): p = session.query(P).first() session.delete(p) session.add(P(id="P1", data="really a row-switch")) - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): session.commit() def test_child_row_switch_one(self): @@ -1117,12 +1063,10 @@ class AlternateGeneratorTest(fixtures.MappedTest): p = session.query(P).first() p.c = C(data="child row-switch") - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): session.commit() - @testing.requires.sane_rowcount_w_returning + @testing.requires.sane_rowcount def test_child_row_switch_two(self): P = self.classes.P @@ -1206,9 +1150,7 @@ class PlainInheritanceTest(fixtures.MappedTest): s.commit() s1.sub_data = "s2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s.commit() eq_(s1.version_id, 2) @@ -1799,14 +1741,12 @@ class ManualVersionTest(fixtures.MappedTest): a1.vid = 2 a1.data = "d2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): sess.commit() eq_(a1.vid, 2) - @testing.requires.sane_rowcount_w_returning + @testing.requires.sane_rowcount def test_update_concurrent_check(self): sess = fixture_session() a1 = self.classes.A() @@ -1833,18 +1773,14 @@ class ManualVersionTest(fixtures.MappedTest): # change the data and UPDATE without # incrementing version id a1.data = "d2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): sess.commit() eq_(a1.vid, 1) a1.data = "d3" a1.vid = 2 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): sess.commit() eq_(a1.vid, 2) @@ -1907,18 +1843,14 @@ class ManualInheritanceVersionTest(fixtures.MappedTest): # change col on subtable only without # incrementing version id b1.b_data = "bd2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): sess.commit() eq_(b1.vid, 1) b1.b_data = "d3" b1.vid = 2 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): sess.commit() eq_(b1.vid, 2) @@ -1990,9 +1922,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest): f1.value = "f1rev2" f2.value = "f2rev2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.commit() eq_( @@ -2015,9 +1945,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest): f1.version_id = 2 f2.value = "f2rev2" f2.version_id = 2 - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): s1.flush() eq_( @@ -2057,9 +1985,7 @@ class VersioningMappedSelectTest(fixtures.MappedTest): s1.expire_all() - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): f1.value = "f2" f1.version_id = 2 s1.flush() @@ -2109,9 +2035,7 @@ class QuotedBindVersioningTest(fixtures.MappedTest): fixture_session.commit() f1.value = "v2" - with conditional_sane_rowcount_warnings( - update=True, only_returning=True - ): + with conditional_sane_rowcount_warnings(update=True): fixture_session.commit() eq_(f1.version, 2)