From: Mike Bayer Date: Tue, 18 Jan 2011 21:34:34 +0000 (-0500) Subject: - Session.merge() will check the version id of the incoming X-Git-Tag: rel_0_7b1~62 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b3dd50a8dac9f7660b7f497f444e6175fdf85713;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Session.merge() will check the version id of the incoming state against that of the database, assuming the mapping uses version ids and incoming state has a version_id assigned, and raise StaleDataError if they don't match. [ticket:2027] --- diff --git a/CHANGES b/CHANGES index 9d6e25c651..6b39e23ac1 100644 --- a/CHANGES +++ b/CHANGES @@ -10,6 +10,12 @@ CHANGES (i.e. not a mapped class), will return element.alias() instead of raising an error on AliasedClass. [ticket:2018] + - Session.merge() will check the version id of the incoming + state against that of the database, assuming the mapping + uses version ids and incoming state has a version_id + assigned, and raise StaleDataError if they don't + match. [ticket:2027] + - sql - Added NULLS FIRST and NULLS LAST support. It's implemented as an extension to the asc() and desc() operators, called diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 47420e207f..0e1818241f 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1242,6 +1242,34 @@ class Session(object): # check that we didn't just pull the exact same # state out. if state is not merged_state: + # version check if applicable + if mapper.version_id_col is not None: + existing_version = mapper._get_state_attr_by_column( + state, + state_dict, + mapper.version_id_col, + passive=attributes.PASSIVE_NO_INITIALIZE) + + merged_version = mapper._get_state_attr_by_column( + merged_state, + merged_dict, + mapper.version_id_col, + passive=attributes.PASSIVE_NO_INITIALIZE) + + if existing_version is not attributes.PASSIVE_NO_RESULT and \ + merged_version is not attributes.PASSIVE_NO_RESULT and \ + existing_version != merged_version: + raise exc.StaleDataError( + "Version id '%s' on merged state %s " + "does not match existing version '%s'. " + "Leave the version attribute unset when " + "merging to update the most recent version." + % ( + existing_version, + mapperutil.state_str(merged_state), + merged_version + )) + merged_state.load_path = state.load_path merged_state.load_options = state.load_options diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 9da97dc1e9..a42b8e8fe2 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -4,7 +4,8 @@ from sqlalchemy import Integer, String, ForeignKey, literal_column, \ orm, exc, select from test.lib.schema import Table, Column from sqlalchemy.orm import mapper, relationship, Session, \ - create_session, column_property, sessionmaker + create_session, column_property, sessionmaker,\ + exc as orm_exc from test.lib.testing import eq_, ne_, assert_raises, assert_raises_message from test.orm import _base, _fixtures from test.engine import _base as engine_base @@ -46,24 +47,20 @@ class VersioningTest(_base.MappedTest): class Foo(_base.ComparableEntity): pass + @testing.resolve_artifact_names + def _fixture(self): + mapper(Foo, version_table, + version_id_col=version_table.c.version_id) + s1 = Session() + return s1 + @engines.close_open_connections @testing.resolve_artifact_names def test_notsane_warning(self): - # clear the warning module's ignores to force the SAWarning this - # test relies on to be emitted (it may have already been ignored - # forever by other VersioningTests) - try: - del __warningregistry__ - except NameError: - pass - save = testing.db.dialect.supports_sane_rowcount testing.db.dialect.supports_sane_rowcount = False try: - mapper(Foo, version_table, - version_id_col=version_table.c.version_id) - - s1 = create_session(autocommit=False) + s1 = self._fixture() f1 = Foo(value='f1') f2 = Foo(value='f2') s1.add_all((f1, f2)) @@ -75,13 +72,9 @@ class VersioningTest(_base.MappedTest): testing.db.dialect.supports_sane_rowcount = save @testing.emits_warning(r'.*does not support updated rowcount') - @engines.close_open_connections @testing.resolve_artifact_names def test_basic(self): - mapper(Foo, version_table, - version_id_col=version_table.c.version_id) - - s1 = create_session(autocommit=False) + s1 = self._fixture() f1 = Foo(value='f1') f2 = Foo(value='f2') s1.add_all((f1, f2)) @@ -137,10 +130,7 @@ class VersioningTest(_base.MappedTest): state. """ - mapper(Foo, version_table, - version_id_col=version_table.c.version_id) - - s1 = sessionmaker()() + s1 = self._fixture() f1 = Foo(value='f1') s1.add(f1) s1.commit() @@ -167,9 +157,7 @@ class VersioningTest(_base.MappedTest): def test_versioncheck(self): """query.with_lockmode performs a 'version check' on an already loaded instance""" - s1 = create_session(autocommit=False) - - mapper(Foo, version_table, version_id_col=version_table.c.version_id) + s1 = self._fixture() f1s1 = Foo(value='f1 value') s1.add(f1s1) s1.commit() @@ -205,9 +193,7 @@ class VersioningTest(_base.MappedTest): def test_versioncheck_for_update(self): """query.with_lockmode performs a 'version check' on an already loaded instance""" - s1 = create_session(autocommit=False) - - mapper(Foo, version_table, version_id_col=version_table.c.version_id) + s1 = self._fixture() f1s1 = Foo(value='f1 value') s1.add(f1s1) s1.commit() @@ -243,7 +229,79 @@ class VersioningTest(_base.MappedTest): assert f1s2.id == f1s1.id assert f1s2.value == f1s1.value + @testing.resolve_artifact_names + def test_merge_no_version(self): + s1 = self._fixture() + f1 = Foo(value='f1') + s1.add(f1) + s1.commit() + + f1.value = 'f2' + s1.commit() + + f2 = Foo(id=f1.id, value='f3') + f3 = s1.merge(f2) + assert f3 is f1 + s1.commit() + eq_(f3.version_id, 3) + + @testing.resolve_artifact_names + def test_merge_correct_version(self): + s1 = self._fixture() + f1 = Foo(value='f1') + s1.add(f1) + s1.commit() + + f1.value = 'f2' + s1.commit() + + f2 = Foo(id=f1.id, value='f3', version_id=2) + f3 = s1.merge(f2) + assert f3 is f1 + s1.commit() + eq_(f3.version_id, 3) + + @testing.resolve_artifact_names + def test_merge_incorrect_version(self): + s1 = self._fixture() + f1 = Foo(value='f1') + s1.add(f1) + s1.commit() + + f1.value = 'f2' + s1.commit() + f2 = Foo(id=f1.id, value='f3', version_id=1) + assert_raises_message( + orm_exc.StaleDataError, + "Version id '1' on merged state " + " does not match existing version '2'. " + "Leave the version attribute unset when " + "merging to update the most recent version.", + s1.merge, f2 + ) + + @testing.resolve_artifact_names + def test_merge_incorrect_version_not_in_session(self): + s1 = self._fixture() + f1 = Foo(value='f1') + s1.add(f1) + s1.commit() + + f1.value = 'f2' + s1.commit() + + f2 = Foo(id=f1.id, value='f3', version_id=1) + s1.close() + + assert_raises_message( + orm_exc.StaleDataError, + "Version id '1' on merged state " + " does not match existing version '2'. " + "Leave the version attribute unset when " + "merging to update the most recent version.", + s1.merge, f2 + ) class RowSwitchTest(_base.MappedTest): @classmethod