]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed some deep "column correspondence" issues which could
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Jan 2009 02:42:34 +0000 (02:42 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Jan 2009 02:42:34 +0000 (02:42 +0000)
impact a Query made against a selectable containing
multiple versions of the same table, as well as
unions and similar which contained the same table columns
in different column positions at different levels.
[ticket:1268]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql/expression.py
test/orm/query.py
test/sql/selectable.py

diff --git a/CHANGES b/CHANGES
index 84759a9fd18301ef0d8a2138dafcf0a764ae5216..03c833a64373e61fd174448072cd2c58dd6925ac 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -75,6 +75,13 @@ CHANGES
       next compile() call.  This issue occurs frequently
       when using declarative.
 
+    - Fixed some deep "column correspondence" issues which could
+      impact a Query made against a selectable containing
+      multiple versions of the same table, as well as 
+      unions and similar which contained the same table columns
+      in different column positions at different levels.
+      [ticket:1268]
+      
     - Custom comparator classes used in conjunction with 
       column_property(), relation() etc. can define 
       new comparison methods on the Comparator, which will
index 1e257f9a471e44f86131384474873369f30af7c3..5c0629dafb9077357351ddf56d112f6854a0985f 100644 (file)
@@ -926,7 +926,6 @@ class Mapper(object):
         }
 
         """
-
         result = util.column_dict()
         def visit_binary(binary):
             if binary.operator == operators.eq:
index 7204e29564b7d991a7641a7a0d6647b1db7ba171..7eeff0660d1fdd3fad9c6193ba0a63ea73323ecb 100644 (file)
@@ -33,6 +33,7 @@ from sqlalchemy import util, exc
 from sqlalchemy.sql import operators
 from sqlalchemy.sql.visitors import Visitable, cloned_traverse
 from sqlalchemy import types as sqltypes
+import operator
 
 functions, schema, sql_util = None, None, None
 DefaultDialect, ClauseAdapter, Annotated = None, None, None
@@ -1840,9 +1841,32 @@ class FromClause(Selectable):
         for c in cols:
             i = c.proxy_set.intersection(target_set)
             if i and \
-                (not require_embedded or c.proxy_set.issuperset(target_set)) and \
-                (intersect is None or len(i) > len(intersect)):
-                col, intersect = c, i
+                (not require_embedded or c.proxy_set.issuperset(target_set)):
+                
+                if col is None:
+                    # no corresponding column yet, pick this one.
+                    col, intersect = c, i
+                elif len(i) > len(intersect):
+                    # 'c' has a larger field of correspondence than 'col'.
+                    # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches a1.c.x->table.c.x better than 
+                    # selectable.c.x->table.c.x does.
+                    col, intersect = c, i
+                elif i == intersect:
+                    # they have the same field of correspondence.
+                    # see which proxy_set has fewer columns in it, which indicates a
+                    # closer relationship with the root column.  Also take into account the 
+                    # "weight" attribute which CompoundSelect() uses to give higher precedence to
+                    # columns based on vertical position in the compound statement, and discard columns
+                    # that have no reference to the target column (also occurs with CompoundSelect)
+                    col_distance = util.reduce(operator.add, 
+                                        [sc._annotations.get('weight', 1) for sc in col.proxy_set if sc.shares_lineage(column)]
+                                    )
+                    c_distance = util.reduce(operator.add, 
+                                        [sc._annotations.get('weight', 1) for sc in c.proxy_set if sc.shares_lineage(column)]
+                                    )
+                    if \
+                        c_distance < col_distance:
+                        col, intersect = c, i
         return col
 
     @property
@@ -3097,8 +3121,12 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
     def _populate_column_collection(self):
         for cols in zip(*[s.c for s in self.selects]):
             proxy = cols[0]._make_proxy(self, name=self.use_labels and cols[0]._label or None)
-            proxy.proxies = cols
-
+            
+            # place a 'weight' annotation corresponding to how low in the list of select()s
+            # the column occurs, so that the corresponding_column() operation
+            # can resolve conflicts
+            proxy.proxies = [c._annotate({'weight':i + 1}) for i, c in enumerate(cols)]
+            
     def _copy_internals(self, clone=_clone):
         self._reset_exported()
         self.selects = [clone(s) for s in self.selects]
index 076c1c9406dd9ece32b51bd58d5acba71848d7f5..284ddd1944ba7076d2c4cb1403290a5a99da3cb0 100644 (file)
@@ -1692,10 +1692,14 @@ class MixedEntitiesTest(QueryTest):
         
         sess = create_session()
         oalias = aliased(Order)
-        
+
         for q in [
             sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id),
             sess.query(Order, oalias)._from_self().filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id),
+            
+            # same thing, but reversed.  
+            sess.query(oalias, Order)._from_self().filter(oalias.user_id==Order.user_id).filter(oalias.user_id==7).filter(Order.id<oalias.id).order_by(oalias.id, Order.id),
+            
             # here we go....two layers of aliasing
             sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)),
 
@@ -2252,6 +2256,15 @@ class SelfReferentialTest(ORMTest):
             (Node(data='n122'), Node(data='n12'), Node(data='n1'))
         )
 
+        # same, change order around
+        self.assertEquals(
+            sess.query(parent, grandparent, Node).\
+                join((Node.parent, parent), (parent.parent, grandparent)).\
+                    filter(Node.data=='n122').filter(parent.data=='n12').\
+                    filter(grandparent.data=='n1')._from_self().first(),
+            (Node(data='n12'), Node(data='n1'), Node(data='n122'))
+        )
+
         self.assertEquals(
             sess.query(Node, parent, grandparent).\
                 join((Node.parent, parent), (parent.parent, grandparent)).\
index eb8bc861f5c922ef0198de5bd956f60b0455cb9a..e9ed5f5653df9bef1210f7607c00fce1e84e3b0b 100755 (executable)
@@ -1,6 +1,4 @@
-"""tests that various From objects properly export their columns, as well as
-useable primary keys and foreign keys.  Full relational algebra depends on
-every selectable unit behaving nicely with others.."""
+"""Test various algorithmic properties of selectables."""
 
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
@@ -27,7 +25,7 @@ table2 = Table('table2', metadata,
 )
 
 class SelectableTest(TestBase, AssertsExecutionResults):
-    def test_distance(self):
+    def test_distance_on_labels(self):
         # same column three times
         s = select([table1.c.col1.label('c2'), table1.c.col1, table1.c.col1.label('c1')])
 
@@ -36,6 +34,17 @@ class SelectableTest(TestBase, AssertsExecutionResults):
         assert s.corresponding_column(s.c.col1) is s.c.col1
         assert s.corresponding_column(s.c.c1) is s.c.c1
 
+    def test_distance_on_aliases(self):
+        a1 = table1.alias('a1')
+        
+        for s in (
+            select([a1, table1], use_labels=True),
+            select([table1, a1], use_labels=True)
+        ):
+            assert s.corresponding_column(table1.c.col1) is s.c.table1_col1
+            assert s.corresponding_column(a1.c.col1) is s.c.a1_col1
+        
+
     def test_join_against_self(self):
         jj = select([table1.c.col1.label('bar_col1')])
         jjj = join(table1, jj, table1.c.col1==jj.c.bar_col1)
@@ -45,10 +54,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
 
         assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1
 
-        # test alias of the join, targets the column with the least
-        # "distance" between the requested column and the returned column
-        # (i.e. there is less indirection between j2.c.table1_col1 and table1.c.col1, than
-        # there is from j2.c.bar_col1 to table1.c.col1)
+        # test alias of the join
         j2 = jjj.alias('foo')
         assert j2.corresponding_column(table1.c.col1) is j2.c.table1_col1
 
@@ -66,7 +72,6 @@ class SelectableTest(TestBase, AssertsExecutionResults):
         assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1
 
         j2 = jjj.alias('foo')
-        print j2.corresponding_column(jjj.c.table1_col1)
         assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1
 
         assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1
@@ -87,14 +92,29 @@ class SelectableTest(TestBase, AssertsExecutionResults):
             )
         s1 = table1.select(use_labels=True)
         s2 = table2.select(use_labels=True)
-        print ["%d %s" % (id(c),c.key) for c in u.c]
         c = u.corresponding_column(s1.c.table1_col2)
-        print "%d %s" % (id(c), c.key)
-        print id(u.corresponding_column(s1.c.table1_col2).table)
-        print id(u.c.col2.table)
         assert u.corresponding_column(s1.c.table1_col2) is u.c.col2
         assert u.corresponding_column(s2.c.table2_col2) is u.c.col2
 
+    def test_union_precedence(self):
+        # conflicting column correspondence should be resolved based on 
+        # the order of the select()s in the union
+        
+        s1 = select([table1.c.col1, table1.c.col2])
+        s2 = select([table1.c.col2, table1.c.col1])
+        s3 = select([table1.c.col3, table1.c.colx])
+        s4 = select([table1.c.colx, table1.c.col3])
+        
+        u1 = union(s1, s2)
+        assert u1.corresponding_column(table1.c.col1) is u1.c.col1
+        assert u1.corresponding_column(table1.c.col2) is u1.c.col2
+        
+        u1 = union(s1, s2, s3, s4)
+        assert u1.corresponding_column(table1.c.col1) is u1.c.col1
+        assert u1.corresponding_column(table1.c.col2) is u1.c.col2
+        assert u1.corresponding_column(table1.c.colx) is u1.c.col2
+        assert u1.corresponding_column(table1.c.col3) is u1.c.col1
+        
     def test_singular_union(self):
         u = union(select([table1.c.col1, table1.c.col2, table1.c.col3]), select([table1.c.col1, table1.c.col2, table1.c.col3]))
 
@@ -153,7 +173,6 @@ class SelectableTest(TestBase, AssertsExecutionResults):
 
     def test_select_labels(self):
         a = table1.select(use_labels=True)
-        print str(a.select())
         j = join(a, table2)
 
         criterion = a.c.table1_col1 == table2.c.col2
@@ -196,9 +215,6 @@ class SelectableTest(TestBase, AssertsExecutionResults):
         j3 = a.join(j2, j2.c.aid==a.c.id)
 
         j4 = select([j3]).alias('foo')
-        print j4
-        print j4.corresponding_column(j2.c.aid)
-        print j4.c.aid
         assert j4.corresponding_column(j2.c.aid) is j4.c.aid
         assert j4.corresponding_column(a.c.id) is j4.c.id