]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
add typing parameters
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Nov 2022 18:05:08 +0000 (13:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Nov 2022 19:27:21 +0000 (14:27 -0500)
Fixed typing issue where :paramref:`.revision.process_revision_directives`
was not fully typed; additionally ensured all ``Callable`` and ``Dict``
arguments to :meth:`.EnvironmentContext.configure` include parameters in
the typing declaration.

Change-Id: I3ac389992f357359439be5659af33525fc290f96
Fixes: #1110
alembic/command.py
alembic/config.py
alembic/context.pyi
alembic/op.pyi
alembic/operations/base.py
alembic/runtime/environment.py
alembic/util/langhelpers.py
docs/build/unreleased/1110.rst [new file with mode: 0644]
tools/write_pyi.py

index bbff75d1f78d0d2ef452e3cb98cec4bbce295fd2..162b3d0c996ba8c73a5271b909cac4d428ff6325 100644 (file)
@@ -1,7 +1,6 @@
 from __future__ import annotations
 
 import os
-from typing import Callable
 from typing import List
 from typing import Optional
 from typing import TYPE_CHECKING
@@ -15,6 +14,7 @@ from .script import ScriptDirectory
 if TYPE_CHECKING:
     from alembic.config import Config
     from alembic.script.base import Script
+    from .runtime.environment import ProcessRevisionDirectiveFn
 
 
 def list_templates(config):
@@ -124,7 +124,7 @@ def revision(
     version_path: Optional[str] = None,
     rev_id: Optional[str] = None,
     depends_on: Optional[str] = None,
-    process_revision_directives: Callable = None,
+    process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
 ) -> Union[Optional["Script"], List[Optional["Script"]]]:
     """Create a new revision file.
 
@@ -243,9 +243,9 @@ def revision(
 def merge(
     config: "Config",
     revisions: str,
-    message: str = None,
-    branch_label: str = None,
-    rev_id: str = None,
+    message: Optional[str] = None,
+    branch_label: Optional[str] = None,
+    rev_id: Optional[str] = None,
 ) -> Optional["Script"]:
     """Merge two revisions together.  Creates a new migration file.
 
index dcfb928862c46b9591e25abbc5170187012fbb8f..8464407d59f7ff7ef33b470c9c0235fe60440186 100644 (file)
@@ -99,7 +99,7 @@ class Config:
         stdout: TextIO = sys.stdout,
         cmd_opts: Optional[Namespace] = None,
         config_args: util.immutabledict = util.immutabledict(),
-        attributes: dict = None,
+        attributes: Optional[dict] = None,
     ) -> None:
         """Construct a new :class:`.Config`"""
         self.config_file_name = file_
index a2e53994a20f6c33724ce510d8d76434d032a2e5..9871fadddb9a85b27adef371fe6c6c3c35360fdd 100644 (file)
@@ -19,13 +19,13 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.schema import MetaData
 
     from .config import Config
+    from .operations import MigrateOperation
     from .runtime.migration import _ProxyTransaction
     from .runtime.migration import MigrationContext
     from .script import ScriptDirectory
-
 ### end imports ###
 
-def begin_transaction() -> Union["_ProxyTransaction", ContextManager]:
+def begin_transaction() -> Union[_ProxyTransaction, ContextManager]:
     """Return a context manager that will
     enclose an operation within a "transaction",
     as defined by the environment's offline
@@ -75,29 +75,33 @@ def configure(
     connection: Optional[Connection] = None,
     url: Optional[str] = None,
     dialect_name: Optional[str] = None,
-    dialect_opts: Optional[dict] = None,
+    dialect_opts: Optional[Dict[str, Any]] = None,
     transactional_ddl: Optional[bool] = None,
     transaction_per_migration: bool = False,
     output_buffer: Optional[TextIO] = None,
     starting_rev: Optional[str] = None,
     tag: Optional[str] = None,
-    template_args: Optional[dict] = None,
+    template_args: Optional[Dict[str, Any]] = None,
     render_as_batch: bool = False,
     target_metadata: Optional[MetaData] = None,
-    include_name: Optional[Callable] = None,
-    include_object: Optional[Callable] = None,
+    include_name: Optional[Callable[..., bool]] = None,
+    include_object: Optional[Callable[..., bool]] = None,
     include_schemas: bool = False,
-    process_revision_directives: Optional[Callable] = None,
+    process_revision_directives: Optional[
+        Callable[
+            [MigrationContext, Tuple[str, str], List[MigrateOperation]], None
+        ]
+    ] = None,
     compare_type: bool = False,
     compare_server_default: bool = False,
-    render_item: Optional[Callable] = None,
+    render_item: Optional[Callable[..., bool]] = None,
     literal_binds: bool = False,
     upgrade_token: str = "upgrades",
     downgrade_token: str = "downgrades",
     alembic_module_prefix: str = "op.",
     sqlalchemy_module_prefix: str = "sa.",
     user_module_prefix: Optional[str] = None,
-    on_version_apply: Optional[Callable] = None,
+    on_version_apply: Optional[Callable[..., None]] = None,
     **kw: Any,
 ) -> None:
     """Configure a :class:`.MigrationContext` within this
index 490d714614fcbd53f9f457f2987f22864779b8a1..4e80a00aeaf8765ea5f5b8acca58a5aa60022a8d 100644 (file)
@@ -35,8 +35,8 @@ if TYPE_CHECKING:
 
     from .operations.ops import BatchOperations
     from .operations.ops import MigrateOperation
+    from .runtime.migration import MigrationContext
     from .util.sqla_compat import _literal_bindparam
-
 ### end imports ###
 
 def add_column(
@@ -1082,13 +1082,13 @@ def get_bind() -> Connection:
 
     """
 
-def get_context():
+def get_context() -> MigrationContext:
     """Return the :class:`.MigrationContext` object that's
     currently in use.
 
     """
 
-def implementation_for(op_cls: Any) -> Callable:
+def implementation_for(op_cls: Any) -> Callable[..., Any]:
     """Register an implementation for a given :class:`.MigrateOperation`.
 
     This is part of the operation extensibility API.
@@ -1101,7 +1101,7 @@ def implementation_for(op_cls: Any) -> Callable:
 
 def inline_literal(
     value: Union[str, int], type_: None = None
-) -> "_literal_bindparam":
+) -> _literal_bindparam:
     """Produce an 'inline literal' expression, suitable for
     using in an INSERT, UPDATE, or DELETE statement.
 
@@ -1152,7 +1152,7 @@ def invoke(operation: MigrateOperation) -> Any:
 
 def register_operation(
     name: str, sourcename: Optional[str] = None
-) -> Callable:
+) -> Callable[..., Any]:
     """Register a new operation for this class.
 
     This method is normally used to add new operations
index 535dff0f9cdb22b7b67f718f52477e0471a2f2b5..2178998a638ea4390219a7d4f37821827cdf393b 100644 (file)
@@ -25,6 +25,7 @@ from ..util import sqla_compat
 from ..util.compat import formatannotation_fwdref
 from ..util.compat import inspect_formatargspec
 from ..util.compat import inspect_getfullargspec
+from ..util.sqla_compat import _literal_bindparam
 
 
 NoneType = type(None)
@@ -39,7 +40,6 @@ if TYPE_CHECKING:
     from .ops import MigrateOperation
     from ..ddl import DefaultImpl
     from ..runtime.migration import MigrationContext
-    from ..util.sqla_compat import _literal_bindparam
 
 __all__ = ("Operations", "BatchOperations")
 
@@ -80,8 +80,8 @@ class Operations(util.ModuleClsProxy):
 
     def __init__(
         self,
-        migration_context: "MigrationContext",
-        impl: Optional["BatchOperationsImpl"] = None,
+        migration_context: MigrationContext,
+        impl: Optional[BatchOperationsImpl] = None,
     ) -> None:
         """Construct a new :class:`.Operations`
 
@@ -100,7 +100,7 @@ class Operations(util.ModuleClsProxy):
     @classmethod
     def register_operation(
         cls, name: str, sourcename: Optional[str] = None
-    ) -> Callable:
+    ) -> Callable[..., Any]:
         """Register a new operation for this class.
 
         This method is normally used to add new operations
@@ -188,7 +188,7 @@ class Operations(util.ModuleClsProxy):
         return register
 
     @classmethod
-    def implementation_for(cls, op_cls: Any) -> Callable:
+    def implementation_for(cls, op_cls: Any) -> Callable[..., Any]:
         """Register an implementation for a given :class:`.MigrateOperation`.
 
         This is part of the operation extensibility API.
@@ -208,8 +208,8 @@ class Operations(util.ModuleClsProxy):
     @classmethod
     @contextmanager
     def context(
-        cls, migration_context: "MigrationContext"
-    ) -> Iterator["Operations"]:
+        cls, migration_context: MigrationContext
+    ) -> Iterator[Operations]:
         op = Operations(migration_context)
         op._install_proxy()
         yield op
@@ -382,7 +382,7 @@ class Operations(util.ModuleClsProxy):
         yield batch_op
         impl.flush()
 
-    def get_context(self):
+    def get_context(self) -> MigrationContext:
         """Return the :class:`.MigrationContext` object that's
         currently in use.
 
@@ -390,7 +390,7 @@ class Operations(util.ModuleClsProxy):
 
         return self.migration_context
 
-    def invoke(self, operation: "MigrateOperation") -> Any:
+    def invoke(self, operation: MigrateOperation) -> Any:
         """Given a :class:`.MigrateOperation`, invoke it in terms of
         this :class:`.Operations` instance.
 
@@ -400,7 +400,7 @@ class Operations(util.ModuleClsProxy):
         )
         return fn(self, operation)
 
-    def f(self, name: str) -> "conv":
+    def f(self, name: str) -> conv:
         """Indicate a string name that has already had a naming convention
         applied to it.
 
@@ -440,7 +440,7 @@ class Operations(util.ModuleClsProxy):
 
     def inline_literal(
         self, value: Union[str, int], type_: None = None
-    ) -> "_literal_bindparam":
+    ) -> _literal_bindparam:
         r"""Produce an 'inline literal' expression, suitable for
         using in an INSERT, UPDATE, or DELETE statement.
 
@@ -484,7 +484,7 @@ class Operations(util.ModuleClsProxy):
         """
         return sqla_compat._literal_bindparam(None, value, type_=type_)
 
-    def get_bind(self) -> "Connection":
+    def get_bind(self) -> Connection:
         """Return the current 'bind'.
 
         Under normal circumstances, this is the
index 3cec5b1c32de7fbd497ee20162de383ace92ffd7..6dbbcc31c323244696f75ad34b5da351b25f6fb4 100644 (file)
@@ -12,6 +12,7 @@ from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import Union
 
+from .migration import _ProxyTransaction
 from .migration import MigrationContext
 from .. import util
 from ..operations import Operations
@@ -23,13 +24,17 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.elements import ClauseElement
     from sqlalchemy.sql.schema import MetaData
 
-    from .migration import _ProxyTransaction
     from ..config import Config
     from ..ddl import DefaultImpl
+    from ..operations.ops import MigrateOperation
     from ..script.base import ScriptDirectory
 
 _RevNumber = Optional[Union[str, Tuple[str, ...]]]
 
+ProcessRevisionDirectiveFn = Callable[
+    [MigrationContext, Tuple[str, str], List["MigrateOperation"]], None
+]
+
 
 class EnvironmentContext(util.ModuleClsProxy):
 
@@ -109,7 +114,7 @@ class EnvironmentContext(util.ModuleClsProxy):
     """
 
     def __init__(
-        self, config: "Config", script: "ScriptDirectory", **kw: Any
+        self, config: Config, script: ScriptDirectory, **kw: Any
     ) -> None:
         r"""Construct a new :class:`.EnvironmentContext`.
 
@@ -124,7 +129,7 @@ class EnvironmentContext(util.ModuleClsProxy):
         self.script = script
         self.context_opts = kw
 
-    def __enter__(self) -> "EnvironmentContext":
+    def __enter__(self) -> EnvironmentContext:
         """Establish a context which provides a
         :class:`.EnvironmentContext` object to
         env.py scripts.
@@ -265,13 +270,13 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     @overload
     def get_x_argument(  # type:ignore[misc]
-        self, as_dictionary: "Literal[False]" = ...
+        self, as_dictionary: Literal[False] = ...
     ) -> List[str]:
         ...
 
     @overload
     def get_x_argument(  # type:ignore[misc]
-        self, as_dictionary: "Literal[True]" = ...
+        self, as_dictionary: Literal[True] = ...
     ) -> Dict[str, str]:
         ...
 
@@ -326,32 +331,34 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     def configure(
         self,
-        connection: Optional["Connection"] = None,
+        connection: Optional[Connection] = None,
         url: Optional[str] = None,
         dialect_name: Optional[str] = None,
-        dialect_opts: Optional[dict] = None,
+        dialect_opts: Optional[Dict[str, Any]] = None,
         transactional_ddl: Optional[bool] = None,
         transaction_per_migration: bool = False,
         output_buffer: Optional[TextIO] = None,
         starting_rev: Optional[str] = None,
         tag: Optional[str] = None,
-        template_args: Optional[dict] = None,
+        template_args: Optional[Dict[str, Any]] = None,
         render_as_batch: bool = False,
-        target_metadata: Optional["MetaData"] = None,
-        include_name: Optional[Callable] = None,
-        include_object: Optional[Callable] = None,
+        target_metadata: Optional[MetaData] = None,
+        include_name: Optional[Callable[..., bool]] = None,
+        include_object: Optional[Callable[..., bool]] = None,
         include_schemas: bool = False,
-        process_revision_directives: Optional[Callable] = None,
+        process_revision_directives: Optional[
+            ProcessRevisionDirectiveFn
+        ] = None,
         compare_type: bool = False,
         compare_server_default: bool = False,
-        render_item: Optional[Callable] = None,
+        render_item: Optional[Callable[..., bool]] = None,
         literal_binds: bool = False,
         upgrade_token: str = "upgrades",
         downgrade_token: str = "downgrades",
         alembic_module_prefix: str = "op.",
         sqlalchemy_module_prefix: str = "sa.",
         user_module_prefix: Optional[str] = None,
-        on_version_apply: Optional[Callable] = None,
+        on_version_apply: Optional[Callable[..., None]] = None,
         **kw: Any,
     ) -> None:
         """Configure a :class:`.MigrationContext` within this
@@ -859,7 +866,7 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     def execute(
         self,
-        sql: Union["ClauseElement", str],
+        sql: Union[ClauseElement, str],
         execution_options: Optional[dict] = None,
     ) -> None:
         """Execute the given SQL using the current change context.
@@ -888,7 +895,7 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     def begin_transaction(
         self,
-    ) -> Union["_ProxyTransaction", ContextManager]:
+    ) -> Union[_ProxyTransaction, ContextManager]:
         """Return a context manager that will
         enclose an operation within a "transaction",
         as defined by the environment's offline
@@ -934,7 +941,7 @@ class EnvironmentContext(util.ModuleClsProxy):
 
         return self.get_context().begin_transaction()
 
-    def get_context(self) -> "MigrationContext":
+    def get_context(self) -> MigrationContext:
         """Return the current :class:`.MigrationContext` object.
 
         If :meth:`.EnvironmentContext.configure` has not been
@@ -946,7 +953,7 @@ class EnvironmentContext(util.ModuleClsProxy):
             raise Exception("No context has been configured yet.")
         return self._migration_context
 
-    def get_bind(self) -> "Connection":
+    def get_bind(self) -> Connection:
         """Return the current 'bind'.
 
         In "online" mode, this is the
@@ -959,5 +966,5 @@ class EnvironmentContext(util.ModuleClsProxy):
         """
         return self.get_context().bind  # type: ignore[return-value]
 
-    def get_impl(self) -> "DefaultImpl":
+    def get_impl(self) -> DefaultImpl:
         return self.get_context().impl
index b6ceb0cd953bdb4454fb6893dc4583912c3f7096..ff2687ce8abdd3aca0297bd2f59c88c78235bc3e 100644 (file)
@@ -198,7 +198,7 @@ def to_tuple(x: Any, default: tuple) -> tuple:
 
 
 @overload
-def to_tuple(x: None, default: _T = None) -> _T:
+def to_tuple(x: None, default: Optional[_T] = None) -> _T:
     ...
 
 
diff --git a/docs/build/unreleased/1110.rst b/docs/build/unreleased/1110.rst
new file mode 100644 (file)
index 0000000..fe9cfff
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 1110
+
+    Fixed typing issue where :paramref:`.revision.process_revision_directives`
+    was not fully typed; additionally ensured all ``Callable`` and ``Dict``
+    arguments to :meth:`.EnvironmentContext.configure` include parameters in
+    the typing declaration.
+
+    Additionally updated the codebase for Mypy 0.990 compliance.
\ No newline at end of file
index 52fac3c1282efdce72e11f4c1d1534a162081836..e5112fdb0be31dc64920bd8a5a63f93db178fb3d 100644 (file)
@@ -125,7 +125,7 @@ def generate_pyi_for_proxy(
 def _generate_stub_for_attr(cls, name, printer, env):
     try:
         annotations = typing.get_type_hints(cls, env)
-    except NameError as e:
+    except NameError:
         annotations = cls.__annotations__
     type_ = annotations.get(name, "Any")
     if isinstance(type_, str) and type_[0] in "'\"":
@@ -155,10 +155,7 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
         if getattr(annotation, "__module__", None) == "typing":
             retval = repr(annotation).replace("typing.", "")
         elif isinstance(annotation, type):
-            if annotation.__module__ in ("builtins", base_module):
-                retval = annotation.__qualname__
-            else:
-                retval = annotation.__module__ + "." + annotation.__qualname__
+            retval = annotation.__qualname__
         else:
             retval = annotation
 
@@ -184,6 +181,7 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
         '''{fn.__doc__}'''
     """
     )
+
     printer.write_indented_block(func_text)