]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve overloads applied to generic functions
authorFederico Caselli <cfederico87@gmail.com>
Mon, 24 Mar 2025 20:50:45 +0000 (21:50 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 3 Apr 2025 18:00:23 +0000 (20:00 +0200)
try again to remove the overloads to the generic functionn
generator (like coalesce, array_agg, etc).
As of mypy 1.15 it still does now work, but a simpler version
is added in this change

Change-Id: I8b97ae00298ec6f6bf8580090e5defff71e1ceb0
(cherry picked from commit 5cc6a65c61798078959455f5d74f535681c119b7)

lib/sqlalchemy/sql/functions.py
test/typing/plain_files/sql/functions_again.py
tools/generate_sql_functions.py

index 0e52a8bb736bc500b0562208e75217c4ff3e3d17..cd63e82339e9a69d6ef489e401be9307675fb292 100644 (file)
@@ -5,7 +5,6 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
-
 """SQL function API, factories, and built-in functions."""
 
 from __future__ import annotations
@@ -153,7 +152,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
 
     clause_expr: Grouping[Any]
 
-    def __init__(self, *clauses: _ColumnExpressionOrLiteralArgument[Any]):
+    def __init__(
+        self, *clauses: _ColumnExpressionOrLiteralArgument[Any]
+    ) -> None:
         r"""Construct a :class:`.FunctionElement`.
 
         :param \*clauses: list of column expressions that form the arguments
@@ -779,7 +780,7 @@ class FunctionAsBinary(BinaryExpression[Any]):
 
     def __init__(
         self, fn: FunctionElement[Any], left_index: int, right_index: int
-    ):
+    ) -> None:
         self.sql_function = fn
         self.left_index = left_index
         self.right_index = right_index
@@ -831,7 +832,7 @@ class ScalarFunctionColumn(NamedColumn[_T]):
         fn: FunctionElement[_T],
         name: str,
         type_: Optional[_TypeEngineArgument[_T]] = None,
-    ):
+    ) -> None:
         self.fn = fn
         self.name = name
 
@@ -930,7 +931,7 @@ class _FunctionGenerator:
 
     """  # noqa
 
-    def __init__(self, **opts: Any):
+    def __init__(self, **opts: Any) -> None:
         self.__names: List[str] = []
         self.opts = opts
 
@@ -990,10 +991,10 @@ class _FunctionGenerator:
         @property
         def ansifunction(self) -> Type[AnsiFunction[Any]]: ...
 
-        # set ColumnElement[_T] as a separate overload, to appease mypy
-        # which seems to not want to accept _T from _ColumnExpressionArgument.
-        # this is even if all non-generic types are removed from it, so
-        # reasons remain unclear for why this does not work
+        # set ColumnElement[_T] as a separate overload, to appease
+        # mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument. Seems somewhat related to the covariant
+        # _HasClauseElement as of mypy 1.15
 
         @overload
         def array_agg(
@@ -1014,7 +1015,7 @@ class _FunctionGenerator:
         @overload
         def array_agg(
             self,
-            col: _ColumnExpressionOrLiteralArgument[_T],
+            col: _T,
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> array_agg[_T]: ...
@@ -1032,10 +1033,10 @@ class _FunctionGenerator:
         @property
         def char_length(self) -> Type[char_length]: ...
 
-        # set ColumnElement[_T] as a separate overload, to appease mypy
-        # which seems to not want to accept _T from _ColumnExpressionArgument.
-        # this is even if all non-generic types are removed from it, so
-        # reasons remain unclear for why this does not work
+        # set ColumnElement[_T] as a separate overload, to appease
+        # mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument. Seems somewhat related to the covariant
+        # _HasClauseElement as of mypy 1.15
 
         @overload
         def coalesce(
@@ -1056,7 +1057,7 @@ class _FunctionGenerator:
         @overload
         def coalesce(
             self,
-            col: _ColumnExpressionOrLiteralArgument[_T],
+            col: _T,
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]: ...
@@ -1107,10 +1108,10 @@ class _FunctionGenerator:
         @property
         def localtimestamp(self) -> Type[localtimestamp]: ...
 
-        # set ColumnElement[_T] as a separate overload, to appease mypy
-        # which seems to not want to accept _T from _ColumnExpressionArgument.
-        # this is even if all non-generic types are removed from it, so
-        # reasons remain unclear for why this does not work
+        # set ColumnElement[_T] as a separate overload, to appease
+        # mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument. Seems somewhat related to the covariant
+        # _HasClauseElement as of mypy 1.15
 
         @overload
         def max(  # noqa: A001
@@ -1131,7 +1132,7 @@ class _FunctionGenerator:
         @overload
         def max(  # noqa: A001
             self,
-            col: _ColumnExpressionOrLiteralArgument[_T],
+            col: _T,
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> max[_T]: ...
@@ -1143,10 +1144,10 @@ class _FunctionGenerator:
             **kwargs: Any,
         ) -> max[_T]: ...
 
-        # set ColumnElement[_T] as a separate overload, to appease mypy
-        # which seems to not want to accept _T from _ColumnExpressionArgument.
-        # this is even if all non-generic types are removed from it, so
-        # reasons remain unclear for why this does not work
+        # set ColumnElement[_T] as a separate overload, to appease
+        # mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument. Seems somewhat related to the covariant
+        # _HasClauseElement as of mypy 1.15
 
         @overload
         def min(  # noqa: A001
@@ -1167,7 +1168,7 @@ class _FunctionGenerator:
         @overload
         def min(  # noqa: A001
             self,
-            col: _ColumnExpressionOrLiteralArgument[_T],
+            col: _T,
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> min[_T]: ...
@@ -1212,10 +1213,10 @@ class _FunctionGenerator:
         @property
         def session_user(self) -> Type[session_user]: ...
 
-        # set ColumnElement[_T] as a separate overload, to appease mypy
-        # which seems to not want to accept _T from _ColumnExpressionArgument.
-        # this is even if all non-generic types are removed from it, so
-        # reasons remain unclear for why this does not work
+        # set ColumnElement[_T] as a separate overload, to appease
+        # mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument. Seems somewhat related to the covariant
+        # _HasClauseElement as of mypy 1.15
 
         @overload
         def sum(  # noqa: A001
@@ -1236,7 +1237,7 @@ class _FunctionGenerator:
         @overload
         def sum(  # noqa: A001
             self,
-            col: _ColumnExpressionOrLiteralArgument[_T],
+            col: _T,
             *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> sum[_T]: ...
@@ -1332,7 +1333,7 @@ class Function(FunctionElement[_T]):
         *clauses: _ColumnExpressionOrLiteralArgument[_T],
         type_: None = ...,
         packagenames: Optional[Tuple[str, ...]] = ...,
-    ): ...
+    ) -> None: ...
 
     @overload
     def __init__(
@@ -1341,7 +1342,7 @@ class Function(FunctionElement[_T]):
         *clauses: _ColumnExpressionOrLiteralArgument[Any],
         type_: _TypeEngineArgument[_T] = ...,
         packagenames: Optional[Tuple[str, ...]] = ...,
-    ): ...
+    ) -> None: ...
 
     def __init__(
         self,
@@ -1349,7 +1350,7 @@ class Function(FunctionElement[_T]):
         *clauses: _ColumnExpressionOrLiteralArgument[Any],
         type_: Optional[_TypeEngineArgument[_T]] = None,
         packagenames: Optional[Tuple[str, ...]] = None,
-    ):
+    ) -> None:
         """Construct a :class:`.Function`.
 
         The :data:`.func` construct is normally used to construct
@@ -1531,7 +1532,7 @@ class GenericFunction(Function[_T]):
 
     def __init__(
         self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
-    ):
+    ) -> None:
         parsed_args = kwargs.pop("_parsed_args", None)
         if parsed_args is None:
             parsed_args = [
@@ -1578,7 +1579,7 @@ class next_value(GenericFunction[int]):
         ("sequence", InternalTraversal.dp_named_ddl_element)
     ]
 
-    def __init__(self, seq: schema.Sequence, **kw: Any):
+    def __init__(self, seq: schema.Sequence, **kw: Any) -> None:
         assert isinstance(
             seq, schema.Sequence
         ), "next_value() accepts a Sequence object as input."
@@ -1603,7 +1604,9 @@ class AnsiFunction(GenericFunction[_T]):
 
     inherit_cache = True
 
-    def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
+    def __init__(
+        self, *args: _ColumnExpressionArgument[Any], **kwargs: Any
+    ) -> None:
         GenericFunction.__init__(self, *args, **kwargs)
 
 
@@ -1614,10 +1617,10 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
 
     inherit_cache = True
 
-    # set ColumnElement[_T] as a separate overload, to appease mypy which seems
-    # to not want to accept _T from _ColumnExpressionArgument.  this is even if
-    # all non-generic types are removed from it, so reasons remain unclear for
-    # why this does not work
+    # set ColumnElement[_T] as a separate overload, to appease
+    # mypy which seems to not want to accept _T from
+    # _ColumnExpressionArgument. Seems somewhat related to the covariant
+    # _HasClauseElement as of mypy 1.15
 
     @overload
     def __init__(
@@ -1625,7 +1628,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
         col: ColumnElement[_T],
         *args: _ColumnExpressionOrLiteralArgument[Any],
         **kwargs: Any,
-    ): ...
+    ) -> None: ...
 
     @overload
     def __init__(
@@ -1633,19 +1636,19 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
         col: _ColumnExpressionArgument[_T],
         *args: _ColumnExpressionOrLiteralArgument[Any],
         **kwargs: Any,
-    ): ...
+    ) -> None: ...
 
     @overload
     def __init__(
         self,
-        col: _ColumnExpressionOrLiteralArgument[_T],
+        col: _T,
         *args: _ColumnExpressionOrLiteralArgument[Any],
         **kwargs: Any,
-    ): ...
+    ) -> None: ...
 
     def __init__(
-        self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
-    ):
+        self, *args: _ColumnExpressionOrLiteralArgument[_T], **kwargs: Any
+    ) -> None:
         fn_args: Sequence[ColumnElement[Any]] = [
             coercions.expect(
                 roles.ExpressionElementRole,
@@ -1727,7 +1730,7 @@ class char_length(GenericFunction[int]):
     type = sqltypes.Integer()
     inherit_cache = True
 
-    def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any):
+    def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any) -> None:
         # slight hack to limit to just one positional argument
         # not sure why this one function has this special treatment
         super().__init__(arg, **kw)
@@ -1773,7 +1776,7 @@ class count(GenericFunction[int]):
             _ColumnExpressionArgument[Any], _StarOrOne, None
         ] = None,
         **kwargs: Any,
-    ):
+    ) -> None:
         if expression is None:
             expression = literal_column("*")
         super().__init__(expression, **kwargs)
@@ -1862,7 +1865,9 @@ class array_agg(ReturnTypeFromArgs[Sequence[_T]]):
 
     inherit_cache = True
 
-    def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
+    def __init__(
+        self, *args: _ColumnExpressionArgument[Any], **kwargs: Any
+    ) -> None:
         fn_args: Sequence[ColumnElement[Any]] = [
             coercions.expect(
                 roles.ExpressionElementRole, c, apply_propagate_attrs=self
@@ -2095,5 +2100,7 @@ class aggregate_strings(GenericFunction[str]):
     _has_args = True
     inherit_cache = True
 
-    def __init__(self, clause: _ColumnExpressionArgument[Any], separator: str):
+    def __init__(
+        self, clause: _ColumnExpressionArgument[Any], separator: str
+    ) -> None:
         super().__init__(clause, separator)
index 67888790f6bc5ef826ce9e4988e80266a695da6f..24b720f67107345b86a3610a756dd75f21f9f0b4 100644 (file)
@@ -1,4 +1,6 @@
+from sqlalchemy import column
 from sqlalchemy import func
+from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
@@ -53,6 +55,10 @@ reveal_type(stmt1)
 # test #10818
 # EXPECTED_TYPE: coalesce[str]
 reveal_type(func.coalesce(Foo.c, "a", "b"))
+# EXPECTED_TYPE: coalesce[str]
+reveal_type(func.coalesce("a", "b"))
+# EXPECTED_TYPE: coalesce[int]
+reveal_type(func.coalesce(column("x", Integer), 3))
 
 
 stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a)
index 5049ce52066f861cd5dba5e65e0701d4e9f3ce7c..624fbb75ed21cbd0854d12c62e6828e4dbe5c8a5 100644 (file)
@@ -67,10 +67,10 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str:
                             textwrap.indent(
                                 f"""
 
-# set ColumnElement[_T] as a separate overload, to appease mypy
-# which seems to not want to accept _T from _ColumnExpressionArgument.
-# this is even if all non-generic types are removed from it, so
-# reasons remain unclear for why this does not work
+# set ColumnElement[_T] as a separate overload, to appease
+# mypy which seems to not want to accept _T from
+# _ColumnExpressionArgument. Seems somewhat related to the covariant
+# _HasClauseElement as of mypy 1.15
 
 @overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
@@ -90,17 +90,15 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 ) -> {fn_class.__name__}[_T]:
         ...
 
-
 @overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _ColumnExpressionOrLiteralArgument[_T],
+    col: _T,
     *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
         ...
 
-
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
     col: _ColumnExpressionOrLiteralArgument[_T],