]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Annotate batch_alter_table
authorCaselIT <cfederico87@gmail.com>
Sat, 4 Jun 2022 09:59:23 +0000 (05:59 -0400)
committerCaselIT <cfederico87@gmail.com>
Sat, 4 Jun 2022 10:05:46 +0000 (12:05 +0200)
Fixes: #975
Closes: #1032
Pull-request: https://github.com/sqlalchemy/alembic/pull/1032
Pull-request-sha: a111d4f446e861bd01a3cea7ebd1c18a2446601a

Change-Id: Idb8e1c8b6577204a64cde195f094830cdbba68ce

alembic/op.pyi
alembic/operations/base.py
alembic/util/compat.py
tools/write_pyi.py

index 9e3169ad3d77ac2aa8000cd4a1e94d928d42ba54..3745684f7600867ac1e26655be52c6a3714ab219 100644 (file)
@@ -1,11 +1,16 @@
 # ### this file stubs are generated by tools/write_pyi.py - do not edit ###
 # ### imports are manually managed
-
+from contextlib import contextmanager
 from typing import Any
 from typing import Callable
+from typing import Dict
+from typing import Iterator
 from typing import List
+from typing import Literal
+from typing import Mapping
 from typing import Optional
 from typing import Sequence
+from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import Union
@@ -27,6 +32,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.type_api import TypeEngine
     from sqlalchemy.util import immutabledict
 
+    from .operations.ops import BatchOperations
     from .operations.ops import MigrateOperation
     from .util.sqla_compat import _literal_bindparam
 
@@ -190,18 +196,19 @@ def alter_column(
 
     """
 
+@contextmanager
 def batch_alter_table(
-    table_name,
-    schema=None,
-    recreate="auto",
-    partial_reordering=None,
-    copy_from=None,
-    table_args=(),
-    table_kwargs=immutabledict({}),
-    reflect_args=(),
-    reflect_kwargs=immutabledict({}),
-    naming_convention=None,
-):
+    table_name: str,
+    schema: Optional[str] = None,
+    recreate: Literal["auto", "always", "never"] = "auto",
+    partial_reordering: Optional[tuple] = None,
+    copy_from: Optional["Table"] = None,
+    table_args: Tuple[Any, ...] = (),
+    table_kwargs: Mapping[str, Any] = immutabledict({}),
+    reflect_args: Tuple[Any, ...] = (),
+    reflect_kwargs: Mapping[str, Any] = immutabledict({}),
+    naming_convention: Optional[Dict[str, str]] = None,
+) -> Iterator["BatchOperations"]:
     """Invoke a series of per-table migrations in batch.
 
     Batch mode allows a series of operations specific to a table
index 68b620fad7b9479c078a0be717877d59c49a6e46..9ecf3d4a9ad2c8a38c71f34dd7230e4ddc26c011 100644 (file)
@@ -5,10 +5,13 @@ import re
 import textwrap
 from typing import Any
 from typing import Callable
+from typing import Dict
 from typing import Iterator
 from typing import List  # noqa
+from typing import Mapping
 from typing import Optional
 from typing import Sequence  # noqa
+from typing import Tuple
 from typing import Type  # noqa
 from typing import TYPE_CHECKING
 from typing import Union
@@ -27,6 +30,8 @@ from ..util.compat import inspect_getfullargspec
 NoneType = type(None)
 
 if TYPE_CHECKING:
+    from typing import Literal
+
     from sqlalchemy import Table  # noqa
     from sqlalchemy.engine import Connection
 
@@ -211,17 +216,17 @@ class Operations(util.ModuleClsProxy):
     @contextmanager
     def batch_alter_table(
         self,
-        table_name,
-        schema=None,
-        recreate="auto",
-        partial_reordering=None,
-        copy_from=None,
-        table_args=(),
-        table_kwargs=util.immutabledict(),
-        reflect_args=(),
-        reflect_kwargs=util.immutabledict(),
-        naming_convention=None,
-    ):
+        table_name: str,
+        schema: Optional[str] = None,
+        recreate: Literal["auto", "always", "never"] = "auto",
+        partial_reordering: Optional[tuple] = None,
+        copy_from: Optional["Table"] = None,
+        table_args: Tuple[Any, ...] = (),
+        table_kwargs: Mapping[str, Any] = util.immutabledict(),
+        reflect_args: Tuple[Any, ...] = (),
+        reflect_kwargs: Mapping[str, Any] = util.immutabledict(),
+        naming_convention: Optional[Dict[str, str]] = None,
+    ) -> Iterator["BatchOperations"]:
         """Invoke a series of per-table migrations in batch.
 
         Batch mode allows a series of operations specific to a table
index 289aaa228a8027efd4b66ea98448dd2a18953fb5..e2279756eb0bad0fe8d6d749762bd081ffb85ad5 100644 (file)
@@ -35,9 +35,9 @@ else:
 def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
     ep = importlib_metadata.entry_points()
     if hasattr(ep, "select"):
-        return ep.select(group=group)  # type: ignore
+        return ep.select(group=group)
     else:
-        return ep.get(group, ())  # type: ignore
+        return ep.get(group, ())
 
 
 def formatannotation_fwdref(annotation, base_module=None):
index ec928cc114cd754837042a67f3ce520fd6939de8..cf42d1b1f9ceba92e4bddd5a60f792a16bdf34c2 100644 (file)
@@ -38,6 +38,7 @@ TRIM_MODULE = [
     "sqlalchemy.sql.functions.",
     "sqlalchemy.sql.dml.",
 ]
+CONTEXT_MANAGERS = {"op": ["batch_alter_table"]}
 
 
 def generate_pyi_for_proxy(
@@ -46,8 +47,10 @@ def generate_pyi_for_proxy(
     source_path: Path,
     destination_path: Path,
     ignore_output: bool,
-    ignore_items: set,
+    file_key: str,
 ):
+    ignore_items = IGNORE_ITEMS.get(file_key, set())
+    context_managers = CONTEXT_MANAGERS.get(file_key, [])
     if sys.version_info < (3, 9):
         raise RuntimeError("This script must be run with Python 3.9 or higher")
 
@@ -93,7 +96,9 @@ def generate_pyi_for_proxy(
                 continue
             meth = getattr(cls, name, None)
             if callable(meth):
-                _generate_stub_for_meth(cls, name, printer, env)
+                _generate_stub_for_meth(
+                    cls, name, printer, env, name in context_managers
+                )
             else:
                 _generate_stub_for_attr(cls, name, printer, env)
 
@@ -125,7 +130,7 @@ def _generate_stub_for_attr(cls, name, printer, env):
     printer.writeline(f"{name}: {type_}")
 
 
-def _generate_stub_for_meth(cls, name, printer, env):
+def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
 
     fn = getattr(cls, name)
     while hasattr(fn, "__wrapped__"):
@@ -168,24 +173,19 @@ def _generate_stub_for_meth(cls, name, printer, env):
         formatannotation=_formatannotation,
         formatreturns=lambda val: f"-> {_formatannotation(val)}",
     )
-
+    contextmanager = "@contextmanager" if is_context_manager else ""
     func_text = textwrap.dedent(
-        """\
-    def %(name)s%(argspec)s:
-        '''%(doc)s'''
+        f"""
+    {contextmanager}
+    def {name}{argspec}:
+        '''{fn.__doc__}'''
     """
-        % {
-            "name": name,
-            "argspec": argspec,
-            "doc": fn.__doc__,
-        }
     )
-
     printer.write_indented_block(func_text)
 
 
 def run_file(
-    source_path: Path, cls_to_generate: type, stdout: bool, ignore_items: set
+    source_path: Path, cls_to_generate: type, stdout: bool, file_key: str
 ):
     progname = Path(sys.argv[0]).as_posix()
     if not stdout:
@@ -195,7 +195,7 @@ def run_file(
             source_path=source_path,
             destination_path=source_path,
             ignore_output=False,
-            ignore_items=ignore_items,
+            file_key=file_key,
         )
     else:
         with NamedTemporaryFile(delete=False, suffix=".pyi") as f:
@@ -207,7 +207,7 @@ def run_file(
                 source_path=source_path,
                 destination_path=f_path,
                 ignore_output=True,
-                ignore_items=ignore_items,
+                file_key=file_key,
             )
             sys.stdout.write(f_path.read_text())
         f_path.unlink()
@@ -216,15 +216,13 @@ def run_file(
 def main(args):
     location = Path(__file__).parent.parent / "alembic"
     if args.file in {"all", "op"}:
-        run_file(
-            location / "op.pyi", Operations, args.stdout, IGNORE_ITEMS["op"]
-        )
+        run_file(location / "op.pyi", Operations, args.stdout, "op")
     if args.file in {"all", "context"}:
         run_file(
             location / "context.pyi",
             EnvironmentContext,
             args.stdout,
-            IGNORE_ITEMS["context"],
+            "context",
         )