]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Session.merge() will check the version id of the incoming
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Jan 2011 21:34:34 +0000 (16:34 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Jan 2011 21:34:34 +0000 (16:34 -0500)
state against that of the database, assuming the mapping
uses version ids and incoming state has a version_id
assigned, and raise StaleDataError if they don't
match.  [ticket:2027]

CHANGES
lib/sqlalchemy/orm/session.py
test/orm/test_versioning.py

diff --git a/CHANGES b/CHANGES
index 9d6e25c651b7ff44ede31eb94e1373e118223fdd..6b39e23ac1a181715d26d76f906e8798ae690a1b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -10,6 +10,12 @@ CHANGES
     (i.e. not a mapped class), will return element.alias() 
     instead of raising an error on AliasedClass.  [ticket:2018]
 
+  - Session.merge() will check the version id of the incoming
+    state against that of the database, assuming the mapping
+    uses version ids and incoming state has a version_id
+    assigned, and raise StaleDataError if they don't 
+    match.  [ticket:2027]
+
 - sql
   - Added NULLS FIRST and NULLS LAST support. It's implemented
     as an extension to the asc() and desc() operators, called
index 47420e207f7b88bf8d546cb2b69f50ab53d24839..0e1818241faf3b559e4571850932d64b6bb0b63b 100644 (file)
@@ -1242,6 +1242,34 @@ class Session(object):
         # check that we didn't just pull the exact same
         # state out.
         if state is not merged_state:
+            # version check if applicable
+            if mapper.version_id_col is not None:
+                existing_version = mapper._get_state_attr_by_column(
+                            state, 
+                            state_dict, 
+                            mapper.version_id_col,
+                            passive=attributes.PASSIVE_NO_INITIALIZE)
+
+                merged_version = mapper._get_state_attr_by_column(
+                            merged_state, 
+                            merged_dict, 
+                            mapper.version_id_col,
+                            passive=attributes.PASSIVE_NO_INITIALIZE)
+
+                if existing_version is not attributes.PASSIVE_NO_RESULT and \
+                    merged_version is not attributes.PASSIVE_NO_RESULT and \
+                    existing_version != merged_version:
+                    raise exc.StaleDataError(
+                            "Version id '%s' on merged state %s "
+                            "does not match existing version '%s'. "
+                            "Leave the version attribute unset when "
+                            "merging to update the most recent version."
+                            % (
+                                existing_version,
+                                mapperutil.state_str(merged_state),
+                                merged_version
+                            ))
+
             merged_state.load_path = state.load_path
             merged_state.load_options = state.load_options
 
index 9da97dc1e92163362ce827c65c051453b54d2905..a42b8e8fe2b3fdef58eaa84653ae4f47b1fc6206 100644 (file)
@@ -4,7 +4,8 @@ from sqlalchemy import Integer, String, ForeignKey, literal_column, \
     orm, exc, select
 from test.lib.schema import Table, Column
 from sqlalchemy.orm import mapper, relationship, Session, \
-    create_session, column_property, sessionmaker
+    create_session, column_property, sessionmaker,\
+    exc as orm_exc
 from test.lib.testing import eq_, ne_, assert_raises, assert_raises_message
 from test.orm import _base, _fixtures
 from test.engine import _base as engine_base
@@ -46,24 +47,20 @@ class VersioningTest(_base.MappedTest):
         class Foo(_base.ComparableEntity):
             pass
 
+    @testing.resolve_artifact_names
+    def _fixture(self):
+        mapper(Foo, version_table, 
+                version_id_col=version_table.c.version_id)
+        s1 = Session()
+        return s1
+
     @engines.close_open_connections
     @testing.resolve_artifact_names
     def test_notsane_warning(self):
-        # clear the warning module's ignores to force the SAWarning this
-        # test relies on to be emitted (it may have already been ignored
-        # forever by other VersioningTests)
-        try:
-            del __warningregistry__
-        except NameError:
-            pass
-
         save = testing.db.dialect.supports_sane_rowcount
         testing.db.dialect.supports_sane_rowcount = False
         try:
-            mapper(Foo, version_table, 
-                    version_id_col=version_table.c.version_id)
-
-            s1 = create_session(autocommit=False)
+            s1 = self._fixture()
             f1 = Foo(value='f1')
             f2 = Foo(value='f2')
             s1.add_all((f1, f2))
@@ -75,13 +72,9 @@ class VersioningTest(_base.MappedTest):
             testing.db.dialect.supports_sane_rowcount = save
 
     @testing.emits_warning(r'.*does not support updated rowcount')
-    @engines.close_open_connections
     @testing.resolve_artifact_names
     def test_basic(self):
-        mapper(Foo, version_table, 
-                version_id_col=version_table.c.version_id)
-
-        s1 = create_session(autocommit=False)
+        s1 = self._fixture()
         f1 = Foo(value='f1')
         f2 = Foo(value='f2')
         s1.add_all((f1, f2))
@@ -137,10 +130,7 @@ class VersioningTest(_base.MappedTest):
         state.
 
         """
-        mapper(Foo, version_table, 
-                version_id_col=version_table.c.version_id)
-
-        s1 = sessionmaker()()
+        s1 = self._fixture()
         f1 = Foo(value='f1')
         s1.add(f1)
         s1.commit()
@@ -167,9 +157,7 @@ class VersioningTest(_base.MappedTest):
     def test_versioncheck(self):
         """query.with_lockmode performs a 'version check' on an already loaded instance"""
 
-        s1 = create_session(autocommit=False)
-
-        mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+        s1 = self._fixture()
         f1s1 = Foo(value='f1 value')
         s1.add(f1s1)
         s1.commit()
@@ -205,9 +193,7 @@ class VersioningTest(_base.MappedTest):
     def test_versioncheck_for_update(self):
         """query.with_lockmode performs a 'version check' on an already loaded instance"""
 
-        s1 = create_session(autocommit=False)
-
-        mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+        s1 = self._fixture()
         f1s1 = Foo(value='f1 value')
         s1.add(f1s1)
         s1.commit()
@@ -243,7 +229,79 @@ class VersioningTest(_base.MappedTest):
         assert f1s2.id == f1s1.id
         assert f1s2.value == f1s1.value
 
+    @testing.resolve_artifact_names
+    def test_merge_no_version(self):
+        s1 = self._fixture()
+        f1 = Foo(value='f1')
+        s1.add(f1)
+        s1.commit()
+
+        f1.value = 'f2'
+        s1.commit()
+
+        f2 = Foo(id=f1.id, value='f3')
+        f3 = s1.merge(f2)
+        assert f3 is f1
+        s1.commit()
+        eq_(f3.version_id, 3)
+
+    @testing.resolve_artifact_names
+    def test_merge_correct_version(self):
+        s1 = self._fixture()
+        f1 = Foo(value='f1')
+        s1.add(f1)
+        s1.commit()
+
+        f1.value = 'f2'
+        s1.commit()
+
+        f2 = Foo(id=f1.id, value='f3', version_id=2)
+        f3 = s1.merge(f2)
+        assert f3 is f1
+        s1.commit()
+        eq_(f3.version_id, 3)
+
+    @testing.resolve_artifact_names
+    def test_merge_incorrect_version(self):
+        s1 = self._fixture()
+        f1 = Foo(value='f1')
+        s1.add(f1)
+        s1.commit()
+
+        f1.value = 'f2'
+        s1.commit()
 
+        f2 = Foo(id=f1.id, value='f3', version_id=1)
+        assert_raises_message(
+            orm_exc.StaleDataError,
+            "Version id '1' on merged state "
+            "<Foo at .*?> does not match existing version '2'. "
+            "Leave the version attribute unset when "
+            "merging to update the most recent version.",
+            s1.merge, f2
+        )
+
+    @testing.resolve_artifact_names
+    def test_merge_incorrect_version_not_in_session(self):
+        s1 = self._fixture()
+        f1 = Foo(value='f1')
+        s1.add(f1)
+        s1.commit()
+
+        f1.value = 'f2'
+        s1.commit()
+
+        f2 = Foo(id=f1.id, value='f3', version_id=1)
+        s1.close()
+
+        assert_raises_message(
+            orm_exc.StaleDataError,
+            "Version id '1' on merged state "
+            "<Foo at .*?> does not match existing version '2'. "
+            "Leave the version attribute unset when "
+            "merging to update the most recent version.",
+            s1.merge, f2
+        )
 
 class RowSwitchTest(_base.MappedTest):
     @classmethod