]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- compiler visit_label() checks a flag "within_order_by" and will render its own...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 3 Aug 2008 21:19:32 +0000 (21:19 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 3 Aug 2008 21:19:32 +0000 (21:19 +0000)
and not its contained expression, if the dialect reports true for supports_simple_order_by_label.
the flag is not propagated forwards, meant to closely mimic the syntax Postgres expects which is
that only a simple name can be in the ORDER BY, not a more complex expression or function call
with the label name embedded (mysql and sqlite support more complex expressions).

This further sets the standard for propigation of **kwargs within compiler, that we can't just send
**kwargs along blindly to each XXX.process() call; whenever a **kwarg needs to propagate through,
most methods will have to be aware of it and know when they should send it on forward and when not.
This was actually already the case with result_map as well.

The supports_simple_order_by dialect flag defaults to True but is conservatively explicitly set to
False on all dialects except SQLite/MySQL/Postgres to start.

[ticket:1068]

13 files changed:
CHANGES
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/maxdb.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/sql/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index efd5dfb4eb9041d926e4706670595b728a9f6067..dd8a004988095604af8b98a35a76d110b286d112 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -76,6 +76,13 @@ CHANGES
 - sql
     - func.count() with no arguments renders as COUNT(*),
       equivalent to func.count(text('*')). 
+      
+    - simple label names in ORDER BY expressions render as 
+      themselves, and not as a re-statement of their 
+      corresponding expression.  This feature is currently 
+      enabled only for SQLite, MySQL, and Postgres.  
+      It can be enabled on other dialects as each is shown 
+      to support this behavior. [ticket:1068]
  
 - ext
     - Class-bound attributes sent as arguments to 
index 425c1eb69c3bc1194ae8f444d5b5cd3fdf88a7ba..0dfa62888262ab202989d5abcf3750c3e53c48ff 100644 (file)
@@ -170,6 +170,7 @@ class AccessDialect(default.DefaultDialect):
     name = 'access'
     supports_sane_rowcount = False
     supports_sane_multi_rowcount = False
+    supports_simple_order_by_label = False
 
     def type_descriptor(self, typeobj):
         newobj = types.adapt_type(typeobj, self.colspecs)
index 33ae4feab011f5f83a2080d1b601a055527273ac..412f7bce6025202297628045ee49b711d750bfed 100644 (file)
@@ -310,6 +310,7 @@ class FBDialect(default.DefaultDialect):
     max_identifier_length = 31
     preexecute_pk_sequences = True
     supports_pk_autoincrement = False
+    supports_simple_order_by_label = False
 
     def __init__(self, type_conv=200, concurrency_level=1, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
@@ -675,7 +676,7 @@ class FBCompiler(sql.compiler.DefaultCompiler):
                         yield co
                 else:
                     yield c
-        columns = [self.process(c, render_labels=True)
+        columns = [self.process(c, within_columns_clause=True)
                    for c in flatten_columnlist(returning_cols)]
         text += ' RETURNING ' + ', '.join(columns)
         return text
index 130b08c41ad3ee5da1f527ff42296a29ab0148b4..39d0ee96a50915c4fd364c504216ef4c1554adfb 100644 (file)
@@ -202,6 +202,7 @@ class InfoDialect(default.DefaultDialect):
     default_paramstyle = 'qmark'
     # for informix 7.31
     max_identifier_length = 18
+    supports_simple_order_by_label = False
 
     def __init__(self, use_ansi=True, **kwargs):
         self.use_ansi = use_ansi
@@ -414,12 +415,12 @@ class InfoCompiler(compiler.DefaultCompiler):
         else:
             return compiler.DefaultCompiler.visit_function( self , func )
 
-    def visit_clauselist(self, list):
+    def visit_clauselist(self, list, within_order_by=False, **kwargs):
         try:
             li = [ c for c in list.clauses if c.name != 'oid' ]
         except:
             li = [ c for c in list.clauses ]
-        return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
+        return ', '.join([s for s in [self.process(c, within_order_by=within_order_by) for c in li] if s is not None])
 
 class InfoSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, first_pk=False):
index c9ea2b57923a7f54b4b708003911a80b2222c668..c51f1b43134170e5e12e4a3e1d6969109d2419cd 100644 (file)
@@ -473,6 +473,7 @@ class MaxDBDialect(default.DefaultDialect):
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
     preexecute_pk_sequences = True
+    supports_simple_order_by_label = False
 
     # MaxDB-specific
     datetimeformat = 'internal'
index 0341d182307c90307b746580a99dcf9df335ba7b..f0385673680e1efe13be9b1bb8fa548e3604f161 100644 (file)
@@ -361,6 +361,8 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
 
 class MSSQLDialect(default.DefaultDialect):
     name = 'mssql'
+    supports_simple_order_by_label = False
+
     colspecs = {
         sqltypes.Unicode : MSNVarchar,
         sqltypes.Integer : MSInteger,
index f2e5ba2f64ab1eb86773d2ac5b8e48329f14a96f..8341e3401ef7a53c8753f2c6fe103c7ecbd47303 100644 (file)
@@ -240,6 +240,7 @@ class OracleDialect(default.DefaultDialect):
     preexecute_pk_sequences = True
     supports_pk_autoincrement = False
     default_paramstyle = 'named'
+    supports_simple_order_by_label = False
 
     def __init__(self, use_ansi=True, auto_setinputsizes=True, auto_convert_lobs=True, threaded=True, allow_twophase=True, arraysize=50, **kwargs):
         default.DefaultDialect.__init__(self, **kwargs)
index 5d61b32ad475cf1ae15c7e438bb862a477a7b39a..744a573c90b275bf49182c284c62a21652ceaecb 100644 (file)
@@ -703,7 +703,7 @@ class PGCompiler(compiler.DefaultCompiler):
                         yield co
                 else:
                     yield c
-        columns = [self.process(c, render_labels=True) for c in flatten_columnlist(returning_cols)]
+        columns = [self.process(c, within_columns_clause=True) for c in flatten_columnlist(returning_cols)]
         text += ' RETURNING ' + string.join(columns, ', ')
         return text
 
index aea77f8bfe647bed85561afa7e1e16b1810a41e3..dd55ac0d2f76ba56511e5061b7280bd2563cc5b3 100644 (file)
@@ -455,6 +455,7 @@ class SybaseSQLDialect(default.DefaultDialect):
     supports_unicode_statements = False
     supports_sane_rowcount = False
     supports_sane_multi_rowcount = False
+    supports_simple_order_by_label = False
 
     def __new__(cls, dbapi=None, *args, **kwargs):
         if cls != SybaseSQLDialect:
index a3ae6d456585f8afd6a00e203e44f682e07ffa32..b240130f6d55d153d9beda7fb36958ad01d5e391 100644 (file)
@@ -40,6 +40,7 @@ class DefaultDialect(base.Dialect):
     dbapi_type_map = {}
     default_paramstyle = 'named'
     supports_default_values = True
+    supports_simple_order_by_label = True
 
     def __init__(self, convert_unicode=False, assert_unicode=False, encoding='utf-8', paramstyle=None, dbapi=None, **kwargs):
         self.convert_unicode = convert_unicode
index 044e5d5fe6d76d00fc3c55ede069a58a6417eab7..4ad07b49dc32c04273ebd06ce7cbe31e10bc4e50 100644 (file)
@@ -226,17 +226,24 @@ class DefaultCompiler(engine.Compiled):
     def visit_grouping(self, grouping, **kwargs):
         return "(" + self.process(grouping.element) + ")"
 
-    def visit_label(self, label, result_map=None, render_labels=False):
-        if not render_labels:
-            return self.process(label.element)
-            
-        labelname = self._truncated_identifier("colident", label.name)
+    def visit_label(self, label, result_map=None, within_columns_clause=False, within_order_by=False):
+        # only render labels within the columns clause
+        # or ORDER BY clause of a select.  dialect-specific compilers
+        # can modify this behavior.
+        if within_columns_clause:
+            labelname = self._truncated_identifier("colident", label.name)
 
-        if result_map is not None:
-            result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
+            if result_map is not None:
+                result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
 
-        return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
+            return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
+        elif within_order_by and self.dialect.supports_simple_order_by_label:
+            labelname = self._truncated_identifier("colident", label.name)
 
+            return self.preparer.format_label(label, labelname)
+        else:
+            return self.process(label.element)
+            
     def visit_column(self, column, result_map=None, **kwargs):
 
         if column._is_oid:
@@ -304,7 +311,7 @@ class DefaultCompiler(engine.Compiled):
     def visit_null(self, null, **kwargs):
         return 'NULL'
 
-    def visit_clauselist(self, clauselist, **kwargs):
+    def visit_clauselist(self, clauselist, within_order_by=False, **kwargs):
         sep = clauselist.operator
         if sep is None:
             sep = " "
@@ -312,7 +319,7 @@ class DefaultCompiler(engine.Compiled):
             sep = ', '
         else:
             sep = " " + self.operator_string(clauselist.operator) + " "
-        return sep.join(s for s in (self.process(c) for c in clauselist.clauses)
+        return sep.join(s for s in (self.process(c, within_order_by=within_order_by) for c in clauselist.clauses)
                         if s is not None)
 
     def visit_calculatedclause(self, clause, **kwargs):
@@ -332,8 +339,8 @@ class DefaultCompiler(engine.Compiled):
         else:
             return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
 
-    def function_argspec(self, func):
-        return self.process(func.clause_expr)
+    def function_argspec(self, func, **kwargs):
+        return self.process(func.clause_expr, **kwargs)
 
     def function_string(self, func):
         return self.functions.get(func.__class__, self.functions.get(func.name, func.name + "%(expr)s"))
@@ -364,8 +371,8 @@ class DefaultCompiler(engine.Compiled):
         else:
             return text
 
-    def visit_unary(self, unary, **kwargs):
-        s = self.process(unary.element)
+    def visit_unary(self, unary, within_order_by=False, **kwargs):
+        s = self.process(unary.element, within_order_by=within_order_by)
         if unary.operator:
             s = self.operator_string(unary.operator) + " " + s
         if unary.modifier:
@@ -505,7 +512,7 @@ class DefaultCompiler(engine.Compiled):
             [c for c in [
                 self.process(
                     self.label_select_column(select, co, asfrom=asfrom), 
-                    render_labels=True,
+                    within_columns_clause=True,
                     **column_clause_args) 
                 for co in select.inner_columns
             ]
@@ -557,7 +564,7 @@ class DefaultCompiler(engine.Compiled):
         return select._distinct and "DISTINCT " or ""
 
     def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
+        order_by = self.process(select._order_by_clause, within_order_by=True)
         if order_by:
             return " ORDER BY " + order_by
         else:
index a8432853185ab2b076b98df8da12033925ec68c4..6bed07a9bd0d699f7b2e322289c92a3cbf1cf619 100644 (file)
@@ -154,6 +154,38 @@ class QueryTest(TestBase):
             assert row['anon_1'] == 8
             assert row['anon_2'] == 10
 
+    def test_order_by_label(self):
+        """test that a label within an ORDER BY works on each backend.
+        
+        simple labels in ORDER BYs now render as the actual labelname 
+        which not every database supports.
+        
+        """
+        users.insert().execute(
+            {'user_id':7, 'user_name':'jack'},
+            {'user_id':8, 'user_name':'ed'},
+            {'user_id':9, 'user_name':'fred'},
+        )
+        
+        concat = ("test: " + users.c.user_name).label('thedata')
+        self.assertEquals(
+            select([concat]).order_by(concat).execute().fetchall(),
+            [("test: ed",), ("test: fred",), ("test: jack",)]
+        )
+
+        concat = ("test: " + users.c.user_name).label('thedata')
+        self.assertEquals(
+            select([concat]).order_by(desc(concat)).execute().fetchall(),
+            [("test: jack",), ("test: fred",), ("test: ed",)]
+        )
+
+        concat = ("test: " + users.c.user_name).label('thedata')
+        self.assertEquals(
+            select([concat]).order_by(concat + "x").execute().fetchall(),
+            [("test: ed",), ("test: fred",), ("test: jack",)]
+        )
+        
+        
     def test_row_comparison(self):
         users.insert().execute(user_id = 7, user_name = 'jack')
         rp = users.select().execute().fetchone()
index 70d21798c64f682724200545d1694f6900cdf415..b4e47c3e06cbf62689f04ccd31805269b106abe5 100644 (file)
@@ -3,6 +3,7 @@ import datetime, re, operator
 from sqlalchemy import *
 from sqlalchemy import exc, sql, util
 from sqlalchemy.sql import table, column, compiler
+from sqlalchemy.engine import default
 from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
 from testlib import *
 
@@ -326,6 +327,54 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         x = func.lala(table1.c.myid).label('foo')
         self.assert_compile(select([x], x==5), "SELECT lala(mytable.myid) AS foo FROM mytable WHERE lala(mytable.myid) = :param_1")
 
+    def test_labels_in_expressions(self):
+        """test that label() constructs in ORDER BY render as the labelname.
+        
+        Postgres' behavior was used as the guide for this,
+        so that only a simple label expression 
+        and not a more complex expression involving the label
+        name would be rendered using the label name.
+        
+        """
+        lab1 = (table1.c.myid + "12").label('foo')
+        lab2 = func.somefunc(table1.c.name).label('bar')
+
+        dialect = default.DefaultDialect()
+        self.assert_compile(select([lab1, lab2]).order_by(lab1, desc(lab2)), 
+            "SELECT mytable.myid + :myid_1 AS foo, somefunc(mytable.name) AS bar FROM mytable ORDER BY foo, bar DESC",
+            dialect=dialect
+        )
+
+        # the function embedded label renders as the function
+        self.assert_compile(select([lab1, lab2]).order_by(func.hoho(lab1), desc(lab2)), 
+            "SELECT mytable.myid + :myid_1 AS foo, somefunc(mytable.name) AS bar FROM mytable ORDER BY hoho(mytable.myid + :myid_1), bar DESC",
+            dialect=dialect
+        )
+
+        # binary expressions render as the expression without labels
+        self.assert_compile(select([lab1, lab2]).order_by(lab1 + "test"), 
+            "SELECT mytable.myid + :myid_1 AS foo, somefunc(mytable.name) AS bar FROM mytable ORDER BY mytable.myid + :myid_1 + :param_1",
+            dialect=dialect
+        )
+
+        # labels within functions in the columns clause render with the expression
+        self.assert_compile(
+            select([lab1, func.foo(lab1)]), 
+            "SELECT mytable.myid + :myid_1 AS foo, foo(mytable.myid + :myid_1) AS foo_1 FROM mytable",
+            dialect=dialect
+            )
+
+        dialect = default.DefaultDialect()
+        dialect.supports_simple_order_by_label = False
+        self.assert_compile(select([lab1, lab2]).order_by(lab1, desc(lab2)), 
+            "SELECT mytable.myid + :myid_1 AS foo, somefunc(mytable.name) AS bar FROM mytable ORDER BY mytable.myid + :myid_1, somefunc(mytable.name) DESC",
+            dialect=dialect
+        )
+        self.assert_compile(select([lab1, lab2]).order_by(func.hoho(lab1), desc(lab2)), 
+            "SELECT mytable.myid + :myid_1 AS foo, somefunc(mytable.name) AS bar FROM mytable ORDER BY hoho(mytable.myid + :myid_1), somefunc(mytable.name) DESC",
+            dialect=dialect
+        )
+
     def test_conjunctions(self):
         self.assert_compile(
             and_(table1.c.myid == 12, table1.c.name=='asdf', table2.c.othername == 'foo', "sysdate() = today()"),