From 0836f5cdfdaff67f3232a6bdbd58ac75924dcafa Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 20 Jul 2007 03:20:33 +0000 Subject: [PATCH] - Eager loading now functions at any arbitrary depth along self-referential and cyclical structures. When loading cyclical structures, specify "join_depth" on relation() indicating how many times you'd like the table to join to itself; each level gets a distinct table alias. - adds a stack to the Mapper->eagerloader->Mapper process which is the single point of tracking the various AliasedClause objects both at query compile time as well as result fetching time. self-referential relationships narrow down the "aliasing" of tables more sharply so as to produce the correct eager joins in those cases without stepping on more generalized cases. the mechanism of detecting "too deep" of an eager load now works based on locating a true cycle, but only if join_depth is not specified; otherwise join_depth is used. [ticket:659] --- CHANGES | 31 ++++-- lib/sqlalchemy/orm/interfaces.py | 25 +++++ lib/sqlalchemy/orm/mapper.py | 9 +- lib/sqlalchemy/orm/properties.py | 5 +- lib/sqlalchemy/orm/query.py | 3 +- lib/sqlalchemy/orm/strategies.py | 166 +++++++++++++++++------------- test/orm/cycles.py | 16 --- test/orm/eager_relations.py | 79 ++++++++++++++ test/orm/inheritance/polymorph.py | 23 ----- 9 files changed, 232 insertions(+), 125 deletions(-) diff --git a/CHANGES b/CHANGES index 6f72ac3820..dc93d7d0f1 100644 --- a/CHANGES +++ b/CHANGES @@ -38,15 +38,32 @@ querying divergent criteria. ClauseElements at the front of filter_by() are removed (use filter()). - - added operator support to class-instrumented attributes. you can now - filter() (or whatever) using .==. - for column based properties, all column operators work (i.e. ==, <, >, - like(), in_(), etc.). For relation() and composite column properties, - ==, !=, and == are implemented so far. + - Eager loading has been enhanced to allow even more joins in more places. + It now functions at any arbitrary depth along self-referential + and cyclical structures. When loading cyclical structures, specify "join_depth" + on relation() indicating how many times you'd like the table to join + to itself; each level gets a distinct table alias. The alias names + themselves are generated at compile time using a simple counting + scheme now and are a lot easier on the eyes, as well as of course + completely deterministic. [ticket:659] + + - Class-level properties are now usable as query elements ...no + more '.c.' ! "Class.c.propname" is now superceded by "Class.propname". + All clause operators are supported, as well as higher level operators + such as Class.prop== for scalar attributes and + Class.prop.contains() for collection-based attributes + (both are also negatable). Table-based column expressions as well as + columns mounted on mapped classes via 'c' are of course still fully available + and can be freely mixed with the new attributes. [ticket:643] - - added composite column properties. using the composite(cls, *columns) - function inside of the "properties" dict, instances of cls will be + - added composite column properties. This allows you to create a + type which is represented by more than one column, when using the + ORM. Objects of the new type are fully functional in query expressions, + comparisons, query.get() clauses, etc. and act as though they are regular + single-column scalars..except they're not ! + Use the function composite(cls, *columns) inside of the + mapper's "properties" dict, and instances of cls will be created/mapped to a single attribute, comprised of the values correponding to *columns [ticket:211] diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index e1209fabf9..79b575cc2f 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -388,6 +388,31 @@ class StrategizedProperty(MapperProperty): if self.is_primary(): self.strategy.init_class_attribute() +class LoaderStack(object): + """a stack object used during load operations to track the + current position among a chain of mappers to eager loaders.""" + + def __init__(self): + self.__stack = [] + + def push_property(self, key): + self.__stack.append(key) + + def push_mapper(self, mapper): + self.__stack.append(mapper.base_mapper()) + + def pop(self): + self.__stack.pop() + + def snapshot(self): + """return an 'snapshot' of this stack. + + this is a tuple form of the stack which can be used as a hash key.""" + return tuple(self.__stack) + + def __str__(self): + return "->".join([str(s) for s in self.__stack]) + class OperationContext(object): """Serve as a context during a query construction or instance loading operation. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 712f7b90a1..097c906abf 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -805,6 +805,7 @@ class Mapper(object): def base_mapper(self): """Return the ultimate base mapper in an inheritance chain.""" + # TODO: calculate this at mapper setup time if self.inherits is not None: return self.inherits.base_mapper() else: @@ -1596,8 +1597,8 @@ class Mapper(object): def populate_instance(self, selectcontext, instance, row, ispostselect=None, **flags): """populate an instance from a result row.""" - - populators = selectcontext.attributes.get(('instance_populators', self, ispostselect), None) + selectcontext.stack.push_mapper(self) + populators = selectcontext.attributes.get(('instance_populators', self, selectcontext.stack.snapshot(), ispostselect), None) if populators is None: populators = [] post_processors = [] @@ -1612,11 +1613,13 @@ class Mapper(object): if poly_select_loader is not None: post_processors.append(poly_select_loader) - selectcontext.attributes[('instance_populators', self, ispostselect)] = populators + selectcontext.attributes[('instance_populators', self, selectcontext.stack.snapshot(), ispostselect)] = populators selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors for p in populators: p(instance, row, ispostselect=ispostselect, **flags) + + selectcontext.stack.pop() if self.non_primary: selectcontext.attributes[('populating_mapper', instance)] = self diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 010413c640..4635b8e4b7 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -118,7 +118,7 @@ class PropertyLoader(StrategizedProperty): of items that correspond to a related database table. """ - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True, join_depth=None): self.uselist = uselist self.argument = argument self.entity_name = entity_name @@ -137,7 +137,8 @@ class PropertyLoader(StrategizedProperty): self.enable_typechecks = enable_typechecks self._parent_join_cache = {} self.comparator = PropertyLoader.Comparator(self) - + self.join_depth = join_depth + if cascade is not None: self.cascade = mapperutil.CascadeOptions(cascade) else: diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d487db5742..aeae60365c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -6,7 +6,7 @@ from sqlalchemy import sql, util, exceptions, sql_util, logging from sqlalchemy.orm import mapper, object_mapper -from sqlalchemy.orm.interfaces import OperationContext +from sqlalchemy.orm.interfaces import OperationContext, LoaderStack import operator __all__ = ['Query', 'QueryContext', 'SelectionContext'] @@ -1137,6 +1137,7 @@ class SelectionContext(OperationContext): self.session = session self.extension = extension self.identity_map = {} + self.stack = LoaderStack() super(SelectionContext, self).__init__(mapper, kwargs.pop('with_options', []), **kwargs) def accept_option(self, opt): diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 4597e6d72b..473fe7729a 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -439,18 +439,12 @@ class EagerLoader(AbstractRelationLoader): def init(self): super(EagerLoader, self).init() - if self.parent.isa(self.mapper): - raise exceptions.ArgumentError( - "Error creating eager relationship '%s' on parent class '%s' " - "to child class '%s': Cant use eager loading on a self " - "referential relationship." % - (self.key, repr(self.parent.class_), repr(self.mapper.class_))) if self.is_default: self.parent._eager_loaders.add(self.parent_property) self.clauses = {} - self.clauses_by_lead_mapper = {} - + self.join_depth = self.parent_property.join_depth + class AliasedClauses(object): """Defines a set of join conditions and table aliases which are aliased on a randomly-generated alias name, corresponding @@ -471,17 +465,6 @@ class EagerLoader(AbstractRelationLoader): (EagerLoader 'keywords') --> mapper C - will generate:: - - EagerLoader 'items' --> { - None : AliasedClauses(items, None, alias_suffix='AB34') # mappera JOIN mapperb_AB34 - } - - EagerLoader 'keywords' --> [ - None : AliasedClauses(keywords, None, alias_suffix='43EF') # mapperb JOIN mapperc_43EF - AliasedClauses(items, None, alias_suffix='AB34') : - AliasedClauses(keywords, items, alias_suffix='8F44') # mapperb_AB34 JOIN mapperc_8F44 - ] """ def __init__(self, eagerloader, parentclauses=None): @@ -489,6 +472,10 @@ class EagerLoader(AbstractRelationLoader): self.target = eagerloader.select_table self.eagertarget = eagerloader.select_table.alias(None) self.extra_cols = {} + if parentclauses is not None: + self.path = parentclauses.path + (self.parent.parent, self.parent.key) + else: + self.path = (self.parent.parent, self.parent.key) if eagerloader.secondary: self.eagersecondary = eagerloader.secondary.alias(None) @@ -505,11 +492,20 @@ class EagerLoader(AbstractRelationLoader): self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True) else: self.eagerprimary = eagerloader.polymorphic_primaryjoin + + # for self-referential eager load, the "aliasing" of each side of the join condition + # must be limited to exactly the cols we know are on "our side". for non-self-referntial, + # be more liberal to include other elements of the join condition which deal with "our" table + if eagerloader.parent_property._is_self_referential(): + include = eagerloader.parent_property.remote_side + else: + include = None + if parentclauses is not None: - aliasizer = sql_util.ClauseAdapter(self.eagertarget) + aliasizer = sql_util.ClauseAdapter(self.eagertarget, include=include) aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side)) else: - aliasizer = sql_util.ClauseAdapter(self.eagertarget) + aliasizer = sql_util.ClauseAdapter(self.eagertarget, include=include) self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True) if eagerloader.order_by: @@ -518,7 +514,10 @@ class EagerLoader(AbstractRelationLoader): self.eager_order_by = None self._row_decorator = self._create_decorator_row() - + + def __str__(self): + return "->".join([str(s) for s in self.path]) + def aliased_column(self, column): """return the aliased version of the given column, creating a new label for it if not already present in this AliasedClauses eagertable.""" @@ -573,16 +572,30 @@ class EagerLoader(AbstractRelationLoader): def setup_query(self, context, eagertable=None, parentclauses=None, parentmapper=None, **kwargs): """Add a left outer join to the statement thats being constructed.""" + # build a path as we setup the query. the format of this path + # matches that of interfaces.LoaderStack, and will be used in the + # row-loading phase to match up AliasedClause objects with the current + # LoaderStack position. + if parentclauses: + path = parentclauses.path + (self.parent.base_mapper(), self.key) + else: + path = (self.parent.base_mapper(), self.key) + + + if self.join_depth: + if len(path) / 2 > self.join_depth: + return + else: + if self.mapper in path: + return + + #print "CREATING EAGER PATH FOR", "->".join([str(s) for s in path]) + if parentmapper is None: localparent = context.mapper else: localparent = parentmapper - if self.mapper in context.recursion_stack: - return - else: - context.recursion_stack.add(self.parent) - statement = context.statement if hasattr(statement, '_outerjoin'): @@ -604,16 +617,13 @@ class EagerLoader(AbstractRelationLoader): break else: raise exceptions.InvalidRequestError("EagerLoader cannot locate a clause with which to outer join to, in query '%s' %s" % (str(statement), localparent.mapped_table)) - + try: - clauses = self.clauses[parentclauses] + clauses = self.clauses[path] except KeyError: clauses = EagerLoader.AliasedClauses(self, parentclauses) - self.clauses[parentclauses] = clauses - - if context.mapper not in self.clauses_by_lead_mapper: - self.clauses_by_lead_mapper[context.mapper] = clauses - + self.clauses[path] = clauses + if self.secondaryjoin is not None: statement._outerjoin = sql.outerjoin(towrap, clauses.eagersecondary, clauses.eagerprimary).outerjoin(clauses.eagertarget, clauses.eagersecondaryjoin) if self.order_by is False and self.secondary.default_order_by() is not None: @@ -630,8 +640,8 @@ class EagerLoader(AbstractRelationLoader): for value in self.select_mapper.iterate_properties: value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper) - - def _create_row_decorator(self, selectcontext, row): + + def _create_row_decorator(self, selectcontext, row, path): """Create a *row decorating* function that will apply eager aliasing to the row. @@ -639,6 +649,8 @@ class EagerLoader(AbstractRelationLoader): else return None. """ + #print "creating row decorator for path ", "->".join([str(s) for s in path]) + # check for a user-defined decorator in the SelectContext (which was set up by the contains_eager() option) if selectcontext.attributes.has_key(("eager_row_processor", self.parent_property)): # custom row decoration function, placed in the selectcontext by the @@ -649,11 +661,13 @@ class EagerLoader(AbstractRelationLoader): else: try: # decorate the row according to the stored AliasedClauses for this eager load - clauses = self.clauses_by_lead_mapper[selectcontext.mapper] + clauses = self.clauses[path] decorator = clauses._row_decorator except KeyError, k: # no stored AliasedClauses: eager loading was not set up in the query and # AliasedClauses never got initialized + if self._should_log_debug: + self.logger.debug("Could not locate aliased clauses for key: " + str(path)) return None try: @@ -669,54 +683,60 @@ class EagerLoader(AbstractRelationLoader): return None def create_row_processor(self, selectcontext, mapper, row): - row_decorator = self._create_row_decorator(selectcontext, row) + selectcontext.stack.push_property(self.key) + path = selectcontext.stack.snapshot() + + row_decorator = self._create_row_decorator(selectcontext, row, path) if row_decorator is not None: def execute(instance, row, isnew, **flags): - if self in selectcontext.recursion_stack: - return decorated_row = row_decorator(row) - # TODO: recursion check a speed hit...? try to get a "termination point" into the AliasedClauses - # or EagerRowAdapter ? - selectcontext.recursion_stack.add(self) - try: - if not self.uselist: - if self._should_log_debug: - self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) - if isnew: - # set a scalar object instance directly on the - # parent object, bypassing InstrumentedAttribute - # event handlers. - # - # FIXME: instead of... - sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None)) - # bypass and set directly: - #instance.__dict__[self.key] = ... - else: - # call _instance on the row, even though the object has been created, - # so that we further descend into properties - self.mapper._instance(selectcontext, decorated_row, None) + selectcontext.stack.push_property(self.key) + + if not self.uselist: + if self._should_log_debug: + self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) + if isnew: + # set a scalar object instance directly on the + # parent object, bypassing InstrumentedAttribute + # event handlers. + # + # FIXME: instead of... + sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None)) + # bypass and set directly: + #instance.__dict__[self.key] = ... else: - if isnew: - if self._should_log_debug: - self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key)) + # call _instance on the row, even though the object has been created, + # so that we further descend into properties + self.mapper._instance(selectcontext, decorated_row, None) + else: + if isnew: + if self._should_log_debug: + self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key)) - collection = sessionlib.attribute_manager.init_collection(instance, self.key) - appender = util.UniqueAppender(collection, 'append_without_event') + collection = sessionlib.attribute_manager.init_collection(instance, self.key) + appender = util.UniqueAppender(collection, 'append_without_event') - # store it in the "scratch" area, which is local to this load operation. - selectcontext.attributes[(instance, self.key)] = appender - result_list = selectcontext.attributes[(instance, self.key)] - if self._should_log_debug: - self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) - self.select_mapper._instance(selectcontext, decorated_row, result_list) - finally: - selectcontext.recursion_stack.remove(self) + # store it in the "scratch" area, which is local to this load operation. + selectcontext.attributes[(instance, self.key)] = appender + result_list = selectcontext.attributes[(instance, self.key)] + if self._should_log_debug: + self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) + + self.select_mapper._instance(selectcontext, decorated_row, result_list) + selectcontext.stack.pop() + + selectcontext.stack.pop() return (execute, None) else: self.logger.debug("eager loader %s degrading to lazy loader" % str(self)) + selectcontext.stack.pop() return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row) + + def __str__(self): + return str(self.parent) + "." + self.key + EagerLoader.logger = logging.class_logger(EagerLoader) class EagerLazyOption(StrategizedOption): diff --git a/test/orm/cycles.py b/test/orm/cycles.py index 7e1716ca3b..bdbb146e9f 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -109,22 +109,6 @@ class SelfReferentialTest(AssertMixin): sess.delete(a) sess.flush() - def testeagerassertion(self): - """test that an eager self-referential relationship raises an error.""" - class C1(Tester): - pass - class C2(Tester): - pass - - m1 = mapper(C1, t1, properties = { - 'c1s' : relation(C1, lazy=False), - }) - - try: - m1.compile() - assert False - except exceptions.ArgumentError: - assert True class SelfReferentialNoPKTest(AssertMixin): """test self-referential relationship that joins on a column other than the primary key column""" diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 895ec4a138..c2d1b28bc9 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -185,6 +185,46 @@ class EagerTest(QueryTest): ] == q.all() self.assert_sql_count(testbase.db, go, 1) + + def test_double_same_mappers(self): + """tests lazy loading with two relations simulatneously, from the same table, using aliases. """ + + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id), + }) + mapper(Item, items) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=False), + open_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 1, users.c.id==orders.c.user_id), lazy=False), + closed_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 0, users.c.id==orders.c.user_id), lazy=False) + )) + q = create_session().query(User) + + def go(): + assert [ + User( + id=7, + addresses=[Address(id=1)], + open_orders = [Order(id=3, items=[Item(id=3), Item(id=4), Item(id=5)])], + closed_orders = [Order(id=1, items=[Item(id=1), Item(id=2), Item(id=3)]), Order(id=5, items=[Item(id=5)])] + ), + User( + id=8, + addresses=[Address(id=2), Address(id=3), Address(id=4)], + open_orders = [], + closed_orders = [] + ), + User( + id=9, + addresses=[Address(id=5)], + open_orders = [Order(id=4, items=[Item(id=1), Item(id=5)])], + closed_orders = [Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])] + ), + User(id=10) + + ] == q.all() + self.assert_sql_count(testbase.db, go, 1) def test_limit(self): """test limit operations combined with lazy-load relationships.""" @@ -398,5 +438,44 @@ class EagerTest(QueryTest): l = q.filter(addresses.c.email_address == 'ed@lala.com').filter(Address.user_id==User.id) assert fixtures.user_address_result[1:2] == l.all() +class SelfReferentialEagerTest(testbase.ORMTest): + def define_tables(self, metadata): + global nodes + nodes = Table('nodes', metadata, + Column('id', Integer, primary_key=True), + Column('parent_id', Integer, ForeignKey('nodes.id')), + Column('data', String(30))) + + def test_basic(self): + class Node(Base): + def append(self, node): + self.children.append(node) + + mapper(Node, nodes, properties={ + 'children':relation(Node, lazy=False, join_depth=3) + }) + sess = create_session() + n1 = Node(data='n1') + n1.append(Node(data='n11')) + n1.append(Node(data='n12')) + n1.append(Node(data='n13')) + n1.children[1].append(Node(data='n121')) + n1.children[1].append(Node(data='n122')) + n1.children[1].append(Node(data='n123')) + sess.save(n1) + sess.flush() + sess.clear() + def go(): + d = sess.query(Node).filter_by(data='n1').first() + assert Node(data='n1', children=[ + Node(data='n11'), + Node(data='n12', children=[ + Node(data='n121'), + Node(data='n122'), + Node(data='n123') + ]), + Node(data='n13') + ]) == d + self.assert_sql_count(testbase.db, go, 1) if __name__ == '__main__': testbase.main() diff --git a/test/orm/inheritance/polymorph.py b/test/orm/inheritance/polymorph.py index d7900610ff..2aae9eb6a4 100644 --- a/test/orm/inheritance/polymorph.py +++ b/test/orm/inheritance/polymorph.py @@ -112,29 +112,6 @@ class CompileTest(PolymorphTest): #person_mapper.compile() class_mapper(Manager).compile() - def testcompile3(self): - """test that a mapper referencing an inheriting mapper in a self-referential relationship does - not allow an eager load to be set up.""" - person_join = polymorphic_union( { - 'engineer':people.join(engineers), - 'manager':people.join(managers), - 'person':people.select(people.c.type=='person'), - }, None, 'pjoin') - - person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, - polymorphic_identity='person', - properties = dict(managers = relation(Manager, lazy=False)) - ) - - mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer') - mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager') - - try: - class_mapper(Manager).compile() - assert False - except exceptions.ArgumentError: - assert True - class InsertOrderTest(PolymorphTest): def test_insert_order(self): """test that classes of multiple types mix up mapper inserts -- 2.47.3