]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- implemented "version check" logic in Query/Mapper, used
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Sep 2006 00:20:28 +0000 (00:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 11 Sep 2006 00:20:28 +0000 (00:20 +0000)
when version_id_col is in effect and query.with_lockmode()
is used to get() an instance thats already loaded
[ticket:292]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
test/orm/objectstore.py

diff --git a/CHANGES b/CHANGES
index 7cb9b2cb517d3be8e676685cca92fcb0ce006974..b0d75ea72fd1a09bfcf0fbf1ea508430c26a133f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -10,6 +10,9 @@ including "with_lockmode" function to get a Query copy that has
 a default locking mode.  Will translate "read"/"update" 
 arguments into a for_update argument on the select side.
 [ticket:292]
+- implemented "version check" logic in Query/Mapper, used
+when version_id_col is in effect and query.with_lockmode()
+is used to get() an instance thats already loaded
 
 0.2.8
 - cleanup on connection methods + documentation.  custom DBAPI
index 4197401f7d4a2011190aa77d63d5eb4181772bbb..837f17a3327f0426a18ffb047363183d10685fe8 100644 (file)
@@ -610,6 +610,7 @@ class Mapper(object):
         limit = kwargs.get('limit', None)
         offset = kwargs.get('offset', None)
         populate_existing = kwargs.get('populate_existing', False)
+        version_check = kwargs.get('version_check', False)
         
         result = util.UniqueAppender([])
         if mappers:
@@ -624,7 +625,7 @@ class Mapper(object):
             row = cursor.fetchone()
             if row is None:
                 break
-            self._instance(session, row, imap, result, populate_existing=populate_existing)
+            self._instance(session, row, imap, result, populate_existing=populate_existing, version_check=version_check)
             i = 0
             for m in mappers:
                 m._instance(session, row, imap, otherresults[i])
@@ -838,7 +839,7 @@ class Mapper(object):
                     rows += c.cursor.rowcount
 
                 if c.supports_sane_rowcount() and rows != len(update):
-                    raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
+                    raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
 
             if len(insert):
                 statement = table.insert()
@@ -932,7 +933,7 @@ class Mapper(object):
                 statement = table.delete(clause)
                 c = connection.execute(statement, delete)
                 if c.supports_sane_rowcount() and c.rowcount != len(delete):
-                    raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
+                    raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
                     
         [self.extension.after_delete(self, connection, obj) for obj in deleted_objects]
 
@@ -972,7 +973,7 @@ class Mapper(object):
     def get_select_mapper(self):
         return self.__surrogate_mapper or self
         
-    def _instance(self, session, row, imap, result = None, populate_existing = False):
+    def _instance(self, session, row, imap, result = None, populate_existing = False, version_check=False):
         """pulls an object instance from the given row and appends it to the given result
         list. if the instance already exists in the given identity map, its not added.  in
         either case, executes all the property loaders on the instance to also process extra
@@ -994,6 +995,9 @@ class Mapper(object):
         if session.has_key(identitykey):
             instance = session._get(identitykey)
             isnew = False
+            if version_check and self.version_id_col is not None and self._getattrbycolumn(instance, self.version_id_col) != row[self.version_id_col]:
+                raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._getattrbycolumn(instance, self.version_id_col), row[self.version_id_col]))
+                        
             if populate_existing or session.is_expired(instance, unexpire=True):
                 if not imap.has_key(identitykey):
                     imap[identitykey] = instance
index 052d048cbb892a19ef1287df7833a98fb3bc1ce8..13092b44d8335be9b4d33ac5909a92f64dd47c35 100644 (file)
@@ -278,7 +278,7 @@ class Query(object):
         
     def _get(self, key, ident=None, reload=False, lockmode=None):
         lockmode = lockmode or self.lockmode
-        if not reload and not self.always_refresh and lockmode == None:
+        if not reload and not self.always_refresh and lockmode is None:
             try:
                 return self.session._get(key)
             except KeyError:
@@ -301,7 +301,7 @@ class Query(object):
                 i += 1
         try:
             statement = self.compile(self._get_clause, lockmode=lockmode)
-            return self._select_statement(statement, params=params, populate_existing=reload)[0]
+            return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0]
         except IndexError:
             return None
 
index 4c35fee6505287fa52dc7b19a2b63281244fb119..ec47749bd06cd8a69cbb13a5903f344f438b635d 100644 (file)
@@ -150,7 +150,7 @@ class VersioningTest(SessionTest):
             # a concurrent session has modified this, should throw
             # an exception
             s.flush()
-        except exceptions.SQLAlchemyError, e:
+        except exceptions.ConcurrentModificationError, e:
             #print e
             success = True
         assert success
@@ -166,10 +166,48 @@ class VersioningTest(SessionTest):
         success = False
         try:
             s.flush()
-        except exceptions.SQLAlchemyError, e:
+        except exceptions.ConcurrentModificationError, e:
             #print e
             success = True
         assert success
+    def testversioncheck(self):
+        """test that query.with_lockmode performs a 'version check' on an already loaded instance"""
+        s1 = create_session()
+        class Foo(object):pass
+        assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+        f1s1 =Foo(value='f1', _sa_session=s1)
+        s1.flush()
+        s2 = create_session()
+        f1s2 = s2.query(Foo).get(f1s1.id)
+        f1s2.value='f1 new value'
+        s2.flush()
+        try:
+            # load, version is wrong
+            s1.query(Foo).with_lockmode('read').get(f1s1.id)
+            assert False
+        except exceptions.ConcurrentModificationError, e:
+            assert True
+        # reload it
+        s1.query(Foo).load(f1s1.id)
+        # now assert version OK
+        s1.query(Foo).with_lockmode('read').get(f1s1.id)
+        
+        # assert brand new load is OK too
+        s1.clear()
+        s1.query(Foo).with_lockmode('read').get(f1s1.id)
+        
+    def testnoversioncheck(self):
+        """test that query.with_lockmode works OK when the mapper has no version id col"""
+        s1 = create_session()
+        class Foo(object):pass
+        assign_mapper(Foo, version_table)
+        f1s1 =Foo(value='f1', _sa_session=s1)
+        f1s1.version_id=0
+        s1.flush()
+        s2 = create_session()
+        f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id)
+        assert f1s2.id == f1s1.id
+        assert f1s2.value == f1s1.value
         
 class UnicodeTest(SessionTest):
     def setUpAll(self):