From: Mike Bayer Date: Fri, 20 Jan 2023 20:17:44 +0000 (-0500) Subject: generate stubs for func known functions X-Git-Tag: rel_2_0_0~19^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6499098e36497d15d5972696983ce0ae4cc99409;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git generate stubs for func known functions Added typing for the built-in generic functions that are available from the :data:`_sql.func` namespace, which accept a particular set of arguments and return a particular type, such as for :class:`_sql.count`, :class:`_sql.current_timestamp`, etc. Fixes: #9129 Change-Id: I1a2e0dcca3048c77e84dc786843a7df05c457dfa --- diff --git a/doc/build/changelog/unreleased_20/9129.rst b/doc/build/changelog/unreleased_20/9129.rst new file mode 100644 index 0000000000..7aa13c51cd --- /dev/null +++ b/doc/build/changelog/unreleased_20/9129.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, typing + :tickets: 9129 + + Added typing for the built-in generic functions that are available from the + :data:`_sql.func` namespace, which accept a particular set of arguments and + return a particular type, such as for :class:`_sql.count`, + :class:`_sql.current_timestamp`, etc. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 26929761aa..6054be98a7 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -910,6 +910,155 @@ class _FunctionGenerator: self.__names[-1], packagenames=tuple(self.__names[0:-1]), *c, **o ) + if TYPE_CHECKING: + + # START GENERATED FUNCTION ACCESSORS + + # code within this block is **programmatically, + # statically generated** by tools/generate_sql_functions.py + + @property + def ansifunction(self) -> Type[AnsiFunction[Any]]: + ... + + @property + def array_agg(self) -> Type[array_agg[Any]]: + ... + + @property + def cast(self) -> Type[Cast[Any]]: + ... + + @property + def char_length(self) -> Type[char_length]: + ... + + @property + def coalesce(self) -> Type[coalesce[Any]]: + ... + + @property + def concat(self) -> Type[concat]: + ... + + @property + def count(self) -> Type[count]: + ... + + @property + def cube(self) -> Type[cube[Any]]: + ... + + @property + def cume_dist(self) -> Type[cume_dist[Any]]: + ... + + @property + def current_date(self) -> Type[current_date]: + ... + + @property + def current_time(self) -> Type[current_time]: + ... + + @property + def current_timestamp(self) -> Type[current_timestamp]: + ... + + @property + def current_user(self) -> Type[current_user]: + ... + + @property + def dense_rank(self) -> Type[dense_rank]: + ... + + @property + def extract(self) -> Type[Extract]: + ... + + @property + def grouping_sets(self) -> Type[grouping_sets[Any]]: + ... + + @property + def localtime(self) -> Type[localtime]: + ... + + @property + def localtimestamp(self) -> Type[localtimestamp]: + ... + + @property + def max(self) -> Type[max[Any]]: # noqa: A001 + ... + + @property + def min(self) -> Type[min[Any]]: # noqa: A001 + ... + + @property + def mode(self) -> Type[mode[Any]]: + ... + + @property + def next_value(self) -> Type[next_value]: + ... + + @property + def now(self) -> Type[now]: + ... + + @property + def orderedsetagg(self) -> Type[OrderedSetAgg[Any]]: + ... + + @property + def percent_rank(self) -> Type[percent_rank[Any]]: + ... + + @property + def percentile_cont(self) -> Type[percentile_cont[Any]]: + ... + + @property + def percentile_disc(self) -> Type[percentile_disc[Any]]: + ... + + @property + def random(self) -> Type[random]: + ... + + @property + def rank(self) -> Type[rank]: + ... + + @property + def returntypefromargs(self) -> Type[ReturnTypeFromArgs[Any]]: + ... + + @property + def rollup(self) -> Type[rollup[Any]]: + ... + + @property + def session_user(self) -> Type[session_user]: + ... + + @property + def sum(self) -> Type[sum[Any]]: # noqa: A001 + ... + + @property + def sysdate(self) -> Type[sysdate]: + ... + + @property + def user(self) -> Type[user]: + ... + + # END GENERATED FUNCTION ACCESSORS + func = _FunctionGenerator() func.__doc__ = _FunctionGenerator.__doc__ diff --git a/test/ext/mypy/plain_files/functions.py b/test/ext/mypy/plain_files/functions.py new file mode 100644 index 0000000000..ecd404010e --- /dev/null +++ b/test/ext/mypy/plain_files/functions.py @@ -0,0 +1,119 @@ +"""this file is generated by tools/generate_sql_functions.py""" + +from sqlalchemy import column +from sqlalchemy import func +from sqlalchemy import select + +# START GENERATED FUNCTION TYPING TESTS + +# code within this block is **programmatically, +# statically generated** by tools/generate_sql_functions.py + +stmt1 = select(func.char_length(column("x"))) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt1) + + +stmt2 = select(func.concat()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt2) + + +stmt3 = select(func.count(column("x"))) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt3) + + +stmt4 = select(func.cume_dist()) + +# EXPECTED_RE_TYPE: .*Select\[Any\] +reveal_type(stmt4) + + +stmt5 = select(func.current_date()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] +reveal_type(stmt5) + + +stmt6 = select(func.current_time()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] +reveal_type(stmt6) + + +stmt7 = select(func.current_timestamp()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +reveal_type(stmt7) + + +stmt8 = select(func.current_user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt8) + + +stmt9 = select(func.dense_rank()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt9) + + +stmt10 = select(func.localtime()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +reveal_type(stmt10) + + +stmt11 = select(func.localtimestamp()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +reveal_type(stmt11) + + +stmt12 = select(func.next_value(column("x"))) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt12) + + +stmt13 = select(func.now()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +reveal_type(stmt13) + + +stmt14 = select(func.percent_rank()) + +# EXPECTED_RE_TYPE: .*Select\[Any\] +reveal_type(stmt14) + + +stmt15 = select(func.rank()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] +reveal_type(stmt15) + + +stmt16 = select(func.session_user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt16) + + +stmt17 = select(func.sysdate()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] +reveal_type(stmt17) + + +stmt18 = select(func.user()) + +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] +reveal_type(stmt18) + +# END GENERATED FUNCTION TYPING TESTS diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py new file mode 100644 index 0000000000..d207c62bcd --- /dev/null +++ b/tools/generate_sql_functions.py @@ -0,0 +1,160 @@ +"""Generate inline stubs for generic functions on func + +""" +# mypy: ignore-errors + +from __future__ import annotations + +from decimal import Decimal +import inspect +import re +from tempfile import NamedTemporaryFile +import textwrap +from typing import Any + +from sqlalchemy.sql.functions import _registry +from sqlalchemy.types import TypeEngine +from sqlalchemy.util.tool_support import code_writer_cmd + + +def _fns_in_deterministic_order(): + reg = _registry["_default"] + for key in sorted(reg): + yield key, reg[key] + + +def process_functions(filename: str, cmd: code_writer_cmd) -> str: + + with NamedTemporaryFile( + mode="w", + delete=False, + suffix=".py", + ) as buf, open(filename) as orig_py: + indent = "" + in_block = False + + for line in orig_py: + m = re.match( + r"^( *)# START GENERATED FUNCTION ACCESSORS", + line, + ) + if m: + in_block = True + buf.write(line) + indent = m.group(1) + buf.write( + textwrap.indent( + """ +# code within this block is **programmatically, +# statically generated** by tools/generate_sql_functions.py +""", + indent, + ) + ) + + builtins = set(dir(__builtins__)) + for key, fn_class in _fns_in_deterministic_order(): + is_reserved_word = key in builtins + + guess_its_generic = bool(fn_class.__parameters__) + + buf.write( + textwrap.indent( + f""" +@property +def {key}(self) -> Type[{fn_class.__name__}{ + '[Any]' if guess_its_generic else '' +}]:{ + ' # noqa: A001' if is_reserved_word else '' +} + ... + +""", + indent, + ) + ) + + m = re.match( + r"^( *)# START GENERATED FUNCTION TYPING TESTS", + line, + ) + if m: + in_block = True + buf.write(line) + indent = m.group(1) + + buf.write( + textwrap.indent( + """ +# code within this block is **programmatically, +# statically generated** by tools/generate_sql_functions.py +""", + indent, + ) + ) + + count = 0 + for key, fn_class in _fns_in_deterministic_order(): + if hasattr(fn_class, "type") and isinstance( + fn_class.type, TypeEngine + ): + python_type = fn_class.type.python_type + + # TODO: numeric types don't seem to be coming out + # at the moment, because Numeric is typed generically + # in that it can return Decimal or float. We would need + # to further break out Numeric / Float into types + # that type out as returning an exact Decimal or float + if python_type is Decimal: + python_type = Any + python_expr = f"{python_type.__name__}" + else: + python_expr = rf"Tuple\[.*{python_type.__name__}\]" + argspec = inspect.getfullargspec(fn_class) + args = ", ".join( + 'column("x")' for elem in argspec.args[1:] + ) + count += 1 + + buf.write( + textwrap.indent( + rf""" +stmt{count} = select(func.{key}({args})) + +# EXPECTED_RE_TYPE: .*Select\[{python_expr}\] +reveal_type(stmt{count}) + +""", + indent, + ) + ) + + if in_block and line.startswith( + f"{indent}# END GENERATED FUNCTION" + ): + in_block = False + + if not in_block: + buf.write(line) + return buf.name + + +def main(cmd: code_writer_cmd) -> None: + for path in [functions_py, test_functions_py]: + destination_path = path + tempfile = process_functions(destination_path, cmd) + cmd.run_zimports(tempfile) + cmd.run_black(tempfile) + cmd.write_output_file_from_tempfile(tempfile, destination_path) + + +functions_py = "lib/sqlalchemy/sql/functions.py" +test_functions_py = "test/ext/mypy/plain_files/functions.py" + + +if __name__ == "__main__": + + cmd = code_writer_cmd(__file__) + + with cmd.run_program(): + main(cmd) diff --git a/tox.ini b/tox.ini index 144de79a79..503b3b8dd4 100644 --- a/tox.ini +++ b/tox.ini @@ -217,6 +217,7 @@ commands = python ./tools/generate_tuple_map_overloads.py --check python ./tools/generate_proxy_methods.py --check python ./tools/sync_test_files.py --check + python ./tools/generate_sql_functions.py --check # "pep8" env was renamed to "lint".