]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
generate stubs for func known functions
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 20 Jan 2023 20:17:44 +0000 (15:17 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jan 2023 16:17:38 +0000 (11:17 -0500)
Added typing for the built-in generic functions that are available from the
:data:`_sql.func` namespace, which accept a particular set of arguments and
return a particular type, such as for :class:`_sql.count`,
:class:`_sql.current_timestamp`, etc.

Fixes: #9129
Change-Id: I1a2e0dcca3048c77e84dc786843a7df05c457dfa

doc/build/changelog/unreleased_20/9129.rst [new file with mode: 0644]
lib/sqlalchemy/sql/functions.py
test/ext/mypy/plain_files/functions.py [new file with mode: 0644]
tools/generate_sql_functions.py [new file with mode: 0644]
tox.ini

diff --git a/doc/build/changelog/unreleased_20/9129.rst b/doc/build/changelog/unreleased_20/9129.rst
new file mode 100644 (file)
index 0000000..7aa13c5
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9129
+
+    Added typing for the built-in generic functions that are available from the
+    :data:`_sql.func` namespace, which accept a particular set of arguments and
+    return a particular type, such as for :class:`_sql.count`,
+    :class:`_sql.current_timestamp`, etc.
index 26929761aa02c0add4969ca859af260a0f37834f..6054be98a7675f55fe9aa192d2da9047b8df47fd 100644 (file)
@@ -910,6 +910,155 @@ class _FunctionGenerator:
             self.__names[-1], packagenames=tuple(self.__names[0:-1]), *c, **o
         )
 
+    if TYPE_CHECKING:
+
+        # START GENERATED FUNCTION ACCESSORS
+
+        # code within this block is **programmatically,
+        # statically generated** by tools/generate_sql_functions.py
+
+        @property
+        def ansifunction(self) -> Type[AnsiFunction[Any]]:
+            ...
+
+        @property
+        def array_agg(self) -> Type[array_agg[Any]]:
+            ...
+
+        @property
+        def cast(self) -> Type[Cast[Any]]:
+            ...
+
+        @property
+        def char_length(self) -> Type[char_length]:
+            ...
+
+        @property
+        def coalesce(self) -> Type[coalesce[Any]]:
+            ...
+
+        @property
+        def concat(self) -> Type[concat]:
+            ...
+
+        @property
+        def count(self) -> Type[count]:
+            ...
+
+        @property
+        def cube(self) -> Type[cube[Any]]:
+            ...
+
+        @property
+        def cume_dist(self) -> Type[cume_dist[Any]]:
+            ...
+
+        @property
+        def current_date(self) -> Type[current_date]:
+            ...
+
+        @property
+        def current_time(self) -> Type[current_time]:
+            ...
+
+        @property
+        def current_timestamp(self) -> Type[current_timestamp]:
+            ...
+
+        @property
+        def current_user(self) -> Type[current_user]:
+            ...
+
+        @property
+        def dense_rank(self) -> Type[dense_rank]:
+            ...
+
+        @property
+        def extract(self) -> Type[Extract]:
+            ...
+
+        @property
+        def grouping_sets(self) -> Type[grouping_sets[Any]]:
+            ...
+
+        @property
+        def localtime(self) -> Type[localtime]:
+            ...
+
+        @property
+        def localtimestamp(self) -> Type[localtimestamp]:
+            ...
+
+        @property
+        def max(self) -> Type[max[Any]]:  # noqa: A001
+            ...
+
+        @property
+        def min(self) -> Type[min[Any]]:  # noqa: A001
+            ...
+
+        @property
+        def mode(self) -> Type[mode[Any]]:
+            ...
+
+        @property
+        def next_value(self) -> Type[next_value]:
+            ...
+
+        @property
+        def now(self) -> Type[now]:
+            ...
+
+        @property
+        def orderedsetagg(self) -> Type[OrderedSetAgg[Any]]:
+            ...
+
+        @property
+        def percent_rank(self) -> Type[percent_rank[Any]]:
+            ...
+
+        @property
+        def percentile_cont(self) -> Type[percentile_cont[Any]]:
+            ...
+
+        @property
+        def percentile_disc(self) -> Type[percentile_disc[Any]]:
+            ...
+
+        @property
+        def random(self) -> Type[random]:
+            ...
+
+        @property
+        def rank(self) -> Type[rank]:
+            ...
+
+        @property
+        def returntypefromargs(self) -> Type[ReturnTypeFromArgs[Any]]:
+            ...
+
+        @property
+        def rollup(self) -> Type[rollup[Any]]:
+            ...
+
+        @property
+        def session_user(self) -> Type[session_user]:
+            ...
+
+        @property
+        def sum(self) -> Type[sum[Any]]:  # noqa: A001
+            ...
+
+        @property
+        def sysdate(self) -> Type[sysdate]:
+            ...
+
+        @property
+        def user(self) -> Type[user]:
+            ...
+
+        # END GENERATED FUNCTION ACCESSORS
+
 
 func = _FunctionGenerator()
 func.__doc__ = _FunctionGenerator.__doc__
diff --git a/test/ext/mypy/plain_files/functions.py b/test/ext/mypy/plain_files/functions.py
new file mode 100644 (file)
index 0000000..ecd4040
--- /dev/null
@@ -0,0 +1,119 @@
+"""this file is generated by tools/generate_sql_functions.py"""
+
+from sqlalchemy import column
+from sqlalchemy import func
+from sqlalchemy import select
+
+# START GENERATED FUNCTION TYPING TESTS
+
+# code within this block is **programmatically,
+# statically generated** by tools/generate_sql_functions.py
+
+stmt1 = select(func.char_length(column("x")))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt1)
+
+
+stmt2 = select(func.concat())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt2)
+
+
+stmt3 = select(func.count(column("x")))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt3)
+
+
+stmt4 = select(func.cume_dist())
+
+# EXPECTED_RE_TYPE: .*Select\[Any\]
+reveal_type(stmt4)
+
+
+stmt5 = select(func.current_date())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\]
+reveal_type(stmt5)
+
+
+stmt6 = select(func.current_time())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\]
+reveal_type(stmt6)
+
+
+stmt7 = select(func.current_timestamp())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+reveal_type(stmt7)
+
+
+stmt8 = select(func.current_user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt8)
+
+
+stmt9 = select(func.dense_rank())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt9)
+
+
+stmt10 = select(func.localtime())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+reveal_type(stmt10)
+
+
+stmt11 = select(func.localtimestamp())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+reveal_type(stmt11)
+
+
+stmt12 = select(func.next_value(column("x")))
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt12)
+
+
+stmt13 = select(func.now())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+reveal_type(stmt13)
+
+
+stmt14 = select(func.percent_rank())
+
+# EXPECTED_RE_TYPE: .*Select\[Any\]
+reveal_type(stmt14)
+
+
+stmt15 = select(func.rank())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\]
+reveal_type(stmt15)
+
+
+stmt16 = select(func.session_user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt16)
+
+
+stmt17 = select(func.sysdate())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\]
+reveal_type(stmt17)
+
+
+stmt18 = select(func.user())
+
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\]
+reveal_type(stmt18)
+
+# END GENERATED FUNCTION TYPING TESTS
diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py
new file mode 100644 (file)
index 0000000..d207c62
--- /dev/null
@@ -0,0 +1,160 @@
+"""Generate inline stubs for generic functions on func
+
+"""
+# mypy: ignore-errors
+
+from __future__ import annotations
+
+from decimal import Decimal
+import inspect
+import re
+from tempfile import NamedTemporaryFile
+import textwrap
+from typing import Any
+
+from sqlalchemy.sql.functions import _registry
+from sqlalchemy.types import TypeEngine
+from sqlalchemy.util.tool_support import code_writer_cmd
+
+
+def _fns_in_deterministic_order():
+    reg = _registry["_default"]
+    for key in sorted(reg):
+        yield key, reg[key]
+
+
+def process_functions(filename: str, cmd: code_writer_cmd) -> str:
+
+    with NamedTemporaryFile(
+        mode="w",
+        delete=False,
+        suffix=".py",
+    ) as buf, open(filename) as orig_py:
+        indent = ""
+        in_block = False
+
+        for line in orig_py:
+            m = re.match(
+                r"^( *)# START GENERATED FUNCTION ACCESSORS",
+                line,
+            )
+            if m:
+                in_block = True
+                buf.write(line)
+                indent = m.group(1)
+                buf.write(
+                    textwrap.indent(
+                        """
+# code within this block is **programmatically,
+# statically generated** by tools/generate_sql_functions.py
+""",
+                        indent,
+                    )
+                )
+
+                builtins = set(dir(__builtins__))
+                for key, fn_class in _fns_in_deterministic_order():
+                    is_reserved_word = key in builtins
+
+                    guess_its_generic = bool(fn_class.__parameters__)
+
+                    buf.write(
+                        textwrap.indent(
+                            f"""
+@property
+def {key}(self) -> Type[{fn_class.__name__}{
+    '[Any]' if guess_its_generic else ''
+}]:{
+     '  # noqa: A001' if is_reserved_word else ''
+}
+    ...
+
+""",
+                            indent,
+                        )
+                    )
+
+            m = re.match(
+                r"^( *)# START GENERATED FUNCTION TYPING TESTS",
+                line,
+            )
+            if m:
+                in_block = True
+                buf.write(line)
+                indent = m.group(1)
+
+                buf.write(
+                    textwrap.indent(
+                        """
+# code within this block is **programmatically,
+# statically generated** by tools/generate_sql_functions.py
+""",
+                        indent,
+                    )
+                )
+
+                count = 0
+                for key, fn_class in _fns_in_deterministic_order():
+                    if hasattr(fn_class, "type") and isinstance(
+                        fn_class.type, TypeEngine
+                    ):
+                        python_type = fn_class.type.python_type
+
+                        # TODO: numeric types don't seem to be coming out
+                        # at the moment, because Numeric is typed generically
+                        # in that it can return Decimal or float. We would need
+                        # to further break out Numeric / Float into types
+                        # that type out as returning an exact Decimal or float
+                        if python_type is Decimal:
+                            python_type = Any
+                            python_expr = f"{python_type.__name__}"
+                        else:
+                            python_expr = rf"Tuple\[.*{python_type.__name__}\]"
+                        argspec = inspect.getfullargspec(fn_class)
+                        args = ", ".join(
+                            'column("x")' for elem in argspec.args[1:]
+                        )
+                        count += 1
+
+                        buf.write(
+                            textwrap.indent(
+                                rf"""
+stmt{count} = select(func.{key}({args}))
+
+# EXPECTED_RE_TYPE: .*Select\[{python_expr}\]
+reveal_type(stmt{count})
+
+""",
+                                indent,
+                            )
+                        )
+
+            if in_block and line.startswith(
+                f"{indent}# END GENERATED FUNCTION"
+            ):
+                in_block = False
+
+            if not in_block:
+                buf.write(line)
+    return buf.name
+
+
+def main(cmd: code_writer_cmd) -> None:
+    for path in [functions_py, test_functions_py]:
+        destination_path = path
+        tempfile = process_functions(destination_path, cmd)
+        cmd.run_zimports(tempfile)
+        cmd.run_black(tempfile)
+        cmd.write_output_file_from_tempfile(tempfile, destination_path)
+
+
+functions_py = "lib/sqlalchemy/sql/functions.py"
+test_functions_py = "test/ext/mypy/plain_files/functions.py"
+
+
+if __name__ == "__main__":
+
+    cmd = code_writer_cmd(__file__)
+
+    with cmd.run_program():
+        main(cmd)
diff --git a/tox.ini b/tox.ini
index 144de79a7995ce12d893c36211cd4faa595677c7..503b3b8dd478dd467f12e835f22597d2f163f6a0 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -217,6 +217,7 @@ commands =
      python ./tools/generate_tuple_map_overloads.py --check
      python ./tools/generate_proxy_methods.py --check
      python ./tools/sync_test_files.py --check
+     python ./tools/generate_sql_functions.py --check
 
 
 # "pep8" env was renamed to "lint".