]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Fix unknown types reported by pyright
authorCaselIT <cfederico87@gmail.com>
Mon, 6 Mar 2023 21:15:15 +0000 (22:15 +0100)
committerCaselIT <cfederico87@gmail.com>
Fri, 10 Mar 2023 18:27:55 +0000 (19:27 +0100)
Fixed various typing issues observed with pyright, including issues
involving the combination of :class:`.Function` and
:meth:`.MigrationContext.begin_transaction`.

Fixes: #1191
Fixes: #1201
Change-Id: I9856a8f59c22130c8bbcbed3e19cf2e8a8bf0608

alembic/context.pyi
alembic/ddl/base.py
alembic/op.pyi
alembic/operations/batch.py
alembic/operations/ops.py
alembic/runtime/environment.py
alembic/runtime/migration.py
docs/build/unreleased/1191.rst [new file with mode: 0644]
tools/write_pyi.py

index 86345c4f6076239f58b1ba6e526917b2d159ab13..142a0c7b3bee8c8a3d8119c00b99ee5627e5c370 100644 (file)
@@ -27,8 +27,8 @@ if TYPE_CHECKING:
     from .script import ScriptDirectory
 ### end imports ###
 
-def begin_transaction() -> Union[_ProxyTransaction, ContextManager]:
-    """Return a context manager that will
+def begin_transaction() -> Union[_ProxyTransaction, ContextManager[None]]:
+    r"""Return a context manager that will
     enclose an operation within a "transaction",
     as defined by the environment's offline
     and transactional DDL settings.
@@ -106,7 +106,7 @@ def configure(
     on_version_apply: Optional[Callable[..., None]] = None,
     **kw: Any,
 ) -> None:
-    """Configure a :class:`.MigrationContext` within this
+    r"""Configure a :class:`.MigrationContext` within this
     :class:`.EnvironmentContext` which will provide database
     connectivity and other configuration to a series of
     migration scripts.
@@ -542,7 +542,7 @@ def configure(
 def execute(
     sql: Union[ClauseElement, str], execution_options: Optional[dict] = None
 ) -> None:
-    """Execute the given SQL using the current change context.
+    r"""Execute the given SQL using the current change context.
 
     The behavior of :meth:`.execute` is the same
     as that of :meth:`.Operations.execute`.  Please see that
@@ -555,7 +555,7 @@ def execute(
     """
 
 def get_bind() -> Connection:
-    """Return the current 'bind'.
+    r"""Return the current 'bind'.
 
     In "online" mode, this is the
     :class:`sqlalchemy.engine.Connection` currently being used
@@ -567,7 +567,7 @@ def get_bind() -> Connection:
     """
 
 def get_context() -> MigrationContext:
-    """Return the current :class:`.MigrationContext` object.
+    r"""Return the current :class:`.MigrationContext` object.
 
     If :meth:`.EnvironmentContext.configure` has not been
     called yet, raises an exception.
@@ -575,7 +575,7 @@ def get_context() -> MigrationContext:
     """
 
 def get_head_revision() -> Union[str, Tuple[str, ...], None]:
-    """Return the hex identifier of the 'head' script revision.
+    r"""Return the hex identifier of the 'head' script revision.
 
     If the script directory has multiple heads, this
     method raises a :class:`.CommandError`;
@@ -589,7 +589,7 @@ def get_head_revision() -> Union[str, Tuple[str, ...], None]:
     """
 
 def get_head_revisions() -> Union[str, Tuple[str, ...], None]:
-    """Return the hex identifier of the 'heads' script revision(s).
+    r"""Return the hex identifier of the 'heads' script revision(s).
 
     This returns a tuple containing the version number of all
     heads in the script directory.
@@ -600,7 +600,7 @@ def get_head_revisions() -> Union[str, Tuple[str, ...], None]:
     """
 
 def get_revision_argument() -> Union[str, Tuple[str, ...], None]:
-    """Get the 'destination' revision argument.
+    r"""Get the 'destination' revision argument.
 
     This is typically the argument passed to the
     ``upgrade`` or ``downgrade`` command.
@@ -615,7 +615,7 @@ def get_revision_argument() -> Union[str, Tuple[str, ...], None]:
     """
 
 def get_starting_revision_argument() -> Union[str, Tuple[str, ...], None]:
-    """Return the 'starting revision' argument,
+    r"""Return the 'starting revision' argument,
     if the revision was passed using ``start:end``.
 
     This is only meaningful in "offline" mode.
@@ -628,7 +628,7 @@ def get_starting_revision_argument() -> Union[str, Tuple[str, ...], None]:
     """
 
 def get_tag_argument() -> Optional[str]:
-    """Return the value passed for the ``--tag`` argument, if any.
+    r"""Return the value passed for the ``--tag`` argument, if any.
 
     The ``--tag`` argument is not used directly by Alembic,
     but is available for custom ``env.py`` configurations that
@@ -654,7 +654,7 @@ def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ...
 def get_x_argument(
     as_dictionary: bool = ...,
 ) -> Union[List[str], Dict[str, str]]:
-    """Return the value(s) passed for the ``-x`` argument, if any.
+    r"""Return the value(s) passed for the ``-x`` argument, if any.
 
     The ``-x`` argument is an open ended flag that allows any user-defined
     value or values to be passed on the command line, then available
@@ -694,7 +694,7 @@ def get_x_argument(
     """
 
 def is_offline_mode() -> bool:
-    """Return True if the current migrations environment
+    r"""Return True if the current migrations environment
     is running in "offline mode".
 
     This is ``True`` or ``False`` depending
@@ -706,7 +706,7 @@ def is_offline_mode() -> bool:
     """
 
 def is_transactional_ddl():
-    """Return True if the context is configured to expect a
+    r"""Return True if the context is configured to expect a
     transactional DDL capable backend.
 
     This defaults to the type of database in use, and
@@ -719,7 +719,7 @@ def is_transactional_ddl():
     """
 
 def run_migrations(**kw: Any) -> None:
-    """Run migrations as determined by the current command line
+    r"""Run migrations as determined by the current command line
     configuration
     as well as versioning information present (or not) in the current
     database connection (if one is present).
@@ -742,7 +742,7 @@ def run_migrations(**kw: Any) -> None:
 script: ScriptDirectory
 
 def static_output(text: str) -> None:
-    """Emit text directly to the "offline" SQL stream.
+    r"""Emit text directly to the "offline" SQL stream.
 
     Typically this is for emitting comments that
     start with --.  The statement is not treated
index c3bdaf382be31576eb57c4a07cda2b5912e6b263..65da32f40cf24c2f93d08de20a793b572d3b5c67 100644 (file)
@@ -20,6 +20,8 @@ from ..util.sqla_compat import _is_type_bound  # noqa
 from ..util.sqla_compat import _table_for_constraint  # noqa
 
 if TYPE_CHECKING:
+    from typing import Any
+
     from sqlalchemy.sql.compiler import Compiled
     from sqlalchemy.sql.compiler import DDLCompiler
     from sqlalchemy.sql.elements import TextClause
@@ -31,7 +33,7 @@ if TYPE_CHECKING:
     from ..util.sqla_compat import Computed
     from ..util.sqla_compat import Identity
 
-_ServerDefault = Union["TextClause", "FetchedValue", "Function", str]
+_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
 
 
 class AlterTable(DDLElement):
index 7a5710eb1f5ccd31dddb4f1b684a50e0032f5baa..2f92dc3401493fcd4324869f0c89581d93175ad3 100644 (file)
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
 def add_column(
     table_name: str, column: Column, schema: Optional[str] = None
 ) -> Optional[Table]:
-    """Issue an "add column" instruction using the current
+    r"""Issue an "add column" instruction using the current
     migration context.
 
     e.g.::
@@ -108,7 +108,7 @@ def alter_column(
     schema: Optional[str] = None,
     **kw: Any
 ) -> Optional[Table]:
-    """Issue an "alter column" instruction using the
+    r"""Issue an "alter column" instruction using the
     current migration context.
 
     Generally, only that aspect of the column which
@@ -210,7 +210,7 @@ def batch_alter_table(
     reflect_kwargs: Mapping[str, Any] = immutabledict({}),
     naming_convention: Optional[Dict[str, str]] = None,
 ) -> Iterator[BatchOperations]:
-    """Invoke a series of per-table migrations in batch.
+    r"""Invoke a series of per-table migrations in batch.
 
     Batch mode allows a series of operations specific to a table
     to be syntactically grouped together, and allows for alternate
@@ -352,7 +352,7 @@ def bulk_insert(
     rows: List[dict],
     multiinsert: bool = True,
 ) -> None:
-    """Issue a "bulk insert" operation using the current
+    r"""Issue a "bulk insert" operation using the current
     migration context.
 
     This provides a means of representing an INSERT of multiple rows
@@ -434,7 +434,7 @@ def create_check_constraint(
     schema: Optional[str] = None,
     **kw: Any
 ) -> Optional[Table]:
-    """Issue a "create check constraint" instruction using the
+    r"""Issue a "create check constraint" instruction using the
     current migration context.
 
     e.g.::
@@ -478,7 +478,7 @@ def create_check_constraint(
 def create_exclude_constraint(
     constraint_name: str, table_name: str, *elements: Any, **kw: Any
 ) -> Optional[Table]:
-    """Issue an alter to create an EXCLUDE constraint using the
+    r"""Issue an alter to create an EXCLUDE constraint using the
     current migration context.
 
     .. note::  This method is Postgresql specific, and additionally
@@ -530,7 +530,7 @@ def create_foreign_key(
     referent_schema: Optional[str] = None,
     **dialect_kw: Any
 ) -> Optional[Table]:
-    """Issue a "create foreign key" instruction using the
+    r"""Issue a "create foreign key" instruction using the
     current migration context.
 
     e.g.::
@@ -578,12 +578,12 @@ def create_foreign_key(
 def create_index(
     index_name: Optional[str],
     table_name: str,
-    columns: Sequence[Union[str, TextClause, Function]],
+    columns: Sequence[Union[str, TextClause, Function[Any]]],
     schema: Optional[str] = None,
     unique: bool = False,
     **kw: Any
 ) -> Optional[Table]:
-    """Issue a "create index" instruction using the current
+    r"""Issue a "create index" instruction using the current
     migration context.
 
     e.g.::
@@ -631,7 +631,7 @@ def create_primary_key(
     columns: List[str],
     schema: Optional[str] = None,
 ) -> Optional[Table]:
-    """Issue a "create primary key" instruction using the current
+    r"""Issue a "create primary key" instruction using the current
     migration context.
 
     e.g.::
@@ -671,7 +671,7 @@ def create_primary_key(
 def create_table(
     table_name: str, *columns: SchemaItem, **kw: Any
 ) -> Optional[Table]:
-    """Issue a "create table" instruction using the current migration
+    r"""Issue a "create table" instruction using the current migration
     context.
 
     This directive receives an argument list similar to that of the
@@ -754,7 +754,7 @@ def create_table_comment(
     existing_comment: None = None,
     schema: Optional[str] = None,
 ) -> Optional[Table]:
-    """Emit a COMMENT ON operation to set the comment for a table.
+    r"""Emit a COMMENT ON operation to set the comment for a table.
 
     .. versionadded:: 1.0.6
 
@@ -781,7 +781,7 @@ def create_unique_constraint(
     schema: Optional[str] = None,
     **kw: Any
 ) -> Any:
-    """Issue a "create unique constraint" instruction using the
+    r"""Issue a "create unique constraint" instruction using the
     current migration context.
 
     e.g.::
@@ -822,7 +822,7 @@ def create_unique_constraint(
 def drop_column(
     table_name: str, column_name: str, schema: Optional[str] = None, **kw: Any
 ) -> Optional[Table]:
-    """Issue a "drop column" instruction using the current
+    r"""Issue a "drop column" instruction using the current
     migration context.
 
     e.g.::
@@ -865,7 +865,7 @@ def drop_constraint(
     type_: Optional[str] = None,
     schema: Optional[str] = None,
 ) -> Optional[Table]:
-    """Drop a constraint of the given name, typically via DROP CONSTRAINT.
+    r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
 
     :param constraint_name: name of the constraint.
     :param table_name: table name.
@@ -884,7 +884,7 @@ def drop_index(
     schema: Optional[str] = None,
     **kw: Any
 ) -> Optional[Table]:
-    """Issue a "drop index" instruction using the current
+    r"""Issue a "drop index" instruction using the current
     migration context.
 
     e.g.::
@@ -909,7 +909,7 @@ def drop_index(
 def drop_table(
     table_name: str, schema: Optional[str] = None, **kw: Any
 ) -> None:
-    """Issue a "drop table" instruction using the current
+    r"""Issue a "drop table" instruction using the current
     migration context.
 
 
@@ -932,7 +932,7 @@ def drop_table_comment(
     existing_comment: Optional[str] = None,
     schema: Optional[str] = None,
 ) -> Optional[Table]:
-    """Issue a "drop table comment" operation to
+    r"""Issue a "drop table comment" operation to
     remove an existing comment set on a table.
 
     .. versionadded:: 1.0.6
@@ -952,7 +952,7 @@ def drop_table_comment(
 def execute(
     sqltext: Union[str, TextClause, Update], execution_options: None = None
 ) -> Optional[Table]:
-    """Execute the given SQL using the current migration context.
+    r"""Execute the given SQL using the current migration context.
 
     The given SQL can be a plain string, e.g.::
 
@@ -1035,7 +1035,7 @@ def execute(
     """
 
 def f(name: str) -> conv:
-    """Indicate a string name that has already had a naming convention
+    r"""Indicate a string name that has already had a naming convention
     applied to it.
 
     This feature combines with the SQLAlchemy ``naming_convention`` feature
@@ -1072,7 +1072,7 @@ def f(name: str) -> conv:
     """
 
 def get_bind() -> Connection:
-    """Return the current 'bind'.
+    r"""Return the current 'bind'.
 
     Under normal circumstances, this is the
     :class:`~sqlalchemy.engine.Connection` currently being used
@@ -1083,13 +1083,13 @@ def get_bind() -> Connection:
     """
 
 def get_context() -> MigrationContext:
-    """Return the :class:`.MigrationContext` object that's
+    r"""Return the :class:`.MigrationContext` object that's
     currently in use.
 
     """
 
 def implementation_for(op_cls: Any) -> Callable[..., Any]:
-    """Register an implementation for a given :class:`.MigrateOperation`.
+    r"""Register an implementation for a given :class:`.MigrateOperation`.
 
     This is part of the operation extensibility API.
 
@@ -1102,7 +1102,7 @@ def implementation_for(op_cls: Any) -> Callable[..., Any]:
 def inline_literal(
     value: Union[str, int], type_: None = None
 ) -> _literal_bindparam:
-    """Produce an 'inline literal' expression, suitable for
+    r"""Produce an 'inline literal' expression, suitable for
     using in an INSERT, UPDATE, or DELETE statement.
 
     When using Alembic in "offline" mode, CRUD operations
@@ -1145,7 +1145,7 @@ def inline_literal(
     """
 
 def invoke(operation: MigrateOperation) -> Any:
-    """Given a :class:`.MigrateOperation`, invoke it in terms of
+    r"""Given a :class:`.MigrateOperation`, invoke it in terms of
     this :class:`.Operations` instance.
 
     """
@@ -1153,7 +1153,7 @@ def invoke(operation: MigrateOperation) -> Any:
 def register_operation(
     name: str, sourcename: Optional[str] = None
 ) -> Callable[..., Any]:
-    """Register a new operation for this class.
+    r"""Register a new operation for this class.
 
     This method is normally used to add new operations
     to the :class:`.Operations` class, and possibly the
@@ -1172,7 +1172,7 @@ def register_operation(
 def rename_table(
     old_table_name: str, new_table_name: str, schema: Optional[str] = None
 ) -> Optional[Table]:
-    """Emit an ALTER TABLE to rename a table.
+    r"""Emit an ALTER TABLE to rename a table.
 
     :param old_table_name: old name.
     :param new_table_name: new name.
index da2caf6d2cc3c28eedebca6a008246eb38c9bb40..00f13a1bea399828c097c7826a69296f52e4315a 100644 (file)
@@ -484,7 +484,7 @@ class ApplyBatchImpl:
         table_name: str,
         column_name: str,
         nullable: Optional[bool] = None,
-        server_default: Optional[Union[Function, str, bool]] = False,
+        server_default: Optional[Union[Function[Any], str, bool]] = False,
         name: Optional[str] = None,
         type_: Optional[TypeEngine] = None,
         autoincrement: None = None,
index 48384f96654f6f809f26f43efc99c529a05ec726..a40704fdc2a8998faffbf01dc046b51232cae068 100644 (file)
@@ -911,7 +911,7 @@ class CreateIndexOp(MigrateOperation):
         operations: Operations,
         index_name: Optional[str],
         table_name: str,
-        columns: Sequence[Union[str, TextClause, Function]],
+        columns: Sequence[Union[str, TextClause, Function[Any]]],
         schema: Optional[str] = None,
         unique: bool = False,
         **kw: Any,
@@ -1885,7 +1885,7 @@ class AlterColumnOp(AlterTableOp):
         column_name: str,
         nullable: Optional[bool] = None,
         comment: Union[str, Literal[False]] = False,
-        server_default: Union[Function, bool] = False,
+        server_default: Union[Function[Any], bool] = False,
         new_column_name: Optional[str] = None,
         type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
         existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
index a441d1fd53739df3e1375d084defb01d377334aa..0f9d3a56c0fa8847717829d355dc1454705a5256 100644 (file)
@@ -897,7 +897,7 @@ class EnvironmentContext(util.ModuleClsProxy):
 
     def begin_transaction(
         self,
-    ) -> Union[_ProxyTransaction, ContextManager]:
+    ) -> Union[_ProxyTransaction, ContextManager[None]]:
         """Return a context manager that will
         enclose an operation within a "transaction",
         as defined by the environment's offline
index 95eb82a450a662ee1c2fd7839d9c95cc589151ea..a6156511fdfb1970837ebf04c762694005a5d106 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 from contextlib import contextmanager
+from contextlib import nullcontext
 import logging
 import sys
 from typing import Any
@@ -366,7 +367,7 @@ class MigrationContext:
 
     def begin_transaction(
         self, _per_migration: bool = False
-    ) -> Union[_ProxyTransaction, ContextManager]:
+    ) -> Union[_ProxyTransaction, ContextManager[None]]:
         """Begin a logical transaction for migration operations.
 
         This method is used within an ``env.py`` script to demarcate where
@@ -408,12 +409,8 @@ class MigrationContext:
 
         """
 
-        @contextmanager
-        def do_nothing():
-            yield
-
         if self._in_external_transaction:
-            return do_nothing()
+            return nullcontext()
 
         if self.impl.transactional_ddl:
             transaction_now = _per_migration == self._transaction_per_migration
@@ -421,13 +418,13 @@ class MigrationContext:
             transaction_now = _per_migration is True
 
         if not transaction_now:
-            return do_nothing()
+            return nullcontext()
 
         elif not self.impl.transactional_ddl:
             assert _per_migration
 
             if self.as_sql:
-                return do_nothing()
+                return nullcontext()
             else:
                 # track our own notion of a "transaction block", which must be
                 # committed when complete.   Don't rely upon whether or not the
@@ -443,7 +440,7 @@ class MigrationContext:
                 in_transaction = self._transaction is not None
 
                 if in_transaction:
-                    return do_nothing()
+                    return nullcontext()
                 else:
                     assert self.connection is not None
                     self._transaction = (
diff --git a/docs/build/unreleased/1191.rst b/docs/build/unreleased/1191.rst
new file mode 100644 (file)
index 0000000..0ef1953
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 1191, 1201
+
+    Fixed various typing issues observed with pyright, including issues
+    involving the combination of :class:`.Function` and
+    :meth:`.MigrationContext.begin_transaction`.
index 376163b1d5aada147b4ed62865f90a407db445a7..aec18131c926f6d9c532965e530cbcd2a6591c6d 100644 (file)
@@ -217,7 +217,7 @@ def _generate_stub_for_meth(
 
     fn_doc = base_method.__doc__ if base_method else fn.__doc__
     has_docs = gen_docs and fn_doc is not None
-    docs = '"""' + f"{fn_doc}" + '"""' if has_docs else ""
+    docs = 'r"""' + f"{fn_doc}" + '"""' if has_docs else ""
 
     func_text = textwrap.dedent(
         f"""