From 00da425b07549cc297541d19a5408d44c156a37e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 29 Nov 2006 17:16:41 +0000 Subject: [PATCH] - added extra check to "stop" cascading on save/update/save-update if an instance is detected to be already in the session. --- CHANGES | 2 + lib/sqlalchemy/orm/interfaces.py | 4 +- lib/sqlalchemy/orm/mapper.py | 8 +-- lib/sqlalchemy/orm/properties.py | 8 +-- lib/sqlalchemy/orm/session.py | 6 +-- lib/sqlalchemy/orm/util.py | 3 +- test/perf/cascade_speed.py | 90 ++++++++++++++++++++++++++++++++ 7 files changed, 107 insertions(+), 14 deletions(-) create mode 100644 test/perf/cascade_speed.py diff --git a/CHANGES b/CHANGES index 516309c3d5..0befda974c 100644 --- a/CHANGES +++ b/CHANGES @@ -3,6 +3,8 @@ [ticket:247] - added label() function to Select class, when scalar=True is used to create a scalar subquery. +- added extra check to "stop" cascading on save/update/save-update if +an instance is detected to be already in the session. 0.3.1 - Engine/Pool: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d6c4204f95..0f607ebeec 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -18,9 +18,9 @@ class MapperProperty(object): """called when the mapper receives a row. instance is the parent instance corresponding to the row. """ raise NotImplementedError() - def cascade_iterator(self, type, object, recursive=None): + def cascade_iterator(self, type, object, recursive=None, halt_on=None): return [] - def cascade_callable(self, type, object, callable_, recursive=None): + def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None): return [] def get_criterion(self, query, key, value): """Returns a WHERE clause suitable for this MapperProperty corresponding to the diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 5c1a3b5647..5fc986d6be 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1082,7 +1082,7 @@ class Mapper(object): for prop in self.__props.values(): prop.register_dependencies(uowcommit, *args, **kwargs) - def cascade_iterator(self, type, object, recursive=None): + def cascade_iterator(self, type, object, recursive=None, halt_on=None): """iterate each element in an object graph, for all relations taht meet the given cascade rule. type - the name of the cascade rule (i.e. save-update, delete, etc.) @@ -1094,10 +1094,10 @@ class Mapper(object): if recursive is None: recursive=util.Set() for prop in self.__props.values(): - for c in prop.cascade_iterator(type, object, recursive): + for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on): yield c - def cascade_callable(self, type, object, callable_, recursive=None): + def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None): """execute a callable for each element in an object graph, for all relations that meet the given cascade rule. type - the name of the cascade rule (i.e. save-update, delete, etc.) @@ -1111,7 +1111,7 @@ class Mapper(object): if recursive is None: recursive=util.Set() for prop in self.__props.values(): - prop.cascade_callable(type, object, callable_, recursive) + prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on) def get_select_mapper(self): diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 6d11436118..8379a1d1f0 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -118,7 +118,7 @@ class PropertyLoader(StrategizedProperty): def __str__(self): return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper) - def cascade_iterator(self, type, object, recursive): + def cascade_iterator(self, type, object, recursive, halt_on=None): if not type in self.cascade: return passive = type != 'delete' or self.passive_deletes @@ -127,7 +127,7 @@ class PropertyLoader(StrategizedProperty): return mapper = self.mapper.primary_mapper() for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items(): - if c is not None and c not in recursive: + if c is not None and c not in recursive and (halt_on is None or not halt_on(c)): if not isinstance(c, self.mapper.class_): raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) recursive.add(c) @@ -135,14 +135,14 @@ class PropertyLoader(StrategizedProperty): for c2 in mapper.cascade_iterator(type, c, recursive): yield c2 - def cascade_callable(self, type, object, callable_, recursive): + def cascade_callable(self, type, object, callable_, recursive, halt_on=None): if not type in self.cascade: return mapper = self.mapper.primary_mapper() passive = type != 'delete' or self.passive_deletes for c in sessionlib.attribute_manager.get_as_list(object, self.key, passive=passive): - if c is not None and c not in recursive: + if c is not None and c not in recursive and (halt_on is None or not halt_on(c)): if not isinstance(c, self.mapper.class_): raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) recursive.add(c) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 921449a6af..0a196d3d9d 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -287,7 +287,7 @@ class Session(object): instance. """ self._save_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) + _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self) def update(self, object, entity_name=None): """Bring the given detached (saved) instance into this Session. @@ -298,7 +298,7 @@ class Session(object): This operation cascades the "save_or_update" method to associated instances if the relation is mapped with cascade="save-update".""" self._update_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) + _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self) def save_or_update(self, object, entity_name=None): """save or update the given object into this Session. @@ -306,7 +306,7 @@ class Session(object): The presence of an '_instance_key' attribute on the instance determines whether to save() or update() the instance.""" self._save_or_update_impl(object, entity_name=entity_name) - _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) + _object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e), halt_on=lambda c:c in self) def _save_or_update_impl(self, object, entity_name=None): key = getattr(object, '_instance_key', None) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index d90a0a1d8d..47b358fb9d 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -19,7 +19,8 @@ class CascadeOptions(object): #self.refresh_expire = "refresh-expire" in values or "all" in values def __contains__(self, item): return getattr(self, item.replace("-", "_"), False) - + def __repr__(self): + return "CascadeOptions(arg=%s)" % repr(",".join([x for x in ['delete', 'save_update', 'merge', 'expunge', 'delete_orphan'] if getattr(self, x, False) is True])) def polymorphic_union(table_map, typecolname, aliasname='p_union'): """create a UNION statement used by a polymorphic mapper. diff --git a/test/perf/cascade_speed.py b/test/perf/cascade_speed.py new file mode 100644 index 0000000000..a09e9dd6f7 --- /dev/null +++ b/test/perf/cascade_speed.py @@ -0,0 +1,90 @@ +from sqlalchemy import * +from timeit import Timer +import sys + +meta = DynamicMetaData("time_trial") + +orders = Table('orders', meta, + Column('id', Integer, Sequence('order_id_seq'), primary_key = True), +) +items = Table('items', meta, + Column('id', Integer, Sequence('item_id_seq'), primary_key = True), + Column('order_id', Integer, ForeignKey(orders.c.id), nullable=False), +) +attributes = Table('attributes', meta, + Column('id', Integer, Sequence('attribute_id_seq'), primary_key = True), + Column('item_id', Integer, ForeignKey(items.c.id), nullable=False), +) +values = Table('values', meta, + Column('id', Integer, Sequence('value_id_seq'), primary_key = True), + Column('attribute_id', Integer, ForeignKey(attributes.c.id), nullable=False), +) + +class Order(object): pass +class Item(object): pass +class Attribute(object): pass +class Value(object): pass + +valueMapper = mapper(Value, values) +attrMapper = mapper(Attribute, attributes, properties=dict( + values=relation(valueMapper, cascade="save-update", backref="attribute") +)) +itemMapper = mapper(Item, items, properties=dict( + attributes=relation(attrMapper, cascade="save-update", backref="item") +)) +orderMapper = mapper(Order, orders, properties=dict( + items=relation(itemMapper, cascade="save-update", backref="order") +)) + + + +class TimeTrial(object): + + def create_fwd_assoc(self): + item = Item() + self.order.items.append(item) + for attrid in range(10): + attr = Attribute() + item.attributes.append(attr) + for valueid in range(5): + val = Value() + attr.values.append(val) + + def create_back_assoc(self): + item = Item() + item.order = self.order + for attrid in range(10): + attr = Attribute() + attr.item = item + for valueid in range(5): + val = Value() + val.attribute = attr + + def run(self, number): + s = create_session() + self.order = order = Order() + s.save(order) + + ctime = 0.0 + timer = Timer("create_it()", "from __main__ import create_it") + for n in xrange(number): + t = timer.timeit(1) + print "Time to create item %i: %.5f sec" % (n, t) + ctime += t + + assert len(order.items) == 10 + assert sum(len(item.attributes) for item in order.items) == 100 + assert sum(len(attr.values) for item in order.items for attr in item.attributes) == 500 + assert len(s.new) == 611 + print "Created 610 objects in %.5f sec" % ctime + +if __name__ == "__main__": + tt = TimeTrial() + + print "\nCreate forward associations" + create_it = tt.create_fwd_assoc + tt.run(10) + + print "\nCreate backward associations" + create_it = tt.create_back_assoc + tt.run(10) \ No newline at end of file -- 2.47.2