]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
attempt number one, doesn't detect though if the label in the order by is not directl...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 May 2013 23:22:59 +0000 (19:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 May 2013 23:22:59 +0000 (19:22 -0400)
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

index daa9fe0855a8eff1999e6e8075055f8cd3e68dcf..dc45e12b16ef7673be4840948b0805f4336ad411 100644 (file)
@@ -52,6 +52,8 @@ class DefaultDialect(interfaces.Dialect):
     supports_native_enum = False
     supports_native_boolean = False
 
+    supports_simple_order_by_label = True
+
     # if the NUMERIC type
     # returns decimal.Decimal.
     # *not* the FLOAT type however.
index b902f9ffc27b461c78ef0585331403258cbc29a2..8eb8c5fd88abedea8ea79bfffc97e9fafdf9d6e4 100644 (file)
@@ -389,16 +389,24 @@ class SQLCompiler(engine.Compiled):
     def visit_label(self, label,
                             add_to_result_map=None,
                             within_label_clause=False,
-                            within_columns_clause=False, **kw):
+                            within_columns_clause=False,
+                            order_by_labels=None, **kw):
         # only render labels within the columns clause
         # or ORDER BY clause of a select.  dialect-specific compilers
         # can modify this behavior.
-        if within_columns_clause and not within_label_clause:
+#        if order_by_labels:
+#            import pdb
+#            pdb.set_trace()
+        render_label_with_as = within_columns_clause and not within_label_clause
+        render_label_only = order_by_labels and label in order_by_labels
+
+        if render_label_only or render_label_with_as:
             if isinstance(label.name, sql._truncated_label):
                 labelname = self._truncated_identifier("colident", label.name)
             else:
                 labelname = label.name
 
+        if render_label_with_as:
             if add_to_result_map is not None:
                 add_to_result_map(
                         labelname,
@@ -413,6 +421,8 @@ class SQLCompiler(engine.Compiled):
                                     **kw) + \
                         OPERATORS[operators.as_] + \
                         self.preparer.format_label(label, labelname)
+        elif render_label_only:
+            return labelname
         else:
             return label.element._compiler_dispatch(self,
                                     within_columns_clause=False,
@@ -1181,7 +1191,13 @@ class SQLCompiler(engine.Compiled):
                 text += " \nHAVING " + t
 
         if select._order_by_clause.clauses:
-            text += self.order_by_clause(select, **kwargs)
+            if self.dialect.supports_simple_order_by_label:
+                order_by_labels = set(c for k, c in select._columns_plus_names)
+            else:
+                order_by_labels = None
+
+            text += self.order_by_clause(select,
+                                    order_by_labels=order_by_labels, **kwargs)
         if select._limit is not None or select._offset is not None:
             text += self.limit_clause(select)
         if select.for_update:
index 9cd893c1adc871c22764f9e2bfe9721b5733506c..887676f948813bb63f79f517649d8155ac3fa836 100644 (file)
@@ -746,6 +746,69 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                 'foo || :param_1')
 
 
+    def test_labels_in_expressions(self):
+        """test that label() constructs in ORDER BY render as the labelname.
+
+        Postgres' behavior was used as the guide for this,
+        so that only a simple label expression
+        and not a more complex expression involving the label
+        name would be rendered using the label name.
+
+        """
+        lab1 = (table1.c.myid + "12").label('foo')
+        lab2 = func.somefunc(table1.c.name).label('bar')
+
+        dialect = default.DefaultDialect()
+        self.assert_compile(select([lab1, lab2]).order_by(lab1, desc(lab2)),
+            "SELECT mytable.myid + :myid_1 AS foo, "
+            "somefunc(mytable.name) AS bar FROM mytable "
+            "ORDER BY foo, bar DESC",
+            dialect=dialect
+        )
+
+        # the function embedded label renders as the function
+        self.assert_compile(
+            select([lab1, lab2]).order_by(func.hoho(lab1), desc(lab2)),
+            "SELECT mytable.myid + :myid_1 AS foo, "
+            "somefunc(mytable.name) AS bar FROM mytable "
+            "ORDER BY hoho(mytable.myid + :myid_1), bar DESC",
+            dialect=dialect
+        )
+
+        # binary expressions render as the expression without labels
+        self.assert_compile(select([lab1, lab2]).order_by(lab1 + "test"),
+            "SELECT mytable.myid + :myid_1 AS foo, "
+            "somefunc(mytable.name) AS bar FROM mytable "
+            "ORDER BY mytable.myid + :myid_1 + :param_1",
+            dialect=dialect
+        )
+
+        # labels within functions in the columns clause render
+        # with the expression
+        self.assert_compile(
+            select([lab1, func.foo(lab1)]),
+            "SELECT mytable.myid + :myid_1 AS foo, "
+            "foo(mytable.myid + :myid_1) AS foo_1 FROM mytable",
+            dialect=dialect
+            )
+
+        dialect = default.DefaultDialect()
+        dialect.supports_simple_order_by_label = False
+        self.assert_compile(select([lab1, lab2]).order_by(lab1, desc(lab2)),
+            "SELECT mytable.myid + :myid_1 AS foo, "
+            "somefunc(mytable.name) AS bar FROM mytable "
+            "ORDER BY mytable.myid + :myid_1, somefunc(mytable.name) DESC",
+            dialect=dialect
+        )
+        self.assert_compile(
+            select([lab1, lab2]).order_by(func.hoho(lab1), desc(lab2)),
+            "SELECT mytable.myid + :myid_1 AS foo, "
+            "somefunc(mytable.name) AS bar FROM mytable "
+            "ORDER BY hoho(mytable.myid + :myid_1), "
+            "somefunc(mytable.name) DESC",
+            dialect=dialect
+        )
+
     def test_conjunctions(self):
         a, b, c = 'a', 'b', 'c'
         x = and_(a, b, c)