From abb10856dcea07ca4d38d28df4e493d11d8fd345 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 3 Apr 2008 16:34:03 +0000 Subject: [PATCH] - case() interprets the "THEN" expressions as values by default, meaning case([(x==y, "foo")]) will interpret "foo" as a bound value, not a SQL expression. use text(expr) for literal SQL expressions in this case. For the criterion itself, these may be literal strings only if the "value" keyword is present, otherwise SA will force explicit usage of either text() or literal(). --- CHANGES | 9 ++++-- lib/sqlalchemy/sql/expression.py | 47 ++++++++++++++++++++++++++------ test/sql/case_statement.py | 35 ++++++++++++++++-------- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/CHANGES b/CHANGES index 53eb7683ee..a83db182c1 100644 --- a/CHANGES +++ b/CHANGES @@ -196,8 +196,13 @@ CHANGES symptom. - The case() function now also takes a dictionary as its whens - parameter. But beware that it doesn't escape literals, use - the literal construct for that. + parameter. It also interprets the "THEN" expressions + as values by default, meaning case([(x==y, "foo")]) will + interpret "foo" as a bound value, not a SQL expression. + use text(expr) for literal SQL expressions in this case. + For the criterion itself, these may be literal strings + only if the "value" keyword is present, otherwise SA + will force explicit usage of either text() or literal(). - declarative extension - The "synonym" function is now directly usable with diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index cc97227a70..39a2ae3eb9 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -392,7 +392,7 @@ def not_(clause): result. """ - return operators.inv(clause) + return operators.inv(_literal_as_binds(clause)) def distinct(expr): """Return a ``DISTINCT`` clause.""" @@ -416,24 +416,45 @@ def case(whens, value=None, else_=None): """Produce a ``CASE`` statement. whens - A sequence of pairs or a dict to be translated into "when / then" clauses. + A sequence of pairs, or alternatively a dict, + to be translated into "WHEN / THEN" clauses. value - Optional for simple case statements. + Optional for simple case statements, produces + a column expression as in "CASE WHEN ..." else\_ - Optional as well, for case defaults. + Optional as well, for case defaults produces + the "ELSE" portion of the "CASE" statement. + + The expressions used for THEN and ELSE, + when specified as strings, will be interpreted + as bound values. To specify textual SQL expressions + for these, use the text() construct. + + The expressions used for the WHEN criterion + may only be literal strings when "value" is + present, i.e. CASE table.somecol WHEN "x" THEN "y". + Otherwise, literal strings are not accepted + in this position, and either the text() + or literal() constructs must be used to + interpret raw string values. + """ - try: whens = util.dictlike_iteritems(whens) except TypeError: pass - - whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) + + if value: + crit_filter = _literal_as_binds + else: + crit_filter = _no_literals + + whenlist = [ClauseList('WHEN', crit_filter(c), 'THEN', _literal_as_binds(r), operator=None) for (c,r) in whens] - if not else_ is None: - whenlist.append(ClauseList('ELSE', else_, operator=None)) + if else_ is not None: + whenlist.append(ClauseList('ELSE', _literal_as_binds(else_), operator=None)) if whenlist: type = list(whenlist[-1])[-1].type else: @@ -842,6 +863,14 @@ def _literal_as_binds(element, name=None, type_=None): else: return element +def _no_literals(element): + if isinstance(element, Operators): + return element.expression_element() + elif _is_literal(element): + raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + else: + return element + def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: diff --git a/test/sql/case_statement.py b/test/sql/case_statement.py index 730517b210..257298c8e5 100644 --- a/test/sql/case_statement.py +++ b/test/sql/case_statement.py @@ -2,10 +2,11 @@ import testenv; testenv.configure_for_tests() import sys from sqlalchemy import * from testlib import * -from sqlalchemy import util +from sqlalchemy import util, exceptions +from sqlalchemy.sql import table, column -class CaseTest(TestBase): +class CaseTest(TestBase, AssertsCompiledSQL): def setUpAll(self): metadata = MetaData(testing.db) @@ -30,9 +31,9 @@ class CaseTest(TestBase): def testcase(self): inner = select([case([ [info_table.c.pk < 3, - literal('lessthan3', type_=String)], + 'lessthan3'], [and_(info_table.c.pk >= 3, info_table.c.pk < 7), - literal('gt3', type_=String)]]).label('x'), + 'gt3']]).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -69,9 +70,9 @@ class CaseTest(TestBase): w_else = select([case([ [info_table.c.pk < 3, - literal(3, type_=Integer)], + 3], [and_(info_table.c.pk >= 3, info_table.c.pk < 6), - literal(6, type_=Integer)]], + 6]], else_ = 0).label('x'), info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') @@ -87,12 +88,21 @@ class CaseTest(TestBase): (0, 6, 'pk_6_data') ] + def test_literal_interpretation(self): + t = table('test', column('col1')) + + self.assertRaises(exceptions.ArgumentError, case, [("x", "y")]) + + self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END") + self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :test_col1_1) THEN :param_1 ELSE :param_2 END") + + @testing.fails_on('maxdb') def testcase_with_dict(self): query = select([case({ - info_table.c.pk < 3: literal('lessthan3'), - info_table.c.pk >= 3: literal('gt3'), - }, else_=literal('other')), + info_table.c.pk < 3: 'lessthan3', + info_table.c.pk >= 3: 'gt3', + }, else_='other'), info_table.c.pk, info_table.c.info ], from_obj=[info_table]) @@ -106,13 +116,14 @@ class CaseTest(TestBase): ] simple_query = select([case({ - 1: literal('one'), - 2: literal('two'), - }, value=info_table.c.pk, else_=literal('other')), + 1: 'one', + 2: 'two', + }, value=info_table.c.pk, else_='other'), info_table.c.pk ], whereclause=info_table.c.pk < 4, from_obj=[info_table]) + assert simple_query.execute().fetchall() == [ ('one', 1), ('two', 2), -- 2.47.3