]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type array_agg()
authorDenis Laxalde <denis@laxalde.org>
Mon, 24 Mar 2025 20:35:07 +0000 (16:35 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 24 Mar 2025 21:22:35 +0000 (22:22 +0100)
The return type of `array_agg()` is declared as a `Sequence[T]` where `T` is bound to the type of input argument.

This is implemented by making `array_agg()` inheriting from `ReturnTypeFromArgs` which provides appropriate overloads of `__init__()` to support this.

This usage of ReturnTypeFromArgs is a bit different from previous ones as the return type of the function is not exactly the same as that of its arguments, but a "collection" (a generic, namely a Sequence here) of the argument types.  Accordingly, we adjust the code of `tools/generate_sql_functions.py` to retrieve the "collection" type from 'fn_class' annotation and generate expected return type.

Also add a couple of hand-written typing tests for PostgreSQL.

Related to #6810

Closes: #12461
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12461
Pull-request-sha: ba27cbb8639dcd35127ab6a2928b7b5b3667e287

Change-Id: I3fd538cc7092a0492c26970f0b825bf70ddb66cd

lib/sqlalchemy/sql/functions.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py
test/typing/plain_files/sql/functions.py
tools/generate_sql_functions.py

index 87a68cfd90b8733cba6c14dd1a4e113c77ded5c8..c35cbf4adc570ee46d6014cdf76a38006461ea56 100644 (file)
@@ -6,9 +6,7 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
 
-"""SQL function API, factories, and built-in functions.
-
-"""
+"""SQL function API, factories, and built-in functions."""
 
 from __future__ import annotations
 
@@ -988,8 +986,41 @@ class _FunctionGenerator:
         @property
         def ansifunction(self) -> Type[AnsiFunction[Any]]: ...
 
-        @property
-        def array_agg(self) -> Type[array_agg[Any]]: ...
+        # set ColumnElement[_T] as a separate overload, to appease mypy
+        # which seems to not want to accept _T from _ColumnExpressionArgument.
+        # this is even if all non-generic types are removed from it, so
+        # reasons remain unclear for why this does not work
+
+        @overload
+        def array_agg(
+            self,
+            col: ColumnElement[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
+
+        @overload
+        def array_agg(
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
+
+        @overload
+        def array_agg(
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
+
+        def array_agg(
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
 
         @property
         def cast(self) -> Type[Cast[Any]]: ...
@@ -1567,7 +1598,9 @@ class AnsiFunction(GenericFunction[_T]):
 
 
 class ReturnTypeFromArgs(GenericFunction[_T]):
-    """Define a function whose return type is the same as its arguments."""
+    """Define a function whose return type is bound to the type of its
+    arguments.
+    """
 
     inherit_cache = True
 
@@ -1799,7 +1832,7 @@ class user(AnsiFunction[str]):
     inherit_cache = True
 
 
-class array_agg(GenericFunction[_T]):
+class array_agg(ReturnTypeFromArgs[Sequence[_T]]):
     """Support for the ARRAY_AGG function.
 
     The ``func.array_agg(expr)`` construct returns an expression of
index b74ea53082c0cfc31efc5431b48bc29a1fd4a325..6dda180c4f92cd33787ec079c5fa191690c4a9a2 100644 (file)
@@ -123,3 +123,11 @@ reveal_type(ARRAY(Text))
 
 # EXPECTED_TYPE: Column[Sequence[int]]
 reveal_type(Column(type_=ARRAY(Integer)))
+
+stmt_array_agg = select(func.array_agg(Column("num", type_=Integer)))
+
+# EXPECTED_TYPE: Select[Sequence[int]]
+reveal_type(stmt_array_agg)
+
+# EXPECTED_TYPE: Select[Sequence[str]]
+reveal_type(select(func.array_agg(Test.ident_str)))
index 9f307e5d921b60857425dd6caba5f00e2038266b..800ed90a99060c89284268fcfc11fdee48405fd4 100644 (file)
@@ -19,137 +19,143 @@ stmt1 = select(func.aggregate_strings(column("x", String), ","))
 reveal_type(stmt1)
 
 
-stmt2 = select(func.char_length(column("x")))
+stmt2 = select(func.array_agg(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[.*Sequence\[.*int\]\]
 reveal_type(stmt2)
 
 
-stmt3 = select(func.coalesce(column("x", Integer)))
+stmt3 = select(func.char_length(column("x")))
 
 # EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt3)
 
 
-stmt4 = select(func.concat())
+stmt4 = select(func.coalesce(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt4)
 
 
-stmt5 = select(func.count(column("x")))
+stmt5 = select(func.concat())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[.*str\]
 reveal_type(stmt5)
 
 
-stmt6 = select(func.cume_dist())
+stmt6 = select(func.count(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[.*Decimal\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt6)
 
 
-stmt7 = select(func.current_date())
+stmt7 = select(func.cume_dist())
 
-# EXPECTED_RE_TYPE: .*Select\[.*date\]
+# EXPECTED_RE_TYPE: .*Select\[.*Decimal\]
 reveal_type(stmt7)
 
 
-stmt8 = select(func.current_time())
+stmt8 = select(func.current_date())
 
-# EXPECTED_RE_TYPE: .*Select\[.*time\]
+# EXPECTED_RE_TYPE: .*Select\[.*date\]
 reveal_type(stmt8)
 
 
-stmt9 = select(func.current_timestamp())
+stmt9 = select(func.current_time())
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
+# EXPECTED_RE_TYPE: .*Select\[.*time\]
 reveal_type(stmt9)
 
 
-stmt10 = select(func.current_user())
+stmt10 = select(func.current_timestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
+# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
 reveal_type(stmt10)
 
 
-stmt11 = select(func.dense_rank())
+stmt11 = select(func.current_user())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[.*str\]
 reveal_type(stmt11)
 
 
-stmt12 = select(func.localtime())
+stmt12 = select(func.dense_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt12)
 
 
-stmt13 = select(func.localtimestamp())
+stmt13 = select(func.localtime())
 
 # EXPECTED_RE_TYPE: .*Select\[.*datetime\]
 reveal_type(stmt13)
 
 
-stmt14 = select(func.max(column("x", Integer)))
+stmt14 = select(func.localtimestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
 reveal_type(stmt14)
 
 
-stmt15 = select(func.min(column("x", Integer)))
+stmt15 = select(func.max(column("x", Integer)))
 
 # EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt15)
 
 
-stmt16 = select(func.next_value(Sequence("x_seq")))
+stmt16 = select(func.min(column("x", Integer)))
 
 # EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt16)
 
 
-stmt17 = select(func.now())
+stmt17 = select(func.next_value(Sequence("x_seq")))
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt17)
 
 
-stmt18 = select(func.percent_rank())
+stmt18 = select(func.now())
 
-# EXPECTED_RE_TYPE: .*Select\[.*Decimal\]
+# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
 reveal_type(stmt18)
 
 
-stmt19 = select(func.rank())
+stmt19 = select(func.percent_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[.*Decimal\]
 reveal_type(stmt19)
 
 
-stmt20 = select(func.session_user())
+stmt20 = select(func.rank())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt20)
 
 
-stmt21 = select(func.sum(column("x", Integer)))
+stmt21 = select(func.session_user())
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[.*str\]
 reveal_type(stmt21)
 
 
-stmt22 = select(func.sysdate())
+stmt22 = select(func.sum(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
+# EXPECTED_RE_TYPE: .*Select\[.*int\]
 reveal_type(stmt22)
 
 
-stmt23 = select(func.user())
+stmt23 = select(func.sysdate())
 
-# EXPECTED_RE_TYPE: .*Select\[.*str\]
+# EXPECTED_RE_TYPE: .*Select\[.*datetime\]
 reveal_type(stmt23)
 
+
+stmt24 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[.*str\]
+reveal_type(stmt24)
+
 # END GENERATED FUNCTION TYPING TESTS
 
 stmt_count: Select[int, int, int] = select(
index dc68b40f0a195e72ebe3763fae1d88daead34091..a88a7d702204c4deb22db0af3e3eb540b20d41a7 100644 (file)
@@ -1,6 +1,4 @@
-"""Generate inline stubs for generic functions on func
-
-"""
+"""Generate inline stubs for generic functions on func"""
 
 # mypy: ignore-errors
 
@@ -10,6 +8,9 @@ import inspect
 import re
 from tempfile import NamedTemporaryFile
 import textwrap
+import typing
+
+import typing_extensions
 
 from sqlalchemy.sql.functions import _registry
 from sqlalchemy.sql.functions import ReturnTypeFromArgs
@@ -168,12 +169,25 @@ def {key}(self) -> Type[{_type}]:{_reserved_word}
                     if issubclass(fn_class, ReturnTypeFromArgs):
                         count += 1
 
+                        # Would be ReturnTypeFromArgs
+                        (orig_base,) = typing_extensions.get_original_bases(
+                            fn_class
+                        )
+                        # Type parameter of ReturnTypeFromArgs
+                        (rtype,) = typing.get_args(orig_base)
+                        # The origin type, if rtype is a generic
+                        orig_type = typing.get_origin(rtype)
+                        if orig_type is not None:
+                            coltype = rf".*{orig_type.__name__}\[.*int\]"
+                        else:
+                            coltype = ".*int"
+
                         buf.write(
                             textwrap.indent(
                                 rf"""
 stmt{count} = select(func.{key}(column('x', Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[.*int\]
+# EXPECTED_RE_TYPE: .*Select\[{coltype}\]
 reveal_type(stmt{count})
 
 """,