From 9cc0aa292c302afd66f306a0aff8567151fee0a9 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 11 Sep 2006 00:20:28 +0000 Subject: [PATCH] - implemented "version check" logic in Query/Mapper, used when version_id_col is in effect and query.with_lockmode() is used to get() an instance thats already loaded [ticket:292] --- CHANGES | 3 +++ lib/sqlalchemy/orm/mapper.py | 12 +++++++---- lib/sqlalchemy/orm/query.py | 4 ++-- test/orm/objectstore.py | 42 ++++++++++++++++++++++++++++++++++-- 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/CHANGES b/CHANGES index 7cb9b2cb51..b0d75ea72f 100644 --- a/CHANGES +++ b/CHANGES @@ -10,6 +10,9 @@ including "with_lockmode" function to get a Query copy that has a default locking mode. Will translate "read"/"update" arguments into a for_update argument on the select side. [ticket:292] +- implemented "version check" logic in Query/Mapper, used +when version_id_col is in effect and query.with_lockmode() +is used to get() an instance thats already loaded 0.2.8 - cleanup on connection methods + documentation. custom DBAPI diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 4197401f7d..837f17a332 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -610,6 +610,7 @@ class Mapper(object): limit = kwargs.get('limit', None) offset = kwargs.get('offset', None) populate_existing = kwargs.get('populate_existing', False) + version_check = kwargs.get('version_check', False) result = util.UniqueAppender([]) if mappers: @@ -624,7 +625,7 @@ class Mapper(object): row = cursor.fetchone() if row is None: break - self._instance(session, row, imap, result, populate_existing=populate_existing) + self._instance(session, row, imap, result, populate_existing=populate_existing, version_check=version_check) i = 0 for m in mappers: m._instance(session, row, imap, otherresults[i]) @@ -838,7 +839,7 @@ class Mapper(object): rows += c.cursor.rowcount if c.supports_sane_rowcount() and rows != len(update): - raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) + raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) if len(insert): statement = table.insert() @@ -932,7 +933,7 @@ class Mapper(object): statement = table.delete(clause) c = connection.execute(statement, delete) if c.supports_sane_rowcount() and c.rowcount != len(delete): - raise exceptions.FlushError("ConcurrencyError - updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete))) + raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete))) [self.extension.after_delete(self, connection, obj) for obj in deleted_objects] @@ -972,7 +973,7 @@ class Mapper(object): def get_select_mapper(self): return self.__surrogate_mapper or self - def _instance(self, session, row, imap, result = None, populate_existing = False): + def _instance(self, session, row, imap, result = None, populate_existing = False, version_check=False): """pulls an object instance from the given row and appends it to the given result list. if the instance already exists in the given identity map, its not added. in either case, executes all the property loaders on the instance to also process extra @@ -994,6 +995,9 @@ class Mapper(object): if session.has_key(identitykey): instance = session._get(identitykey) isnew = False + if version_check and self.version_id_col is not None and self._getattrbycolumn(instance, self.version_id_col) != row[self.version_id_col]: + raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._getattrbycolumn(instance, self.version_id_col), row[self.version_id_col])) + if populate_existing or session.is_expired(instance, unexpire=True): if not imap.has_key(identitykey): imap[identitykey] = instance diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 052d048cbb..13092b44d8 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -278,7 +278,7 @@ class Query(object): def _get(self, key, ident=None, reload=False, lockmode=None): lockmode = lockmode or self.lockmode - if not reload and not self.always_refresh and lockmode == None: + if not reload and not self.always_refresh and lockmode is None: try: return self.session._get(key) except KeyError: @@ -301,7 +301,7 @@ class Query(object): i += 1 try: statement = self.compile(self._get_clause, lockmode=lockmode) - return self._select_statement(statement, params=params, populate_existing=reload)[0] + return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0] except IndexError: return None diff --git a/test/orm/objectstore.py b/test/orm/objectstore.py index 4c35fee650..ec47749bd0 100644 --- a/test/orm/objectstore.py +++ b/test/orm/objectstore.py @@ -150,7 +150,7 @@ class VersioningTest(SessionTest): # a concurrent session has modified this, should throw # an exception s.flush() - except exceptions.SQLAlchemyError, e: + except exceptions.ConcurrentModificationError, e: #print e success = True assert success @@ -166,10 +166,48 @@ class VersioningTest(SessionTest): success = False try: s.flush() - except exceptions.SQLAlchemyError, e: + except exceptions.ConcurrentModificationError, e: #print e success = True assert success + def testversioncheck(self): + """test that query.with_lockmode performs a 'version check' on an already loaded instance""" + s1 = create_session() + class Foo(object):pass + assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id) + f1s1 =Foo(value='f1', _sa_session=s1) + s1.flush() + s2 = create_session() + f1s2 = s2.query(Foo).get(f1s1.id) + f1s2.value='f1 new value' + s2.flush() + try: + # load, version is wrong + s1.query(Foo).with_lockmode('read').get(f1s1.id) + assert False + except exceptions.ConcurrentModificationError, e: + assert True + # reload it + s1.query(Foo).load(f1s1.id) + # now assert version OK + s1.query(Foo).with_lockmode('read').get(f1s1.id) + + # assert brand new load is OK too + s1.clear() + s1.query(Foo).with_lockmode('read').get(f1s1.id) + + def testnoversioncheck(self): + """test that query.with_lockmode works OK when the mapper has no version id col""" + s1 = create_session() + class Foo(object):pass + assign_mapper(Foo, version_table) + f1s1 =Foo(value='f1', _sa_session=s1) + f1s1.version_id=0 + s1.flush() + s2 = create_session() + f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id) + assert f1s2.id == f1s1.id + assert f1s2.value == f1s1.value class UnicodeTest(SessionTest): def setUpAll(self): -- 2.47.2