]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
still not locating more nested expressions, may need to match on name
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 28 May 2013 01:05:16 +0000 (21:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 28 May 2013 01:05:16 +0000 (21:05 -0400)
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/suite/__init__.py
lib/sqlalchemy/testing/suite/test_select.py [new file with mode: 0644]
test/orm/test_eager_relations.py
test/sql/test_compiler.py

index 8eb8c5fd88abedea8ea79bfffc97e9fafdf9d6e4..5fbfa34f3d35c4b82337cf939c94f6f928e2d2a7 100644 (file)
@@ -390,15 +390,13 @@ class SQLCompiler(engine.Compiled):
                             add_to_result_map=None,
                             within_label_clause=False,
                             within_columns_clause=False,
-                            order_by_labels=None, **kw):
+                            render_label_as_label=None,
+                            **kw):
         # only render labels within the columns clause
         # or ORDER BY clause of a select.  dialect-specific compilers
         # can modify this behavior.
-#        if order_by_labels:
-#            import pdb
-#            pdb.set_trace()
         render_label_with_as = within_columns_clause and not within_label_clause
-        render_label_only = order_by_labels and label in order_by_labels
+        render_label_only = render_label_as_label is label
 
         if render_label_only or render_label_with_as:
             if isinstance(label.name, sql._truncated_label):
@@ -518,7 +516,9 @@ class SQLCompiler(engine.Compiled):
     def visit_false(self, expr, **kw):
         return 'false'
 
-    def visit_clauselist(self, clauselist, **kwargs):
+    def visit_clauselist(self, clauselist, order_by_select=None, **kw):
+        if order_by_select is not None:
+            return self._order_by_clauselist(clauselist, order_by_select, **kw)
         sep = clauselist.operator
         if sep is None:
             sep = " "
@@ -526,8 +526,34 @@ class SQLCompiler(engine.Compiled):
             sep = OPERATORS[clauselist.operator]
         return sep.join(
                     s for s in
-                    (c._compiler_dispatch(self, **kwargs)
-                    for c in clauselist.clauses)
+                    (
+                        c._compiler_dispatch(self, **kw)
+                        for c in clauselist.clauses)
+                    if s)
+
+    def _order_by_clauselist(self, clauselist, order_by_select, **kw):
+        # look through raw columns collection for labels.
+        # note that its OK we aren't expanding tables and other selectables
+        # here; we can only add a label in the ORDER BY for an individual
+        # label expression in the columns clause.
+        raw_col = set(order_by_select._raw_columns)
+        def label_ok(c):
+            if c in raw_col:
+                return c
+            elif getattr(c, 'modifier', None) in \
+                    (operators.desc_op, operators.asc_op) and \
+                    c.element.proxy_set.intersection(raw_col):
+                return c.element
+            else:
+                return None
+
+        return ", ".join(
+                    s for s in
+                    (
+                        c._compiler_dispatch(self,
+                                render_label_as_label=label_ok(c),
+                                **kw)
+                        for c in clauselist.clauses)
                     if s)
 
     def visit_case(self, clause, **kwargs):
@@ -1192,12 +1218,12 @@ class SQLCompiler(engine.Compiled):
 
         if select._order_by_clause.clauses:
             if self.dialect.supports_simple_order_by_label:
-                order_by_labels = set(c for k, c in select._columns_plus_names)
+                order_by_select = select
             else:
-                order_by_labels = None
+                order_by_select = None
 
             text += self.order_by_clause(select,
-                                    order_by_labels=order_by_labels, **kwargs)
+                            order_by_select=order_by_select, **kwargs)
         if select._limit is not None or select._offset is not None:
             text += self.limit_clause(select)
         if select.for_update:
index f65dd1a3431ead281288a22f96b9dfa2af0bb8bd..780aa40aa71b31053d73ffd8f03bab78a23b5e17 100644 (file)
@@ -2,6 +2,7 @@
 from sqlalchemy.testing.suite.test_ddl import *
 from sqlalchemy.testing.suite.test_insert import *
 from sqlalchemy.testing.suite.test_sequence import *
+from sqlalchemy.testing.suite.test_select import *
 from sqlalchemy.testing.suite.test_results import *
 from sqlalchemy.testing.suite.test_update_delete import *
 from sqlalchemy.testing.suite.test_reflection import *
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
new file mode 100644 (file)
index 0000000..b040c8f
--- /dev/null
@@ -0,0 +1,83 @@
+from .. import fixtures, config
+from ..assertions import eq_
+
+from sqlalchemy import Integer, String, select, func
+
+from ..schema import Table, Column
+
+
+class OrderByLabelTest(fixtures.TablesTest):
+    """Test the dialect sends appropriate ORDER BY expressions when
+    labels are used.
+
+    This essentially exercises the "supports_simple_order_by_label"
+    setting.
+
+    """
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("some_table", metadata,
+            Column('id', Integer, primary_key=True),
+            Column('x', Integer),
+            Column('y', Integer),
+            Column('q', String(50)),
+            Column('p', String(50))
+            )
+
+    @classmethod
+    def insert_data(cls):
+        config.db.execute(
+            cls.tables.some_table.insert(),
+            [
+                {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
+                {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
+                {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
+            ]
+        )
+
+    def _assert_result(self, select, result):
+        eq_(
+            config.db.execute(select).fetchall(),
+            result
+        )
+
+    def test_plain(self):
+        table = self.tables.some_table
+        lx = table.c.x.label('lx')
+        self._assert_result(
+            select([lx]).order_by(lx),
+            [(1, ), (2, ), (3, )]
+        )
+
+    def test_composed_int(self):
+        table = self.tables.some_table
+        lx = (table.c.x + table.c.y).label('lx')
+        self._assert_result(
+            select([lx]).order_by(lx),
+            [(3, ), (5, ), (7, )]
+        )
+
+    def test_composed_multiple(self):
+        table = self.tables.some_table
+        lx = (table.c.x + table.c.y).label('lx')
+        ly = (func.lower(table.c.q) + table.c.p).label('ly')
+        self._assert_result(
+            select([lx, ly]).order_by(lx, ly.desc()),
+            [(3, u'q1p3'), (5, u'q2p2'), (7, u'q3p1')]
+        )
+
+    def test_plain_desc(self):
+        table = self.tables.some_table
+        lx = table.c.x.label('lx')
+        self._assert_result(
+            select([lx]).order_by(lx.desc()),
+            [(3, ), (2, ), (1, )]
+        )
+
+    def test_composed_int_desc(self):
+        table = self.tables.some_table
+        lx = (table.c.x + table.c.y).label('lx')
+        self._assert_result(
+            select([lx]).order_by(lx.desc()),
+            [(7, ), (5, ), (3, )]
+        )
index bd85e4ce830f01070dc11fa60db1c0292d7b7b6b..b240d29f6d894779acb837b5f98888948d0e8234 100644 (file)
@@ -1386,8 +1386,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS "
             "b_1_a_id, b_1.value AS b_1_value FROM (SELECT "
             "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
-            "AS anon_2, a.id AS a_id FROM a ORDER BY (SELECT "
-            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "AS anon_2, a.id AS a_id FROM a ORDER BY anon_2 "
             "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON "
             "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2"
         )
@@ -1409,8 +1408,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS "
             "b_1_a_id, b_1.value AS b_1_value FROM (SELECT "
             "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
-            "AS anon_2, a.id AS a_id FROM a ORDER BY (SELECT "
-            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) DESC "
+            "AS anon_2, a.id AS a_id FROM a ORDER BY anon_2 DESC "
             "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON "
             "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2 DESC"
         )
@@ -1433,8 +1431,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS "
             "b_1_a_id, b_1.value AS b_1_value FROM (SELECT "
             "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
-            "AS anon_2, a.id AS a_id FROM a ORDER BY (SELECT "
-            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "AS anon_2, a.id AS a_id FROM a ORDER BY anon_2 "
             "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON "
             "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2"
         )
@@ -1479,8 +1476,7 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             "AS anon_1_foo, b_1.id AS b_1_id, b_1.a_id AS "
             "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id "
             "AS a_id, (SELECT sum(b.value) AS sum_1 FROM b WHERE "
-            "b.a_id = a.id) AS foo FROM a ORDER BY (SELECT "
-            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "b.a_id = a.id) AS foo FROM a ORDER BY foo "
             "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 "
             "ON anon_1.a_id = b_1.a_id ORDER BY "
             "anon_1.foo"
index 887676f948813bb63f79f517649d8155ac3fa836..d5f52bdf3557755b90f16150ba68ed9ea29c9323 100644 (file)
@@ -746,19 +746,24 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                 'foo || :param_1')
 
 
-    def test_labels_in_expressions(self):
-        """test that label() constructs in ORDER BY render as the labelname.
+    def test_foo(self):
+        lx = (table1.c.myid + table1.c.myid).label('lx')
+        ly = (func.lower(table1.c.name) + table1.c.description).label('ly')
+        dialect = default.DefaultDialect()
 
-        Postgres' behavior was used as the guide for this,
-        so that only a simple label expression
-        and not a more complex expression involving the label
-        name would be rendered using the label name.
+        self.assert_compile(
+            select([lx, ly]).order_by(lx, ly.desc()),
+            "SELECT mytable.myid + mytable.myid AS lx, "
+            "lower(mytable.name) || mytable.description AS ly "
+            "FROM mytable ORDER BY lx, ly DESC",
+            dialect=dialect
+            )
 
-        """
-        lab1 = (table1.c.myid + "12").label('foo')
+    def test_labels_in_expressions(self):
+        lab1 = (table1.c.myid + 12).label('foo')
         lab2 = func.somefunc(table1.c.name).label('bar')
-
         dialect = default.DefaultDialect()
+
         self.assert_compile(select([lab1, lab2]).order_by(lab1, desc(lab2)),
             "SELECT mytable.myid + :myid_1 AS foo, "
             "somefunc(mytable.name) AS bar FROM mytable "
@@ -786,9 +791,22 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         # labels within functions in the columns clause render
         # with the expression
         self.assert_compile(
-            select([lab1, func.foo(lab1)]),
+            select([lab1, func.foo(lab1)]).order_by(lab1, func.foo(lab1)),
             "SELECT mytable.myid + :myid_1 AS foo, "
-            "foo(mytable.myid + :myid_1) AS foo_1 FROM mytable",
+            "foo(mytable.myid + :myid_1) AS foo_1 FROM mytable "
+            "ORDER BY foo, foo(mytable.myid + :myid_1)",
+            dialect=dialect
+            )
+
+
+        lx = (table1.c.myid + table1.c.myid).label('lx')
+        ly = (func.lower(table1.c.name) + table1.c.description).label('ly')
+
+        self.assert_compile(
+            select([lx, ly]).order_by(lx, ly.desc()),
+            "SELECT mytable.myid + mytable.myid AS lx, "
+            "lower(mytable.name) || mytable.description AS ly "
+            "FROM mytable ORDER BY lx, ly DESC",
             dialect=dialect
             )