]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Typing: fix type of func.coalesce when used with hybrid properties
authorYannick PÉROUX <yannick.peroux@getalma.eu>
Tue, 4 Nov 2025 17:58:03 +0000 (12:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Nov 2025 21:13:38 +0000 (16:13 -0500)
Fixed typing issue where :class:`.coalesce` would not return the correct
return type when a nullable form of that argument were passed, even though
this function is meant to select the non-null entry among possibly null
arguments.  Pull request courtesy Yannick PÉROUX.

Closes: #12963
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12963
Pull-request-sha: 05d0d9784d4497fb3bfee540fbc51747c1767c90

Change-Id: Ife83a384ea57faf446c1fdb542df14627348f40f

doc/build/changelog/unreleased_20/12963.rst [new file with mode: 0644]
lib/sqlalchemy/sql/functions.py
test/typing/plain_files/sql/functions_again.py
tools/generate_sql_functions.py

diff --git a/doc/build/changelog/unreleased_20/12963.rst b/doc/build/changelog/unreleased_20/12963.rst
new file mode 100644 (file)
index 0000000..3e457db
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+
+    Fixed typing issue where :class:`.coalesce` would not return the correct
+    return type when a nullable form of that argument were passed, even though
+    this function is meant to select the non-null entry among possibly null
+    arguments.  Pull request courtesy Yannick PÉROUX.
+
index d4aafd362592339db9aeaba3fcd6417f4c1074f9..3e3fc27132f764a6de26632db741d2b54babd4e6 100644 (file)
@@ -1076,7 +1076,7 @@ class _FunctionGenerator:
         @overload
         def coalesce(
             self,
-            col: _ColumnExpressionArgument[_T],
+            col: _ColumnExpressionArgument[Optional[_T]],
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]: ...
@@ -1084,14 +1084,14 @@ class _FunctionGenerator:
         @overload
         def coalesce(
             self,
-            col: _T,
+            col: Optional[_T],
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]: ...
 
         def coalesce(
             self,
-            col: _ColumnExpressionOrLiteralArgument[_T],
+            col: _ColumnExpressionOrLiteralArgument[Optional[_T]],
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]: ...
@@ -1720,7 +1720,42 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
         super().__init__(*fn_args, **kwargs)
 
 
-class coalesce(ReturnTypeFromArgs[_T]):
+class ReturnTypeFromOptionalArgs(ReturnTypeFromArgs[_T]):
+    inherit_cache = True
+
+    @overload
+    def __init__(
+        self,
+        col: ColumnElement[_T],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
+        **kwargs: Any,
+    ) -> None: ...
+
+    @overload
+    def __init__(
+        self,
+        col: _ColumnExpressionArgument[Optional[_T]],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
+        **kwargs: Any,
+    ) -> None: ...
+
+    @overload
+    def __init__(
+        self,
+        col: Optional[_T],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
+        **kwargs: Any,
+    ) -> None: ...
+
+    def __init__(
+        self,
+        *args: _ColumnExpressionOrLiteralArgument[Optional[_T]],
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(*args, **kwargs)  # type: ignore
+
+
+class coalesce(ReturnTypeFromOptionalArgs[_T]):
     _has_args = True
     inherit_cache = True
 
index 1be8c5ce7822f59e0acbfee5168521a3c087e711..a961f307bef51f637bf2abc469d861cb7f18da99 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import Function
 from sqlalchemy import Integer
 from sqlalchemy import Select
 from sqlalchemy import select
+from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -29,6 +30,11 @@ class Foo(Base):
     a: Mapped[int]
     b: Mapped[int]
     c: Mapped[str]
+    _d: Mapped[int | None] = mapped_column("d")
+
+    @hybrid_property
+    def d(self) -> int | None:
+        return self._d
 
 
 assert_type(
@@ -66,6 +72,7 @@ assert_type(func.coalesce(Foo.c, "a", "b"), coalesce[str])
 assert_type(func.coalesce("a", "b"), coalesce[str])
 assert_type(func.coalesce(column("x", Integer), 3), coalesce[int])
 
+assert_type(func.coalesce(Foo._d, 100), coalesce[int])
 
 stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a)
 assert_type(stmt2, Select[int, str])
index a78e2492a54aaf1cc08ce46437ca5adaee56c906..d7e80538ff655e2be27811ea39f3f7a775b22a2e 100644 (file)
@@ -14,6 +14,7 @@ import typing_extensions
 
 from sqlalchemy.sql.functions import _registry
 from sqlalchemy.sql.functions import ReturnTypeFromArgs
+from sqlalchemy.sql.functions import ReturnTypeFromOptionalArgs
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.util.tool_support import code_writer_cmd
 
@@ -22,7 +23,7 @@ def _fns_in_deterministic_order():
     reg = _registry["_default"]
     for key in sorted(reg):
         cls = reg[key]
-        if cls is ReturnTypeFromArgs:
+        if cls is ReturnTypeFromArgs or cls is ReturnTypeFromOptionalArgs:
             continue
         yield key, cls
 
@@ -63,6 +64,11 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str:
                     is_reserved_word = key in builtins
 
                     if issubclass(fn_class, ReturnTypeFromArgs):
+                        if issubclass(fn_class, ReturnTypeFromOptionalArgs):
+                            _TEE = "Optional[_T]"
+                        else:
+                            _TEE = "_T"
+
                         buf.write(
                             textwrap.indent(
                                 f"""
@@ -84,7 +90,7 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 @overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _ColumnExpressionArgument[_T],
+    col: _ColumnExpressionArgument[{_TEE}],
     *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
@@ -93,7 +99,7 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 @overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _T,
+    col: {_TEE},
     *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
@@ -101,7 +107,7 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _ColumnExpressionOrLiteralArgument[_T],
+    col: _ColumnExpressionOrLiteralArgument[{_TEE}],
     *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]: