]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
everything passes with this!!!!!!! holy crap !!!!! and its the simplest of all
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 25 Apr 2013 17:54:40 +0000 (13:54 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 25 Apr 2013 17:54:40 +0000 (13:54 -0400)
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/orm/test_joins.py

index dd24d8bf44c0526c7ff050a13c84b04b88d6f874..9472f86980c56234ebbc297e5d3d660dd6df9c6a 100644 (file)
@@ -874,70 +874,82 @@ class _ORMJoin(expression.Join):
                             isouter=False, join_to_left=True):
 
         adapt_from = None
-        if hasattr(left, '_orm_mappers'):
-            left_mapper = left._orm_mappers[1]
-        else:
-            info = inspection.inspect(left)
-            left_mapper = getattr(info, 'mapper', None)
-
         left_info = inspection.inspect(left)
-        left_selectable = left_info.selectable
 
-        info = inspection.inspect(right)
-        right_mapper = getattr(info, 'mapper', None)
-        right = info.selectable
-        right_is_aliased = getattr(info, 'is_aliased_class', False)
+        if hasattr(left, '_orm_infos'):
+            left_orm_info = left._orm_infos[1]
+        else:
+            #if isinstance(left, expression.Join):
+            #    info = inspection.inspect(left.right)
+            #else:
+            #    info = inspection.inspect(left)
+            left_orm_info = left_info
 
-        if right_is_aliased:
-            adapt_to = right
+        right_info = inspection.inspect(right)
+
+        if getattr(right_info, 'is_aliased_class', False):
+            adapt_to = right_info.selectable
         else:
             adapt_to = None
 
-        if left_mapper or right_mapper:
-            self._orm_mappers = (left_mapper, right_mapper)
-
-            if isinstance(onclause, basestring):
-                prop = left_mapper.get_property(onclause)
-                on_selectable = prop.parent.selectable
-            elif isinstance(onclause, attributes.QueryableAttribute):
-                on_selectable = onclause.comparator._source_selectable()
-                #if adapt_from is None:
-                #    adapt_from = onclause.comparator._source_selectable()
-                prop = onclause.property
-            elif isinstance(onclause, MapperProperty):
-                prop = onclause
-                on_selectable = prop.parent.selectable
+#        import pdb
+#        pdb.set_trace()
+        self._orm_infos = (left_orm_info, right_info)
+
+        if isinstance(onclause, basestring):
+            onclause = getattr(left_orm_info.entity, onclause)
+
+        if isinstance(onclause, attributes.QueryableAttribute):
+            on_selectable = onclause.comparator._source_selectable()
+            prop = onclause.property
+        elif isinstance(onclause, MapperProperty):
+            prop = onclause
+            on_selectable = prop.parent.selectable
+        else:
+            prop = None
+
+        if prop:
+            #import pdb
+            #pdb.set_trace()
+            if sql_util.clause_is_present(on_selectable, left_info.selectable):
+                adapt_from = on_selectable
+            else:
+                adapt_from = left_info.selectable
+#                import pdb
+#                pdb.set_trace()
+                #adapt_from = left_orm_info.selectable
+                #adapt_from = left_info.selectable
+#                adapt_from = None
+#            if adapt_from is None:
+#                _derived = []
+#                for s in expression._from_objects(left_info.selectable):
+#                    if s == on_selectable:
+#                        adapt_from = s
+#                        break
+#                    elif s.is_derived_from(on_selectable):
+#                        _derived.append(s)
+#                else:
+#                    if _derived:
+#                        adapt_from = _derived[0]
+
+            #if adapt_from is None:
+#            adapt_from = left_info.selectable
+
+            #adapt_from = None
+            pj, sj, source, dest, \
+                secondary, target_adapter = prop._create_joins(
+                            source_selectable=adapt_from,
+                            dest_selectable=adapt_to,
+                            source_polymorphic=True,
+                            dest_polymorphic=True,
+                            of_type=right_info.mapper)
+
+            if sj is not None:
+                left = sql.join(left, secondary, pj, isouter)
+                onclause = sj
             else:
-                prop = None
-
-            if prop:
-                import pdb
-                pdb.set_trace()
-                _derived = []
-                for s in expression._from_objects(left_selectable):
-                    if s == on_selectable:
-                        adapt_from = s
-                        break
-                    elif s.is_derived_from(on_selectable):
-                        _derived.append(s)
-                else:
-                    if _derived:
-                        adapt_from = _derived[0]
-
-                pj, sj, source, dest, \
-                    secondary, target_adapter = prop._create_joins(
-                                source_selectable=adapt_from,
-                                dest_selectable=adapt_to,
-                                source_polymorphic=True,
-                                dest_polymorphic=True,
-                                of_type=right_mapper)
-
-                if sj is not None:
-                    left = sql.join(left, secondary, pj, isouter)
-                    onclause = sj
-                else:
-                    onclause = pj
-                self._target_adapter = target_adapter
+                onclause = pj
+            self._target_adapter = target_adapter
 
         expression.Join.__init__(self, left, right, onclause, isouter)
 
index d2e644ce2b684db3c37b7bacf37d27da9efe7a5a..92b8aea985aba51205ffb298d2a58bbbca2934dc 100644 (file)
@@ -3909,8 +3909,14 @@ class Join(FromClause):
 
     def is_derived_from(self, fromclause):
         return fromclause is self or \
-                self.left.is_derived_from(fromclause) or\
-                self.right.is_derived_from(fromclause)
+                self.left.is_derived_from(fromclause) or \
+                self.right.is_derived_from(fromclause) or \
+                (
+                    isinstance(fromclause, Join) and
+                    self.left.is_derived_from(fromclause.left) and
+                    self.right.is_derived_from(fromclause.right) and
+                    self.onclause.compare(fromclause.onclause)
+                )
 
     def self_group(self, against=None):
         return FromGrouping(self)
@@ -3947,6 +3953,12 @@ class Join(FromClause):
     def get_children(self, **kwargs):
         return self.left, self.right, self.onclause
 
+    def compare(self, other):
+        return isinstance(other, Join) and \
+            self.left.compare(other.left) and \
+            self.right.compare(other.right) and \
+            self.onclause.compare(other.onclause)
+
     def _match_primaries(self, left, right):
         if isinstance(left, Join):
             left_right = left.right
index 520c90f999f9bca09f5df699dc56f66278aea1d8..4aa2d749686465f580b29fa324de6be7b1836b66 100644 (file)
@@ -203,7 +203,7 @@ def clause_is_present(clause, search):
     stack = [search]
     while stack:
         elem = stack.pop()
-        if clause is elem:
+        if clause == elem:  # use == here so that Annotated's compare
             return True
         elif isinstance(elem, expression.Join):
             stack.extend((elem.left, elem.right))
index 629c55ce513106115848337ea1bb5ad4f467eb60..2bf0d8d9293daf77eed341fd7a8dd37d4c0eabdf 100644 (file)
@@ -215,7 +215,7 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
             , use_default_dialect = True
         )
 
-    def test_prop_with_polymorphic(self):
+    def test_prop_with_polymorphic_1(self):
         Person, Manager, Paperwork = (self.classes.Person,
                                 self.classes.Manager,
                                 self.classes.Paperwork)
@@ -238,6 +238,13 @@ class InheritedJoinTest(fixtures.MappedTest, AssertsCompiledSQL):
                 , use_default_dialect=True
             )
 
+    def test_prop_with_polymorphic_2(self):
+        Person, Manager, Paperwork = (self.classes.Person,
+                                self.classes.Manager,
+                                self.classes.Paperwork)
+
+        sess = create_session()
+
         self.assert_compile(
             sess.query(Person).with_polymorphic(Manager).
                     join('paperwork', aliased=True).
@@ -1928,34 +1935,50 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
             use_default_dialect=True
         )
 
-    def test_explicit_join(self):
+    def test_explicit_join_1(self):
         Node = self.classes.Node
-
-        sess = create_session()
-
         n1 = aliased(Node)
         n2 = aliased(Node)
 
         self.assert_compile(
             join(Node, n1, 'children').join(n2, 'children'),
-            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
+            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+            "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
             use_default_dialect=True
         )
 
+    def test_explicit_join_2(self):
+        Node = self.classes.Node
+        n1 = aliased(Node)
+        n2 = aliased(Node)
+
         self.assert_compile(
             join(Node, n1, Node.children).join(n2, n1.children),
-            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
+            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+            "JOIN nodes AS nodes_2 ON nodes_1.id = nodes_2.parent_id",
             use_default_dialect=True
         )
 
+    def test_explicit_join_3(self):
+        Node = self.classes.Node
+        n1 = aliased(Node)
+        n2 = aliased(Node)
+
         # the join_to_left=False here is unfortunate.   the default on this flag should
         # be False.
         self.assert_compile(
             join(Node, n1, Node.children).join(n2, Node.children, join_to_left=False),
-            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id",
+            "nodes JOIN nodes AS nodes_1 ON nodes.id = nodes_1.parent_id "
+            "JOIN nodes AS nodes_2 ON nodes.id = nodes_2.parent_id",
             use_default_dialect=True
         )
 
+    def test_explicit_join_4(self):
+        Node = self.classes.Node
+        sess = create_session()
+        n1 = aliased(Node)
+        n2 = aliased(Node)
+
         self.assert_compile(
             sess.query(Node).join(n1, Node.children).join(n2, n1.children),
             "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS "
@@ -1964,6 +1987,12 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
             use_default_dialect=True
         )
 
+    def test_explicit_join_5(self):
+        Node = self.classes.Node
+        sess = create_session()
+        n1 = aliased(Node)
+        n2 = aliased(Node)
+
         self.assert_compile(
             sess.query(Node).join(n1, Node.children).join(n2, Node.children),
             "SELECT nodes.id AS nodes_id, nodes.parent_id AS nodes_parent_id, nodes.data AS "
@@ -1972,25 +2001,59 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
             use_default_dialect=True
         )
 
-        node = sess.query(Node).select_from(join(Node, n1, 'children')).filter(n1.data=='n122').first()
-        assert node.data=='n12'
+    def test_explicit_join_6(self):
+        Node = self.classes.Node
+        sess = create_session()
+        n1 = aliased(Node)
 
-        node = sess.query(Node).select_from(join(Node, n1, 'children').join(n2, 'children')).\
-            filter(n2.data=='n122').first()
-        assert node.data=='n1'
+        node = sess.query(Node).select_from(join(Node, n1, 'children')).\
+                filter(n1.data == 'n122').first()
+        assert node.data == 'n12'
+
+    def test_explicit_join_7(self):
+        Node = self.classes.Node
+        sess = create_session()
+        n1 = aliased(Node)
+        n2 = aliased(Node)
+
+        node = sess.query(Node).select_from(
+                join(Node, n1, 'children').join(n2, 'children')).\
+            filter(n2.data == 'n122').first()
+        assert node.data == 'n1'
+
+    def test_explicit_join_8(self):
+        Node = self.classes.Node
+        sess = create_session()
+        n1 = aliased(Node)
+        n2 = aliased(Node)
 
         # mix explicit and named onclauses
-        node = sess.query(Node).select_from(join(Node, n1, Node.id==n1.parent_id).join(n2, 'children')).\
-            filter(n2.data=='n122').first()
-        assert node.data=='n1'
+        node = sess.query(Node).select_from(
+                    join(Node, n1, Node.id == n1.parent_id).join(n2, 'children')).\
+            filter(n2.data == 'n122').first()
+        assert node.data == 'n1'
+
+    def test_explicit_join_9(self):
+        Node = self.classes.Node
+        sess = create_session()
+        n1 = aliased(Node)
+        n2 = aliased(Node)
 
         node = sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
-            filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).first()
+            filter(and_(Node.data == 'n122', n1.data == 'n12', n2.data == 'n1')).first()
         assert node.data == 'n122'
 
+    def test_explicit_join_10(self):
+        Node = self.classes.Node
+        sess = create_session()
+        n1 = aliased(Node)
+        n2 = aliased(Node)
+
         eq_(
             list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
-            filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
+            filter(and_(Node.data == 'n122',
+                        n1.data == 'n12',
+                        n2.data == 'n1')).values(Node.data, n1.data, n2.data)),
             [('n122', 'n12', 'n1')])
 
     def test_join_to_nonaliased(self):
@@ -2040,8 +2103,8 @@ class SelfReferentialTest(fixtures.MappedTest, AssertsCompiledSQL):
             sess.query(Node, parent, grandparent).\
                 join(parent, Node.parent).\
                 join(grandparent, parent.parent).\
-                    filter(Node.data=='n122').filter(parent.data=='n12').\
-                    filter(grandparent.data=='n1').from_self().first(),
+                    filter(Node.data == 'n122').filter(parent.data == 'n12').\
+                    filter(grandparent.data == 'n1').from_self().first(),
             (Node(data='n122'), Node(data='n12'), Node(data='n1'))
         )