the relationship.
- eager loading is slightly more strict about detecting "self-referential"
relationships, specifically between polymorphic mappers.
+ - improved support for complex queries embedded into "where" criterion
+ for query.select() [ticket:449]
+ - contains_eager('foo') automatically implies eagerload('foo')
- fixed bug where cascade operations incorrectly included deleted collection
items in the cascade [ticket:445]
- fix to deferred so that load operation doesnt mistakenly occur when only
a custom row decorator.
used when feeding SQL result sets directly into
- query.instances()."""
- return strategies.RowDecorateOption(key, decorator=decorator)
+ query.instances(). Also bundles an EagerLazyOption to turn on eager loading in case it isnt already."""
+ return (strategies.EagerLazyOption(key, lazy=False), strategies.RowDecorateOption(key, decorator=decorator))
def defer(name):
"""return a MapperOption that will convert the column property of the given
self.options = options
self.attributes = {}
self.recursion_stack = util.Set()
- for opt in options:
+ for opt in util.flatten_iterator(options):
self.accept_option(opt)
def accept_option(self, opt):
pass
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import sql, util, exceptions, sql_util, logging
+from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
from sqlalchemy.orm import mapper, class_mapper
from sqlalchemy.orm.interfaces import OperationContext
_get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type))
self.mapper._get_clause = _get_clause
self._get_clause = self.mapper._get_clause
- for opt in self.with_options:
+ for opt in util.flatten_iterator(self.with_options):
opt.process_query(self)
def _insert_extension(self, ext):
if order_by:
order_by = util.to_list(order_by) or []
cf = sql_util.ColumnFinder()
- [o.accept_visitor(cf) for o in order_by]
+ for o in order_by:
+ o.accept_visitor(cf)
else:
cf = []
s2.order_by(*util.to_list(order_by))
s3 = s2.alias('tbl_row_count')
crit = s3.primary_key==self.table.primary_key
- statement = sql.select([], crit, from_obj=[self.table], use_labels=True, for_update=for_update)
+ statement = sql.select([], crit, use_labels=True, for_update=for_update)
# now for the order by, convert the columns to their corresponding columns
# in the "rowcount" query, and tack that new order by onto the "rowcount" query
if order_by:
- class Aliasizer(sql_util.Aliasizer):
- def get_alias(self, table):
- return s3
- order_by = [o.copy_container() for o in order_by]
- aliasizer = Aliasizer(*[t for t in sql_util.TableFinder(s3)])
- [o.accept_visitor(aliasizer) for o in order_by]
- statement.order_by(*util.to_list(order_by))
+ statement.order_by(*sql_util.ClauseAdapter(s3).copy_and_process(order_by))
else:
statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args())
if order_by:
def _aliasize_orderby(self, orderby, copy=True):
if copy:
- orderby = [o.copy_container() for o in util.to_list(orderby)]
+ return self.aliasizer.copy_and_process(util.to_list(orderby))
else:
orderby = util.to_list(orderby)
- for i in range(0, len(orderby)):
- if isinstance(orderby[i], schema.Column):
- orderby[i] = self.eagertarget.corresponding_column(orderby[i])
- else:
- orderby[i].accept_visitor(self.aliasizer)
- return orderby
+ self.aliasizer.process_list(orderby)
+ return orderby
def _create_decorator_row(self):
class EagerRowAdapter(object):
self.columns.add(c)
def __iter__(self):
return iter(self.columns)
-
-class Aliasizer(sql.ClauseVisitor):
+
+class ColumnsInClause(sql.ClauseVisitor):
+ """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
+ def __init__(self, selectable):
+ self.selectable = selectable
+ self.result = False
+ def visit_column(self, column):
+ if self.selectable.c.get(column.key) is column:
+ self.result = True
+
+class AbstractClauseProcessor(sql.ClauseVisitor):
+ """traverses a clause and attempts to convert the contents of container elements
+ to a converted element. the conversion operation is defined by subclasses."""
+ def convert_element(self, elem):
+ """define the 'conversion' method for this AbstractClauseProcessor"""
+ raise NotImplementedError()
+ def copy_and_process(self, list_):
+ """copy the container elements in the given list to a new list and
+ process the new list."""
+ list_ = [o.copy_container() for o in list_]
+ self.process_list(list_)
+ return list_
+
+ def process_list(self, list_):
+ """process all elements of the given list in-place"""
+ for i in range(0, len(list_)):
+ elem = self.convert_element(list_[i])
+ if elem is not None:
+ list_[i] = elem
+ else:
+ list_[i].accept_visitor(self)
+ def visit_compound(self, compound):
+ self.visit_clauselist(compound)
+ def visit_clauselist(self, clist):
+ for i in range(0, len(clist.clauses)):
+ n = self.convert_element(clist.clauses[i])
+ if n is not None:
+ clist.clauses[i] = n
+ def visit_binary(self, binary):
+ elem = self.convert_element(binary.left)
+ if elem is not None:
+ binary.left = elem
+ elem = self.convert_element(binary.right)
+ if elem is not None:
+ binary.right = elem
+
+class Aliasizer(AbstractClauseProcessor):
"""converts a table instance within an expression to be an alias of that table."""
def __init__(self, *tables, **kwargs):
self.tables = {}
self.binary = None
def get_alias(self, table):
return self.aliases[table]
- def visit_compound(self, compound):
- self.visit_clauselist(compound)
- def visit_clauselist(self, clist):
- for i in range(0, len(clist.clauses)):
- if isinstance(clist.clauses[i], schema.Column) and self.tables.has_key(clist.clauses[i].table):
- orig = clist.clauses[i]
- clist.clauses[i] = self.get_alias(clist.clauses[i].table).corresponding_column(clist.clauses[i])
- def visit_binary(self, binary):
- if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table):
- binary.left = self.get_alias(binary.left.table).corresponding_column(binary.left)
- if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table):
- binary.right = self.get_alias(binary.right.table).corresponding_column(binary.right)
-
+ def convert_element(self, elem):
+ if isinstance(elem, sql.ColumnElement) and hasattr(elem, 'table') and self.tables.has_key(elem.table):
+ return self.get_alias(elem.table).corresponding_column(elem)
+ else:
+ return None
-class ClauseAdapter(sql.ClauseVisitor):
+class ClauseAdapter(AbstractClauseProcessor):
"""given a clause (like as in a WHERE criterion), locates columns which 'correspond' to a given selectable,
and changes those columns to be that of the selectable.
self.include = include
self.exclude = exclude
self.equivalents = equivalents
- def include_col(self, col):
+
+ def convert_element(self, col):
if not isinstance(col, sql.ColumnElement):
return None
if self.include is not None:
if newcol is None and self.equivalents is not None and col in self.equivalents:
newcol = self.selectable.corresponding_column(self.equivalents[col], raiseerr=False, keys_ok=False)
return newcol
- def visit_binary(self, binary):
- col = self.include_col(binary.left)
- if col is not None:
- binary.left = col
- col = self.include_col(binary.right)
- if col is not None:
- binary.right = col
-
-class ColumnsInClause(sql.ClauseVisitor):
- """given a selectable, visits clauses and determines if any columns from the clause are in the selectable"""
- def __init__(self, selectable):
- self.selectable = selectable
- self.result = False
- def visit_column(self, column):
- if self.selectable.c.get(column.key) is column:
- self.result = True
+
else:
return x
+def flatten_iterator(x):
+ """given an iterator of which further sub-elements may also be iterators,
+ flatten the sub-elements into a single iterator."""
+ for elem in x:
+ if hasattr(elem, '__iter__'):
+ for y in flatten_iterator(elem):
+ yield y
+ else:
+ yield elem
+
def reversed(seq):
try:
return __builtin__.reversed(seq)
import testbase
from sqlalchemy import *
from sqlalchemy.ext.selectresults import SelectResults
+import random
class EagerTest(AssertMixin):
def setUpAll(self):
session.clear()
obj = session.query(Left).get_by(tag='tag1')
print obj.middle.right[0]
+
+class EagerTest3(testbase.ORMTest):
+ """test eager loading combined with nested SELECT statements, functions, and aggregates"""
+ def define_tables(self, metadata):
+ global datas, foo, stats
+ datas=Table( 'datas',metadata,
+ Column ( 'id', Integer, primary_key=True,nullable=False ),
+ Column ( 'a', Integer , nullable=False ) )
+
+ foo=Table('foo',metadata,
+ Column ( 'data_id', Integer, ForeignKey('datas.id'),nullable=False,primary_key=True ),
+ Column ( 'bar', Integer ) )
+
+ stats=Table('stats',metadata,
+ Column ( 'id', Integer, primary_key=True, nullable=False ),
+ Column ( 'data_id', Integer, ForeignKey('datas.id')),
+ Column ( 'somedata', Integer, nullable=False ))
+
+ def test_nesting_with_functions(self):
+ class Data(object): pass
+ class Foo(object):pass
+ class Stat(object): pass
+
+ Data.mapper=mapper(Data,datas)
+ Foo.mapper=mapper(Foo,foo,properties={'data':relation(Data,backref=backref('foo',uselist=False))})
+ Stat.mapper=mapper(Stat,stats,properties={'data':relation(Data)})
+
+ s=create_session()
+ data = []
+ for x in range(5):
+ d=Data()
+ d.a=x
+ s.save(d)
+ data.append(d)
+
+ for x in range(10):
+ rid=random.randint(0,len(data) - 1)
+ somedata=random.randint(1,50000)
+ stat=Stat()
+ stat.data = data[rid]
+ stat.somedata=somedata
+ s.save(stat)
+
+ s.flush()
+
+ arb_data=select(
+ [stats.c.data_id,func.max(stats.c.somedata).label('max')],
+ stats.c.data_id<=25,
+ group_by=[stats.c.data_id]).alias('arb')
+
+ arb_result = arb_data.execute().fetchall()
+ # order the result list descending based on 'max'
+ arb_result.sort(lambda a, b:cmp(b['max'],a['max']))
+ # extract just the "data_id" from it
+ arb_result = [row['data_id'] for row in arb_result]
+
+ # now query for Data objects using that above select, adding the
+ # "order by max desc" separately
+ q=s.query(Data).options(eagerload('foo')).select(
+ from_obj=[datas.join(arb_data,arb_data.c.data_id==datas.c.id)],
+ order_by=[desc(arb_data.c.max)],limit=10)
+
+ # extract "data_id" from the list of result objects
+ verify_result = [d.id for d in q]
+ # assert equality including ordering (may break if the DB "ORDER BY" and python's sort() used differing
+ # algorithms and there are repeated 'somedata' values in the list)
+ assert verify_result == arb_result
if __name__ == "__main__":
testbase.main()
def testcustomeagerquery(self):
mapper(User, users, properties={
- 'addresses':relation(Address, lazy=False)
+ # setting lazy=True - the contains_eager() option below
+ # should imply eagerload()
+ 'addresses':relation(Address, lazy=True)
})
mapper(Address, addresses)