From: Mike Bayer Date: Fri, 19 Mar 2021 14:34:31 +0000 (-0400) Subject: Correct for coercion from list args to positional for case X-Git-Tag: rel_1_4_2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ed515f2ca16e1b40efe5ee0299417f8d6eb51b86;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Correct for coercion from list args to positional for case Fixed regression in the :func:`_sql.case` construct, where the "dictionary" form of argument specification failed to work correctly if it were passed positionally, rather than as a "whens" keyword argument. Fixes: #6097 Change-Id: I4138f54309a08c8e4e63cfafc211176e0b9a76c7 --- diff --git a/doc/build/changelog/unreleased_14/6097.rst b/doc/build/changelog/unreleased_14/6097.rst new file mode 100644 index 0000000000..ade52860d2 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6097.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 6097 + + Fixed regression in the :func:`_sql.case` construct, where the "dictionary" + form of argument specification failed to work correctly if it were passed + positionally, rather than as a "whens" keyword argument. diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 76ba7e2146..35ac1a5ba1 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -98,7 +98,7 @@ def _document_text_coercion(paramname, meth_rst, param_rst): def _expression_collection_was_a_list(attrname, fnname, args): - if args and isinstance(args[0], (list, set)) and len(args) == 1: + if args and isinstance(args[0], (list, set, dict)) and len(args) == 1: util.warn_deprecated_20( 'The "%s" argument to %s() is now passed as a series of ' "positional " diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 26c03b57bc..b3b0413856 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2618,7 +2618,7 @@ class Case(ColumnElement): acts somewhat analogously to an "if/then" construct in other languages. It returns an instance of :class:`.Case`. - :func:`.case` in its usual form is passed a list of "when" + :func:`.case` in its usual form is passed a series of "when" constructs, that is, a list of conditions and results as tuples:: from sqlalchemy import case @@ -2653,7 +2653,7 @@ class Case(ColumnElement): stmt = select(users_table).\ where( case( - {"wendy": "W", "jack": "J"}, + whens={"wendy": "W", "jack": "J"}, value=users_table.c.name, else_='E' ) diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index b44971cecd..7dd66840f5 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -154,6 +154,53 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END", ) + @testing.combinations( + ( + (lambda t: ({"x": "y"}, t.c.col1, None)), + "CASE test.col1 WHEN :param_1 THEN :param_2 END", + ), + ( + (lambda t: ({"x": "y", "p": "q"}, t.c.col1, None)), + "CASE test.col1 WHEN :param_1 THEN :param_2 " + "WHEN :param_3 THEN :param_4 END", + ), + ( + (lambda t: ({t.c.col1 == 7: "x"}, None, 10)), + "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END", + ), + ( + (lambda t: ({t.c.col1 == 7: "x", t.c.col1 == 10: "y"}, None, 10)), + "CASE WHEN (test.col1 = :col1_1) THEN :param_1 " + "WHEN (test.col1 = :col1_2) THEN :param_2 ELSE :param_3 END", + ), + argnames="test_case, expected", + ) + @testing.combinations(("positional",), ("kwarg",), argnames="argstyle") + def test_when_dicts(self, argstyle, test_case, expected): + t = table("test", column("col1")) + + whens, value, else_ = testing.resolve_lambda(test_case, t=t) + + def _case_args(whens, value=None, else_=None): + kw = {} + if value is not None: + kw["value"] = value + if else_ is not None: + kw["else_"] = else_ + + if argstyle == "kwarg": + return case(whens=whens, **kw) + elif argstyle == "positional": + return case(whens, **kw) + + # note: 1.3 also does not allow this form + # case([whens], **kw) + + self.assert_compile( + _case_args(whens=whens, value=value, else_=else_), + expected, + ) + def test_text_doesnt_explode(self, connection): for s in [