]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- type expressions invoke in SQL, but are only for the benefit of columns
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Sep 2012 14:28:26 +0000 (10:28 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Sep 2012 14:28:26 +0000 (10:28 -0400)
delivered to a result set. therefore these expressions should only be rendered
for those columns that are being delivered to the result, thereby preventing
the expression from stacking onto itself within nesting scenarios.

lib/sqlalchemy/sql/compiler.py
test/sql/test_type_expressions.py

index 2fc14c84c5f7898c02721eb5f5d531be5121ef26..3778c768374def7cf55a060f87aeb97e2656ceb0 100644 (file)
@@ -986,15 +986,13 @@ class SQLCompiler(engine.Compiled):
                                     within_columns_clause=True):
         """produce labeled columns present in a select()."""
 
-        if column.type._has_column_expression:
+        if column.type._has_column_expression and \
+            populate_result_map:
             col_expr = column.type.column_expression(column)
-            if populate_result_map:
-                add_to_result_map = lambda keyname, name, objects, type_: \
-                                    self._add_to_result_map(
-                                            keyname, name,
-                                            objects + (column,), type_)
-            else:
-                add_to_result_map = None
+            add_to_result_map = lambda keyname, name, objects, type_: \
+                                self._add_to_result_map(
+                                        keyname, name,
+                                        objects + (column,), type_)
         else:
             col_expr = column
             if populate_result_map:
index 1a331d5703942b4868fa1a7ac38ee4bfe88f6750..320dc5d7c27856bfe49df5683dd955275e802d19 100644 (file)
@@ -2,9 +2,7 @@ from sqlalchemy import Table, Column, String, func, MetaData, select, TypeDecora
 from test.lib import fixtures, AssertsCompiledSQL, testing
 from test.lib.testing import eq_
 
-class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
-    __dialect__ = 'default'
-
+class _ExprFixture(object):
     def _fixture(self):
         class MyString(String):
             def bind_expression(self, bindvalue):
@@ -19,6 +17,9 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         )
         return test_table
 
+class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
     def test_select_cols(self):
         table = self._fixture()
 
@@ -84,6 +85,41 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "test_table WHERE test_table.y = lower(:y_1)"
         )
 
+class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
+    def test_select_from_select(self):
+        table = self._fixture()
+        self.assert_compile(
+            table.select().select(),
+            "SELECT x, lower(y) AS y FROM (SELECT test_table.x "
+                "AS x, test_table.y AS y FROM test_table)"
+        )
+
+    def test_select_from_alias(self):
+        table = self._fixture()
+        self.assert_compile(
+            table.select().alias().select(),
+            "SELECT anon_1.x, lower(anon_1.y) AS y FROM (SELECT "
+                "test_table.x AS x, test_table.y AS y "
+                "FROM test_table) AS anon_1"
+        )
+
+    def test_select_from_aliased_join(self):
+        table = self._fixture()
+        s1 = table.select().alias()
+        s2 = table.select().alias()
+        j = s1.join(s2, s1.c.x == s2.c.x)
+        s3 = j.select()
+        self.assert_compile(s3,
+            "SELECT anon_1.x, lower(anon_1.y) AS y, anon_2.x, "
+            "lower(anon_2.y) AS y "
+            "FROM (SELECT test_table.x AS x, test_table.y AS y "
+            "FROM test_table) AS anon_1 JOIN (SELECT "
+            "test_table.x AS x, test_table.y AS y "
+            "FROM test_table) AS anon_2 ON anon_1.x = anon_2.x"
+        )
+
 class RoundTripTestBase(object):
     def test_round_trip(self):
         testing.db.execute(