]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
this version actually works for all existing tests plus simple self-referential.
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Mar 2010 22:33:31 +0000 (18:33 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Mar 2010 22:33:31 +0000 (18:33 -0400)
I don't like how difficult it was to get Query() to do it, however.

lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_subquery_relations.py

index f067172174ca5b638e1f4798733f7b3e22065ba6..2dfefc43326eafcb68ad6452b02a783c3e5dc66f 100644 (file)
@@ -134,7 +134,7 @@ class Query(object):
                 self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
 
     def _set_select_from(self, *obj):
-        
+
         fa = []
         for from_obj in obj:
             if isinstance(from_obj, expression._SelectBaseMixin):
@@ -143,9 +143,8 @@ class Query(object):
 
         self._from_obj = tuple(fa)
 
-        # TODO: only use this adapter for from_self() ?   right
-        # now its usage is somewhat arbitrary.
-        if len(self._from_obj) == 1 and isinstance(self._from_obj[0], expression.Alias):
+        if len(self._from_obj) == 1 and \
+            isinstance(self._from_obj[0], expression.Alias):
             equivs = self.__all_equivs()
             self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj[0], equivs)
         
@@ -625,7 +624,7 @@ class Query(object):
         if entities:
             q._set_entities(entities)
         return q
-
+    
     @_generative()
     def _from_selectable(self, fromclause):
         for attr in ('_statement', '_criterion', '_order_by', '_group_by',
@@ -2139,6 +2138,7 @@ class _MapperEntity(_QueryEntity):
         self._with_polymorphic = with_polymorphic
         self._polymorphic_discriminator = None
         self.is_aliased_class = is_aliased_class
+        self.disable_aliasing = False
         if is_aliased_class:
             self.path_entity = self.entity = self.entity_zero = entity
         else:
@@ -2170,7 +2170,9 @@ class _MapperEntity(_QueryEntity):
         query._entities.append(self)
 
     def _get_entity_clauses(self, query, context):
-
+        if self.disable_aliasing:
+            return None
+            
         adapter = None
         if not self.is_aliased_class and query._polymorphic_adapters:
             adapter = query._polymorphic_adapters.get(self.mapper, None)
@@ -2251,7 +2253,6 @@ class _MapperEntity(_QueryEntity):
     def __str__(self):
         return str(self.mapper)
 
-
 class _ColumnEntity(_QueryEntity):
     """Column/expression based entity."""
 
index f507bfbe5327a92c21248ebc20e7b606f512cc00..4431b408fc59fc506d11804e92cb80cc4e50d160 100644 (file)
@@ -692,11 +692,6 @@ class SubqueryLoader(AbstractRelationshipLoader):
             for c in leftmost_cols
         ]
 
-        local_attr = [
-            self.parent._get_col_to_prop(c).class_attribute
-            for c in local_cols
-        ]
-        
         # modify the query to just look for parent columns in the 
         # join condition
         
@@ -713,24 +708,44 @@ class SubqueryLoader(AbstractRelationshipLoader):
         q._attributes[('subquery_path', None)] = subq_path
 
         # now select from it as a subquery.
-        q = q.from_self(self.mapper, *local_attr)
+        local_attr = [
+            self.parent._get_col_to_prop(c).class_attribute
+            for c in local_cols
+        ]
+
+        q = q.from_self(self.mapper)
+        q._entities[0].disable_aliasing = True
 
-        # and join to the related thing we want
-        # to load.
-        for mapper, key in [(subq_path[i], subq_path[i+1]) 
-                            for i in xrange(0, len(subq_path), 2)]:
+        to_join = [(subq_path[i], subq_path[i+1]) 
+                            for i in xrange(0, len(subq_path), 2)]
+        
+        for i, (mapper, key) in enumerate(to_join):
+            alias_join = i < len(to_join) - 1
+            second_to_last = i == len(to_join) - 2
+            
             prop = mapper.get_property(key)
-            q = q.join(prop.class_attribute)
+            q = q.join(prop.class_attribute, aliased=alias_join)
             
-        #join_on = [(subq_path[i], subq_path[i+1]) 
-        #        for i in xrange(0, len(subq_path), 2)]
-        #for i, (mapper, key) in enumerate(join_on):
-        #    aliased = i != len(join_on) - 1
-        #    prop = mapper.get_property(key)
-        #    q = q.join(prop.class_attribute, aliased=aliased)
-
-        q = q.order_by(*local_attr)
+            if alias_join and second_to_last:
+                cols = [
+                    q._adapt_clause(col, True, False)
+                    for col in local_cols
+                ]
+                for col in cols:
+                    q = q.add_column(col)
+                q = q.order_by(*cols)
         
+        if len(to_join) < 2:
+            local_attr = [
+                self.parent._get_col_to_prop(c).class_attribute
+                for c in local_cols
+            ]
+
+            for col in local_attr:
+                q = q.add_column(col)
+            q = q.order_by(*local_attr)
+                
+
         # propagate loader options etc. to the new query
         q = q._with_current_path(subq_path)
         q = q._conditional_options(*orig_query._with_options)
@@ -774,7 +789,6 @@ class SubqueryLoader(AbstractRelationshipLoader):
             
         local_cols, remote_cols = self._local_remote_columns(self.parent_property)
 
-        local_attr = [self.parent._get_col_to_prop(c).key for c in local_cols]
         remote_attr = [
                         self.mapper._get_col_to_prop(c).key 
                         for c in remote_cols]
index e1372fbfe67e7df55bd7d1a6ce0dc94284787b20..1be81568628c422c5e207b38f044c55b9db1570f 100644 (file)
@@ -569,7 +569,7 @@ class OrderBySecondaryTest(_base.MappedTest):
             ])
         self.assert_sql_count(testing.db, go, 2)
 
-class SelfReferentialEagerTest(_base.MappedTest):
+class SelfReferentialTest(_base.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table('nodes', metadata,
@@ -579,7 +579,7 @@ class SelfReferentialEagerTest(_base.MappedTest):
 
     @testing.fails_on('maxdb', 'FIXME: unknown')
     @testing.resolve_artifact_names
-    def _test_basic(self):
+    def test_basic(self):
         class Node(_base.ComparableEntity):
             def append(self, node):
                 self.children.append(node)
@@ -594,13 +594,13 @@ class SelfReferentialEagerTest(_base.MappedTest):
         n1.append(Node(data='n11'))
         n1.append(Node(data='n12'))
         n1.append(Node(data='n13'))
-#        n1.children[1].append(Node(data='n121'))
-#        n1.children[1].append(Node(data='n122'))
-#        n1.children[1].append(Node(data='n123'))
+        n1.children[1].append(Node(data='n121'))
+        n1.children[1].append(Node(data='n122'))
+        n1.children[1].append(Node(data='n123'))
         n2 = Node(data='n2')
         n2.append(Node(data='n21'))
-#        n2.children[0].append(Node(data='n211'))
-#        n2.children[0].append(Node(data='n212'))
+        n2.children[0].append(Node(data='n211'))
+        n2.children[0].append(Node(data='n212'))
         
         sess.add(n1)
         sess.add(n2)
@@ -612,20 +612,20 @@ class SelfReferentialEagerTest(_base.MappedTest):
             eq_([Node(data='n1', children=[
                     Node(data='n11'),
                     Node(data='n12', children=[
-#                        Node(data='n121'),
-#                        Node(data='n122'),
-#                        Node(data='n123')
+                        Node(data='n121'),
+                        Node(data='n122'),
+                        Node(data='n123')
                     ]),
                     Node(data='n13')
                 ]),
                 Node(data='n2', children=[
                     Node(data='n21', children=[
-#                        Node(data='n211'),
-#                        Node(data='n212'),
+                        Node(data='n211'),
+                        Node(data='n212'),
                     ])
                 ])
             ], d)
-        self.assert_sql_count(testing.db, go, 1)
+        self.assert_sql_count(testing.db, go, 4)