]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Patched a case where query.join() would adapt the
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Sep 2010 14:11:10 +0000 (10:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Sep 2010 14:11:10 +0000 (10:11 -0400)
right side to the right side of the left's join
inappropriately [ticket:1925]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/util.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index 4e49bb9237659536b0d4f994a069f5830510a157..9d9bc9f01e92b2fe9dfeef2fd8d0d8cbea48ee57 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -16,6 +16,10 @@ CHANGES
     passed an empty list to "include_properties" on 
     mapper() [ticket:1918]
   
+  - Patched a case where query.join() would adapt the
+    right side to the right side of the left's join
+    inappropriately [ticket:1925]
+    
   - The exception raised by Session when it is used
     subsequent to a subtransaction rollback (which is what
     happens when a flush fails in autocommit=False mode) has
index b22a10b55ecb0c6f1c839fa3c92e65f74d63402c..0ce84435ffc672e4432c0b8ecc410a80489f8cff 100644 (file)
@@ -1250,7 +1250,7 @@ class Query(object):
                         (left, right))
             
         left_mapper, left_selectable, left_is_aliased = _entity_info(left)
-        right_mapper, right_selectable, is_aliased_class = _entity_info(right)
+        right_mapper, right_selectable, right_is_aliased = _entity_info(right)
 
         if right_mapper and prop and \
                 not right_mapper.common_parent(prop.mapper):
@@ -1279,7 +1279,7 @@ class Query(object):
             need_adapter = True
 
         aliased_entity = right_mapper and \
-                            not is_aliased_class and \
+                            not right_is_aliased and \
                             (
                                 right_mapper.with_polymorphic or
                                 isinstance(
@@ -1342,8 +1342,16 @@ class Query(object):
                         )
                     )
         
-        join_to_left = not is_aliased_class and not left_is_aliased
-
+        # this is an overly broad assumption here, but there's a 
+        # very wide variety of situations where we rely upon orm.join's
+        # adaption to glue clauses together, with joined-table inheritance's
+        # wide array of variables taking up most of the space.
+        # Setting the flag here is still a guess, so it is a bug
+        # that we don't have definitive criterion to determine when 
+        # adaption should be enabled (or perhaps that we're even doing the 
+        # whole thing the way we are here).
+        join_to_left = not right_is_aliased and not left_is_aliased
+            
         if self._from_obj and left_selectable is not None:
             replace_clause_index, clause = sql_util.find_join_source(
                                                     self._from_obj, 
@@ -1351,10 +1359,16 @@ class Query(object):
             if clause is not None:
                 # the entire query's FROM clause is an alias of itself (i.e.
                 # from_self(), similar). if the left clause is that one,
-                # ensure it aliases to the left side.
+                # ensure it adapts to the left side.
                 if self._from_obj_alias and clause is self._from_obj[0]:
                     join_to_left = True
-
+                
+                # An exception case where adaption to the left edge is not
+                # desirable.  See above note on join_to_left.
+                if join_to_left and isinstance(clause, expression.Join) and \
+                    sql_util.clause_is_present(left_selectable, clause):
+                    join_to_left = False
+                    
                 clause = orm_join(clause, 
                                     right, 
                                     onclause, isouter=outerjoin, 
index bd4f70247fac5ffef6f5ede83d5e616a082c498a..638549e12c2abbac17e80c4249e9999cf00df1e1 100644 (file)
@@ -92,6 +92,25 @@ def find_columns(clause):
     visitors.traverse(clause, {}, {'column':cols.add})
     return cols
 
+def clause_is_present(clause, search):
+    """Given a target clause and a second to search within, return True
+    if the target is plainly present in the search without any
+    subqueries or aliases involved.
+    
+    Basically descends through Joins.
+    
+    """
+
+    stack = [search]
+    while stack:
+        elem = stack.pop()
+        if clause is elem:
+            return True
+        elif isinstance(elem, expression.Join):
+            stack.extend((elem.left, elem.right))
+    return False
+    
+    
 def bind_values(clause):
     """Return an ordered list of "bound" values in the given clause.
 
index 3a6436610f6eafcaed40bd5761fc211c3f6c38b5..22e5ac84fc1cd2e6f9b006b9fb388b7f881d0fdc 100644 (file)
@@ -3618,6 +3618,88 @@ class CustomJoinTest(QueryTest):
             [User(id=7)]
         )
 
+class SelfRefMixedTest(_base.MappedTest, AssertsCompiledSQL):
+    run_setup_mappers = 'once'
+
+    @classmethod
+    def define_tables(cls, metadata):
+        nodes = Table('nodes', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('parent_id', Integer, ForeignKey('nodes.id'))
+        )
+        
+        sub_table = Table('sub_table', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('node_id', Integer, ForeignKey('nodes.id')),
+        )
+        
+        assoc_table = Table('assoc_table', metadata,
+            Column('left_id', Integer, ForeignKey('nodes.id')),
+            Column('right_id', Integer, ForeignKey('nodes.id'))
+        )
+        
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_classes(cls):
+        class Node(Base):
+            pass
+        
+        class Sub(Base):
+            pass
+
+        mapper(Node, nodes, properties={
+            'children':relationship(Node, lazy='select', join_depth=3,
+                backref=backref('parent', remote_side=[nodes.c.id])
+            ),
+            'subs' : relationship(Sub),
+            'assoc':relationship(Node, 
+                            secondary=assoc_table, 
+                            primaryjoin=nodes.c.id==assoc_table.c.left_id, 
+                            secondaryjoin=nodes.c.id==assoc_table.c.right_id)
+        })
+        mapper(Sub, sub_table)
+
+    @testing.resolve_artifact_names
+    def test_o2m_aliased_plus_o2m(self):
+        sess = create_session()
+        n1 = aliased(Node)
+
+        self.assert_compile(
+            sess.query(Node).join((n1, Node.children)).join((Sub, n1.subs)),
+            "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id "
+            "FROM nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+            "JOIN sub_table ON nodes_1.id = sub_table.node_id"
+        )
+    
+        self.assert_compile(
+            sess.query(Node).join((n1, Node.children)).join((Sub, Node.subs)),
+            "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id "
+            "FROM nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+            "JOIN sub_table ON nodes.id = sub_table.node_id"
+        )
+
+    @testing.resolve_artifact_names
+    def test_m2m_aliased_plus_o2m(self):
+        sess = create_session()
+        n1 = aliased(Node)
+
+        self.assert_compile(
+            sess.query(Node).join((n1, Node.assoc)).join((Sub, n1.subs)),
+            "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id "
+            "FROM nodes JOIN assoc_table AS assoc_table_1 ON nodes.id = "
+            "assoc_table_1.left_id JOIN nodes AS nodes_1 ON nodes_1.id = "
+            "assoc_table_1.right_id JOIN sub_table ON nodes_1.id = sub_table.node_id",
+        )
+    
+        self.assert_compile(
+            sess.query(Node).join((n1, Node.assoc)).join((Sub, Node.subs)),
+            "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id "
+            "FROM nodes JOIN assoc_table AS assoc_table_1 ON nodes.id = "
+            "assoc_table_1.left_id JOIN nodes AS nodes_1 ON nodes_1.id = "
+            "assoc_table_1.right_id JOIN sub_table ON nodes.id = sub_table.node_id",
+        )
+        
+    
 class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
     run_setup_mappers = 'once'
     run_inserts = 'once'
@@ -3630,20 +3712,23 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
             Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
             Column('parent_id', Integer, ForeignKey('nodes.id')),
             Column('data', String(30)))
-
+        
     @classmethod
     def insert_data(cls):
-        global Node
-    
+        # TODO: somehow using setup_classes()
+        # here normally is screwing up the other tests.
+        
+        global Node, Sub
         class Node(Base):
             def append(self, node):
                 self.children.append(node)
-
+        
         mapper(Node, nodes, properties={
             'children':relationship(Node, lazy='select', join_depth=3,
                 backref=backref('parent', remote_side=[nodes.c.id])
-            )
+            ),
         })
+
         sess = create_session()
         n1 = Node(data='n1')
         n1.append(Node(data='n11'))
@@ -3656,6 +3741,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
         sess.flush()
         sess.close()
     
+    @testing.resolve_artifact_names
     def test_join(self):
         sess = create_session()
 
@@ -3673,6 +3759,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
             join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
         assert node.data == 'n122'
     
+    @testing.resolve_artifact_names
     def test_string_or_prop_aliased(self):
         """test that join('foo') behaves the same as join(Cls.foo) in a self
         referential scenario.
@@ -3721,6 +3808,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
                 use_default_dialect=True
             )
         
+    @testing.resolve_artifact_names
     def test_from_self_inside_excludes_outside(self):
         """test the propagation of aliased() from inside to outside
         on a from_self()..
@@ -3774,11 +3862,43 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
             use_default_dialect=True
         )
         
+    @testing.resolve_artifact_names
     def test_explicit_join(self):
         sess = create_session()
     
         n1 = aliased(Node)
         n2 = aliased(Node)
+        
+        self.assert_compile(
+            join(Node, n1, 'children').join(n2, 'children'),
+            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id"
+        )
+
+        self.assert_compile(
+            join(Node, n1, Node.children).join(n2, n1.children),
+            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id"
+        )
+
+        # the join_to_left=False here is unfortunate.   the default on this flag should
+        # be False.
+        self.assert_compile(
+            join(Node, n1, Node.children).join(n2, Node.children, join_to_left=False),
+            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id"
+        )
+
+        self.assert_compile(
+            sess.query(Node).join((n1, Node.children)).join((n2, n1.children)),
+            "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS "
+            "nodes_data FROM nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+            "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id"
+        )
+    
+        self.assert_compile(
+            sess.query(Node).join((n1, Node.children)).join((n2, Node.children)),
+            "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS "
+            "nodes_data FROM nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+            "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id"
+        )
     
         node = sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data=='n122').first()
         assert node.data=='n12'
@@ -3800,7 +3920,8 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
             list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
             filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
             [('n122', 'n12', 'n1')])
-
+    
+    @testing.resolve_artifact_names
     def test_join_to_nonaliased(self):
         sess = create_session()
     
@@ -3819,6 +3940,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
         )
     
         
+    @testing.resolve_artifact_names
     def test_multiple_explicit_entities(self):
         sess = create_session()
     
@@ -3868,6 +3990,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
         )
     
     
+    @testing.resolve_artifact_names
     def test_any(self):
         sess = create_session()
         eq_(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
@@ -3875,6 +3998,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
         eq_(sess.query(Node).filter(~Node.children.any()).order_by(Node.id).all(), 
                 [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
 
+    @testing.resolve_artifact_names
     def test_has(self):
         sess = create_session()
     
@@ -3883,6 +4007,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
         eq_(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
         eq_(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
 
+    @testing.resolve_artifact_names
     def test_contains(self):
         sess = create_session()
     
@@ -3892,6 +4017,7 @@ class SelfReferentialTest(_base.MappedTest, AssertsCompiledSQL):
         n13 = sess.query(Node).filter(Node.data=='n13').one()
         eq_(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
 
+    @testing.resolve_artifact_names
     def test_eq_ne(self):
         sess = create_session()