From b9f1a92493dcbe0600866a0ad0370c3dea7b41f3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 30 Mar 2010 19:19:41 -0400 Subject: [PATCH] This is turning out to be a rewrite of the accounting system of unitofwork.py, but the overarching method of doing things stays the same. it should be easy to add new dependencies between actions and to change the structure of how things are done. --- lib/sqlalchemy/orm/dependency.py | 44 ++++++++++++---- lib/sqlalchemy/orm/mapper.py | 60 +++------------------ lib/sqlalchemy/orm/properties.py | 4 -- lib/sqlalchemy/orm/unitofwork.py | 90 +++++++++++++++++++++++++------- test/base/test_dependency.py | 54 +++++++++---------- 5 files changed, 135 insertions(+), 117 deletions(-) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index fa33318044..a120a626f2 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -14,7 +14,7 @@ dependencies at flush time. from sqlalchemy import sql, util import sqlalchemy.exceptions as sa_exc -from sqlalchemy.orm import attributes, exc, sync +from sqlalchemy.orm import attributes, exc, sync, unitofwork from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY @@ -188,18 +188,42 @@ class DependencyProcessor(object): return "%s(%s)" % (self.__class__.__name__, self.prop) class OneToManyDP(DependencyProcessor): - def register_dependencies(self, uowcommit): + + def per_mapper_flush_actions(self, uow): if self.post_update: - uowcommit.register_dependency(self.mapper, self.dependency_marker) - uowcommit.register_dependency(self.parent, self.dependency_marker) + # ... else: - uowcommit.register_dependency(self.parent, self.mapper) + after_save = unitofwork.ProcessAll(uow, self, False) + before_delete = unitofwork.ProcessAll(uow, self, True) + + parent_saves = unitofwork.SaveUpdateAll(uow, self.parent) + child_saves = unitofwork.SaveUpdateAll(uow, self.mapper) + + parent_deletes = unitofwork.DeleteAll(uow, self.parent) + child_deletes = unitofwork.DeleteAll(uow, self.mapper) + + uowtransaction.dependencies.update([ + (parent_saves, after_save), + (after_save, child_saves), + (child_deletes, before_delete), + (before_delete, parent_deletes) + ]) + + +# def register_dependencies(self, uowcommit): +# if self.post_update: +# uowcommit.register_dependency(self.mapper, self.dependency_marker) +# uowcommit.register_dependency(self.parent, self.dependency_marker) +# else: +# uowcommit.register_dependency(self.parent, self.mapper) +# +# +# def register_processors(self, uowcommit): +# if self.post_update: +# uowcommit.register_processor(self.dependency_marker, self, self.parent) +# else: +# uowcommit.register_processor(self.parent, self, self.parent) - def register_processors(self, uowcommit): - if self.post_update: - uowcommit.register_processor(self.dependency_marker, self, self.parent) - else: - uowcommit.register_processor(self.parent, self, self.parent) def process_dependencies(self, task, deplist, uowcommit, delete = False): if delete: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 18436e211a..a1787933f2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1254,35 +1254,14 @@ class Mapper(object): return sqlutil.sort_tables(l) - - def get_flush_actions(self, uowtransaction, state): - if isdelete: - type_ = Delete - tables = reversed(mapper._sorted_table_list) - elif not _state_has_identity(state): - type_ = Insert - tables = mapper._sorted_table_list - else: - type_ = Update - tables = mapper._sorted_table_list - - recs = [ - type_(state, table) - for table in tables - ] - for i, rec in enumerate(recs): - if i > 0: - self._dependency(recs[i - 1], recs[i]) - recs.append(SyncKeys(state, recs[i - 1].table, recs[i].table)) - - dep_recs = [] - for prop in mapper._props.values(): - dp = prop.get_flush_actions(uowtransaction, recs, state) - if dp: - dep_recs.extend(dp) - - return recs + dep_recs + def per_mapper_flush_actions(self, uowtransaction): + unitofwork.SaveUpdateAll(uow, self.base_mapper) + unitofwork.DeleteAll(uow, self.base_mapper) + for prop in self._props.values(): + dp = prop._dependency_processor + if dp is not None: + dp.per_mapper_flush_actions(uowtransaction) def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): @@ -1877,31 +1856,6 @@ class Mapper(object): log.class_logger(Mapper) -class Insert(unitofwork.Rec): - def __init__(self, mapper, state, table): - self.mapper = mapper - self.state = state - self.table = table - -class Update(unitofwork.Rec): - def __init__(self, mapper, state, table): - self.mapper = mapper - self.state = state - self.table = table - -class Delete(unitofwork.Rec): - def __init__(self, mapper, state, table): - self.mapper = mapper - self.state = state - self.table = table - -class SyncKeys(unitofwork.Rec): - def __init__(self, mapper, state, parent, child): - self.mapper = mapper - self.state = state - self.parent = parent - self.child = child - def reconstructor(fn): """Decorate a method as the 'reconstructor' hook. diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index ec21b27d62..ff1f234765 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -1285,10 +1285,6 @@ class RelationshipProperty(StrategizedProperty): source_selectable, dest_selectable, secondary, target_adapter) - def get_flush_actions(self, uowtransaction, records, state): - if not self.viewonly: - return self._depency_processor.get_flush_actions(uowtransaction, records, state) - PropertyLoader = RelationProperty = RelationshipProperty log.class_logger(RelationshipProperty) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index dc4adff966..ea0639192b 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -88,12 +88,12 @@ class UOWTransaction(object): # information. self.attributes = {} - self.recs = [] - self.states = set() - self.dependencies = [] - - def _dependency(self, rec1, rec2): - self.dependencies.append((rec1, rec2)) + self.mappers = collections.defaultdict(set) + self.actions = {} + self.saves = set() + self.deletes = set() + self.etc = set() + self.dependencies = set() def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) @@ -129,21 +129,39 @@ class UOWTransaction(object): return mapper = _state_mapper(state) + self.mappers[mapper].add(state) + self._state_collection(isdelete, listonly).add(state) + + def register_dependency(self, parent, child): + self.dependencies.add((parent, child)) - self.states.add(state) + def _state_collection(self, isdelete, listonly): + if isdelete: + return self.deletes + elif not listonly: + return self.saves + else: + return self.etc - self.recs.extend( - mapper.get_flush_actions(self, state) - ) + def states_for_mapper(self, mapper, isdelete, listonly): + return iter(self._state_collection(isdelete, listonly)[mapper]) + + def states_for_mapper_hierarchy(self, mapper, isdelete, listonly): + collection = self._state_collection(isdelete, listonly) + for mapper in mapper.base_mapper.polymorphic_iterator(): + for state in collection[mapper]: + yield state + + def execute(self): + for mapper in self.mappers: + mapper.per_mapper_flush_actions(self) + +# if cycles: +# break up actions into finer grained actions along those cycles - def execute(self): - # so here, thinking we could figure out a way to get - # consecutive, "compatible" records to collapse together, - # i.e. a bunch of updates become an executemany(), etc. - # even though we usually need individual executes. - for rec in topological.sort(self.dependencies, self.recs): - rec.execute() +# for rec in topological.sort(self.dependencies, self.actions): +# rec.execute() def finalize_flush_changes(self): """mark processed objects as clean / deleted after a successful flush(). @@ -160,8 +178,40 @@ class UOWTransaction(object): log.class_logger(UOWTransaction) -# TODO: don't know what these should be. -# its very hard not to use subclasses to define behavior here. class Rec(object): - pass + def __new__(self, uow, *args): + key = (self.__class__, ) + args + if key in uow.actions: + return uow.actions[key] + else: + uow.actions[key] = ret = object.__new__(self) + return ret + +class SaveUpdateAll(Rec): + def __init__(self, uow, mapper): + self.mapper = mapper + +class DeleteAll(Rec): + def __init__(self, mapper): + self.mapper = mapper + +class ProcessAll(Rec): + def __init__(self, uow, dependency_processor, delete): + self.dependency_processor = dependency_processor + self.delete = delete + +class ProcessState(Rec): + def __init__(self, uow, dependency_processor, delete, state): + self.dependency_processor = dependency_processor + self.delete = delete + self.state = state + +class SaveUpdateState(Rec): + def __init__(self, uow, state): + self.state = state + +class DeleteState(Rec): + def __init__(self, uow, state): + self.state = state + diff --git a/test/base/test_dependency.py b/test/base/test_dependency.py index 1a7e1d1f1d..7dc55ea99e 100644 --- a/test/base/test_dependency.py +++ b/test/base/test_dependency.py @@ -11,6 +11,7 @@ class DependencySortTest(TestBase): for parent, child in tuples: deps[parent].add(child) + assert len(result) for i, node in enumerate(result): for n in result[i:]: assert node not in deps[n] @@ -35,8 +36,7 @@ class DependencySortTest(TestBase): (node4, subnode3), (node4, subnode4) ] - head = topological.sort(tuples, []) - self.assert_sort(tuples, head) + self.assert_sort(tuples, topological.sort(tuples, [])) def testsort2(self): node1 = 'node1' @@ -53,28 +53,7 @@ class DependencySortTest(TestBase): (node5, node6), (node6, node2) ] - head = topological.sort(tuples, [node7]) - self.assert_sort(tuples, head) - - def testsort3(self): - ['Mapper|Keyword|keywords,Mapper|IKAssociation|itemkeywords', 'Mapper|Item|items,Mapper|IKAssociation|itemkeywords'] - node1 = 'keywords' - node2 = 'itemkeyowrds' - node3 = 'items' - tuples = [ - (node1, node2), - (node3, node2), - (node1,node3) - ] - head1 = topological.sort(tuples, [node1, node2, node3]) - head2 = topological.sort(tuples, [node3, node1, node2]) - head3 = topological.sort(tuples, [node3, node2, node1]) - - # TODO: figure out a "node == node2" function - #self.assert_(str(head1) == str(head2) == str(head3)) - print "\n" + str(head1) - print "\n" + str(head2) - print "\n" + str(head3) + self.assert_sort(tuples, topological.sort(tuples, [node7])) def testsort4(self): node1 = 'keywords' @@ -87,8 +66,7 @@ class DependencySortTest(TestBase): (node1, node3), (node3, node2) ] - head = topological.sort(tuples, []) - self.assert_sort(tuples, head) + self.assert_sort(tuples, topological.sort(tuples, [])) def testcircular(self): node1 = 'node1' @@ -107,6 +85,8 @@ class DependencySortTest(TestBase): allitems = [node1, node2, node3, node4] assert_raises(exc.CircularDependencyError, topological.sort, tuples, allitems) + # TODO: test find_cycles + def testcircular2(self): # this condition was arising from ticket:362 # and was not treated properly by topological sort @@ -123,20 +103,34 @@ class DependencySortTest(TestBase): ] assert_raises(exc.CircularDependencyError, topological.sort, tuples, []) + # TODO: test find_cycles + def testcircular3(self): question, issue, providerservice, answer, provider = "Question", "Issue", "ProviderService", "Answer", "Provider" - tuples = [(question, issue), (providerservice, issue), (provider, question), (question, provider), (providerservice, question), (provider, providerservice), (question, answer), (issue, question)] + tuples = [(question, issue), (providerservice, issue), (provider, question), + (question, provider), (providerservice, question), + (provider, providerservice), (question, answer), (issue, question)] assert_raises(exc.CircularDependencyError, topological.sort, tuples, []) - + + # TODO: test find_cycles + def testbigsort(self): tuples = [(i, i + 1) for i in range(0, 1500, 2)] - head = topological.sort(tuples, []) + self.assert_sort( + tuples, + topological.sort(tuples, []) + ) def testids(self): # ticket:1380 regression: would raise a KeyError - topological.sort([(id(i), i) for i in range(3)], []) + tuples = [(id(i), i) for i in range(3)] + self.assert_sort( + tuples, + topological.sort(tuples, []) + ) + -- 2.47.3