From e28ee4ed42ac57f727a934a0916075168d87fcf3 Mon Sep 17 00:00:00 2001 From: CaselIT Date: Sat, 4 Jun 2022 05:59:23 -0400 Subject: [PATCH] Annotate batch_alter_table Fixes: #975 Closes: #1032 Pull-request: https://github.com/sqlalchemy/alembic/pull/1032 Pull-request-sha: a111d4f446e861bd01a3cea7ebd1c18a2446601a Change-Id: Idb8e1c8b6577204a64cde195f094830cdbba68ce --- alembic/op.pyi | 31 +++++++++++++++++++------------ alembic/operations/base.py | 27 ++++++++++++++++----------- alembic/util/compat.py | 4 ++-- tools/write_pyi.py | 38 ++++++++++++++++++-------------------- 4 files changed, 55 insertions(+), 45 deletions(-) diff --git a/alembic/op.pyi b/alembic/op.pyi index 9e3169ad..3745684f 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -1,11 +1,16 @@ # ### this file stubs are generated by tools/write_pyi.py - do not edit ### # ### imports are manually managed - +from contextlib import contextmanager from typing import Any from typing import Callable +from typing import Dict +from typing import Iterator from typing import List +from typing import Literal +from typing import Mapping from typing import Optional from typing import Sequence +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import Union @@ -27,6 +32,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.util import immutabledict + from .operations.ops import BatchOperations from .operations.ops import MigrateOperation from .util.sqla_compat import _literal_bindparam @@ -190,18 +196,19 @@ def alter_column( """ +@contextmanager def batch_alter_table( - table_name, - schema=None, - recreate="auto", - partial_reordering=None, - copy_from=None, - table_args=(), - table_kwargs=immutabledict({}), - reflect_args=(), - reflect_kwargs=immutabledict({}), - naming_convention=None, -): + table_name: str, + schema: Optional[str] = None, + recreate: Literal["auto", "always", "never"] = "auto", + partial_reordering: Optional[tuple] = None, + copy_from: Optional["Table"] = None, + table_args: Tuple[Any, ...] = (), + table_kwargs: Mapping[str, Any] = immutabledict({}), + reflect_args: Tuple[Any, ...] = (), + reflect_kwargs: Mapping[str, Any] = immutabledict({}), + naming_convention: Optional[Dict[str, str]] = None, +) -> Iterator["BatchOperations"]: """Invoke a series of per-table migrations in batch. Batch mode allows a series of operations specific to a table diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 68b620fa..9ecf3d4a 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -5,10 +5,13 @@ import re import textwrap from typing import Any from typing import Callable +from typing import Dict from typing import Iterator from typing import List # noqa +from typing import Mapping from typing import Optional from typing import Sequence # noqa +from typing import Tuple from typing import Type # noqa from typing import TYPE_CHECKING from typing import Union @@ -27,6 +30,8 @@ from ..util.compat import inspect_getfullargspec NoneType = type(None) if TYPE_CHECKING: + from typing import Literal + from sqlalchemy import Table # noqa from sqlalchemy.engine import Connection @@ -211,17 +216,17 @@ class Operations(util.ModuleClsProxy): @contextmanager def batch_alter_table( self, - table_name, - schema=None, - recreate="auto", - partial_reordering=None, - copy_from=None, - table_args=(), - table_kwargs=util.immutabledict(), - reflect_args=(), - reflect_kwargs=util.immutabledict(), - naming_convention=None, - ): + table_name: str, + schema: Optional[str] = None, + recreate: Literal["auto", "always", "never"] = "auto", + partial_reordering: Optional[tuple] = None, + copy_from: Optional["Table"] = None, + table_args: Tuple[Any, ...] = (), + table_kwargs: Mapping[str, Any] = util.immutabledict(), + reflect_args: Tuple[Any, ...] = (), + reflect_kwargs: Mapping[str, Any] = util.immutabledict(), + naming_convention: Optional[Dict[str, str]] = None, + ) -> Iterator["BatchOperations"]: """Invoke a series of per-table migrations in batch. Batch mode allows a series of operations specific to a table diff --git a/alembic/util/compat.py b/alembic/util/compat.py index 289aaa22..e2279756 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -35,9 +35,9 @@ else: def importlib_metadata_get(group: str) -> Sequence[EntryPoint]: ep = importlib_metadata.entry_points() if hasattr(ep, "select"): - return ep.select(group=group) # type: ignore + return ep.select(group=group) else: - return ep.get(group, ()) # type: ignore + return ep.get(group, ()) def formatannotation_fwdref(annotation, base_module=None): diff --git a/tools/write_pyi.py b/tools/write_pyi.py index ec928cc1..cf42d1b1 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -38,6 +38,7 @@ TRIM_MODULE = [ "sqlalchemy.sql.functions.", "sqlalchemy.sql.dml.", ] +CONTEXT_MANAGERS = {"op": ["batch_alter_table"]} def generate_pyi_for_proxy( @@ -46,8 +47,10 @@ def generate_pyi_for_proxy( source_path: Path, destination_path: Path, ignore_output: bool, - ignore_items: set, + file_key: str, ): + ignore_items = IGNORE_ITEMS.get(file_key, set()) + context_managers = CONTEXT_MANAGERS.get(file_key, []) if sys.version_info < (3, 9): raise RuntimeError("This script must be run with Python 3.9 or higher") @@ -93,7 +96,9 @@ def generate_pyi_for_proxy( continue meth = getattr(cls, name, None) if callable(meth): - _generate_stub_for_meth(cls, name, printer, env) + _generate_stub_for_meth( + cls, name, printer, env, name in context_managers + ) else: _generate_stub_for_attr(cls, name, printer, env) @@ -125,7 +130,7 @@ def _generate_stub_for_attr(cls, name, printer, env): printer.writeline(f"{name}: {type_}") -def _generate_stub_for_meth(cls, name, printer, env): +def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): fn = getattr(cls, name) while hasattr(fn, "__wrapped__"): @@ -168,24 +173,19 @@ def _generate_stub_for_meth(cls, name, printer, env): formatannotation=_formatannotation, formatreturns=lambda val: f"-> {_formatannotation(val)}", ) - + contextmanager = "@contextmanager" if is_context_manager else "" func_text = textwrap.dedent( - """\ - def %(name)s%(argspec)s: - '''%(doc)s''' + f""" + {contextmanager} + def {name}{argspec}: + '''{fn.__doc__}''' """ - % { - "name": name, - "argspec": argspec, - "doc": fn.__doc__, - } ) - printer.write_indented_block(func_text) def run_file( - source_path: Path, cls_to_generate: type, stdout: bool, ignore_items: set + source_path: Path, cls_to_generate: type, stdout: bool, file_key: str ): progname = Path(sys.argv[0]).as_posix() if not stdout: @@ -195,7 +195,7 @@ def run_file( source_path=source_path, destination_path=source_path, ignore_output=False, - ignore_items=ignore_items, + file_key=file_key, ) else: with NamedTemporaryFile(delete=False, suffix=".pyi") as f: @@ -207,7 +207,7 @@ def run_file( source_path=source_path, destination_path=f_path, ignore_output=True, - ignore_items=ignore_items, + file_key=file_key, ) sys.stdout.write(f_path.read_text()) f_path.unlink() @@ -216,15 +216,13 @@ def run_file( def main(args): location = Path(__file__).parent.parent / "alembic" if args.file in {"all", "op"}: - run_file( - location / "op.pyi", Operations, args.stdout, IGNORE_ITEMS["op"] - ) + run_file(location / "op.pyi", Operations, args.stdout, "op") if args.file in {"all", "context"}: run_file( location / "context.pyi", EnvironmentContext, args.stdout, - IGNORE_ITEMS["context"], + "context", ) -- 2.47.2