return res
class UnitOfWork(object):
- def __init__(self, parent = None):
+ def __init__(self, parent = None, is_begun = False):
+ self.is_begun = is_begun
self.new = util.HashSet()
self.dirty = util.HashSet()
self.modified_lists = util.HashSet()
obj.__dict__[key] = v
self.register_attribute(obj, key).setattr_clean(v)
return v
-
+
+ def rollback_attribute(self, obj, key):
+ if self.attribute_history.has_key(obj):
+ h = self.attribute_history[obj][key]
+ h.rollback()
+ obj.__dict__[key] = h.current
+
def set_attribute(self, obj, key, value, usehistory = False):
if usehistory:
self.register_attribute(obj, key).setattr(value)
self.register_dirty(obj)
else:
self.register_new(obj)
-
+
+ def rollback_obj(self, obj):
+ try:
+ attributes = self.attribute_history[obj]
+ for key, hist in attributes.iteritems():
+ hist.rollback()
+ obj.__dict__[key] = hist.current
+ except KeyError:
+ pass
+ for value in obj.__dict__.values():
+ if isinstance(value, util.HistoryArraySet):
+ value.rollback()
def register_attribute(self, obj, key):
try:
attributes = self.attribute_history[obj]
def register_list_attribute(self, obj, key, data = None):
try:
- childlist = obj.__dict__[key]
+ attributes = self.attribute_history[obj]
except KeyError:
- childlist = UOWListElement(obj)
- obj.__dict__[key] = childlist
-
- if callable(childlist):
- childlist = UOWListElement(obj, childlist())
- obj.__dict__[key] = childlist
- elif not isinstance(childlist, util.HistoryArraySet):
- childlist = UOWListElement(obj, childlist)
- obj.__dict__[key] = childlist
+ attributes = self.attribute_history.setdefault(obj, {})
+ try:
+ childlist = attributes[key]
+ except KeyError:
+ try:
+ list = obj.__dict__[key]
+ if callable(list):
+ list = list()
+ except KeyError:
+ list = []
+ obj.__dict__[key] = list
+
+ childlist = UOWListElement(obj, list)
+
if data is not None and childlist.data != data:
try:
childlist.set_data(data)
except TypeError:
raise "object " + repr(data) + " is not an iterable object"
return childlist
-
+
+ def rollback_list_attribute(self, obj, key):
+ try:
+ childlist = obj.__dict__[key]
+ if isinstance(childlist, util.HistoryArraySet):
+ childlist.rollback()
+ except KeyError:
+ pass
def register_clean(self, obj, scope="thread"):
try:
del self.dirty[obj]
def register_dirty(self, obj):
self.dirty.append(obj)
-
+
def is_dirty(self, obj):
if not self.dirty.contains(obj):
return False
# TODO: tie in register_new/register_dirty with table transaction begins ?
def begin(self):
- u = UnitOfWork(self)
+ u = UnitOfWork(self, True)
uow.set(u)
def commit(self, *objects):
import sqlalchemy.mapper
commit_context = UOWTransaction(self)
-
+
if len(objects):
for obj in objects:
commit_context.append_task(obj)
obj = item.obj()
commit_context.append_task(obj)
- commit_context.execute()
-
- # TODO: deleted stuff
+ engines = util.HashSet()
+ for mapper in commit_context.mappers.keys():
+ for e in mapper.engines:
+ engines.append(e)
+
+ for e in engines:
+ e.begin()
+ try:
+ commit_context.execute()
+ except:
+ for e in engines:
+ e.rollback()
+ if self.parent:
+ uow.set(self.parent)
+ raise
+ for e in engines:
+ e.commit()
+
+ commit_context.post_exec()
if self.parent:
uow.set(self.parent)
def __init__(self, uow):
self.uow = uow
self.mappers = {}
+ self.engines = util.HashSet()
self.dependencies = {}
self.tasks = {}
self.saved_objects = util.HashSet()
return 0
mapperlist.sort(compare)
- try:
- # TODO: db tranasction boundary
- for task in mapperlist:
- obj_list = task.objects
- task.mapper.save_obj(obj_list, self)
- for dep in task.dependencies:
- (processor, stuff_to_process) = dep
- processor.process_dependencies(stuff_to_process, self)
- except:
- raise
+ for task in mapperlist:
+ obj_list = task.objects
+ task.mapper.save_obj(obj_list, self)
+ for dep in task.dependencies:
+ (processor, stuff_to_process) = dep
+ processor.process_dependencies(stuff_to_process, self)
+ def post_exec(self):
for obj in self.saved_objects:
mapper = self.object_mapper(obj)
obj._instance_key = mapper.identity_key(obj)