From: Mike Bayer Date: Sat, 17 Sep 2005 07:49:31 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~692 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=776c2396a0c1856c3c252b3ef22bace0acdf1741;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index a78dafc4c9..dffa4afa21 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -114,7 +114,8 @@ class UOWListElement(util.HistoryArraySet): return res class UnitOfWork(object): - def __init__(self, parent = None): + def __init__(self, parent = None, is_begun = False): + self.is_begun = is_begun self.new = util.HashSet() self.dirty = util.HashSet() self.modified_lists = util.HashSet() @@ -135,7 +136,13 @@ class UnitOfWork(object): obj.__dict__[key] = v self.register_attribute(obj, key).setattr_clean(v) return v - + + def rollback_attribute(self, obj, key): + if self.attribute_history.has_key(obj): + h = self.attribute_history[obj][key] + h.rollback() + obj.__dict__[key] = h.current + def set_attribute(self, obj, key, value, usehistory = False): if usehistory: self.register_attribute(obj, key).setattr(value) @@ -153,7 +160,18 @@ class UnitOfWork(object): self.register_dirty(obj) else: self.register_new(obj) - + + def rollback_obj(self, obj): + try: + attributes = self.attribute_history[obj] + for key, hist in attributes.iteritems(): + hist.rollback() + obj.__dict__[key] = hist.current + except KeyError: + pass + for value in obj.__dict__.values(): + if isinstance(value, util.HistoryArraySet): + value.rollback() def register_attribute(self, obj, key): try: attributes = self.attribute_history[obj] @@ -166,24 +184,36 @@ class UnitOfWork(object): def register_list_attribute(self, obj, key, data = None): try: - childlist = obj.__dict__[key] + attributes = self.attribute_history[obj] except KeyError: - childlist = UOWListElement(obj) - obj.__dict__[key] = childlist - - if callable(childlist): - childlist = UOWListElement(obj, childlist()) - obj.__dict__[key] = childlist - elif not isinstance(childlist, util.HistoryArraySet): - childlist = UOWListElement(obj, childlist) - obj.__dict__[key] = childlist + attributes = self.attribute_history.setdefault(obj, {}) + try: + childlist = attributes[key] + except KeyError: + try: + list = obj.__dict__[key] + if callable(list): + list = list() + except KeyError: + list = [] + obj.__dict__[key] = list + + childlist = UOWListElement(obj, list) + if data is not None and childlist.data != data: try: childlist.set_data(data) except TypeError: raise "object " + repr(data) + " is not an iterable object" return childlist - + + def rollback_list_attribute(self, obj, key): + try: + childlist = obj.__dict__[key] + if isinstance(childlist, util.HistoryArraySet): + childlist.rollback() + except KeyError: + pass def register_clean(self, obj, scope="thread"): try: del self.dirty[obj] @@ -202,7 +232,7 @@ class UnitOfWork(object): def register_dirty(self, obj): self.dirty.append(obj) - + def is_dirty(self, obj): if not self.dirty.contains(obj): return False @@ -214,14 +244,14 @@ class UnitOfWork(object): # TODO: tie in register_new/register_dirty with table transaction begins ? def begin(self): - u = UnitOfWork(self) + u = UnitOfWork(self, True) uow.set(u) def commit(self, *objects): import sqlalchemy.mapper commit_context = UOWTransaction(self) - + if len(objects): for obj in objects: commit_context.append_task(obj) @@ -232,9 +262,25 @@ class UnitOfWork(object): obj = item.obj() commit_context.append_task(obj) - commit_context.execute() - - # TODO: deleted stuff + engines = util.HashSet() + for mapper in commit_context.mappers.keys(): + for e in mapper.engines: + engines.append(e) + + for e in engines: + e.begin() + try: + commit_context.execute() + except: + for e in engines: + e.rollback() + if self.parent: + uow.set(self.parent) + raise + for e in engines: + e.commit() + + commit_context.post_exec() if self.parent: uow.set(self.parent) @@ -243,6 +289,7 @@ class UOWTransaction(object): def __init__(self, uow): self.uow = uow self.mappers = {} + self.engines = util.HashSet() self.dependencies = {} self.tasks = {} self.saved_objects = util.HashSet() @@ -295,17 +342,14 @@ class UOWTransaction(object): return 0 mapperlist.sort(compare) - try: - # TODO: db tranasction boundary - for task in mapperlist: - obj_list = task.objects - task.mapper.save_obj(obj_list, self) - for dep in task.dependencies: - (processor, stuff_to_process) = dep - processor.process_dependencies(stuff_to_process, self) - except: - raise + for task in mapperlist: + obj_list = task.objects + task.mapper.save_obj(obj_list, self) + for dep in task.dependencies: + (processor, stuff_to_process) = dep + processor.process_dependencies(stuff_to_process, self) + def post_exec(self): for obj in self.saved_objects: mapper = self.object_mapper(obj) obj._instance_key = mapper.identity_key(obj) diff --git a/test/mapper.py b/test/mapper.py index 2d457542d2..e2db30ca8c 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -5,6 +5,7 @@ import sqlalchemy.objectstore as objectstore #ECHO = True ECHO = False +DATA = True execfile("test/tables.py") db.echo = True