I don't like how difficult it was to get Query() to do it, however.
self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
def _set_select_from(self, *obj):
-
+
fa = []
for from_obj in obj:
if isinstance(from_obj, expression._SelectBaseMixin):
self._from_obj = tuple(fa)
- # TODO: only use this adapter for from_self() ? right
- # now its usage is somewhat arbitrary.
- if len(self._from_obj) == 1 and isinstance(self._from_obj[0], expression.Alias):
+ if len(self._from_obj) == 1 and \
+ isinstance(self._from_obj[0], expression.Alias):
equivs = self.__all_equivs()
self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj[0], equivs)
if entities:
q._set_entities(entities)
return q
-
+
@_generative()
def _from_selectable(self, fromclause):
for attr in ('_statement', '_criterion', '_order_by', '_group_by',
self._with_polymorphic = with_polymorphic
self._polymorphic_discriminator = None
self.is_aliased_class = is_aliased_class
+ self.disable_aliasing = False
if is_aliased_class:
self.path_entity = self.entity = self.entity_zero = entity
else:
query._entities.append(self)
def _get_entity_clauses(self, query, context):
-
+ if self.disable_aliasing:
+ return None
+
adapter = None
if not self.is_aliased_class and query._polymorphic_adapters:
adapter = query._polymorphic_adapters.get(self.mapper, None)
def __str__(self):
return str(self.mapper)
-
class _ColumnEntity(_QueryEntity):
"""Column/expression based entity."""
for c in leftmost_cols
]
- local_attr = [
- self.parent._get_col_to_prop(c).class_attribute
- for c in local_cols
- ]
-
# modify the query to just look for parent columns in the
# join condition
q._attributes[('subquery_path', None)] = subq_path
# now select from it as a subquery.
- q = q.from_self(self.mapper, *local_attr)
+ local_attr = [
+ self.parent._get_col_to_prop(c).class_attribute
+ for c in local_cols
+ ]
+
+ q = q.from_self(self.mapper)
+ q._entities[0].disable_aliasing = True
- # and join to the related thing we want
- # to load.
- for mapper, key in [(subq_path[i], subq_path[i+1])
- for i in xrange(0, len(subq_path), 2)]:
+ to_join = [(subq_path[i], subq_path[i+1])
+ for i in xrange(0, len(subq_path), 2)]
+
+ for i, (mapper, key) in enumerate(to_join):
+ alias_join = i < len(to_join) - 1
+ second_to_last = i == len(to_join) - 2
+
prop = mapper.get_property(key)
- q = q.join(prop.class_attribute)
+ q = q.join(prop.class_attribute, aliased=alias_join)
- #join_on = [(subq_path[i], subq_path[i+1])
- # for i in xrange(0, len(subq_path), 2)]
- #for i, (mapper, key) in enumerate(join_on):
- # aliased = i != len(join_on) - 1
- # prop = mapper.get_property(key)
- # q = q.join(prop.class_attribute, aliased=aliased)
-
- q = q.order_by(*local_attr)
+ if alias_join and second_to_last:
+ cols = [
+ q._adapt_clause(col, True, False)
+ for col in local_cols
+ ]
+ for col in cols:
+ q = q.add_column(col)
+ q = q.order_by(*cols)
+ if len(to_join) < 2:
+ local_attr = [
+ self.parent._get_col_to_prop(c).class_attribute
+ for c in local_cols
+ ]
+
+ for col in local_attr:
+ q = q.add_column(col)
+ q = q.order_by(*local_attr)
+
+
# propagate loader options etc. to the new query
q = q._with_current_path(subq_path)
q = q._conditional_options(*orig_query._with_options)
local_cols, remote_cols = self._local_remote_columns(self.parent_property)
- local_attr = [self.parent._get_col_to_prop(c).key for c in local_cols]
remote_attr = [
self.mapper._get_col_to_prop(c).key
for c in remote_cols]
])
self.assert_sql_count(testing.db, go, 2)
-class SelfReferentialEagerTest(_base.MappedTest):
+class SelfReferentialTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):
Table('nodes', metadata,
@testing.fails_on('maxdb', 'FIXME: unknown')
@testing.resolve_artifact_names
- def _test_basic(self):
+ def test_basic(self):
class Node(_base.ComparableEntity):
def append(self, node):
self.children.append(node)
n1.append(Node(data='n11'))
n1.append(Node(data='n12'))
n1.append(Node(data='n13'))
-# n1.children[1].append(Node(data='n121'))
-# n1.children[1].append(Node(data='n122'))
-# n1.children[1].append(Node(data='n123'))
+ n1.children[1].append(Node(data='n121'))
+ n1.children[1].append(Node(data='n122'))
+ n1.children[1].append(Node(data='n123'))
n2 = Node(data='n2')
n2.append(Node(data='n21'))
-# n2.children[0].append(Node(data='n211'))
-# n2.children[0].append(Node(data='n212'))
+ n2.children[0].append(Node(data='n211'))
+ n2.children[0].append(Node(data='n212'))
sess.add(n1)
sess.add(n2)
eq_([Node(data='n1', children=[
Node(data='n11'),
Node(data='n12', children=[
-# Node(data='n121'),
-# Node(data='n122'),
-# Node(data='n123')
+ Node(data='n121'),
+ Node(data='n122'),
+ Node(data='n123')
]),
Node(data='n13')
]),
Node(data='n2', children=[
Node(data='n21', children=[
-# Node(data='n211'),
-# Node(data='n212'),
+ Node(data='n211'),
+ Node(data='n212'),
])
])
], d)
- self.assert_sql_count(testing.db, go, 1)
+ self.assert_sql_count(testing.db, go, 4)