]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Correct for coercion from list args to positional for case
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2021 14:34:31 +0000 (10:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2021 14:34:31 +0000 (10:34 -0400)
Fixed regression in the :func:`_sql.case` construct, where the "dictionary"
form of argument specification failed to work correctly if it were passed
positionally, rather than as a "whens" keyword argument.

Fixes: #6097
Change-Id: I4138f54309a08c8e4e63cfafc211176e0b9a76c7

doc/build/changelog/unreleased_14/6097.rst [new file with mode: 0644]
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
test/sql/test_case_statement.py

diff --git a/doc/build/changelog/unreleased_14/6097.rst b/doc/build/changelog/unreleased_14/6097.rst
new file mode 100644 (file)
index 0000000..ade5286
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 6097
+
+    Fixed regression in the :func:`_sql.case` construct, where the "dictionary"
+    form of argument specification failed to work correctly if it were passed
+    positionally, rather than as a "whens" keyword argument.
index 76ba7e2146933c5c9639a112774e49979becdc24..35ac1a5ba117512b0d95e60f825985b1f47cf2a9 100644 (file)
@@ -98,7 +98,7 @@ def _document_text_coercion(paramname, meth_rst, param_rst):
 
 
 def _expression_collection_was_a_list(attrname, fnname, args):
-    if args and isinstance(args[0], (list, set)) and len(args) == 1:
+    if args and isinstance(args[0], (list, set, dict)) and len(args) == 1:
         util.warn_deprecated_20(
             'The "%s" argument to %s() is now passed as a series of '
             "positional "
index 26c03b57bcf29dbd881b920f948daf1d57ecc311..b3b0413856d725b09ac518078693c9088de59231 100644 (file)
@@ -2618,7 +2618,7 @@ class Case(ColumnElement):
         acts somewhat analogously to an "if/then" construct in other
         languages.  It returns an instance of :class:`.Case`.
 
-        :func:`.case` in its usual form is passed a list of "when"
+        :func:`.case` in its usual form is passed a series of "when"
         constructs, that is, a list of conditions and results as tuples::
 
             from sqlalchemy import case
@@ -2653,7 +2653,7 @@ class Case(ColumnElement):
             stmt = select(users_table).\
                         where(
                             case(
-                                {"wendy": "W", "jack": "J"},
+                                whens={"wendy": "W", "jack": "J"},
                                 value=users_table.c.name,
                                 else_='E'
                             )
index b44971cecd24b89efe046c59b1300d13a2dabaf9..7dd66840f5d122cdd7a1874e7df0917b1b3b754c 100644 (file)
@@ -154,6 +154,53 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
             "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END",
         )
 
+    @testing.combinations(
+        (
+            (lambda t: ({"x": "y"}, t.c.col1, None)),
+            "CASE test.col1 WHEN :param_1 THEN :param_2 END",
+        ),
+        (
+            (lambda t: ({"x": "y", "p": "q"}, t.c.col1, None)),
+            "CASE test.col1 WHEN :param_1 THEN :param_2 "
+            "WHEN :param_3 THEN :param_4 END",
+        ),
+        (
+            (lambda t: ({t.c.col1 == 7: "x"}, None, 10)),
+            "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END",
+        ),
+        (
+            (lambda t: ({t.c.col1 == 7: "x", t.c.col1 == 10: "y"}, None, 10)),
+            "CASE WHEN (test.col1 = :col1_1) THEN :param_1 "
+            "WHEN (test.col1 = :col1_2) THEN :param_2 ELSE :param_3 END",
+        ),
+        argnames="test_case, expected",
+    )
+    @testing.combinations(("positional",), ("kwarg",), argnames="argstyle")
+    def test_when_dicts(self, argstyle, test_case, expected):
+        t = table("test", column("col1"))
+
+        whens, value, else_ = testing.resolve_lambda(test_case, t=t)
+
+        def _case_args(whens, value=None, else_=None):
+            kw = {}
+            if value is not None:
+                kw["value"] = value
+            if else_ is not None:
+                kw["else_"] = else_
+
+            if argstyle == "kwarg":
+                return case(whens=whens, **kw)
+            elif argstyle == "positional":
+                return case(whens, **kw)
+
+            # note: 1.3 also does not allow this form
+            # case([whens], **kw)
+
+        self.assert_compile(
+            _case_args(whens=whens, value=value, else_=else_),
+            expected,
+        )
+
     def test_text_doesnt_explode(self, connection):
 
         for s in [