]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- case() interprets the "THEN" expressions
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Apr 2008 16:34:03 +0000 (16:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Apr 2008 16:34:03 +0000 (16:34 +0000)
as values by default, meaning case([(x==y, "foo")]) will
interpret "foo" as a bound value, not a SQL expression.
use text(expr) for literal SQL expressions in this case.
For the criterion itself, these may be literal strings
only if the "value" keyword is present, otherwise SA
will force explicit usage of either text() or literal().

CHANGES
lib/sqlalchemy/sql/expression.py
test/sql/case_statement.py

diff --git a/CHANGES b/CHANGES
index 53eb7683ee8b1877a2381cfa581ddd879121832a..a83db182c18de8beae3c60562822b4dac8355eca 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -196,8 +196,13 @@ CHANGES
       symptom.
 
     - The case() function now also takes a dictionary as its whens
-      parameter. But beware that it doesn't escape literals, use
-      the literal construct for that.
+      parameter.  It also interprets the "THEN" expressions
+      as values by default, meaning case([(x==y, "foo")]) will
+      interpret "foo" as a bound value, not a SQL expression.
+      use text(expr) for literal SQL expressions in this case.
+      For the criterion itself, these may be literal strings
+      only if the "value" keyword is present, otherwise SA
+      will force explicit usage of either text() or literal().
 
 - declarative extension
     - The "synonym" function is now directly usable with
index cc97227a702832cf1a6b9138ae332bfd4641dfc4..39a2ae3eb93bb5d6b88882218caafc5903d3e3bb 100644 (file)
@@ -392,7 +392,7 @@ def not_(clause):
     result.
     """
 
-    return operators.inv(clause)
+    return operators.inv(_literal_as_binds(clause))
 
 def distinct(expr):
     """Return a ``DISTINCT`` clause."""
@@ -416,24 +416,45 @@ def case(whens, value=None, else_=None):
     """Produce a ``CASE`` statement.
 
     whens
-      A sequence of pairs or a dict to be translated into "when / then" clauses.
+      A sequence of pairs, or alternatively a dict,
+      to be translated into "WHEN / THEN" clauses.
 
     value
-      Optional for simple case statements.
+      Optional for simple case statements, produces
+      a column expression as in "CASE <expr> WHEN ..."
 
     else\_
-      Optional as well, for case defaults.
+      Optional as well, for case defaults produces 
+      the "ELSE" portion of the "CASE" statement.
+    
+    The expressions used for THEN and ELSE,
+    when specified as strings, will be interpreted 
+    as bound values. To specify textual SQL expressions 
+    for these, use the text(<string>) construct.
+    
+    The expressions used for the WHEN criterion
+    may only be literal strings when "value" is 
+    present, i.e. CASE table.somecol WHEN "x" THEN "y".  
+    Otherwise, literal strings are not accepted 
+    in this position, and either the text(<string>)
+    or literal(<string>) constructs must be used to 
+    interpret raw string values.
+      
     """
-
     try:
         whens = util.dictlike_iteritems(whens)
     except TypeError:
         pass
-
-    whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None)
+    
+    if value:
+        crit_filter = _literal_as_binds
+    else:
+        crit_filter = _no_literals
+        
+    whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None)
                 for (c,r) in whens]
-    if not else_ is None:
-        whenlist.append(ClauseList('ELSE', else_, operator=None))
+    if else_ is not None:
+        whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None))
     if whenlist:
         type = list(whenlist[-1])[-1].type
     else:
@@ -842,6 +863,14 @@ def _literal_as_binds(element, name=None, type_=None):
     else:
         return element
 
+def _no_literals(element):
+    if isinstance(element, Operators):
+        return element.expression_element()
+    elif _is_literal(element):
+        raise exceptions.ArgumentError("Ambiguous literal: %r.  Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+    else:
+        return element
+    
 def _corresponding_column_or_error(fromclause, column, require_embedded=False):
     c = fromclause.corresponding_column(column, require_embedded=require_embedded)
     if not c:
index 730517b21080db4a16ba2631ebc2588e125b3e2c..257298c8e5a718474c26c3055c96f8940d57a4eb 100644 (file)
@@ -2,10 +2,11 @@ import testenv; testenv.configure_for_tests()
 import sys
 from sqlalchemy import *
 from testlib import *
-from sqlalchemy import util
+from sqlalchemy import util, exceptions
+from sqlalchemy.sql import table, column
 
 
-class CaseTest(TestBase):
+class CaseTest(TestBase, AssertsCompiledSQL):
 
     def setUpAll(self):
         metadata = MetaData(testing.db)
@@ -30,9 +31,9 @@ class CaseTest(TestBase):
     def testcase(self):
         inner = select([case([
                 [info_table.c.pk < 3,
-                        literal('lessthan3', type_=String)],
+                        'lessthan3'],
         [and_(info_table.c.pk >= 3, info_table.c.pk < 7),
-                        literal('gt3', type_=String)]]).label('x'),
+                        'gt3']]).label('x'),
         info_table.c.pk, info_table.c.info],
                 from_obj=[info_table]).alias('q_inner')
 
@@ -69,9 +70,9 @@ class CaseTest(TestBase):
 
         w_else = select([case([
                 [info_table.c.pk < 3,
-                        literal(3, type_=Integer)],
+                        3],
         [and_(info_table.c.pk >= 3, info_table.c.pk < 6),
-                        literal(6, type_=Integer)]],
+                        6]],
                 else_ = 0).label('x'),
         info_table.c.pk, info_table.c.info],
                 from_obj=[info_table]).alias('q_inner')
@@ -87,12 +88,21 @@ class CaseTest(TestBase):
             (0, 6, 'pk_6_data')
         ]
 
+    def test_literal_interpretation(self):
+        t = table('test', column('col1'))
+        
+        self.assertRaises(exceptions.ArgumentError, case, [("x", "y")])
+        
+        self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END")
+        self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :test_col1_1) THEN :param_1 ELSE :param_2 END")
+
+        
     @testing.fails_on('maxdb')
     def testcase_with_dict(self):
         query = select([case({
-                    info_table.c.pk < 3: literal('lessthan3'),
-                    info_table.c.pk >= 3: literal('gt3'),
-                }, else_=literal('other')),
+                    info_table.c.pk < 3: 'lessthan3',
+                    info_table.c.pk >= 3: 'gt3',
+                }, else_='other'),
                 info_table.c.pk, info_table.c.info
             ],
             from_obj=[info_table])
@@ -106,13 +116,14 @@ class CaseTest(TestBase):
         ]
 
         simple_query = select([case({
-                    1: literal('one'),
-                    2: literal('two'),
-                }, value=info_table.c.pk, else_=literal('other')),
+                    1: 'one',
+                    2: 'two',
+                }, value=info_table.c.pk, else_='other'),
                 info_table.c.pk
             ],
             whereclause=info_table.c.pk < 4,
             from_obj=[info_table])
+        
         assert simple_query.execute().fetchall() == [
             ('one', 1),
             ('two', 2),