From: Mike Bayer Date: Sat, 3 Jan 2009 02:42:34 +0000 (+0000) Subject: - Fixed some deep "column correspondence" issues which could X-Git-Tag: rel_0_5_0~29 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=9fe69cb503de4bcbace8d74b14dd6d096d457e72;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Fixed some deep "column correspondence" issues which could impact a Query made against a selectable containing multiple versions of the same table, as well as unions and similar which contained the same table columns in different column positions at different levels. [ticket:1268] --- diff --git a/CHANGES b/CHANGES index 84759a9fd1..03c833a643 100644 --- a/CHANGES +++ b/CHANGES @@ -75,6 +75,13 @@ CHANGES next compile() call. This issue occurs frequently when using declarative. + - Fixed some deep "column correspondence" issues which could + impact a Query made against a selectable containing + multiple versions of the same table, as well as + unions and similar which contained the same table columns + in different column positions at different levels. + [ticket:1268] + - Custom comparator classes used in conjunction with column_property(), relation() etc. can define new comparison methods on the Comparator, which will diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 1e257f9a47..5c0629dafb 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -926,7 +926,6 @@ class Mapper(object): } """ - result = util.column_dict() def visit_binary(binary): if binary.operator == operators.eq: diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 7204e29564..7eeff0660d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -33,6 +33,7 @@ from sqlalchemy import util, exc from sqlalchemy.sql import operators from sqlalchemy.sql.visitors import Visitable, cloned_traverse from sqlalchemy import types as sqltypes +import operator functions, schema, sql_util = None, None, None DefaultDialect, ClauseAdapter, Annotated = None, None, None @@ -1840,9 +1841,32 @@ class FromClause(Selectable): for c in cols: i = c.proxy_set.intersection(target_set) if i and \ - (not require_embedded or c.proxy_set.issuperset(target_set)) and \ - (intersect is None or len(i) > len(intersect)): - col, intersect = c, i + (not require_embedded or c.proxy_set.issuperset(target_set)): + + if col is None: + # no corresponding column yet, pick this one. + col, intersect = c, i + elif len(i) > len(intersect): + # 'c' has a larger field of correspondence than 'col'. + # i.e. selectable.c.a1_x->a1.c.x->table.c.x matches a1.c.x->table.c.x better than + # selectable.c.x->table.c.x does. + col, intersect = c, i + elif i == intersect: + # they have the same field of correspondence. + # see which proxy_set has fewer columns in it, which indicates a + # closer relationship with the root column. Also take into account the + # "weight" attribute which CompoundSelect() uses to give higher precedence to + # columns based on vertical position in the compound statement, and discard columns + # that have no reference to the target column (also occurs with CompoundSelect) + col_distance = util.reduce(operator.add, + [sc._annotations.get('weight', 1) for sc in col.proxy_set if sc.shares_lineage(column)] + ) + c_distance = util.reduce(operator.add, + [sc._annotations.get('weight', 1) for sc in c.proxy_set if sc.shares_lineage(column)] + ) + if \ + c_distance < col_distance: + col, intersect = c, i return col @property @@ -3097,8 +3121,12 @@ class CompoundSelect(_SelectBaseMixin, FromClause): def _populate_column_collection(self): for cols in zip(*[s.c for s in self.selects]): proxy = cols[0]._make_proxy(self, name=self.use_labels and cols[0]._label or None) - proxy.proxies = cols - + + # place a 'weight' annotation corresponding to how low in the list of select()s + # the column occurs, so that the corresponding_column() operation + # can resolve conflicts + proxy.proxies = [c._annotate({'weight':i + 1}) for i, c in enumerate(cols)] + def _copy_internals(self, clone=_clone): self._reset_exported() self.selects = [clone(s) for s in self.selects] diff --git a/test/orm/query.py b/test/orm/query.py index 076c1c9406..284ddd1944 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1692,10 +1692,14 @@ class MixedEntitiesTest(QueryTest): sess = create_session() oalias = aliased(Order) - + for q in [ sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id), sess.query(Order, oalias)._from_self().filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id).order_by(Order.id, oalias.id), + + # same thing, but reversed. + sess.query(oalias, Order)._from_self().filter(oalias.user_id==Order.user_id).filter(oalias.user_id==7).filter(Order.idoalias.id)._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)), @@ -2252,6 +2256,15 @@ class SelfReferentialTest(ORMTest): (Node(data='n122'), Node(data='n12'), Node(data='n1')) ) + # same, change order around + self.assertEquals( + sess.query(parent, grandparent, Node).\ + join((Node.parent, parent), (parent.parent, grandparent)).\ + filter(Node.data=='n122').filter(parent.data=='n12').\ + filter(grandparent.data=='n1')._from_self().first(), + (Node(data='n12'), Node(data='n1'), Node(data='n122')) + ) + self.assertEquals( sess.query(Node, parent, grandparent).\ join((Node.parent, parent), (parent.parent, grandparent)).\ diff --git a/test/sql/selectable.py b/test/sql/selectable.py index eb8bc861f5..e9ed5f5653 100755 --- a/test/sql/selectable.py +++ b/test/sql/selectable.py @@ -1,6 +1,4 @@ -"""tests that various From objects properly export their columns, as well as -useable primary keys and foreign keys. Full relational algebra depends on -every selectable unit behaving nicely with others..""" +"""Test various algorithmic properties of selectables.""" import testenv; testenv.configure_for_tests() from sqlalchemy import * @@ -27,7 +25,7 @@ table2 = Table('table2', metadata, ) class SelectableTest(TestBase, AssertsExecutionResults): - def test_distance(self): + def test_distance_on_labels(self): # same column three times s = select([table1.c.col1.label('c2'), table1.c.col1, table1.c.col1.label('c1')]) @@ -36,6 +34,17 @@ class SelectableTest(TestBase, AssertsExecutionResults): assert s.corresponding_column(s.c.col1) is s.c.col1 assert s.corresponding_column(s.c.c1) is s.c.c1 + def test_distance_on_aliases(self): + a1 = table1.alias('a1') + + for s in ( + select([a1, table1], use_labels=True), + select([table1, a1], use_labels=True) + ): + assert s.corresponding_column(table1.c.col1) is s.c.table1_col1 + assert s.corresponding_column(a1.c.col1) is s.c.a1_col1 + + def test_join_against_self(self): jj = select([table1.c.col1.label('bar_col1')]) jjj = join(table1, jj, table1.c.col1==jj.c.bar_col1) @@ -45,10 +54,7 @@ class SelectableTest(TestBase, AssertsExecutionResults): assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1 - # test alias of the join, targets the column with the least - # "distance" between the requested column and the returned column - # (i.e. there is less indirection between j2.c.table1_col1 and table1.c.col1, than - # there is from j2.c.bar_col1 to table1.c.col1) + # test alias of the join j2 = jjj.alias('foo') assert j2.corresponding_column(table1.c.col1) is j2.c.table1_col1 @@ -66,7 +72,6 @@ class SelectableTest(TestBase, AssertsExecutionResults): assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1 j2 = jjj.alias('foo') - print j2.corresponding_column(jjj.c.table1_col1) assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1 assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1 @@ -87,14 +92,29 @@ class SelectableTest(TestBase, AssertsExecutionResults): ) s1 = table1.select(use_labels=True) s2 = table2.select(use_labels=True) - print ["%d %s" % (id(c),c.key) for c in u.c] c = u.corresponding_column(s1.c.table1_col2) - print "%d %s" % (id(c), c.key) - print id(u.corresponding_column(s1.c.table1_col2).table) - print id(u.c.col2.table) assert u.corresponding_column(s1.c.table1_col2) is u.c.col2 assert u.corresponding_column(s2.c.table2_col2) is u.c.col2 + def test_union_precedence(self): + # conflicting column correspondence should be resolved based on + # the order of the select()s in the union + + s1 = select([table1.c.col1, table1.c.col2]) + s2 = select([table1.c.col2, table1.c.col1]) + s3 = select([table1.c.col3, table1.c.colx]) + s4 = select([table1.c.colx, table1.c.col3]) + + u1 = union(s1, s2) + assert u1.corresponding_column(table1.c.col1) is u1.c.col1 + assert u1.corresponding_column(table1.c.col2) is u1.c.col2 + + u1 = union(s1, s2, s3, s4) + assert u1.corresponding_column(table1.c.col1) is u1.c.col1 + assert u1.corresponding_column(table1.c.col2) is u1.c.col2 + assert u1.corresponding_column(table1.c.colx) is u1.c.col2 + assert u1.corresponding_column(table1.c.col3) is u1.c.col1 + def test_singular_union(self): u = union(select([table1.c.col1, table1.c.col2, table1.c.col3]), select([table1.c.col1, table1.c.col2, table1.c.col3])) @@ -153,7 +173,6 @@ class SelectableTest(TestBase, AssertsExecutionResults): def test_select_labels(self): a = table1.select(use_labels=True) - print str(a.select()) j = join(a, table2) criterion = a.c.table1_col1 == table2.c.col2 @@ -196,9 +215,6 @@ class SelectableTest(TestBase, AssertsExecutionResults): j3 = a.join(j2, j2.c.aid==a.c.id) j4 = select([j3]).alias('foo') - print j4 - print j4.corresponding_column(j2.c.aid) - print j4.c.aid assert j4.corresponding_column(j2.c.aid) is j4.c.aid assert j4.corresponding_column(a.c.id) is j4.c.id