From: Mike Bayer Date: Fri, 24 Nov 2023 20:20:31 +0000 (-0500) Subject: fully type functions.py X-Git-Tag: rel_2_0_24~26 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=972220878c0177531ad6f584fde2717f8e0a4315;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fully type functions.py Completed pep-484 typing for the ``sqlalchemy.sql.functions`` module. :func:`_sql.select` constructs made against ``func`` elements should now have filled-in return types. References: #6810 Change-Id: I5121583c9c5b6f7151f811348c7a281c446cf0b8 (cherry picked from commit 045732a738a10891b85be8e286eab3e5b756a445) --- diff --git a/doc/build/changelog/unreleased_20/sql_func_typing.rst b/doc/build/changelog/unreleased_20/sql_func_typing.rst new file mode 100644 index 0000000000..f4ea6f40c3 --- /dev/null +++ b/doc/build/changelog/unreleased_20/sql_func_typing.rst @@ -0,0 +1,7 @@ + .. change:: + :tags: bug, typing + :tickets: 6810 + + Completed pep-484 typing for the ``sqlalchemy.sql.functions`` module. + :func:`_sql.select` constructs made against ``func`` elements should now + have filled-in return types. diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 27197375d2..23e275ed5d 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -10,7 +10,6 @@ from __future__ import annotations import typing from typing import Any from typing import Callable -from typing import Iterable from typing import Mapping from typing import Optional from typing import overload @@ -49,6 +48,7 @@ from .functions import FunctionElement from ..util.typing import Literal if typing.TYPE_CHECKING: + from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _ColumnExpressionOrStrLabelArgument @@ -1483,18 +1483,8 @@ if not TYPE_CHECKING: def over( element: FunctionElement[_T], - partition_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, - order_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index c9e183058e..0793fbb3db 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -11,6 +11,7 @@ import operator from typing import Any from typing import Callable from typing import Dict +from typing import Iterable from typing import Mapping from typing import NoReturn from typing import Optional @@ -198,6 +199,12 @@ _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] _ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] +_ByArgument = Union[ + Iterable[_ColumnExpressionOrStrLabelArgument[Any]], + _ColumnExpressionOrStrLabelArgument[Any], +] +"""Used for keyword-based ``order_by`` and ``partition_by`` parameters.""" + _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 48dfd25829..cafd291eee 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -80,6 +80,7 @@ from ..util.typing import Literal from ..util.typing import Self if typing.TYPE_CHECKING: + from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument from ._typing import _InfoType @@ -4189,18 +4190,8 @@ class Over(ColumnElement[_T]): def __init__( self, element: ColumnElement[_T], - partition_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, - order_by: Optional[ - Union[ - Iterable[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], - ] - ] = None, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ): @@ -5202,12 +5193,12 @@ def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]: return cols -def _type_from_args(args): +def _type_from_args(args: Sequence[ColumnElement[_T]]) -> TypeEngine[_T]: for a in args: if not a.type._isnull: return a.type else: - return type_api.NULLTYPE + return type_api.NULLTYPE # type: ignore def _corresponding_column_or_error(fromclause, column, require_embedded=False): diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index fc23e9d215..c5eb6b2811 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -4,7 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls + """SQL function API, factories, and built-in functions. @@ -17,13 +17,16 @@ import decimal from typing import Any from typing import cast from typing import Dict +from typing import List from typing import Mapping from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from . import annotation from . import coercions @@ -59,23 +62,35 @@ from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util +from ..util.typing import Self if TYPE_CHECKING: + from ._typing import _ByArgument + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _TypeEngineArgument + from .base import _EntityNamespace + from .elements import ClauseElement + from .elements import KeyedColumnElement + from .elements import TableValuedColumn + from .operators import OperatorType from ..engine.base import Connection from ..engine.cursor import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams from ..engine.interfaces import CoreExecuteOptionsParameter _T = TypeVar("_T", bound=Any) +_S = TypeVar("_S", bound=Any) _registry: util.defaultdict[ str, Dict[str, Type[Function[Any]]] ] = util.defaultdict(dict) -def register_function(identifier, fn, package="_default"): +def register_function( + identifier: str, fn: Type[Function[Any]], package: str = "_default" +) -> None: """Associate a callable with a particular func. name. This is normally called by GenericFunction, but is also @@ -138,7 +153,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): clause_expr: Grouping[Any] - def __init__(self, *clauses: Any): + def __init__(self, *clauses: _ColumnExpressionOrLiteralArgument[Any]): r"""Construct a :class:`.FunctionElement`. :param \*clauses: list of column expressions that form the arguments @@ -154,7 +169,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): :class:`.Function` """ - args = [ + args: Sequence[_ColumnExpressionArgument[Any]] = [ coercions.expect( roles.ExpressionElementRole, c, @@ -171,7 +186,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): _non_anon_label = None @property - def _proxy_key(self): + def _proxy_key(self) -> Any: return super()._proxy_key or getattr(self, "name", None) def _execute_on_connection( @@ -184,7 +199,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): self, distilled_params, execution_options ) - def scalar_table_valued(self, name, type_=None): + def scalar_table_valued( + self, name: str, type_: Optional[_TypeEngineArgument[_T]] = None + ) -> ScalarFunctionColumn[_T]: """Return a column expression that's against this :class:`_functions.FunctionElement` as a scalar table-valued expression. @@ -217,7 +234,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return ScalarFunctionColumn(self, name, type_) - def table_valued(self, *expr, **kw): + def table_valued( + self, *expr: _ColumnExpressionArgument[Any], **kw: Any + ) -> TableValuedAlias: r"""Return a :class:`_sql.TableValuedAlias` representation of this :class:`_functions.FunctionElement` with table-valued expressions added. @@ -303,7 +322,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return new_func.alias(name=name, joins_implicitly=joins_implicitly) - def column_valued(self, name=None, joins_implicitly=False): + def column_valued( + self, name: Optional[str] = None, joins_implicitly: bool = False + ) -> TableValuedColumn[_T]: """Return this :class:`_functions.FunctionElement` as a column expression that selects from itself as a FROM clause. @@ -345,7 +366,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return self.alias(name=name, joins_implicitly=joins_implicitly).column @util.ro_non_memoized_property - def columns(self): + def columns(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501 r"""The set of columns exported by this :class:`.FunctionElement`. This is a placeholder collection that allows the function to be @@ -371,7 +392,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return self.c @util.ro_memoized_property - def c(self): + def c(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501 """synonym for :attr:`.FunctionElement.columns`.""" return ColumnCollection( @@ -379,16 +400,21 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ) @property - def _all_selected_columns(self): + def _all_selected_columns(self) -> Sequence[KeyedColumnElement[Any]]: if is_table_value_type(self.type): - cols = self.type._elements + # TODO: this might not be fully accurate + cols = cast( + "Sequence[KeyedColumnElement[Any]]", self.type._elements + ) else: cols = [self.label(None)] return cols @property - def exported_columns(self): + def exported_columns( # type: ignore[override] + self, + ) -> ColumnCollection[str, KeyedColumnElement[Any]]: return self.columns @HasMemoized.memoized_attribute @@ -399,7 +425,14 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): """ return cast(ClauseList, self.clause_expr.element) - def over(self, partition_by=None, order_by=None, rows=None, range_=None): + def over( + self, + *, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, + rows: Optional[Tuple[Optional[int], Optional[int]]] = None, + range_: Optional[Tuple[Optional[int], Optional[int]]] = None, + ) -> Over[_T]: """Produce an OVER clause against this function. Used against aggregate or so-called "window" functions, @@ -431,7 +464,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): range_=range_, ) - def within_group(self, *order_by): + def within_group( + self, *order_by: _ColumnExpressionArgument[Any] + ) -> WithinGroup[_T]: """Produce a WITHIN GROUP (ORDER BY expr) clause against this function. Used against so-called "ordered set aggregate" and "hypothetical @@ -449,7 +484,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): """ return WithinGroup(self, *order_by) - def filter(self, *criterion): + def filter( + self, *criterion: _ColumnExpressionArgument[bool] + ) -> Union[Self, FunctionFilter[_T]]: """Produce a FILTER clause against this function. Used against aggregate and window functions, @@ -479,7 +516,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return self return FunctionFilter(self, *criterion) - def as_comparison(self, left_index, right_index): + def as_comparison( + self, left_index: int, right_index: int + ) -> FunctionAsBinary: """Interpret this expression as a boolean comparison between two values. @@ -554,10 +593,12 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return FunctionAsBinary(self, left_index, right_index) @property - def _from_objects(self): + def _from_objects(self) -> Any: return self.clauses._from_objects - def within_group_type(self, within_group): + def within_group_type( + self, within_group: WithinGroup[_S] + ) -> Optional[TypeEngine[_S]]: """For types that define their return type as based on the criteria within a WITHIN GROUP (ORDER BY) expression, called by the :class:`.WithinGroup` construct. @@ -569,7 +610,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return None - def alias(self, name=None, joins_implicitly=False): + def alias( + self, name: Optional[str] = None, joins_implicitly: bool = False + ) -> TableValuedAlias: r"""Produce a :class:`_expression.Alias` construct against this :class:`.FunctionElement`. @@ -647,7 +690,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): joins_implicitly=joins_implicitly, ) - def select(self) -> Select[Any]: + def select(self) -> Select[Tuple[_T]]: """Produce a :func:`_expression.select` construct against this :class:`.FunctionElement`. @@ -661,7 +704,14 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): s = s.execution_options(**self._execution_options) return s - def _bind_param(self, operator, obj, type_=None, **kw): + def _bind_param( + self, + operator: OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + **kw: Any, + ) -> BindParameter[_T]: return BindParameter( None, obj, @@ -669,10 +719,11 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): _compared_to_type=self.type, unique=True, type_=type_, + expanding=expanding, **kw, ) - def self_group(self, against=None): + def self_group(self, against: Optional[OperatorType] = None) -> ClauseElement: # type: ignore[override] # noqa E501 # for the moment, we are parenthesizing all array-returning # expressions against getitem. This may need to be made # more portable if in the future we support other DBs @@ -685,7 +736,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return super().self_group(against=against) @property - def entity_namespace(self): + def entity_namespace(self) -> _EntityNamespace: """overrides FromClause.entity_namespace as functions are generally column expressions and not FromClauses. @@ -707,7 +758,7 @@ class FunctionAsBinary(BinaryExpression[Any]): left_index: int right_index: int - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key(self, anon_map: Any, bindparams: Any) -> Any: return ColumnElement._gen_cache_key(self, anon_map, bindparams) def __init__( @@ -860,8 +911,8 @@ class _FunctionGenerator: """ # noqa - def __init__(self, **opts): - self.__names = [] + def __init__(self, **opts: Any): + self.__names: List[str] = [] self.opts = opts def __getattr__(self, name: str) -> _FunctionGenerator: @@ -936,8 +987,33 @@ class _FunctionGenerator: def char_length(self) -> Type[char_length]: ... - @property - def coalesce(self) -> Type[coalesce[Any]]: + # appease mypy which seems to not want to accept _T from + # _ColumnExpressionArgument, as it includes non-generic types + + @overload + def coalesce( + self, + col: ColumnElement[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: + ... + + @overload + def coalesce( + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: + ... + + def coalesce( + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: ... @property @@ -992,12 +1068,62 @@ class _FunctionGenerator: def localtimestamp(self) -> Type[localtimestamp]: ... - @property - def max(self) -> Type[max[Any]]: # noqa: A001 + # appease mypy which seems to not want to accept _T from + # _ColumnExpressionArgument, as it includes non-generic types + + @overload + def max( # noqa: A001 + self, + col: ColumnElement[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> max[_T]: ... - @property - def min(self) -> Type[min[Any]]: # noqa: A001 + @overload + def max( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> max[_T]: + ... + + def max( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> max[_T]: + ... + + # appease mypy which seems to not want to accept _T from + # _ColumnExpressionArgument, as it includes non-generic types + + @overload + def min( # noqa: A001 + self, + col: ColumnElement[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> min[_T]: + ... + + @overload + def min( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> min[_T]: + ... + + def min( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> min[_T]: ... @property @@ -1036,10 +1162,6 @@ class _FunctionGenerator: def rank(self) -> Type[rank]: ... - @property - def returntypefromargs(self) -> Type[ReturnTypeFromArgs[Any]]: - ... - @property def rollup(self) -> Type[rollup[Any]]: ... @@ -1048,8 +1170,33 @@ class _FunctionGenerator: def session_user(self) -> Type[session_user]: ... - @property - def sum(self) -> Type[sum[Any]]: # noqa: A001 + # appease mypy which seems to not want to accept _T from + # _ColumnExpressionArgument, as it includes non-generic types + + @overload + def sum( # noqa: A001 + self, + col: ColumnElement[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> sum[_T]: + ... + + @overload + def sum( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> sum[_T]: + ... + + def sum( # noqa: A001 + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ) -> sum[_T]: ... @property @@ -1131,10 +1278,30 @@ class Function(FunctionElement[_T]): """ + @overload + def __init__( + self, + name: str, + *clauses: _ColumnExpressionOrLiteralArgument[_T], + type_: None = ..., + packagenames: Optional[Tuple[str, ...]] = ..., + ): + ... + + @overload + def __init__( + self, + name: str, + *clauses: _ColumnExpressionOrLiteralArgument[Any], + type_: _TypeEngineArgument[_T] = ..., + packagenames: Optional[Tuple[str, ...]] = ..., + ): + ... + def __init__( self, name: str, - *clauses: Any, + *clauses: _ColumnExpressionOrLiteralArgument[Any], type_: Optional[_TypeEngineArgument[_T]] = None, packagenames: Optional[Tuple[str, ...]] = None, ): @@ -1153,7 +1320,14 @@ class Function(FunctionElement[_T]): FunctionElement.__init__(self, *clauses) - def _bind_param(self, operator, obj, type_=None, **kw): + def _bind_param( + self, + operator: OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + **kw: Any, + ) -> BindParameter[_T]: return BindParameter( self.name, obj, @@ -1161,6 +1335,7 @@ class Function(FunctionElement[_T]): _compared_to_type=self.type, type_=type_, unique=True, + expanding=expanding, **kw, ) @@ -1306,7 +1481,9 @@ class GenericFunction(Function[_T]): # Set _register to True to register child classes by default cls._register = True - def __init__(self, *args, **kwargs): + def __init__( + self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any + ): parsed_args = kwargs.pop("_parsed_args", None) if parsed_args is None: parsed_args = [ @@ -1332,8 +1509,8 @@ class GenericFunction(Function[_T]): ) -register_function("cast", Cast) -register_function("extract", Extract) +register_function("cast", Cast) # type: ignore +register_function("extract", Extract) # type: ignore class next_value(GenericFunction[int]): @@ -1353,7 +1530,7 @@ class next_value(GenericFunction[int]): ("sequence", InternalTraversal.dp_named_ddl_element) ] - def __init__(self, seq, **kw): + def __init__(self, seq: schema.Sequence, **kw: Any): assert isinstance( seq, schema.Sequence ), "next_value() accepts a Sequence object as input." @@ -1362,14 +1539,14 @@ class next_value(GenericFunction[int]): seq.data_type or getattr(self, "type", None) ) - def compare(self, other, **kw): + def compare(self, other: Any, **kw: Any) -> bool: return ( isinstance(other, next_value) and self.sequence.name == other.sequence.name ) @property - def _from_objects(self): + def _from_objects(self) -> Any: return [] @@ -1378,7 +1555,7 @@ class AnsiFunction(GenericFunction[_T]): inherit_cache = True - def __init__(self, *args, **kwargs): + def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any): GenericFunction.__init__(self, *args, **kwargs) @@ -1387,8 +1564,29 @@ class ReturnTypeFromArgs(GenericFunction[_T]): inherit_cache = True - def __init__(self, *args, **kwargs): - fn_args = [ + # appease mypy which seems to not want to accept _T from + # _ColumnExpressionArgument, as it includes non-generic types + + @overload + def __init__( + self, + col: ColumnElement[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ): + ... + + @overload + def __init__( + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, + ): + ... + + def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any): + fn_args: Sequence[ColumnElement[Any]] = [ coercions.expect( roles.ExpressionElementRole, c, @@ -1469,7 +1667,7 @@ class char_length(GenericFunction[int]): type = sqltypes.Integer() inherit_cache = True - def __init__(self, arg, **kw): + def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any): # slight hack to limit to just one positional argument # not sure why this one function has this special treatment super().__init__(arg, **kw) @@ -1506,7 +1704,11 @@ class count(GenericFunction[int]): type = sqltypes.Integer() inherit_cache = True - def __init__(self, expression=None, **kwargs): + def __init__( + self, + expression: Optional[_ColumnExpressionArgument[Any]] = None, + **kwargs: Any, + ): if expression is None: expression = literal_column("*") super().__init__(expression, **kwargs) @@ -1595,8 +1797,8 @@ class array_agg(GenericFunction[_T]): inherit_cache = True - def __init__(self, *args, **kwargs): - fn_args = [ + def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any): + fn_args: Sequence[ColumnElement[Any]] = [ coercions.expect( roles.ExpressionElementRole, c, apply_propagate_attrs=self ) @@ -1624,9 +1826,13 @@ class OrderedSetAgg(GenericFunction[_T]): array_for_multi_clause = False inherit_cache = True - def within_group_type(self, within_group): + def within_group_type( + self, within_group: WithinGroup[Any] + ) -> TypeEngine[Any]: func_clauses = cast(ClauseList, self.clause_expr.element) - order_by = sqlutil.unwrap_order_by(within_group.order_by) + order_by: Sequence[ColumnElement[Any]] = sqlutil.unwrap_order_by( + within_group.order_by + ) if self.array_for_multi_clause and len(func_clauses.clauses) > 1: return sqltypes.ARRAY(order_by[0].type) else: @@ -1824,5 +2030,5 @@ class aggregate_strings(GenericFunction[str]): _has_args = True inherit_cache = True - def __init__(self, clause, separator): + def __init__(self, clause: _ColumnExpressionArgument[Any], separator: str): super().__init__(clause, separator) diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 28480a5d43..19551831fe 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -367,7 +367,7 @@ def find_tables( return tables -def unwrap_order_by(clause): +def unwrap_order_by(clause: Any) -> Any: """Break up an 'order by' expression into individual column-expressions, without DESC/ASC/NULLS FIRST/NULLS LAST""" diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index e66e554cff..6a345fcf6e 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -2,14 +2,17 @@ from sqlalchemy import column from sqlalchemy import func +from sqlalchemy import Integer from sqlalchemy import select +from sqlalchemy import Sequence +from sqlalchemy import String # START GENERATED FUNCTION TYPING TESTS # code within this block is **programmatically, # statically generated** by tools/generate_sql_functions.py -stmt1 = select(func.aggregate_strings(column("x"), column("x"))) +stmt1 = select(func.aggregate_strings(column("x", String), ",")) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt1) @@ -21,105 +24,129 @@ stmt2 = select(func.char_length(column("x"))) reveal_type(stmt2) -stmt3 = select(func.concat()) +stmt3 = select(func.coalesce(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt3) -stmt4 = select(func.count(column("x"))) +stmt4 = select(func.concat()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt4) -stmt5 = select(func.cume_dist()) +stmt5 = select(func.count(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt5) -stmt6 = select(func.current_date()) +stmt6 = select(func.cume_dist()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt6) -stmt7 = select(func.current_time()) +stmt7 = select(func.current_date()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] reveal_type(stmt7) -stmt8 = select(func.current_timestamp()) +stmt8 = select(func.current_time()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] reveal_type(stmt8) -stmt9 = select(func.current_user()) +stmt9 = select(func.current_timestamp()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt9) -stmt10 = select(func.dense_rank()) +stmt10 = select(func.current_user()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt10) -stmt11 = select(func.localtime()) +stmt11 = select(func.dense_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt11) -stmt12 = select(func.localtimestamp()) +stmt12 = select(func.localtime()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt12) -stmt13 = select(func.next_value(column("x"))) +stmt13 = select(func.localtimestamp()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt13) -stmt14 = select(func.now()) +stmt14 = select(func.max(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt14) -stmt15 = select(func.percent_rank()) +stmt15 = select(func.min(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt15) -stmt16 = select(func.rank()) +stmt16 = select(func.next_value(Sequence("x_seq"))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt16) -stmt17 = select(func.session_user()) +stmt17 = select(func.now()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt17) -stmt18 = select(func.sysdate()) +stmt18 = select(func.percent_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt18) -stmt19 = select(func.user()) +stmt19 = select(func.rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt19) + +stmt20 = select(func.session_user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt20) + + +stmt21 = select(func.sum(column("x", Integer))) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt21) + + +stmt22 = select(func.sysdate()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +reveal_type(stmt22) + + +stmt23 = select(func.user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt23) + # END GENERATED FUNCTION TYPING TESTS diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 848a927225..348b334484 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -11,6 +11,7 @@ from tempfile import NamedTemporaryFile import textwrap from sqlalchemy.sql.functions import _registry +from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.types import TypeEngine from sqlalchemy.util.tool_support import code_writer_cmd @@ -18,7 +19,10 @@ from sqlalchemy.util.tool_support import code_writer_cmd def _fns_in_deterministic_order(): reg = _registry["_default"] for key in sorted(reg): - yield key, reg[key] + cls = reg[key] + if cls is ReturnTypeFromArgs: + continue + yield key, cls def process_functions(filename: str, cmd: code_writer_cmd) -> str: @@ -53,23 +57,75 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str: for key, fn_class in _fns_in_deterministic_order(): is_reserved_word = key in builtins - guess_its_generic = bool(fn_class.__parameters__) + if issubclass(fn_class, ReturnTypeFromArgs): + buf.write( + textwrap.indent( + f""" + +# appease mypy which seems to not want to accept _T from +# _ColumnExpressionArgument, as it includes non-generic types + +@overload +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: ColumnElement[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... - buf.write( - textwrap.indent( - f""" +@overload +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... + +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... + + """, + indent, + ) + ) + else: + guess_its_generic = bool(fn_class.__parameters__) + + # the latest flake8 is quite broken here: + # 1. it insists on linting f-strings, no option + # to turn it off + # 2. the f-string indentation rules are either broken + # or completely impossible to figure out + # 3. there's no way to E501 a too-long f-string, + # so I can't even put the expressions all one line + # to get around the indentation errors + # 4. Therefore here I have to concat part of the + # string outside of the f-string + _type = fn_class.__name__ + _type += "[Any]" if guess_its_generic else "" + _reserved_word = ( + " # noqa: A001" if is_reserved_word else "" + ) + + # now the f-string + buf.write( + textwrap.indent( + f""" @property -def {key}(self) -> Type[{fn_class.__name__}{ - '[Any]' if guess_its_generic else '' -}]:{ - ' # noqa: A001' if is_reserved_word else '' -} +def {key}(self) -> Type[{_type}]:{_reserved_word} ... """, - indent, + indent, + ) ) - ) m = re.match( r"^( *)# START GENERATED FUNCTION TYPING TESTS", @@ -92,15 +148,48 @@ def {key}(self) -> Type[{fn_class.__name__}{ count = 0 for key, fn_class in _fns_in_deterministic_order(): - if hasattr(fn_class, "type") and isinstance( + if issubclass(fn_class, ReturnTypeFromArgs): + count += 1 + + buf.write( + textwrap.indent( + rf""" +stmt{count} = select(func.{key}(column('x', Integer))) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt{count}) + +""", + indent, + ) + ) + elif fn_class.__name__ == "aggregate_strings": + count += 1 + buf.write( + textwrap.indent( + rf""" +stmt{count} = select(func.{key}(column('x', String), ',')) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt{count}) + +""", + indent, + ) + ) + + elif hasattr(fn_class, "type") and isinstance( fn_class.type, TypeEngine ): python_type = fn_class.type.python_type python_expr = rf"Tuple\[.*{python_type.__name__}\]" argspec = inspect.getfullargspec(fn_class) - args = ", ".join( - 'column("x")' for elem in argspec.args[1:] - ) + if fn_class.__name__ == "next_value": + args = "Sequence('x_seq')" + else: + args = ", ".join( + 'column("x")' for elem in argspec.args[1:] + ) count += 1 buf.write(