]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix CASE statement when else_ is zero
authorRick Morrison <rickmorrison@gmail.com>
Thu, 15 Mar 2007 01:58:46 +0000 (01:58 +0000)
committerRick Morrison <rickmorrison@gmail.com>
Thu, 15 Mar 2007 01:58:46 +0000 (01:58 +0000)
CHANGES
lib/sqlalchemy/sql.py
test/sql/case_statement.py

diff --git a/CHANGES b/CHANGES
index 01348d468b6e6cf3706f866e4dc9ae14889ed363..745922297b3e5001b5d3ca74d1e4ad1844258f53 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -45,6 +45,9 @@
     with an easy method of limiting a traversal (just pass flags which
     are picked up by appropriate get_children() methods). [ticket:501]
 
+    - the "else_" parameter to the case statement now properly works when
+    set to zero.
+
 - oracle:
     - got binary working for any size input !  cx_oracle works fine,
       it was my fault as BINARY was being passed and not BLOB for
index 83124069732692f7296af0cda1d6d4477a9aaefe..d48e385f25e7939e259825ea9a1c70eb43177ed2 100644 (file)
@@ -243,7 +243,7 @@ def case(whens, value=None, else_=None):
     """
 
     whenlist = [_CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
-    if else_:
+    if not else_ is None:
         whenlist.append(_CompoundClause(None, 'ELSE', else_))
     cc = _CalculatedClause(None, 'CASE', value, *whenlist + ['END'])
     for c in cc.clauses:
index 1a75d5289ee84e0786585441394a8ac4dada73d4..946279b9dcf52248717c29365a7836655296722d 100644 (file)
@@ -17,15 +17,20 @@ class CaseTest(testbase.PersistTest):
                {'pk':1, 'info':'pk_1_data'},
                {'pk':2, 'info':'pk_2_data'},
                {'pk':3, 'info':'pk_3_data'},
-               {'pk':4, 'info':'pk_4_data'},
-           {'pk':5, 'info':'pk_5_data'})
+               {'pk':4, 'info':'pk_4_data'},
+               {'pk':5, 'info':'pk_5_data'},
+               {'pk':6, 'info':'pk_6_data'})
     def tearDownAll(self):
         info_table.drop()
     
     def testcase(self):
-        inner = select([case([[info_table.c.pk < 3, literal('lessthan3', type=String)],
-               [info_table.c.pk >= 3, literal('gt3', type=String)]]).label('x'),
-               info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner')
+        inner = select([case([
+               [info_table.c.pk < 3, 
+                        literal('lessthan3', type=String)],
+               [and_(info_table.c.pk >= 3, info_table.c.pk < 7), 
+                        literal('gt3', type=String)]]).label('x'),
+               info_table.c.pk, info_table.c.info], 
+                from_obj=[info_table]).alias('q_inner')
 
         inner_result = inner.execute().fetchall()
 
@@ -35,12 +40,14 @@ class CaseTest(testbase.PersistTest):
         # gt3 3 pk_3_data
         # gt3 4 pk_4_data
         # gt3 5 pk_5_data
+        # gt3 6 pk_6_data
         assert inner_result == [
             ('lessthan3', 1, 'pk_1_data'),
             ('lessthan3', 2, 'pk_2_data'),
             ('gt3', 3, 'pk_3_data'),
             ('gt3', 4, 'pk_4_data'),
             ('gt3', 5, 'pk_5_data'),
+            ('gt3', 6, 'pk_6_data')
         ]
 
         outer = select([inner])
@@ -53,6 +60,27 @@ class CaseTest(testbase.PersistTest):
             ('gt3', 3, 'pk_3_data'),
             ('gt3', 4, 'pk_4_data'),
             ('gt3', 5, 'pk_5_data'),
+            ('gt3', 6, 'pk_6_data')
+        ]
+
+        w_else = select([case([
+               [info_table.c.pk < 3, 
+                        literal(3, type=Integer)],
+               [and_(info_table.c.pk >= 3, info_table.c.pk < 6), 
+                        literal(6, type=Integer)]],
+                else_ = 0).label('x'),
+               info_table.c.pk, info_table.c.info], 
+                from_obj=[info_table]).alias('q_inner')
+
+        else_result = w_else.execute().fetchall()
+
+        assert else_result == [
+            (3, 1, 'pk_1_data'),
+            (3, 2, 'pk_2_data'),
+            (6, 3, 'pk_3_data'),
+            (6, 4, 'pk_4_data'),
+            (6, 5, 'pk_5_data'),
+            (0, 6, 'pk_6_data')
         ]
 
 if __name__ == "__main__":