]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type array_agg()
authorDenis Laxalde <denis@laxalde.org>
Mon, 24 Mar 2025 20:35:07 +0000 (16:35 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 25 Mar 2025 22:37:32 +0000 (23:37 +0100)
The return type of `array_agg()` is declared as a `Sequence[T]` where `T` is bound to the type of input argument.

This is implemented by making `array_agg()` inheriting from `ReturnTypeFromArgs` which provides appropriate overloads of `__init__()` to support this.

This usage of ReturnTypeFromArgs is a bit different from previous ones as the return type of the function is not exactly the same as that of its arguments, but a "collection" (a generic, namely a Sequence here) of the argument types.  Accordingly, we adjust the code of `tools/generate_sql_functions.py` to retrieve the "collection" type from 'fn_class' annotation and generate expected return type.

Also add a couple of hand-written typing tests for PostgreSQL.

Related to #6810

Closes: #12461
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12461
Pull-request-sha: ba27cbb8639dcd35127ab6a2928b7b5b3667e287

Change-Id: I3fd538cc7092a0492c26970f0b825bf70ddb66cd
(cherry picked from commit 543acbd8d1c7e3037877ca74a6b05f62592ef153)

lib/sqlalchemy/sql/functions.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py
test/typing/plain_files/sql/functions.py
tools/generate_sql_functions.py

index ea02279d4809bdc7756aa37791e556cdfe531e49..bd7d6877c3ed5ce6d03fdad28dfb70a37a9e3487 100644 (file)
@@ -6,9 +6,7 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
 
-"""SQL function API, factories, and built-in functions.
-
-"""
+"""SQL function API, factories, and built-in functions."""
 
 from __future__ import annotations
 
@@ -990,8 +988,41 @@ class _FunctionGenerator:
         @property
         def ansifunction(self) -> Type[AnsiFunction[Any]]: ...
 
-        @property
-        def array_agg(self) -> Type[array_agg[Any]]: ...
+        # set ColumnElement[_T] as a separate overload, to appease mypy
+        # which seems to not want to accept _T from _ColumnExpressionArgument.
+        # this is even if all non-generic types are removed from it, so
+        # reasons remain unclear for why this does not work
+
+        @overload
+        def array_agg(
+            self,
+            col: ColumnElement[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
+
+        @overload
+        def array_agg(
+            self,
+            col: _ColumnExpressionArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
+
+        @overload
+        def array_agg(
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
+
+        def array_agg(
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> array_agg[_T]: ...
 
         @property
         def cast(self) -> Type[Cast[Any]]: ...
@@ -1575,7 +1606,9 @@ class AnsiFunction(GenericFunction[_T]):
 
 
 class ReturnTypeFromArgs(GenericFunction[_T]):
-    """Define a function whose return type is the same as its arguments."""
+    """Define a function whose return type is bound to the type of its
+    arguments.
+    """
 
     inherit_cache = True
 
@@ -1807,7 +1840,7 @@ class user(AnsiFunction[str]):
     inherit_cache = True
 
 
-class array_agg(GenericFunction[_T]):
+class array_agg(ReturnTypeFromArgs[Sequence[_T]]):
     """Support for the ARRAY_AGG function.
 
     The ``func.array_agg(expr)`` construct returns an expression of
index bc05ef8c4418d5b99a23e61a4b544751fdee57da..3dbb94987879f5ae6616675a2420e75d046c0b14 100644 (file)
@@ -123,3 +123,11 @@ reveal_type(ARRAY(Text))
 
 # EXPECTED_TYPE: Column[Sequence[int]]
 reveal_type(Column(type_=ARRAY(Integer)))
+
+stmt_array_agg = select(func.array_agg(Column("num", type_=Integer)))
+
+# EXPECTED_TYPE: Select[Tuple[Sequence[int]]]
+reveal_type(stmt_array_agg)
+
+# EXPECTED_TYPE: Select[Tuple[Sequence[str]]]
+reveal_type(select(func.array_agg(Test.ident_str)))
index f657a48571aa9c7d22c940eb3cce18ce8dcb378c..e1cea4193e4283c88acc53baa1ba9332e089aa53 100644 (file)
@@ -21,137 +21,143 @@ stmt1 = select(func.aggregate_strings(column("x", String), ","))
 reveal_type(stmt1)
 
 
-stmt2 = select(func.char_length(column("x")))
+stmt2 = select(func.array_agg(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Sequence\[.*int\]\]\]
 reveal_type(stmt2)
 
 
-stmt3 = select(func.coalesce(column("x", Integer)))
+stmt3 = select(func.char_length(column("x")))
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt3)
 
 
-stmt4 = select(func.concat())
+stmt4 = select(func.coalesce(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt4)
 
 
-stmt5 = select(func.count(column("x")))
+stmt5 = select(func.concat())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt5)
 
 
-stmt6 = select(func.cume_dist())
+stmt6 = select(func.count(column("x")))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt6)
 
 
-stmt7 = select(func.current_date())
+stmt7 = select(func.cume_dist())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt7)
 
 
-stmt8 = select(func.current_time())
+stmt8 = select(func.current_date())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
 reveal_type(stmt8)
 
 
-stmt9 = select(func.current_timestamp())
+stmt9 = select(func.current_time())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
 reveal_type(stmt9)
 
 
-stmt10 = select(func.current_user())
+stmt10 = select(func.current_timestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt10)
 
 
-stmt11 = select(func.dense_rank())
+stmt11 = select(func.current_user())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt11)
 
 
-stmt12 = select(func.localtime())
+stmt12 = select(func.dense_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt12)
 
 
-stmt13 = select(func.localtimestamp())
+stmt13 = select(func.localtime())
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt13)
 
 
-stmt14 = select(func.max(column("x", Integer)))
+stmt14 = select(func.localtimestamp())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt14)
 
 
-stmt15 = select(func.min(column("x", Integer)))
+stmt15 = select(func.max(column("x", Integer)))
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt15)
 
 
-stmt16 = select(func.next_value(Sequence("x_seq")))
+stmt16 = select(func.min(column("x", Integer)))
 
 # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt16)
 
 
-stmt17 = select(func.now())
+stmt17 = select(func.next_value(Sequence("x_seq")))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt17)
 
 
-stmt18 = select(func.percent_rank())
+stmt18 = select(func.now())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt18)
 
 
-stmt19 = select(func.rank())
+stmt19 = select(func.percent_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt19)
 
 
-stmt20 = select(func.session_user())
+stmt20 = select(func.rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt20)
 
 
-stmt21 = select(func.sum(column("x", Integer)))
+stmt21 = select(func.session_user())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
 reveal_type(stmt21)
 
 
-stmt22 = select(func.sysdate())
+stmt22 = select(func.sum(column("x", Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
 reveal_type(stmt22)
 
 
-stmt23 = select(func.user())
+stmt23 = select(func.sysdate())
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
 reveal_type(stmt23)
 
+
+stmt24 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt24)
+
 # END GENERATED FUNCTION TYPING TESTS
 
 stmt_count: Select[Tuple[int, int, int]] = select(
index 0e5104352f58ab7f88c0b1785539efdec90ba28b..5049ce52066f861cd5dba5e65e0701d4e9f3ce7c 100644 (file)
@@ -1,6 +1,4 @@
-"""Generate inline stubs for generic functions on func
-
-"""
+"""Generate inline stubs for generic functions on func"""
 
 # mypy: ignore-errors
 
@@ -10,6 +8,9 @@ import inspect
 import re
 from tempfile import NamedTemporaryFile
 import textwrap
+import typing
+
+import typing_extensions
 
 from sqlalchemy.sql.functions import _registry
 from sqlalchemy.sql.functions import ReturnTypeFromArgs
@@ -168,12 +169,25 @@ def {key}(self) -> Type[{_type}]:{_reserved_word}
                     if issubclass(fn_class, ReturnTypeFromArgs):
                         count += 1
 
+                        # Would be ReturnTypeFromArgs
+                        (orig_base,) = typing_extensions.get_original_bases(
+                            fn_class
+                        )
+                        # Type parameter of ReturnTypeFromArgs
+                        (rtype,) = typing.get_args(orig_base)
+                        # The origin type, if rtype is a generic
+                        orig_type = typing.get_origin(rtype)
+                        if orig_type is not None:
+                            coltype = rf".*{orig_type.__name__}\[.*int\]"
+                        else:
+                            coltype = ".*int"
+
                         buf.write(
                             textwrap.indent(
                                 rf"""
 stmt{count} = select(func.{key}(column('x', Integer)))
 
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[{coltype}\]\]
 reveal_type(stmt{count})
 
 """,