-from .api import _render_migration_diffs
-from .api import compare_metadata
-from .api import produce_migrations
-from .api import render_python_code
-from .api import RevisionContext
-from .compare import _produce_net_changes
-from .compare import comparators
-from .render import render_op_text
-from .render import renderers
-from .rewriter import Rewriter
+from .api import _render_migration_diffs as _render_migration_diffs
+from .api import compare_metadata as compare_metadata
+from .api import produce_migrations as produce_migrations
+from .api import render_python_code as render_python_code
+from .api import RevisionContext as RevisionContext
+from .compare import _produce_net_changes as _produce_net_changes
+from .compare import comparators as comparators
+from .render import render_op_text as render_op_text
+from .render import renderers as renderers
+from .rewriter import Rewriter as Rewriter
from sqlalchemy.engine import Inspector
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import SchemaItem
+ from sqlalchemy.sql.schema import Table
from ..config import Config
from ..operations.ops import DowngradeOps
"""
migration_script = produce_migrations(context, metadata)
+ assert migration_script.upgrade_ops is not None
return migration_script.upgrade_ops.as_diffs()
self,
migration_context: MigrationContext,
metadata: Optional[MetaData] = None,
- opts: Optional[dict] = None,
+ opts: Optional[Dict[str, Any]] = None,
autogenerate: bool = True,
) -> None:
if (
run_filters = run_object_filters
@util.memoized_property
- def sorted_tables(self):
+ def sorted_tables(self) -> List[Table]:
"""Return an aggregate of the :attr:`.MetaData.sorted_tables`
collection(s).
return result
@util.memoized_property
- def table_key_to_table(self):
+ def table_key_to_table(self) -> Dict[str, Table]:
"""Return an aggregate of the :attr:`.MetaData.tables` dictionaries.
The :attr:`.MetaData.tables` collection is a dictionary of table key
objects contain the same table key, an exception is raised.
"""
- result = {}
+ result: Dict[str, Table] = {}
for m in util.to_list(self.metadata):
intersect = set(result).intersection(set(m.tables))
if intersect:
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import contextlib
# 5. index things by name, for those objects that have names
metadata_names = {
cast(str, c.md_name_to_sql_name(autogen_context)): c
- for c in metadata_unique_constraints_sig.union(
- metadata_indexes_sig # type:ignore[arg-type]
- )
+ for c in metadata_unique_constraints_sig.union(metadata_indexes_sig)
if c.is_named
}
obj.const, obj.name, "foreign_key_constraint", False, compare_to
):
modify_table_ops.ops.append(
- ops.CreateForeignKeyOp.from_constraint(const.const)
+ ops.CreateForeignKeyOp.from_constraint(const.const) # type: ignore[has-type] # noqa: E501
)
log.info(
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
from io import StringIO
) -> str:
base_type, variant_mapping = sqla_compat._get_variant_mapping(type_)
base = _repr_type(base_type, autogen_context, _skip_variants=True)
- assert base is not None and base is not False
+ assert base is not None and base is not False # type: ignore[comparison-overlap] # noqa:E501
for dialect in sorted(variant_mapping):
typ = variant_mapping[dialect]
base += ".with_variant(%s, %r)" % (
won't fail if the remote table can't be resolved.
"""
- colspec = fk._get_colspec() # type:ignore[attr-defined]
+ colspec = fk._get_colspec()
tokens = colspec.split(".")
tname, colname = tokens[-2:]
% {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"cols": ", ".join(
- "%r" % _ident(cast("Column", f.parent).name)
- for f in constraint.elements
+ repr(_ident(f.parent.name)) for f in constraint.elements
),
"refcols": ", ".join(
repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
# ideally SQLAlchemy would give us more of a first class
# way to detect this.
if (
- constraint._create_rule # type:ignore[attr-defined]
- and hasattr(
- constraint._create_rule, "target" # type:ignore[attr-defined]
- )
+ constraint._create_rule
+ and hasattr(constraint._create_rule, "target")
and isinstance(
- constraint._create_rule.target, # type:ignore[attr-defined]
+ constraint._create_rule.target,
sqltypes.TypeEngine,
)
):
from ..operations.ops import AddColumnOp
from ..operations.ops import AlterColumnOp
from ..operations.ops import CreateTableOp
+ from ..operations.ops import DowngradeOps
from ..operations.ops import MigrateOperation
from ..operations.ops import MigrationScript
from ..operations.ops import ModifyTableOps
from ..operations.ops import OpContainer
- from ..runtime.environment import _GetRevArg
+ from ..operations.ops import UpgradeOps
from ..runtime.migration import MigrationContext
+ from ..script.revision import _GetRevArg
class Rewriter:
Type[CreateTableOp],
Type[ModifyTableOps],
],
- ) -> Callable:
+ ) -> Callable[..., Any]:
"""Register a function as rewriter for a given type.
The function should receive three arguments, which are
revision: _GetRevArg,
directive: MigrationScript,
) -> None:
- upgrade_ops_list = []
+ upgrade_ops_list: List[UpgradeOps] = []
for upgrade_ops in directive.upgrade_ops_list:
ret = self._traverse_for(context, revision, upgrade_ops)
if len(ret) != 1:
"Can only return single object for UpgradeOps traverse"
)
upgrade_ops_list.append(ret[0])
- directive.upgrade_ops = upgrade_ops_list
- downgrade_ops_list = []
+ directive.upgrade_ops = upgrade_ops_list # type: ignore
+
+ downgrade_ops_list: List[DowngradeOps] = []
for downgrade_ops in directive.downgrade_ops_list:
ret = self._traverse_for(context, revision, downgrade_ops)
if len(ret) != 1:
"Can only return single object for DowngradeOps traverse"
)
downgrade_ops_list.append(ret[0])
- directive.downgrade_ops = downgrade_ops_list
+ directive.downgrade_ops = downgrade_ops_list # type: ignore
@_traverse.dispatch_for(ops.OpContainer)
def _traverse_op_container(
+# mypy: allow-untyped-defs, allow-untyped-calls
+
from __future__ import annotations
import os
from .runtime.environment import ProcessRevisionDirectiveFn
-def list_templates(config: Config):
+def list_templates(config: Config) -> None:
"""List available templates.
:param config: a :class:`.Config` object.
from typing import Mapping
from typing import Optional
from typing import overload
+from typing import Sequence
from typing import TextIO
from typing import Union
stdout: TextIO = sys.stdout,
cmd_opts: Optional[Namespace] = None,
config_args: Mapping[str, Any] = util.immutabledict(),
- attributes: Optional[dict] = None,
+ attributes: Optional[Dict[str, Any]] = None,
) -> None:
"""Construct a new :class:`.Config`"""
self.config_file_name = file_
"""
@util.memoized_property
- def attributes(self):
+ def attributes(self) -> Dict[str, Any]:
"""A Python dictionary for storage of additional state.
"""
return {}
- def print_stdout(self, text: str, *arg) -> None:
+ def print_stdout(self, text: str, *arg: Any) -> None:
"""Render a message to standard out.
When :meth:`.Config.print_stdout` is called with additional args
util.write_outstream(self.stdout, output, "\n", **self.messaging_opts)
@util.memoized_property
- def file_config(self):
+ def file_config(self) -> ConfigParser:
"""Return the underlying ``ConfigParser`` object.
Direct access to the .ini file is available here,
) -> Optional[str]:
...
- def get_main_option(self, name, default=None):
+ def get_main_option(
+ self, name: str, default: Optional[str] = None
+ ) -> Optional[str]:
"""Return an option from the 'main' section of the .ini file.
This defaults to being a key from the ``[alembic]``
self._generate_args(prog)
def _generate_args(self, prog: Optional[str]) -> None:
- def add_options(fn, parser, positional, kwargs):
+ def add_options(
+ fn: Any, parser: Any, positional: Any, kwargs: Any
+ ) -> None:
kwargs_opts = {
"template": (
"-t",
)
subparsers = parser.add_subparsers()
- positional_translations = {command.stamp: {"revision": "revisions"}}
+ positional_translations: Dict[Any, Any] = {
+ command.stamp: {"revision": "revisions"}
+ }
for fn in [getattr(command, n) for n in dir(command)]:
if (
else:
util.err(str(e), **config.messaging_opts)
- def main(self, argv=None):
+ def main(self, argv: Optional[Sequence[str]] = None) -> None:
options = self.parser.parse_args(argv)
if not hasattr(options, "cmd"):
# see http://bugs.python.org/issue9253, argparse
self.run_cmd(cfg, options)
-def main(argv=None, prog=None, **kwargs):
+def main(
+ argv: Optional[Sequence[str]] = None,
+ prog: Optional[str] = None,
+ **kwargs: Any,
+) -> None:
"""The console runner function for Alembic."""
CommandLine(prog=prog).main(argv=argv)
MigrationContext,
Column[Any],
Column[Any],
- TypeEngine,
- TypeEngine,
+ TypeEngine[Any],
+ TypeEngine[Any],
],
Optional[bool],
],
"""
def execute(
- sql: Union[Executable, str], execution_options: Optional[dict] = None
+ sql: Union[Executable, str],
+ execution_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Execute the given SQL using the current change context.
"""
-def is_transactional_ddl():
+def is_transactional_ddl() -> bool:
"""Return True if the context is configured to expect a
transactional DDL capable backend.
from . import oracle
from . import postgresql
from . import sqlite
-from .impl import DefaultImpl
+from .impl import DefaultImpl as DefaultImpl
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
from typing import Any
from sqlalchemy.sql.schema import UniqueConstraint
from typing_extensions import TypeGuard
-from alembic.ddl.base import _fk_spec
from .. import util
from ..util import sqla_compat
ondelete,
deferrable,
initially,
- ) = _fk_spec(const)
+ ) = sqla_compat._fk_spec(const)
self._sig: Tuple[Any, ...] = (
self.source_schema,
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import functools
self.comment = comment
-@compiles(RenameTable)
+@compiles(RenameTable) # type: ignore[misc]
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
)
-@compiles(AddColumn)
+@compiles(AddColumn) # type: ignore[misc]
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
)
-@compiles(DropColumn)
+@compiles(DropColumn) # type: ignore[misc]
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
)
-@compiles(ColumnNullable)
+@compiles(ColumnNullable) # type: ignore[misc]
def visit_column_nullable(
element: ColumnNullable, compiler: DDLCompiler, **kw
) -> str:
)
-@compiles(ColumnType)
+@compiles(ColumnType) # type: ignore[misc]
def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
)
-@compiles(ColumnName)
+@compiles(ColumnName) # type: ignore[misc]
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
)
-@compiles(ColumnDefault)
+@compiles(ColumnDefault) # type: ignore[misc]
def visit_column_default(
element: ColumnDefault, compiler: DDLCompiler, **kw
) -> str:
)
-@compiles(ComputedColumnDefault)
+@compiles(ComputedColumnDefault) # type: ignore[misc]
def visit_computed_column(
element: ComputedColumnDefault, compiler: DDLCompiler, **kw
):
)
-@compiles(IdentityColumnDefault)
+@compiles(IdentityColumnDefault) # type: ignore[misc]
def visit_identity_column(
element: IdentityColumnDefault, compiler: DDLCompiler, **kw
):
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import logging
from . import _autogen
from . import base
-from ._autogen import _constraint_sig
-from ._autogen import ComparisonResult
+from ._autogen import _constraint_sig as _constraint_sig
+from ._autogen import ComparisonResult as ComparisonResult
from .. import util
from ..util import sqla_compat
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import re
from typing import Union
from sqlalchemy import types as sqltypes
-from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.base import Executable
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
+from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import re
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
-from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import AlterColumn
from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
+from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
) -> bool:
return (
type_ is not None
- and type_._type_affinity # type:ignore[attr-defined]
- is sqltypes.DateTime
+ and type_._type_affinity is sqltypes.DateTime
and server_default is not None
)
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import re
from typing import Optional
from typing import TYPE_CHECKING
-from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
from .base import AddColumn
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
+from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import logging
from .base import alter_table
from .base import AlterColumn
from .base import ColumnComment
-from .base import compiles
from .base import format_column_name
from .base import format_table_name
from .base import format_type
from ..operations.base import BatchOperations
from ..operations.base import Operations
from ..util import sqla_compat
+from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from typing import Literal
metadata_default = literal_column(metadata_default)
# run a real compare against the server
- return not self.connection.scalar(
+ conn = self.connection
+ assert conn is not None
+ return not conn.scalar(
sqla_compat._select(
literal_column(conn_col_default) == metadata_default
)
return cls(
constraint.name,
constraint_table.name,
- [
- (expr, op)
- for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
+ [ # type: ignore
+ (expr, op) for expr, name, op in constraint._render_exprs
],
where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
expr,
name,
oper,
- ) in excl._render_exprs: # type:ignore[attr-defined]
+ ) in excl._render_exprs:
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
constraint_name: str,
*elements: Any,
**kw: Any,
- ):
+ ) -> Optional[Table]:
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
args = [
"(%s, %r)"
% (
- _render_potential_column(sqltext, autogen_context),
+ _render_potential_column(
+ sqltext, # type:ignore[arg-type]
+ autogen_context,
+ ),
opstring,
)
- for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
+ for sqltext, name, opstring in constraint._render_exprs
]
if constraint.where is not None:
args.append(
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import re
from sqlalchemy import JSON
from sqlalchemy import schema
from sqlalchemy import sql
-from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import format_table_name
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
+from ..util.sqla_compat import compiles
if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
def add_constraint(self, const: Constraint):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
- if const._create_rule is None: # type:ignore[attr-defined]
+ if const._create_rule is None:
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
- elif const._create_rule(self): # type:ignore[attr-defined]
+ elif const._create_rule(self):
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint. "
)
def drop_constraint(self, const: Constraint):
- if const._create_rule is None: # type:ignore[attr-defined]
+ if const._create_rule is None:
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
new_type: TypeEngine,
) -> None:
if (
- existing.type._type_affinity # type:ignore[attr-defined]
- is not new_type._type_affinity # type:ignore[attr-defined]
+ existing.type._type_affinity is not new_type._type_affinity
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
from typing import Literal
from typing import Mapping
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import Type
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.util import immutabledict
- from .operations.ops import BatchOperations
+ from .operations.base import BatchOperations
+ from .operations.ops import AddColumnOp
+ from .operations.ops import AddConstraintOp
+ from .operations.ops import AlterColumnOp
+ from .operations.ops import AlterTableOp
+ from .operations.ops import BulkInsertOp
+ from .operations.ops import CreateIndexOp
+ from .operations.ops import CreateTableCommentOp
+ from .operations.ops import CreateTableOp
+ from .operations.ops import DropColumnOp
+ from .operations.ops import DropConstraintOp
+ from .operations.ops import DropIndexOp
+ from .operations.ops import DropTableCommentOp
+ from .operations.ops import DropTableOp
+ from .operations.ops import ExecuteSQLOp
from .operations.ops import MigrateOperation
from .runtime.migration import MigrationContext
from .util.sqla_compat import _literal_bindparam
_T = TypeVar("_T")
+_C = TypeVar("_C", bound=Callable[..., Any])
+
### end imports ###
def add_column(
comment: Union[str, Literal[False], None] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
- type_: Union[TypeEngine, Type[TypeEngine], None] = None,
- existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
+ type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
+ existing_type: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
existing_server_default: Union[
str, bool, Identity, Computed, None
] = False,
table_name: str,
schema: Optional[str] = None,
recreate: Literal["auto", "always", "never"] = "auto",
- partial_reordering: Optional[tuple] = None,
+ partial_reordering: Optional[Tuple[Any, ...]] = None,
copy_from: Optional[Table] = None,
table_args: Tuple[Any, ...] = (),
table_kwargs: Mapping[str, Any] = immutabledict({}),
def bulk_insert(
table: Union[Table, TableClause],
- rows: List[dict],
+ rows: List[Dict[str, Any]],
*,
multiinsert: bool = True,
) -> None:
"""
-def implementation_for(op_cls: Any) -> Callable[..., Any]:
+def implementation_for(op_cls: Any) -> Callable[[_C], _C]:
"""Register an implementation for a given :class:`.MigrateOperation`.
This is part of the operation extensibility API.
"""
def inline_literal(
- value: Union[str, int], type_: Optional[TypeEngine] = None
+ value: Union[str, int], type_: Optional[TypeEngine[Any]] = None
) -> _literal_bindparam:
r"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
"""
+@overload
+def invoke(operation: CreateTableOp) -> Table: ...
+@overload
+def invoke(
+ operation: Union[
+ AddConstraintOp,
+ DropConstraintOp,
+ CreateIndexOp,
+ DropIndexOp,
+ AddColumnOp,
+ AlterColumnOp,
+ AlterTableOp,
+ CreateTableCommentOp,
+ DropTableCommentOp,
+ DropColumnOp,
+ BulkInsertOp,
+ DropTableOp,
+ ExecuteSQLOp,
+ ]
+) -> None: ...
+@overload
def invoke(operation: MigrateOperation) -> Any:
"""Given a :class:`.MigrateOperation`, invoke it in terms of
this :class:`.Operations` instance.
def register_operation(
name: str, sourcename: Optional[str] = None
-) -> Callable[[_T], _T]:
+) -> Callable[[Type[_T]], Type[_T]]:
"""Register a new operation for this class.
This method is normally used to add new operations
+# mypy: allow-untyped-calls
+
from __future__ import annotations
from contextlib import contextmanager
from typing import Iterator
from typing import List # noqa
from typing import Mapping
+from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence # noqa
from typing import Tuple
from typing import Type # noqa
from sqlalchemy.types import TypeEngine
from .batch import BatchOperationsImpl
+ from .ops import AddColumnOp
+ from .ops import AddConstraintOp
+ from .ops import AlterColumnOp
+ from .ops import AlterTableOp
+ from .ops import BulkInsertOp
+ from .ops import CreateIndexOp
+ from .ops import CreateTableCommentOp
+ from .ops import CreateTableOp
+ from .ops import DropColumnOp
+ from .ops import DropConstraintOp
+ from .ops import DropIndexOp
+ from .ops import DropTableCommentOp
+ from .ops import DropTableOp
+ from .ops import ExecuteSQLOp
from .ops import MigrateOperation
from ..ddl import DefaultImpl
from ..runtime.migration import MigrationContext
__all__ = ("Operations", "BatchOperations")
_T = TypeVar("_T")
+_C = TypeVar("_C", bound=Callable[..., Any])
+
class AbstractOperations(util.ModuleClsProxy):
"""Base class for Operations and BatchOperations.
@classmethod
def register_operation(
cls, name: str, sourcename: Optional[str] = None
- ) -> Callable[[_T], _T]:
+ ) -> Callable[[Type[_T]], Type[_T]]:
"""Register a new operation for this class.
This method is normally used to add new operations
"""
- def register(op_cls):
+ def register(op_cls: Type[_T]) -> Type[_T]:
if sourcename is None:
fn = getattr(op_cls, name)
source_name = fn.__name__
*spec, formatannotation=formatannotation_fwdref
)
num_defaults = len(spec[3]) if spec[3] else 0
+
+ defaulted_vals: Tuple[Any, ...]
+
if num_defaults:
- defaulted_vals = name_args[0 - num_defaults :]
+ defaulted_vals = tuple(name_args[0 - num_defaults :])
else:
defaulted_vals = ()
globals_ = dict(globals())
globals_.update({"op_cls": op_cls})
- lcl = {}
+ lcl: Dict[str, Any] = {}
exec(func_text, globals_, lcl)
setattr(cls, name, lcl[name])
return register
@classmethod
- def implementation_for(cls, op_cls: Any) -> Callable[..., Any]:
+ def implementation_for(cls, op_cls: Any) -> Callable[[_C], _C]:
"""Register an implementation for a given :class:`.MigrateOperation`.
This is part of the operation extensibility API.
"""
- def decorate(fn):
+ def decorate(fn: _C) -> _C:
cls._to_impl.dispatch_for(op_cls)(fn)
return fn
table_name: str,
schema: Optional[str] = None,
recreate: Literal["auto", "always", "never"] = "auto",
- partial_reordering: Optional[tuple] = None,
+ partial_reordering: Optional[Tuple[Any, ...]] = None,
copy_from: Optional[Table] = None,
table_args: Tuple[Any, ...] = (),
table_kwargs: Mapping[str, Any] = util.immutabledict(),
return self.migration_context
+ @overload
+ def invoke(self, operation: CreateTableOp) -> Table:
+ ...
+
+ @overload
+ def invoke(
+ self,
+ operation: Union[
+ AddConstraintOp,
+ DropConstraintOp,
+ CreateIndexOp,
+ DropIndexOp,
+ AddColumnOp,
+ AlterColumnOp,
+ AlterTableOp,
+ CreateTableCommentOp,
+ DropTableCommentOp,
+ DropColumnOp,
+ BulkInsertOp,
+ DropTableOp,
+ ExecuteSQLOp,
+ ],
+ ) -> None:
+ ...
+
+ @overload
+ def invoke(self, operation: MigrateOperation) -> Any:
+ ...
+
def invoke(self, operation: MigrateOperation) -> Any:
"""Given a :class:`.MigrateOperation`, invoke it in terms of
this :class:`.Operations` instance.
comment: Union[str, Literal[False], None] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
- type_: Union[TypeEngine, Type[TypeEngine], None] = None,
- existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
+ type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
+ existing_type: Union[
+ TypeEngine[Any], Type[TypeEngine[Any]], None
+ ] = None,
existing_server_default: Union[
str, bool, Identity, Computed, None
] = False,
def bulk_insert(
self,
table: Union[Table, TableClause],
- rows: List[dict],
+ rows: List[Dict[str, Any]],
*,
multiinsert: bool = True,
) -> None:
impl: BatchOperationsImpl
- def _noop(self, operation):
+ def _noop(self, operation: Any) -> NoReturn:
raise NotImplementedError(
"The %s method does not apply to a batch table alter operation."
% operation
comment: Union[str, Literal[False], None] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
- type_: Union[TypeEngine, Type[TypeEngine], None] = None,
- existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
+ type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None,
+ existing_type: Union[
+ TypeEngine[Any], Type[TypeEngine[Any]], None
+ ] = None,
existing_server_default: Union[
str, bool, Identity, Computed, None
] = False,
def create_exclude_constraint(
self, constraint_name: str, *elements: Any, **kw: Any
- ):
+ ) -> Optional[Table]:
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
from typing import Any
from sqlalchemy import schema as sql_schema
from sqlalchemy import Table
from sqlalchemy import types as sqltypes
-from sqlalchemy.events import SchemaEventTarget
+from sqlalchemy.sql.schema import SchemaEventTarget
from sqlalchemy.util import OrderedDict
from sqlalchemy.util import topological
for idx_existing in self.indexes.values():
# this is a lift-and-move from Table.to_metadata
- if idx_existing._column_flag: # type: ignore
+ if idx_existing._column_flag:
continue
idx_copy = Index(
def _setup_referent(
self, metadata: MetaData, constraint: ForeignKeyConstraint
) -> None:
- spec = constraint.elements[
- 0
- ]._get_colspec() # type:ignore[attr-defined]
+ spec = constraint.elements[0]._get_colspec()
parts = spec.split(".")
tname = parts[-2]
if len(parts) == 3:
else:
sql_schema.DefaultClause(
server_default # type: ignore[arg-type]
- )._set_parent( # type:ignore[attr-defined]
- existing
- )
+ )._set_parent(existing)
if autoincrement is not None:
existing.autoincrement = bool(autoincrement)
from typing import Any
from typing import Callable
from typing import cast
+from typing import Dict
from typing import FrozenSet
from typing import Iterator
from typing import List
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
from sqlalchemy.types import NULLTYPE
from ..runtime.migration import MigrationContext
from ..script.revision import _RevIdType
+_T = TypeVar("_T", bound=Any)
+_AC = TypeVar("_AC", bound="AddConstraintOp")
+
class MigrateOperation:
"""base class for migration command and organization objects.
"""
@util.memoized_property
- def info(self):
+ def info(self) -> Dict[Any, Any]:
"""A dictionary that may be used to store arbitrary information
along with this :class:`.MigrateOperation` object.
add_constraint_ops = util.Dispatcher()
@property
- def constraint_type(self):
+ def constraint_type(self) -> str:
raise NotImplementedError()
@classmethod
- def register_add_constraint(cls, type_: str) -> Callable:
- def go(klass):
+ def register_add_constraint(
+ cls, type_: str
+ ) -> Callable[[Type[_AC]], Type[_AC]]:
+ def go(klass: Type[_AC]) -> Type[_AC]:
cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
return klass
@classmethod
def from_constraint(cls, constraint: Constraint) -> AddConstraintOp:
- return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
+ return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( # type: ignore[no-any-return] # noqa: E501
constraint
)
uq_constraint = cast("UniqueConstraint", constraint)
- kw: dict = {}
+ kw: Dict[str, Any] = {}
if uq_constraint.deferrable:
kw["deferrable"] = uq_constraint.deferrable
if uq_constraint.initially:
@classmethod
def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp:
fk_constraint = cast("ForeignKeyConstraint", constraint)
- kw: dict = {}
+ kw: Dict[str, Any] = {}
if fk_constraint.onupdate:
kw["onupdate"] = fk_constraint.onupdate
if fk_constraint.ondelete:
def from_index(cls, index: Index) -> CreateIndexOp:
assert index.table is not None
return cls(
- index.name, # type: ignore[arg-type]
+ index.name,
index.table.name,
index.expressions,
schema=index.table.schema,
return cls(
table.name,
- list(table.c) + list(table.constraints), # type:ignore[arg-type]
+ list(table.c) + list(table.constraints),
schema=table.schema,
_namespace_metadata=_namespace_metadata,
# given a Table() object, this Table will contain full Index()
)
return operations.invoke(op)
- def reverse(self):
+ def reverse(self) -> Union[CreateTableCommentOp, DropTableCommentOp]:
"""Reverses the COMMENT ON operation against a table."""
if self.existing_comment is None:
return DropTableCommentOp(
schema=self.schema,
)
- def to_table(self, migration_context=None):
+ def to_table(
+ self, migration_context: Optional[MigrationContext] = None
+ ) -> Table:
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(
self.table_name, schema=self.schema, comment=self.comment
)
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[Any, ...]:
return ("add_table_comment", self.to_table(), self.existing_comment)
)
return operations.invoke(op)
- def reverse(self):
+ def reverse(self) -> CreateTableCommentOp:
"""Reverses the COMMENT ON operation against a table."""
return CreateTableCommentOp(
self.table_name, self.existing_comment, schema=self.schema
)
- def to_table(self, migration_context=None):
+ def to_table(
+ self, migration_context: Optional[MigrationContext] = None
+ ) -> Table:
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(self.table_name, schema=self.schema)
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[Any, ...]:
return ("remove_table_comment", self.to_table())
comment: Optional[Union[str, Literal[False]]] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
- type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
- existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
+ type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
+ existing_type: Optional[
+ Union[TypeEngine[Any], Type[TypeEngine[Any]]]
+ ] = None,
existing_server_default: Optional[
Union[str, bool, Identity, Computed]
] = False,
comment: Optional[Union[str, Literal[False]]] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
- type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
- existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None,
+ type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None,
+ existing_type: Optional[
+ Union[TypeEngine[Any], Type[TypeEngine[Any]]]
+ ] = None,
existing_server_default: Optional[
Union[str, bool, Identity, Computed]
] = False,
) -> Tuple[str, Optional[str], str, Column[Any]]:
return ("add_column", self.schema, self.table_name, self.column)
- def to_column(self) -> Column:
+ def to_column(self) -> Column[Any]:
return self.column
@classmethod
- def from_column(cls, col: Column) -> AddColumnOp:
+ def from_column(cls, col: Column[Any]) -> AddColumnOp:
return cls(col.table.name, col, schema=col.table.schema)
@classmethod
def to_column(
self, migration_context: Optional[MigrationContext] = None
- ) -> Column:
+ ) -> Column[Any]:
if self._reverse is not None:
return self._reverse.column
schema_obj = schemaobj.SchemaObjects(migration_context)
def __init__(
self,
table: Union[Table, TableClause],
- rows: List[dict],
+ rows: List[Dict[str, Any]],
*,
multiinsert: bool = True,
) -> None:
cls,
operations: Operations,
table: Union[Table, TableClause],
- rows: List[dict],
+ rows: List[Dict[str, Any]],
*,
multiinsert: bool = True,
) -> None:
self.upgrade_token = upgrade_token
def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps:
- downgrade_ops.ops[:] = list( # type:ignore[index]
+ downgrade_ops.ops[:] = list(
reversed([op.reverse() for op in self.ops])
)
return downgrade_ops
super().__init__(ops=ops)
self.downgrade_token = downgrade_token
- def reverse(self):
+ def reverse(self) -> UpgradeOps:
return UpgradeOps(
ops=list(reversed([op.reverse() for op in self.ops]))
)
"""
_needs_render: Optional[bool]
+ _upgrade_ops: List[UpgradeOps]
+ _downgrade_ops: List[DowngradeOps]
def __init__(
self,
self.downgrade_ops = downgrade_ops
@property
- def upgrade_ops(self):
+ def upgrade_ops(self) -> Optional[UpgradeOps]:
"""An instance of :class:`.UpgradeOps`.
.. seealso::
return self._upgrade_ops[0]
@upgrade_ops.setter
- def upgrade_ops(self, upgrade_ops):
+ def upgrade_ops(
+ self, upgrade_ops: Union[UpgradeOps, List[UpgradeOps]]
+ ) -> None:
self._upgrade_ops = util.to_list(upgrade_ops)
for elem in self._upgrade_ops:
assert isinstance(elem, UpgradeOps)
@property
- def downgrade_ops(self):
+ def downgrade_ops(self) -> Optional[DowngradeOps]:
"""An instance of :class:`.DowngradeOps`.
.. seealso::
return self._downgrade_ops[0]
@downgrade_ops.setter
- def downgrade_ops(self, downgrade_ops):
+ def downgrade_ops(
+ self, downgrade_ops: Union[DowngradeOps, List[DowngradeOps]]
+ ) -> None:
self._downgrade_ops = util.to_list(downgrade_ops)
for elem in self._downgrade_ops:
assert isinstance(elem, DowngradeOps)
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
from typing import Any
ForeignKey.
"""
- if isinstance(fk._colspec, str): # type:ignore[attr-defined]
- table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined]
- ".", 1
- )
+ if isinstance(fk._colspec, str):
+ table_key, cname = fk._colspec.rsplit(".", 1)
sname, tname = self._parse_table_key(table_key)
if table_key not in metadata.tables:
rel_t = sa_schema.Table(tname, metadata, schema=sname)
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from typing import TYPE_CHECKING
from sqlalchemy import schema as sa_schema
has been configured.
"""
- return self.context_opts.get("as_sql", False)
+ return self.context_opts.get("as_sql", False) # type: ignore[no-any-return] # noqa: E501
- def is_transactional_ddl(self):
+ def is_transactional_ddl(self) -> bool:
"""Return True if the context is configured to expect a
transactional DDL capable backend.
line.
"""
- return self.context_opts.get("tag", None)
+ return self.context_opts.get("tag", None) # type: ignore[no-any-return] # noqa: E501
@overload
def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
def execute(
self,
sql: Union[Executable, str],
- execution_options: Optional[dict] = None,
+ execution_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Execute the given SQL using the current change context.
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
from contextlib import contextmanager
start_from_rev = None
elif start_from_rev is not None and self.script:
start_from_rev = [
- cast("Script", self.script.get_revision(sfr)).revision
+ self.script.get_revision(sfr).revision
for sfr in util.to_list(start_from_rev)
if sfr not in (None, "base")
]
def execute(
self,
sql: Union[Executable, str],
- execution_options: Optional[dict] = None,
+ execution_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Execute a SQL construct or string statement.
is_upgrade: bool
migration_fn: Any
+ if TYPE_CHECKING:
+
+ @property
+ def doc(self) -> Optional[str]:
+ ...
+
@property
def name(self) -> str:
return self.migration_fn.__name__
self.revision = revision
self.is_upgrade = is_upgrade
if is_upgrade:
- self.migration_fn = (
- revision.module.upgrade # type:ignore[attr-defined]
- )
+ self.migration_fn = revision.module.upgrade
else:
- self.migration_fn = (
- revision.module.downgrade # type:ignore[attr-defined]
- )
+ self.migration_fn = revision.module.downgrade
def __repr__(self):
return "RevisionStep(%r, is_upgrade=%r)" % (
)
@property
- def doc(self) -> str:
+ def doc(self) -> Optional[str]:
return self.revision.doc
@property
def __eq__(self, other):
return (
isinstance(other, StampStep)
- and other.from_revisions == self.revisions
+ and other.from_revisions == self.from_revisions
and other.to_revisions == self.to_revisions
and other.branch_move == self.branch_move
and self.is_upgrade == other.is_upgrade
from zoneinfo import ZoneInfoNotFoundError
else:
from backports.zoneinfo import ZoneInfo # type: ignore[import-not-found,no-redef] # noqa: E501
- from backports.zoneinfo import ZoneInfoNotFoundError # type: ignore[import-not-found,no-redef] # noqa: E501
+ from backports.zoneinfo import ZoneInfoNotFoundError # type: ignore[no-redef] # noqa: E501
except ImportError:
ZoneInfo = None # type: ignore[assignment, misc]
return loc[0]
@util.memoized_property
- def _version_locations(self):
+ def _version_locations(self) -> Sequence[str]:
if self.version_locations:
return [
os.path.abspath(util.coerce_resource_to_filename(location))
):
yield cast(Script, rev)
- def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]:
+ def get_revisions(self, id_: _GetRevArg) -> Tuple[Script, ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
"""
with self._catch_revision_errors():
return cast(
- Tuple[Optional[Script], ...],
+ Tuple[Script, ...],
self.revision_map.get_revisions(id_),
)
- def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
+ def get_all_current(self, id_: Tuple[str, ...]) -> Set[Script]:
with self._catch_revision_errors():
- return cast(
- Set[Optional[Script]], self.revision_map._get_all_current(id_)
- )
+ return cast(Set[Script], self.revision_map._get_all_current(id_))
- def get_revision(self, id_: str) -> Optional[Script]:
+ def get_revision(self, id_: str) -> Script:
"""Return the :class:`.Script` instance with the given rev id.
.. seealso::
"""
with self._catch_revision_errors():
- return cast(Optional[Script], self.revision_map.get_revision(id_))
+ return cast(Script, self.revision_map.get_revision(id_))
def as_revision_number(
self, id_: Optional[str]
util.load_python_file(self.dir, "env.py")
@property
- def env_py_location(self):
+ def env_py_location(self) -> str:
return os.path.abspath(os.path.join(self.dir, "env.py"))
def _generate_template(self, src: str, dest: str, **kw: Any) -> None:
self.revision_map.get_revisions(head),
)
for h in heads:
- assert h != "base"
+ assert h != "base" # type: ignore[comparison-overlap]
if len(set(heads)) != len(heads):
raise util.CommandError("Duplicate head revisions specified")
self.path = path
super().__init__(
rev_id,
- module.down_revision, # type: ignore[attr-defined]
+ module.down_revision,
branch_labels=util.to_tuple(
getattr(module, "branch_labels", None), default=()
),
if doc:
if hasattr(self.module, "_alembic_source_encoding"):
doc = doc.decode( # type: ignore[attr-defined]
- self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa
+ self.module._alembic_source_encoding
)
return doc.strip() # type: ignore[union-attr]
else:
)
return entry
- def __str__(self):
+ def __str__(self) -> str:
return "%s -> %s%s%s%s, %s" % (
self._format_down_revision(),
self.revision,
from typing import List
from typing import Optional
from typing import overload
+from typing import Protocol
from typing import Sequence
from typing import Set
from typing import Tuple
_revision_illegal_chars = ["@", "-", "+"]
+class _CollectRevisionsProtocol(Protocol):
+ def __call__(
+ self,
+ upper: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
+ inclusive: bool,
+ implicit_base: bool,
+ assert_relative_length: bool,
+ ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]:
+ ...
+
+
class RevisionError(Exception):
pass
for rev in self._get_ancestor_nodes(
[revision],
include_dependencies=False,
- map_=cast(_RevisionMapType, map_),
+ map_=map_,
):
if rev is revision:
continue
The iterator yields :class:`.Revision` objects.
"""
- fn: Callable
+ fn: _CollectRevisionsProtocol
if select_for_downgrade:
fn = self._collect_downgrade_revisions
else:
) -> Iterator[Any]:
if omit_immediate_dependencies:
- def fn(rev):
+ def fn(rev: Revision) -> Iterable[str]:
if rev not in targets:
return rev._all_nextrev
else:
elif include_dependencies:
- def fn(rev):
+ def fn(rev: Revision) -> Iterable[str]:
return rev._all_nextrev
else:
- def fn(rev):
+ def fn(rev: Revision) -> Iterable[str]:
return rev.nextrev
return self._iterate_related_revisions(
) -> Iterator[Revision]:
if include_dependencies:
- def fn(rev):
+ def fn(rev: Revision) -> Iterable[str]:
return rev._normalized_down_revisions
else:
- def fn(rev):
+ def fn(rev: Revision) -> Iterable[str]:
return rev._versioned_down_revisions
return self._iterate_related_revisions(
def _iterate_related_revisions(
self,
- fn: Callable,
+ fn: Callable[[Revision], Iterable[str]],
targets: Collection[Optional[_RevisionOrBase]],
map_: Optional[_RevisionMapType],
check: bool = False,
id_to_rev = self._revision_map
- def get_ancestors(rev_id):
+ def get_ancestors(rev_id: str) -> Set[str]:
return {
r.revision
for r in self._get_ancestor_nodes([id_to_rev[rev_id]])
children: Sequence[Optional[_RevisionOrBase]]
for _ in range(abs(steps)):
if steps > 0:
- assert initial != "base"
+ assert initial != "base" # type: ignore[comparison-overlap]
# Walk up
walk_up = [
is_revision(rev)
children = walk_up
else:
# Walk down
- if initial == "base":
+ if initial == "base": # type: ignore[comparison-overlap]
children = ()
else:
children = self.get_revisions(
# No relative destination given, revision specified is absolute.
branch_label, _, symbol = target.rpartition("@")
if not branch_label:
- branch_label = None # type:ignore[assignment]
+ branch_label = None
return branch_label, self.get_revision(symbol)
def _parse_upgrade_target(
def _collect_downgrade_revisions(
self,
upper: _RevisionIdentifierType,
- target: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
inclusive: bool,
implicit_base: bool,
assert_relative_length: bool,
- ) -> Any:
+ ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]:
"""
Compute the set of current revisions specified by :upper, and the
downgrade target specified by :target. Return all dependents of target
branch_label, target_revision = self._parse_downgrade_target(
current_revisions=upper,
- target=target,
+ target=lower,
assert_relative_length=assert_relative_length,
)
if target_revision == "base":
inclusive: bool,
implicit_base: bool,
assert_relative_length: bool,
- ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]:
+ ) -> Tuple[Set[Revision], Tuple[Revision, ...]]:
"""
Compute the set of required revisions specified by :upper, and the
current set of active revisions specified by :lower. Find the
)
needs.intersection_update(lower_descendents)
- return needs, tuple(targets) # type:ignore[return-value]
+ return needs, tuple(targets)
def _get_all_current(
self, id_: Tuple[str, ...]
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import shlex
-from .editor import open_in_editor
-from .exc import AutogenerateDiffsDetected
-from .exc import CommandError
-from .langhelpers import _with_legacy_names
-from .langhelpers import asbool
-from .langhelpers import dedupe_tuple
-from .langhelpers import Dispatcher
-from .langhelpers import EMPTY_DICT
-from .langhelpers import immutabledict
-from .langhelpers import memoized_property
-from .langhelpers import ModuleClsProxy
-from .langhelpers import not_none
-from .langhelpers import rev_id
-from .langhelpers import to_list
-from .langhelpers import to_tuple
-from .langhelpers import unique_list
-from .messaging import err
-from .messaging import format_as_comma
-from .messaging import msg
-from .messaging import obfuscate_url_pw
-from .messaging import status
-from .messaging import warn
-from .messaging import write_outstream
-from .pyfiles import coerce_resource_to_filename
-from .pyfiles import load_python_file
-from .pyfiles import pyc_file_from_path
-from .pyfiles import template_to_file
-from .sqla_compat import has_computed
-from .sqla_compat import sqla_13
-from .sqla_compat import sqla_14
-from .sqla_compat import sqla_2
+from .editor import open_in_editor as open_in_editor
+from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected
+from .exc import CommandError as CommandError
+from .langhelpers import _with_legacy_names as _with_legacy_names
+from .langhelpers import asbool as asbool
+from .langhelpers import dedupe_tuple as dedupe_tuple
+from .langhelpers import Dispatcher as Dispatcher
+from .langhelpers import EMPTY_DICT as EMPTY_DICT
+from .langhelpers import immutabledict as immutabledict
+from .langhelpers import memoized_property as memoized_property
+from .langhelpers import ModuleClsProxy as ModuleClsProxy
+from .langhelpers import not_none as not_none
+from .langhelpers import rev_id as rev_id
+from .langhelpers import to_list as to_list
+from .langhelpers import to_tuple as to_tuple
+from .langhelpers import unique_list as unique_list
+from .messaging import err as err
+from .messaging import format_as_comma as format_as_comma
+from .messaging import msg as msg
+from .messaging import obfuscate_url_pw as obfuscate_url_pw
+from .messaging import status as status
+from .messaging import warn as warn
+from .messaging import write_outstream as write_outstream
+from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename
+from .pyfiles import load_python_file as load_python_file
+from .pyfiles import pyc_file_from_path as pyc_file_from_path
+from .pyfiles import template_to_file as template_to_file
+from .sqla_compat import has_computed as has_computed
+from .sqla_compat import sqla_13 as sqla_13
+from .sqla_compat import sqla_14 as sqla_14
+from .sqla_compat import sqla_2 as sqla_2
if not sqla_13:
+# mypy: no-warn-unused-ignores
+
from __future__ import annotations
from configparser import ConfigParser
import os
import sys
import typing
+from typing import Any
+from typing import List
+from typing import Optional
from typing import Sequence
from typing import Union
-from sqlalchemy.util import inspect_getfullargspec # noqa
-from sqlalchemy.util.compat import inspect_formatargspec # noqa
+if True:
+ # zimports hack for too-long names
+ from sqlalchemy.util import ( # noqa: F401
+ inspect_getfullargspec as inspect_getfullargspec,
+ )
+ from sqlalchemy.util.compat import ( # noqa: F401
+ inspect_formatargspec as inspect_formatargspec,
+ )
is_posix = os.name == "posix"
if py39:
- from importlib import resources as importlib_resources
- from importlib import metadata as importlib_metadata
- from importlib.metadata import EntryPoint
+ from importlib import resources as _resources
+
+ importlib_resources = _resources
+ from importlib import metadata as _metadata
+
+ importlib_metadata = _metadata
+ from importlib.metadata import EntryPoint as EntryPoint
else:
import importlib_resources # type:ignore # noqa
import importlib_metadata # type:ignore # noqa
def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
ep = importlib_metadata.entry_points()
if hasattr(ep, "select"):
- return ep.select(group=group) # type: ignore
+ return ep.select(group=group)
else:
return ep.get(group, ()) # type: ignore
-def formatannotation_fwdref(annotation, base_module=None):
+def formatannotation_fwdref(
+ annotation: Any, base_module: Optional[Any] = None
+) -> str:
"""vendored from python 3.7"""
# copied over _formatannotation from sqlalchemy 2.0
def read_config_parser(
file_config: ConfigParser,
file_argument: Sequence[Union[str, os.PathLike[str]]],
-) -> list[str]:
+) -> List[str]:
if py310:
return file_config.read(file_argument, encoding="locale")
else:
import textwrap
from typing import Any
from typing import Callable
+from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
+from typing import MutableMapping
+from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import uuid
import warnings
-from sqlalchemy.util import asbool # noqa
-from sqlalchemy.util import immutabledict # noqa
-from sqlalchemy.util import memoized_property # noqa
-from sqlalchemy.util import to_list # noqa
-from sqlalchemy.util import unique_list # noqa
+from sqlalchemy.util import asbool as asbool # noqa: F401
+from sqlalchemy.util import immutabledict as immutabledict # noqa: F401
+from sqlalchemy.util import to_list as to_list # noqa: F401
+from sqlalchemy.util import unique_list as unique_list
from .compat import inspect_getfullargspec
+if True:
+ # zimports workaround :(
+ from sqlalchemy.util import ( # noqa: F401
+ memoized_property as memoized_property,
+ )
+
EMPTY_DICT: Mapping[Any, Any] = immutabledict()
-_T = TypeVar("_T")
+_T = TypeVar("_T", bound=Any)
+
+_C = TypeVar("_C", bound=Callable[..., Any])
class _ModuleClsMeta(type):
- def __setattr__(cls, key: str, value: Callable) -> None:
+ def __setattr__(cls, key: str, value: Callable[..., Any]) -> None:
super().__setattr__(key, value)
cls._update_module_proxies(key) # type: ignore
"""
- _setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
- lambda: (set(), [])
- )
+ _setups: Dict[
+ Type[Any],
+ Tuple[
+ Set[str],
+ List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]],
+ ],
+ ] = collections.defaultdict(lambda: (set(), []))
@classmethod
def _update_module_proxies(cls, name: str) -> None:
del globals_[attr_name]
@classmethod
- def create_module_class_proxy(cls, globals_, locals_):
+ def create_module_class_proxy(
+ cls,
+ globals_: MutableMapping[str, Any],
+ locals_: MutableMapping[str, Any],
+ ) -> None:
attr_names, modules = cls._setups[cls]
modules.append((globals_, locals_))
cls._setup_proxy(globals_, locals_, attr_names)
@classmethod
- def _setup_proxy(cls, globals_, locals_, attr_names):
+ def _setup_proxy(
+ cls,
+ globals_: MutableMapping[str, Any],
+ locals_: MutableMapping[str, Any],
+ attr_names: Set[str],
+ ) -> None:
for methname in dir(cls):
cls._add_proxied_attribute(methname, globals_, locals_, attr_names)
@classmethod
- def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names):
+ def _add_proxied_attribute(
+ cls,
+ methname: str,
+ globals_: MutableMapping[str, Any],
+ locals_: MutableMapping[str, Any],
+ attr_names: Set[str],
+ ) -> None:
if not methname.startswith("_"):
meth = getattr(cls, methname)
if callable(meth):
attr_names.add(methname)
@classmethod
- def _create_method_proxy(cls, name, globals_, locals_):
+ def _create_method_proxy(
+ cls,
+ name: str,
+ globals_: MutableMapping[str, Any],
+ locals_: MutableMapping[str, Any],
+ ) -> Callable[..., Any]:
fn = getattr(cls, name)
- def _name_error(name, from_):
+ def _name_error(name: str, from_: Exception) -> NoReturn:
raise NameError(
"Can't invoke function '%s', as the proxy object has "
"not yet been "
translations,
)
- def translate(fn_name, spec, translations, args, kw):
+ def translate(
+ fn_name: str, spec: Any, translations: Any, args: Any, kw: Any
+ ) -> Any:
return_kw = {}
return_args = []
"doc": fn.__doc__,
}
)
- lcl = {}
+ lcl: MutableMapping[str, Any] = {}
- exec(func_text, globals_, lcl)
- return lcl[name]
+ exec(func_text, cast("Dict[str, Any]", globals_), lcl)
+ return cast("Callable[..., Any]", lcl[name])
-def _with_legacy_names(translations):
- def decorate(fn):
- fn._legacy_translations = translations
+def _with_legacy_names(translations: Any) -> Any:
+ def decorate(fn: _C) -> _C:
+ fn._legacy_translations = translations # type: ignore[attr-defined]
return fn
return decorate
@overload
-def to_tuple(x: Any, default: tuple) -> tuple:
+def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]:
...
@overload
-def to_tuple(x: None, default: Optional[_T] = None) -> _T:
+def to_tuple(x: None, default: Optional[_T] = ...) -> _T:
...
@overload
-def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
+def to_tuple(
+ x: Any, default: Optional[Tuple[Any, ...]] = None
+) -> Tuple[Any, ...]:
...
-def to_tuple(x, default=None):
+def to_tuple(
+ x: Any, default: Optional[Tuple[Any, ...]] = None
+) -> Optional[Tuple[Any, ...]]:
if x is None:
return default
elif isinstance(x, str):
class Dispatcher:
def __init__(self, uselist: bool = False) -> None:
- self._registry: Dict[tuple, Any] = {}
+ self._registry: Dict[Tuple[Any, ...], Any] = {}
self.uselist = uselist
def dispatch_for(
self, target: Any, qualifier: str = "default"
- ) -> Callable:
- def decorate(fn):
+ ) -> Callable[[_C], _C]:
+ def decorate(fn: _C) -> _C:
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
else:
def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
if isinstance(obj, str):
- targets: Sequence = [obj]
+ targets: Sequence[Any] = [obj]
elif isinstance(obj, type):
targets = obj.__mro__
else:
raise ValueError("no dispatch function for object: %s" % obj)
def _fn_or_list(
- self, fn_or_list: Union[List[Callable], Callable]
- ) -> Callable:
+ self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]]
+ ) -> Callable[..., Any]:
if self.uselist:
- def go(*arg, **kw):
+ def go(*arg: Any, **kw: Any) -> None:
+ if TYPE_CHECKING:
+ assert isinstance(fn_or_list, Sequence)
for fn in fn_or_list:
fn(*arg, **kw)
import logging
import sys
import textwrap
+from typing import Iterator
from typing import Optional
from typing import TextIO
from typing import Union
@contextmanager
-def status(status_msg: str, newline: bool = False, quiet: bool = False):
+def status(
+ status_msg: str, newline: bool = False, quiet: bool = False
+) -> Iterator[None]:
msg(status_msg + " ...", newline, flush=True, quiet=quiet)
try:
yield
write_outstream(sys.stdout, " done\n")
-def err(message: str, quiet: bool = False):
+def err(message: str, quiet: bool = False) -> None:
log.error(message)
msg(f"FAILED: {message}", quiet=quiet)
sys.exit(-1)
def obfuscate_url_pw(input_url: str) -> str:
u = url.make_url(input_url)
- return sqla_compat.url_render_as_string(u, hide_password=True)
+ return sqla_compat.url_render_as_string(u, hide_password=True) # type: ignore # noqa: E501
def warn(msg: str, stacklevel: int = 2) -> None:
import os
import re
import tempfile
+from types import ModuleType
+from typing import Any
from typing import Optional
from mako import exceptions
def template_to_file(
- template_file: str, dest: str, output_encoding: str, **kw
+ template_file: str, dest: str, output_encoding: str, **kw: Any
) -> None:
template = Template(filename=template_file)
try:
return None
-def load_python_file(dir_: str, filename: str):
+def load_python_file(dir_: str, filename: str) -> ModuleType:
"""Load a file from the given path as a Python module."""
module_id = re.sub(r"\W", "_", filename)
module = load_module_py(module_id, pyc_path)
elif ext in (".pyc", ".pyo"):
module = load_module_py(module_id, path)
+ else:
+ assert False
return module
-def load_module_py(module_id: str, path: str):
+def load_module_py(module_id: str, path: str) -> ModuleType:
spec = importlib.util.spec_from_file_location(module_id, path)
assert spec
module = importlib.util.module_from_spec(spec)
+# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
+# mypy: no-warn-return-any, allow-any-generics
+
from __future__ import annotations
import contextlib
import re
from typing import Any
+from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Optional
+from typing import Protocol
+from typing import Set
+from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.engine import url
-from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from typing_extensions import TypeGuard
if TYPE_CHECKING:
+ from sqlalchemy import ClauseElement
from sqlalchemy import Index
from sqlalchemy import Table
from sqlalchemy.engine import Connection
_CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"])
+class _CompilerProtocol(Protocol):
+ def __call__(self, element: Any, compiler: Any, **kw: Any) -> str:
+ ...
+
+
def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
sqlalchemy_version = __version__
try:
- from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME
+ from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501
except ImportError:
from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501
"Placeholder for unsupported SQLAlchemy classes"
+if TYPE_CHECKING:
+
+ def compiles(
+ element: Type[ClauseElement], *dialects: str
+ ) -> Callable[[_CompilerProtocol], _CompilerProtocol]:
+ ...
+
+else:
+ from sqlalchemy.ext.compiler import compiles
+
try:
- from sqlalchemy import Computed
+ from sqlalchemy import Computed as Computed
except ImportError:
if not TYPE_CHECKING:
has_computed_reflection = _vers >= (1, 3, 16)
try:
- from sqlalchemy import Identity
+ from sqlalchemy import Identity as Identity
except ImportError:
if not TYPE_CHECKING:
def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
- return schema_item._copy(**kw) # type: ignore[union-attr]
+ return schema_item._copy(**kw)
else:
return schema_item.copy(**kw) # type: ignore[union-attr]
return type_.impl, type_.mapping
-def _fk_spec(constraint):
+def _fk_spec(constraint: ForeignKeyConstraint) -> Any:
+ if TYPE_CHECKING:
+ assert constraint.columns is not None
+ assert constraint.elements is not None
+ assert isinstance(constraint.parent, Table)
+
source_columns = [
constraint.columns[key].name for key in constraint.column_keys
]
def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
- spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
+ spec = constraint.elements[0]._get_colspec()
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
- return constraint._type_bound # type: ignore[attr-defined]
+ return constraint._type_bound
def _find_columns(clause):
"""locate Column objects within the given expression."""
- cols = set()
+ cols: Set[ColumnElement[Any]] = set()
traverse(clause, {}, {"column": cols.add})
return cols
if isinstance(constraint, schema.Index):
# name should not be quoted.
d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
- return d._prepared_index_name( # type: ignore[attr-defined]
- constraint
- )
+ return d._prepared_index_name(constraint)
else:
# name should not be quoted.
return dialect.identifier_preparer.format_constraint(constraint)
if sqla_14:
from sqlalchemy import create_mock_engine
- from sqlalchemy import select as _select
+
+ # weird mypy workaround
+ from sqlalchemy import select as _sa_select
+
+ _select = _sa_select
else:
from sqlalchemy import create_engine
"postgresql://", strategy="mock", executor=executor
)
- def _select(*columns, **kw) -> Select: # type: ignore[no-redef]
+ def _select(*columns, **kw) -> Select:
return sql.select(list(columns), **kw) # type: ignore[call-overload]
--- /dev/null
+.. change::
+ :tags: bug, typing
+ :tickets: 1377
+
+ Updated pep-484 typing to pass mypy "strict" mode, however including
+ per-module qualifications for specific typing elements not yet complete.
+ This allows us to catch specific typing issues that have been ongoing
+ such as import symbols not properly exported.
+
show_error_codes = true
[[tool.mypy.overrides]]
+
module = [
- 'alembic.operations.ops',
- 'alembic.op',
- 'alembic.context',
- 'alembic.autogenerate.api',
- 'alembic.runtime.*',
+ "alembic.*"
]
-disallow_incomplete_defs = true
+warn_unused_ignores = true
+strict = true
+
+
[[tool.mypy.overrides]]
module = [
markers =
backend: tests that should run on all backends; typically dialect-sensitive
-[mypy]
-show_error_codes = True
-allow_redefinition = True
-
-[mypy-mako.*]
-ignore_missing_imports = True
-
-[mypy-sqlalchemy.testing.*]
-ignore_missing_imports = True
-
-[mypy-importlib_resources.*]
-ignore_missing_imports = True
-
-[mypy-importlib_metadata.*]
-ignore_missing_imports = True
{"entrypoint": "zimports", "options": "-e"},
ignore_output=ignore_output,
)
- # note that we do not distribute pyproject.toml with the distribution
- # right now due to user complaints, so we can't refer to it here because
- # this all has to run as part of the test suite
+
console_scripts(
str(destination_path),
{"entrypoint": "black", "options": "-l79"},
else:
retval = annotation
+ retval = re.sub(r"TypeEngine\b", "TypeEngine[Any]", retval)
+
retval = retval.replace("~", "") # typevar repr as "~T"
for trim in TRIM_MODULE:
retval = retval.replace(trim, "")