# the MIT License: https://www.opensource.org/licenses/mit-license.php
-"""SQL function API, factories, and built-in functions.
-
-"""
+"""SQL function API, factories, and built-in functions."""
from __future__ import annotations
@property
def ansifunction(self) -> Type[AnsiFunction[Any]]: ...
- @property
- def array_agg(self) -> Type[array_agg[Any]]: ...
+ # set ColumnElement[_T] as a separate overload, to appease mypy
+ # which seems to not want to accept _T from _ColumnExpressionArgument.
+ # this is even if all non-generic types are removed from it, so
+ # reasons remain unclear for why this does not work
+
+ @overload
+ def array_agg(
+ self,
+ col: ColumnElement[_T],
+ *args: _ColumnExpressionOrLiteralArgument[Any],
+ **kwargs: Any,
+ ) -> array_agg[_T]: ...
+
+ @overload
+ def array_agg(
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionOrLiteralArgument[Any],
+ **kwargs: Any,
+ ) -> array_agg[_T]: ...
+
+ @overload
+ def array_agg(
+ self,
+ col: _ColumnExpressionOrLiteralArgument[_T],
+ *args: _ColumnExpressionOrLiteralArgument[Any],
+ **kwargs: Any,
+ ) -> array_agg[_T]: ...
+
+ def array_agg(
+ self,
+ col: _ColumnExpressionOrLiteralArgument[_T],
+ *args: _ColumnExpressionOrLiteralArgument[Any],
+ **kwargs: Any,
+ ) -> array_agg[_T]: ...
@property
def cast(self) -> Type[Cast[Any]]: ...
class ReturnTypeFromArgs(GenericFunction[_T]):
- """Define a function whose return type is the same as its arguments."""
+ """Define a function whose return type is bound to the type of its
+ arguments.
+ """
inherit_cache = True
inherit_cache = True
-class array_agg(GenericFunction[_T]):
+class array_agg(ReturnTypeFromArgs[Sequence[_T]]):
"""Support for the ARRAY_AGG function.
The ``func.array_agg(expr)`` construct returns an expression of
reveal_type(stmt1)
-stmt2 = select(func.char_length(column("x")))
+stmt2 = select(func.array_agg(column("x", Integer)))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Sequence\[.*int\]\]\]
reveal_type(stmt2)
-stmt3 = select(func.coalesce(column("x", Integer)))
+stmt3 = select(func.char_length(column("x")))
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt3)
-stmt4 = select(func.concat())
+stmt4 = select(func.coalesce(column("x", Integer)))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt4)
-stmt5 = select(func.count(column("x")))
+stmt5 = select(func.concat())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt5)
-stmt6 = select(func.cume_dist())
+stmt6 = select(func.count(column("x")))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt6)
-stmt7 = select(func.current_date())
+stmt7 = select(func.cume_dist())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
reveal_type(stmt7)
-stmt8 = select(func.current_time())
+stmt8 = select(func.current_date())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
reveal_type(stmt8)
-stmt9 = select(func.current_timestamp())
+stmt9 = select(func.current_time())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
reveal_type(stmt9)
-stmt10 = select(func.current_user())
+stmt10 = select(func.current_timestamp())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt10)
-stmt11 = select(func.dense_rank())
+stmt11 = select(func.current_user())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt11)
-stmt12 = select(func.localtime())
+stmt12 = select(func.dense_rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt12)
-stmt13 = select(func.localtimestamp())
+stmt13 = select(func.localtime())
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt13)
-stmt14 = select(func.max(column("x", Integer)))
+stmt14 = select(func.localtimestamp())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt14)
-stmt15 = select(func.min(column("x", Integer)))
+stmt15 = select(func.max(column("x", Integer)))
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt15)
-stmt16 = select(func.next_value(Sequence("x_seq")))
+stmt16 = select(func.min(column("x", Integer)))
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt16)
-stmt17 = select(func.now())
+stmt17 = select(func.next_value(Sequence("x_seq")))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt17)
-stmt18 = select(func.percent_rank())
+stmt18 = select(func.now())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt18)
-stmt19 = select(func.rank())
+stmt19 = select(func.percent_rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
reveal_type(stmt19)
-stmt20 = select(func.session_user())
+stmt20 = select(func.rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt20)
-stmt21 = select(func.sum(column("x", Integer)))
+stmt21 = select(func.session_user())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt21)
-stmt22 = select(func.sysdate())
+stmt22 = select(func.sum(column("x", Integer)))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt22)
-stmt23 = select(func.user())
+stmt23 = select(func.sysdate())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt23)
+
+stmt24 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt24)
+
# END GENERATED FUNCTION TYPING TESTS
stmt_count: Select[Tuple[int, int, int]] = select(
-"""Generate inline stubs for generic functions on func
-
-"""
+"""Generate inline stubs for generic functions on func"""
# mypy: ignore-errors
import re
from tempfile import NamedTemporaryFile
import textwrap
+import typing
+
+import typing_extensions
from sqlalchemy.sql.functions import _registry
from sqlalchemy.sql.functions import ReturnTypeFromArgs
if issubclass(fn_class, ReturnTypeFromArgs):
count += 1
+ # Would be ReturnTypeFromArgs
+ (orig_base,) = typing_extensions.get_original_bases(
+ fn_class
+ )
+ # Type parameter of ReturnTypeFromArgs
+ (rtype,) = typing.get_args(orig_base)
+ # The origin type, if rtype is a generic
+ orig_type = typing.get_origin(rtype)
+ if orig_type is not None:
+ coltype = rf".*{orig_type.__name__}\[.*int\]"
+ else:
+ coltype = ".*int"
+
buf.write(
textwrap.indent(
rf"""
stmt{count} = select(func.{key}(column('x', Integer)))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[{coltype}\]\]
reveal_type(stmt{count})
""",