From: Ants Aasma Date: Thu, 3 Apr 2008 14:08:22 +0000 (+0000) Subject: The case() function now also takes a dictionary as its whens parameter. But beware... X-Git-Tag: rel_0_4_5~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=921efb250c284ee121e9fc9d0f12eb2612047645;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git The case() function now also takes a dictionary as its whens parameter. But beware that it doesn't escape literals, use the literal construct for that. --- diff --git a/CHANGES b/CHANGES index 2f30c883cf..53eb7683ee 100644 --- a/CHANGES +++ b/CHANGES @@ -195,6 +195,10 @@ CHANGES queries with lots of eager loads might have seen this symptom. + - The case() function now also takes a dictionary as its whens + parameter. But beware that it doesn't escape literals, use + the literal construct for that. + - declarative extension - The "synonym" function is now directly usable with "declarative". Pass in the decorated property using the diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index c487ee173d..cc97227a70 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -416,7 +416,7 @@ def case(whens, value=None, else_=None): """Produce a ``CASE`` statement. whens - A sequence of pairs to be translated into "when / then" clauses. + A sequence of pairs or a dict to be translated into "when / then" clauses. value Optional for simple case statements. @@ -425,6 +425,11 @@ def case(whens, value=None, else_=None): Optional as well, for case defaults. """ + try: + whens = util.dictlike_iteritems(whens) + except TypeError: + pass + whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) for (c,r) in whens] if not else_ is None: diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index ab68b52109..730517b210 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -2,6 +2,7 @@ import testenv; testenv.configure_for_tests() import sys from sqlalchemy import * from testlib import * +from sqlalchemy import util class CaseTest(TestBase): @@ -86,5 +87,37 @@ class CaseTest(TestBase): (0, 6, 'pk_6_data') ] + @testing.fails_on('maxdb') + def testcase_with_dict(self): + query = select([case({ + info_table.c.pk < 3: literal('lessthan3'), + info_table.c.pk >= 3: literal('gt3'), + }, else_=literal('other')), + info_table.c.pk, info_table.c.info + ], + from_obj=[info_table]) + assert query.execute().fetchall() == [ + ('lessthan3', 1, 'pk_1_data'), + ('lessthan3', 2, 'pk_2_data'), + ('gt3', 3, 'pk_3_data'), + ('gt3', 4, 'pk_4_data'), + ('gt3', 5, 'pk_5_data'), + ('gt3', 6, 'pk_6_data') + ] + + simple_query = select([case({ + 1: literal('one'), + 2: literal('two'), + }, value=info_table.c.pk, else_=literal('other')), + info_table.c.pk + ], + whereclause=info_table.c.pk < 4, + from_obj=[info_table]) + assert simple_query.execute().fetchall() == [ + ('one', 1), + ('two', 2), + ('other', 3), + ] + if __name__ == "__main__": testenv.main()