]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Invoke column_expression() for subsequent SELECTs in CompoundSelect
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Jul 2019 18:49:22 +0000 (14:49 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Jul 2019 18:53:19 +0000 (14:53 -0400)
Fixed bug where :meth:`.TypeEngine.column_expression` method would not be
applied to subsequent SELECT statements inside of a UNION or other
:class:`.CompoundSelect`, even though the SELECT statements are rendered at
the topmost level of the statement.   New logic now differentiates between
rendering the column expression, which is needed for all SELECTs in the
list, vs. gathering the returned data type for the result row, which is
needed only for the first SELECT.

Fixes: #4787
Change-Id: Iceb63e430e76d2365649aa25ead09c4e2a062e10
(cherry picked from commit 2ce8a04e726daecbf060684dcee7559634506700)

doc/build/changelog/unreleased_13/4787.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_type_expressions.py

diff --git a/doc/build/changelog/unreleased_13/4787.rst b/doc/build/changelog/unreleased_13/4787.rst
new file mode 100644 (file)
index 0000000..911a287
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4787
+
+    Fixed bug where :meth:`.TypeEngine.column_expression` method would not be
+    applied to subsequent SELECT statements inside of a UNION or other
+    :class:`.CompoundSelect`, even though the SELECT statements are rendered at
+    the topmost level of the statement.   New logic now differentiates between
+    rendering the column expression, which is needed for all SELECTs in the
+    list, vs. gathering the returned data type for the result row, which is
+    needed only for the first SELECT.
index b94857fed73d14bafc72ed8d27281a00fe489608..7c4a7a518356b42dc6ffd62de1b19cbdbb8b78c0 100644 (file)
@@ -1773,18 +1773,26 @@ class SQLCompiler(Compiled):
         column_clause_args,
         name=None,
         within_columns_clause=True,
+        need_column_expressions=False,
     ):
         """produce labeled columns present in a select()."""
 
         impl = column.type.dialect_impl(self.dialect)
-        if impl._has_column_expression and populate_result_map:
+
+        if impl._has_column_expression and (
+            need_column_expressions or populate_result_map
+        ):
             col_expr = impl.column_expression(column)
 
-            def add_to_result_map(keyname, name, objects, type_):
-                self._add_to_result_map(
-                    keyname, name, (column,) + objects, type_
-                )
+            if populate_result_map:
 
+                def add_to_result_map(keyname, name, objects, type_):
+                    self._add_to_result_map(
+                        keyname, name, (column,) + objects, type_
+                    )
+
+            else:
+                add_to_result_map = None
         else:
             col_expr = column
             if populate_result_map:
@@ -2039,15 +2047,15 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
-        populate_result_map = (
+        populate_result_map = need_column_expressions = (
             toplevel
-            or (
-                compound_index == 0
-                and entry.get("need_result_map_for_compound", False)
-            )
+            or entry.get("need_result_map_for_compound", False)
             or entry.get("need_result_map_for_nested", False)
         )
 
+        if compound_index > 0:
+            populate_result_map = False
+
         # this was first proposed as part of #3372; however, it is not
         # reached in current tests and could possibly be an assertion
         # instead.
@@ -2092,6 +2100,7 @@ class SQLCompiler(Compiled):
                     asfrom,
                     column_clause_args,
                     name=name,
+                    need_column_expressions=need_column_expressions,
                 )
                 for name, column in select._columns_plus_names
             ]
index 18877c338f917b6704a963e11b03ce1969202266..fe383d31467a3ef1f389033c9a30d1a39ffe0b21 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import TypeDecorator
+from sqlalchemy import union
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
@@ -272,6 +273,35 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
             "test_table.y = outside_bind(inside_bind(:y_1))",
         )
 
+    def test_compound_select(self):
+        table = self._fixture()
+
+        s1 = select([table]).where(table.c.y == "hi")
+        s2 = select([table]).where(table.c.y == "there")
+
+        self.assert_compile(
+            union(s1, s2),
+            "SELECT test_table.x, lower(test_table.y) AS y "
+            "FROM test_table WHERE test_table.y = lower(:y_1) "
+            "UNION SELECT test_table.x, lower(test_table.y) AS y "
+            "FROM test_table WHERE test_table.y = lower(:y_2)",
+        )
+
+    def test_select_of_compound_select(self):
+        table = self._fixture()
+
+        s1 = select([table]).where(table.c.y == "hi")
+        s2 = select([table]).where(table.c.y == "there")
+
+        self.assert_compile(
+            union(s1, s2).alias().select(),
+            "SELECT anon_1.x, lower(anon_1.y) AS y FROM "
+            "(SELECT test_table.x AS x, test_table.y AS y "
+            "FROM test_table WHERE test_table.y = lower(:y_1) "
+            "UNION SELECT test_table.x AS x, test_table.y AS y "
+            "FROM test_table WHERE test_table.y = lower(:y_2)) AS anon_1",
+        )
+
 
 class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"