From cfe92fac6794515d3aa3b995e288b11d5c9437fa Mon Sep 17 00:00:00 2001 From: CaselIT Date: Thu, 21 Apr 2022 23:23:00 +0200 Subject: [PATCH] Various typing related updates Change-Id: I778b63b1c438f31964d841576f0dd54ae1a5fadc --- alembic/autogenerate/api.py | 8 +- alembic/autogenerate/compare.py | 28 ++-- alembic/autogenerate/render.py | 2 + alembic/autogenerate/rewriter.py | 7 +- alembic/command.py | 11 +- alembic/config.py | 4 +- alembic/context.pyi | 9 +- alembic/ddl/base.py | 6 +- alembic/ddl/impl.py | 8 +- alembic/ddl/mssql.py | 8 +- alembic/ddl/mysql.py | 6 +- alembic/ddl/oracle.py | 2 + alembic/ddl/postgresql.py | 12 +- alembic/ddl/sqlite.py | 2 + alembic/op.pyi | 52 +++--- alembic/operations/base.py | 2 + alembic/operations/batch.py | 18 ++- alembic/operations/ops.py | 52 +++--- alembic/operations/schemaobj.py | 24 +-- alembic/runtime/environment.py | 4 +- alembic/runtime/migration.py | 33 ++-- alembic/script/base.py | 129 ++++++++------- alembic/script/revision.py | 176 +++++++++++++-------- alembic/script/write_hooks.py | 4 +- alembic/testing/assertions.py | 4 +- alembic/testing/fixtures.py | 2 + alembic/testing/suite/_autogen_fixtures.py | 11 +- alembic/testing/util.py | 3 +- alembic/util/__init__.py | 1 + alembic/util/compat.py | 6 +- alembic/util/editor.py | 2 + alembic/util/langhelpers.py | 7 + alembic/util/messaging.py | 2 + alembic/util/pyfiles.py | 2 + alembic/util/sqla_compat.py | 3 + pyproject.toml | 14 ++ tests/requirements.py | 7 +- tools/write_pyi.py | 56 ++++++- 38 files changed, 462 insertions(+), 265 deletions(-) diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index a5528ff7..4ab8a359 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -1,5 +1,4 @@ -"""Provide the 'autogenerate' feature which can produce migration operations -automatically.""" +from __future__ import annotations import contextlib from typing import Any @@ -19,6 +18,9 @@ from . import render from .. import util from ..operations import ops +"""Provide the 'autogenerate' feature which can produce migration operations +automatically.""" + if TYPE_CHECKING: from sqlalchemy.engine import Connection from sqlalchemy.engine import Dialect @@ -515,7 +517,7 @@ class RevisionContext: branch_labels=migration_script.branch_label, version_path=migration_script.version_path, depends_on=migration_script.depends_on, - **template_args + **template_args, ) def run_autogenerate( diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 528b17ac..693efae9 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import logging import re @@ -282,7 +284,7 @@ def _make_index(params: Dict[str, Any], conn_table: "Table") -> "Index": params["name"], *[conn_table.c[cname] for cname in params["column_names"]], unique=params["unique"], - _table=conn_table + _table=conn_table, ) if "duplicates_constraint" in params: ix.info["duplicates_constraint"] = params["duplicates_constraint"] @@ -294,7 +296,7 @@ def _make_unique_constraint( ) -> "UniqueConstraint": uq = sa_schema.UniqueConstraint( *[conn_table.c[cname] for cname in params["column_names"]], - name=params["name"] + name=params["name"], ) if "duplicates_index" in params: uq.info["duplicates_index"] = params["duplicates_index"] @@ -1245,7 +1247,7 @@ def _compare_foreign_keys( if isinstance(fk, sa_schema.ForeignKeyConstraint) ) - conn_fks = [ + conn_fks_list = [ fk for fk in inspector.get_foreign_keys(tname, schema=schema) if autogen_context.run_name_filters( @@ -1255,9 +1257,13 @@ def _compare_foreign_keys( ) ] - backend_reflects_fk_options = bool(conn_fks and "options" in conn_fks[0]) + backend_reflects_fk_options = bool( + conn_fks_list and "options" in conn_fks_list[0] + ) - conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks) + conn_fks = set( + _make_foreign_key(const, conn_table) for const in conn_fks_list + ) # give the dialect a chance to correct the FKs to match more # closely @@ -1265,24 +1271,24 @@ def _compare_foreign_keys( conn_fks, metadata_fks ) - metadata_fks = set( + metadata_fks_sig = set( _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) for fk in metadata_fks ) - conn_fks = set( + conn_fks_sig = set( _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) for fk in conn_fks ) - conn_fks_by_sig = dict((c.sig, c) for c in conn_fks) - metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks) + conn_fks_by_sig = dict((c.sig, c) for c in conn_fks_sig) + metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks_sig) metadata_fks_by_name = dict( - (c.name, c) for c in metadata_fks if c.name is not None + (c.name, c) for c in metadata_fks_sig if c.name is not None ) conn_fks_by_name = dict( - (c.name, c) for c in conn_fks if c.name is not None + (c.name, c) for c in conn_fks_sig if c.name is not None ) def _add_fk(obj, compare_to): diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 77f0a866..9c992b47 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import OrderedDict from io import StringIO import re diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py index 0fdd3982..79f665a0 100644 --- a/alembic/autogenerate/rewriter.py +++ b/alembic/autogenerate/rewriter.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from typing import Any from typing import Callable from typing import Iterator from typing import List +from typing import Optional from typing import Type from typing import TYPE_CHECKING from typing import Union @@ -49,12 +52,12 @@ class Rewriter: _traverse = util.Dispatcher() - _chained = None + _chained: Optional[Rewriter] = None def __init__(self) -> None: self.dispatch = util.Dispatcher() - def chain(self, other: "Rewriter") -> "Rewriter": + def chain(self, other: Rewriter) -> Rewriter: """Produce a "chain" of this :class:`.Rewriter` to another. This allows two rewriters to operate serially on a stream, diff --git a/alembic/command.py b/alembic/command.py index c1117244..d48affc7 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import os from typing import Callable -from typing import cast from typing import List from typing import Optional from typing import TYPE_CHECKING @@ -86,8 +87,9 @@ def init( for file_ in os.listdir(template_dir): file_path = os.path.join(template_dir, file_) if file_ == "alembic.ini.mako": - config_file = os.path.abspath(cast(str, config.config_file_name)) - if os.access(cast(str, config_file), os.F_OK): + assert config.config_file_name is not None + config_file = os.path.abspath(config.config_file_name) + if os.access(config_file, os.F_OK): util.msg("File %s already exists, skipping" % config_file) else: script._generate_template( @@ -273,7 +275,7 @@ def merge( refresh=True, head=revisions, branch_labels=branch_label, - **template_args # type:ignore[arg-type] + **template_args, # type:ignore[arg-type] ) @@ -642,6 +644,7 @@ def edit(config: "Config", rev: str) -> None: "No revision files indicated by symbol '%s'" % rev ) for sc in revs: + assert sc util.open_in_editor(sc.path) diff --git a/alembic/config.py b/alembic/config.py index f868bf73..dcfb9288 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from argparse import ArgumentParser from argparse import Namespace from configparser import ConfigParser @@ -559,7 +561,7 @@ class CommandLine: fn( config, *[getattr(options, k, None) for k in positional], - **dict((k, getattr(options, k, None)) for k in kwarg) + **dict((k, getattr(options, k, None)) for k in kwarg), ) except util.CommandError as e: if options.raiseerr: diff --git a/alembic/context.pyi b/alembic/context.pyi index 5c29d3ae..5ec47037 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -1,5 +1,6 @@ # ### this file stubs are generated by tools/write_pyi.py - do not edit ### # ### imports are manually managed +from __future__ import annotations from typing import Callable from typing import ContextManager @@ -67,7 +68,7 @@ def begin_transaction() -> Union["_ProxyTransaction", ContextManager]: config: Config def configure( - connection: Optional["Connection"] = None, + connection: Optional[Connection] = None, url: Optional[str] = None, dialect_name: Optional[str] = None, dialect_opts: Optional[dict] = None, @@ -78,7 +79,7 @@ def configure( tag: Optional[str] = None, template_args: Optional[dict] = None, render_as_batch: bool = False, - target_metadata: Optional["MetaData"] = None, + target_metadata: Optional[MetaData] = None, include_name: Optional[Callable] = None, include_object: Optional[Callable] = None, include_schemas: bool = False, @@ -93,7 +94,7 @@ def configure( sqlalchemy_module_prefix: str = "sa.", user_module_prefix: Optional[str] = None, on_version_apply: Optional[Callable] = None, - **kw + **kw, ) -> None: """Configure a :class:`.MigrationContext` within this :class:`.EnvironmentContext` which will provide database @@ -553,7 +554,7 @@ def get_bind(): """ -def get_context() -> "MigrationContext": +def get_context() -> MigrationContext: """Return the current :class:`.MigrationContext` object. If :meth:`.EnvironmentContext.configure` has not been diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index 022dc244..7b0f63e8 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools from typing import Optional from typing import TYPE_CHECKING @@ -114,7 +116,7 @@ class ColumnDefault(AlterColumn): name: str, column_name: str, default: Optional[_ServerDefault], - **kw + **kw, ) -> None: super(ColumnDefault, self).__init__(name, column_name, **kw) self.default = default @@ -135,7 +137,7 @@ class IdentityColumnDefault(AlterColumn): column_name: str, default: Optional["Identity"], impl: "DefaultImpl", - **kw + **kw, ) -> None: super(IdentityColumnDefault, self).__init__(name, column_name, **kw) self.default = default diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 10dcc734..8c9e0b91 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import namedtuple import re from typing import Any @@ -215,7 +217,7 @@ class DefaultImpl(metaclass=ImplMeta): existing_server_default: Optional["_ServerDefault"] = None, existing_nullable: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, - **kw: Any + **kw: Any, ) -> None: if autoincrement is not None or existing_autoincrement is not None: util.warn( @@ -266,7 +268,7 @@ class DefaultImpl(metaclass=ImplMeta): existing_server_default=existing_server_default, existing_nullable=existing_nullable, existing_comment=existing_comment, - **kw + **kw, ) ) if type_ is not None: @@ -324,7 +326,7 @@ class DefaultImpl(metaclass=ImplMeta): table_name: str, column: "Column", schema: Optional[str] = None, - **kw + **kw, ) -> None: self._exec(base.DropColumn(table_name, column, schema=schema)) diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index 4ea67180..b48f8ba9 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from typing import List from typing import Optional @@ -96,7 +98,7 @@ class MSSQLImpl(DefaultImpl): existing_type: Optional["TypeEngine"] = None, existing_server_default: Optional["_ServerDefault"] = None, existing_nullable: Optional[bool] = None, - **kw: Any + **kw: Any, ) -> None: if nullable is not None: @@ -145,7 +147,7 @@ class MSSQLImpl(DefaultImpl): schema=schema, existing_type=existing_type, existing_nullable=existing_nullable, - **kw + **kw, ) if server_default is not False and used_default is False: @@ -203,7 +205,7 @@ class MSSQLImpl(DefaultImpl): table_name: str, column: "Column", schema: Optional[str] = None, - **kw + **kw, ) -> None: drop_default = kw.pop("mssql_drop_default", False) if drop_default: diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index c3d66bdf..0c03fbe1 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from typing import Any from typing import Optional @@ -60,7 +62,7 @@ class MySQLImpl(DefaultImpl): existing_autoincrement: Optional[bool] = None, comment: Optional[Union[str, "Literal[False]"]] = False, existing_comment: Optional[str] = None, - **kw: Any + **kw: Any, ) -> None: if sqla_compat._server_default_is_identity( server_default, existing_server_default @@ -79,7 +81,7 @@ class MySQLImpl(DefaultImpl): existing_nullable=existing_nullable, server_default=server_default, existing_server_default=existing_server_default, - **kw + **kw, ) if name is not None or self._is_mysql_allowed_functional_default( type_ if type_ is not None else existing_type, server_default diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index 6dff6514..0e787fb1 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from typing import Optional from typing import TYPE_CHECKING diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 6174f382..019eb3c7 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import re from typing import Any @@ -143,7 +145,7 @@ class PostgresqlImpl(DefaultImpl): existing_server_default: Optional["_ServerDefault"] = None, existing_nullable: Optional[bool] = None, existing_autoincrement: Optional[bool] = None, - **kw: Any + **kw: Any, ) -> None: using = kw.pop("postgresql_using", None) @@ -179,7 +181,7 @@ class PostgresqlImpl(DefaultImpl): existing_server_default=existing_server_default, existing_nullable=existing_nullable, existing_autoincrement=existing_autoincrement, - **kw + **kw, ) def autogen_column_reflect(self, inspector, table, column_info): @@ -417,7 +419,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): where: Optional[Union["BinaryExpression", str]] = None, schema: Optional[str] = None, _orig_constraint: Optional["ExcludeConstraint"] = None, - **kw + **kw, ) -> None: self.constraint_name = constraint_name self.table_name = table_name @@ -459,7 +461,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): *self.elements, name=self.constraint_name, where=self.where, - **self.kw + **self.kw, ) for ( expr, @@ -477,7 +479,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): constraint_name: str, table_name: str, *elements: Any, - **kw: Any + **kw: Any, ) -> Optional["Table"]: """Issue an alter to create an EXCLUDE constraint using the current migration context. diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index f9161472..9b387664 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from typing import Any from typing import Dict diff --git a/alembic/op.pyi b/alembic/op.pyi index d781ac1b..9e3169ad 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -33,8 +33,8 @@ if TYPE_CHECKING: ### end imports ### def add_column( - table_name: str, column: "Column", schema: Optional[str] = None -) -> Optional["Table"]: + table_name: str, column: Column, schema: Optional[str] = None +) -> Optional[Table]: """Issue an "add column" instruction using the current migration context. @@ -91,16 +91,16 @@ def alter_column( comment: Union[str, bool, None] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Union["TypeEngine", Type["TypeEngine"], None] = None, - existing_type: Union["TypeEngine", Type["TypeEngine"], None] = None, + type_: Union[TypeEngine, Type[TypeEngine], None] = None, + existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, existing_server_default: Union[ - str, bool, "Identity", "Computed", None + str, bool, Identity, Computed, None ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, schema: Optional[str] = None, **kw -) -> Optional["Table"]: +) -> Optional[Table]: """Issue an "alter column" instruction using the current migration context. @@ -340,7 +340,7 @@ def batch_alter_table( """ def bulk_insert( - table: Union["Table", "TableClause"], + table: Union[Table, TableClause], rows: List[dict], multiinsert: bool = True, ) -> None: @@ -422,10 +422,10 @@ def bulk_insert( def create_check_constraint( constraint_name: Optional[str], table_name: str, - condition: Union[str, "BinaryExpression"], + condition: Union[str, BinaryExpression], schema: Optional[str] = None, **kw -) -> Optional["Table"]: +) -> Optional[Table]: """Issue a "create check constraint" instruction using the current migration context. @@ -469,7 +469,7 @@ def create_check_constraint( def create_exclude_constraint( constraint_name: str, table_name: str, *elements: Any, **kw: Any -) -> Optional["Table"]: +) -> Optional[Table]: """Issue an alter to create an EXCLUDE constraint using the current migration context. @@ -521,7 +521,7 @@ def create_foreign_key( source_schema: Optional[str] = None, referent_schema: Optional[str] = None, **dialect_kw -) -> Optional["Table"]: +) -> Optional[Table]: """Issue a "create foreign key" instruction using the current migration context. @@ -570,11 +570,11 @@ def create_foreign_key( def create_index( index_name: str, table_name: str, - columns: Sequence[Union[str, "TextClause", "Function"]], + columns: Sequence[Union[str, TextClause, Function]], schema: Optional[str] = None, unique: bool = False, **kw -) -> Optional["Table"]: +) -> Optional[Table]: """Issue a "create index" instruction using the current migration context. @@ -622,7 +622,7 @@ def create_primary_key( table_name: str, columns: List[str], schema: Optional[str] = None, -) -> Optional["Table"]: +) -> Optional[Table]: """Issue a "create primary key" instruction using the current migration context. @@ -660,7 +660,7 @@ def create_primary_key( """ -def create_table(table_name: str, *columns, **kw) -> Optional["Table"]: +def create_table(table_name: str, *columns, **kw) -> Optional[Table]: """Issue a "create table" instruction using the current migration context. @@ -743,7 +743,7 @@ def create_table_comment( comment: Optional[str], existing_comment: None = None, schema: Optional[str] = None, -) -> Optional["Table"]: +) -> Optional[Table]: """Emit a COMMENT ON operation to set the comment for a table. .. versionadded:: 1.0.6 @@ -811,7 +811,7 @@ def create_unique_constraint( def drop_column( table_name: str, column_name: str, schema: Optional[str] = None, **kw -) -> Optional["Table"]: +) -> Optional[Table]: """Issue a "drop column" instruction using the current migration context. @@ -854,7 +854,7 @@ def drop_constraint( table_name: str, type_: Optional[str] = None, schema: Optional[str] = None, -) -> Optional["Table"]: +) -> Optional[Table]: """Drop a constraint of the given name, typically via DROP CONSTRAINT. :param constraint_name: name of the constraint. @@ -873,7 +873,7 @@ def drop_index( table_name: Optional[str] = None, schema: Optional[str] = None, **kw -) -> Optional["Table"]: +) -> Optional[Table]: """Issue a "drop index" instruction using the current migration context. @@ -921,7 +921,7 @@ def drop_table_comment( table_name: str, existing_comment: Optional[str] = None, schema: Optional[str] = None, -) -> Optional["Table"]: +) -> Optional[Table]: """Issue a "drop table comment" operation to remove an existing comment set on a table. @@ -940,8 +940,8 @@ def drop_table_comment( """ def execute( - sqltext: Union[str, "TextClause", "Update"], execution_options: None = None -) -> Optional["Table"]: + sqltext: Union[str, TextClause, Update], execution_options: None = None +) -> Optional[Table]: """Execute the given SQL using the current migration context. The given SQL can be a plain string, e.g.:: @@ -1024,7 +1024,7 @@ def execute( :meth:`sqlalchemy.engine.Connection.execution_options`. """ -def f(name: str) -> "conv": +def f(name: str) -> conv: """Indicate a string name that has already had a naming convention applied to it. @@ -1061,7 +1061,7 @@ def f(name: str) -> "conv": """ -def get_bind() -> "Connection": +def get_bind() -> Connection: """Return the current 'bind'. Under normal circumstances, this is the @@ -1134,7 +1134,7 @@ def inline_literal( """ -def invoke(operation: "MigrateOperation") -> Any: +def invoke(operation: MigrateOperation) -> Any: """Given a :class:`.MigrateOperation`, invoke it in terms of this :class:`.Operations` instance. @@ -1161,7 +1161,7 @@ def register_operation( def rename_table( old_table_name: str, new_table_name: str, schema: Optional[str] = None -) -> Optional["Table"]: +) -> Optional[Table]: """Emit an ALTER TABLE to rename a table. :param old_table_name: old name. diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 07ddd5af..68b620fa 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import contextmanager import re import textwrap diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index ba1d1967..308bc2e8 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from typing import Any -from typing import cast from typing import Dict from typing import List from typing import Optional @@ -122,7 +123,7 @@ class BatchOperationsImpl: schema=self.schema, autoload_with=self.operations.get_bind(), *self.reflect_args, - **self.reflect_kwargs + **self.reflect_kwargs, ) reflected = True @@ -311,7 +312,7 @@ class ApplyBatchImpl: m, *(list(self.columns.values()) + list(self.table_args)), schema=schema, - **self.table_kwargs + **self.table_kwargs, ) for const in ( @@ -360,7 +361,7 @@ class ApplyBatchImpl: index.name, unique=index.unique, *[self.new_table.c[col] for col in index.columns.keys()], - **index.kwargs + **index.kwargs, ) ) return idx @@ -401,7 +402,7 @@ class ApplyBatchImpl: for elem in constraint.elements ] ], - schema=referent_schema + schema=referent_schema, ) def _create(self, op_impl: "DefaultImpl") -> None: @@ -453,7 +454,7 @@ class ApplyBatchImpl: type_: Optional["TypeEngine"] = None, autoincrement: None = None, comment: Union[str, "Literal[False]"] = False, - **kw + **kw, ) -> None: existing = self.columns[column_name] existing_transfer: Dict[str, Any] = self.column_transfers[column_name] @@ -574,7 +575,7 @@ class ApplyBatchImpl: column: "Column", insert_before: Optional[str] = None, insert_after: Optional[str] = None, - **kw + **kw, ) -> None: self._setup_dependencies_for_add_column( column.name, insert_before, insert_after @@ -647,7 +648,8 @@ class ApplyBatchImpl: if col_const.name == const.name: self.columns[col.name].constraints.remove(col_const) else: - const = self.named_constraints.pop(cast(str, const.name)) + assert const.name + const = self.named_constraints.pop(const.name) except KeyError: if _is_type_bound(const): # type-bound constraints are only included in the new diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 99132dd6..176c6ba6 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import abstractmethod import re from typing import Any @@ -258,7 +260,7 @@ class CreatePrimaryKeyOp(AddConstraintOp): table_name: str, columns: Sequence[str], schema: Optional[str] = None, - **kw + **kw, ) -> None: self.constraint_name = constraint_name self.table_name = table_name @@ -383,7 +385,7 @@ class CreateUniqueConstraintOp(AddConstraintOp): table_name: str, columns: Sequence[str], schema: Optional[str] = None, - **kw + **kw, ) -> None: self.constraint_name = constraint_name self.table_name = table_name @@ -434,7 +436,7 @@ class CreateUniqueConstraintOp(AddConstraintOp): table_name: str, columns: Sequence[str], schema: Optional[str] = None, - **kw + **kw, ) -> Any: """Issue a "create unique constraint" instruction using the current migration context. @@ -483,7 +485,7 @@ class CreateUniqueConstraintOp(AddConstraintOp): operations: "BatchOperations", constraint_name: str, columns: Sequence[str], - **kw + **kw, ) -> Any: """Issue a "create unique constraint" instruction using the current batch migration context. @@ -518,7 +520,7 @@ class CreateForeignKeyOp(AddConstraintOp): referent_table: str, local_cols: List[str], remote_cols: List[str], - **kw + **kw, ) -> None: self.constraint_name = constraint_name self.source_table = source_table @@ -600,7 +602,7 @@ class CreateForeignKeyOp(AddConstraintOp): match: Optional[str] = None, source_schema: Optional[str] = None, referent_schema: Optional[str] = None, - **dialect_kw + **dialect_kw, ) -> Optional["Table"]: """Issue a "create foreign key" instruction using the current migration context. @@ -678,7 +680,7 @@ class CreateForeignKeyOp(AddConstraintOp): deferrable: None = None, initially: None = None, match: None = None, - **dialect_kw + **dialect_kw, ) -> None: """Issue a "create foreign key" instruction using the current batch migration context. @@ -734,7 +736,7 @@ class CreateCheckConstraintOp(AddConstraintOp): table_name: str, condition: Union[str, "TextClause", "ColumnElement[Any]"], schema: Optional[str] = None, - **kw + **kw, ) -> None: self.constraint_name = constraint_name self.table_name = table_name @@ -753,9 +755,7 @@ class CreateCheckConstraintOp(AddConstraintOp): return cls( ck_constraint.name, constraint_table.name, - cast( - "Union[TextClause, ColumnElement[Any]]", ck_constraint.sqltext - ), + cast("ColumnElement[Any]", ck_constraint.sqltext), schema=constraint_table.schema, **ck_constraint.dialect_kwargs, ) @@ -780,7 +780,7 @@ class CreateCheckConstraintOp(AddConstraintOp): table_name: str, condition: Union[str, "BinaryExpression"], schema: Optional[str] = None, - **kw + **kw, ) -> Optional["Table"]: """Issue a "create check constraint" instruction using the current migration context. @@ -831,7 +831,7 @@ class CreateCheckConstraintOp(AddConstraintOp): operations: "BatchOperations", constraint_name: str, condition: "TextClause", - **kw + **kw, ) -> Optional["Table"]: """Issue a "create check constraint" instruction using the current batch migration context. @@ -866,7 +866,7 @@ class CreateIndexOp(MigrateOperation): columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]], schema: Optional[str] = None, unique: bool = False, - **kw + **kw, ) -> None: self.index_name = index_name self.table_name = table_name @@ -917,7 +917,7 @@ class CreateIndexOp(MigrateOperation): columns: Sequence[Union[str, "TextClause", "Function"]], schema: Optional[str] = None, unique: bool = False, - **kw + **kw, ) -> Optional["Table"]: r"""Issue a "create index" instruction using the current migration context. @@ -971,7 +971,7 @@ class CreateIndexOp(MigrateOperation): operations: "BatchOperations", index_name: str, columns: List[str], - **kw + **kw, ) -> Optional["Table"]: """Issue a "create index" instruction using the current batch migration context. @@ -1003,7 +1003,7 @@ class DropIndexOp(MigrateOperation): table_name: Optional[str] = None, schema: Optional[str] = None, _reverse: Optional["CreateIndexOp"] = None, - **kw + **kw, ) -> None: self.index_name = index_name self.table_name = table_name @@ -1050,7 +1050,7 @@ class DropIndexOp(MigrateOperation): index_name: str, table_name: Optional[str] = None, schema: Optional[str] = None, - **kw + **kw, ) -> Optional["Table"]: r"""Issue a "drop index" instruction using the current migration context. @@ -1109,7 +1109,7 @@ class CreateTableOp(MigrateOperation): schema: Optional[str] = None, _namespace_metadata: Optional["MetaData"] = None, _constraints_included: bool = False, - **kw + **kw, ) -> None: self.table_name = table_name self.columns = columns @@ -1326,7 +1326,7 @@ class DropTableOp(MigrateOperation): operations: "Operations", table_name: str, schema: Optional[str] = None, - **kw: Any + **kw: Any, ) -> None: r"""Issue a "drop table" instruction using the current migration context. @@ -1607,7 +1607,7 @@ class AlterColumnOp(AlterTableOp): modify_server_default: Any = False, modify_name: Optional[str] = None, modify_type: Optional[Any] = None, - **kw + **kw, ) -> None: super(AlterColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name @@ -1770,7 +1770,7 @@ class AlterColumnOp(AlterTableOp): existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, schema: Optional[str] = None, - **kw + **kw, ) -> Optional["Table"]: r"""Issue an "alter column" instruction using the current migration context. @@ -1897,7 +1897,7 @@ class AlterColumnOp(AlterTableOp): existing_comment: None = None, insert_before: None = None, insert_after: None = None, - **kw + **kw, ) -> Optional["Table"]: """Issue an "alter column" instruction using the current batch migration context. @@ -1954,7 +1954,7 @@ class AddColumnOp(AlterTableOp): table_name: str, column: "Column", schema: Optional[str] = None, - **kw + **kw, ) -> None: super(AddColumnOp, self).__init__(table_name, schema=schema) self.column = column @@ -2089,7 +2089,7 @@ class DropColumnOp(AlterTableOp): column_name: str, schema: Optional[str] = None, _reverse: Optional["AddColumnOp"] = None, - **kw + **kw, ) -> None: super(DropColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name @@ -2146,7 +2146,7 @@ class DropColumnOp(AlterTableOp): table_name: str, column_name: str, schema: Optional[str] = None, - **kw + **kw, ) -> Optional["Table"]: """Issue a "drop column" instruction using the current migration context. diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py index 0a27920b..6c6f9714 100644 --- a/alembic/operations/schemaobj.py +++ b/alembic/operations/schemaobj.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any from typing import Dict from typing import List @@ -44,7 +46,7 @@ class SchemaObjects: table_name: str, cols: Sequence[str], schema: Optional[str] = None, - **dialect_kw + **dialect_kw, ) -> "PrimaryKeyConstraint": m = self.metadata() columns = [sa_schema.Column(n, NULLTYPE) for n in cols] @@ -68,7 +70,7 @@ class SchemaObjects: referent_schema: Optional[str] = None, initially: Optional[str] = None, match: Optional[str] = None, - **dialect_kw + **dialect_kw, ) -> "ForeignKeyConstraint": m = self.metadata() if source == referent and source_schema == referent_schema: @@ -79,14 +81,14 @@ class SchemaObjects: referent, m, *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], - schema=referent_schema + schema=referent_schema, ) t1 = sa_schema.Table( source, m, *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], - schema=source_schema + schema=source_schema, ) tname = ( @@ -105,7 +107,7 @@ class SchemaObjects: ondelete=ondelete, deferrable=deferrable, initially=initially, - **dialect_kw + **dialect_kw, ) t1.append_constraint(f) @@ -117,13 +119,13 @@ class SchemaObjects: source: str, local_cols: Sequence[str], schema: Optional[str] = None, - **kw + **kw, ) -> "UniqueConstraint": t = sa_schema.Table( source, self.metadata(), *[sa_schema.Column(n, NULLTYPE) for n in local_cols], - schema=schema + schema=schema, ) kw["name"] = name uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) @@ -138,7 +140,7 @@ class SchemaObjects: source: str, condition: Union[str, "TextClause", "ColumnElement[Any]"], schema: Optional[str] = None, - **kw + **kw, ) -> Union["CheckConstraint"]: t = sa_schema.Table( source, @@ -156,7 +158,7 @@ class SchemaObjects: table_name: str, type_: Optional[str], schema: Optional[str] = None, - **kw + **kw, ) -> Any: t = self.table(table_name, schema=schema) types: Dict[Optional[str], Any] = { @@ -237,7 +239,7 @@ class SchemaObjects: tablename: Optional[str], columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]], schema: Optional[str] = None, - **kw + **kw, ) -> "Index": t = sa_schema.Table( tablename or "no_table", @@ -248,7 +250,7 @@ class SchemaObjects: idx = sa_schema.Index( name, *[util.sqla_compat._textual_index_column(t, n) for n in columns], - **kw + **kw, ) return idx diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index f3473de4..5fcee573 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Callable from typing import ContextManager from typing import Dict @@ -345,7 +347,7 @@ class EnvironmentContext(util.ModuleClsProxy): sqlalchemy_module_prefix: str = "sa.", user_module_prefix: Optional[str] = None, on_version_apply: Optional[Callable] = None, - **kw + **kw, ) -> None: """Configure a :class:`.MigrationContext` within this :class:`.EnvironmentContext` which will provide database diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 466264ef..c09c8e41 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import contextmanager import logging import sys @@ -39,6 +41,7 @@ if TYPE_CHECKING: from ..config import Config from ..script.base import Script from ..script.base import ScriptDirectory + from ..script.revision import _RevisionOrBase from ..script.revision import Revision from ..script.revision import RevisionMap @@ -516,7 +519,7 @@ class MigrationContext: elif start_from_rev is not None and self.script: start_from_rev = [ - self.script.get_revision(sfr).revision + cast("Script", self.script.get_revision(sfr)).revision for sfr in util.to_list(start_from_rev) if sfr not in (None, "base") ] @@ -860,15 +863,15 @@ class MigrationInfo: """ - is_upgrade: bool = None # type:ignore[assignment] + is_upgrade: bool """True/False: indicates whether this operation ascends or descends the version tree.""" - is_stamp: bool = None # type:ignore[assignment] + is_stamp: bool """True/False: indicates whether this operation is a stamp (i.e. whether it results in any actual database operations).""" - up_revision_id: Optional[str] = None + up_revision_id: Optional[str] """Version string corresponding to :attr:`.Revision.revision`. In the case of a stamp operation, it is advised to use the @@ -882,10 +885,10 @@ class MigrationInfo: """ - up_revision_ids: Tuple[str, ...] = None # type:ignore[assignment] + up_revision_ids: Tuple[str, ...] """Tuple of version strings corresponding to :attr:`.Revision.revision`. - In the majority of cases, this tuple will be a single value, synonomous + In the majority of cases, this tuple will be a single value, synonymous with the scalar value of :attr:`.MigrationInfo.up_revision_id`. It can be multiple revision identifiers only in the case of an ``alembic stamp`` operation which is moving downwards from multiple @@ -893,7 +896,7 @@ class MigrationInfo: """ - down_revision_ids: Tuple[str, ...] = None # type:ignore[assignment] + down_revision_ids: Tuple[str, ...] """Tuple of strings representing the base revisions of this migration step. If empty, this represents a root revision; otherwise, the first item @@ -901,7 +904,7 @@ class MigrationInfo: from dependencies. """ - revision_map: "RevisionMap" = None # type:ignore[assignment] + revision_map: "RevisionMap" """The revision map inside of which this operation occurs.""" def __init__( @@ -950,7 +953,7 @@ class MigrationInfo: ) @property - def up_revision(self) -> "Revision": + def up_revision(self) -> Optional[Revision]: """Get :attr:`~.MigrationInfo.up_revision_id` as a :class:`.Revision`. @@ -958,25 +961,25 @@ class MigrationInfo: return self.revision_map.get_revision(self.up_revision_id) @property - def up_revisions(self) -> Tuple["Revision", ...]: + def up_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: """Get :attr:`~.MigrationInfo.up_revision_ids` as a :class:`.Revision`.""" return self.revision_map.get_revisions(self.up_revision_ids) @property - def down_revisions(self) -> Tuple["Revision", ...]: + def down_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: """Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of :class:`Revisions <.Revision>`.""" return self.revision_map.get_revisions(self.down_revision_ids) @property - def source_revisions(self) -> Tuple["Revision", ...]: + def source_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: """Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of :class:`Revisions <.Revision>`.""" return self.revision_map.get_revisions(self.source_revision_ids) @property - def destination_revisions(self) -> Tuple["Revision", ...]: + def destination_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]: """Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of :class:`Revisions <.Revision>`.""" return self.revision_map.get_revisions(self.destination_revision_ids) @@ -1059,7 +1062,7 @@ class RevisionStep(MigrationStep): ) @property - def doc(self): + def doc(self) -> str: return self.revision.doc @property @@ -1264,7 +1267,7 @@ class StampStep(MigrationStep): self.migration_fn = self.stamp_revision self.revision_map = revision_map - doc = None + doc: None = None def stamp_revision(self, **kw) -> None: return None diff --git a/alembic/script/base.py b/alembic/script/base.py index ef0fd52a..ccbf86c9 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import contextmanager import datetime import os @@ -21,11 +23,13 @@ from . import revision from . import write_hooks from .. import util from ..runtime import migration +from ..util import not_none if TYPE_CHECKING: from ..config import Config from ..runtime.migration import RevisionStep from ..runtime.migration import StampStep + from ..script.revision import Revision try: from dateutil import tz @@ -112,7 +116,7 @@ class ScriptDirectory: else: return (os.path.abspath(os.path.join(self.dir, "versions")),) - def _load_revisions(self) -> Iterator["Script"]: + def _load_revisions(self) -> Iterator[Script]: if self.version_locations: paths = [ vers @@ -139,7 +143,7 @@ class ScriptDirectory: yield script @classmethod - def from_config(cls, config: "Config") -> "ScriptDirectory": + def from_config(cls, config: Config) -> ScriptDirectory: """Produce a new :class:`.ScriptDirectory` given a :class:`.Config` instance. @@ -152,14 +156,16 @@ class ScriptDirectory: raise util.CommandError( "No 'script_location' key " "found in configuration." ) - truncate_slug_length = cast( - Optional[int], config.get_main_option("truncate_slug_length") - ) - if truncate_slug_length is not None: - truncate_slug_length = int(truncate_slug_length) + truncate_slug_length: Optional[int] + tsl = config.get_main_option("truncate_slug_length") + if tsl is not None: + truncate_slug_length = int(tsl) + else: + truncate_slug_length = None - version_locations = config.get_main_option("version_locations") - if version_locations: + version_locations_str = config.get_main_option("version_locations") + version_locations: Optional[List[str]] + if version_locations_str: version_path_separator = config.get_main_option( "version_path_separator" ) @@ -173,7 +179,9 @@ class ScriptDirectory: } try: - split_char = split_on_path[version_path_separator] + split_char: Optional[str] = split_on_path[ + version_path_separator + ] except KeyError as ke: raise ValueError( "'%s' is not a valid value for " @@ -183,17 +191,15 @@ class ScriptDirectory: else: if split_char is None: # legacy behaviour for backwards compatibility - vl = _split_on_space_comma.split( - cast(str, version_locations) + version_locations = _split_on_space_comma.split( + version_locations_str ) - version_locations: List[str] = vl # type: ignore[no-redef] else: - vl = [ - x - for x in cast(str, version_locations).split(split_char) - if x + version_locations = [ + x for x in version_locations_str.split(split_char) if x ] - version_locations: List[str] = vl # type: ignore[no-redef] + else: + version_locations = None prepend_sys_path = config.get_main_option("prepend_sys_path") if prepend_sys_path: @@ -209,7 +215,7 @@ class ScriptDirectory: truncate_slug_length=truncate_slug_length, sourceless=config.get_main_option("sourceless") == "true", output_encoding=config.get_main_option("output_encoding", "utf-8"), - version_locations=cast("Optional[List[str]]", version_locations), + version_locations=version_locations, timezone=config.get_main_option("timezone"), hook_config=config.get_section("post_write_hooks", {}), ) @@ -262,7 +268,7 @@ class ScriptDirectory: def walk_revisions( self, base: str = "base", head: str = "heads" - ) -> Iterator["Script"]: + ) -> Iterator[Script]: """Iterate through all revisions. :param base: the base revision, or "base" to start from the @@ -279,25 +285,26 @@ class ScriptDirectory: ): yield cast(Script, rev) - def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]: + def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]: """Return the :class:`.Script` instance with the given rev identifier, symbolic name, or sequence of identifiers. """ with self._catch_revision_errors(): return cast( - "Tuple[Script, ...]", self.revision_map.get_revisions(id_) + Tuple[Optional[Script], ...], + self.revision_map.get_revisions(id_), ) - def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]: + def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]: with self._catch_revision_errors(): top_revs = cast( - "Set[Script]", + Set[Optional[Script]], set(self.revision_map.get_revisions(id_)), ) top_revs.update( cast( - "Iterator[Script]", + Iterator[Script], self.revision_map._get_ancestor_nodes( list(top_revs), include_dependencies=True ), @@ -306,7 +313,7 @@ class ScriptDirectory: top_revs = self.revision_map._filter_into_branch_heads(top_revs) return top_revs - def get_revision(self, id_: str) -> "Script": + def get_revision(self, id_: str) -> Optional[Script]: """Return the :class:`.Script` instance with the given rev id. .. seealso:: @@ -316,7 +323,7 @@ class ScriptDirectory: """ with self._catch_revision_errors(): - return cast(Script, self.revision_map.get_revision(id_)) + return cast(Optional[Script], self.revision_map.get_revision(id_)) def as_revision_number( self, id_: Optional[str] @@ -335,7 +342,12 @@ class ScriptDirectory: else: return rev[0] - def iterate_revisions(self, upper, lower): + def iterate_revisions( + self, + upper: Union[str, Tuple[str, ...], None], + lower: Union[str, Tuple[str, ...], None], + **kw: Any, + ) -> Iterator[Script]: """Iterate through script revisions, starting at the given upper revision identifier and ending at the lower. @@ -351,9 +363,12 @@ class ScriptDirectory: :meth:`.RevisionMap.iterate_revisions` """ - return self.revision_map.iterate_revisions(upper, lower) + return cast( + Iterator[Script], + self.revision_map.iterate_revisions(upper, lower, **kw), + ) - def get_current_head(self): + def get_current_head(self) -> Optional[str]: """Return the current head revision. If the script directory has multiple heads @@ -423,36 +438,36 @@ class ScriptDirectory: def _upgrade_revs( self, destination: str, current_rev: str - ) -> List["RevisionStep"]: + ) -> List[RevisionStep]: with self._catch_revision_errors( ancestor="Destination %(end)s is not a valid upgrade " "target from current head(s)", end=destination, ): - revs = self.revision_map.iterate_revisions( + revs = self.iterate_revisions( destination, current_rev, implicit_base=True ) return [ migration.MigrationStep.upgrade_from_script( - self.revision_map, cast(Script, script) + self.revision_map, script ) for script in reversed(list(revs)) ] def _downgrade_revs( self, destination: str, current_rev: Optional[str] - ) -> List["RevisionStep"]: + ) -> List[RevisionStep]: with self._catch_revision_errors( ancestor="Destination %(end)s is not a valid downgrade " "target from current head(s)", end=destination, ): - revs = self.revision_map.iterate_revisions( + revs = self.iterate_revisions( current_rev, destination, select_for_downgrade=True ) return [ migration.MigrationStep.downgrade_from_script( - self.revision_map, cast(Script, script) + self.revision_map, script ) for script in revs ] @@ -472,12 +487,14 @@ class ScriptDirectory: if not revision: revision = "base" - filtered_heads: List["Script"] = [] + filtered_heads: List[Script] = [] for rev in util.to_tuple(revision): if rev: filtered_heads.extend( self.revision_map.filter_for_lineage( - heads_revs, rev, include_dependencies=True + cast(Sequence[Script], heads_revs), + rev, + include_dependencies=True, ) ) filtered_heads = util.unique_list(filtered_heads) @@ -573,7 +590,7 @@ class ScriptDirectory: src, dest, self.output_encoding, - **kw + **kw, ) def _copy_file(self, src: str, dest: str) -> None: @@ -621,8 +638,8 @@ class ScriptDirectory: branch_labels: Optional[str] = None, version_path: Optional[str] = None, depends_on: Optional[_RevIdType] = None, - **kw: Any - ) -> Optional["Script"]: + **kw: Any, + ) -> Optional[Script]: """Generate a new revision file. This runs the ``script.py.mako`` template, given @@ -656,7 +673,12 @@ class ScriptDirectory: "or perform a merge." ) ): - heads = self.revision_map.get_revisions(head) + heads = cast( + Tuple[Optional["Revision"], ...], + self.revision_map.get_revisions(head), + ) + for h in heads: + assert h != "base" if len(set(heads)) != len(heads): raise util.CommandError("Duplicate head revisions specified") @@ -702,17 +724,20 @@ class ScriptDirectory: % head_.revision ) + resolved_depends_on: Optional[List[str]] if depends_on: with self._catch_revision_errors(): - depends_on = [ + resolved_depends_on = [ dep if dep in rev.branch_labels # maintain branch labels else rev.revision # resolve partial revision identifiers for rev, dep in [ - (self.revision_map.get_revision(dep), dep) + (not_none(self.revision_map.get_revision(dep)), dep) for dep in util.to_list(depends_on) ] ] + else: + resolved_depends_on = None self._generate_template( os.path.join(self.dir, "script.py.mako"), @@ -722,13 +747,11 @@ class ScriptDirectory: tuple(h.revision if h is not None else None for h in heads) ), branch_labels=util.to_tuple(branch_labels), - depends_on=revision.tuple_rev_as_scalar( - cast("Optional[List[str]]", depends_on) - ), + depends_on=revision.tuple_rev_as_scalar(resolved_depends_on), create_date=create_date, comma=util.format_as_comma, message=message if message is not None else ("empty message"), - **kw + **kw, ) post_write_hooks = self.hook_config @@ -801,13 +824,13 @@ class Script(revision.Revision): ), ) - module: ModuleType = None # type: ignore[assignment] + module: ModuleType """The Python module representing the actual script itself.""" - path: str = None # type: ignore[assignment] + path: str """Filesystem path of the script.""" - _db_current_indicator = None + _db_current_indicator: Optional[bool] = None """Utility variable which when set will cause string output to indicate this is a "current" version in some database""" @@ -939,7 +962,7 @@ class Script(revision.Revision): @classmethod def _from_path( cls, scriptdir: ScriptDirectory, path: str - ) -> Optional["Script"]: + ) -> Optional[Script]: dir_, filename = os.path.split(path) return cls._from_filename(scriptdir, dir_, filename) @@ -969,7 +992,7 @@ class Script(revision.Revision): @classmethod def _from_filename( cls, scriptdir: ScriptDirectory, dir_: str, filename: str - ) -> Optional["Script"]: + ) -> Optional[Script]: if scriptdir.sourceless: py_match = _sourceless_rev_file.match(filename) else: diff --git a/alembic/script/revision.py b/alembic/script/revision.py index 2bfb7f9d..335314f9 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import re from typing import Any @@ -11,6 +13,7 @@ from typing import Iterable from typing import Iterator from typing import List from typing import Optional +from typing import overload from typing import Sequence from typing import Set from typing import Tuple @@ -21,6 +24,7 @@ from typing import Union from sqlalchemy import util as sqlautil from .. import util +from ..util import not_none if TYPE_CHECKING: from typing import Literal @@ -439,7 +443,7 @@ class RevisionMap: "Revision %s referenced from %s is not present" % (downrev, revision) ) - cast("Revision", map_[downrev]).add_nextrev(revision) + not_none(map_[downrev]).add_nextrev(revision) self._normalize_depends_on(revisions, map_) @@ -502,8 +506,8 @@ class RevisionMap: return self.filter_for_lineage(self.bases, identifier) def get_revisions( - self, id_: Union[str, Collection[str], None] - ) -> Tuple["Revision", ...]: + self, id_: Union[str, Collection[Optional[str]], None] + ) -> Tuple[Optional[_RevisionOrBase], ...]: """Return the :class:`.Revision` instances with the given rev id or identifiers. @@ -537,7 +541,8 @@ class RevisionMap: select_heads = tuple( head for head in select_heads - if branch_label in head.branch_labels + if branch_label + in is_revision(head).branch_labels ) return tuple( self._walk(head, steps=rint) @@ -551,7 +556,7 @@ class RevisionMap: for rev_id in resolved_id ) - def get_revision(self, id_: Optional[str]) -> "Revision": + def get_revision(self, id_: Optional[str]) -> Optional[Revision]: """Return the :class:`.Revision` instance with the given rev id. If a symbolic name such as "head" or "base" is given, resolves @@ -568,12 +573,11 @@ class RevisionMap: resolved_id, branch_label = self._resolve_revision_number(id_) if len(resolved_id) > 1: raise MultipleHeads(resolved_id, id_) - elif resolved_id: - resolved_id = resolved_id[0] # type:ignore[assignment] - return self._revision_for_ident(cast(str, resolved_id), branch_label) + resolved: Union[str, Tuple[()]] = resolved_id[0] if resolved_id else () + return self._revision_for_ident(resolved, branch_label) - def _resolve_branch(self, branch_label: str) -> "Revision": + def _resolve_branch(self, branch_label: str) -> Optional[Revision]: try: branch_rev = self._revision_map[branch_label] except KeyError: @@ -587,25 +591,28 @@ class RevisionMap: else: return nonbranch_rev else: - return cast("Revision", branch_rev) + return branch_rev def _revision_for_ident( - self, resolved_id: str, check_branch: Optional[str] = None - ) -> "Revision": - branch_rev: Optional["Revision"] + self, + resolved_id: Union[str, Tuple[()]], + check_branch: Optional[str] = None, + ) -> Optional[Revision]: + branch_rev: Optional[Revision] if check_branch: branch_rev = self._resolve_branch(check_branch) else: branch_rev = None - revision: Union["Revision", "Literal[False]"] + revision: Union[Optional[Revision], "Literal[False]"] try: - revision = cast("Revision", self._revision_map[resolved_id]) + revision = self._revision_map[resolved_id] except KeyError: # break out to avoid misleading py3k stack traces revision = False revs: Sequence[str] if revision is False: + assert resolved_id # do a partial lookup revs = [ x @@ -637,11 +644,11 @@ class RevisionMap: resolved_id, ) else: - revision = cast("Revision", self._revision_map[revs[0]]) + revision = self._revision_map[revs[0]] - revision = cast("Revision", revision) if check_branch and revision is not None: assert branch_rev is not None + assert resolved_id if not self._shares_lineage( revision.revision, branch_rev.revision ): @@ -653,11 +660,12 @@ class RevisionMap: return revision def _filter_into_branch_heads( - self, targets: Set["Script"] - ) -> Set["Script"]: + self, targets: Set[Optional[Script]] + ) -> Set[Optional[Script]]: targets = set(targets) for rev in list(targets): + assert rev if targets.intersection( self._get_descendant_nodes([rev], include_dependencies=False) ).difference([rev]): @@ -695,9 +703,11 @@ class RevisionMap: if not test_against_revs: return True if not isinstance(target, Revision): - target = self._revision_for_ident(target) + resolved_target = not_none(self._revision_for_ident(target)) + else: + resolved_target = target - test_against_revs = [ + resolved_test_against_revs = [ self._revision_for_ident(test_against_rev) if not isinstance(test_against_rev, Revision) else test_against_rev @@ -709,15 +719,17 @@ class RevisionMap: return bool( set( self._get_descendant_nodes( - [target], include_dependencies=include_dependencies + [resolved_target], + include_dependencies=include_dependencies, ) ) .union( self._get_ancestor_nodes( - [target], include_dependencies=include_dependencies + [resolved_target], + include_dependencies=include_dependencies, ) ) - .intersection(test_against_revs) + .intersection(resolved_test_against_revs) ) def _resolve_revision_number( @@ -768,7 +780,7 @@ class RevisionMap: inclusive: bool = False, assert_relative_length: bool = True, select_for_downgrade: bool = False, - ) -> Iterator["Revision"]: + ) -> Iterator[Revision]: """Iterate through script revisions, starting at the given upper revision identifier and ending at the lower. @@ -795,11 +807,11 @@ class RevisionMap: ) for node in self._topological_sort(revisions, heads): - yield self.get_revision(node) + yield not_none(self.get_revision(node)) def _get_descendant_nodes( self, - targets: Collection["Revision"], + targets: Collection[Revision], map_: Optional[_RevisionMapType] = None, check: bool = False, omit_immediate_dependencies: bool = False, @@ -830,11 +842,11 @@ class RevisionMap: def _get_ancestor_nodes( self, - targets: Collection["Revision"], + targets: Collection[Optional[_RevisionOrBase]], map_: Optional[_RevisionMapType] = None, check: bool = False, include_dependencies: bool = True, - ) -> Iterator["Revision"]: + ) -> Iterator[Revision]: if include_dependencies: @@ -853,17 +865,17 @@ class RevisionMap: def _iterate_related_revisions( self, fn: Callable, - targets: Collection["Revision"], + targets: Collection[Optional[_RevisionOrBase]], map_: Optional[_RevisionMapType], check: bool = False, - ) -> Iterator["Revision"]: + ) -> Iterator[Revision]: if map_ is None: map_ = self._revision_map seen = set() - todo: Deque["Revision"] = collections.deque() - for target in targets: - + todo: Deque[Revision] = collections.deque() + for target_for in targets: + target = is_revision(target_for) todo.append(target) if check: per_target = set() @@ -902,7 +914,7 @@ class RevisionMap: def _topological_sort( self, - revisions: Collection["Revision"], + revisions: Collection[Revision], heads: Any, ) -> List[str]: """Yield revision ids of a collection of Revision objects in @@ -1007,11 +1019,11 @@ class RevisionMap: def _walk( self, - start: Optional[Union[str, "Revision"]], + start: Optional[Union[str, Revision]], steps: int, branch_label: Optional[str] = None, no_overwalk: bool = True, - ) -> "Revision": + ) -> Optional[_RevisionOrBase]: """ Walk the requested number of :steps up (steps > 0) or down (steps < 0) the revision tree. @@ -1030,20 +1042,21 @@ class RevisionMap: else: initial = start - children: Sequence[_RevisionOrBase] + children: Sequence[Optional[_RevisionOrBase]] for _ in range(abs(steps)): if steps > 0: + assert initial != "base" # Walk up - children = [ - rev + walk_up = [ + is_revision(rev) for rev in self.get_revisions( - self.bases - if initial is None - else cast("Revision", initial).nextrev + self.bases if initial is None else initial.nextrev ) ] if branch_label: - children = self.filter_for_lineage(children, branch_label) + children = self.filter_for_lineage(walk_up, branch_label) + else: + children = walk_up else: # Walk down if initial == "base": @@ -1055,17 +1068,17 @@ class RevisionMap: else initial.down_revision ) if not children: - children = cast("Tuple[Literal['base']]", ("base",)) + children = ("base",) if not children: # This will return an invalid result if no_overwalk, otherwise # further steps will stay where we are. ret = None if no_overwalk else initial - return ret # type:ignore[return-value] + return ret elif len(children) > 1: raise RevisionError("Ambiguous walk") initial = children[0] - return cast("Revision", initial) + return initial def _parse_downgrade_target( self, @@ -1170,7 +1183,7 @@ class RevisionMap: current_revisions: _RevisionIdentifierType, target: _RevisionIdentifierType, assert_relative_length: bool, - ) -> Tuple["Revision", ...]: + ) -> Tuple[Optional[_RevisionOrBase], ...]: """ Parse upgrade command syntax :target to retrieve the target revision and given the :current_revisons stamp of the database. @@ -1188,26 +1201,27 @@ class RevisionMap: # No relative destination, target is absolute. return self.get_revisions(target) - current_revisions = util.to_tuple(current_revisions) + current_revisions_tup: Union[str, Collection[Optional[str]], None] + current_revisions_tup = util.to_tuple(current_revisions) branch_label, symbol, relative_str = match.groups() relative = int(relative_str) if relative > 0: if symbol is None: - if not current_revisions: - current_revisions = (None,) + if not current_revisions_tup: + current_revisions_tup = (None,) # Try to filter to a single target (avoid ambiguous branches). - start_revs = current_revisions + start_revs = current_revisions_tup if branch_label: start_revs = self.filter_for_lineage( - self.get_revisions(current_revisions), branch_label + self.get_revisions(current_revisions_tup), branch_label ) if not start_revs: # The requested branch is not a head, so we need to # backtrack to find a branchpoint. active_on_branch = self.filter_for_lineage( self._get_ancestor_nodes( - self.get_revisions(current_revisions) + self.get_revisions(current_revisions_tup) ), branch_label, ) @@ -1294,6 +1308,7 @@ class RevisionMap: target_revision = None assert target_revision is None or isinstance(target_revision, Revision) + roots: List[Revision] # Find candidates to drop. if target_revision is None: # Downgrading back to base: find all tree roots. @@ -1307,7 +1322,10 @@ class RevisionMap: roots = [target_revision] else: # Downgrading to fixed target: find all direct children. - roots = list(self.get_revisions(target_revision.nextrev)) + roots = [ + is_revision(rev) + for rev in self.get_revisions(target_revision.nextrev) + ] if branch_label and len(roots) > 1: # Need to filter roots. @@ -1320,11 +1338,12 @@ class RevisionMap: } # Intersection gives the root revisions we are trying to # rollback with the downgrade. - roots = list( - self.get_revisions( + roots = [ + is_revision(rev) + for rev in self.get_revisions( {rev.revision for rev in roots}.intersection(ancestors) ) - ) + ] # Ensure we didn't throw everything away when filtering branches. if len(roots) == 0: @@ -1374,7 +1393,7 @@ class RevisionMap: inclusive: bool, implicit_base: bool, assert_relative_length: bool, - ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]: + ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]: """ Compute the set of required revisions specified by :upper, and the current set of active revisions specified by :lower. Find the @@ -1386,11 +1405,14 @@ class RevisionMap: of the current/lower revisions. Dependencies from branches with different bases will not be included. """ - targets: Collection["Revision"] = self._parse_upgrade_target( - current_revisions=lower, - target=upper, - assert_relative_length=assert_relative_length, - ) + targets: Collection[Revision] = [ + is_revision(rev) + for rev in self._parse_upgrade_target( + current_revisions=lower, + target=upper, + assert_relative_length=assert_relative_length, + ) + ] # assert type(targets) is tuple, "targets should be a tuple" @@ -1432,6 +1454,7 @@ class RevisionMap: target=lower, assert_relative_length=assert_relative_length, ) + assert rev if rev == "base": current_revisions = tuple() lower = None @@ -1449,14 +1472,16 @@ class RevisionMap: # Include the lower revision (=current_revisions?) in the iteration if inclusive: - needs.update(self.get_revisions(lower)) + needs.update(is_revision(rev) for rev in self.get_revisions(lower)) # By default, base is implicit as we want all dependencies returned. # Base is also implicit if lower = base # implicit_base=False -> only return direct downstreams of # current_revisions if current_revisions and not implicit_base: lower_descendents = self._get_descendant_nodes( - current_revisions, check=True, include_dependencies=False + [is_revision(rev) for rev in current_revisions], + check=True, + include_dependencies=False, ) needs.intersection_update(lower_descendents) @@ -1545,7 +1570,7 @@ class Revision: args.append("branch_labels=%r" % (self.branch_labels,)) return "%s(%s)" % (self.__class__.__name__, ", ".join(args)) - def add_nextrev(self, revision: "Revision") -> None: + def add_nextrev(self, revision: Revision) -> None: self._all_nextrev = self._all_nextrev.union([revision.revision]) if self.revision in revision._versioned_down_revisions: self.nextrev = self.nextrev.union([revision.revision]) @@ -1630,12 +1655,29 @@ class Revision: return len(self._versioned_down_revisions) > 1 +@overload def tuple_rev_as_scalar( rev: Optional[Sequence[str]], ) -> Optional[Union[str, Sequence[str]]]: + ... + + +@overload +def tuple_rev_as_scalar( + rev: Optional[Sequence[Optional[str]]], +) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]: + ... + + +def tuple_rev_as_scalar(rev): if not rev: return None elif len(rev) == 1: return rev[0] else: return rev + + +def is_revision(rev: Any) -> Revision: + assert isinstance(rev, Revision) + return rev diff --git a/alembic/script/write_hooks.py b/alembic/script/write_hooks.py index 0cc9bb8e..8bc7ac1c 100644 --- a/alembic/script/write_hooks.py +++ b/alembic/script/write_hooks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import shlex import subprocess import sys @@ -14,7 +16,7 @@ from ..util import compat REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME" -_registry = {} +_registry: dict = {} def register(name: str) -> Callable: diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py index e7a12c65..1c24066b 100644 --- a/alembic/testing/assertions.py +++ b/alembic/testing/assertions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import re import sys @@ -69,7 +71,7 @@ def _assert_raises( class _ErrorContainer: - error = None + error: Any = None @contextlib.contextmanager diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index 849bc830..26427507 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -1,4 +1,6 @@ # coding: utf-8 +from __future__ import annotations + import configparser from contextlib import contextmanager import io diff --git a/alembic/testing/suite/_autogen_fixtures.py b/alembic/testing/suite/_autogen_fixtures.py index c47cc10e..f97dd753 100644 --- a/alembic/testing/suite/_autogen_fixtures.py +++ b/alembic/testing/suite/_autogen_fixtures.py @@ -1,5 +1,8 @@ +from __future__ import annotations + from typing import Any from typing import Dict +from typing import Set from sqlalchemy import CHAR from sqlalchemy import CheckConstraint @@ -28,7 +31,7 @@ from ...testing import eq_ from ...testing.env import clear_staging_env from ...testing.env import staging_env -names_in_this_test = set() +names_in_this_test: Set[Any] = set() @event.listens_for(Table, "after_parent_attach") @@ -43,15 +46,15 @@ def _default_include_object(obj, name, type_, reflected, compare_to): return True -_default_object_filters = _default_include_object +_default_object_filters: Any = _default_include_object -_default_name_filters = None +_default_name_filters: Any = None class ModelOne: __requires__ = ("unique_constraint_reflection",) - schema = None + schema: Any = None @classmethod def _get_db_schema(cls): diff --git a/alembic/testing/util.py b/alembic/testing/util.py index 9d24d0fe..a82690d1 100644 --- a/alembic/testing/util.py +++ b/alembic/testing/util.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from __future__ import annotations import re import types @@ -55,7 +56,7 @@ def flag_combinations(*combinations): for d in combinations ], id_="i" + ("a" * len(keys)), - argnames=",".join(keys) + argnames=",".join(keys), ) diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index 49bee432..d5fa4d32 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -7,6 +7,7 @@ from .langhelpers import Dispatcher from .langhelpers import immutabledict from .langhelpers import memoized_property from .langhelpers import ModuleClsProxy +from .langhelpers import not_none from .langhelpers import rev_id from .langhelpers import to_list from .langhelpers import to_tuple diff --git a/alembic/util/compat.py b/alembic/util/compat.py index 54420cbc..e6a8f6e0 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import os import sys @@ -26,8 +28,8 @@ if py39: from importlib import metadata as importlib_metadata from importlib.metadata import EntryPoint else: - import importlib_resources # type:ignore[no-redef] # noqa - import importlib_metadata # type:ignore[no-redef] # noqa + import importlib_resources # type:ignore # noqa + import importlib_metadata # type:ignore # noqa from importlib_metadata import EntryPoint # type:ignore # noqa diff --git a/alembic/util/editor.py b/alembic/util/editor.py index ba376c07..f1d1557f 100644 --- a/alembic/util/editor.py +++ b/alembic/util/editor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from os.path import exists from os.path import join diff --git a/alembic/util/langhelpers.py b/alembic/util/langhelpers.py index fd7ccb8f..b6ceb0cd 100644 --- a/alembic/util/langhelpers.py +++ b/alembic/util/langhelpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections from collections.abc import Iterable import textwrap @@ -280,3 +282,8 @@ class Dispatcher: else: d._registry.update(self._registry) return d + + +def not_none(value: Optional[_T]) -> _T: + assert value is not None + return value diff --git a/alembic/util/messaging.py b/alembic/util/messaging.py index 66f8cc25..dad222f6 100644 --- a/alembic/util/messaging.py +++ b/alembic/util/messaging.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Iterable import logging import sys diff --git a/alembic/util/pyfiles.py b/alembic/util/pyfiles.py index b6662b71..75350047 100644 --- a/alembic/util/pyfiles.py +++ b/alembic/util/pyfiles.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import atexit from contextlib import ExitStack import importlib diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 787b77c2..21a9f7f2 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import re from typing import Iterator @@ -56,6 +58,7 @@ _vers = tuple( sqla_13 = _vers >= (1, 3) sqla_14 = _vers >= (1, 4) sqla_14_26 = _vers >= (1, 4, 26) +sqla_2 = _vers >= (1, 5) if sqla_14: diff --git a/pyproject.toml b/pyproject.toml index a8f43fef..721b0db0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,16 @@ [tool.black] line-length = 79 + +[tool.mypy] + +exclude = [ + 'alembic/template', + 'alembic.testing.*', +] + +[[tool.mypy.overrides]] +module = [ + 'mako.*', + 'sqlalchemy.testing.*' +] +ignore_missing_imports = true diff --git a/tests/requirements.py b/tests/requirements.py index 1060e25e..aa88f66d 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -404,4 +404,9 @@ class DefaultRequirements(SuiteRequirements): version = exclusions.only_if( lambda _: compat.py39, "python 3.9 is required" ) - return imports + version + + sqlalchemy = exclusions.only_if( + lambda _: sqla_compat.sqla_2, "sqlalchemy 2 is required" + ) + + return imports + version + sqlalchemy diff --git a/tools/write_pyi.py b/tools/write_pyi.py index 60728a8d..ec928cc1 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -4,6 +4,7 @@ import re import sys from tempfile import NamedTemporaryFile import textwrap +import typing from mako.pygen import PythonPrinter @@ -15,6 +16,8 @@ if True: # avoid flake/zimports messing with the order from alembic.script.write_hooks import console_scripts from alembic.util.compat import inspect_formatargspec from alembic.util.compat import inspect_getfullargspec + from alembic.operations import ops + import sqlalchemy as sa IGNORE_ITEMS = { "op": {"context", "create_module_class_proxy"}, @@ -24,6 +27,17 @@ IGNORE_ITEMS = { "requires_connection", }, } +TRIM_MODULE = [ + "alembic.runtime.migration.", + "alembic.operations.ops.", + "sqlalchemy.engine.base.", + "sqlalchemy.sql.schema.", + "sqlalchemy.sql.selectable.", + "sqlalchemy.sql.elements.", + "sqlalchemy.sql.type_api.", + "sqlalchemy.sql.functions.", + "sqlalchemy.sql.dml.", +] def generate_pyi_for_proxy( @@ -66,14 +80,22 @@ def generate_pyi_for_proxy( printer.writeline("### end imports ###") buf.write("\n\n") + module = sys.modules[cls.__module__] + env = { + **sa.__dict__, + **sa.types.__dict__, + **ops.__dict__, + **module.__dict__, + } + for name in dir(cls): if name.startswith("_") or name in ignore_items: continue - meth = getattr(cls, name) + meth = getattr(cls, name, None) if callable(meth): - _generate_stub_for_meth(cls, name, printer) + _generate_stub_for_meth(cls, name, printer, env) else: - _generate_stub_for_attr(cls, name, printer) + _generate_stub_for_attr(cls, name, printer, env) printer.close() @@ -92,18 +114,29 @@ def generate_pyi_for_proxy( ) -def _generate_stub_for_attr(cls, name, printer): - type_ = cls.__annotations__.get(name, "Any") +def _generate_stub_for_attr(cls, name, printer, env): + try: + annotations = typing.get_type_hints(cls, env) + except NameError as e: + annotations = cls.__annotations__ + type_ = annotations.get(name, "Any") + if isinstance(type_, str) and type_[0] in "'\"": + type_ = type_[1:-1] printer.writeline(f"{name}: {type_}") -def _generate_stub_for_meth(cls, name, printer): +def _generate_stub_for_meth(cls, name, printer, env): fn = getattr(cls, name) while hasattr(fn, "__wrapped__"): fn = fn.__wrapped__ spec = inspect_getfullargspec(fn) + try: + annotations = typing.get_type_hints(fn, env) + spec.annotations.update(annotations) + except NameError as e: + pass name_args = spec[0] assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"] @@ -119,7 +152,10 @@ def _generate_stub_for_meth(cls, name, printer): else: retval = annotation.__module__ + "." + annotation.__qualname__ else: - retval = repr(annotation) + retval = annotation + + for trim in TRIM_MODULE: + retval = retval.replace(trim, "") retval = re.sub( r'ForwardRef\(([\'"].+?[\'"])\)', lambda m: m.group(1), retval @@ -127,7 +163,11 @@ def _generate_stub_for_meth(cls, name, printer): retval = re.sub("NoneType", "None", retval) return retval - argspec = inspect_formatargspec(*spec, formatannotation=_formatannotation) + argspec = inspect_formatargspec( + *spec, + formatannotation=_formatannotation, + formatreturns=lambda val: f"-> {_formatannotation(val)}", + ) func_text = textwrap.dedent( """\ -- 2.47.2