]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Fixed bug whereby usage of a UNION
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Aug 2012 23:06:19 +0000 (19:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Aug 2012 23:06:19 +0000 (19:06 -0400)
or similar inside of an embedded subquery
would interfere with result-column targeting,
in the case that a result-column had the same
ultimate name as a name inside the embedded
UNION. [ticket:2552]

CHANGES
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

diff --git a/CHANGES b/CHANGES
index 87ef0446e54558eddb7288cc9552078db7fe5cc3..0044262517e7721196b028cf13079492701db8c6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -64,6 +64,13 @@ CHANGES
     could be incorrect in certain "clone+replace"
     scenarios.  [ticket:2518]
 
+  - [bug] Fixed bug whereby usage of a UNION
+    or similar inside of an embedded subquery
+    would interfere with result-column targeting,
+    in the case that a result-column had the same
+    ultimate name as a name inside the embedded
+    UNION. [ticket:2552]
+
 - engine
   - [bug] Fixed bug whereby
     a disconnect detect + dispose that occurs
index 4ed1468444947cebbad7babcec87bb28e822916e..bd115314cf4a900a6676ee11c346768cb7730ffd 100644 (file)
@@ -547,7 +547,7 @@ class SQLCompiler(engine.Compiled):
         else:
             name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
             return ".".join(list(func.packagenames) + [name]) % \
-                            {'expr':self.function_argspec(func, **kwargs)}
+                            {'expr': self.function_argspec(func, **kwargs)}
 
     def visit_next_value_func(self, next_value, **kw):
         return self.visit_sequence(next_value.sequence)
@@ -561,9 +561,10 @@ class SQLCompiler(engine.Compiled):
         return func.clause_expr._compiler_dispatch(self, **kwargs)
 
     def visit_compound_select(self, cs, asfrom=False,
-                            parens=True, compound_index=1, **kwargs):
+                            parens=True, compound_index=0, **kwargs):
         entry = self.stack and self.stack[-1] or {}
-        self.stack.append({'from':entry.get('from', None), 'iswrapper':True})
+        self.stack.append({'from': entry.get('from', None),
+                    'iswrapper': not entry})
 
         keyword = self.compound_keywords.get(cs.keyword)
 
@@ -584,7 +585,7 @@ class SQLCompiler(engine.Compiled):
                         self.limit_clause(cs) or ""
 
         if self.ctes and \
-            compound_index==1 and not entry:
+            compound_index == 0 and not entry:
             text = self._render_cte_clause() + text
 
         self.stack.pop(-1)
@@ -913,7 +914,7 @@ class SQLCompiler(engine.Compiled):
 
     def visit_select(self, select, asfrom=False, parens=True,
                             iswrapper=False, fromhints=None,
-                            compound_index=1,
+                            compound_index=0,
                             positional_names=None, **kwargs):
 
         entry = self.stack and self.stack[-1] or {}
@@ -929,14 +930,18 @@ class SQLCompiler(engine.Compiled):
         # to outermost if existingfroms: correlate_froms =
         # correlate_froms.union(existingfroms)
 
-        self.stack.append({'from': correlate_froms, 'iswrapper'
-                          : iswrapper})
+        populate_result_map = compound_index == 0 and (
+                                not entry or \
+                                entry.get('iswrapper', False)
+                            )
+
+        self.stack.append({'from': correlate_froms, 'iswrapper': iswrapper})
 
-        if compound_index==1 and not entry or entry.get('iswrapper', False):
-            column_clause_args = {'result_map':self.result_map,
-                                    'positional_names':positional_names}
+        if populate_result_map:
+            column_clause_args = {'result_map': self.result_map,
+                                    'positional_names': positional_names}
         else:
-            column_clause_args = {'positional_names':positional_names}
+            column_clause_args = {'positional_names': positional_names}
 
         # the actual list of columns to print in the SELECT column list.
         inner_columns = [
@@ -1012,7 +1017,7 @@ class SQLCompiler(engine.Compiled):
             text += self.for_update_clause(select)
 
         if self.ctes and \
-            compound_index==1 and not entry:
+            compound_index == 0 and not entry:
             text  = self._render_cte_clause() + text
 
         self.stack.pop(-1)
index 49de52d899ee31b37155b36035dc5ef2fc4c5d08..f62a6cdc6cefbced192d06e8ab7b05e713896070 100644 (file)
@@ -1,6 +1,6 @@
 #! coding:utf-8
 
-from test.lib.testing import eq_, assert_raises, assert_raises_message
+from test.lib.testing import eq_, is_, assert_raises, assert_raises_message
 import datetime, re, operator, decimal
 from sqlalchemy import *
 from sqlalchemy import exc, sql, util, types, schema
@@ -3159,3 +3159,54 @@ class CoercionTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(and_(t.c.id == 1, null()),
                             "foo.id = :id_1 AND NULL")
 
+
+class ResultMapTest(fixtures.TestBase):
+    """test the behavior of the 'entry stack' and the determination
+    when the result_map needs to be populated.
+
+    """
+    def test_compound_populates(self):
+        t = Table('t', MetaData(), Column('a', Integer), Column('b', Integer))
+        stmt = select([t]).union(select([t]))
+        comp = stmt.compile()
+        eq_(
+            comp.result_map,
+             {'a': ('a', (t.c.a, 'a', 'a'), t.c.a.type),
+             'b': ('b', (t.c.b, 'b', 'b'), t.c.b.type)}
+        )
+
+    def test_compound_not_toplevel_doesnt_populate(self):
+        t = Table('t', MetaData(), Column('a', Integer), Column('b', Integer))
+        subq = select([t]).union(select([t]))
+        stmt = select([t.c.a]).select_from(t.join(subq, t.c.a == subq.c.a))
+        comp = stmt.compile()
+        eq_(
+            comp.result_map,
+             {'a': ('a', (t.c.a, 'a', 'a'), t.c.a.type)}
+        )
+
+    def test_compound_only_top_populates(self):
+        t = Table('t', MetaData(), Column('a', Integer), Column('b', Integer))
+        stmt = select([t.c.a]).union(select([t.c.b]))
+        comp = stmt.compile()
+        eq_(
+            comp.result_map,
+             {'a': ('a', (t.c.a, 'a', 'a'), t.c.a.type)},
+        )
+
+    def test_label_conflict_union(self):
+        t1 = Table('t1', MetaData(), Column('a', Integer), Column('b', Integer))
+        t2 = Table('t2', MetaData(), Column('t1_a', Integer))
+        union = select([t2]).union(select([t2])).alias()
+
+        t1_alias = t1.alias()
+        stmt = select([t1, t1_alias]).select_from(
+                        t1.join(union, t1.c.a == union.c.t1_a)).apply_labels()
+        comp = stmt.compile()
+        eq_(
+            set(comp.result_map),
+            set(['t1_1_b', 't1_1_a', 't1_a', 't1_b'])
+        )
+        is_(
+            comp.result_map['t1_a'][1][1], t1.c.a
+        )
\ No newline at end of file