From: David Evans Date: Mon, 15 Jan 2024 15:13:53 +0000 (-0500) Subject: Fix type of CASE expressions which include NULLs X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8f4ac0c0f07509d2f8a4bce9cbb07ac08ad04044;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix type of CASE expressions which include NULLs Fixed issues in :func:`_sql.case` where the logic for determining the type of the expression could result in :class:`.NullType` if the last element in the "whens" had no type, or in other cases where the type could resolve to ``None``. The logic has been updated to scan all given expressions so that the first non-null type is used, as well as to always ensure a type is present. Pull request courtesy David Evans. updates to test suite to use modern fixture patterns by Mike Fixes: #10843 Closes: #10847 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10847 Pull-request-sha: 4fd5c39ab56de046e68c08f6c20cd1f7b2cb0e0d Change-Id: I40f905ac336a8a42b617ff9473dbd9c22ac57505 --- diff --git a/doc/build/changelog/unreleased_20/10843.rst b/doc/build/changelog/unreleased_20/10843.rst new file mode 100644 index 0000000000..838f6a8beb --- /dev/null +++ b/doc/build/changelog/unreleased_20/10843.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 10843 + + Fixed issues in :func:`_sql.case` where the logic for determining the + type of the expression could result in :class:`.NullType` if the last + element in the "whens" had no type, or in other cases where the type + could resolve to ``None``. The logic has been updated to scan all + given expressions so that the first non-null type is used, as well as + to always ensure a type is present. Pull request courtesy David Evans. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 45eb8f3c55..8f48e78ed0 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -3411,7 +3411,7 @@ class Case(ColumnElement[_T]): except TypeError: pass - whenlist = [ + self.whens = [ ( coercions.expect( roles.ExpressionElementRole, @@ -3423,24 +3423,28 @@ class Case(ColumnElement[_T]): for (c, r) in new_whens ] - if whenlist: - type_ = whenlist[-1][-1].type - else: - type_ = None - if value is None: self.value = None else: self.value = coercions.expect(roles.ExpressionElementRole, value) - self.type = cast(_T, type_) - self.whens = whenlist - if else_ is not None: self.else_ = coercions.expect(roles.ExpressionElementRole, else_) else: self.else_ = None + type_ = next( + ( + then.type + # Iterate `whens` in reverse to match previous behaviour + # where type of final element took priority + for *_, then in reversed(self.whens) + if not then.type._isnull + ), + self.else_.type if self.else_ is not None else type_api.NULLTYPE, + ) + self.type = cast(_T, type_) + @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return list( diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index 6907d21325..5e95d3cb2f 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -5,7 +5,6 @@ from sqlalchemy import Column from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import literal_column -from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table @@ -13,50 +12,48 @@ from sqlalchemy import testing from sqlalchemy import text from sqlalchemy.sql import column from sqlalchemy.sql import table +from sqlalchemy.sql.sqltypes import NullType from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -info_table = None - - -class CaseTest(fixtures.TestBase, AssertsCompiledSQL): +class CaseTest(fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = "default" + run_inserts = "once" + run_deletes = "never" + @classmethod - def setup_test_class(cls): - metadata = MetaData() - global info_table - info_table = Table( - "infos", + def define_tables(cls, metadata): + Table( + "info_table", metadata, Column("pk", Integer, primary_key=True), Column("info", String(30)), ) - with testing.db.begin() as conn: - info_table.create(conn) - - conn.execute( - info_table.insert(), - [ - {"pk": 1, "info": "pk_1_data"}, - {"pk": 2, "info": "pk_2_data"}, - {"pk": 3, "info": "pk_3_data"}, - {"pk": 4, "info": "pk_4_data"}, - {"pk": 5, "info": "pk_5_data"}, - {"pk": 6, "info": "pk_6_data"}, - ], - ) - @classmethod - def teardown_test_class(cls): - with testing.db.begin() as conn: - info_table.drop(conn) + def insert_data(cls, connection): + info_table = cls.tables.info_table + + connection.execute( + info_table.insert(), + [ + {"pk": 1, "info": "pk_1_data"}, + {"pk": 2, "info": "pk_2_data"}, + {"pk": 3, "info": "pk_3_data"}, + {"pk": 4, "info": "pk_4_data"}, + {"pk": 5, "info": "pk_5_data"}, + {"pk": 6, "info": "pk_6_data"}, + ], + ) + connection.commit() @testing.requires.subqueries def test_case(self, connection): + info_table = self.tables.info_table + inner = select( case( (info_table.c.pk < 3, "lessthan3"), @@ -222,6 +219,8 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_text_doesnt_explode(self, connection): + info_table = self.tables.info_table + for s in [ select( case( @@ -255,6 +254,8 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): ) def testcase_with_dict(self): + info_table = self.tables.info_table + query = select( case( { @@ -294,3 +295,61 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): ("two", 2), ("other", 3), ] + + @testing.variation("add_else", [True, False]) + def test_type_of_case_expression_with_all_nulls(self, add_else): + info_table = self.tables.info_table + + expr = case( + (info_table.c.pk < 0, None), + (info_table.c.pk > 9, None), + else_=column("q") if add_else else None, + ) + + assert isinstance(expr.type, NullType) + + @testing.combinations( + lambda info_table: ( + [ + # test non-None in middle of WHENS takes precedence over Nones + (info_table.c.pk < 0, None), + (info_table.c.pk < 5, "five"), + (info_table.c.pk <= 9, info_table.c.pk), + (info_table.c.pk > 9, None), + ], + None, + ), + lambda info_table: ( + # test non-None ELSE takes precedence over WHENs that are None + [(info_table.c.pk < 0, None)], + info_table.c.pk, + ), + lambda info_table: ( + # test non-None WHEN takes precedence over non-None ELSE + [ + (info_table.c.pk < 0, None), + (info_table.c.pk <= 9, info_table.c.pk), + (info_table.c.pk > 9, None), + ], + column("q", String), + ), + lambda info_table: ( + # test last WHEN in list takes precedence + [ + (info_table.c.pk < 0, String), + (info_table.c.pk > 9, None), + (info_table.c.pk <= 9, info_table.c.pk), + ], + column("q", String), + ), + ) + def test_type_of_case_expression(self, when_lambda): + info_table = self.tables.info_table + + whens, else_ = testing.resolve_lambda( + when_lambda, info_table=info_table + ) + + expr = case(*whens, else_=else_) + + assert isinstance(expr.type, Integer)