]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fully type functions.py
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 24 Nov 2023 20:20:31 +0000 (15:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Nov 2023 14:46:33 +0000 (09:46 -0500)
Completed pep-484 typing for the ``sqlalchemy.sql.functions`` module.
:func:`_sql.select` constructs made against ``func`` elements should now
have filled-in return types.

References: #6810
Change-Id: I5121583c9c5b6f7151f811348c7a281c446cf0b8
(cherry picked from commit 045732a738a10891b85be8e286eab3e5b756a445)

doc/build/changelog/unreleased_20/sql_func_typing.rst [new file with mode: 0644]
lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/util.py
test/typing/plain_files/sql/functions.py
tools/generate_sql_functions.py

diff --git a/doc/build/changelog/unreleased_20/sql_func_typing.rst b/doc/build/changelog/unreleased_20/sql_func_typing.rst
new file mode 100644 (file)
index 0000000..f4ea6f4
--- /dev/null
@@ -0,0 +1,7 @@
+ .. change::
+    :tags: bug, typing
+    :tickets: 6810
+
+    Completed pep-484 typing for the ``sqlalchemy.sql.functions`` module.
+    :func:`_sql.select` constructs made against ``func`` elements should now
+    have filled-in return types.
index 27197375d2d9cb53f644c608aafd28ad20b3382e..23e275ed5d742a2f7905fa808307ee0d5216b086 100644 (file)
@@ -10,7 +10,6 @@ from __future__ import annotations
 import typing
 from typing import Any
 from typing import Callable
-from typing import Iterable
 from typing import Mapping
 from typing import Optional
 from typing import overload
@@ -49,6 +48,7 @@ from .functions import FunctionElement
 from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
+    from ._typing import _ByArgument
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrLiteralArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
@@ -1483,18 +1483,8 @@ if not TYPE_CHECKING:
 
 def over(
     element: FunctionElement[_T],
-    partition_by: Optional[
-        Union[
-            Iterable[_ColumnExpressionArgument[Any]],
-            _ColumnExpressionArgument[Any],
-        ]
-    ] = None,
-    order_by: Optional[
-        Union[
-            Iterable[_ColumnExpressionArgument[Any]],
-            _ColumnExpressionArgument[Any],
-        ]
-    ] = None,
+    partition_by: Optional[_ByArgument] = None,
+    order_by: Optional[_ByArgument] = None,
     range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
     rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
 ) -> Over[_T]:
index c9e183058e6ab23973cedaf47ca28d6a7b5cf69d..0793fbb3db1bb5dd0cebf4240d4f2854b094a495 100644 (file)
@@ -11,6 +11,7 @@ import operator
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Iterable
 from typing import Mapping
 from typing import NoReturn
 from typing import Optional
@@ -198,6 +199,12 @@ _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]]
 
 _ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]]
 
+_ByArgument = Union[
+    Iterable[_ColumnExpressionOrStrLabelArgument[Any]],
+    _ColumnExpressionOrStrLabelArgument[Any],
+]
+"""Used for keyword-based ``order_by`` and ``partition_by`` parameters."""
+
 
 _InfoType = Dict[Any, Any]
 """the .info dictionary accepted and used throughout Core /ORM"""
index 48dfd25829a30213c8e6b267491f1246693052ca..cafd291eee20ec3e25d483e79a19c6ac19e2c1c9 100644 (file)
@@ -80,6 +80,7 @@ from ..util.typing import Literal
 from ..util.typing import Self
 
 if typing.TYPE_CHECKING:
+    from ._typing import _ByArgument
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
     from ._typing import _InfoType
@@ -4189,18 +4190,8 @@ class Over(ColumnElement[_T]):
     def __init__(
         self,
         element: ColumnElement[_T],
-        partition_by: Optional[
-            Union[
-                Iterable[_ColumnExpressionArgument[Any]],
-                _ColumnExpressionArgument[Any],
-            ]
-        ] = None,
-        order_by: Optional[
-            Union[
-                Iterable[_ColumnExpressionArgument[Any]],
-                _ColumnExpressionArgument[Any],
-            ]
-        ] = None,
+        partition_by: Optional[_ByArgument] = None,
+        order_by: Optional[_ByArgument] = None,
         range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
         rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
     ):
@@ -5202,12 +5193,12 @@ def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]:
     return cols
 
 
-def _type_from_args(args):
+def _type_from_args(args: Sequence[ColumnElement[_T]]) -> TypeEngine[_T]:
     for a in args:
         if not a.type._isnull:
             return a.type
     else:
-        return type_api.NULLTYPE
+        return type_api.NULLTYPE  # type: ignore
 
 
 def _corresponding_column_or_error(fromclause, column, require_embedded=False):
index fc23e9d2156d9d9d8868abc12072e43e9acba620..c5eb6b28115f2ba2d4744a312203568c253d6a58 100644 (file)
@@ -4,7 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
+
 
 """SQL function API, factories, and built-in functions.
 
@@ -17,13 +17,16 @@ import decimal
 from typing import Any
 from typing import cast
 from typing import Dict
+from typing import List
 from typing import Mapping
 from typing import Optional
 from typing import overload
+from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
 from . import annotation
 from . import coercions
@@ -59,23 +62,35 @@ from .sqltypes import TableValueType
 from .type_api import TypeEngine
 from .visitors import InternalTraversal
 from .. import util
+from ..util.typing import Self
 
 
 if TYPE_CHECKING:
+    from ._typing import _ByArgument
+    from ._typing import _ColumnExpressionArgument
+    from ._typing import _ColumnExpressionOrLiteralArgument
     from ._typing import _TypeEngineArgument
+    from .base import _EntityNamespace
+    from .elements import ClauseElement
+    from .elements import KeyedColumnElement
+    from .elements import TableValuedColumn
+    from .operators import OperatorType
     from ..engine.base import Connection
     from ..engine.cursor import CursorResult
     from ..engine.interfaces import _CoreMultiExecuteParams
     from ..engine.interfaces import CoreExecuteOptionsParameter
 
 _T = TypeVar("_T", bound=Any)
+_S = TypeVar("_S", bound=Any)
 
 _registry: util.defaultdict[
     str, Dict[str, Type[Function[Any]]]
 ] = util.defaultdict(dict)
 
 
-def register_function(identifier, fn, package="_default"):
+def register_function(
+    identifier: str, fn: Type[Function[Any]], package: str = "_default"
+) -> None:
     """Associate a callable with a particular func. name.
 
     This is normally called by GenericFunction, but is also
@@ -138,7 +153,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
 
     clause_expr: Grouping[Any]
 
-    def __init__(self, *clauses: Any):
+    def __init__(self, *clauses: _ColumnExpressionOrLiteralArgument[Any]):
         r"""Construct a :class:`.FunctionElement`.
 
         :param \*clauses: list of column expressions that form the arguments
@@ -154,7 +169,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             :class:`.Function`
 
         """
-        args = [
+        args: Sequence[_ColumnExpressionArgument[Any]] = [
             coercions.expect(
                 roles.ExpressionElementRole,
                 c,
@@ -171,7 +186,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
     _non_anon_label = None
 
     @property
-    def _proxy_key(self):
+    def _proxy_key(self) -> Any:
         return super()._proxy_key or getattr(self, "name", None)
 
     def _execute_on_connection(
@@ -184,7 +199,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             self, distilled_params, execution_options
         )
 
-    def scalar_table_valued(self, name, type_=None):
+    def scalar_table_valued(
+        self, name: str, type_: Optional[_TypeEngineArgument[_T]] = None
+    ) -> ScalarFunctionColumn[_T]:
         """Return a column expression that's against this
         :class:`_functions.FunctionElement` as a scalar
         table-valued expression.
@@ -217,7 +234,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
 
         return ScalarFunctionColumn(self, name, type_)
 
-    def table_valued(self, *expr, **kw):
+    def table_valued(
+        self, *expr: _ColumnExpressionArgument[Any], **kw: Any
+    ) -> TableValuedAlias:
         r"""Return a :class:`_sql.TableValuedAlias` representation of this
         :class:`_functions.FunctionElement` with table-valued expressions added.
 
@@ -303,7 +322,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
 
         return new_func.alias(name=name, joins_implicitly=joins_implicitly)
 
-    def column_valued(self, name=None, joins_implicitly=False):
+    def column_valued(
+        self, name: Optional[str] = None, joins_implicitly: bool = False
+    ) -> TableValuedColumn[_T]:
         """Return this :class:`_functions.FunctionElement` as a column expression that
         selects from itself as a FROM clause.
 
@@ -345,7 +366,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         return self.alias(name=name, joins_implicitly=joins_implicitly).column
 
     @util.ro_non_memoized_property
-    def columns(self):
+    def columns(self) -> ColumnCollection[str, KeyedColumnElement[Any]]:  # type: ignore[override]  # noqa: E501
         r"""The set of columns exported by this :class:`.FunctionElement`.
 
         This is a placeholder collection that allows the function to be
@@ -371,7 +392,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         return self.c
 
     @util.ro_memoized_property
-    def c(self):
+    def c(self) -> ColumnCollection[str, KeyedColumnElement[Any]]:  # type: ignore[override]  # noqa: E501
         """synonym for :attr:`.FunctionElement.columns`."""
 
         return ColumnCollection(
@@ -379,16 +400,21 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         )
 
     @property
-    def _all_selected_columns(self):
+    def _all_selected_columns(self) -> Sequence[KeyedColumnElement[Any]]:
         if is_table_value_type(self.type):
-            cols = self.type._elements
+            # TODO: this might not be fully accurate
+            cols = cast(
+                "Sequence[KeyedColumnElement[Any]]", self.type._elements
+            )
         else:
             cols = [self.label(None)]
 
         return cols
 
     @property
-    def exported_columns(self):
+    def exported_columns(  # type: ignore[override]
+        self,
+    ) -> ColumnCollection[str, KeyedColumnElement[Any]]:
         return self.columns
 
     @HasMemoized.memoized_attribute
@@ -399,7 +425,14 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         """
         return cast(ClauseList, self.clause_expr.element)
 
-    def over(self, partition_by=None, order_by=None, rows=None, range_=None):
+    def over(
+        self,
+        *,
+        partition_by: Optional[_ByArgument] = None,
+        order_by: Optional[_ByArgument] = None,
+        rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
+        range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
+    ) -> Over[_T]:
         """Produce an OVER clause against this function.
 
         Used against aggregate or so-called "window" functions,
@@ -431,7 +464,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             range_=range_,
         )
 
-    def within_group(self, *order_by):
+    def within_group(
+        self, *order_by: _ColumnExpressionArgument[Any]
+    ) -> WithinGroup[_T]:
         """Produce a WITHIN GROUP (ORDER BY expr) clause against this function.
 
         Used against so-called "ordered set aggregate" and "hypothetical
@@ -449,7 +484,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         """
         return WithinGroup(self, *order_by)
 
-    def filter(self, *criterion):
+    def filter(
+        self, *criterion: _ColumnExpressionArgument[bool]
+    ) -> Union[Self, FunctionFilter[_T]]:
         """Produce a FILTER clause against this function.
 
         Used against aggregate and window functions,
@@ -479,7 +516,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             return self
         return FunctionFilter(self, *criterion)
 
-    def as_comparison(self, left_index, right_index):
+    def as_comparison(
+        self, left_index: int, right_index: int
+    ) -> FunctionAsBinary:
         """Interpret this expression as a boolean comparison between two
         values.
 
@@ -554,10 +593,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         return FunctionAsBinary(self, left_index, right_index)
 
     @property
-    def _from_objects(self):
+    def _from_objects(self) -> Any:
         return self.clauses._from_objects
 
-    def within_group_type(self, within_group):
+    def within_group_type(
+        self, within_group: WithinGroup[_S]
+    ) -> Optional[TypeEngine[_S]]:
         """For types that define their return type as based on the criteria
         within a WITHIN GROUP (ORDER BY) expression, called by the
         :class:`.WithinGroup` construct.
@@ -569,7 +610,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
 
         return None
 
-    def alias(self, name=None, joins_implicitly=False):
+    def alias(
+        self, name: Optional[str] = None, joins_implicitly: bool = False
+    ) -> TableValuedAlias:
         r"""Produce a :class:`_expression.Alias` construct against this
         :class:`.FunctionElement`.
 
@@ -647,7 +690,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             joins_implicitly=joins_implicitly,
         )
 
-    def select(self) -> Select[Any]:
+    def select(self) -> Select[Tuple[_T]]:
         """Produce a :func:`_expression.select` construct
         against this :class:`.FunctionElement`.
 
@@ -661,7 +704,14 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             s = s.execution_options(**self._execution_options)
         return s
 
-    def _bind_param(self, operator, obj, type_=None, **kw):
+    def _bind_param(
+        self,
+        operator: OperatorType,
+        obj: Any,
+        type_: Optional[TypeEngine[_T]] = None,
+        expanding: bool = False,
+        **kw: Any,
+    ) -> BindParameter[_T]:
         return BindParameter(
             None,
             obj,
@@ -669,10 +719,11 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             _compared_to_type=self.type,
             unique=True,
             type_=type_,
+            expanding=expanding,
             **kw,
         )
 
-    def self_group(self, against=None):
+    def self_group(self, against: Optional[OperatorType] = None) -> ClauseElement:  # type: ignore[override]  # noqa E501
         # for the moment, we are parenthesizing all array-returning
         # expressions against getitem.  This may need to be made
         # more portable if in the future we support other DBs
@@ -685,7 +736,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             return super().self_group(against=against)
 
     @property
-    def entity_namespace(self):
+    def entity_namespace(self) -> _EntityNamespace:
         """overrides FromClause.entity_namespace as functions are generally
         column expressions and not FromClauses.
 
@@ -707,7 +758,7 @@ class FunctionAsBinary(BinaryExpression[Any]):
     left_index: int
     right_index: int
 
-    def _gen_cache_key(self, anon_map, bindparams):
+    def _gen_cache_key(self, anon_map: Any, bindparams: Any) -> Any:
         return ColumnElement._gen_cache_key(self, anon_map, bindparams)
 
     def __init__(
@@ -860,8 +911,8 @@ class _FunctionGenerator:
 
     """  # noqa
 
-    def __init__(self, **opts):
-        self.__names = []
+    def __init__(self, **opts: Any):
+        self.__names: List[str] = []
         self.opts = opts
 
     def __getattr__(self, name: str) -> _FunctionGenerator:
@@ -936,8 +987,33 @@ class _FunctionGenerator:
         def char_length(self) -> Type[char_length]:
             ...
 
-        @property
-        def coalesce(self) -> Type[coalesce[Any]]:
+        # appease mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument, as it includes non-generic types
+
+        @overload
+        def coalesce(
+            self,
+            col: ColumnElement[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> coalesce[_T]:
+            ...
+
+        @overload
+        def coalesce(
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> coalesce[_T]:
+            ...
+
+        def coalesce(
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> coalesce[_T]:
             ...
 
         @property
@@ -992,12 +1068,62 @@ class _FunctionGenerator:
         def localtimestamp(self) -> Type[localtimestamp]:
             ...
 
-        @property
-        def max(self) -> Type[max[Any]]:  # noqa: A001
+        # appease mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument, as it includes non-generic types
+
+        @overload
+        def max(  # noqa: A001
+            self,
+            col: ColumnElement[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> max[_T]:
             ...
 
-        @property
-        def min(self) -> Type[min[Any]]:  # noqa: A001
+        @overload
+        def max(  # noqa: A001
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> max[_T]:
+            ...
+
+        def max(  # noqa: A001
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> max[_T]:
+            ...
+
+        # appease mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument, as it includes non-generic types
+
+        @overload
+        def min(  # noqa: A001
+            self,
+            col: ColumnElement[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> min[_T]:
+            ...
+
+        @overload
+        def min(  # noqa: A001
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> min[_T]:
+            ...
+
+        def min(  # noqa: A001
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> min[_T]:
             ...
 
         @property
@@ -1036,10 +1162,6 @@ class _FunctionGenerator:
         def rank(self) -> Type[rank]:
             ...
 
-        @property
-        def returntypefromargs(self) -> Type[ReturnTypeFromArgs[Any]]:
-            ...
-
         @property
         def rollup(self) -> Type[rollup[Any]]:
             ...
@@ -1048,8 +1170,33 @@ class _FunctionGenerator:
         def session_user(self) -> Type[session_user]:
             ...
 
-        @property
-        def sum(self) -> Type[sum[Any]]:  # noqa: A001
+        # appease mypy which seems to not want to accept _T from
+        # _ColumnExpressionArgument, as it includes non-generic types
+
+        @overload
+        def sum(  # noqa: A001
+            self,
+            col: ColumnElement[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> sum[_T]:
+            ...
+
+        @overload
+        def sum(  # noqa: A001
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> sum[_T]:
+            ...
+
+        def sum(  # noqa: A001
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionArgument[Any],
+            **kwargs: Any,
+        ) -> sum[_T]:
             ...
 
         @property
@@ -1131,10 +1278,30 @@ class Function(FunctionElement[_T]):
 
     """
 
+    @overload
+    def __init__(
+        self,
+        name: str,
+        *clauses: _ColumnExpressionOrLiteralArgument[_T],
+        type_: None = ...,
+        packagenames: Optional[Tuple[str, ...]] = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        name: str,
+        *clauses: _ColumnExpressionOrLiteralArgument[Any],
+        type_: _TypeEngineArgument[_T] = ...,
+        packagenames: Optional[Tuple[str, ...]] = ...,
+    ):
+        ...
+
     def __init__(
         self,
         name: str,
-        *clauses: Any,
+        *clauses: _ColumnExpressionOrLiteralArgument[Any],
         type_: Optional[_TypeEngineArgument[_T]] = None,
         packagenames: Optional[Tuple[str, ...]] = None,
     ):
@@ -1153,7 +1320,14 @@ class Function(FunctionElement[_T]):
 
         FunctionElement.__init__(self, *clauses)
 
-    def _bind_param(self, operator, obj, type_=None, **kw):
+    def _bind_param(
+        self,
+        operator: OperatorType,
+        obj: Any,
+        type_: Optional[TypeEngine[_T]] = None,
+        expanding: bool = False,
+        **kw: Any,
+    ) -> BindParameter[_T]:
         return BindParameter(
             self.name,
             obj,
@@ -1161,6 +1335,7 @@ class Function(FunctionElement[_T]):
             _compared_to_type=self.type,
             type_=type_,
             unique=True,
+            expanding=expanding,
             **kw,
         )
 
@@ -1306,7 +1481,9 @@ class GenericFunction(Function[_T]):
             # Set _register to True to register child classes by default
             cls._register = True
 
-    def __init__(self, *args, **kwargs):
+    def __init__(
+        self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
+    ):
         parsed_args = kwargs.pop("_parsed_args", None)
         if parsed_args is None:
             parsed_args = [
@@ -1332,8 +1509,8 @@ class GenericFunction(Function[_T]):
         )
 
 
-register_function("cast", Cast)
-register_function("extract", Extract)
+register_function("cast", Cast)  # type: ignore
+register_function("extract", Extract)  # type: ignore
 
 
 class next_value(GenericFunction[int]):
@@ -1353,7 +1530,7 @@ class next_value(GenericFunction[int]):
         ("sequence", InternalTraversal.dp_named_ddl_element)
     ]
 
-    def __init__(self, seq, **kw):
+    def __init__(self, seq: schema.Sequence, **kw: Any):
         assert isinstance(
             seq, schema.Sequence
         ), "next_value() accepts a Sequence object as input."
@@ -1362,14 +1539,14 @@ class next_value(GenericFunction[int]):
             seq.data_type or getattr(self, "type", None)
         )
 
-    def compare(self, other, **kw):
+    def compare(self, other: Any, **kw: Any) -> bool:
         return (
             isinstance(other, next_value)
             and self.sequence.name == other.sequence.name
         )
 
     @property
-    def _from_objects(self):
+    def _from_objects(self) -> Any:
         return []
 
 
@@ -1378,7 +1555,7 @@ class AnsiFunction(GenericFunction[_T]):
 
     inherit_cache = True
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
         GenericFunction.__init__(self, *args, **kwargs)
 
 
@@ -1387,8 +1564,29 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
 
     inherit_cache = True
 
-    def __init__(self, *args, **kwargs):
-        fn_args = [
+    # appease mypy which seems to not want to accept _T from
+    # _ColumnExpressionArgument, as it includes non-generic types
+
+    @overload
+    def __init__(
+        self,
+        col: ColumnElement[_T],
+        *args: _ColumnExpressionArgument[Any],
+        **kwargs: Any,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self,
+        col: _ColumnExpressionArgument[_T],
+        *args: _ColumnExpressionArgument[Any],
+        **kwargs: Any,
+    ):
+        ...
+
+    def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
+        fn_args: Sequence[ColumnElement[Any]] = [
             coercions.expect(
                 roles.ExpressionElementRole,
                 c,
@@ -1469,7 +1667,7 @@ class char_length(GenericFunction[int]):
     type = sqltypes.Integer()
     inherit_cache = True
 
-    def __init__(self, arg, **kw):
+    def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any):
         # slight hack to limit to just one positional argument
         # not sure why this one function has this special treatment
         super().__init__(arg, **kw)
@@ -1506,7 +1704,11 @@ class count(GenericFunction[int]):
     type = sqltypes.Integer()
     inherit_cache = True
 
-    def __init__(self, expression=None, **kwargs):
+    def __init__(
+        self,
+        expression: Optional[_ColumnExpressionArgument[Any]] = None,
+        **kwargs: Any,
+    ):
         if expression is None:
             expression = literal_column("*")
         super().__init__(expression, **kwargs)
@@ -1595,8 +1797,8 @@ class array_agg(GenericFunction[_T]):
 
     inherit_cache = True
 
-    def __init__(self, *args, **kwargs):
-        fn_args = [
+    def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
+        fn_args: Sequence[ColumnElement[Any]] = [
             coercions.expect(
                 roles.ExpressionElementRole, c, apply_propagate_attrs=self
             )
@@ -1624,9 +1826,13 @@ class OrderedSetAgg(GenericFunction[_T]):
     array_for_multi_clause = False
     inherit_cache = True
 
-    def within_group_type(self, within_group):
+    def within_group_type(
+        self, within_group: WithinGroup[Any]
+    ) -> TypeEngine[Any]:
         func_clauses = cast(ClauseList, self.clause_expr.element)
-        order_by = sqlutil.unwrap_order_by(within_group.order_by)
+        order_by: Sequence[ColumnElement[Any]] = sqlutil.unwrap_order_by(
+            within_group.order_by
+        )
         if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
             return sqltypes.ARRAY(order_by[0].type)
         else:
@@ -1824,5 +2030,5 @@ class aggregate_strings(GenericFunction[str]):
     _has_args = True
     inherit_cache = True
 
-    def __init__(self, clause, separator):
+    def __init__(self, clause: _ColumnExpressionArgument[Any], separator: str):
         super().__init__(clause, separator)
index 28480a5d437f0a89cb8c3c524ba03d3721760f7e..19551831fe337d75fcf83c74244aa95ab55e5261 100644 (file)
@@ -367,7 +367,7 @@ def find_tables(
     return tables
 
 
-def unwrap_order_by(clause):
+def unwrap_order_by(clause: Any) -> Any:
     """Break up an 'order by' expression into individual column-expressions,
     without DESC/ASC/NULLS FIRST/NULLS LAST"""
 
index e66e554cff78cb2a3833ffddbdc59fa765983009..6a345fcf6ec6463e753e7ab0e347db0661422876 100644 (file)
@@ -2,14 +2,17 @@
 
 from sqlalchemy import column
 from sqlalchemy import func
+from sqlalchemy import Integer
 from sqlalchemy import select
+from sqlalchemy import Sequence
+from sqlalchemy import String
 
 # START GENERATED FUNCTION TYPING TESTS
 
 # code within this block is **programmatically,
 # statically generated** by tools/generate_sql_functions.py
 
-stmt1 = select(func.aggregate_strings(column("x"), column("x")))
+stmt1 = select(func.aggregate_strings(column("x", String), ","))
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt1)
@@ -21,105 +24,129 @@ stmt2 = select(func.char_length(column("x")))
 reveal_type(stmt2)
 
 
-stmt3 = select(func.concat())
+stmt3 = select(func.coalesce(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt3)
 
 
-stmt4 = select(func.count(column("x")))
+stmt4 = select(func.concat())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt4)
 
 
-stmt5 = select(func.cume_dist())
+stmt5 = select(func.count(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt5)
 
 
-stmt6 = select(func.current_date())
+stmt6 = select(func.cume_dist())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt6)
 
 
-stmt7 = select(func.current_time())
+stmt7 = select(func.current_date())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
 reveal_type(stmt7)
 
 
-stmt8 = select(func.current_timestamp())
+stmt8 = select(func.current_time())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
 reveal_type(stmt8)
 
 
-stmt9 = select(func.current_user())
+stmt9 = select(func.current_timestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt9)
 
 
-stmt10 = select(func.dense_rank())
+stmt10 = select(func.current_user())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt10)
 
 
-stmt11 = select(func.localtime())
+stmt11 = select(func.dense_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt11)
 
 
-stmt12 = select(func.localtimestamp())
+stmt12 = select(func.localtime())
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt12)
 
 
-stmt13 = select(func.next_value(column("x")))
+stmt13 = select(func.localtimestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt13)
 
 
-stmt14 = select(func.now())
+stmt14 = select(func.max(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt14)
 
 
-stmt15 = select(func.percent_rank())
+stmt15 = select(func.min(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt15)
 
 
-stmt16 = select(func.rank())
+stmt16 = select(func.next_value(Sequence("x_seq")))
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt16)
 
 
-stmt17 = select(func.session_user())
+stmt17 = select(func.now())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt17)
 
 
-stmt18 = select(func.sysdate())
+stmt18 = select(func.percent_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt18)
 
 
-stmt19 = select(func.user())
+stmt19 = select(func.rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt19)
 
+
+stmt20 = select(func.session_user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt20)
+
+
+stmt21 = select(func.sum(column("x", Integer)))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt21)
+
+
+stmt22 = select(func.sysdate())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+reveal_type(stmt22)
+
+
+stmt23 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt23)
+
 # END GENERATED FUNCTION TYPING TESTS
index 848a927225056d0b728410a70a33a0d2e874f091..348b3344845e18d820efc6cf6f2274e88c3d6974 100644 (file)
@@ -11,6 +11,7 @@ from tempfile import NamedTemporaryFile
 import textwrap
 
 from sqlalchemy.sql.functions import _registry
+from sqlalchemy.sql.functions import ReturnTypeFromArgs
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.util.tool_support import code_writer_cmd
 
@@ -18,7 +19,10 @@ from sqlalchemy.util.tool_support import code_writer_cmd
 def _fns_in_deterministic_order():
     reg = _registry["_default"]
     for key in sorted(reg):
-        yield key, reg[key]
+        cls = reg[key]
+        if cls is ReturnTypeFromArgs:
+            continue
+        yield key, cls
 
 
 def process_functions(filename: str, cmd: code_writer_cmd) -> str:
@@ -53,23 +57,75 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str:
                 for key, fn_class in _fns_in_deterministic_order():
                     is_reserved_word = key in builtins
 
-                    guess_its_generic = bool(fn_class.__parameters__)
+                    if issubclass(fn_class, ReturnTypeFromArgs):
+                        buf.write(
+                            textwrap.indent(
+                                f"""
+
+# appease mypy which seems to not want to accept _T from
+# _ColumnExpressionArgument, as it includes non-generic types
+
+@overload
+def {key}( {'  # noqa: A001' if is_reserved_word else ''}
+    self,
+    col: ColumnElement[_T],
+    *args: _ColumnExpressionArgument[Any],
+    **kwargs: Any,
+) -> {fn_class.__name__}[_T]:
+    ...
 
-                    buf.write(
-                        textwrap.indent(
-                            f"""
+@overload
+def {key}( {'  # noqa: A001' if is_reserved_word else ''}
+    self,
+    col: _ColumnExpressionArgument[_T],
+    *args: _ColumnExpressionArgument[Any],
+    **kwargs: Any,
+) -> {fn_class.__name__}[_T]:
+        ...
+
+def {key}( {'  # noqa: A001' if is_reserved_word else ''}
+    self,
+    col: _ColumnExpressionArgument[_T],
+    *args: _ColumnExpressionArgument[Any],
+    **kwargs: Any,
+) -> {fn_class.__name__}[_T]:
+    ...
+
+    """,
+                                indent,
+                            )
+                        )
+                    else:
+                        guess_its_generic = bool(fn_class.__parameters__)
+
+                        # the latest flake8 is quite broken here:
+                        # 1. it insists on linting f-strings, no option
+                        #    to turn it off
+                        # 2. the f-string indentation rules are either broken
+                        #    or completely impossible to figure out
+                        # 3. there's no way to E501 a too-long f-string,
+                        #    so I can't even put the expressions all one line
+                        #    to get around the indentation errors
+                        # 4. Therefore here I have to concat part of the
+                        #    string outside of the f-string
+                        _type = fn_class.__name__
+                        _type += "[Any]" if guess_its_generic else ""
+                        _reserved_word = (
+                            "  # noqa: A001" if is_reserved_word else ""
+                        )
+
+                        # now the f-string
+                        buf.write(
+                            textwrap.indent(
+                                f"""
 @property
-def {key}(self) -> Type[{fn_class.__name__}{
-    '[Any]' if guess_its_generic else ''
-}]:{
-     '  # noqa: A001' if is_reserved_word else ''
-}
+def {key}(self) -> Type[{_type}]:{_reserved_word}
     ...
 
 """,
-                            indent,
+                                indent,
+                            )
                         )
-                    )
 
             m = re.match(
                 r"^( *)# START GENERATED FUNCTION TYPING TESTS",
@@ -92,15 +148,48 @@ def {key}(self) -> Type[{fn_class.__name__}{
 
                 count = 0
                 for key, fn_class in _fns_in_deterministic_order():
-                    if hasattr(fn_class, "type") and isinstance(
+                    if issubclass(fn_class, ReturnTypeFromArgs):
+                        count += 1
+
+                        buf.write(
+                            textwrap.indent(
+                                rf"""
+stmt{count} = select(func.{key}(column('x', Integer)))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt{count})
+
+""",
+                                indent,
+                            )
+                        )
+                    elif fn_class.__name__ == "aggregate_strings":
+                        count += 1
+                        buf.write(
+                            textwrap.indent(
+                                rf"""
+stmt{count} = select(func.{key}(column('x', String), ','))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt{count})
+
+""",
+                                indent,
+                            )
+                        )
+
+                    elif hasattr(fn_class, "type") and isinstance(
                         fn_class.type, TypeEngine
                     ):
                         python_type = fn_class.type.python_type
                         python_expr = rf"Tuple\[.*{python_type.__name__}\]"
                         argspec = inspect.getfullargspec(fn_class)
-                        args = ", ".join(
-                            'column("x")' for elem in argspec.args[1:]
-                        )
+                        if fn_class.__name__ == "next_value":
+                            args = "Sequence('x_seq')"
+                        else:
+                            args = ", ".join(
+                                'column("x")' for elem in argspec.args[1:]
+                            )
                         count += 1
 
                         buf.write(