]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
its alive !
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2010 17:25:13 +0000 (13:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2010 17:25:13 +0000 (13:25 -0400)
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_unitofworkv2.py

index a120a626f2e2b40b6c688243175512ac658f7f9d..8301c157c1de9f9c2d0e4dfa642ae79857211d19 100644 (file)
@@ -189,123 +189,115 @@ class DependencyProcessor(object):
 
 class OneToManyDP(DependencyProcessor):
     
-    def per_mapper_flush_actions(self, uow):
+    def per_property_flush_actions(self, uow):
+        unitofwork.GetDependentObjects(uow, self, False, True)
+        
+        after_save = unitofwork.ProcessAll(uow, self, False, True)
+        before_delete = unitofwork.ProcessAll(uow, self, True, True)
+        
+        parent_saves = unitofwork.SaveUpdateAll(uow, self.parent)
+        child_saves = unitofwork.SaveUpdateAll(uow, self.mapper)
+        
+        parent_deletes = unitofwork.DeleteAll(uow, self.parent)
+        child_deletes = unitofwork.DeleteAll(uow, self.mapper)
+
         if self.post_update:
-            # ...
+            uow.dependencies.update([
+                (child_saves, after_save),
+                (parent_saves, after_save),
+                (before_delete, parent_deletes),
+                (before_delete, child_deletes),
+            ])
         else:
-            after_save = unitofwork.ProcessAll(uow, self, False)
-            before_delete = unitofwork.ProcessAll(uow, self, True)
-            
-            parent_saves = unitofwork.SaveUpdateAll(uow, self.parent)
-            child_saves = unitofwork.SaveUpdateAll(uow, self.mapper)
+            unitofwork.GetDependentObjects(uow, self, True, True)
             
-            parent_deletes = unitofwork.DeleteAll(uow, self.parent)
-            child_deletes = unitofwork.DeleteAll(uow, self.mapper)
-            
-            uowtransaction.dependencies.update([
+            uow.dependencies.update([
                 (parent_saves, after_save),
                 (after_save, child_saves),
                 (child_deletes, before_delete),
                 (before_delete, parent_deletes)
             ])
-            
-        
-#    def register_dependencies(self, uowcommit):
-#        if self.post_update:
-#            uowcommit.register_dependency(self.mapper, self.dependency_marker)
-#            uowcommit.register_dependency(self.parent, self.dependency_marker)
-#        else:
-#            uowcommit.register_dependency(self.parent, self.mapper)
-#
-#
-#    def register_processors(self, uowcommit):
-#        if self.post_update:
-#            uowcommit.register_processor(self.dependency_marker, self, self.parent)
-#        else:
-#            uowcommit.register_processor(self.parent, self, self.parent)
-
-
-    def process_dependencies(self, task, deplist, uowcommit, delete = False):
-        if delete:
-            # head object is being deleted, and we manage its list of child objects
-            # the child objects have to have their foreign key to the parent set to NULL
-            # this phase can be called safely for any cascade but is unnecessary if delete cascade
-            # is on.
-            if self.post_update or not self.passive_deletes == 'all':
-                for state in deplist:
-                    history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
-                    if history:
-                        for child in history.deleted:
-                            if child is not None and self.hasparent(child) is False:
-                                self._synchronize(state, child, None, True, uowcommit)
-                                self._conditional_post_update(child, uowcommit, [state])
-                        if self.post_update or not self.cascade.delete:
-                            for child in history.unchanged:
-                                if child is not None:
-                                    self._synchronize(state, child, None, True, uowcommit)
-                                    self._conditional_post_update(child, uowcommit, [state])
-        else:
-            for state in deplist:
-                history = uowcommit.get_attribute_history(state, self.key, passive=True)
-                if history:
-                    for child in history.added:
-                        self._synchronize(state, child, None, False, uowcommit)
+    
+    def presort_delete(self, uowcommit, states):
+        # head object is being deleted, and we manage its list of child objects
+        # the child objects have to have their foreign key to the parent set to NULL
+        should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
+        for state in states:
+            history = uowcommit.get_attribute_history(
+                                        state, self.key, passive=self.passive_deletes)
+            if history:
+                for child in history.deleted:
+                    if child is not None and self.hasparent(child) is False:
+                        if self.cascade.delete_orphan:
+                            uowcommit.register_object(child, isdelete=True)
+                        else:
+                            uowcommit.register_object(child)
+                if should_null_fks:
+                    for child in history.unchanged:
                         if child is not None:
-                            self._conditional_post_update(child, uowcommit, [state])
-
-                    for child in history.deleted:
-                        if not self.cascade.delete_orphan and not self.hasparent(child):
-                            self._synchronize(state, child, None, True, uowcommit)
-
-                    if self._pks_changed(uowcommit, state):
-                        for child in history.unchanged:
-                            self._synchronize(state, child, None, False, uowcommit)
-                            
-    def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
-        if delete:
-            # head object is being deleted, and we manage its list of child objects
-            # the child objects have to have their foreign key to the parent set to NULL
-            if not self.post_update:
-                should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
-                for state in deplist:
+                            uowcommit.register_object(child)
+    
+    def presort_saves(self, uowcommit, states):
+        for state in states:
+            history = uowcommit.get_attribute_history(state, self.key, passive=True)
+            if history:
+                for child in history.added:
+                    if child is not None:
+                        uowcommit.register_object(child)
+                for child in history.deleted:
+                    if not self.cascade.delete_orphan:
+                        uowcommit.register_object(child, isdelete=False)
+                    elif self.hasparent(child) is False:
+                        uowcommit.register_object(child, isdelete=True)
+                        for c, m in self.mapper.cascade_iterator('delete', child):
+                            uowcommit.register_object(
+                                attributes.instance_state(c),
+                                isdelete=True)
+            if self._pks_changed(uowcommit, state):
+                if not history:
                     history = uowcommit.get_attribute_history(
-                                                state, self.key, passive=self.passive_deletes)
-                    if history:
-                        for child in history.deleted:
-                            if child is not None and self.hasparent(child) is False:
-                                if self.cascade.delete_orphan:
-                                    uowcommit.register_object(child, isdelete=True)
-                                else:
-                                    uowcommit.register_object(child)
-                        if should_null_fks:
-                            for child in history.unchanged:
-                                if child is not None:
-                                    uowcommit.register_object(child)
-        else:
-            for state in deplist:
-                history = uowcommit.get_attribute_history(state, self.key, passive=True)
+                                        state, self.key, passive=self.passive_updates)
                 if history:
-                    for child in history.added:
+                    for child in history.unchanged:
                         if child is not None:
                             uowcommit.register_object(child)
+    
+    def process_deletes(self, uowcommit, states):
+        # head object is being deleted, and we manage its list of child objects
+        # the child objects have to have their foreign key to the parent set to NULL
+        # this phase can be called safely for any cascade but is unnecessary if delete cascade
+        # is on.
+        if self.post_update or not self.passive_deletes == 'all':
+            for state in states:
+                history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
+                if history:
                     for child in history.deleted:
-                        if not self.cascade.delete_orphan:
-                            uowcommit.register_object(child, isdelete=False)
-                        elif self.hasparent(child) is False:
-                            uowcommit.register_object(child, isdelete=True)
-                            for c, m in self.mapper.cascade_iterator('delete', child):
-                                uowcommit.register_object(
-                                    attributes.instance_state(c),
-                                    isdelete=True)
-                if self._pks_changed(uowcommit, state):
-                    if not history:
-                        history = uowcommit.get_attribute_history(
-                                            state, self.key, passive=self.passive_updates)
-                    if history:
+                        if child is not None and self.hasparent(child) is False:
+                            self._synchronize(state, child, None, True, uowcommit)
+                            self._conditional_post_update(child, uowcommit, [state])
+                    if self.post_update or not self.cascade.delete:
                         for child in history.unchanged:
                             if child is not None:
-                                uowcommit.register_object(child)
+                                self._synchronize(state, child, None, True, uowcommit)
+                                self._conditional_post_update(child, uowcommit, [state])
+    
+    def process_saves(self, uowcommit, states):
+        for state in states:
+            history = uowcommit.get_attribute_history(state, self.key, passive=True)
+            if history:
+                for child in history.added:
+                    self._synchronize(state, child, None, False, uowcommit)
+                    if child is not None:
+                        self._conditional_post_update(child, uowcommit, [state])
+
+                for child in history.deleted:
+                    if not self.cascade.delete_orphan and not self.hasparent(child):
+                        self._synchronize(state, child, None, True, uowcommit)
 
+                if self._pks_changed(uowcommit, state):
+                    for child in history.unchanged:
+                        self._synchronize(state, child, None, False, uowcommit)
+        
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
         source = state
         dest = child
index ca96764695c2273320ad05d9d9d5735425ebd68a..07f6a09aba16c8c7525c13c80aa225b7dcb42fcc 100644 (file)
@@ -504,7 +504,7 @@ class MapperProperty(object):
         """
         pass
 
-    def get_flush_actions(self, uowtransaction, records, state):
+    def per_property_flush_actions(self, uow):
         pass
         
     def is_primary(self):
index a1787933f269859a4b47c47f20a80ea2eaee8eff..c95bcd4c8770e06696dd026478853f011f381db4 100644 (file)
@@ -1254,14 +1254,13 @@ class Mapper(object):
         
         return sqlutil.sort_tables(l)
         
-    def per_mapper_flush_actions(self, uowtransaction):
-        unitofwork.SaveUpdateAll(uow, self.base_mapper)
-        unitofwork.DeleteAll(uow, self.base_mapper)
+    def per_mapper_flush_actions(self, uow):
+        saves = unitofwork.SaveUpdateAll(uow, self.base_mapper)
+        deletes = unitofwork.DeleteAll(uow, self.base_mapper)
+        uow.dependencies.add((saves, deletes))
         
         for prop in self._props.values():
-            dp = prop._dependency_processor
-            if dp is not None:
-                dp.per_mapper_flush_actions(uowtransaction)
+            prop.per_property_flush_actions(uow)
         
     def _save_obj(self, states, uowtransaction, postupdate=False, 
                                 post_update_cols=None, single=False):
index ff1f23476516ae7d14894608ddce155844839442..61dc9eb55f474874348c6cc7013305d7c91d5cfe 100644 (file)
@@ -1203,6 +1203,10 @@ class RelationshipProperty(StrategizedProperty):
     def _is_self_referential(self):
         return self.mapper.common_parent(self.parent)
 
+    def per_property_flush_actions(self, uow):
+        if not self.viewonly and self._dependency_processor:
+            self._dependency_processor.per_property_flush_actions(uow)
+
     def _create_joins(self, source_polymorphic=False, 
                             source_selectable=None, dest_polymorphic=False, 
                             dest_selectable=None, of_type=None):
index 0810175bf868d2b82d0d450d0484a2f039f65176..391e78fdbdec1ba61b36249addec18af881ab555 100644 (file)
@@ -1416,7 +1416,7 @@ class Session(object):
         for state in proc:
             flush_context.register_object(state, isdelete=True)
 
-        if len(flush_context.tasks) == 0:
+        if not flush_context.has_work:
             return
 
         flush_context.transaction = transaction = self.begin(
index ea0639192b1d4874b35b6704ab767cc1067111b2..5b009baea4138e0e644b8c05f01111418610b9b7 100644 (file)
@@ -88,12 +88,19 @@ class UOWTransaction(object):
         # information.
         self.attributes = {}
         
-        self.mappers = collections.defaultdict(set)
-        self.actions = {}
-        self.saves = set()
-        self.deletes = set()
-        self.etc = set()
+        self.mappers = util.defaultdict(set)
+        self.presort_actions = {}
+        self.postsort_actions = {}
+        self.states = {}
         self.dependencies = set()
+    
+    @property
+    def has_work(self):
+        return bool(self.states)
+
+    def is_deleted(self, state):
+        """return true if the given state is marked as deleted within this UOWTransaction."""
+        return state in self.states and self.states[state][0]
         
     def get_attribute_history(self, state, key, passive=True):
         hashkey = ("history", state, key)
@@ -124,39 +131,44 @@ class UOWTransaction(object):
         # if object is not in the overall session, do nothing
         if not self.session._contains_state(state):
             return
-        
-        if state in self.states:
-            return
+
+        if state not in self.states:
+            mapper = _state_mapper(state)
             
-        mapper = _state_mapper(state)
-        self.mappers[mapper].add(state)
-        self._state_collection(isdelete, listonly).add(state)
+            if mapper not in self.mappers:
+                mapper.per_mapper_flush_actions(self)
+            
+            self.mappers[mapper].add(state)
+        self.states[state] = (isdelete, listonly)
     
-    def register_dependency(self, parent, child):
-        self.dependencies.add((parent, child))
-        
-    def _state_collection(self, isdelete, listonly):
-        if isdelete:
-            return self.deletes
-        elif not listonly:
-            return self.saves
-        else:
-            return self.etc
-        
     def states_for_mapper(self, mapper, isdelete, listonly):
-        return iter(self._state_collection(isdelete, listonly)[mapper])
+        checktup = (isdelete, listonly)
+        for state, tup in self.states.iteritems():
+            if tup == checktup:
+                yield state
 
     def states_for_mapper_hierarchy(self, mapper, isdelete, listonly):
-        collection = self._state_collection(isdelete, listonly)
+        checktup = (isdelete, listonly)
         for mapper in mapper.base_mapper.polymorphic_iterator():
-            for state in collection[mapper]:
-                yield state
+            for state, tup in self.states.iteritems():
+                if tup == checktup:
+                    yield state
                 
     def execute(self):
         
-        for mapper in self.mappers:
-            mapper.per_mapper_flush_actions(self)
-
+        while True:
+            ret = False
+            for action in self.presort_actions.values():
+                if action.execute(self):
+                    ret = True
+            if not ret:
+                break
+        
+        sort = topological.sort(self.dependencies, self.postsort_actions.values())
+        print sort
+        for rec in sort:
+            rec.execute(self)
+            
 #        if cycles:
 #            break up actions into finer grained actions along those cycles
             
@@ -168,49 +180,131 @@ class UOWTransaction(object):
 
         this method is called within the flush() method after the
         execute() method has succeeded and the transaction has been committed.
+        
         """
-
-        for elem in self.elements:
-            if elem.isdelete:
-                self.session._remove_newly_deleted(elem.state)
-            elif not elem.listonly:
-                self.session._register_newly_persistent(elem.state)
+        for state, (isdelete, listonly) in self.states.iteritems():
+            if isdelete:
+                self.session._remove_newly_deleted(state)
+            elif not listonly:
+                self.session._register_newly_persistent(state)
 
 log.class_logger(UOWTransaction)
 
-class Rec(object):
-    def __new__(self, uow, *args):
-        key = (self.__class__, ) + args
-        if key in uow.actions:
-            return uow.actions[key]
+class PreSortRec(object):
+    def __new__(cls, uow, *args):
+        key = (cls, ) + args
+        if key in uow.presort_actions:
+            return uow.presort_actions[key]
         else:
-            uow.actions[key] = ret = object.__new__(self)
+            uow.presort_actions[key] = ret = object.__new__(cls)
             return ret
 
-class SaveUpdateAll(Rec):
+class PostSortRec(object):
+    def __new__(cls, uow, *args):
+        key = (cls, ) + args
+        if key in uow.postsort_actions:
+            return uow.postsort_actions[key]
+        else:
+            uow.postsort_actions[key] = ret = object.__new__(cls)
+            return ret
+    
+    def __repr__(self):
+        return "%s(%s)" % (
+            self.__class__.__name__,
+            ",".join(str(x) for x in self.__dict__.values())
+        )
+
+class PropertyRecMixin(object):
+    def __init__(self, uow, dependency_processor, delete, fromparent):
+        self.dependency_processor = dependency_processor
+        self.delete = delete
+        self.fromparent = fromparent
+        
+        self.processed = set()
+        
+        prop = dependency_processor.prop
+        if fromparent:
+            self._mappers = set(
+                m for m in dependency_processor.parent.polymorphic_iterator()
+                if m._props[prop.key] is prop
+            )
+        else:
+            self._mappers = set(
+                dependency_processor.mapper.polymorphic_iterator()
+            )
+
+    def __repr__(self):
+        return "%s(%s, delete=%s)" % (
+            self.__class__.__name__,
+            self.dependency_processor,
+            self.delete
+        )
+
+    def _elements(self, uow):
+        for mapper in self._mappers:
+            for state in uow.mappers[mapper]:
+                if state in self.processed:
+                    continue
+                (isdelete, listonly) = uow.states[state]
+                if isdelete == self.delete:
+                    yield state
+    
+class GetDependentObjects(PropertyRecMixin, PreSortRec):
+    def __init__(self, *args):
+        self.processed = set()
+        super(GetDependentObjects, self).__init__(*args)
+
+    def execute(self, uow):
+        states = list(self._elements(uow))
+        if states:
+            self.processed.update(states)
+            if self.delete:
+                self.dependency_processor.presort_deletes(uow, states)
+            else:
+                self.dependency_processor.presort_saves(uow, states)
+            return True
+        else:
+            return False
+
+class ProcessAll(PropertyRecMixin, PostSortRec):
+    def execute(self, uow):
+        states = list(self._elements(uow))
+        if self.delete:
+            self.dependency_processor.process_deletes(uow, states)
+        else:
+            self.dependency_processor.process_saves(uow, states)
+
+class SaveUpdateAll(PostSortRec):
     def __init__(self, uow, mapper):
         self.mapper = mapper
 
-class DeleteAll(Rec):
-    def __init__(self, mapper):
+    def execute(self, uow):
+        self.mapper._save_obj(
+            uow.states_for_mapper_hierarchy(self.mapper, False, False),
+            uow
+        )
+        
+class DeleteAll(PostSortRec):
+    def __init__(self, uow, mapper):
         self.mapper = mapper
 
-class ProcessAll(Rec):
-    def __init__(self, uow, dependency_processor, delete):
-        self.dependency_processor = dependency_processor
-        self.delete = delete
+    def execute(self, uow):
+        self.mapper._delete_obj(
+            uow.states_for_mapper_hierarchy(self.mapper, True, False),
+            uow
+        )
 
-class ProcessState(Rec):
+class ProcessState(PostSortRec):
     def __init__(self, uow, dependency_processor, delete, state):
         self.dependency_processor = dependency_processor
         self.delete = delete
         self.state = state
         
-class SaveUpdateState(Rec):
+class SaveUpdateState(PostSortRec):
     def __init__(self, uow, state):
         self.state = state
 
-class DeleteState(Rec):
+class DeleteState(PostSortRec):
     def __init__(self, uow, state):
         self.state = state
 
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..545dd7a602202fffd8f1a9c726b4adb95fe9ccf5 100644 (file)
@@ -0,0 +1,45 @@
+from sqlalchemy.test.testing import eq_, assert_raises, assert_raises_message
+from sqlalchemy.test import testing
+from test.orm import _fixtures
+from sqlalchemy.orm import mapper, relationship, backref, create_session
+from sqlalchemy.test.assertsql import AllOf, CompiledSQL
+
+from test.orm._fixtures import keywords, addresses, Base, Keyword,  \
+           Dingaling, item_keywords, dingalings, User, items,\
+           orders, Address, users, nodes, \
+            order_items, Item, Order, Node, \
+            composite_pk_table, CompositePk
+
+class UOWTest(_fixtures.FixtureTest, testing.AssertsExecutionResults):
+    run_inserts = None
+
+class RudimentaryFlushTest(UOWTest):
+
+    def test_one_to_many(self):
+        mapper(User, users, properties={
+            'addresses':relationship(Address),
+        })
+        mapper(Address, addresses)
+        sess = create_session()
+
+        a1, a2 = Address(email_address='a1'), Address(email_address='a2')
+        u1 = User(name='u1', addresses=[a1, a2])
+        sess.add(u1)
+    
+        self.assert_sql_execution(
+                testing.db,
+                sess.flush,
+                CompiledSQL(
+                    "INSERT INTO users (name) VALUES (:name)",
+                    {'name': 'u1'} 
+                ),
+                CompiledSQL(
+                    "INSERT INTO addresses (user_id, email_address) VALUES (:user_id, :email_address)",
+                    lambda ctx: {'email_address': 'a1', 'user_id':u1.id} 
+                ),
+                CompiledSQL(
+                    "INSERT INTO addresses (user_id, email_address) VALUES (:user_id, :email_address)",
+                    lambda ctx: {'email_address': 'a2', 'user_id':u1.id} 
+                ),
+            )
+    
\ No newline at end of file