]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
The case() function now also takes a dictionary as its whens parameter. But beware...
authorAnts Aasma <ants.aasma@gmail.com>
Thu, 3 Apr 2008 14:08:22 +0000 (14:08 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Thu, 3 Apr 2008 14:08:22 +0000 (14:08 +0000)
CHANGES
lib/sqlalchemy/sql/expression.py
test/sql/case_statement.py

diff --git a/CHANGES b/CHANGES
index 2f30c883cf7cd02e5acd718d547d7f19e390703b..53eb7683ee8b1877a2381cfa581ddd879121832a 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -195,6 +195,10 @@ CHANGES
       queries with lots of eager loads might have seen this
       symptom.
 
+    - The case() function now also takes a dictionary as its whens
+      parameter. But beware that it doesn't escape literals, use
+      the literal construct for that.
+
 - declarative extension
     - The "synonym" function is now directly usable with
       "declarative".  Pass in the decorated property using the
index c487ee173d01df4a621a72d408bb001398a84f39..cc97227a702832cf1a6b9138ae332bfd4641dfc4 100644 (file)
@@ -416,7 +416,7 @@ def case(whens, value=None, else_=None):
     """Produce a ``CASE`` statement.
 
     whens
-      A sequence of pairs to be translated into "when / then" clauses.
+      A sequence of pairs or a dict to be translated into "when / then" clauses.
 
     value
       Optional for simple case statements.
@@ -425,6 +425,11 @@ def case(whens, value=None, else_=None):
       Optional as well, for case defaults.
     """
 
+    try:
+        whens = util.dictlike_iteritems(whens)
+    except TypeError:
+        pass
+
     whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None)
                 for (c,r) in whens]
     if not else_ is None:
index ab68b52109f4f91ad521c3a9d88aae4f2037b12d..730517b21080db4a16ba2631ebc2588e125b3e2c 100644 (file)
@@ -2,6 +2,7 @@ import testenv; testenv.configure_for_tests()
 import sys
 from sqlalchemy import *
 from testlib import *
+from sqlalchemy import util
 
 
 class CaseTest(TestBase):
@@ -86,5 +87,37 @@ class CaseTest(TestBase):
             (0, 6, 'pk_6_data')
         ]
 
+    @testing.fails_on('maxdb')
+    def testcase_with_dict(self):
+        query = select([case({
+                    info_table.c.pk < 3: literal('lessthan3'),
+                    info_table.c.pk >= 3: literal('gt3'),
+                }, else_=literal('other')),
+                info_table.c.pk, info_table.c.info
+            ],
+            from_obj=[info_table])
+        assert query.execute().fetchall() == [
+            ('lessthan3', 1, 'pk_1_data'),
+            ('lessthan3', 2, 'pk_2_data'),
+            ('gt3', 3, 'pk_3_data'),
+            ('gt3', 4, 'pk_4_data'),
+            ('gt3', 5, 'pk_5_data'),
+            ('gt3', 6, 'pk_6_data')
+        ]
+
+        simple_query = select([case({
+                    1: literal('one'),
+                    2: literal('two'),
+                }, value=info_table.c.pk, else_=literal('other')),
+                info_table.c.pk
+            ],
+            whereclause=info_table.c.pk < 4,
+            from_obj=[info_table])
+        assert simple_query.execute().fetchall() == [
+            ('one', 1),
+            ('two', 2),
+            ('other', 3),
+        ]
+
 if __name__ == "__main__":
     testenv.main()