From: Mike Bayer Date: Tue, 14 Feb 2006 00:30:30 +0000 (+0000) Subject: latest reorgnanization of the objectstore, the Session is a simpler object that just... X-Git-Tag: rel_0_1_0~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=791e2f7f7da88bd13a1002540755f920e6703711;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git latest reorgnanization of the objectstore, the Session is a simpler object that just maintains begin/commit state --- diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 3a900569c8..24bf11fd8f 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -204,7 +204,7 @@ class Mapper(object): oldinit = self.class_.__init__ def init(self, *args, **kwargs): nohist = kwargs.pop('_mapper_nohistory', False) - session = kwargs.pop('_sa_session', objectstore.session()) + session = kwargs.pop('_sa_session', objectstore.get_session()) if oldinit is not None: try: oldinit(self, *args, **kwargs) @@ -244,7 +244,7 @@ class Mapper(object): # store new stuff in the identity map for value in imap.values(): - objectstore.session().register_clean(value) + objectstore.get_session().register_clean(value) if len(mappers): return [result] + otherresults @@ -261,7 +261,7 @@ class Mapper(object): def _get(self, key, ident=None): try: - return objectstore.session()._get(key) + return objectstore.get_session()._get(key) except KeyError: if ident is None: ident = key[2] @@ -688,8 +688,8 @@ class Mapper(object): # including modifying any of its related items lists, as its already # been exposed to being modified by the application. identitykey = self._identity_key(row) - if objectstore.session().has_key(identitykey): - instance = objectstore.session()._get(identitykey) + if objectstore.get_session().has_key(identitykey): + instance = objectstore.get_session()._get(identitykey) isnew = False if populate_existing: diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index 078c5a1798..c1549ffb73 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -32,31 +32,25 @@ class Session(object): The registry is capable of maintaining object instances on a thread-local, per-application, or custom user-defined basis.""" - def __init__(self, scope="application", getter=None, hash_key=None, keyfunc=None): + def __init__(self, nest_transactions=False, hash_key=None): """Initialize the objectstore with a UnitOfWork registry. If called with no arguments, creates a single UnitOfWork for all operations. - scope - "application" or "thread", the two default scopes - getter - a callable that takes this Session as an argument and returns a - new UnitOfWork. + nest_transactions - indicates begin/commit statements can be executed in a + "nested", defaults to False which indicates "only commit on the outermost begin/commit" hash_key - the hash_key used to identify objects against this session, which defaults to the id of the Session instance. - keyfunc - allows custom scopes by providing a callable to return the "key" - identifying the desired UnitOfWork. """ - if keyfunc is None: - if scope=="thread": - keyfunc = thread.get_ident - elif scope=="application": - keyfunc = lambda: True - if getter is None: - def createfunc(): - return UnitOfWork(self) + self.uow = UnitOfWork() + self.parent_uow = None + self.begin_count = 0 + self.nest_transactions = nest_transactions + if hash_key is None: + self.hash_key = id(self) else: - createfunc = lambda: getter(self) - self.registry = util.ScopedRegistry(createfunc, keyfunc) - self._hash_key = hash_key - + self.hash_key = hash_key + _sessions[self.hash_key] = self + def get_id_key(ident, class_, table): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a tuple of the object's primary key values. @@ -92,29 +86,69 @@ class Session(object): return (class_, table.hash_key(), tuple([row[column] for column in primary_key])) get_row_key = staticmethod(get_row_key) - def _set_uow(self, uow): - self.registry.set(uow) - uow = property(lambda s:s.registry(), _set_uow, doc="Returns a scope-specific UnitOfWork object for this session.") - - hash_key = property(lambda s:s._hash_key or id(s)) + def begin(self): + """begins a new UnitOfWork transaction. the next commit will affect only + objects that are created, modified, or deleted following the begin statement.""" + self.begin_count += 1 + if self.parent_uow is not None: + return + self.parent_uow = self.uow + self.uow = UnitOfWork(identity_map = self.uow.identity_map) + + def commit(self, *objects): + """commits the current UnitOfWork transaction. if a transaction was begun + via begin(), commits only those objects that were created, modified, or deleted + since that begin statement. otherwise commits all objects that have been + changed. + if individual objects are submitted, then only those objects are committed, and the + begin/commit cycle is not affected.""" + # if an object list is given, commit just those but dont + # change begin/commit status + if len(objects): + self.uow.commit(*objects) + return + if self.parent_uow is not None: + self.begin_count -= 1 + if self.begin_count > 0: + return + self.uow.commit() + if self.parent_uow is not None: + self.uow = self.parent_uow + self.parent_uow = None + + def rollback(self): + """rolls back the current UnitOfWork transaction, in the case that begin() + has been called. The changes logged since the begin() call are discarded.""" + if self.parent_uow is None: + raise "UOW transaction is not begun" + self.uow = self.parent_uow + self.parent_uow = None + self.begin_count = 0 + + def register_clean(self, obj): + self._bind_to(obj) + self.uow.register_clean(obj) + + def register_new(self, obj): + self._bind_to(obj) + self.uow.register_new(obj) - def bind_to(self, obj): + def _bind_to(self, obj): """given an object, binds it to this session. changes on the object will affect the currently scoped UnitOfWork maintained by this session.""" obj._sa_session_id = self.hash_key def __getattr__(self, key): """proxy other methods to our underlying UnitOfWork""" - return getattr(self.registry(), key) + return getattr(self.uow, key) def clear(self): - self.registry.clear() + self.uow = UnitOfWork() - def delete(*obj): + def delete(self, *obj): """registers the given objects as to be deleted upon the next commit""" - u = registry() for o in obj: - u.register_deleted(o) + self.uow.register_deleted(o) def import_instance(self, instance): """places the given instance in the current thread's unit of work context, @@ -130,7 +164,7 @@ class Session(object): key = getattr(instance, '_instance_key', None) mapper = object_mapper(instance) key = (key[0], mapper.table.hash_key(), key[2]) - u = self.registry() + u = self.uow if key is not None: if u.identity_map.has_key(key): return u.identity_map[key] @@ -141,7 +175,6 @@ class Session(object): else: u.register_new(instance) return instance - def get_id_key(ident, class_, table): return Session.get_id_key(ident, class_, table) @@ -152,53 +185,54 @@ def get_row_key(row, class_, table, primary_key): def begin(): """begins a new UnitOfWork transaction. the next commit will affect only objects that are created, modified, or deleted following the begin statement.""" - session().begin() + get_session().begin() def commit(*obj): """commits the current UnitOfWork transaction. if a transaction was begun via begin(), commits only those objects that were created, modified, or deleted since that begin statement. otherwise commits all objects that have been - changed.""" - session().commit(*obj) + changed. + + if individual objects are submitted, then only those objects are committed, and the + begin/commit cycle is not affected.""" + get_session().commit(*obj) def clear(): """removes all current UnitOfWorks and IdentityMaps for this thread and establishes a new one. It is probably a good idea to discard all current mapped object instances, as they are no longer in the Identity Map.""" - session().clear() + get_session().clear() def delete(*obj): """registers the given objects as to be deleted upon the next commit""" - s = session() - for o in obj: - s.register_deleted(o) + s = get_session().delete(*obj) def has_key(key): """returns True if the current thread-local IdentityMap contains the given instance key""" - return session().has_key(key) + return get_session().has_key(key) def has_instance(instance): """returns True if the current thread-local IdentityMap contains the given instance""" - return session().has_instance(instance) + return get_session().has_instance(instance) def is_dirty(obj): """returns True if the given object is in the current UnitOfWork's new or dirty list, or if its a modified list attribute on an object.""" - return session().is_dirty(obj) + return get_session().is_dirty(obj) def instance_key(instance): """returns the IdentityMap key for the given instance""" - return session().instance_key(instance) + return get_session().instance_key(instance) def import_instance(instance): - return session().import_instance(instance) + return get_session().import_instance(instance) class UOWListElement(attributes.ListElement): def __init__(self, obj, key, data=None, deleteremoved=False, **kwargs): attributes.ListElement.__init__(self, obj, key, data=data, **kwargs) self.deleteremoved = deleteremoved def list_value_changed(self, obj, key, item, listval, isdelete): - sess = session(obj) + sess = get_session(obj) if not isdelete and sess.deleted.contains(item): raise "re-inserting a deleted value into a list" sess.modified_lists.append(self) @@ -216,23 +250,17 @@ class UOWAttributeManager(attributes.AttributeManager): def value_changed(self, obj, key, value): if hasattr(obj, '_instance_key'): - session(obj).register_dirty(obj) + get_session(obj).register_dirty(obj) else: - session(obj).register_new(obj) + get_session(obj).register_new(obj) def create_list(self, obj, key, list_, **kwargs): return UOWListElement(obj, key, list_, **kwargs) class UnitOfWork(object): - def __init__(self, session, parent=None, is_begun=False): - self.session = session - self.is_begun = is_begun - if is_begun: - self.begin_count = 1 - else: - self.begin_count = 0 - if parent is not None: - self.identity_map = parent.identity_map + def __init__(self, identity_map=None): + if identity_map is not None: + self.identity_map = identity_map else: self.identity_map = weakref.WeakValueDictionary() @@ -241,7 +269,6 @@ class UnitOfWork(object): self.dirty = util.HashSet() self.modified_lists = util.HashSet() self.deleted = util.HashSet() - self.parent = parent def get(self, class_, *id): """given a class and a list of primary key values in their table-order, locates the mapper @@ -305,12 +332,10 @@ class UnitOfWork(object): if not hasattr(obj, '_instance_key'): mapper = object_mapper(obj) obj._instance_key = mapper.instance_key(obj) - self.session.bind_to(obj) self._put(obj._instance_key, obj) self.attributes.commit(obj) def register_new(self, obj): - self.session.bind_to(obj) self.new.append(obj) def register_dirty(self, obj): @@ -335,19 +360,7 @@ class UnitOfWork(object): except KeyError: pass - # TODO: tie in register_new/register_dirty with table transaction begins ? - def begin(self): - if self.is_begun: - self.begin_count += 1 - return - u = UnitOfWork(self.session, parent=self, is_begun=True) - self.session.registry.set(u) - def commit(self, *objects): - if self.is_begun: - self.begin_count -= 1 - if self.begin_count > 0: - return commit_context = UOWTransaction(self) if len(objects): @@ -394,16 +407,12 @@ class UnitOfWork(object): except: for e in engines: e.rollback() - if self.parent: - self.session.registry.set(self.parent) raise for e in engines: e.commit() commit_context.post_exec() - if self.parent: - self.session.registry.set(self.parent) def rollback_object(self, obj): """'rolls back' the attributes that have been changed on an object instance.""" @@ -975,13 +984,11 @@ def object_mapper(obj): global_attributes = UOWAttributeManager() -global_session = Session(scope="thread", hash_key='thread') -uow = global_session.registry # Note: this is not a UnitOfWork, it is a ScopedRegistry that manages UnitOfWork objects -_sessions = weakref.WeakValueDictionary() -_sessions[global_session.hash_key] = global_session +session_registry = util.ScopedRegistry(Session) # Default session registry +_sessions = weakref.WeakValueDictionary() # all referenced sessions (including user-created) -def session(obj=None): +def get_session(obj=None): # object-specific session ? if obj is not None: # does it have a hash key ? @@ -993,12 +1000,9 @@ def session(obj=None): except KeyError: raise "Session '%s' referenced by object '%s' no longer exists" % (hashkey, repr(obj)) - try: - # have a thread-locally defined session (via using_session) ? - return _sessions[thread.get_ident()] - except KeyError: - # nope, return the regular session - return global_session + return session_registry() + +uow = get_session # deprecated def push_session(sess): old = _sessions.get(thread.get_ident(), None) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 2172106466..633091dd32 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -377,7 +377,7 @@ class ScopedRegistry(object): def __init__(self, createfunc, scopefunc=None): self.createfunc = createfunc if scopefunc is None: - scopefunc = thread.get_ident + self.scopefunc = thread.get_ident else: self.scopefunc = scopefunc self.registry = {} diff --git a/test/objectstore.py b/test/objectstore.py index 8551603927..bc90ec5389 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -72,6 +72,47 @@ class HistoryTest(AssertMixin): u = m.select()[0] print u.addresses[0].user +class SessionTest(AssertMixin): + def setUpAll(self): + db.echo = False + users.create() + tables.user_data() + db.echo = testbase.echo + def tearDownAll(self): + db.echo = False + users.drop() + db.echo = testbase.echo + def setUp(self): + objectstore.get_session().clear() + clear_mappers() + + def test_nested_begin_commit(self): + """test nested session.begin/commit""" + class User(object):pass + m = mapper(User, users) + def name_of(id): + return users.select(users.c.user_id == id).execute().fetchone().user_name + name1 = "Oliver Twist" + name2 = 'Mr. Bumble' + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + s = objectstore.get_session() + s.begin() + s.begin() + m.get(7).user_name = name1 + s.begin() + m.get(8).user_name = name2 + s.commit() + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + s.commit() + self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) + self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + s.commit() + self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1) + self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2) + + class PKTest(AssertMixin): def setUpAll(self): db.echo = False diff --git a/test/tables.py b/test/tables.py index 00f946af48..fecd86bc4b 100644 --- a/test/tables.py +++ b/test/tables.py @@ -71,6 +71,13 @@ def delete(): users.delete().execute() db.commit() +def user_data(): + users.insert().execute( + dict(user_id = 7, user_name = 'jack'), + dict(user_id = 8, user_name = 'ed'), + dict(user_id = 9, user_name = 'fred') + ) + def data(): delete()