From: Federico Caselli Date: Mon, 24 Mar 2025 20:50:45 +0000 (+0100) Subject: improve overloads applied to generic functions X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5cc6a65c61798078959455f5d74f535681c119b7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git improve overloads applied to generic functions try again to remove the overloads to the generic functionn generator (like coalesce, array_agg, etc). As of mypy 1.15 it still does now work, but a simpler version is added in this change Change-Id: I8b97ae00298ec6f6bf8580090e5defff71e1ceb0 --- diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index c35cbf4adc..7b619ec589 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -5,7 +5,6 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php - """SQL function API, factories, and built-in functions.""" from __future__ import annotations @@ -153,7 +152,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): clause_expr: Grouping[Any] - def __init__(self, *clauses: _ColumnExpressionOrLiteralArgument[Any]): + def __init__( + self, *clauses: _ColumnExpressionOrLiteralArgument[Any] + ) -> None: r"""Construct a :class:`.FunctionElement`. :param \*clauses: list of column expressions that form the arguments @@ -775,7 +776,7 @@ class FunctionAsBinary(BinaryExpression[Any]): def __init__( self, fn: FunctionElement[Any], left_index: int, right_index: int - ): + ) -> None: self.sql_function = fn self.left_index = left_index self.right_index = right_index @@ -827,7 +828,7 @@ class ScalarFunctionColumn(NamedColumn[_T]): fn: FunctionElement[_T], name: str, type_: Optional[_TypeEngineArgument[_T]] = None, - ): + ) -> None: self.fn = fn self.name = name @@ -926,7 +927,7 @@ class _FunctionGenerator: """ # noqa - def __init__(self, **opts: Any): + def __init__(self, **opts: Any) -> None: self.__names: List[str] = [] self.opts = opts @@ -986,10 +987,10 @@ class _FunctionGenerator: @property def ansifunction(self) -> Type[AnsiFunction[Any]]: ... - # set ColumnElement[_T] as a separate overload, to appease mypy - # which seems to not want to accept _T from _ColumnExpressionArgument. - # this is even if all non-generic types are removed from it, so - # reasons remain unclear for why this does not work + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 @overload def array_agg( @@ -1010,7 +1011,7 @@ class _FunctionGenerator: @overload def array_agg( self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> array_agg[_T]: ... @@ -1028,10 +1029,10 @@ class _FunctionGenerator: @property def char_length(self) -> Type[char_length]: ... - # set ColumnElement[_T] as a separate overload, to appease mypy - # which seems to not want to accept _T from _ColumnExpressionArgument. - # this is even if all non-generic types are removed from it, so - # reasons remain unclear for why this does not work + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 @overload def coalesce( @@ -1052,7 +1053,7 @@ class _FunctionGenerator: @overload def coalesce( self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> coalesce[_T]: ... @@ -1103,10 +1104,10 @@ class _FunctionGenerator: @property def localtimestamp(self) -> Type[localtimestamp]: ... - # set ColumnElement[_T] as a separate overload, to appease mypy - # which seems to not want to accept _T from _ColumnExpressionArgument. - # this is even if all non-generic types are removed from it, so - # reasons remain unclear for why this does not work + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 @overload def max( # noqa: A001 @@ -1127,7 +1128,7 @@ class _FunctionGenerator: @overload def max( # noqa: A001 self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> max[_T]: ... @@ -1139,10 +1140,10 @@ class _FunctionGenerator: **kwargs: Any, ) -> max[_T]: ... - # set ColumnElement[_T] as a separate overload, to appease mypy - # which seems to not want to accept _T from _ColumnExpressionArgument. - # this is even if all non-generic types are removed from it, so - # reasons remain unclear for why this does not work + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 @overload def min( # noqa: A001 @@ -1163,7 +1164,7 @@ class _FunctionGenerator: @overload def min( # noqa: A001 self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> min[_T]: ... @@ -1208,10 +1209,10 @@ class _FunctionGenerator: @property def session_user(self) -> Type[session_user]: ... - # set ColumnElement[_T] as a separate overload, to appease mypy - # which seems to not want to accept _T from _ColumnExpressionArgument. - # this is even if all non-generic types are removed from it, so - # reasons remain unclear for why this does not work + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 @overload def sum( # noqa: A001 @@ -1232,7 +1233,7 @@ class _FunctionGenerator: @overload def sum( # noqa: A001 self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> sum[_T]: ... @@ -1328,7 +1329,7 @@ class Function(FunctionElement[_T]): *clauses: _ColumnExpressionOrLiteralArgument[_T], type_: None = ..., packagenames: Optional[Tuple[str, ...]] = ..., - ): ... + ) -> None: ... @overload def __init__( @@ -1337,7 +1338,7 @@ class Function(FunctionElement[_T]): *clauses: _ColumnExpressionOrLiteralArgument[Any], type_: _TypeEngineArgument[_T] = ..., packagenames: Optional[Tuple[str, ...]] = ..., - ): ... + ) -> None: ... def __init__( self, @@ -1345,7 +1346,7 @@ class Function(FunctionElement[_T]): *clauses: _ColumnExpressionOrLiteralArgument[Any], type_: Optional[_TypeEngineArgument[_T]] = None, packagenames: Optional[Tuple[str, ...]] = None, - ): + ) -> None: """Construct a :class:`.Function`. The :data:`.func` construct is normally used to construct @@ -1521,7 +1522,7 @@ class GenericFunction(Function[_T]): def __init__( self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any - ): + ) -> None: parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: parsed_args = [ @@ -1568,7 +1569,7 @@ class next_value(GenericFunction[int]): ("sequence", InternalTraversal.dp_named_ddl_element) ] - def __init__(self, seq: schema.Sequence, **kw: Any): + def __init__(self, seq: schema.Sequence, **kw: Any) -> None: assert isinstance( seq, schema.Sequence ), "next_value() accepts a Sequence object as input." @@ -1593,7 +1594,9 @@ class AnsiFunction(GenericFunction[_T]): inherit_cache = True - def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any): + def __init__( + self, *args: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> None: GenericFunction.__init__(self, *args, **kwargs) @@ -1604,10 +1607,10 @@ class ReturnTypeFromArgs(GenericFunction[_T]): inherit_cache = True - # set ColumnElement[_T] as a separate overload, to appease mypy which seems - # to not want to accept _T from _ColumnExpressionArgument. this is even if - # all non-generic types are removed from it, so reasons remain unclear for - # why this does not work + # set ColumnElement[_T] as a separate overload, to appease + # mypy which seems to not want to accept _T from + # _ColumnExpressionArgument. Seems somewhat related to the covariant + # _HasClauseElement as of mypy 1.15 @overload def __init__( @@ -1615,7 +1618,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]): col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ): ... + ) -> None: ... @overload def __init__( @@ -1623,19 +1626,19 @@ class ReturnTypeFromArgs(GenericFunction[_T]): col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ): ... + ) -> None: ... @overload def __init__( self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ): ... + ) -> None: ... def __init__( - self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any - ): + self, *args: _ColumnExpressionOrLiteralArgument[_T], **kwargs: Any + ) -> None: fn_args: Sequence[ColumnElement[Any]] = [ coercions.expect( roles.ExpressionElementRole, @@ -1717,7 +1720,7 @@ class char_length(GenericFunction[int]): type = sqltypes.Integer() inherit_cache = True - def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any): + def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any) -> None: # slight hack to limit to just one positional argument # not sure why this one function has this special treatment super().__init__(arg, **kw) @@ -1763,7 +1766,7 @@ class count(GenericFunction[int]): _ColumnExpressionArgument[Any], _StarOrOne, None ] = None, **kwargs: Any, - ): + ) -> None: if expression is None: expression = literal_column("*") super().__init__(expression, **kwargs) @@ -1852,7 +1855,9 @@ class array_agg(ReturnTypeFromArgs[Sequence[_T]]): inherit_cache = True - def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any): + def __init__( + self, *args: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> None: fn_args: Sequence[ColumnElement[Any]] = [ coercions.expect( roles.ExpressionElementRole, c, apply_propagate_attrs=self @@ -2079,5 +2084,7 @@ class aggregate_strings(GenericFunction[str]): _has_args = True inherit_cache = True - def __init__(self, clause: _ColumnExpressionArgument[Any], separator: str): + def __init__( + self, clause: _ColumnExpressionArgument[Any], separator: str + ) -> None: super().__init__(clause, separator) diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index c3acf0ed27..fc000277d0 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -1,4 +1,6 @@ +from sqlalchemy import column from sqlalchemy import func +from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped @@ -53,6 +55,10 @@ reveal_type(stmt1) # test #10818 # EXPECTED_TYPE: coalesce[str] reveal_type(func.coalesce(Foo.c, "a", "b")) +# EXPECTED_TYPE: coalesce[str] +reveal_type(func.coalesce("a", "b")) +# EXPECTED_TYPE: coalesce[int] +reveal_type(func.coalesce(column("x", Integer), 3)) stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a) diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index a88a7d7022..7b6c93de14 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -67,10 +67,10 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str: textwrap.indent( f""" -# set ColumnElement[_T] as a separate overload, to appease mypy -# which seems to not want to accept _T from _ColumnExpressionArgument. -# this is even if all non-generic types are removed from it, so -# reasons remain unclear for why this does not work +# set ColumnElement[_T] as a separate overload, to appease +# mypy which seems to not want to accept _T from +# _ColumnExpressionArgument. Seems somewhat related to the covariant +# _HasClauseElement as of mypy 1.15 @overload def {key}( {' # noqa: A001' if is_reserved_word else ''} @@ -90,17 +90,15 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} ) -> {fn_class.__name__}[_T]: ... - @overload def {key}( {' # noqa: A001' if is_reserved_word else ''} self, - col: _ColumnExpressionOrLiteralArgument[_T], + col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> {fn_class.__name__}[_T]: ... - def {key}( {' # noqa: A001' if is_reserved_word else ''} self, col: _ColumnExpressionOrLiteralArgument[_T],