]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- post_update behavior improved; does a better job at not
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Sep 2006 17:49:51 +0000 (17:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Sep 2006 17:49:51 +0000 (17:49 +0000)
updating too many rows, updates only required columns
[ticket:208]

CHANGES
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/cycles.py

diff --git a/CHANGES b/CHANGES
index bf64c367d993b1b53a522c588c81f49c6717c44c..89c0877b56172fa02880df2874b51631208e42f1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -19,6 +19,9 @@ after commit/rollback
 - added an implicit close() on the cursor in ResultProxy
 when the result closes
 - added scalar() method to ComposedSQLEngine
+- post_update behavior improved; does a better job at not 
+updating too many rows, updates only required columns
+[ticket:208]
 
 0.2.8
 - cleanup on connection methods + documentation.  custom DBAPI
index 1af368431bc7d5097078a445c4313251c85bf000..b734b34becc510d13f5e3a1b1599e0e24ae9d148 100644 (file)
@@ -58,7 +58,7 @@ class DependencyProcessor(object):
             return (obj1, obj2)
         else:
             return (obj2, obj1)
-
+            
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
         """this method is called during a flush operation to synchronize data between a parent and child object.
         it is called within the context of the various mappers and sometimes individual objects sorted according to their
@@ -81,6 +81,12 @@ class DependencyProcessor(object):
         this dependency processor represents"""
         return sessionlib.attribute_manager.get_history(obj, self.key, passive = passive)
 
+    def _conditional_post_update(self, obj, uowcommit, related):
+        if obj is not None and self.post_update:
+            for x in related:
+                if x is not None and (uowcommit.is_deleted(x) or not hasattr(x, '_instance_key')):
+                    uowcommit.register_object(obj, postupdate=True, post_update_cols=self.syncrules.dest_columns())
+                    break
 
 class OneToManyDP(DependencyProcessor):
     def register_dependencies(self, uowcommit):
@@ -103,21 +109,18 @@ class OneToManyDP(DependencyProcessor):
                     for child in childlist.deleted_items():
                         if child is not None and childlist.hasparent(child) is False:
                             self._synchronize(obj, child, None, True)
-                            if self.post_update:
-                                uowcommit.register_object(child, postupdate=True)
+                            self._conditional_post_update(child, uowcommit, [obj])
                     for child in childlist.unchanged_items():
                         if child is not None:
                             self._synchronize(obj, child, None, True)
-                            if self.post_update:
-                                uowcommit.register_object(child, postupdate=True)
+                            self._conditional_post_update(child, uowcommit, [obj])
         else:
             for obj in deplist:
                 childlist = self.get_object_dependencies(obj, uowcommit, passive=True)
                 if childlist is not None:
                     for child in childlist.added_items():
                         self._synchronize(obj, child, None, False)
-                        if child is not None and self.post_update:
-                            uowcommit.register_object(child, postupdate=True)
+                        self._conditional_post_update(child, uowcommit, [obj])
                     for child in childlist.deleted_items():
                         if not self.cascade.delete_orphan:
                             self._synchronize(obj, child, None, True)
@@ -194,16 +197,16 @@ class ManyToOneDP(DependencyProcessor):
                 # before we can DELETE the row
                 for obj in deplist:
                     self._synchronize(obj, None, None, True)
-                    uowcommit.register_object(obj, postupdate=True)
+                    childlist = self.get_object_dependencies(obj, uowcommit, passive=True)
+                    self._conditional_post_update(obj, uowcommit, childlist.deleted_items() + childlist.unchanged_items() + childlist.added_items())
         else:
             for obj in deplist:
                 childlist = self.get_object_dependencies(obj, uowcommit, passive=True)
                 if childlist is not None:
                     for child in childlist.added_items():
                         self._synchronize(obj, child, None, False)
-                if self.post_update:
-                    uowcommit.register_object(obj, postupdate=True)
-            
+                    self._conditional_post_update(obj, uowcommit, childlist.deleted_items() + childlist.unchanged_items() + childlist.added_items())
+                        
     def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
         #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " PRE process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
         # TODO: post_update instructions should be established in this step as well
index 837f17a3327f0426a18ffb047363183d10685fe8..194b2b9e6fd589d35e159f30b77c77d40dc22a21 100644 (file)
@@ -703,12 +703,13 @@ class Mapper(object):
 
     def _setattrbycolumn(self, obj, column, value):
         self.columntoproperty[column][0].setattr(obj, value)
-            
-    def save_obj(self, objects, uow, postupdate=False):
+    
+    def save_obj(self, objects, uow, postupdate=False, post_update_cols=None):
         """called by a UnitOfWork object to save objects, which involves either an INSERT or
         an UPDATE statement for each table used by this mapper, for each element of the
         list."""
         #print "SAVE_OBJ MAPPER", self.class_.__name__, objects
+        
         connection = uow.transaction.connection(self)
 
         if not postupdate:
@@ -785,6 +786,8 @@ class Mapper(object):
                             # doing an UPDATE ? get the history for the attribute, with "passive"
                             # so as not to trigger any deferred loads.  if there is a new
                             # value, add it to the bind parameters
+                            if post_update_cols is not None and col not in post_update_cols:
+                                continue
                             prop = self._getpropbycolumn(col, False)
                             if prop is None:
                                 continue
index 8004776c8b69f8af9e4114ae74df608003f45b37..c43bd69e0aaa5e7eb6169a3f255ccb85c13815b1 100644 (file)
@@ -81,7 +81,10 @@ class ClauseSynchronizer(object):
         sqlclause.accept_visitor(processor)
         if len(self.syncrules) == rules_added:
             raise ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
-        
+    
+    def dest_columns(self):
+        return [r.dest_column for r in self.syncrules if r.dest_column is not None]
+
     def execute(self, source, dest, obj=None, child=None, clearkeys=None):
         for rule in self.syncrules:
             rule.execute(source, dest, obj, child, clearkeys)
index 2dac05b1a6c1973845e41cc83b24d9c43c1d5770..0a5669227180fae96e3ae8a2dbd8239d54f1acf2 100644 (file)
@@ -239,7 +239,7 @@ class UOWTransaction(object):
         self.__is_executing = False
         
     # TODO: shouldnt be able to register stuff here that is not in the enclosing Session
-    def register_object(self, obj, isdelete = False, listonly = False, postupdate=False, **kwargs):
+    def register_object(self, obj, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs):
         """adds an object to this UOWTransaction to be updated in the database.
 
         'isdelete' indicates whether the object is to be deleted or saved (update/inserted).
@@ -258,7 +258,7 @@ class UOWTransaction(object):
         self.mappers.add(mapper)
         task = self.get_task_by_mapper(mapper)
         if postupdate:
-            mod = task.append_postupdate(obj)
+            mod = task.append_postupdate(obj, post_update_cols)
             if mod: self._mark_modified()
             return
                 
@@ -283,7 +283,13 @@ class UOWTransaction(object):
         #if self.__is_executing:
         #    raise "test assertion failed"
         self.__modified = True
+    
         
+    def is_deleted(self, obj):
+        mapper = object_mapper(obj)
+        task = self.get_task_by_mapper(mapper)
+        return task.is_deleted(obj)
+            
     def get_task_by_mapper(self, mapper, dontcreate=False):
         """every individual mapper involved in the transaction has a single
         corresponding UOWTask object, which stores all the operations involved
@@ -650,9 +656,11 @@ class UOWTask(object):
             rec.isdelete = True
         return retval
     
-    def append_postupdate(self, obj):
+    def append_postupdate(self, obj, post_update_cols):
         # postupdates are UPDATED immeditely (for now)
-        self.mapper.save_obj([obj], self.uowtransaction, postupdate=True)
+        # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns
+        # instead of __eq__
+        self.mapper.save_obj([obj], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols))
         return True
             
     def delete(self, obj):
@@ -711,7 +719,16 @@ class UOWTask(object):
             if obj in self.objects:
                 return True
         return False
-        
+
+    def is_inserted(self, obj):
+        return not hasattr(obj, '_instance_key')
+    
+    def is_deleted(self, obj):
+        try:
+            return self.objects[obj].isdelete
+        except KeyError:
+            return False
+          
     def get_elements(self, polymorphic=False):
         if polymorphic:
             for task in self.polymorphic_tasks():
index dbc234d858f6b2b64f31d27a2206002d63deddce..dd2836469707624903cf2aa0123c8a2607c4a42d 100644 (file)
@@ -214,10 +214,12 @@ class OneToManyManyToOneTest(AssertMixin):
         ball = Table('ball', metadata,
          Column('id', Integer, Sequence('ball_id_seq', optional=True), primary_key=True),
          Column('person_id', Integer),
+         Column('data', String(30))
          )
         person = Table('person', metadata,
          Column('id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
          Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
+         Column('data', String(30))
 #         Column('favorite_ball_id', Integer),
          )
 
@@ -267,10 +269,12 @@ class OneToManyManyToOneTest(AssertMixin):
     def testpostupdate_m2o(self):
         """tests a cycle between two rows, with a post_update on the many-to-one"""
         class Person(object):
-         pass
+            def __init__(self, data):
+                self.data = data
 
         class Ball(object):
-         pass
+            def __init__(self, data):
+                self.data = data
 
         Ball.mapper = mapper(Ball, ball)
         Person.mapper = mapper(Person, person, properties= dict(
@@ -281,12 +285,12 @@ class OneToManyManyToOneTest(AssertMixin):
 
         print str(Person.mapper.props['balls'].primaryjoin)
 
-        b = Ball()
-        p = Person()
+        b = Ball('some data')
+        p = Person('some data')
         p.balls.append(b)
-        p.balls.append(Ball())
-        p.balls.append(Ball())
-        p.balls.append(Ball())
+        p.balls.append(Ball('some data'))
+        p.balls.append(Ball('some data'))
+        p.balls.append(Ball('some data'))
         p.favorateBall = b
         sess = create_session()
         sess.save(b)
@@ -294,24 +298,24 @@ class OneToManyManyToOneTest(AssertMixin):
         
         self.assert_sql(db, lambda: sess.flush(), [
             (
-                "INSERT INTO person (favorite_ball_id) VALUES (:favorite_ball_id)",
-                {'favorite_ball_id': None}
+                "INSERT INTO person (favorite_ball_id, data) VALUES (:favorite_ball_id, :data)",
+                {'favorite_ball_id': None, 'data':'some data'}
             ),
             (
-                "INSERT INTO ball (person_id) VALUES (:person_id)",
-                lambda ctx:{'person_id':p.id}
+                "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                lambda ctx:{'person_id':p.id, 'data':'some data'}
             ),
             (
-                "INSERT INTO ball (person_id) VALUES (:person_id)",
-                lambda ctx:{'person_id':p.id}
+                "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                lambda ctx:{'person_id':p.id, 'data':'some data'}
             ),
             (
-                "INSERT INTO ball (person_id) VALUES (:person_id)",
-                lambda ctx:{'person_id':p.id}
+                "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                lambda ctx:{'person_id':p.id, 'data':'some data'}
             ),
             (
-                "INSERT INTO ball (person_id) VALUES (:person_id)",
-                lambda ctx:{'person_id':p.id}
+                "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                lambda ctx:{'person_id':p.id, 'data':'some data'}
             ),
             (
                 "UPDATE person SET favorite_ball_id=:favorite_ball_id WHERE person.id = :person_id",
@@ -320,24 +324,24 @@ class OneToManyManyToOneTest(AssertMixin):
         ], 
         with_sequences= [
                 (
-                    "INSERT INTO person (id, favorite_ball_id) VALUES (:id, :favorite_ball_id)",
-                    lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favorite_ball_id': None}
+                    "INSERT INTO person (id, favorite_ball_id, data) VALUES (:id, :favorite_ball_id, :data)",
+                    lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favorite_ball_id': None, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
+                    "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
+                    "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
+                    "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id}
+                    "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                    lambda ctx:{'id':ctx.last_inserted_ids()[0],'person_id':p.id, 'data':'some data'}
                 ),
                 # heres the post update 
                 (
@@ -369,10 +373,12 @@ class OneToManyManyToOneTest(AssertMixin):
     def testpostupdate_o2m(self):
         """tests a cycle between two rows, with a post_update on the one-to-many"""
         class Person(object):
-         pass
+            def __init__(self, data):
+                self.data = data
 
         class Ball(object):
-         pass
+            def __init__(self, data):
+                self.data = data
 
         Ball.mapper = mapper(Ball, ball)
         Person.mapper = mapper(Person, person, properties= dict(
@@ -383,14 +389,14 @@ class OneToManyManyToOneTest(AssertMixin):
 
         print str(Person.mapper.props['balls'].primaryjoin)
 
-        b = Ball()
-        p = Person()
+        b = Ball('some data')
+        p = Person('some data')
         p.balls.append(b)
-        b2 = Ball()
+        b2 = Ball('some data')
         p.balls.append(b2)
-        b3 = Ball()
+        b3 = Ball('some data')
         p.balls.append(b3)
-        b4 = Ball()
+        b4 = Ball('some data')
         p.balls.append(b4)
         p.favorateBall = b
         sess = create_session()
@@ -398,24 +404,24 @@ class OneToManyManyToOneTest(AssertMixin):
 
         self.assert_sql(db, lambda: sess.flush(), [
                 (
-                    "INSERT INTO ball (person_id) VALUES (:person_id)",
-                    {'person_id':None}
+                    "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                    {'person_id':None, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO ball (person_id) VALUES (:person_id)",
-                    {'person_id':None}
+                    "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                    {'person_id':None, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO ball (person_id) VALUES (:person_id)",
-                    {'person_id':None}
+                    "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                    {'person_id':None, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO ball (person_id) VALUES (:person_id)",
-                    {'person_id':None}
+                    "INSERT INTO ball (person_id, data) VALUES (:person_id, :data)",
+                    {'person_id':None, 'data':'some data'}
                 ),
                 (
-                    "INSERT INTO person (favorite_ball_id) VALUES (:favorite_ball_id)",
-                    lambda ctx:{'favorite_ball_id':b.id}
+                    "INSERT INTO person (favorite_ball_id, data) VALUES (:favorite_ball_id, :data)",
+                    lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}
                 ),
                 # heres the post update on each one-to-many item
                 (
@@ -437,24 +443,24 @@ class OneToManyManyToOneTest(AssertMixin):
         ],
         with_sequences=[
             (
-                "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
+                "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None, 'data':'some data'}
             ),
             (
-                "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
+                "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None, 'data':'some data'}
             ),
             (
-                "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
+                "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None, 'data':'some data'}
             ),
             (
-                "INSERT INTO ball (id, person_id) VALUES (:id, :person_id)",
-                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None}
+                "INSERT INTO ball (id, person_id, data) VALUES (:id, :person_id, :data)",
+                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'person_id':None, 'data':'some data'}
             ),
             (
-                "INSERT INTO person (id, favorite_ball_id) VALUES (:id, :favorite_ball_id)",
-                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favorite_ball_id':b.id}
+                "INSERT INTO person (id, favorite_ball_id, data) VALUES (:id, :favorite_ball_id, :data)",
+                lambda ctx:{'id':ctx.last_inserted_ids()[0], 'favorite_ball_id':b.id, 'data':'some data'}
             ),
             (
                 "UPDATE ball SET person_id=:person_id WHERE ball.id = :ball_id",
@@ -498,11 +504,109 @@ class OneToManyManyToOneTest(AssertMixin):
             ),
             (
                 "DELETE FROM ball WHERE ball.id = :id",
-                None
-                # the order of deletion is not predictable, but its roughly:
-                # lambda ctx:[{'id': b.id}, {'id': b2.id}, {'id': b3.id}, {'id': b4.id}]
+                lambda ctx:[{'id': b.id}, {'id': b2.id}, {'id': b3.id}, {'id': b4.id}]
             )
         ])
+
+class SelfReferentialPostUpdateTest(AssertMixin):
+    def setUpAll(self):
+        global metadata, node_table
+        metadata = BoundMetaData(testbase.db)
+        node_table = Table('node', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('path', String(50), nullable=False),
+            Column('parent_id', Integer, ForeignKey('node.id'), nullable=True),
+            Column('prev_sibling_id', Integer, ForeignKey('node.id'), nullable=True),
+            Column('next_sibling_id', Integer, ForeignKey('node.id'), nullable=True)
+        )
+        node_table.create()
+    def tearDownAll(self):
+        node_table.drop()
+    
+    def testbasic(self):
+        """test that post_update only fires off when needed.
+        
+        this test case used to produce many superfluous update statements, particularly upon delete"""
+        class Node(object):
+            def __init__(self, path=''):
+                self.path = path
+
+        n_mapper = mapper(Node, node_table, properties={
+            'children': relation(
+                Node,
+                primaryjoin=node_table.c.id==node_table.c.parent_id,
+                lazy=True,
+                cascade="all",
+                backref=backref("parent", primaryjoin=node_table.c.parent_id==node_table.c.id, foreignkey=node_table.c.id)
+            ),
+            'prev_sibling': relation(
+                Node,
+                primaryjoin=node_table.c.prev_sibling_id==node_table.c.id,
+                foreignkey=node_table.c.id,
+                lazy=True,
+                uselist=False
+            ),
+            'next_sibling': relation(
+                Node,
+                primaryjoin=node_table.c.next_sibling_id==node_table.c.id,
+                foreignkey=node_table.c.id,
+                lazy=True,
+                uselist=False,
+                post_update=True
+            )
+        })
+
+        session = create_session()
+
+        def append_child(parent, child):
+            if len(parent.children):
+                parent.children[-1].next_sibling = child
+                child.prev_sibling = parent.children[-1]
+            parent.children.append(child)
+        
+        def remove_child(parent, child):
+            child.parent = None
+            node = child.next_sibling
+            node.prev_sibling = child.prev_sibling
+            child.prev_sibling.next_sibling = node
+            session.delete(child)
+        root = Node('root')
+
+        about = Node('about')
+        cats = Node('cats')
+        stories = Node('stories')
+        bruce = Node('bruce')
+
+        append_child(root, about)
+        assert(about.prev_sibling is None)
+        append_child(root, cats)
+        assert(cats.prev_sibling is about)
+        assert(cats.next_sibling is None)
+        assert(about.next_sibling is cats)
+        assert(about.prev_sibling is None)
+        append_child(root, stories)
+        append_child(root, bruce)
+        session.save(root)
+        session.flush()
+
+        remove_child(root, cats)
+        # pre-trigger lazy loader on 'cats' to make the test easier
+        cats.children
+
+        self.assert_sql(db, lambda: session.flush(), [
+            (
+                "UPDATE node SET prev_sibling_id=:prev_sibling_id WHERE node.id = :node_id",
+                lambda ctx:{'prev_sibling_id':about.id, 'node_id':stories.id}
+            ),
+            (
+                "UPDATE node SET next_sibling_id=:next_sibling_id WHERE node.id = :node_id",
+                lambda ctx:{'next_sibling_id':stories.id, 'node_id':about.id}
+            ),
+            (
+                "DELETE FROM node WHERE node.id = :id",
+                lambda ctx:[{'id':cats.id}]
+            ),
+        ])
         
 if __name__ == "__main__":
     testbase.main()