]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Typing: fix type of func.coalesce when used with hybrid properties
authorYannick PÉROUX <yannick.peroux@getalma.eu>
Tue, 4 Nov 2025 17:58:03 +0000 (12:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Nov 2025 00:45:17 +0000 (19:45 -0500)
Fixed typing issue where :class:`.coalesce` would not return the correct
return type when a nullable form of that argument were passed, even though
this function is meant to select the non-null entry among possibly null
arguments.  Pull request courtesy Yannick PÉROUX.

Closes: #12963
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12963
Pull-request-sha: 05d0d9784d4497fb3bfee540fbc51747c1767c90

Change-Id: Ife83a384ea57faf446c1fdb542df14627348f40f
(cherry picked from commit d160cb5314239ef9487c84aa5173e946d57804fd)

doc/build/changelog/unreleased_20/12963.rst [new file with mode: 0644]
lib/sqlalchemy/sql/functions.py
test/typing/plain_files/sql/functions_again.py
tools/generate_sql_functions.py

diff --git a/doc/build/changelog/unreleased_20/12963.rst b/doc/build/changelog/unreleased_20/12963.rst
new file mode 100644 (file)
index 0000000..3e457db
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+
+    Fixed typing issue where :class:`.coalesce` would not return the correct
+    return type when a nullable form of that argument were passed, even though
+    this function is meant to select the non-null entry among possibly null
+    arguments.  Pull request courtesy Yannick PÉROUX.
+
index fece73c9d8fa7250534a8b875f5644cdc074bae5..31f5015b524298445fedec667407678bd817c68b 100644 (file)
@@ -1047,7 +1047,7 @@ class _FunctionGenerator:
         @overload
         def coalesce(
             self,
-            col: _ColumnExpressionArgument[_T],
+            col: _ColumnExpressionArgument[Optional[_T]],
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]: ...
@@ -1055,14 +1055,14 @@ class _FunctionGenerator:
         @overload
         def coalesce(
             self,
-            col: _T,
+            col: Optional[_T],
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]: ...
 
         def coalesce(
             self,
-            col: _ColumnExpressionOrLiteralArgument[_T],
+            col: _ColumnExpressionOrLiteralArgument[Optional[_T]],
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]: ...
@@ -1661,7 +1661,42 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
         super().__init__(*fn_args, **kwargs)
 
 
-class coalesce(ReturnTypeFromArgs[_T]):
+class ReturnTypeFromOptionalArgs(ReturnTypeFromArgs[_T]):
+    inherit_cache = True
+
+    @overload
+    def __init__(
+        self,
+        col: ColumnElement[_T],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
+        **kwargs: Any,
+    ) -> None: ...
+
+    @overload
+    def __init__(
+        self,
+        col: _ColumnExpressionArgument[Optional[_T]],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
+        **kwargs: Any,
+    ) -> None: ...
+
+    @overload
+    def __init__(
+        self,
+        col: Optional[_T],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
+        **kwargs: Any,
+    ) -> None: ...
+
+    def __init__(
+        self,
+        *args: _ColumnExpressionOrLiteralArgument[Optional[_T]],
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(*args, **kwargs)  # type: ignore
+
+
+class coalesce(ReturnTypeFromOptionalArgs[_T]):
     _has_args = True
     inherit_cache = True
 
index 24b720f67107345b86a3610a756dd75f21f9f0b4..63ca442b6d12b895d7f3a070bd4a467825e3d9b3 100644 (file)
@@ -1,10 +1,14 @@
+from typing import assert_type
+
 from sqlalchemy import column
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import select
+from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.functions import coalesce
 
 
 class Base(DeclarativeBase):
@@ -18,6 +22,11 @@ class Foo(Base):
     a: Mapped[int]
     b: Mapped[int]
     c: Mapped[str]
+    _d: Mapped[int | None] = mapped_column("d")
+
+    @hybrid_property
+    def d(self) -> int | None:
+        return self._d
 
 
 # EXPECTED_TYPE: Over[Any]
@@ -60,6 +69,7 @@ reveal_type(func.coalesce("a", "b"))
 # EXPECTED_TYPE: coalesce[int]
 reveal_type(func.coalesce(column("x", Integer), 3))
 
+assert_type(func.coalesce(Foo._d, 100), coalesce[int])
 
 stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a)
 # EXPECTED_TYPE: Select[Tuple[int, str]]
index 624fbb75ed21cbd0854d12c62e6828e4dbe5c8a5..49844947bb1bb3a3f0b5bcdf2926834540a9d54f 100644 (file)
@@ -14,6 +14,7 @@ import typing_extensions
 
 from sqlalchemy.sql.functions import _registry
 from sqlalchemy.sql.functions import ReturnTypeFromArgs
+from sqlalchemy.sql.functions import ReturnTypeFromOptionalArgs
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.util.tool_support import code_writer_cmd
 
@@ -22,7 +23,7 @@ def _fns_in_deterministic_order():
     reg = _registry["_default"]
     for key in sorted(reg):
         cls = reg[key]
-        if cls is ReturnTypeFromArgs:
+        if cls is ReturnTypeFromArgs or cls is ReturnTypeFromOptionalArgs:
             continue
         yield key, cls
 
@@ -63,6 +64,11 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str:
                     is_reserved_word = key in builtins
 
                     if issubclass(fn_class, ReturnTypeFromArgs):
+                        if issubclass(fn_class, ReturnTypeFromOptionalArgs):
+                            _TEE = "Optional[_T]"
+                        else:
+                            _TEE = "_T"
+
                         buf.write(
                             textwrap.indent(
                                 f"""
@@ -84,7 +90,7 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 @overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _ColumnExpressionArgument[_T],
+    col: _ColumnExpressionArgument[{_TEE}],
     *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
@@ -93,7 +99,7 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 @overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _T,
+    col: {_TEE},
     *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
@@ -101,7 +107,7 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _ColumnExpressionOrLiteralArgument[_T],
+    col: _ColumnExpressionOrLiteralArgument[{_TEE}],
     *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]: