#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: allow-untyped-defs, allow-untyped-calls
+
"""SQL function API, factories, and built-in functions.
from typing import Any
from typing import cast
from typing import Dict
+from typing import List
from typing import Mapping
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import annotation
from . import coercions
from .type_api import TypeEngine
from .visitors import InternalTraversal
from .. import util
+from ..util.typing import Self
if TYPE_CHECKING:
+ from ._typing import _ByArgument
+ from ._typing import _ColumnExpressionArgument
+ from ._typing import _ColumnExpressionOrLiteralArgument
from ._typing import _TypeEngineArgument
+ from .base import _EntityNamespace
+ from .elements import ClauseElement
+ from .elements import KeyedColumnElement
+ from .elements import TableValuedColumn
+ from .operators import OperatorType
from ..engine.base import Connection
from ..engine.cursor import CursorResult
from ..engine.interfaces import _CoreMultiExecuteParams
from ..engine.interfaces import CoreExecuteOptionsParameter
_T = TypeVar("_T", bound=Any)
+_S = TypeVar("_S", bound=Any)
_registry: util.defaultdict[
str, Dict[str, Type[Function[Any]]]
] = util.defaultdict(dict)
-def register_function(identifier, fn, package="_default"):
+def register_function(
+ identifier: str, fn: Type[Function[Any]], package: str = "_default"
+) -> None:
"""Associate a callable with a particular func. name.
This is normally called by GenericFunction, but is also
clause_expr: Grouping[Any]
- def __init__(self, *clauses: Any):
+ def __init__(self, *clauses: _ColumnExpressionOrLiteralArgument[Any]):
r"""Construct a :class:`.FunctionElement`.
:param \*clauses: list of column expressions that form the arguments
:class:`.Function`
"""
- args = [
+ args: Sequence[_ColumnExpressionArgument[Any]] = [
coercions.expect(
roles.ExpressionElementRole,
c,
_non_anon_label = None
@property
- def _proxy_key(self):
+ def _proxy_key(self) -> Any:
return super()._proxy_key or getattr(self, "name", None)
def _execute_on_connection(
self, distilled_params, execution_options
)
- def scalar_table_valued(self, name, type_=None):
+ def scalar_table_valued(
+ self, name: str, type_: Optional[_TypeEngineArgument[_T]] = None
+ ) -> ScalarFunctionColumn[_T]:
"""Return a column expression that's against this
:class:`_functions.FunctionElement` as a scalar
table-valued expression.
return ScalarFunctionColumn(self, name, type_)
- def table_valued(self, *expr, **kw):
+ def table_valued(
+ self, *expr: _ColumnExpressionArgument[Any], **kw: Any
+ ) -> TableValuedAlias:
r"""Return a :class:`_sql.TableValuedAlias` representation of this
:class:`_functions.FunctionElement` with table-valued expressions added.
return new_func.alias(name=name, joins_implicitly=joins_implicitly)
- def column_valued(self, name=None, joins_implicitly=False):
+ def column_valued(
+ self, name: Optional[str] = None, joins_implicitly: bool = False
+ ) -> TableValuedColumn[_T]:
"""Return this :class:`_functions.FunctionElement` as a column expression that
selects from itself as a FROM clause.
return self.alias(name=name, joins_implicitly=joins_implicitly).column
@util.ro_non_memoized_property
- def columns(self):
+ def columns(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501
r"""The set of columns exported by this :class:`.FunctionElement`.
This is a placeholder collection that allows the function to be
return self.c
@util.ro_memoized_property
- def c(self):
+ def c(self) -> ColumnCollection[str, KeyedColumnElement[Any]]: # type: ignore[override] # noqa: E501
"""synonym for :attr:`.FunctionElement.columns`."""
return ColumnCollection(
)
@property
- def _all_selected_columns(self):
+ def _all_selected_columns(self) -> Sequence[KeyedColumnElement[Any]]:
if is_table_value_type(self.type):
- cols = self.type._elements
+ # TODO: this might not be fully accurate
+ cols = cast(
+ "Sequence[KeyedColumnElement[Any]]", self.type._elements
+ )
else:
cols = [self.label(None)]
return cols
@property
- def exported_columns(self):
+ def exported_columns( # type: ignore[override]
+ self,
+ ) -> ColumnCollection[str, KeyedColumnElement[Any]]:
return self.columns
@HasMemoized.memoized_attribute
"""
return cast(ClauseList, self.clause_expr.element)
- def over(self, partition_by=None, order_by=None, rows=None, range_=None):
+ def over(
+ self,
+ *,
+ partition_by: Optional[_ByArgument] = None,
+ order_by: Optional[_ByArgument] = None,
+ rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
+ range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
+ ) -> Over[_T]:
"""Produce an OVER clause against this function.
Used against aggregate or so-called "window" functions,
range_=range_,
)
- def within_group(self, *order_by):
+ def within_group(
+ self, *order_by: _ColumnExpressionArgument[Any]
+ ) -> WithinGroup[_T]:
"""Produce a WITHIN GROUP (ORDER BY expr) clause against this function.
Used against so-called "ordered set aggregate" and "hypothetical
"""
return WithinGroup(self, *order_by)
- def filter(self, *criterion):
+ def filter(
+ self, *criterion: _ColumnExpressionArgument[bool]
+ ) -> Union[Self, FunctionFilter[_T]]:
"""Produce a FILTER clause against this function.
Used against aggregate and window functions,
return self
return FunctionFilter(self, *criterion)
- def as_comparison(self, left_index, right_index):
+ def as_comparison(
+ self, left_index: int, right_index: int
+ ) -> FunctionAsBinary:
"""Interpret this expression as a boolean comparison between two
values.
return FunctionAsBinary(self, left_index, right_index)
@property
- def _from_objects(self):
+ def _from_objects(self) -> Any:
return self.clauses._from_objects
- def within_group_type(self, within_group):
+ def within_group_type(
+ self, within_group: WithinGroup[_S]
+ ) -> Optional[TypeEngine[_S]]:
"""For types that define their return type as based on the criteria
within a WITHIN GROUP (ORDER BY) expression, called by the
:class:`.WithinGroup` construct.
return None
- def alias(self, name=None, joins_implicitly=False):
+ def alias(
+ self, name: Optional[str] = None, joins_implicitly: bool = False
+ ) -> TableValuedAlias:
r"""Produce a :class:`_expression.Alias` construct against this
:class:`.FunctionElement`.
joins_implicitly=joins_implicitly,
)
- def select(self) -> Select[Any]:
+ def select(self) -> Select[Tuple[_T]]:
"""Produce a :func:`_expression.select` construct
against this :class:`.FunctionElement`.
s = s.execution_options(**self._execution_options)
return s
- def _bind_param(self, operator, obj, type_=None, **kw):
+ def _bind_param(
+ self,
+ operator: OperatorType,
+ obj: Any,
+ type_: Optional[TypeEngine[_T]] = None,
+ expanding: bool = False,
+ **kw: Any,
+ ) -> BindParameter[_T]:
return BindParameter(
None,
obj,
_compared_to_type=self.type,
unique=True,
type_=type_,
+ expanding=expanding,
**kw,
)
- def self_group(self, against=None):
+ def self_group(self, against: Optional[OperatorType] = None) -> ClauseElement: # type: ignore[override] # noqa E501
# for the moment, we are parenthesizing all array-returning
# expressions against getitem. This may need to be made
# more portable if in the future we support other DBs
return super().self_group(against=against)
@property
- def entity_namespace(self):
+ def entity_namespace(self) -> _EntityNamespace:
"""overrides FromClause.entity_namespace as functions are generally
column expressions and not FromClauses.
left_index: int
right_index: int
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key(self, anon_map: Any, bindparams: Any) -> Any:
return ColumnElement._gen_cache_key(self, anon_map, bindparams)
def __init__(
""" # noqa
- def __init__(self, **opts):
- self.__names = []
+ def __init__(self, **opts: Any):
+ self.__names: List[str] = []
self.opts = opts
def __getattr__(self, name: str) -> _FunctionGenerator:
def char_length(self) -> Type[char_length]:
...
- @property
- def coalesce(self) -> Type[coalesce[Any]]:
+ # appease mypy which seems to not want to accept _T from
+ # _ColumnExpressionArgument, as it includes non-generic types
+
+ @overload
+ def coalesce(
+ self,
+ col: ColumnElement[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> coalesce[_T]:
+ ...
+
+ @overload
+ def coalesce(
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> coalesce[_T]:
+ ...
+
+ def coalesce(
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> coalesce[_T]:
...
@property
def localtimestamp(self) -> Type[localtimestamp]:
...
- @property
- def max(self) -> Type[max[Any]]: # noqa: A001
+ # appease mypy which seems to not want to accept _T from
+ # _ColumnExpressionArgument, as it includes non-generic types
+
+ @overload
+ def max( # noqa: A001
+ self,
+ col: ColumnElement[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> max[_T]:
...
- @property
- def min(self) -> Type[min[Any]]: # noqa: A001
+ @overload
+ def max( # noqa: A001
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> max[_T]:
+ ...
+
+ def max( # noqa: A001
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> max[_T]:
+ ...
+
+ # appease mypy which seems to not want to accept _T from
+ # _ColumnExpressionArgument, as it includes non-generic types
+
+ @overload
+ def min( # noqa: A001
+ self,
+ col: ColumnElement[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> min[_T]:
+ ...
+
+ @overload
+ def min( # noqa: A001
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> min[_T]:
+ ...
+
+ def min( # noqa: A001
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> min[_T]:
...
@property
def rank(self) -> Type[rank]:
...
- @property
- def returntypefromargs(self) -> Type[ReturnTypeFromArgs[Any]]:
- ...
-
@property
def rollup(self) -> Type[rollup[Any]]:
...
def session_user(self) -> Type[session_user]:
...
- @property
- def sum(self) -> Type[sum[Any]]: # noqa: A001
+ # appease mypy which seems to not want to accept _T from
+ # _ColumnExpressionArgument, as it includes non-generic types
+
+ @overload
+ def sum( # noqa: A001
+ self,
+ col: ColumnElement[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> sum[_T]:
+ ...
+
+ @overload
+ def sum( # noqa: A001
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> sum[_T]:
+ ...
+
+ def sum( # noqa: A001
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ) -> sum[_T]:
...
@property
"""
+ @overload
+ def __init__(
+ self,
+ name: str,
+ *clauses: _ColumnExpressionOrLiteralArgument[_T],
+ type_: None = ...,
+ packagenames: Optional[Tuple[str, ...]] = ...,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ name: str,
+ *clauses: _ColumnExpressionOrLiteralArgument[Any],
+ type_: _TypeEngineArgument[_T] = ...,
+ packagenames: Optional[Tuple[str, ...]] = ...,
+ ):
+ ...
+
def __init__(
self,
name: str,
- *clauses: Any,
+ *clauses: _ColumnExpressionOrLiteralArgument[Any],
type_: Optional[_TypeEngineArgument[_T]] = None,
packagenames: Optional[Tuple[str, ...]] = None,
):
FunctionElement.__init__(self, *clauses)
- def _bind_param(self, operator, obj, type_=None, **kw):
+ def _bind_param(
+ self,
+ operator: OperatorType,
+ obj: Any,
+ type_: Optional[TypeEngine[_T]] = None,
+ expanding: bool = False,
+ **kw: Any,
+ ) -> BindParameter[_T]:
return BindParameter(
self.name,
obj,
_compared_to_type=self.type,
type_=type_,
unique=True,
+ expanding=expanding,
**kw,
)
# Set _register to True to register child classes by default
cls._register = True
- def __init__(self, *args, **kwargs):
+ def __init__(
+ self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
+ ):
parsed_args = kwargs.pop("_parsed_args", None)
if parsed_args is None:
parsed_args = [
)
-register_function("cast", Cast)
-register_function("extract", Extract)
+register_function("cast", Cast) # type: ignore
+register_function("extract", Extract) # type: ignore
class next_value(GenericFunction[int]):
("sequence", InternalTraversal.dp_named_ddl_element)
]
- def __init__(self, seq, **kw):
+ def __init__(self, seq: schema.Sequence, **kw: Any):
assert isinstance(
seq, schema.Sequence
), "next_value() accepts a Sequence object as input."
seq.data_type or getattr(self, "type", None)
)
- def compare(self, other, **kw):
+ def compare(self, other: Any, **kw: Any) -> bool:
return (
isinstance(other, next_value)
and self.sequence.name == other.sequence.name
)
@property
- def _from_objects(self):
+ def _from_objects(self) -> Any:
return []
inherit_cache = True
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
GenericFunction.__init__(self, *args, **kwargs)
inherit_cache = True
- def __init__(self, *args, **kwargs):
- fn_args = [
+ # appease mypy which seems to not want to accept _T from
+ # _ColumnExpressionArgument, as it includes non-generic types
+
+ @overload
+ def __init__(
+ self,
+ col: ColumnElement[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+ ):
+ ...
+
+ def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
+ fn_args: Sequence[ColumnElement[Any]] = [
coercions.expect(
roles.ExpressionElementRole,
c,
type = sqltypes.Integer()
inherit_cache = True
- def __init__(self, arg, **kw):
+ def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any):
# slight hack to limit to just one positional argument
# not sure why this one function has this special treatment
super().__init__(arg, **kw)
type = sqltypes.Integer()
inherit_cache = True
- def __init__(self, expression=None, **kwargs):
+ def __init__(
+ self,
+ expression: Optional[_ColumnExpressionArgument[Any]] = None,
+ **kwargs: Any,
+ ):
if expression is None:
expression = literal_column("*")
super().__init__(expression, **kwargs)
inherit_cache = True
- def __init__(self, *args, **kwargs):
- fn_args = [
+ def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
+ fn_args: Sequence[ColumnElement[Any]] = [
coercions.expect(
roles.ExpressionElementRole, c, apply_propagate_attrs=self
)
array_for_multi_clause = False
inherit_cache = True
- def within_group_type(self, within_group):
+ def within_group_type(
+ self, within_group: WithinGroup[Any]
+ ) -> TypeEngine[Any]:
func_clauses = cast(ClauseList, self.clause_expr.element)
- order_by = sqlutil.unwrap_order_by(within_group.order_by)
+ order_by: Sequence[ColumnElement[Any]] = sqlutil.unwrap_order_by(
+ within_group.order_by
+ )
if self.array_for_multi_clause and len(func_clauses.clauses) > 1:
return sqltypes.ARRAY(order_by[0].type)
else:
_has_args = True
inherit_cache = True
- def __init__(self, clause, separator):
+ def __init__(self, clause: _ColumnExpressionArgument[Any], separator: str):
super().__init__(clause, separator)
from sqlalchemy import column
from sqlalchemy import func
+from sqlalchemy import Integer
from sqlalchemy import select
+from sqlalchemy import Sequence
+from sqlalchemy import String
# START GENERATED FUNCTION TYPING TESTS
# code within this block is **programmatically,
# statically generated** by tools/generate_sql_functions.py
-stmt1 = select(func.aggregate_strings(column("x"), column("x")))
+stmt1 = select(func.aggregate_strings(column("x", String), ","))
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt1)
reveal_type(stmt2)
-stmt3 = select(func.concat())
+stmt3 = select(func.coalesce(column("x", Integer)))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt3)
-stmt4 = select(func.count(column("x")))
+stmt4 = select(func.concat())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt4)
-stmt5 = select(func.cume_dist())
+stmt5 = select(func.count(column("x")))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt5)
-stmt6 = select(func.current_date())
+stmt6 = select(func.cume_dist())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
reveal_type(stmt6)
-stmt7 = select(func.current_time())
+stmt7 = select(func.current_date())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
reveal_type(stmt7)
-stmt8 = select(func.current_timestamp())
+stmt8 = select(func.current_time())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
reveal_type(stmt8)
-stmt9 = select(func.current_user())
+stmt9 = select(func.current_timestamp())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt9)
-stmt10 = select(func.dense_rank())
+stmt10 = select(func.current_user())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
reveal_type(stmt10)
-stmt11 = select(func.localtime())
+stmt11 = select(func.dense_rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt11)
-stmt12 = select(func.localtimestamp())
+stmt12 = select(func.localtime())
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt12)
-stmt13 = select(func.next_value(column("x")))
+stmt13 = select(func.localtimestamp())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt13)
-stmt14 = select(func.now())
+stmt14 = select(func.max(column("x", Integer)))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt14)
-stmt15 = select(func.percent_rank())
+stmt15 = select(func.min(column("x", Integer)))
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt15)
-stmt16 = select(func.rank())
+stmt16 = select(func.next_value(Sequence("x_seq")))
# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt16)
-stmt17 = select(func.session_user())
+stmt17 = select(func.now())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
reveal_type(stmt17)
-stmt18 = select(func.sysdate())
+stmt18 = select(func.percent_rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
reveal_type(stmt18)
-stmt19 = select(func.user())
+stmt19 = select(func.rank())
-# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
reveal_type(stmt19)
+
+stmt20 = select(func.session_user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt20)
+
+
+stmt21 = select(func.sum(column("x", Integer)))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt21)
+
+
+stmt22 = select(func.sysdate())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+reveal_type(stmt22)
+
+
+stmt23 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt23)
+
# END GENERATED FUNCTION TYPING TESTS
import textwrap
from sqlalchemy.sql.functions import _registry
+from sqlalchemy.sql.functions import ReturnTypeFromArgs
from sqlalchemy.types import TypeEngine
from sqlalchemy.util.tool_support import code_writer_cmd
def _fns_in_deterministic_order():
reg = _registry["_default"]
for key in sorted(reg):
- yield key, reg[key]
+ cls = reg[key]
+ if cls is ReturnTypeFromArgs:
+ continue
+ yield key, cls
def process_functions(filename: str, cmd: code_writer_cmd) -> str:
for key, fn_class in _fns_in_deterministic_order():
is_reserved_word = key in builtins
- guess_its_generic = bool(fn_class.__parameters__)
+ if issubclass(fn_class, ReturnTypeFromArgs):
+ buf.write(
+ textwrap.indent(
+ f"""
+
+# appease mypy which seems to not want to accept _T from
+# _ColumnExpressionArgument, as it includes non-generic types
+
+@overload
+def {key}( {' # noqa: A001' if is_reserved_word else ''}
+ self,
+ col: ColumnElement[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+) -> {fn_class.__name__}[_T]:
+ ...
- buf.write(
- textwrap.indent(
- f"""
+@overload
+def {key}( {' # noqa: A001' if is_reserved_word else ''}
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+) -> {fn_class.__name__}[_T]:
+ ...
+
+def {key}( {' # noqa: A001' if is_reserved_word else ''}
+ self,
+ col: _ColumnExpressionArgument[_T],
+ *args: _ColumnExpressionArgument[Any],
+ **kwargs: Any,
+) -> {fn_class.__name__}[_T]:
+ ...
+
+ """,
+ indent,
+ )
+ )
+ else:
+ guess_its_generic = bool(fn_class.__parameters__)
+
+ # the latest flake8 is quite broken here:
+ # 1. it insists on linting f-strings, no option
+ # to turn it off
+ # 2. the f-string indentation rules are either broken
+ # or completely impossible to figure out
+ # 3. there's no way to E501 a too-long f-string,
+ # so I can't even put the expressions all one line
+ # to get around the indentation errors
+ # 4. Therefore here I have to concat part of the
+ # string outside of the f-string
+ _type = fn_class.__name__
+ _type += "[Any]" if guess_its_generic else ""
+ _reserved_word = (
+ " # noqa: A001" if is_reserved_word else ""
+ )
+
+ # now the f-string
+ buf.write(
+ textwrap.indent(
+ f"""
@property
-def {key}(self) -> Type[{fn_class.__name__}{
- '[Any]' if guess_its_generic else ''
-}]:{
- ' # noqa: A001' if is_reserved_word else ''
-}
+def {key}(self) -> Type[{_type}]:{_reserved_word}
...
""",
- indent,
+ indent,
+ )
)
- )
m = re.match(
r"^( *)# START GENERATED FUNCTION TYPING TESTS",
count = 0
for key, fn_class in _fns_in_deterministic_order():
- if hasattr(fn_class, "type") and isinstance(
+ if issubclass(fn_class, ReturnTypeFromArgs):
+ count += 1
+
+ buf.write(
+ textwrap.indent(
+ rf"""
+stmt{count} = select(func.{key}(column('x', Integer)))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt{count})
+
+""",
+ indent,
+ )
+ )
+ elif fn_class.__name__ == "aggregate_strings":
+ count += 1
+ buf.write(
+ textwrap.indent(
+ rf"""
+stmt{count} = select(func.{key}(column('x', String), ','))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt{count})
+
+""",
+ indent,
+ )
+ )
+
+ elif hasattr(fn_class, "type") and isinstance(
fn_class.type, TypeEngine
):
python_type = fn_class.type.python_type
python_expr = rf"Tuple\[.*{python_type.__name__}\]"
argspec = inspect.getfullargspec(fn_class)
- args = ", ".join(
- 'column("x")' for elem in argspec.args[1:]
- )
+ if fn_class.__name__ == "next_value":
+ args = "Sequence('x_seq')"
+ else:
+ args = ", ".join(
+ 'column("x")' for elem in argspec.args[1:]
+ )
count += 1
buf.write(