From 72f479b324c0b62a19e58c7b173a62b55c34a928 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 1 Feb 2007 01:47:54 +0000 Subject: [PATCH] - improved support for complex queries embedded into "where" criterion for query.select() [ticket:449] - contains_eager('foo') automatically implies eagerload('foo') - query.options() can take a combiantion MapperOptions and tuples of MapperOptions, so that functions can return groups - refactoring to Aliasizer and ClauseAdapter so that they share a common base methodology, which addresses all sql.ColumnElements instead of just schema.Column. common list-processing methods added. - query.compile and eagerloader._aliasize_orderby make usage of improved list processing on above. - query.compile, within the "nested select generate" step processes the order_by clause using the ClauseAdapter instead of Aliasizer since there is only one "target" --- CHANGES | 3 ++ lib/sqlalchemy/orm/__init__.py | 4 +- lib/sqlalchemy/orm/interfaces.py | 2 +- lib/sqlalchemy/orm/query.py | 17 +++--- lib/sqlalchemy/orm/strategies.py | 10 ++-- lib/sqlalchemy/sql_util.py | 89 ++++++++++++++++++++------------ lib/sqlalchemy/util.py | 10 ++++ test/orm/eagertest3.py | 68 ++++++++++++++++++++++++ test/orm/mapper.py | 4 +- 9 files changed, 152 insertions(+), 55 deletions(-) diff --git a/CHANGES b/CHANGES index 38474189e3..4ff29d3b00 100644 --- a/CHANGES +++ b/CHANGES @@ -15,6 +15,9 @@ the relationship. - eager loading is slightly more strict about detecting "self-referential" relationships, specifically between polymorphic mappers. + - improved support for complex queries embedded into "where" criterion + for query.select() [ticket:449] + - contains_eager('foo') automatically implies eagerload('foo') - fixed bug where cascade operations incorrectly included deleted collection items in the cascade [ticket:445] - fix to deferred so that load operation doesnt mistakenly occur when only diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 1e1a75b631..a933cb1b27 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -124,8 +124,8 @@ def contains_eager(key, decorator=None): a custom row decorator. used when feeding SQL result sets directly into - query.instances().""" - return strategies.RowDecorateOption(key, decorator=decorator) + query.instances(). Also bundles an EagerLazyOption to turn on eager loading in case it isnt already.""" + return (strategies.EagerLazyOption(key, lazy=False), strategies.RowDecorateOption(key, decorator=decorator)) def defer(name): """return a MapperOption that will convert the column property of the given diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 959c9c36b3..f40856b41a 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -101,7 +101,7 @@ class OperationContext(object): self.options = options self.attributes = {} self.recursion_stack = util.Set() - for opt in options: + for opt in util.flatten_iterator(options): self.accept_option(opt) def accept_option(self, opt): pass diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index bc047ff504..adec69116b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions, sql_util, logging +from sqlalchemy import sql, util, exceptions, sql_util, logging, schema from sqlalchemy.orm import mapper, class_mapper from sqlalchemy.orm.interfaces import OperationContext @@ -34,7 +34,7 @@ class Query(object): _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type)) self.mapper._get_clause = _get_clause self._get_clause = self.mapper._get_clause - for opt in self.with_options: + for opt in util.flatten_iterator(self.with_options): opt.process_query(self) def _insert_extension(self, ext): @@ -440,7 +440,8 @@ class Query(object): if order_by: order_by = util.to_list(order_by) or [] cf = sql_util.ColumnFinder() - [o.accept_visitor(cf) for o in order_by] + for o in order_by: + o.accept_visitor(cf) else: cf = [] @@ -449,17 +450,11 @@ class Query(object): s2.order_by(*util.to_list(order_by)) s3 = s2.alias('tbl_row_count') crit = s3.primary_key==self.table.primary_key - statement = sql.select([], crit, from_obj=[self.table], use_labels=True, for_update=for_update) + statement = sql.select([], crit, use_labels=True, for_update=for_update) # now for the order by, convert the columns to their corresponding columns # in the "rowcount" query, and tack that new order by onto the "rowcount" query if order_by: - class Aliasizer(sql_util.Aliasizer): - def get_alias(self, table): - return s3 - order_by = [o.copy_container() for o in order_by] - aliasizer = Aliasizer(*[t for t in sql_util.TableFinder(s3)]) - [o.accept_visitor(aliasizer) for o in order_by] - statement.order_by(*util.to_list(order_by)) + statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by)) else: statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args()) if order_by: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 50a2ea27d2..29b60a8f61 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -419,15 +419,11 @@ class EagerLoader(AbstractRelationLoader): def _aliasize_orderby(self, orderby, copy=True): if copy: - orderby = [o.copy_container() for o in util.to_list(orderby)] + return self.aliasizer.copy_and_process(util.to_list(orderby)) else: orderby = util.to_list(orderby) - for i in range(0, len(orderby)): - if isinstance(orderby[i], schema.Column): - orderby[i] = self.eagertarget.corresponding_column(orderby[i]) - else: - orderby[i].accept_visitor(self.aliasizer) - return orderby + self.aliasizer.process_list(orderby) + return orderby def _create_decorator_row(self): class EagerRowAdapter(object): diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 10d4495d93..e672feb1a9 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -78,8 +78,53 @@ class ColumnFinder(sql.ClauseVisitor): self.columns.add(c) def __iter__(self): return iter(self.columns) - -class Aliasizer(sql.ClauseVisitor): + +class ColumnsInClause(sql.ClauseVisitor): + """given a selectable, visits clauses and determines if any columns from the clause are in the selectable""" + def __init__(self, selectable): + self.selectable = selectable + self.result = False + def visit_column(self, column): + if self.selectable.c.get(column.key) is column: + self.result = True + +class AbstractClauseProcessor(sql.ClauseVisitor): + """traverses a clause and attempts to convert the contents of container elements + to a converted element. the conversion operation is defined by subclasses.""" + def convert_element(self, elem): + """define the 'conversion' method for this AbstractClauseProcessor""" + raise NotImplementedError() + def copy_and_process(self, list_): + """copy the container elements in the given list to a new list and + process the new list.""" + list_ = [o.copy_container() for o in list_] + self.process_list(list_) + return list_ + + def process_list(self, list_): + """process all elements of the given list in-place""" + for i in range(0, len(list_)): + elem = self.convert_element(list_[i]) + if elem is not None: + list_[i] = elem + else: + list_[i].accept_visitor(self) + def visit_compound(self, compound): + self.visit_clauselist(compound) + def visit_clauselist(self, clist): + for i in range(0, len(clist.clauses)): + n = self.convert_element(clist.clauses[i]) + if n is not None: + clist.clauses[i] = n + def visit_binary(self, binary): + elem = self.convert_element(binary.left) + if elem is not None: + binary.left = elem + elem = self.convert_element(binary.right) + if elem is not None: + binary.right = elem + +class Aliasizer(AbstractClauseProcessor): """converts a table instance within an expression to be an alias of that table.""" def __init__(self, *tables, **kwargs): self.tables = {} @@ -95,21 +140,13 @@ class Aliasizer(sql.ClauseVisitor): self.binary = None def get_alias(self, table): return self.aliases[table] - def visit_compound(self, compound): - self.visit_clauselist(compound) - def visit_clauselist(self, clist): - for i in range(0, len(clist.clauses)): - if isinstance(clist.clauses[i], schema.Column) and self.tables.has_key(clist.clauses[i].table): - orig = clist.clauses[i] - clist.clauses[i] = self.get_alias(clist.clauses[i].table).corresponding_column(clist.clauses[i]) - def visit_binary(self, binary): - if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table): - binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left) - if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table): - binary.right = self.get_alias(binary.right.table).corresponding_column(binary.right) - + def convert_element(self, elem): + if isinstance(elem, sql.ColumnElement) and hasattr(elem, 'table') and self.tables.has_key(elem.table): + return self.get_alias(elem.table).corresponding_column(elem) + else: + return None -class ClauseAdapter(sql.ClauseVisitor): +class ClauseAdapter(AbstractClauseProcessor): """given a clause (like as in a WHERE criterion), locates columns which 'correspond' to a given selectable, and changes those columns to be that of the selectable. @@ -140,7 +177,8 @@ class ClauseAdapter(sql.ClauseVisitor): self.include = include self.exclude = exclude self.equivalents = equivalents - def include_col(self, col): + + def convert_element(self, col): if not isinstance(col, sql.ColumnElement): return None if self.include is not None: @@ -153,19 +191,4 @@ class ClauseAdapter(sql.ClauseVisitor): if newcol is None and self.equivalents is not None and col in self.equivalents: newcol = self.selectable.corresponding_column(self.equivalents[col], raiseerr=False, keys_ok=False) return newcol - def visit_binary(self, binary): - col = self.include_col(binary.left) - if col is not None: - binary.left = col - col = self.include_col(binary.right) - if col is not None: - binary.right = col - -class ColumnsInClause(sql.ClauseVisitor): - """given a selectable, visits clauses and determines if any columns from the clause are in the selectable""" - def __init__(self, selectable): - self.selectable = selectable - self.result = False - def visit_column(self, column): - if self.selectable.c.get(column.key) is column: - self.result = True + diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 54b1afa9fb..c2e0dbc45f 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -35,6 +35,16 @@ def to_set(x): else: return x +def flatten_iterator(x): + """given an iterator of which further sub-elements may also be iterators, + flatten the sub-elements into a single iterator.""" + for elem in x: + if hasattr(elem, '__iter__'): + for y in flatten_iterator(elem): + yield y + else: + yield elem + def reversed(seq): try: return __builtin__.reversed(seq) diff --git a/test/orm/eagertest3.py b/test/orm/eagertest3.py index 10f8fce0a3..e33ce43944 100644 --- a/test/orm/eagertest3.py +++ b/test/orm/eagertest3.py @@ -2,6 +2,7 @@ from testbase import PersistTest, AssertMixin import testbase from sqlalchemy import * from sqlalchemy.ext.selectresults import SelectResults +import random class EagerTest(AssertMixin): def setUpAll(self): @@ -197,7 +198,74 @@ class EagerTest2(AssertMixin): session.clear() obj = session.query(Left).get_by(tag='tag1') print obj.middle.right[0] + +class EagerTest3(testbase.ORMTest): + """test eager loading combined with nested SELECT statements, functions, and aggregates""" + def define_tables(self, metadata): + global datas, foo, stats + datas=Table( 'datas',metadata, + Column ( 'id', Integer, primary_key=True,nullable=False ), + Column ( 'a', Integer , nullable=False ) ) + + foo=Table('foo',metadata, + Column ( 'data_id', Integer, ForeignKey('datas.id'),nullable=False,primary_key=True ), + Column ( 'bar', Integer ) ) + + stats=Table('stats',metadata, + Column ( 'id', Integer, primary_key=True, nullable=False ), + Column ( 'data_id', Integer, ForeignKey('datas.id')), + Column ( 'somedata', Integer, nullable=False )) + + def test_nesting_with_functions(self): + class Data(object): pass + class Foo(object):pass + class Stat(object): pass + + Data.mapper=mapper(Data,datas) + Foo.mapper=mapper(Foo,foo,properties={'data':relation(Data,backref=backref('foo',uselist=False))}) + Stat.mapper=mapper(Stat,stats,properties={'data':relation(Data)}) + + s=create_session() + data = [] + for x in range(5): + d=Data() + d.a=x + s.save(d) + data.append(d) + + for x in range(10): + rid=random.randint(0,len(data) - 1) + somedata=random.randint(1,50000) + stat=Stat() + stat.data = data[rid] + stat.somedata=somedata + s.save(stat) + + s.flush() + + arb_data=select( + [stats.c.data_id,func.max(stats.c.somedata).label('max')], + stats.c.data_id<=25, + group_by=[stats.c.data_id]).alias('arb') + + arb_result = arb_data.execute().fetchall() + # order the result list descending based on 'max' + arb_result.sort(lambda a, b:cmp(b['max'],a['max'])) + # extract just the "data_id" from it + arb_result = [row['data_id'] for row in arb_result] + + # now query for Data objects using that above select, adding the + # "order by max desc" separately + q=s.query(Data).options(eagerload('foo')).select( + from_obj=[datas.join(arb_data,arb_data.c.data_id==datas.c.id)], + order_by=[desc(arb_data.c.max)],limit=10) + + # extract "data_id" from the list of result objects + verify_result = [d.id for d in q] + # assert equality including ordering (may break if the DB "ORDER BY" and python's sort() used differing + # algorithms and there are repeated 'somedata' values in the list) + assert verify_result == arb_result if __name__ == "__main__": testbase.main() diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 479d484539..46854b2cea 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1028,7 +1028,9 @@ class EagerTest(MapperSuperTest): def testcustomeagerquery(self): mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False) + # setting lazy=True - the contains_eager() option below + # should imply eagerload() + 'addresses':relation(Address, lazy=True) }) mapper(Address, addresses) -- 2.47.2