]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix type of CASE expressions which include NULLs
authorDavid Evans <dave@evansd.net>
Mon, 15 Jan 2024 15:13:53 +0000 (10:13 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 15 Jan 2024 17:08:10 +0000 (12:08 -0500)
Fixed issues in :func:`_sql.case` where the logic for determining the
type of the expression could result in :class:`.NullType` if the last
element in the "whens" had no type, or in other cases where the type
could resolve to ``None``.  The logic has been updated to scan all
given expressions so that the first non-null type is used, as well as
to always ensure a type is present.  Pull request courtesy David Evans.

updates to test suite to use modern fixture patterns by Mike

Fixes: #10843
Closes: #10847
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10847
Pull-request-sha: 4fd5c39ab56de046e68c08f6c20cd1f7b2cb0e0d

Change-Id: I40f905ac336a8a42b617ff9473dbd9c22ac57505

doc/build/changelog/unreleased_20/10843.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/sql/test_case_statement.py

diff --git a/doc/build/changelog/unreleased_20/10843.rst b/doc/build/changelog/unreleased_20/10843.rst
new file mode 100644 (file)
index 0000000..838f6a8
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 10843
+
+    Fixed issues in :func:`_sql.case` where the logic for determining the
+    type of the expression could result in :class:`.NullType` if the last
+    element in the "whens" had no type, or in other cases where the type
+    could resolve to ``None``.  The logic has been updated to scan all
+    given expressions so that the first non-null type is used, as well as
+    to always ensure a type is present.  Pull request courtesy David Evans.
index 45eb8f3c55be13d47ae7550b672dff2a2b47b1b3..8f48e78ed0ffe7b13855e22b763abbd982615478 100644 (file)
@@ -3411,7 +3411,7 @@ class Case(ColumnElement[_T]):
         except TypeError:
             pass
 
-        whenlist = [
+        self.whens = [
             (
                 coercions.expect(
                     roles.ExpressionElementRole,
@@ -3423,24 +3423,28 @@ class Case(ColumnElement[_T]):
             for (c, r) in new_whens
         ]
 
-        if whenlist:
-            type_ = whenlist[-1][-1].type
-        else:
-            type_ = None
-
         if value is None:
             self.value = None
         else:
             self.value = coercions.expect(roles.ExpressionElementRole, value)
 
-        self.type = cast(_T, type_)
-        self.whens = whenlist
-
         if else_ is not None:
             self.else_ = coercions.expect(roles.ExpressionElementRole, else_)
         else:
             self.else_ = None
 
+        type_ = next(
+            (
+                then.type
+                # Iterate `whens` in reverse to match previous behaviour
+                # where type of final element took priority
+                for *_, then in reversed(self.whens)
+                if not then.type._isnull
+            ),
+            self.else_.type if self.else_ is not None else type_api.NULLTYPE,
+        )
+        self.type = cast(_T, type_)
+
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
         return list(
index 6907d213257afc66f72266c4d9b37faef38be041..5e95d3cb2f71d09dfaaa964234b6cf1ceebeda8f 100644 (file)
@@ -5,7 +5,6 @@ from sqlalchemy import Column
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
-from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
@@ -13,50 +12,48 @@ from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy.sql import column
 from sqlalchemy.sql import table
+from sqlalchemy.sql.sqltypes import NullType
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 
 
-info_table = None
-
-
-class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
+class CaseTest(fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = "default"
 
+    run_inserts = "once"
+    run_deletes = "never"
+
     @classmethod
-    def setup_test_class(cls):
-        metadata = MetaData()
-        global info_table
-        info_table = Table(
-            "infos",
+    def define_tables(cls, metadata):
+        Table(
+            "info_table",
             metadata,
             Column("pk", Integer, primary_key=True),
             Column("info", String(30)),
         )
 
-        with testing.db.begin() as conn:
-            info_table.create(conn)
-
-            conn.execute(
-                info_table.insert(),
-                [
-                    {"pk": 1, "info": "pk_1_data"},
-                    {"pk": 2, "info": "pk_2_data"},
-                    {"pk": 3, "info": "pk_3_data"},
-                    {"pk": 4, "info": "pk_4_data"},
-                    {"pk": 5, "info": "pk_5_data"},
-                    {"pk": 6, "info": "pk_6_data"},
-                ],
-            )
-
     @classmethod
-    def teardown_test_class(cls):
-        with testing.db.begin() as conn:
-            info_table.drop(conn)
+    def insert_data(cls, connection):
+        info_table = cls.tables.info_table
+
+        connection.execute(
+            info_table.insert(),
+            [
+                {"pk": 1, "info": "pk_1_data"},
+                {"pk": 2, "info": "pk_2_data"},
+                {"pk": 3, "info": "pk_3_data"},
+                {"pk": 4, "info": "pk_4_data"},
+                {"pk": 5, "info": "pk_5_data"},
+                {"pk": 6, "info": "pk_6_data"},
+            ],
+        )
+        connection.commit()
 
     @testing.requires.subqueries
     def test_case(self, connection):
+        info_table = self.tables.info_table
+
         inner = select(
             case(
                 (info_table.c.pk < 3, "lessthan3"),
@@ -222,6 +219,8 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_text_doesnt_explode(self, connection):
+        info_table = self.tables.info_table
+
         for s in [
             select(
                 case(
@@ -255,6 +254,8 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def testcase_with_dict(self):
+        info_table = self.tables.info_table
+
         query = select(
             case(
                 {
@@ -294,3 +295,61 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
             ("two", 2),
             ("other", 3),
         ]
+
+    @testing.variation("add_else", [True, False])
+    def test_type_of_case_expression_with_all_nulls(self, add_else):
+        info_table = self.tables.info_table
+
+        expr = case(
+            (info_table.c.pk < 0, None),
+            (info_table.c.pk > 9, None),
+            else_=column("q") if add_else else None,
+        )
+
+        assert isinstance(expr.type, NullType)
+
+    @testing.combinations(
+        lambda info_table: (
+            [
+                # test non-None in middle of WHENS takes precedence over Nones
+                (info_table.c.pk < 0, None),
+                (info_table.c.pk < 5, "five"),
+                (info_table.c.pk <= 9, info_table.c.pk),
+                (info_table.c.pk > 9, None),
+            ],
+            None,
+        ),
+        lambda info_table: (
+            # test non-None ELSE takes precedence over WHENs that are None
+            [(info_table.c.pk < 0, None)],
+            info_table.c.pk,
+        ),
+        lambda info_table: (
+            # test non-None WHEN takes precedence over non-None ELSE
+            [
+                (info_table.c.pk < 0, None),
+                (info_table.c.pk <= 9, info_table.c.pk),
+                (info_table.c.pk > 9, None),
+            ],
+            column("q", String),
+        ),
+        lambda info_table: (
+            # test last WHEN in list takes precedence
+            [
+                (info_table.c.pk < 0, String),
+                (info_table.c.pk > 9, None),
+                (info_table.c.pk <= 9, info_table.c.pk),
+            ],
+            column("q", String),
+        ),
+    )
+    def test_type_of_case_expression(self, when_lambda):
+        info_table = self.tables.info_table
+
+        whens, else_ = testing.resolve_lambda(
+            when_lambda, info_table=info_table
+        )
+
+        expr = case(*whens, else_=else_)
+
+        assert isinstance(expr.type, Integer)