From d160cb5314239ef9487c84aa5173e946d57804fd Mon Sep 17 00:00:00 2001 From: =?utf8?q?Yannick=20P=C3=89ROUX?= Date: Tue, 4 Nov 2025 12:58:03 -0500 Subject: [PATCH] Typing: fix type of func.coalesce when used with hybrid properties MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Fixed typing issue where :class:`.coalesce` would not return the correct return type when a nullable form of that argument were passed, even though this function is meant to select the non-null entry among possibly null arguments. Pull request courtesy Yannick PÉROUX. Closes: #12963 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12963 Pull-request-sha: 05d0d9784d4497fb3bfee540fbc51747c1767c90 Change-Id: Ife83a384ea57faf446c1fdb542df14627348f40f --- doc/build/changelog/unreleased_20/12963.rst | 8 ++++ lib/sqlalchemy/sql/functions.py | 43 +++++++++++++++++-- .../typing/plain_files/sql/functions_again.py | 7 +++ tools/generate_sql_functions.py | 14 ++++-- 4 files changed, 64 insertions(+), 8 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/12963.rst diff --git a/doc/build/changelog/unreleased_20/12963.rst b/doc/build/changelog/unreleased_20/12963.rst new file mode 100644 index 0000000000..3e457db351 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12963.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, typing + + Fixed typing issue where :class:`.coalesce` would not return the correct + return type when a nullable form of that argument were passed, even though + this function is meant to select the non-null entry among possibly null + arguments. Pull request courtesy Yannick PÉROUX. + diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index d4aafd3625..3e3fc27132 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1076,7 +1076,7 @@ class _FunctionGenerator: @overload def coalesce( self, - col: _ColumnExpressionArgument[_T], + col: _ColumnExpressionArgument[Optional[_T]], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> coalesce[_T]: ... @@ -1084,14 +1084,14 @@ class _FunctionGenerator: @overload def coalesce( self, - col: _T, + col: Optional[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> coalesce[_T]: ... def coalesce( self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _ColumnExpressionOrLiteralArgument[Optional[_T]], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> coalesce[_T]: ... @@ -1720,7 +1720,42 @@ class ReturnTypeFromArgs(GenericFunction[_T]): super().__init__(*fn_args, **kwargs) -class coalesce(ReturnTypeFromArgs[_T]): +class ReturnTypeFromOptionalArgs(ReturnTypeFromArgs[_T]): + inherit_cache = True + + @overload + def __init__( + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> None: ... + + @overload + def __init__( + self, + col: _ColumnExpressionArgument[Optional[_T]], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> None: ... + + @overload + def __init__( + self, + col: Optional[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> None: ... + + def __init__( + self, + *args: _ColumnExpressionOrLiteralArgument[Optional[_T]], + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) # type: ignore + + +class coalesce(ReturnTypeFromOptionalArgs[_T]): _has_args = True inherit_cache = True diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index 1be8c5ce78..a961f307be 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -7,6 +7,7 @@ from sqlalchemy import Function from sqlalchemy import Integer from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -29,6 +30,11 @@ class Foo(Base): a: Mapped[int] b: Mapped[int] c: Mapped[str] + _d: Mapped[int | None] = mapped_column("d") + + @hybrid_property + def d(self) -> int | None: + return self._d assert_type( @@ -66,6 +72,7 @@ assert_type(func.coalesce(Foo.c, "a", "b"), coalesce[str]) assert_type(func.coalesce("a", "b"), coalesce[str]) assert_type(func.coalesce(column("x", Integer), 3), coalesce[int]) +assert_type(func.coalesce(Foo._d, 100), coalesce[int]) stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a) assert_type(stmt2, Select[int, str]) diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index a78e2492a5..d7e80538ff 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -14,6 +14,7 @@ import typing_extensions from sqlalchemy.sql.functions import _registry from sqlalchemy.sql.functions import ReturnTypeFromArgs +from sqlalchemy.sql.functions import ReturnTypeFromOptionalArgs from sqlalchemy.types import TypeEngine from sqlalchemy.util.tool_support import code_writer_cmd @@ -22,7 +23,7 @@ def _fns_in_deterministic_order(): reg = _registry["_default"] for key in sorted(reg): cls = reg[key] - if cls is ReturnTypeFromArgs: + if cls is ReturnTypeFromArgs or cls is ReturnTypeFromOptionalArgs: continue yield key, cls @@ -63,6 +64,11 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str: is_reserved_word = key in builtins if issubclass(fn_class, ReturnTypeFromArgs): + if issubclass(fn_class, ReturnTypeFromOptionalArgs): + _TEE = "Optional[_T]" + else: + _TEE = "_T" + buf.write( textwrap.indent( f""" @@ -84,7 +90,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} @overload def {key}( {' # noqa: A001' if is_reserved_word else ''} self, - col: _ColumnExpressionArgument[_T], + col: _ColumnExpressionArgument[{_TEE}], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> {fn_class.__name__}[_T]: @@ -93,7 +99,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} @overload def {key}( {' # noqa: A001' if is_reserved_word else ''} self, - col: _T, + col: {_TEE}, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> {fn_class.__name__}[_T]: @@ -101,7 +107,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} def {key}( {' # noqa: A001' if is_reserved_word else ''} self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _ColumnExpressionOrLiteralArgument[{_TEE}], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> {fn_class.__name__}[_T]: -- 2.47.3