]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix select.with_only_columns type hints
authorFederico Caselli <cfederico87@gmail.com>
Tue, 27 Aug 2024 18:02:00 +0000 (20:02 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 28 Aug 2024 19:52:30 +0000 (21:52 +0200)
Fixes: #11782
Change-Id: Idce218a9730986d3ca70547c83aa1c0f8b5ee5b2

doc/build/changelog/unreleased_20/11782.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
tools/format_docs_code.py
tools/generate_proxy_methods.py
tools/generate_sql_functions.py
tools/generate_tuple_map_overloads.py
tools/trace_orm_adapter.py

diff --git a/doc/build/changelog/unreleased_20/11782.rst b/doc/build/changelog/unreleased_20/11782.rst
new file mode 100644 (file)
index 0000000..df8e1f5
--- /dev/null
@@ -0,0 +1,5 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 11782
+
+    Fixed typing issue with :meth:`_sql.Select.with_only_columns`.
index f38e6cea0a55fcb3394fc2d850766b4b38f55b90..958638b106431030a2a1b52bbe49732da227bf66 100644 (file)
@@ -5838,22 +5838,35 @@ class Select(
         )
         return woc
 
-    # START OVERLOADED FUNCTIONS self.with_only_columns Select 8
+    # START OVERLOADED FUNCTIONS self.with_only_columns Select 1-8 ", *, maintain_column_froms: bool =..." # noqa: E501
 
     # code within this block is **programmatically,
-    # statically generated** by tools/generate_sel_v1_overloads.py
+    # statically generated** by tools/generate_tuple_map_overloads.py
 
     @overload
-    def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[_T0]: ...
+    def with_only_columns(
+        self, __ent0: _TCCA[_T0], /, *, maintain_column_froms: bool = ...
+    ) -> Select[_T0]: ...
 
     @overload
     def with_only_columns(
-        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1]
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        /,
+        *,
+        maintain_column_froms: bool = ...,
     ) -> Select[_T0, _T1]: ...
 
     @overload
     def with_only_columns(
-        self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2]
+        self,
+        __ent0: _TCCA[_T0],
+        __ent1: _TCCA[_T1],
+        __ent2: _TCCA[_T2],
+        /,
+        *,
+        maintain_column_froms: bool = ...,
     ) -> Select[_T0, _T1, _T2]: ...
 
     @overload
@@ -5863,6 +5876,9 @@ class Select(
         __ent1: _TCCA[_T1],
         __ent2: _TCCA[_T2],
         __ent3: _TCCA[_T3],
+        /,
+        *,
+        maintain_column_froms: bool = ...,
     ) -> Select[_T0, _T1, _T2, _T3]: ...
 
     @overload
@@ -5873,6 +5889,9 @@ class Select(
         __ent2: _TCCA[_T2],
         __ent3: _TCCA[_T3],
         __ent4: _TCCA[_T4],
+        /,
+        *,
+        maintain_column_froms: bool = ...,
     ) -> Select[_T0, _T1, _T2, _T3, _T4]: ...
 
     @overload
@@ -5884,6 +5903,9 @@ class Select(
         __ent3: _TCCA[_T3],
         __ent4: _TCCA[_T4],
         __ent5: _TCCA[_T5],
+        /,
+        *,
+        maintain_column_froms: bool = ...,
     ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5]: ...
 
     @overload
@@ -5896,6 +5918,9 @@ class Select(
         __ent4: _TCCA[_T4],
         __ent5: _TCCA[_T5],
         __ent6: _TCCA[_T6],
+        /,
+        *,
+        maintain_column_froms: bool = ...,
     ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ...
 
     @overload
@@ -5909,7 +5934,10 @@ class Select(
         __ent5: _TCCA[_T5],
         __ent6: _TCCA[_T6],
         __ent7: _TCCA[_T7],
-    ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ...
+        /,
+        *entities: _ColumnsClauseArgument[Any],
+        maintain_column_froms: bool = ...,
+    ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny]]: ...
 
     # END OVERLOADED FUNCTIONS self.with_only_columns
 
index 7bae0126b027369aa987d86d12cc4d8d548df228..8d24a9163af8b4413269c698b4ee26c91b99f010 100644 (file)
@@ -6,6 +6,7 @@ that it extracts from the documentation.
 .. versionadded:: 2.0
 
 """
+
 # mypy: ignore-errors
 
 from argparse import ArgumentParser
@@ -316,11 +317,13 @@ def main(
             print(
                 f"{to_reformat} file(s) would be reformatted;",
                 (
-                    f"{sum(formatting_error_counts)} formatting errors "
-                    f"reported in {len(formatting_error_counts)} files"
-                )
-                if formatting_error_counts
-                else "no formatting errors reported",
+                    (
+                        f"{sum(formatting_error_counts)} formatting errors "
+                        f"reported in {len(formatting_error_counts)} files"
+                    )
+                    if formatting_error_counts
+                    else "no formatting errors reported"
+                ),
             )
 
             exit(1)
@@ -388,9 +391,11 @@ Use --report-doctest to ignore errors on plain code blocks.
             for val in config.get("target_version", [])
             if val != "py27"
         },
-        line_length=config.get("line_length", DEFAULT_LINE_LENGTH)
-        if args.project_line_length
-        else DEFAULT_LINE_LENGTH,
+        line_length=(
+            config.get("line_length", DEFAULT_LINE_LENGTH)
+            if args.project_line_length
+            else DEFAULT_LINE_LENGTH
+        ),
     )
     REPORT_ONLY_DOCTEST = args.report_doctest
 
index 9881d26426fdc738e9717613a7d51b1e889ffc09..31832ae8bfadec72306daa6e193f44477481a894 100644 (file)
@@ -40,6 +40,7 @@ typed by hand.
 .. versionadded:: 2.0
 
 """
+
 # mypy: ignore-errors
 
 from __future__ import annotations
@@ -85,9 +86,9 @@ class _repr_sym:
         return self.sym
 
 
-classes: collections.defaultdict[
-    str, Dict[str, Tuple[Any, ...]]
-] = collections.defaultdict(dict)
+classes: collections.defaultdict[str, Dict[str, Tuple[Any, ...]]] = (
+    collections.defaultdict(dict)
+)
 
 _T = TypeVar("_T", bound="Any")
 
@@ -214,18 +215,22 @@ def process_class(
 
             if spec.defaults:
                 new_defaults = tuple(
-                    _repr_sym("util.EMPTY_DICT")
-                    if df is util.EMPTY_DICT
-                    else df
+                    (
+                        _repr_sym("util.EMPTY_DICT")
+                        if df is util.EMPTY_DICT
+                        else df
+                    )
                     for df in spec.defaults
                 )
                 elem[3] = new_defaults
 
             if spec.kwonlydefaults:
                 new_kwonlydefaults = {
-                    name: _repr_sym("util.EMPTY_DICT")
-                    if df is util.EMPTY_DICT
-                    else df
+                    name: (
+                        _repr_sym("util.EMPTY_DICT")
+                        if df is util.EMPTY_DICT
+                        else df
+                    )
                     for name, df in spec.kwonlydefaults.items()
                 }
                 elem[5] = new_kwonlydefaults
@@ -415,9 +420,9 @@ def main(cmd: code_writer_cmd) -> None:
     from sqlalchemy import util
     from sqlalchemy.util import langhelpers
 
-    util.create_proxy_methods = (
-        langhelpers.create_proxy_methods
-    ) = create_proxy_methods
+    util.create_proxy_methods = langhelpers.create_proxy_methods = (
+        create_proxy_methods
+    )
 
     for entry in entries:
         if cmd.args.module in {"all", entry}:
index 411cfed7219943b38473753c64084f4315138b8b..b777ae406a28d628fca2b53b96e25d142edec246 100644 (file)
@@ -1,6 +1,7 @@
 """Generate inline stubs for generic functions on func
 
 """
+
 # mypy: ignore-errors
 
 from __future__ import annotations
index 9ca648333cde1d8608f079df5f39da8a30a74315..a7a2eb5f4308182b63caef4734289f27680ae7a7 100644 (file)
@@ -16,6 +16,7 @@ combinatoric generated code approach.
 .. versionadded:: 2.0
 
 """
+
 # mypy: ignore-errors
 
 from __future__ import annotations
@@ -36,10 +37,13 @@ is_posix = os.name == "posix"
 sys.path.append(str(Path(__file__).parent.parent))
 
 
-def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
+def process_module(
+    modname: str, filename: str, expected_number: int, cmd: code_writer_cmd
+) -> str:
     # use tempfile in same path as the module, or at least in the
     # current working directory, so that black / zimports use
     # local pyproject.toml
+    found = 0
     with NamedTemporaryFile(
         mode="w",
         delete=False,
@@ -54,6 +58,7 @@ def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
                 line,
             )
             if m:
+                found += 1
                 indent = m.group(1)
                 given_fnname = current_fnname = m.group(2)
                 if current_fnname.startswith("self."):
@@ -116,16 +121,20 @@ def {current_fnname}(
 
             if not in_block:
                 buf.write(line)
+    if found != expected_number:
+        raise Exception(
+            f"{modname} processed {found}. expected {expected_number}"
+        )
     return buf.name
 
 
-def run_module(modname: str, cmd: code_writer_cmd) -> None:
+def run_module(modname: str, count: int, cmd: code_writer_cmd) -> None:
     cmd.write_status(f"importing module {modname}\n")
     mod = importlib.import_module(modname)
     destination_path = mod.__file__
     assert destination_path is not None
 
-    tempfile = process_module(modname, destination_path, cmd)
+    tempfile = process_module(modname, destination_path, count, cmd)
 
     cmd.run_zimports(tempfile)
     cmd.run_black(tempfile)
@@ -133,17 +142,17 @@ def run_module(modname: str, cmd: code_writer_cmd) -> None:
 
 
 def main(cmd: code_writer_cmd) -> None:
-    for modname in entries:
+    for modname, count in entries:
         if cmd.args.module in {"all", modname}:
-            run_module(modname, cmd)
+            run_module(modname, count, cmd)
 
 
 entries = [
-    "sqlalchemy.sql._selectable_constructors",
-    "sqlalchemy.orm.session",
-    "sqlalchemy.orm.query",
-    "sqlalchemy.sql.selectable",
-    "sqlalchemy.sql.dml",
+    ("sqlalchemy.sql._selectable_constructors", 1),
+    ("sqlalchemy.orm.session", 1),
+    ("sqlalchemy.orm.query", 1),
+    ("sqlalchemy.sql.selectable", 1),
+    ("sqlalchemy.sql.dml", 3),
 ]
 
 if __name__ == "__main__":
@@ -152,7 +161,7 @@ if __name__ == "__main__":
     with cmd.add_arguments() as parser:
         parser.add_argument(
             "--module",
-            choices=entries + ["all"],
+            choices=[n for n, _ in entries] + ["all"],
             default="all",
             help="Which file to generate. Default is to regenerate all files",
         )
index de8098bcb8f67f027cfb6132efe7bc1dd7f29c10..966705690de4e83326b36df646a199df336fd8e3 100644 (file)
@@ -23,6 +23,7 @@ You can then set a breakpoint at the end of any adapt step:
 
 
 """  # noqa: E501
+
 # mypy: ignore-errors