]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use ExpressionElementRole for case targets in case()
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Nov 2021 01:26:44 +0000 (21:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Nov 2021 01:34:32 +0000 (21:34 -0400)
Fixed regression where the :func:`_sql.text` construct would no longer be
accepted as a target case in the "whens" list within a :func:`_sql.case`
construct. The regression appears related to an attempt to guard against
some forms of literal values that were considered to be ambiguous when
passed here; however, there's no reason the target cases shouldn't be
interpreted as open-ended SQL expressions just like anywhere else, and a
literal string or tuple will be converted to a bound parameter as would be
the case elsewhere.

Fixes: #7287
Change-Id: I75478adfa115f3292cb1362cc5b2fdf152b0ed6f

doc/build/changelog/unreleased_14/7287.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/sql/test_case_statement.py
test/sql/test_text.py

diff --git a/doc/build/changelog/unreleased_14/7287.rst b/doc/build/changelog/unreleased_14/7287.rst
new file mode 100644 (file)
index 0000000..14c72a8
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 7287
+
+    Fixed regression where the :func:`_sql.text` construct would no longer be
+    accepted as a target case in the "whens" list within a :func:`_sql.case`
+    construct. The regression appears related to an attempt to guard against
+    some forms of literal values that were considered to be ambiguous when
+    passed here; however, there's no reason the target cases shouldn't be
+    interpreted as open-ended SQL expressions just like anywhere else, and a
+    literal string or tuple will be converted to a bound parameter as would be
+    the case elsewhere.
\ No newline at end of file
index f1fe46fd2346a518cfa1bb8d007ca1e9d519e398..c8faebbd9dbbdd3ebb5bdc278ea8bd64cbca252a 100644 (file)
@@ -2943,28 +2943,18 @@ class Case(ColumnElement):
             pass
 
         value = kw.pop("value", None)
-        if value is not None:
-            whenlist = [
-                (
-                    coercions.expect(
-                        roles.ExpressionElementRole,
-                        c,
-                        apply_propagate_attrs=self,
-                    ).self_group(),
-                    coercions.expect(roles.ExpressionElementRole, r),
-                )
-                for (c, r) in whens
-            ]
-        else:
-            whenlist = [
-                (
-                    coercions.expect(
-                        roles.ColumnArgumentRole, c, apply_propagate_attrs=self
-                    ).self_group(),
-                    coercions.expect(roles.ExpressionElementRole, r),
-                )
-                for (c, r) in whens
-            ]
+
+        whenlist = [
+            (
+                coercions.expect(
+                    roles.ExpressionElementRole,
+                    c,
+                    apply_propagate_attrs=self,
+                ).self_group(),
+                coercions.expect(roles.ExpressionElementRole, r),
+            )
+            for (c, r) in whens
+        ]
 
         if whenlist:
             type_ = list(whenlist[-1])[-1].type
index 63491524c2af541a686862a0121a563c10692516..db7f16194f773ce46473b5506d17340c957c18ce 100644 (file)
@@ -2,7 +2,7 @@ from sqlalchemy import and_
 from sqlalchemy import case
 from sqlalchemy import cast
 from sqlalchemy import Column
-from sqlalchemy import exc
+from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
@@ -13,7 +13,6 @@ from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy.sql import column
 from sqlalchemy.sql import table
-from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
@@ -125,23 +124,62 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
             ],
         )
 
-    def test_literal_interpretation_ambiguous(self):
-        assert_raises_message(
-            exc.ArgumentError,
-            r"Column expression expected, got 'x'",
-            case,
-            ("x", "y"),
+    def test_literal_interpretation_one(self):
+        """note this is modified as of #7287 to accept strings, tuples
+        and other literal values as input
+        where they are interpreted as bound values just like any other
+        expression.
+
+        Previously, an exception would be raised that the literal was
+        ambiguous.
+
+
+        """
+        self.assert_compile(
+            case(("x", "y")),
+            "CASE WHEN :param_1 THEN :param_2 END",
+            checkparams={"param_1": "x", "param_2": "y"},
         )
 
-    def test_literal_interpretation_ambiguous_tuple(self):
-        assert_raises_message(
-            exc.ArgumentError,
-            r"Column expression expected, got \('x', 'y'\)",
-            case,
-            (("x", "y"), "z"),
+    def test_literal_interpretation_two(self):
+        """note this is modified as of #7287 to accept strings, tuples
+        and other literal values as input
+        where they are interpreted as bound values just like any other
+        expression.
+
+        Previously, an exception would be raised that the literal was
+        ambiguous.
+
+
+        """
+        self.assert_compile(
+            case(
+                (("x", "y"), "z"),
+            ),
+            "CASE WHEN :param_1 THEN :param_2 END",
+            checkparams={"param_1": ("x", "y"), "param_2": "z"},
         )
 
-    def test_literal_interpretation(self):
+    def test_literal_interpretation_two_point_five(self):
+        """note this is modified as of #7287 to accept strings, tuples
+        and other literal values as input
+        where they are interpreted as bound values just like any other
+        expression.
+
+        Previously, an exception would be raised that the literal was
+        ambiguous.
+
+
+        """
+        self.assert_compile(
+            case(
+                (12, "z"),
+            ),
+            "CASE WHEN :param_1 THEN :param_2 END",
+            checkparams={"param_1": 12, "param_2": "z"},
+        )
+
+    def test_literal_interpretation_three(self):
         t = table("test", column("col1"))
 
         self.assert_compile(
@@ -220,6 +258,16 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
                 [("no",), ("no",), ("no",), ("yes",), ("no",), ("no",)],
             )
 
+    def test_text_doenst_explode_even_in_whenlist(self):
+        """test #7287"""
+        self.assert_compile(
+            case(
+                (text(":case = 'upper'"), func.upper(literal_column("q"))),
+                else_=func.lower(literal_column("q")),
+            ),
+            "CASE WHEN :case = 'upper' THEN upper(q) ELSE lower(q) END",
+        )
+
     def testcase_with_dict(self):
         query = select(
             case(
index 15f6f604861bc20de04a4ddc7f382160d107140c..4fff0ed7ef39bc1f6032349d408578a3d77badae 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import bindparam
 from sqlalchemy import Column
 from sqlalchemy import desc
 from sqlalchemy import exc
+from sqlalchemy import extract
 from sqlalchemy import Float
 from sqlalchemy import func
 from sqlalchemy import Integer
@@ -182,6 +183,14 @@ class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL):
             "(select f from bar where lala=heyhey) foo WHERE foo.f = t.id",
         )
 
+    def test_expression_element_role(self):
+        """test #7287"""
+
+        self.assert_compile(
+            extract("year", text("some_date + :param")),
+            "EXTRACT(year FROM some_date + :param)",
+        )
+
     @testing.combinations(
         (
             None,