]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix type of CASE expressions which include NULLs 10847/head
authorDavid Evans <dave@evansd.net>
Mon, 8 Jan 2024 15:06:55 +0000 (15:06 +0000)
committerDavid Evans <dave@evansd.net>
Thu, 11 Jan 2024 08:27:24 +0000 (08:27 +0000)
Fixes: #10843
doc/build/changelog/unreleased_20/10843.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/sql/test_case_statement.py

diff --git a/doc/build/changelog/unreleased_20/10843.rst b/doc/build/changelog/unreleased_20/10843.rst
new file mode 100644 (file)
index 0000000..400d815
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug
+    :tickets: 10843
+
+    Enhanced case() to derive the ultimate type by scanning the given
+    expressions for the last non-NULL expression, rather than hardcoding to
+    the last WHEN case.
index 45eb8f3c55be13d47ae7550b672dff2a2b47b1b3..8f48e78ed0ffe7b13855e22b763abbd982615478 100644 (file)
@@ -3411,7 +3411,7 @@ class Case(ColumnElement[_T]):
         except TypeError:
             pass
 
-        whenlist = [
+        self.whens = [
             (
                 coercions.expect(
                     roles.ExpressionElementRole,
@@ -3423,24 +3423,28 @@ class Case(ColumnElement[_T]):
             for (c, r) in new_whens
         ]
 
-        if whenlist:
-            type_ = whenlist[-1][-1].type
-        else:
-            type_ = None
-
         if value is None:
             self.value = None
         else:
             self.value = coercions.expect(roles.ExpressionElementRole, value)
 
-        self.type = cast(_T, type_)
-        self.whens = whenlist
-
         if else_ is not None:
             self.else_ = coercions.expect(roles.ExpressionElementRole, else_)
         else:
             self.else_ = None
 
+        type_ = next(
+            (
+                then.type
+                # Iterate `whens` in reverse to match previous behaviour
+                # where type of final element took priority
+                for *_, then in reversed(self.whens)
+                if not then.type._isnull
+            ),
+            self.else_.type if self.else_ is not None else type_api.NULLTYPE,
+        )
+        self.type = cast(_T, type_)
+
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
         return list(
index 6907d213257afc66f72266c4d9b37faef38be041..08575f79bcb3f4d7cff4aaa37cb7c5adcd8446d3 100644 (file)
@@ -13,6 +13,7 @@ from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy.sql import column
 from sqlalchemy.sql import table
+from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
@@ -294,3 +295,40 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
             ("two", 2),
             ("other", 3),
         ]
+
+    def test_type_of_case_expression_with_all_nulls(self):
+        expr = case(
+            (info_table.c.pk < 0, None),
+            (info_table.c.pk > 9, None),
+        )
+
+        assert isinstance(expr.type, NullType)
+
+    def test_type_of_case_expression_with_all_nulls_with_else(self):
+        expr = case(
+            (info_table.c.pk < 0, None),
+            (info_table.c.pk > 9, None),
+            else_=column("q"),
+        )
+
+        assert isinstance(expr.type, NullType)
+
+    def test_type_of_case_expression_with_null_case_and_no_else_clause(self):
+        expr = case(
+            (info_table.c.pk < 0, None),
+            # This mixing of types is not legal in most DBMSs, but we want to
+            # test that types of later cases take priority over earlier ones
+            (info_table.c.pk < 5, "five"),
+            (info_table.c.pk <= 9, info_table.c.pk),
+            (info_table.c.pk > 9, None),
+        )
+
+        assert isinstance(expr.type, Integer)
+
+    def test_type_of_case_expression_with_null_case_and_else_clause(self):
+        expr = case(
+            (info_table.c.pk < 0, None),
+            else_=info_table.c.pk,
+        )
+
+        assert isinstance(expr.type, Integer)