From 4fd5c39ab56de046e68c08f6c20cd1f7b2cb0e0d Mon Sep 17 00:00:00 2001 From: David Evans Date: Mon, 8 Jan 2024 15:06:55 +0000 Subject: [PATCH] Fix type of CASE expressions which include NULLs Fixes: #10843 --- doc/build/changelog/unreleased_20/10843.rst | 7 ++++ lib/sqlalchemy/sql/elements.py | 22 +++++++----- test/sql/test_case_statement.py | 38 +++++++++++++++++++++ 3 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/10843.rst diff --git a/doc/build/changelog/unreleased_20/10843.rst b/doc/build/changelog/unreleased_20/10843.rst new file mode 100644 index 0000000000..400d8157f4 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10843.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug + :tickets: 10843 + + Enhanced case() to derive the ultimate type by scanning the given + expressions for the last non-NULL expression, rather than hardcoding to + the last WHEN case. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 45eb8f3c55..8f48e78ed0 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3411,7 +3411,7 @@ class Case(ColumnElement[_T]): except TypeError: pass - whenlist = [ + self.whens = [ ( coercions.expect( roles.ExpressionElementRole, @@ -3423,24 +3423,28 @@ class Case(ColumnElement[_T]): for (c, r) in new_whens ] - if whenlist: - type_ = whenlist[-1][-1].type - else: - type_ = None - if value is None: self.value = None else: self.value = coercions.expect(roles.ExpressionElementRole, value) - self.type = cast(_T, type_) - self.whens = whenlist - if else_ is not None: self.else_ = coercions.expect(roles.ExpressionElementRole, else_) else: self.else_ = None + type_ = next( + ( + then.type + # Iterate `whens` in reverse to match previous behaviour + # where type of final element took priority + for *_, then in reversed(self.whens) + if not then.type._isnull + ), + self.else_.type if self.else_ is not None else type_api.NULLTYPE, + ) + self.type = cast(_T, type_) + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index 6907d21325..08575f79bc 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -13,6 +13,7 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.sql import column from sqlalchemy.sql import table +from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -294,3 +295,40 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): ("two", 2), ("other", 3), ] + + def test_type_of_case_expression_with_all_nulls(self): + expr = case( + (info_table.c.pk < 0, None), + (info_table.c.pk > 9, None), + ) + + assert isinstance(expr.type, NullType) + + def test_type_of_case_expression_with_all_nulls_with_else(self): + expr = case( + (info_table.c.pk < 0, None), + (info_table.c.pk > 9, None), + else_=column("q"), + ) + + assert isinstance(expr.type, NullType) + + def test_type_of_case_expression_with_null_case_and_no_else_clause(self): + expr = case( + (info_table.c.pk < 0, None), + # This mixing of types is not legal in most DBMSs, but we want to + # test that types of later cases take priority over earlier ones + (info_table.c.pk < 5, "five"), + (info_table.c.pk <= 9, info_table.c.pk), + (info_table.c.pk > 9, None), + ) + + assert isinstance(expr.type, Integer) + + def test_type_of_case_expression_with_null_case_and_else_clause(self): + expr = case( + (info_table.c.pk < 0, None), + else_=info_table.c.pk, + ) + + assert isinstance(expr.type, Integer) -- 2.47.3