From: Denis Laxalde Date: Mon, 24 Mar 2025 20:35:07 +0000 (-0400) Subject: Type array_agg() X-Git-Tag: rel_2_0_40~2^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=56600630ffec6929c167c053fb852b0d77d55f14;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Type array_agg() The return type of `array_agg()` is declared as a `Sequence[T]` where `T` is bound to the type of input argument. This is implemented by making `array_agg()` inheriting from `ReturnTypeFromArgs` which provides appropriate overloads of `__init__()` to support this. This usage of ReturnTypeFromArgs is a bit different from previous ones as the return type of the function is not exactly the same as that of its arguments, but a "collection" (a generic, namely a Sequence here) of the argument types. Accordingly, we adjust the code of `tools/generate_sql_functions.py` to retrieve the "collection" type from 'fn_class' annotation and generate expected return type. Also add a couple of hand-written typing tests for PostgreSQL. Related to #6810 Closes: #12461 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12461 Pull-request-sha: ba27cbb8639dcd35127ab6a2928b7b5b3667e287 Change-Id: I3fd538cc7092a0492c26970f0b825bf70ddb66cd (cherry picked from commit 543acbd8d1c7e3037877ca74a6b05f62592ef153) --- diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index ea02279d48..bd7d6877c3 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -6,9 +6,7 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""SQL function API, factories, and built-in functions. - -""" +"""SQL function API, factories, and built-in functions.""" from __future__ import annotations @@ -990,8 +988,41 @@ class _FunctionGenerator: @property def ansifunction(self) -> Type[AnsiFunction[Any]]: ... - @property - def array_agg(self) -> Type[array_agg[Any]]: ... + # set ColumnElement[_T] as a separate overload, to appease mypy + # which seems to not want to accept _T from _ColumnExpressionArgument. + # this is even if all non-generic types are removed from it, so + # reasons remain unclear for why this does not work + + @overload + def array_agg( + self, + col: ColumnElement[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... + + @overload + def array_agg( + self, + col: _ColumnExpressionArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... + + @overload + def array_agg( + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... + + def array_agg( + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> array_agg[_T]: ... @property def cast(self) -> Type[Cast[Any]]: ... @@ -1575,7 +1606,9 @@ class AnsiFunction(GenericFunction[_T]): class ReturnTypeFromArgs(GenericFunction[_T]): - """Define a function whose return type is the same as its arguments.""" + """Define a function whose return type is bound to the type of its + arguments. + """ inherit_cache = True @@ -1807,7 +1840,7 @@ class user(AnsiFunction[str]): inherit_cache = True -class array_agg(GenericFunction[_T]): +class array_agg(ReturnTypeFromArgs[Sequence[_T]]): """Support for the ARRAY_AGG function. The ``func.array_agg(expr)`` construct returns an expression of diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index bc05ef8c44..3dbb949878 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -123,3 +123,11 @@ reveal_type(ARRAY(Text)) # EXPECTED_TYPE: Column[Sequence[int]] reveal_type(Column(type_=ARRAY(Integer))) + +stmt_array_agg = select(func.array_agg(Column("num", type_=Integer))) + +# EXPECTED_TYPE: Select[Tuple[Sequence[int]]] +reveal_type(stmt_array_agg) + +# EXPECTED_TYPE: Select[Tuple[Sequence[str]]] +reveal_type(select(func.array_agg(Test.ident_str))) diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index f657a48571..e1cea4193e 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -21,137 +21,143 @@ stmt1 = select(func.aggregate_strings(column("x", String), ",")) reveal_type(stmt1) -stmt2 = select(func.char_length(column("x"))) +stmt2 = select(func.array_agg(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Sequence\[.*int\]\]\] reveal_type(stmt2) -stmt3 = select(func.coalesce(column("x", Integer))) +stmt3 = select(func.char_length(column("x"))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt3) -stmt4 = select(func.concat()) +stmt4 = select(func.coalesce(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt4) -stmt5 = select(func.count(column("x"))) +stmt5 = select(func.concat()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt5) -stmt6 = select(func.cume_dist()) +stmt6 = select(func.count(column("x"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt6) -stmt7 = select(func.current_date()) +stmt7 = select(func.cume_dist()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt7) -stmt8 = select(func.current_time()) +stmt8 = select(func.current_date()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] reveal_type(stmt8) -stmt9 = select(func.current_timestamp()) +stmt9 = select(func.current_time()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] reveal_type(stmt9) -stmt10 = select(func.current_user()) +stmt10 = select(func.current_timestamp()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt10) -stmt11 = select(func.dense_rank()) +stmt11 = select(func.current_user()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt11) -stmt12 = select(func.localtime()) +stmt12 = select(func.dense_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt12) -stmt13 = select(func.localtimestamp()) +stmt13 = select(func.localtime()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt13) -stmt14 = select(func.max(column("x", Integer))) +stmt14 = select(func.localtimestamp()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt14) -stmt15 = select(func.min(column("x", Integer))) +stmt15 = select(func.max(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt15) -stmt16 = select(func.next_value(Sequence("x_seq"))) +stmt16 = select(func.min(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt16) -stmt17 = select(func.now()) +stmt17 = select(func.next_value(Sequence("x_seq"))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt17) -stmt18 = select(func.percent_rank()) +stmt18 = select(func.now()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt18) -stmt19 = select(func.rank()) +stmt19 = select(func.percent_rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt19) -stmt20 = select(func.session_user()) +stmt20 = select(func.rank()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt20) -stmt21 = select(func.sum(column("x", Integer))) +stmt21 = select(func.session_user()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt21) -stmt22 = select(func.sysdate()) +stmt22 = select(func.sum(column("x", Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt22) -stmt23 = select(func.user()) +stmt23 = select(func.sysdate()) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt23) + +stmt24 = select(func.user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt24) + # END GENERATED FUNCTION TYPING TESTS stmt_count: Select[Tuple[int, int, int]] = select( diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 0e5104352f..5049ce5206 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -1,6 +1,4 @@ -"""Generate inline stubs for generic functions on func - -""" +"""Generate inline stubs for generic functions on func""" # mypy: ignore-errors @@ -10,6 +8,9 @@ import inspect import re from tempfile import NamedTemporaryFile import textwrap +import typing + +import typing_extensions from sqlalchemy.sql.functions import _registry from sqlalchemy.sql.functions import ReturnTypeFromArgs @@ -168,12 +169,25 @@ def {key}(self) -> Type[{_type}]:{_reserved_word} if issubclass(fn_class, ReturnTypeFromArgs): count += 1 + # Would be ReturnTypeFromArgs + (orig_base,) = typing_extensions.get_original_bases( + fn_class + ) + # Type parameter of ReturnTypeFromArgs + (rtype,) = typing.get_args(orig_base) + # The origin type, if rtype is a generic + orig_type = typing.get_origin(rtype) + if orig_type is not None: + coltype = rf".*{orig_type.__name__}\[.*int\]" + else: + coltype = ".*int" + buf.write( textwrap.indent( rf""" stmt{count} = select(func.{key}(column('x', Integer))) -# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[{coltype}\]\] reveal_type(stmt{count}) """,