From 73344fd0d35bd2bf4c4bb8f2a8534a97d7f241af Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 27 Aug 2024 20:02:00 +0200 Subject: [PATCH] Fix select.with_only_columns type hints Fixes: #11782 Change-Id: Idce218a9730986d3ca70547c83aa1c0f8b5ee5b2 --- doc/build/changelog/unreleased_20/11782.rst | 5 +++ lib/sqlalchemy/sql/selectable.py | 40 +++++++++++++++++---- tools/format_docs_code.py | 21 ++++++----- tools/generate_proxy_methods.py | 29 ++++++++------- tools/generate_sql_functions.py | 1 + tools/generate_tuple_map_overloads.py | 31 ++++++++++------ tools/trace_orm_adapter.py | 1 + 7 files changed, 91 insertions(+), 37 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/11782.rst diff --git a/doc/build/changelog/unreleased_20/11782.rst b/doc/build/changelog/unreleased_20/11782.rst new file mode 100644 index 0000000000..df8e1f5c3b --- /dev/null +++ b/doc/build/changelog/unreleased_20/11782.rst @@ -0,0 +1,5 @@ +.. change:: + :tags: bug, typing + :tickets: 11782 + + Fixed typing issue with :meth:`_sql.Select.with_only_columns`. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index f38e6cea0a..958638b106 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -5838,22 +5838,35 @@ class Select( ) return woc - # START OVERLOADED FUNCTIONS self.with_only_columns Select 8 + # START OVERLOADED FUNCTIONS self.with_only_columns Select 1-8 ", *, maintain_column_froms: bool =..." # noqa: E501 # code within this block is **programmatically, - # statically generated** by tools/generate_sel_v1_overloads.py + # statically generated** by tools/generate_tuple_map_overloads.py @overload - def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[_T0]: ... + def with_only_columns( + self, __ent0: _TCCA[_T0], /, *, maintain_column_froms: bool = ... + ) -> Select[_T0]: ... @overload def with_only_columns( - self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + /, + *, + maintain_column_froms: bool = ..., ) -> Select[_T0, _T1]: ... @overload def with_only_columns( - self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + /, + *, + maintain_column_froms: bool = ..., ) -> Select[_T0, _T1, _T2]: ... @overload @@ -5863,6 +5876,9 @@ class Select( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], + /, + *, + maintain_column_froms: bool = ..., ) -> Select[_T0, _T1, _T2, _T3]: ... @overload @@ -5873,6 +5889,9 @@ class Select( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], + /, + *, + maintain_column_froms: bool = ..., ) -> Select[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -5884,6 +5903,9 @@ class Select( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], + /, + *, + maintain_column_froms: bool = ..., ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -5896,6 +5918,9 @@ class Select( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], + /, + *, + maintain_column_froms: bool = ..., ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -5909,7 +5934,10 @@ class Select( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... + /, + *entities: _ColumnsClauseArgument[Any], + maintain_column_froms: bool = ..., + ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny]]: ... # END OVERLOADED FUNCTIONS self.with_only_columns diff --git a/tools/format_docs_code.py b/tools/format_docs_code.py index 7bae0126b0..8d24a9163a 100644 --- a/tools/format_docs_code.py +++ b/tools/format_docs_code.py @@ -6,6 +6,7 @@ that it extracts from the documentation. .. versionadded:: 2.0 """ + # mypy: ignore-errors from argparse import ArgumentParser @@ -316,11 +317,13 @@ def main( print( f"{to_reformat} file(s) would be reformatted;", ( - f"{sum(formatting_error_counts)} formatting errors " - f"reported in {len(formatting_error_counts)} files" - ) - if formatting_error_counts - else "no formatting errors reported", + ( + f"{sum(formatting_error_counts)} formatting errors " + f"reported in {len(formatting_error_counts)} files" + ) + if formatting_error_counts + else "no formatting errors reported" + ), ) exit(1) @@ -388,9 +391,11 @@ Use --report-doctest to ignore errors on plain code blocks. for val in config.get("target_version", []) if val != "py27" }, - line_length=config.get("line_length", DEFAULT_LINE_LENGTH) - if args.project_line_length - else DEFAULT_LINE_LENGTH, + line_length=( + config.get("line_length", DEFAULT_LINE_LENGTH) + if args.project_line_length + else DEFAULT_LINE_LENGTH + ), ) REPORT_ONLY_DOCTEST = args.report_doctest diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index 9881d26426..31832ae8bf 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -40,6 +40,7 @@ typed by hand. .. versionadded:: 2.0 """ + # mypy: ignore-errors from __future__ import annotations @@ -85,9 +86,9 @@ class _repr_sym: return self.sym -classes: collections.defaultdict[ - str, Dict[str, Tuple[Any, ...]] -] = collections.defaultdict(dict) +classes: collections.defaultdict[str, Dict[str, Tuple[Any, ...]]] = ( + collections.defaultdict(dict) +) _T = TypeVar("_T", bound="Any") @@ -214,18 +215,22 @@ def process_class( if spec.defaults: new_defaults = tuple( - _repr_sym("util.EMPTY_DICT") - if df is util.EMPTY_DICT - else df + ( + _repr_sym("util.EMPTY_DICT") + if df is util.EMPTY_DICT + else df + ) for df in spec.defaults ) elem[3] = new_defaults if spec.kwonlydefaults: new_kwonlydefaults = { - name: _repr_sym("util.EMPTY_DICT") - if df is util.EMPTY_DICT - else df + name: ( + _repr_sym("util.EMPTY_DICT") + if df is util.EMPTY_DICT + else df + ) for name, df in spec.kwonlydefaults.items() } elem[5] = new_kwonlydefaults @@ -415,9 +420,9 @@ def main(cmd: code_writer_cmd) -> None: from sqlalchemy import util from sqlalchemy.util import langhelpers - util.create_proxy_methods = ( - langhelpers.create_proxy_methods - ) = create_proxy_methods + util.create_proxy_methods = langhelpers.create_proxy_methods = ( + create_proxy_methods + ) for entry in entries: if cmd.args.module in {"all", entry}: diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 411cfed721..b777ae406a 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -1,6 +1,7 @@ """Generate inline stubs for generic functions on func """ + # mypy: ignore-errors from __future__ import annotations diff --git a/tools/generate_tuple_map_overloads.py b/tools/generate_tuple_map_overloads.py index 9ca648333c..a7a2eb5f43 100644 --- a/tools/generate_tuple_map_overloads.py +++ b/tools/generate_tuple_map_overloads.py @@ -16,6 +16,7 @@ combinatoric generated code approach. .. versionadded:: 2.0 """ + # mypy: ignore-errors from __future__ import annotations @@ -36,10 +37,13 @@ is_posix = os.name == "posix" sys.path.append(str(Path(__file__).parent.parent)) -def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: +def process_module( + modname: str, filename: str, expected_number: int, cmd: code_writer_cmd +) -> str: # use tempfile in same path as the module, or at least in the # current working directory, so that black / zimports use # local pyproject.toml + found = 0 with NamedTemporaryFile( mode="w", delete=False, @@ -54,6 +58,7 @@ def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: line, ) if m: + found += 1 indent = m.group(1) given_fnname = current_fnname = m.group(2) if current_fnname.startswith("self."): @@ -116,16 +121,20 @@ def {current_fnname}( if not in_block: buf.write(line) + if found != expected_number: + raise Exception( + f"{modname} processed {found}. expected {expected_number}" + ) return buf.name -def run_module(modname: str, cmd: code_writer_cmd) -> None: +def run_module(modname: str, count: int, cmd: code_writer_cmd) -> None: cmd.write_status(f"importing module {modname}\n") mod = importlib.import_module(modname) destination_path = mod.__file__ assert destination_path is not None - tempfile = process_module(modname, destination_path, cmd) + tempfile = process_module(modname, destination_path, count, cmd) cmd.run_zimports(tempfile) cmd.run_black(tempfile) @@ -133,17 +142,17 @@ def run_module(modname: str, cmd: code_writer_cmd) -> None: def main(cmd: code_writer_cmd) -> None: - for modname in entries: + for modname, count in entries: if cmd.args.module in {"all", modname}: - run_module(modname, cmd) + run_module(modname, count, cmd) entries = [ - "sqlalchemy.sql._selectable_constructors", - "sqlalchemy.orm.session", - "sqlalchemy.orm.query", - "sqlalchemy.sql.selectable", - "sqlalchemy.sql.dml", + ("sqlalchemy.sql._selectable_constructors", 1), + ("sqlalchemy.orm.session", 1), + ("sqlalchemy.orm.query", 1), + ("sqlalchemy.sql.selectable", 1), + ("sqlalchemy.sql.dml", 3), ] if __name__ == "__main__": @@ -152,7 +161,7 @@ if __name__ == "__main__": with cmd.add_arguments() as parser: parser.add_argument( "--module", - choices=entries + ["all"], + choices=[n for n, _ in entries] + ["all"], default="all", help="Which file to generate. Default is to regenerate all files", ) diff --git a/tools/trace_orm_adapter.py b/tools/trace_orm_adapter.py index de8098bcb8..966705690d 100644 --- a/tools/trace_orm_adapter.py +++ b/tools/trace_orm_adapter.py @@ -23,6 +23,7 @@ You can then set a breakpoint at the end of any adapt step: """ # noqa: E501 + # mypy: ignore-errors -- 2.47.2