]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added extra check to "stop" cascading on save/update/save-update if
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Nov 2006 17:16:41 +0000 (17:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 29 Nov 2006 17:16:41 +0000 (17:16 +0000)
an instance is detected to be already in the session.

CHANGES
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/util.py
test/perf/cascade_speed.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index 516309c3d568cd6b5ca21137af05b56fd49b8bc3..0befda974c128c11c42b6cde85dbf0e5b1eccc3d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -3,6 +3,8 @@
 [ticket:247]
 - added label() function to Select class, when scalar=True is used
 to create a scalar subquery.
+- added extra check to "stop" cascading on save/update/save-update if
+an instance is detected to be already in the session.
 
 0.3.1
 - Engine/Pool:
index d6c4204f95bc91bce19ce8903c07a37db479b729..0f607ebeecb8a5f348878f0c0b90906369be264a 100644 (file)
@@ -18,9 +18,9 @@ class MapperProperty(object):
         """called when the mapper receives a row.  instance is the parent instance
         corresponding to the row. """
         raise NotImplementedError()
-    def cascade_iterator(self, type, object, recursive=None):
+    def cascade_iterator(self, type, object, recursive=None, halt_on=None):
         return []
-    def cascade_callable(self, type, object, callable_, recursive=None):
+    def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
         return []
     def get_criterion(self, query, key, value):
         """Returns a WHERE clause suitable for this MapperProperty corresponding to the 
index 5c1a3b5647931bac4b2b27ba12e27e4027fac82a..5fc986d6befafce5197252dbc94cf1476a06f5fa 100644 (file)
@@ -1082,7 +1082,7 @@ class Mapper(object):
         for prop in self.__props.values():
             prop.register_dependencies(uowcommit, *args, **kwargs)
     
-    def cascade_iterator(self, type, object, recursive=None):
+    def cascade_iterator(self, type, object, recursive=None, halt_on=None):
         """iterate each element in an object graph, for all relations taht meet the given cascade rule.
         
         type - the name of the cascade rule (i.e. save-update, delete, etc.)
@@ -1094,10 +1094,10 @@ class Mapper(object):
         if recursive is None:
             recursive=util.Set()
         for prop in self.__props.values():
-            for c in prop.cascade_iterator(type, object, recursive):
+            for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
                 yield c
 
-    def cascade_callable(self, type, object, callable_, recursive=None):
+    def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
         """execute a callable for each element in an object graph, for all relations that meet the given cascade rule.
         
         type - the name of the cascade rule (i.e. save-update, delete, etc.)
@@ -1111,7 +1111,7 @@ class Mapper(object):
         if recursive is None:
             recursive=util.Set()
         for prop in self.__props.values():
-            prop.cascade_callable(type, object, callable_, recursive)
+            prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on)
             
 
     def get_select_mapper(self):
index 6d11436118fe4ba2362a02c783f95d7473026682..8379a1d1f0b2cb7642dcaea29aaebaa8c56fe28c 100644 (file)
@@ -118,7 +118,7 @@ class PropertyLoader(StrategizedProperty):
     def __str__(self):
         return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper)
         
-    def cascade_iterator(self, type, object, recursive):
+    def cascade_iterator(self, type, object, recursive, halt_on=None):
         if not type in self.cascade:
             return
         passive = type != 'delete' or self.passive_deletes
@@ -127,7 +127,7 @@ class PropertyLoader(StrategizedProperty):
             return
         mapper = self.mapper.primary_mapper()
         for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items():
-            if c is not None and c not in recursive:
+            if c is not None and c not in recursive and (halt_on is None or not halt_on(c)):
                 if not isinstance(c, self.mapper.class_):
                     raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
                 recursive.add(c)
@@ -135,14 +135,14 @@ class PropertyLoader(StrategizedProperty):
                 for c2 in mapper.cascade_iterator(type, c, recursive):
                     yield c2
 
-    def cascade_callable(self, type, object, callable_, recursive):
+    def cascade_callable(self, type, object, callable_, recursive, halt_on=None):
         if not type in self.cascade:
             return
         
         mapper = self.mapper.primary_mapper()
         passive = type != 'delete' or self.passive_deletes
         for c in sessionlib.attribute_manager.get_as_list(object, self.key, passive=passive):
-            if c is not None and c not in recursive:
+            if c is not None and c not in recursive and (halt_on is None or not halt_on(c)):
                 if not isinstance(c, self.mapper.class_):
                     raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
                 recursive.add(c)
index 921449a6af845c4c3174d3ec48c3de2c1137fbbf..0a196d3d9d5cb434f53faaad934ebeaf5bf93f21 100644 (file)
@@ -287,7 +287,7 @@ class Session(object):
         instance.
         """
         self._save_impl(object, entity_name=entity_name)
-        _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
+        _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self)
 
     def update(self, object, entity_name=None):
         """Bring the given detached (saved) instance into this Session.
@@ -298,7 +298,7 @@ class Session(object):
         This operation cascades the "save_or_update" method to associated instances if the relation is mapped 
         with cascade="save-update"."""
         self._update_impl(object, entity_name=entity_name)
-        _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
+        _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self)
 
     def save_or_update(self, object, entity_name=None):
         """save or update the given object into this Session.
@@ -306,7 +306,7 @@ class Session(object):
         The presence of an '_instance_key' attribute on the instance determines whether to
         save() or update() the instance."""
         self._save_or_update_impl(object, entity_name=entity_name)
-        _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
+        _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self)
     
     def _save_or_update_impl(self, object, entity_name=None):
         key = getattr(object, '_instance_key', None)
index d90a0a1d8d529a7030b770f0e73c86226e3d5889..47b358fb9d3f9357d0842019e79c9a7b2d185c4c 100644 (file)
@@ -19,7 +19,8 @@ class CascadeOptions(object):
         #self.refresh_expire = "refresh-expire" in values or "all" in values
     def __contains__(self, item):
         return getattr(self, item.replace("-", "_"), False)
-    
+    def __repr__(self):
+        return "CascadeOptions(arg=%s)" % repr(",".join([x for x in ['delete', 'save_update', 'merge', 'expunge', 'delete_orphan'] if getattr(self, x, False) is True]))
 
 def polymorphic_union(table_map, typecolname, aliasname='p_union'):
     """create a UNION statement used by a polymorphic mapper.
diff --git a/test/perf/cascade_speed.py b/test/perf/cascade_speed.py
new file mode 100644 (file)
index 0000000..a09e9dd
--- /dev/null
@@ -0,0 +1,90 @@
+from sqlalchemy import *
+from timeit import Timer
+import sys
+
+meta = DynamicMetaData("time_trial")
+
+orders = Table('orders', meta,
+    Column('id', Integer, Sequence('order_id_seq'), primary_key = True),
+)
+items = Table('items', meta,
+    Column('id', Integer, Sequence('item_id_seq'), primary_key = True),
+    Column('order_id', Integer, ForeignKey(orders.c.id), nullable=False),
+)
+attributes = Table('attributes', meta,
+    Column('id', Integer, Sequence('attribute_id_seq'), primary_key = True),
+    Column('item_id', Integer, ForeignKey(items.c.id), nullable=False),
+)
+values = Table('values', meta,
+    Column('id', Integer, Sequence('value_id_seq'), primary_key = True),
+    Column('attribute_id', Integer, ForeignKey(attributes.c.id), nullable=False),
+)
+
+class Order(object): pass
+class Item(object): pass
+class Attribute(object): pass
+class Value(object): pass
+
+valueMapper = mapper(Value, values)
+attrMapper = mapper(Attribute, attributes, properties=dict(
+    values=relation(valueMapper, cascade="save-update", backref="attribute")
+))
+itemMapper = mapper(Item, items, properties=dict(
+    attributes=relation(attrMapper, cascade="save-update", backref="item")
+))
+orderMapper = mapper(Order, orders, properties=dict(
+    items=relation(itemMapper, cascade="save-update", backref="order")
+))
+
+
+
+class TimeTrial(object):
+
+    def create_fwd_assoc(self):
+        item = Item()
+        self.order.items.append(item)
+        for attrid in range(10):
+            attr = Attribute()
+            item.attributes.append(attr)
+            for valueid in range(5):
+                val = Value()
+                attr.values.append(val)
+
+    def create_back_assoc(self):
+        item = Item()
+        item.order = self.order
+        for attrid in range(10):
+            attr = Attribute()
+            attr.item = item
+            for valueid in range(5):
+                val = Value()
+                val.attribute = attr
+                
+    def run(self, number):
+        s = create_session()
+        self.order = order = Order()
+        s.save(order)
+
+        ctime = 0.0
+        timer = Timer("create_it()", "from __main__ import create_it")
+        for n in xrange(number):
+            t = timer.timeit(1)
+            print "Time to create item %i: %.5f sec" % (n, t)
+            ctime += t
+
+        assert len(order.items) == 10
+        assert sum(len(item.attributes) for item in order.items) == 100
+        assert sum(len(attr.values) for item in order.items for attr in item.attributes) == 500
+        assert len(s.new) == 611
+        print "Created 610 objects in %.5f sec" % ctime
+
+if __name__ == "__main__":
+    tt = TimeTrial()
+
+    print "\nCreate forward associations"
+    create_it = tt.create_fwd_assoc
+    tt.run(10)
+
+    print "\nCreate backward associations"
+    create_it = tt.create_back_assoc
+    tt.run(10)
\ No newline at end of file