The registry is capable of maintaining object instances on a thread-local,
per-application, or custom user-defined basis."""
- def __init__(self, scope="application", getter=None, hash_key=None, keyfunc=None):
+ def __init__(self, nest_transactions=False, hash_key=None):
"""Initialize the objectstore with a UnitOfWork registry. If called
with no arguments, creates a single UnitOfWork for all operations.
- scope - "application" or "thread", the two default scopes
- getter - a callable that takes this Session as an argument and returns a
- new UnitOfWork.
+ nest_transactions - indicates begin/commit statements can be executed in a
+ "nested", defaults to False which indicates "only commit on the outermost begin/commit"
hash_key - the hash_key used to identify objects against this session, which
defaults to the id of the Session instance.
- keyfunc - allows custom scopes by providing a callable to return the "key"
- identifying the desired UnitOfWork.
"""
- if keyfunc is None:
- if scope=="thread":
- keyfunc = thread.get_ident
- elif scope=="application":
- keyfunc = lambda: True
- if getter is None:
- def createfunc():
- return UnitOfWork(self)
+ self.uow = UnitOfWork()
+ self.parent_uow = None
+ self.begin_count = 0
+ self.nest_transactions = nest_transactions
+ if hash_key is None:
+ self.hash_key = id(self)
else:
- createfunc = lambda: getter(self)
- self.registry = util.ScopedRegistry(createfunc, keyfunc)
- self._hash_key = hash_key
-
+ self.hash_key = hash_key
+ _sessions[self.hash_key] = self
+
def get_id_key(ident, class_, table):
"""returns an identity-map key for use in storing/retrieving an item from the identity
map, given a tuple of the object's primary key values.
return (class_, table.hash_key(), tuple([row[column] for column in primary_key]))
get_row_key = staticmethod(get_row_key)
- def _set_uow(self, uow):
- self.registry.set(uow)
- uow = property(lambda s:s.registry(), _set_uow, doc="Returns a scope-specific UnitOfWork object for this session.")
-
- hash_key = property(lambda s:s._hash_key or id(s))
+ def begin(self):
+ """begins a new UnitOfWork transaction. the next commit will affect only
+ objects that are created, modified, or deleted following the begin statement."""
+ self.begin_count += 1
+ if self.parent_uow is not None:
+ return
+ self.parent_uow = self.uow
+ self.uow = UnitOfWork(identity_map = self.uow.identity_map)
+
+ def commit(self, *objects):
+ """commits the current UnitOfWork transaction. if a transaction was begun
+ via begin(), commits only those objects that were created, modified, or deleted
+ since that begin statement. otherwise commits all objects that have been
+ changed.
+ if individual objects are submitted, then only those objects are committed, and the
+ begin/commit cycle is not affected."""
+ # if an object list is given, commit just those but dont
+ # change begin/commit status
+ if len(objects):
+ self.uow.commit(*objects)
+ return
+ if self.parent_uow is not None:
+ self.begin_count -= 1
+ if self.begin_count > 0:
+ return
+ self.uow.commit()
+ if self.parent_uow is not None:
+ self.uow = self.parent_uow
+ self.parent_uow = None
+
+ def rollback(self):
+ """rolls back the current UnitOfWork transaction, in the case that begin()
+ has been called. The changes logged since the begin() call are discarded."""
+ if self.parent_uow is None:
+ raise "UOW transaction is not begun"
+ self.uow = self.parent_uow
+ self.parent_uow = None
+ self.begin_count = 0
+
+ def register_clean(self, obj):
+ self._bind_to(obj)
+ self.uow.register_clean(obj)
+
+ def register_new(self, obj):
+ self._bind_to(obj)
+ self.uow.register_new(obj)
- def bind_to(self, obj):
+ def _bind_to(self, obj):
"""given an object, binds it to this session. changes on the object will affect
the currently scoped UnitOfWork maintained by this session."""
obj._sa_session_id = self.hash_key
def __getattr__(self, key):
"""proxy other methods to our underlying UnitOfWork"""
- return getattr(self.registry(), key)
+ return getattr(self.uow, key)
def clear(self):
- self.registry.clear()
+ self.uow = UnitOfWork()
- def delete(*obj):
+ def delete(self, *obj):
"""registers the given objects as to be deleted upon the next commit"""
- u = registry()
for o in obj:
- u.register_deleted(o)
+ self.uow.register_deleted(o)
def import_instance(self, instance):
"""places the given instance in the current thread's unit of work context,
key = getattr(instance, '_instance_key', None)
mapper = object_mapper(instance)
key = (key[0], mapper.table.hash_key(), key[2])
- u = self.registry()
+ u = self.uow
if key is not None:
if u.identity_map.has_key(key):
return u.identity_map[key]
else:
u.register_new(instance)
return instance
-
def get_id_key(ident, class_, table):
return Session.get_id_key(ident, class_, table)
def begin():
"""begins a new UnitOfWork transaction. the next commit will affect only
objects that are created, modified, or deleted following the begin statement."""
- session().begin()
+ get_session().begin()
def commit(*obj):
"""commits the current UnitOfWork transaction. if a transaction was begun
via begin(), commits only those objects that were created, modified, or deleted
since that begin statement. otherwise commits all objects that have been
- changed."""
- session().commit(*obj)
+ changed.
+
+ if individual objects are submitted, then only those objects are committed, and the
+ begin/commit cycle is not affected."""
+ get_session().commit(*obj)
def clear():
"""removes all current UnitOfWorks and IdentityMaps for this thread and
establishes a new one. It is probably a good idea to discard all
current mapped object instances, as they are no longer in the Identity Map."""
- session().clear()
+ get_session().clear()
def delete(*obj):
"""registers the given objects as to be deleted upon the next commit"""
- s = session()
- for o in obj:
- s.register_deleted(o)
+ s = get_session().delete(*obj)
def has_key(key):
"""returns True if the current thread-local IdentityMap contains the given instance key"""
- return session().has_key(key)
+ return get_session().has_key(key)
def has_instance(instance):
"""returns True if the current thread-local IdentityMap contains the given instance"""
- return session().has_instance(instance)
+ return get_session().has_instance(instance)
def is_dirty(obj):
"""returns True if the given object is in the current UnitOfWork's new or dirty list,
or if its a modified list attribute on an object."""
- return session().is_dirty(obj)
+ return get_session().is_dirty(obj)
def instance_key(instance):
"""returns the IdentityMap key for the given instance"""
- return session().instance_key(instance)
+ return get_session().instance_key(instance)
def import_instance(instance):
- return session().import_instance(instance)
+ return get_session().import_instance(instance)
class UOWListElement(attributes.ListElement):
def __init__(self, obj, key, data=None, deleteremoved=False, **kwargs):
attributes.ListElement.__init__(self, obj, key, data=data, **kwargs)
self.deleteremoved = deleteremoved
def list_value_changed(self, obj, key, item, listval, isdelete):
- sess = session(obj)
+ sess = get_session(obj)
if not isdelete and sess.deleted.contains(item):
raise "re-inserting a deleted value into a list"
sess.modified_lists.append(self)
def value_changed(self, obj, key, value):
if hasattr(obj, '_instance_key'):
- session(obj).register_dirty(obj)
+ get_session(obj).register_dirty(obj)
else:
- session(obj).register_new(obj)
+ get_session(obj).register_new(obj)
def create_list(self, obj, key, list_, **kwargs):
return UOWListElement(obj, key, list_, **kwargs)
class UnitOfWork(object):
- def __init__(self, session, parent=None, is_begun=False):
- self.session = session
- self.is_begun = is_begun
- if is_begun:
- self.begin_count = 1
- else:
- self.begin_count = 0
- if parent is not None:
- self.identity_map = parent.identity_map
+ def __init__(self, identity_map=None):
+ if identity_map is not None:
+ self.identity_map = identity_map
else:
self.identity_map = weakref.WeakValueDictionary()
self.dirty = util.HashSet()
self.modified_lists = util.HashSet()
self.deleted = util.HashSet()
- self.parent = parent
def get(self, class_, *id):
"""given a class and a list of primary key values in their table-order, locates the mapper
if not hasattr(obj, '_instance_key'):
mapper = object_mapper(obj)
obj._instance_key = mapper.instance_key(obj)
- self.session.bind_to(obj)
self._put(obj._instance_key, obj)
self.attributes.commit(obj)
def register_new(self, obj):
- self.session.bind_to(obj)
self.new.append(obj)
def register_dirty(self, obj):
except KeyError:
pass
- # TODO: tie in register_new/register_dirty with table transaction begins ?
- def begin(self):
- if self.is_begun:
- self.begin_count += 1
- return
- u = UnitOfWork(self.session, parent=self, is_begun=True)
- self.session.registry.set(u)
-
def commit(self, *objects):
- if self.is_begun:
- self.begin_count -= 1
- if self.begin_count > 0:
- return
commit_context = UOWTransaction(self)
if len(objects):
except:
for e in engines:
e.rollback()
- if self.parent:
- self.session.registry.set(self.parent)
raise
for e in engines:
e.commit()
commit_context.post_exec()
- if self.parent:
- self.session.registry.set(self.parent)
def rollback_object(self, obj):
"""'rolls back' the attributes that have been changed on an object instance."""
global_attributes = UOWAttributeManager()
-global_session = Session(scope="thread", hash_key='thread')
-uow = global_session.registry # Note: this is not a UnitOfWork, it is a ScopedRegistry that manages UnitOfWork objects
-_sessions = weakref.WeakValueDictionary()
-_sessions[global_session.hash_key] = global_session
+session_registry = util.ScopedRegistry(Session) # Default session registry
+_sessions = weakref.WeakValueDictionary() # all referenced sessions (including user-created)
-def session(obj=None):
+def get_session(obj=None):
# object-specific session ?
if obj is not None:
# does it have a hash key ?
except KeyError:
raise "Session '%s' referenced by object '%s' no longer exists" % (hashkey, repr(obj))
- try:
- # have a thread-locally defined session (via using_session) ?
- return _sessions[thread.get_ident()]
- except KeyError:
- # nope, return the regular session
- return global_session
+ return session_registry()
+
+uow = get_session # deprecated
def push_session(sess):
old = _sessions.get(thread.get_ident(), None)