path = path + (self.key,)
# check for user-defined eager alias
- if ("eager_row_processor", path) in context.attributes:
- clauses = context.attributes[("eager_row_processor", path)]
+ if ("user_defined_eager_row_processor", path) in context.attributes:
+ clauses = context.attributes[("user_defined_eager_row_processor", path)]
adapter = entity._get_entity_clauses(context.query, context)
if adapter and clauses:
- context.attributes[("eager_row_processor", path)] = clauses = adapter.wrap(clauses)
+ context.attributes[("user_defined_eager_row_processor", path)] = clauses = clauses.wrap(adapter)
elif adapter:
- context.attributes[("eager_row_processor", path)] = clauses = adapter
-
+ context.attributes[("user_defined_eager_row_processor", path)] = clauses = adapter
+
+ add_to_collection = context.primary_columns
+
else:
clauses = self._create_eager_join(context, entity, path, adapter, parentmapper)
if not clauses:
return
context.attributes[("eager_row_processor", path)] = clauses
+
+ add_to_collection = context.secondary_columns
for value in self.mapper._iterate_polymorphic_properties():
- value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=context.secondary_columns)
+ value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=add_to_collection)
def _create_eager_join(self, context, entity, path, adapter, parentmapper):
# check for join_depth or basic recursion,
return clauses
def _create_eager_adapter(self, context, row, adapter, path):
- if ("eager_row_processor", path) in context.attributes:
+ if ("user_defined_eager_row_processor", path) in context.attributes:
+ decorator = context.attributes[("user_defined_eager_row_processor", path)]
+ # user defined eagerloads are part of the "primary" portion of the load.
+ # the adapters applied to the Query should be honored.
+ if context.adapter and decorator:
+ decorator = decorator.wrap(context.adapter)
+ elif context.adapter:
+ decorator = context.adapter
+ elif ("eager_row_processor", path) in context.attributes:
decorator = context.attributes[("eager_row_processor", path)]
else:
if self._should_log_debug:
prop = mapper.get_property(propname, resolve_synonyms=True)
self.alias = prop.target.alias(self.alias)
- query._attributes[("eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
+ query._attributes[("user_defined_eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
else:
- query._attributes[("eager_row_processor", paths[-1])] = None
+ query._attributes[("user_defined_eager_row_processor", paths[-1])] = None
assert fixtures.user_address_result == l
self.assert_sql_count(testing.db, go, 1)
+ # same thing, but alias addresses, so that the adapter generated by select_from() is wrapped within
+ # the adapter created by contains_eager()
+ adalias = addresses.alias()
+ query = users.select(users.c.id==7).union(users.select(users.c.id>7)).alias('ulist').outerjoin(adalias).select(use_labels=True,order_by=['ulist.id', adalias.c.id])
+ def go():
+ l = sess.query(User).select_from(query).options(contains_eager('addresses', alias=adalias)).all()
+ assert fixtures.user_address_result == l
+ self.assert_sql_count(testing.db, go, 1)
+
def test_contains_eager(self):
sess = create_session()
# test that contains_eager suppresses the normal outer join rendering
q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses)).order_by(User.id)
- self.assert_compile(q.with_labels().statement, "SELECT users.id AS users_id, users.name AS users_name, "\
- "addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\
- "addresses.email_address AS addresses_email_address FROM users LEFT OUTER JOIN addresses "\
- "ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
-
+ self.assert_compile(q.with_labels().statement,
+ "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\
+ "addresses.email_address AS addresses_email_address, users.id AS users_id, "\
+ "users.name AS users_name FROM users LEFT OUTER JOIN addresses "\
+ "ON users.id = addresses.user_id ORDER BY users.id"
+ , dialect=default.DefaultDialect())
+
def go():
assert fixtures.user_address_result == q.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
-
def go():
l = list(q.options(contains_eager(User.addresses)).instances(selectquery.execute()))
assert fixtures.user_address_result[0:3] == l
assert fixtures.user_address_result == l.all()
self.assert_sql_count(testing.db, go, 1)
sess.clear()
-
oalias = orders.alias('o1')
ialias = items.alias('i1')
self.assert_sql_count(testing.db, go, 1)
sess.clear()
+ def test_mixed_eager_contains_with_limit(self):
+ sess = create_session()
+
+ q = sess.query(User)
+ def go():
+ # outerjoin to User.orders, offset 1/limit 2 so we get user 7 + second two orders.
+ # then eagerload the addresses. User + Order columns go into the subquery, address
+ # left outer joins to the subquery, eagerloader for User.orders applies context.adapter
+ # to result rows. This was [ticket:1180].
+ l = q.outerjoin(User.orders).options(eagerload(User.addresses), contains_eager(User.orders)).offset(1).limit(2).all()
+ eq_(l, [User(id=7,
+ addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)],
+ name=u'jack',
+ orders=[
+ Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3),
+ Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5)
+ ])])
+ self.assert_sql_count(testing.db, go, 1)
+ sess.clear()
+
+ def go():
+ # same as above, except Order is aliased, so two adapters are applied by the
+ # eager loader
+ oalias = aliased(Order)
+ l = q.outerjoin(User.orders, oalias).options(eagerload(User.addresses), contains_eager(User.orders, alias=oalias)).offset(1).limit(2).all()
+ eq_(l, [User(id=7,
+ addresses=[Address(email_address=u'jack@bean.com',user_id=7,id=1)],
+ name=u'jack',
+ orders=[
+ Order(address_id=1,user_id=7,description=u'order 3',isopen=1,id=3),
+ Order(address_id=None,user_id=7,description=u'order 5',isopen=0,id=5)
+ ])])
+ self.assert_sql_count(testing.db, go, 1)
+
+
class MixedEntitiesTest(QueryTest):
def test_values(self):