From: Mike Bayer Date: Thu, 4 Nov 2021 01:26:44 +0000 (-0400) Subject: use ExpressionElementRole for case targets in case() X-Git-Tag: rel_2_0_0b1~670 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=77a17797ecc08736ea942e29f79df4f96bd74e0c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git use ExpressionElementRole for case targets in case() Fixed regression where the :func:`_sql.text` construct would no longer be accepted as a target case in the "whens" list within a :func:`_sql.case` construct. The regression appears related to an attempt to guard against some forms of literal values that were considered to be ambiguous when passed here; however, there's no reason the target cases shouldn't be interpreted as open-ended SQL expressions just like anywhere else, and a literal string or tuple will be converted to a bound parameter as would be the case elsewhere. Fixes: #7287 Change-Id: I75478adfa115f3292cb1362cc5b2fdf152b0ed6f --- diff --git a/doc/build/changelog/unreleased_14/7287.rst b/doc/build/changelog/unreleased_14/7287.rst new file mode 100644 index 0000000000..14c72a8aff --- /dev/null +++ b/doc/build/changelog/unreleased_14/7287.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 7287 + + Fixed regression where the :func:`_sql.text` construct would no longer be + accepted as a target case in the "whens" list within a :func:`_sql.case` + construct. The regression appears related to an attempt to guard against + some forms of literal values that were considered to be ambiguous when + passed here; however, there's no reason the target cases shouldn't be + interpreted as open-ended SQL expressions just like anywhere else, and a + literal string or tuple will be converted to a bound parameter as would be + the case elsewhere. \ No newline at end of file diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index f1fe46fd23..c8faebbd9d 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2943,28 +2943,18 @@ class Case(ColumnElement): pass value = kw.pop("value", None) - if value is not None: - whenlist = [ - ( - coercions.expect( - roles.ExpressionElementRole, - c, - apply_propagate_attrs=self, - ).self_group(), - coercions.expect(roles.ExpressionElementRole, r), - ) - for (c, r) in whens - ] - else: - whenlist = [ - ( - coercions.expect( - roles.ColumnArgumentRole, c, apply_propagate_attrs=self - ).self_group(), - coercions.expect(roles.ExpressionElementRole, r), - ) - for (c, r) in whens - ] + + whenlist = [ + ( + coercions.expect( + roles.ExpressionElementRole, + c, + apply_propagate_attrs=self, + ).self_group(), + coercions.expect(roles.ExpressionElementRole, r), + ) + for (c, r) in whens + ] if whenlist: type_ = list(whenlist[-1])[-1].type diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index 63491524c2..db7f16194f 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -2,7 +2,7 @@ from sqlalchemy import and_ from sqlalchemy import case from sqlalchemy import cast from sqlalchemy import Column -from sqlalchemy import exc +from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import literal_column from sqlalchemy import MetaData @@ -13,7 +13,6 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.sql import column from sqlalchemy.sql import table -from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -125,23 +124,62 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): ], ) - def test_literal_interpretation_ambiguous(self): - assert_raises_message( - exc.ArgumentError, - r"Column expression expected, got 'x'", - case, - ("x", "y"), + def test_literal_interpretation_one(self): + """note this is modified as of #7287 to accept strings, tuples + and other literal values as input + where they are interpreted as bound values just like any other + expression. + + Previously, an exception would be raised that the literal was + ambiguous. + + + """ + self.assert_compile( + case(("x", "y")), + "CASE WHEN :param_1 THEN :param_2 END", + checkparams={"param_1": "x", "param_2": "y"}, ) - def test_literal_interpretation_ambiguous_tuple(self): - assert_raises_message( - exc.ArgumentError, - r"Column expression expected, got \('x', 'y'\)", - case, - (("x", "y"), "z"), + def test_literal_interpretation_two(self): + """note this is modified as of #7287 to accept strings, tuples + and other literal values as input + where they are interpreted as bound values just like any other + expression. + + Previously, an exception would be raised that the literal was + ambiguous. + + + """ + self.assert_compile( + case( + (("x", "y"), "z"), + ), + "CASE WHEN :param_1 THEN :param_2 END", + checkparams={"param_1": ("x", "y"), "param_2": "z"}, ) - def test_literal_interpretation(self): + def test_literal_interpretation_two_point_five(self): + """note this is modified as of #7287 to accept strings, tuples + and other literal values as input + where they are interpreted as bound values just like any other + expression. + + Previously, an exception would be raised that the literal was + ambiguous. + + + """ + self.assert_compile( + case( + (12, "z"), + ), + "CASE WHEN :param_1 THEN :param_2 END", + checkparams={"param_1": 12, "param_2": "z"}, + ) + + def test_literal_interpretation_three(self): t = table("test", column("col1")) self.assert_compile( @@ -220,6 +258,16 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): [("no",), ("no",), ("no",), ("yes",), ("no",), ("no",)], ) + def test_text_doenst_explode_even_in_whenlist(self): + """test #7287""" + self.assert_compile( + case( + (text(":case = 'upper'"), func.upper(literal_column("q"))), + else_=func.lower(literal_column("q")), + ), + "CASE WHEN :case = 'upper' THEN upper(q) ELSE lower(q) END", + ) + def testcase_with_dict(self): query = select( case( diff --git a/test/sql/test_text.py b/test/sql/test_text.py index 15f6f60486..4fff0ed7ef 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -6,6 +6,7 @@ from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import desc from sqlalchemy import exc +from sqlalchemy import extract from sqlalchemy import Float from sqlalchemy import func from sqlalchemy import Integer @@ -182,6 +183,14 @@ class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL): "(select f from bar where lala=heyhey) foo WHERE foo.f = t.id", ) + def test_expression_element_role(self): + """test #7287""" + + self.assert_compile( + extract("year", text("some_date + :param")), + "EXTRACT(year FROM some_date + :param)", + ) + @testing.combinations( ( None,