From: Mike Bayer Date: Sun, 22 Jul 2007 14:07:15 +0000 (+0000) Subject: - got self-referential query.join()/query.outerjoin() to work. X-Git-Tag: rel_0_4_6~54 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=94ea86c195cc8c39ccd9a109be5bd31c2e9ca7cf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - got self-referential query.join()/query.outerjoin() to work. - PropertyLoader adds local_side set which is the opposite of remote_side, makes the self-referential aliasing a snap. - added "id" argument to join()/outerjoin() to allow access to the aliased joins in add_entity(). - added "alias" argument to add_entity() to get at entities selected from an explicit Alias - starting to move EagerLoader.AliasedClasues to a general utility function which will be used by Query as well --- diff --git a/CHANGES b/CHANGES index ce23ba9e63..f6403612f0 100644 --- a/CHANGES +++ b/CHANGES @@ -34,23 +34,34 @@ - Class-level properties are now usable as query elements ...no more '.c.' ! "Class.c.propname" is now superceded by "Class.propname". All clause operators are supported, as well as higher level operators - such as Class.prop== for scalar attributes and - Class.prop.contains() for collection-based attributes - (both are also negatable). Table-based column expressions as well as - columns mounted on mapped classes via 'c' are of course still fully available - and can be freely mixed with the new attributes. + such as Class.prop== for scalar attributes, + Class.prop.contains() and Class.prop.any() + for collection-based attributes (all are also negatable). Table-based column + expressions as well as columns mounted on mapped classes via 'c' are of + course still fully available and can be freely mixed with the new attributes. [ticket:643] - removed ancient query.select_by_attributename() capability. - - - added "aliased joins" positional argument to the front of - filter_by(). this allows auto-creation of joins that are aliased - locally to the individual filter_by() call. This allows the - auto-construction of joins which cross the same paths but are - querying divergent criteria. ClauseElements at the front of - filter_by() are removed (use filter()). - - - added query.populate_existing().. - marks the query to reload + + - the aliasing logic used by eager loading has been generalized, so that + it also adds full automatic aliasing support to Query. It's no longer + necessary to create an explicit Alias to join to the same tables multiple times; + *even for self-referential relationships!!* + - join() and outerjoin() take arguments "aliased=True". this causes + their joins to be built on aliased tables; subsequent calls + to filter() and filter_by() will translate all table expressions + (yes, real expressions using the original mapped Table) to be that of + the Alias for the duration of that join() (i.e. until reset_joinpoint() + or another join() is called). + - join() and outerjoin() take arguments "id=". when used + with "aliased=True", the id can be referenced by add_entity(cls, id=) + so that you can select the joined instances even if they're from an alias. + - join() and outerjoin() now work with self-referential relationships! using + "aliased=True", you can join as many levels deep as desired, i.e. + query.join(['children', 'children'], aliased=True); filter criterion will + be against the rightmost joined table + + - added query.populate_existing() - marks the query to reload all attributes and collections of all instances touched in the query, including eagerly-loaded entities [ticket:660] diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d259270580..a335cdd69c 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -488,16 +488,13 @@ class PropertyLoader(StrategizedProperty): "argument." % (str(self))) def _determine_remote_side(self): - if len(self.remote_side): - return - self.remote_side = util.Set() + if not len(self.remote_side): + if self.direction is sync.MANYTOONE: + self.remote_side = util.Set(self._opposite_side) + elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: + self.remote_side = util.Set(self.foreign_keys) - if self.direction is sync.MANYTOONE: - for c in self._opposite_side: - self.remote_side.add(c) - elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY: - for c in self.foreign_keys: - self.remote_side.add(c) + self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side) def _create_polymorphic_joins(self): # get ready to create "polymorphic" primary/secondary join clauses. @@ -575,18 +572,20 @@ class PropertyLoader(StrategizedProperty): def _is_self_referential(self): return self.parent.mapped_table is self.target or self.parent.select_table is self.target - def get_join(self, parent, primary=True, secondary=True): + def get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): try: - return self._parent_join_cache[(parent, primary, secondary)] + return self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] except KeyError: parent_equivalents = parent._get_equivalent_columns() secondaryjoin = self.polymorphic_secondaryjoin - if self.direction is sync.ONETOMANY: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) - elif self.direction is sync.MANYTOONE: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) - elif self.secondaryjoin: - primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + if polymorphic_parent: + # adapt the "parent" side of our join condition to the "polymorphic" select of the parent + if self.direction is sync.ONETOMANY: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + elif self.direction is sync.MANYTOONE: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, include=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) + elif self.secondaryjoin: + primaryjoin = sql_util.ClauseAdapter(parent.select_table, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(self.polymorphic_primaryjoin, clone=True) if secondaryjoin is not None: if secondary and not primary: @@ -597,7 +596,7 @@ class PropertyLoader(StrategizedProperty): j = primaryjoin else: j = primaryjoin - self._parent_join_cache[(parent, primary, secondary)] = j + self._parent_join_cache[(parent, primary, secondary, polymorphic_parent)] = j return j def register_dependencies(self, uowcommit): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 004024e58d..4f59dd4550 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -38,6 +38,7 @@ class Query(object): self._column_aggregate = None self._joinpoint = self.mapper self._aliases = None + self._alias_ids = {} self._from_obj = [self.table] self._populate_existing = False self._version_check = False @@ -167,7 +168,7 @@ class Query(object): prop = mapper.get_property(property, resolve_synonyms=True) return self.filter(prop.compare(operator.eq, instance, value_is_parent=True)) - def add_entity(self, entity): + def add_entity(self, entity, alias=None, id=None): """add a mapped entity to the list of result columns to be returned. This will have the effect of all result-returning methods returning a tuple @@ -183,9 +184,16 @@ class Query(object): entity a class or mapper which will be added to the results. + alias + a sqlalchemy.sql.Alias object which will be used to select rows. this + will match the usage of the given Alias in filter(), order_by(), etc. expressions + + id + a string ID matching that given to query.join() or query.outerjoin(); rows will be + selected from the aliased join created via those methods. """ q = self._clone() - q._entities = q._entities + [entity] + q._entities = q._entities + [(entity, alias, id)] return q def add_column(self, column): @@ -255,11 +263,11 @@ class Query(object): if criterion is not None and not isinstance(criterion, sql.ClauseElement): raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") + if self._aliases is not None: - adapter = sql_util.ClauseAdapter(self._aliases[0]) - for alias in self._aliases[1:]: - adapter.chain(sql_util.ClauseAdapter(alias)) - criterion = adapter.traverse(criterion, clone=True) + # adapt only the *last* alias in the list for now. + # this helps a self-referential join to work, i.e. table.join(table.alias(a)).join(table.alias(b)) + criterion = sql_util.ClauseAdapter(self._aliases[-1]).traverse(criterion, clone=True) q = self._clone() if q._criterion is not None: @@ -316,6 +324,8 @@ class Query(object): # TODO: create_aliases automatically ? probably raise exceptions.InvalidRequestError("Self-referential query on '%s' property requries create_aliases=True argument." % str(prop)) # dont re-join to a table already in our from objects + # TODO: this code has a little bit of overlap with strategies.EagerLoader.AliasedClauses. possibly + # look into generalizing that functionality for usage in both places if prop.select_table not in currenttables or create_aliases: if outerjoin: if prop.secondary: @@ -346,10 +356,10 @@ class Query(object): if create_aliases: join = prop.get_join(mapper) if alias is not None: - join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) + join = sql_util.ClauseAdapter(alias, exclude=prop.remote_side).traverse(join, clone=True) alias = prop.select_table.alias() aliases.append(alias) - join = sql_util.ClauseAdapter(alias).traverse(join, clone=True) + join = sql_util.ClauseAdapter(alias, exclude=prop.local_side).traverse(join, clone=True) clause = clause.join(alias, join) else: clause = clause.join(prop.select_table, prop.get_join(mapper)) @@ -447,7 +457,7 @@ class Query(object): q._group_by = q._group_by + util.to_list(criterion) return q - def join(self, prop, aliased=False): + def join(self, prop, aliased=False, id=None): """create a join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. @@ -460,9 +470,11 @@ class Query(object): q._from_obj = [clause] q._joinpoint = mapper q._aliases = aliases + if id: + q._alias_ids[id] = aliases[-1] return q - def outerjoin(self, prop, aliased=False): + def outerjoin(self, prop, aliased=False, id=None): """create a left outer join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. @@ -474,6 +486,8 @@ class Query(object): q._from_obj = [clause] q._joinpoint = mapper q._aliases = aliases + if id: + q._alias_ids[id] = aliases[-1] return q def reset_joinpoint(self): @@ -645,13 +659,23 @@ class Query(object): mappers_or_columns = tuple(self._entities) + mappers_or_columns if mappers_or_columns: for m in mappers_or_columns: + if isinstance(m, tuple): + (m, alias, alias_id) = m + if alias_id is not None: + try: + alias = self._alias_ids[alias_id] + except KeyError: + raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id) + else: + alias = alias_id = None if isinstance(m, type): m = mapper.class_mapper(m) if isinstance(m, mapper.Mapper): def x(m): + row_adapter = sql_util.create_row_adapter(alias, m.select_table) appender = [] def proc(context, row): - if not m._instance(context, row, appender): + if not m._instance(context, row_adapter(row), appender): appender.append(None) process.append((proc, appender)) x(m) @@ -865,11 +889,18 @@ class Query(object): # additional entities/columns, add those to selection criterion for m in self._entities: - if isinstance(m, type): - m = mapper.class_mapper(m) - if isinstance(m, mapper.Mapper): - for value in m.iterate_properties: - value.setup(context) + if isinstance(m, tuple): + (m, alias, alias_id) = m + if alias_id is not None: + try: + alias = self._alias_ids[alias_id] + except KeyError: + raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % alias_id) + if isinstance(m, type): + m = mapper.class_mapper(m) + if isinstance(m, mapper.Mapper): + for value in m.iterate_properties: + value.setup(context, eagertable=alias) elif isinstance(m, sql.ColumnElement): statement.append_column(m) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 0ceb969559..3e26280fca 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -24,6 +24,8 @@ class ColumnLoader(LoaderStrategy): for c in self.columns: if parentclauses is not None: context.statement.append_column(parentclauses.aliased_column(c)) + elif eagertable is not None: + context.statement.append_column(eagertable.corresponding_column(c)) else: context.statement.append_column(c) @@ -493,19 +495,11 @@ class EagerLoader(AbstractRelationLoader): else: self.eagerprimary = eagerloader.polymorphic_primaryjoin - # for self-referential eager load, the "aliasing" of each side of the join condition - # must be limited to exactly the cols we know are on "our side". for non-self-referntial, - # be more liberal to include other elements of the join condition which deal with "our" table - if eagerloader.parent_property._is_self_referential(): - include = eagerloader.parent_property.remote_side - else: - include = None - if parentclauses is not None: - aliasizer = sql_util.ClauseAdapter(self.eagertarget, include=include) + aliasizer = sql_util.ClauseAdapter(self.eagertarget, exclude=eagerloader.parent_property.local_side) aliasizer.chain(sql_util.ClauseAdapter(parentclauses.eagertarget, exclude=eagerloader.parent_property.remote_side)) else: - aliasizer = sql_util.ClauseAdapter(self.eagertarget, include=include) + aliasizer = sql_util.ClauseAdapter(self.eagertarget, exclude=eagerloader.parent_property.local_side) self.eagerprimary = aliasizer.traverse(self.eagerprimary, clone=True) if eagerloader.order_by: @@ -513,7 +507,7 @@ class EagerLoader(AbstractRelationLoader): else: self.eager_order_by = None - self._row_decorator = self._create_decorator_row() + self._row_decorator = sql_util.create_row_adapter(self.eagertarget, self.target) def __str__(self): return "->".join([str(s) for s in self.path]) @@ -543,29 +537,6 @@ class EagerLoader(AbstractRelationLoader): self.extra_cols[column] = aliased_column return aliased_column - def _create_decorator_row(self): - class EagerRowAdapter(object): - def __init__(self, row): - self.row = row - def __contains__(self, key): - return key in map or key in self.row - def has_key(self, key): - return key in self - def __getitem__(self, key): - if key in map: - key = map[key] - return self.row[key] - def keys(self): - return map.keys() - map = {} - for c in self.eagertarget.c: - parent = self.target.corresponding_column(c) - map[parent] = c - map[parent._label] = c - map[parent.name] = c - EagerRowAdapter.map = map - return EagerRowAdapter - def init_class_attribute(self): self.parent_property._get_strategy(LazyLoader).init_class_attribute() @@ -639,7 +610,7 @@ class EagerLoader(AbstractRelationLoader): statement.append_from(statement._outerjoin) for value in self.select_mapper.iterate_properties: - value.setup(context, eagertable=clauses.eagertarget, parentclauses=clauses, parentmapper=self.select_mapper) + value.setup(context, parentclauses=clauses, parentmapper=self.select_mapper) def _create_row_decorator(self, selectcontext, row, path): """Create a *row decorating* function that will apply eager diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 96dae5e0a2..78383c78ff 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -233,3 +233,39 @@ class ClauseAdapter(AbstractClauseProcessor): if newcol: return newcol return newcol + + +def create_row_adapter(alias, table): + """given a sql.Alias and a target selectable, return a callable which, + when passed a RowProxy, will return a new dict-like object + that translates Column objects to that of the Alias before calling upon the row. + + This allows a regular Table to be used to target columns in a row that was in reality generated from an alias + of that table, in such a way that the row can be passed to logic which knows nothing about the aliased form + of the table. + """ + + if alias is None: + return lambda row:row + + class AliasedRowAdapter(object): + def __init__(self, row): + self.row = row + def __contains__(self, key): + return key in map or key in self.row + def has_key(self, key): + return key in self + def __getitem__(self, key): + if key in map: + key = map[key] + return self.row[key] + def keys(self): + return map.keys() + map = {} + for c in alias.c: + parent = table.corresponding_column(c) + map[parent] = c + map[parent._label] = c + map[parent.name] = c + AliasedRowAdapter.map = map + return AliasedRowAdapter diff --git a/test/orm/eager_relations.py b/test/orm/eager_relations.py index 4dcdeac37d..396c28bf94 100644 --- a/test/orm/eager_relations.py +++ b/test/orm/eager_relations.py @@ -148,7 +148,7 @@ class EagerTest(QueryTest): assert fixtures.user_address_result == sess.query(User).all() def test_double(self): - """tests lazy loading with two relations simulatneously, from the same table, using aliases. """ + """tests eager loading with two relations simulatneously, from the same table, using aliases. """ openorders = alias(orders, 'openorders') closedorders = alias(orders, 'closedorders') @@ -187,7 +187,7 @@ class EagerTest(QueryTest): self.assert_sql_count(testbase.db, go, 1) def test_double_same_mappers(self): - """tests lazy loading with two relations simulatneously, from the same table, using aliases. """ + """tests eager loading with two relations simulatneously, from the same table, using aliases. """ mapper(Address, addresses) mapper(Order, orders, properties={ diff --git a/test/orm/generative.py b/test/orm/generative.py index 3d8f2bb392..4924908449 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -7,6 +7,8 @@ from sqlalchemy.orm import * from sqlalchemy import exceptions from testbase import Table, Column +# TODO: these are more tests that should be updated to be part of test/orm/query.py + class Foo(object): def __init__(self, **kwargs): for k in kwargs: @@ -256,13 +258,13 @@ class SelfRefTest(ORMTest): sess.query(T).join('children').select_by(id=7) assert False except exceptions.InvalidRequestError, e: - assert str(e) == "Self-referential query on 'T.children (T)' property must be constructed manually using an Alias object for the related table.", str(e) + assert str(e) == "Self-referential query on 'T.children (T)' property requries create_aliases=True argument.", str(e) try: sess.query(T).join(['children']).select_by(id=7) assert False except exceptions.InvalidRequestError, e: - assert str(e) == "Self-referential query on 'T.children (T)' property must be constructed manually using an Alias object for the related table.", str(e) + assert str(e) == "Self-referential query on 'T.children (T)' property requries create_aliases=True argument.", str(e) diff --git a/test/orm/query.py b/test/orm/query.py index 6fa5c9644c..82bc0c6d41 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -429,6 +429,27 @@ class JoinTest(QueryTest): assert [] == q.all() assert q.count() == 0 + def test_aliased_add_entity(self): + """test the usage of aliased joins with add_entity()""" + sess = create_session() + q = sess.query(User).join('orders', aliased=True, id='order1').filter(Order.description=="order 3").join(['orders', 'items'], aliased=True, id='item1').filter(Item.description=="item 1") + + try: + q.add_entity(Order, id='fakeid').compile() + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Query has no alias identified by 'fakeid'" + + try: + q.add_entity(Order, id='fakeid').instances(None) + assert False + except exceptions.InvalidRequestError, e: + assert str(e) == "Query has no alias identified by 'fakeid'" + + q = q.add_entity(Order, id='order1').add_entity(Item, id='item1') + assert q.count() == 1 + assert [(User(id=7), Order(description='order 3'), Item(description='item 1'))] == q.all() + class SynonymTest(QueryTest): keep_mappers = True @@ -573,6 +594,30 @@ class InstancesTest(QueryTest): q = sess.query(User, Address).join('addresses').options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com') assert q.all() == [(user8, address3)] + def test_aliased_multi_mappers(self): + sess = create_session() + + (user7, user8, user9, user10) = sess.query(User).all() + (address1, address2, address3, address4, address5) = sess.query(Address).all() + + # note the result is a cartesian product + expected = [(user7, address1), + (user8, address2), + (user8, address3), + (user8, address4), + (user9, address5), + (user10, None)] + + q = sess.query(User) + adalias = addresses.alias('adalias') + q = q.add_entity(Address, alias=adalias).select_from(users.outerjoin(adalias)) + l = q.all() + assert l == expected + + q = sess.query(User).add_entity(Address, alias=adalias) + l = q.select_from(users.outerjoin(adalias)).filter(adalias.c.email_address=='ed@bettyboop.com').all() + assert l == [(user8, address3)] + def test_multi_columns(self): sess = create_session() (user7, user8, user9, user10) = sess.query(User).all() @@ -612,8 +657,29 @@ class InstancesTest(QueryTest): assert q.all() == expected -# this test not working yet -class SelfReferentialTest(object): #testbase.ORMTest): +class CustomJoinTest(QueryTest): + keep_mappers = False + + def setup_mappers(self): + pass + + def test_double_same_mappers(self): + """test aliasing of joins with a custom join condition""" + mapper(Address, addresses) + mapper(Order, orders, properties={ + 'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id), + }) + mapper(Item, items) + mapper(User, users, properties = dict( + addresses = relation(Address, lazy=False), + open_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 1, users.c.id==orders.c.user_id), lazy=False), + closed_orders = relation(Order, primaryjoin = and_(orders.c.isopen == 0, users.c.id==orders.c.user_id), lazy=False) + )) + q = create_session().query(User) + + assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all() + +class SelfReferentialJoinTest(testbase.ORMTest): def define_tables(self, metadata): global nodes nodes = Table('nodes', metadata, @@ -647,6 +713,9 @@ class SelfReferentialTest(object): #testbase.ORMTest): node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first() assert node.data=='n12' + node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first() + assert node.data=='n1' + if __name__ == '__main__': testbase.main()