from . import context
from . import op
-from .runtime import environment
-from .runtime import migration
__version__ = "1.7.0"
-
-sys.modules["alembic.migration"] = migration
-sys.modules["alembic.environment"] = environment
automatically."""
import contextlib
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Iterator
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import inspect
from .. import util
from ..operations import ops
-
-def compare_metadata(context, metadata):
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine import Inspector
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+
+ from alembic.config import Config
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import UpgradeOps
+ from alembic.runtime.migration import MigrationContext
+ from alembic.script.base import Script
+ from alembic.script.base import ScriptDirectory
+
+
+def compare_metadata(context: "MigrationContext", metadata: "MetaData") -> Any:
"""Compare a database schema to that given in a
:class:`~sqlalchemy.schema.MetaData` instance.
return migration_script.upgrade_ops.as_diffs()
-def produce_migrations(context, metadata):
+def produce_migrations(
+ context: "MigrationContext", metadata: "MetaData"
+) -> "MigrationScript":
"""Produce a :class:`.MigrationScript` structure based on schema
comparison.
def render_python_code(
- up_or_down_op,
- sqlalchemy_module_prefix="sa.",
- alembic_module_prefix="op.",
- render_as_batch=False,
- imports=(),
- render_item=None,
- migration_context=None,
-):
+ up_or_down_op: "UpgradeOps",
+ sqlalchemy_module_prefix: str = "sa.",
+ alembic_module_prefix: str = "op.",
+ render_as_batch: bool = False,
+ imports: Tuple[str, ...] = (),
+ render_item: None = None,
+ migration_context: Optional["MigrationContext"] = None,
+) -> str:
"""Render Python code given an :class:`.UpgradeOps` or
:class:`.DowngradeOps` object.
)
-def _render_migration_diffs(context, template_args):
+def _render_migration_diffs(
+ context: "MigrationContext", template_args: Dict[Any, Any]
+) -> None:
"""legacy, used by test_autogen_composition at the moment"""
autogen_context = AutogenContext(context)
"""Maintains configuration and state that's specific to an
autogenerate operation."""
- metadata = None
+ metadata: Optional["MetaData"] = None
"""The :class:`~sqlalchemy.schema.MetaData` object
representing the destination.
"""
- connection = None
+ connection: Optional["Connection"] = None
"""The :class:`~sqlalchemy.engine.base.Connection` object currently
connected to the database backend being compared.
"""
- dialect = None
+ dialect: Optional["Dialect"] = None
"""The :class:`~sqlalchemy.engine.Dialect` object currently in use.
This is normally obtained from the
"""
- imports = None
+ imports: Set[str] = None # type: ignore[assignment]
"""A ``set()`` which contains string Python import directives.
The directives are to be rendered into the ``${imports}`` section
"""
- migration_context = None
+ migration_context: "MigrationContext" = None # type: ignore[assignment]
"""The :class:`.MigrationContext` established by the ``env.py`` script."""
def __init__(
- self, migration_context, metadata=None, opts=None, autogenerate=True
- ):
+ self,
+ migration_context: "MigrationContext",
+ metadata: Optional["MetaData"] = None,
+ opts: Optional[dict] = None,
+ autogenerate: bool = True,
+ ) -> None:
if (
autogenerate
self.dialect = self.migration_context.dialect
self.imports = set()
- self.opts = opts
- self._has_batch = False
+ self.opts: Dict[str, Any] = opts
+ self._has_batch: bool = False
@util.memoized_property
- def inspector(self):
+ def inspector(self) -> "Inspector":
return inspect(self.connection)
@contextlib.contextmanager
- def _within_batch(self):
+ def _within_batch(self) -> Iterator[None]:
self._has_batch = True
yield
self._has_batch = False
- def run_name_filters(self, name, type_, parent_names):
+ def run_name_filters(
+ self,
+ name: Optional[str],
+ type_: str,
+ parent_names: Dict[str, Optional[str]],
+ ) -> bool:
"""Run the context's name filters and return True if the targets
should be part of the autogenerate operation.
else:
return True
- def run_object_filters(self, object_, name, type_, reflected, compare_to):
+ def run_object_filters(
+ self,
+ object_: Union[
+ "Table",
+ "Index",
+ "Column",
+ "UniqueConstraint",
+ "ForeignKeyConstraint",
+ ],
+ name: Optional[str],
+ type_: str,
+ reflected: bool,
+ compare_to: Optional[
+ Union["Table", "Index", "Column", "UniqueConstraint"]
+ ],
+ ) -> bool:
"""Run the context's object filters and return True if the targets
should be part of the autogenerate operation.
def __init__(
self,
- config,
- script_directory,
- command_args,
- process_revision_directives=None,
- ):
+ config: "Config",
+ script_directory: "ScriptDirectory",
+ command_args: Dict[str, Any],
+ process_revision_directives: Optional[Callable] = None,
+ ) -> None:
self.config = config
self.script_directory = script_directory
self.command_args = command_args
}
self.generated_revisions = [self._default_revision()]
- def _to_script(self, migration_script):
- template_args = {}
- for k, v in self.template_args.items():
- template_args.setdefault(k, v)
+ def _to_script(
+ self, migration_script: "MigrationScript"
+ ) -> Optional["Script"]:
+ template_args: Dict[str, Any] = self.template_args.copy()
if getattr(migration_script, "_needs_render", False):
autogen_context = self._last_autogen_context
autogen_context, migration_script, template_args
)
+ assert migration_script.rev_id is not None
return self.script_directory.generate_revision(
migration_script.rev_id,
migration_script.message,
**template_args
)
- def run_autogenerate(self, rev, migration_context):
+ def run_autogenerate(
+ self, rev: tuple, migration_context: "MigrationContext"
+ ):
self._run_environment(rev, migration_context, True)
- def run_no_autogenerate(self, rev, migration_context):
+ def run_no_autogenerate(
+ self, rev: tuple, migration_context: "MigrationContext"
+ ):
self._run_environment(rev, migration_context, False)
- def _run_environment(self, rev, migration_context, autogenerate):
+ def _run_environment(
+ self,
+ rev: tuple,
+ migration_context: "MigrationContext",
+ autogenerate: bool,
+ ):
if autogenerate:
if self.command_args["sql"]:
raise util.CommandError(
ops.DowngradeOps([], downgrade_token=downgrade_token)
)
- self._last_autogen_context = autogen_context = AutogenContext(
+ autogen_context = AutogenContext(
migration_context, autogenerate=autogenerate
)
+ self._last_autogen_context: AutogenContext = autogen_context
if autogenerate:
compare._populate_migration_script(
for migration_script in self.generated_revisions:
migration_script._needs_render = True
- def _default_revision(self):
+ def _default_revision(self) -> "MigrationScript":
+ command_args: Dict[str, Any] = self.command_args
op = ops.MigrationScript(
- rev_id=self.command_args["rev_id"] or util.rev_id(),
- message=self.command_args["message"],
+ rev_id=command_args["rev_id"] or util.rev_id(),
+ message=command_args["message"],
upgrade_ops=ops.UpgradeOps([]),
downgrade_ops=ops.DowngradeOps([]),
- head=self.command_args["head"],
- splice=self.command_args["splice"],
- branch_label=self.command_args["branch_label"],
- version_path=self.command_args["version_path"],
- depends_on=self.command_args["depends_on"],
+ head=command_args["head"],
+ splice=command_args["splice"],
+ branch_label=command_args["branch_label"],
+ version_path=command_args["version_path"],
+ depends_on=command_args["depends_on"],
)
return op
- def generate_scripts(self):
+ def generate_scripts(self) -> Iterator[Optional["Script"]]:
for generated_revision in self.generated_revisions:
yield self._to_script(generated_revision)
import contextlib
import logging
import re
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import event
from sqlalchemy import inspect
from ..operations import ops
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.operations.ops import AlterColumnOp
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import ModifyTableOps
+ from alembic.operations.ops import UpgradeOps
+
log = logging.getLogger(__name__)
-def _populate_migration_script(autogen_context, migration_script):
+def _populate_migration_script(
+ autogen_context: "AutogenContext", migration_script: "MigrationScript"
+) -> None:
upgrade_ops = migration_script.upgrade_ops_list[-1]
downgrade_ops = migration_script.downgrade_ops_list[-1]
comparators = util.Dispatcher(uselist=True)
-def _produce_net_changes(autogen_context, upgrade_ops):
+def _produce_net_changes(
+ autogen_context: "AutogenContext", upgrade_ops: "UpgradeOps"
+) -> None:
connection = autogen_context.connection
+ assert connection is not None
include_schemas = autogen_context.opts.get("include_schemas", False)
- inspector = inspect(connection)
+ inspector: "Inspector" = inspect(connection)
default_schema = connection.dialect.default_schema_name
+ schemas: Set[Optional[str]]
if include_schemas:
schemas = set(inspector.get_schema_names())
# replace default schema name with None
schemas.discard(default_schema)
schemas.add(None)
else:
- schemas = [None]
+ schemas = {None}
schemas = {
s for s in schemas if autogen_context.run_name_filters(s, "schema", {})
}
+ assert autogen_context.dialect is not None
comparators.dispatch("schema", autogen_context.dialect.name)(
autogen_context, upgrade_ops, schemas
)
@comparators.dispatch_for("schema")
-def _autogen_for_tables(autogen_context, upgrade_ops, schemas):
+def _autogen_for_tables(
+ autogen_context: "AutogenContext",
+ upgrade_ops: "UpgradeOps",
+ schemas: Union[Set[None], Set[Optional[str]]],
+) -> None:
inspector = autogen_context.inspector
- conn_table_names = set()
+ conn_table_names: Set[Tuple[Optional[str], str]] = set()
version_table_schema = (
autogen_context.migration_context.version_table_schema
def _compare_tables(
- conn_table_names,
- metadata_table_names,
- inspector,
- upgrade_ops,
- autogen_context,
-):
+ conn_table_names: "set",
+ metadata_table_names: "set",
+ inspector: "Inspector",
+ upgrade_ops: "UpgradeOps",
+ autogen_context: "AutogenContext",
+) -> None:
default_schema = inspector.bind.dialect.default_schema_name
upgrade_ops.ops.append(modify_table_ops)
-def _make_index(params, conn_table):
+def _make_index(params: Dict[str, Any], conn_table: "Table") -> "Index":
ix = sa_schema.Index(
params["name"],
*[conn_table.c[cname] for cname in params["column_names"]],
return ix
-def _make_unique_constraint(params, conn_table):
+def _make_unique_constraint(
+ params: Dict[str, Any], conn_table: "Table"
+) -> "UniqueConstraint":
uq = sa_schema.UniqueConstraint(
*[conn_table.c[cname] for cname in params["column_names"]],
name=params["name"]
return uq
-def _make_foreign_key(params, conn_table):
+def _make_foreign_key(
+ params: Dict[str, Any], conn_table: "Table"
+) -> "ForeignKeyConstraint":
tname = params["referred_table"]
if params["referred_schema"]:
tname = "%s.%s" % (params["referred_schema"], tname)
@contextlib.contextmanager
def _compare_columns(
- schema,
- tname,
- conn_table,
- metadata_table,
- modify_table_ops,
- autogen_context,
- inspector,
-):
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: "Table",
+ metadata_table: "Table",
+ modify_table_ops: "ModifyTableOps",
+ autogen_context: "AutogenContext",
+ inspector: "Inspector",
+) -> Iterator[None]:
name = "%s.%s" % (schema, tname) if schema else tname
metadata_col_names = OrderedSet(
c.name for c in metadata_table.c if not c.system
class _constraint_sig:
- def md_name_to_sql_name(self, context):
+ const: Union["UniqueConstraint", "ForeignKeyConstraint", "Index"]
+
+ def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
def __ne__(self, other):
return self.const != other.const
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self.const)
is_index = False
is_unique = True
- def __init__(self, const):
+ def __init__(self, const: "UniqueConstraint") -> None:
self.const = const
self.name = const.name
self.sig = tuple(sorted([col.name for col in const.columns]))
@property
- def column_names(self):
+ def column_names(self) -> List[str]:
return [col.name for col in self.const.columns]
class _ix_constraint_sig(_constraint_sig):
is_index = True
- def __init__(self, const):
+ def __init__(self, const: "Index") -> None:
self.const = const
self.name = const.name
self.sig = tuple(sorted([col.name for col in const.columns]))
self.is_unique = bool(const.unique)
- def md_name_to_sql_name(self, context):
+ def md_name_to_sql_name(self, context: "AutogenContext") -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
@property
- def column_names(self):
+ def column_names(self) -> Union[List["quoted_name"], List[None]]:
return sqla_compat._get_index_column_names(self.const)
class _fk_constraint_sig(_constraint_sig):
- def __init__(self, const, include_options=False):
+ def __init__(
+ self, const: "ForeignKeyConstraint", include_options: bool = False
+ ) -> None:
self.const = const
self.name = const.name
initially,
) = _fk_spec(const)
- self.sig = (
+ self.sig: Tuple[Any, ...] = (
self.source_schema,
self.source_table,
tuple(self.source_columns),
@comparators.dispatch_for("table")
def _compare_indexes_and_uniques(
- autogen_context, modify_ops, schema, tname, conn_table, metadata_table
-):
+ autogen_context: "AutogenContext",
+ modify_ops: "ModifyTableOps",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: Optional["Table"],
+ metadata_table: Optional["Table"],
+) -> None:
inspector = autogen_context.inspector
is_create_table = conn_table is None
metadata_unique_constraints = set()
metadata_indexes = set()
- conn_uniques = conn_indexes = frozenset()
+ conn_uniques = conn_indexes = frozenset() # type:ignore[var-annotated]
supports_unique_constraints = False
# 1b. ... and from connection, if the table exists
if hasattr(inspector, "get_unique_constraints"):
try:
- conn_uniques = inspector.get_unique_constraints(
+ conn_uniques = inspector.get_unique_constraints( # type:ignore[assignment] # noqa
tname, schema=schema
)
supports_unique_constraints = True
# not being present
pass
else:
- conn_uniques = [
+ conn_uniques = [ # type:ignore[assignment]
uq
for uq in conn_uniques
if autogen_context.run_name_filters(
if uq.get("duplicates_index"):
unique_constraints_duplicate_unique_indexes = True
try:
- conn_indexes = inspector.get_indexes(tname, schema=schema)
+ conn_indexes = inspector.get_indexes( # type:ignore[assignment]
+ tname, schema=schema
+ )
except NotImplementedError:
pass
else:
- conn_indexes = [
+ conn_indexes = [ # type:ignore[assignment]
ix
for ix in conn_indexes
if autogen_context.run_name_filters(
# into schema objects
if is_drop_table:
# for DROP TABLE uniques are inline, don't need them
- conn_uniques = set()
+ conn_uniques = set() # type:ignore[assignment]
else:
- conn_uniques = set(
+ conn_uniques = set( # type:ignore[assignment]
_make_unique_constraint(uq_def, conn_table)
for uq_def in conn_uniques
)
- conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes)
+ conn_indexes = set( # type:ignore[assignment]
+ _make_index(ix, conn_table) for ix in conn_indexes
+ )
# 2a. if the dialect dupes unique indexes as unique constraints
# (mysql and oracle), correct for that
# _constraint_sig() objects provide a consistent facade over both
# Index and UniqueConstraint so we can easily work with them
# interchangeably
- metadata_unique_constraints = set(
+ metadata_unique_constraints_sig = set(
_uq_constraint_sig(uq) for uq in metadata_unique_constraints
)
- metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes)
+ metadata_indexes_sig = set(
+ _ix_constraint_sig(ix) for ix in metadata_indexes
+ )
conn_unique_constraints = set(
_uq_constraint_sig(uq) for uq in conn_uniques
)
- conn_indexes = set(_ix_constraint_sig(ix) for ix in conn_indexes)
+ conn_indexes_sig = set(_ix_constraint_sig(ix) for ix in conn_indexes)
# 5. index things by name, for those objects that have names
metadata_names = dict(
- (c.md_name_to_sql_name(autogen_context), c)
- for c in metadata_unique_constraints.union(metadata_indexes)
+ (cast(str, c.md_name_to_sql_name(autogen_context)), c)
+ for c in metadata_unique_constraints_sig.union(
+ metadata_indexes_sig # type:ignore[arg-type]
+ )
if isinstance(c, _ix_constraint_sig)
or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
)
conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints)
- conn_indexes_by_name = dict((c.name, c) for c in conn_indexes)
+ conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = dict(
+ (c.name, c) for c in conn_indexes_sig
+ )
conn_names = dict(
(c.name, c)
- for c in conn_unique_constraints.union(conn_indexes)
+ for c in conn_unique_constraints.union(
+ conn_indexes_sig # type:ignore[arg-type]
+ )
if c.name is not None
)
# constraints.
conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints)
metadata_uniques_by_sig = dict(
- (uq.sig, uq) for uq in metadata_unique_constraints
+ (uq.sig, uq) for uq in metadata_unique_constraints_sig
)
- metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes)
+ metadata_indexes_by_sig = dict((ix.sig, ix) for ix in metadata_indexes_sig)
unnamed_metadata_uniques = dict(
(uq.sig, uq)
- for uq in metadata_unique_constraints
+ for uq in metadata_unique_constraints_sig
if not sqla_compat._constraint_is_named(
uq.const, autogen_context.dialect
)
)
for removed_name in sorted(set(conn_names).difference(metadata_names)):
- conn_obj = conn_names[removed_name]
+ conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[
+ removed_name
+ ]
if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
continue
elif removed_name in doubled_constraints:
@comparators.dispatch_for("column")
def _compare_nullable(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: Union["quoted_name", str],
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
metadata_col_nullable = metadata_col.nullable
conn_col_nullable = conn_col.nullable
@comparators.dispatch_for("column")
def _setup_autoincrement(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: "quoted_name",
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
if metadata_col.table._autoincrement_column is metadata_col:
alter_column_op.kw["autoincrement"] = True
@comparators.dispatch_for("column")
def _compare_type(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: Union["quoted_name", str],
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
conn_type = conn_col.type
alter_column_op.existing_type = conn_type
def _render_server_default_for_compare(
- metadata_default, metadata_col, autogen_context
-):
+ metadata_default: Optional[Any],
+ metadata_col: "Column",
+ autogen_context: "AutogenContext",
+) -> Optional[str]:
rendered = _user_defined_render(
"server_default", metadata_default, autogen_context
)
return None
-def _normalize_computed_default(sqltext):
+def _normalize_computed_default(sqltext: str) -> str:
"""we want to warn if a computed sql expression has changed. however
we don't want false positives and the warning is not that critical.
so filter out most forms of variability from the SQL text.
def _compare_computed_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: "str",
+ cname: "str",
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> None:
rendered_metadata_default = str(
- metadata_col.server_default.sqltext.compile(
+ cast(sa_schema.Computed, metadata_col.server_default).sqltext.compile(
dialect=autogen_context.dialect,
compile_kwargs={"literal_binds": True},
)
_warn_computed_not_supported(tname, cname)
-def _warn_computed_not_supported(tname, cname):
+def _warn_computed_not_supported(tname: str, cname: str) -> None:
util.warn("Computed default on %s.%s cannot be modified" % (tname, cname))
@comparators.dispatch_for("column")
def _compare_server_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: Union["quoted_name", str],
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> Optional[bool]:
metadata_default = metadata_col.server_default
conn_col_default = conn_col.server_default
return False
else:
- return _compare_computed_default(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
+ return (
+ _compare_computed_default( # type:ignore[func-returns-value]
+ autogen_context,
+ alter_column_op,
+ schema,
+ tname,
+ cname,
+ conn_col,
+ metadata_col,
+ )
)
if sqla_compat._server_default_is_computed(conn_col_default):
_warn_computed_not_supported(tname, cname)
)
rendered_conn_default = (
- conn_col_default.arg.text if conn_col_default else None
+ cast(Any, conn_col_default).arg.text if conn_col_default else None
)
alter_column_op.existing_server_default = conn_col_default
alter_column_op.modify_server_default = metadata_default
log.info("Detected server default on column '%s.%s'", tname, cname)
+ return None
+
@comparators.dispatch_for("column")
def _compare_column_comment(
- autogen_context,
- alter_column_op,
- schema,
- tname,
- cname,
- conn_col,
- metadata_col,
-):
-
+ autogen_context: "AutogenContext",
+ alter_column_op: "AlterColumnOp",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ cname: "quoted_name",
+ conn_col: "Column",
+ metadata_col: "Column",
+) -> Optional["Literal[False]"]:
+
+ assert autogen_context.dialect is not None
if not autogen_context.dialect.supports_comments:
- return
+ return None
metadata_comment = metadata_col.comment
conn_col_comment = conn_col.comment
alter_column_op.modify_comment = metadata_comment
log.info("Detected column comment '%s.%s'", tname, cname)
+ return None
+
@comparators.dispatch_for("table")
def _compare_foreign_keys(
- autogen_context,
- modify_table_ops,
- schema,
- tname,
- conn_table,
- metadata_table,
-):
+ autogen_context: "AutogenContext",
+ modify_table_ops: "ModifyTableOps",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: Optional["Table"],
+ metadata_table: Optional["Table"],
+) -> None:
# if we're doing CREATE TABLE, all FKs are created
# inline within the table def
)
]
- backend_reflects_fk_options = conn_fks and "options" in conn_fks[0]
+ backend_reflects_fk_options = bool(conn_fks and "options" in conn_fks[0])
conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks)
@comparators.dispatch_for("table")
def _compare_table_comment(
- autogen_context,
- modify_table_ops,
- schema,
- tname,
- conn_table,
- metadata_table,
-):
-
+ autogen_context: "AutogenContext",
+ modify_table_ops: "ModifyTableOps",
+ schema: Optional[str],
+ tname: Union["quoted_name", str],
+ conn_table: Optional["Table"],
+ metadata_table: Optional["Table"],
+) -> None:
+
+ assert autogen_context.dialect is not None
if not autogen_context.dialect.supports_comments:
return
from collections import OrderedDict
from io import StringIO
import re
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from mako.pygen import PythonPrinter
from sqlalchemy import schema as sa_schema
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
+from sqlalchemy.sql.elements import conv
from .. import util
from ..operations import ops
from ..util import sqla_compat
from ..util.compat import string_types
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import CheckConstraint
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import DefaultClause
+ from sqlalchemy.sql.schema import FetchedValue
+ from sqlalchemy.sql.schema import ForeignKey
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import PrimaryKeyConstraint
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.sqltypes import ARRAY
+ from sqlalchemy.sql.type_api import TypeEngine
+ from sqlalchemy.sql.type_api import Variant
+
+ from alembic.autogenerate.api import AutogenContext
+ from alembic.config import Config
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import ModifyTableOps
+ from alembic.util.sqla_compat import Computed
+ from alembic.util.sqla_compat import Identity
-MAX_PYTHON_ARGS = 255
-
-try:
- from sqlalchemy.sql.naming import conv
-
- def _render_gen_name(autogen_context, name):
- if isinstance(name, conv):
- return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
- else:
- return name
+MAX_PYTHON_ARGS = 255
-except ImportError:
- def _render_gen_name(autogen_context, name):
+def _render_gen_name(
+ autogen_context: "AutogenContext",
+ name: Optional[Union["quoted_name", str]],
+) -> Optional[Union["quoted_name", str, "_f_name"]]:
+ if isinstance(name, conv):
+ return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
+ else:
return name
-def _indent(text):
+def _indent(text: str) -> str:
text = re.compile(r"^", re.M).sub(" ", text).strip()
text = re.compile(r" +$", re.M).sub("", text)
return text
def _render_python_into_templatevars(
- autogen_context, migration_script, template_args
-):
+ autogen_context: "AutogenContext",
+ migration_script: "MigrationScript",
+ template_args: Dict[str, Union[str, "Config"]],
+) -> None:
imports = autogen_context.imports
for upgrade_ops, downgrade_ops in zip(
default_renderers = renderers = util.Dispatcher()
-def _render_cmd_body(op_container, autogen_context):
+def _render_cmd_body(
+ op_container: "ops.OpContainer",
+ autogen_context: "AutogenContext",
+) -> str:
buf = StringIO()
printer = PythonPrinter(buf)
has_lines = False
for op in op_container.ops:
lines = render_op(autogen_context, op)
- has_lines = has_lines or lines
+ has_lines = has_lines or bool(lines)
for line in lines:
printer.writeline(line)
return buf.getvalue()
-def render_op(autogen_context, op):
+def render_op(
+ autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+) -> List[str]:
renderer = renderers.dispatch(op)
lines = util.to_list(renderer(autogen_context, op))
return lines
-def render_op_text(autogen_context, op):
+def render_op_text(
+ autogen_context: "AutogenContext", op: "ops.MigrateOperation"
+) -> str:
return "\n".join(render_op(autogen_context, op))
@renderers.dispatch_for(ops.ModifyTableOps)
-def _render_modify_table(autogen_context, op):
+def _render_modify_table(
+ autogen_context: "AutogenContext", op: "ModifyTableOps"
+) -> List[str]:
opts = autogen_context.opts
render_as_batch = opts.get("render_as_batch", False)
@renderers.dispatch_for(ops.CreateTableCommentOp)
-def _render_create_table_comment(autogen_context, op):
+def _render_create_table_comment(
+ autogen_context: "AutogenContext", op: "ops.CreateTableCommentOp"
+) -> str:
templ = (
"{prefix}create_table_comment(\n"
@renderers.dispatch_for(ops.DropTableCommentOp)
-def _render_drop_table_comment(autogen_context, op):
+def _render_drop_table_comment(
+ autogen_context: "AutogenContext", op: "ops.DropTableCommentOp"
+) -> str:
templ = (
"{prefix}drop_table_comment(\n"
@renderers.dispatch_for(ops.CreateTableOp)
-def _add_table(autogen_context, op):
+def _add_table(
+ autogen_context: "AutogenContext", op: "ops.CreateTableOp"
+) -> str:
table = op.to_table()
args = [
)
if len(args) > MAX_PYTHON_ARGS:
- args = "*[" + ",\n".join(args) + "]"
+ args_str = "*[" + ",\n".join(args) + "]"
else:
- args = ",\n".join(args)
+ args_str = ",\n".join(args)
text = "%(prefix)screate_table(%(tablename)r,\n%(args)s" % {
"tablename": _ident(op.table_name),
"prefix": _alembic_autogenerate_prefix(autogen_context),
- "args": args,
+ "args": args_str,
}
if op.schema:
text += ",\nschema=%r" % _ident(op.schema)
@renderers.dispatch_for(ops.DropTableOp)
-def _drop_table(autogen_context, op):
+def _drop_table(
+ autogen_context: "AutogenContext", op: "ops.DropTableOp"
+) -> str:
text = "%(prefix)sdrop_table(%(tname)r" % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"tname": _ident(op.table_name),
@renderers.dispatch_for(ops.CreateIndexOp)
-def _add_index(autogen_context, op):
+def _add_index(
+ autogen_context: "AutogenContext", op: "ops.CreateIndexOp"
+) -> str:
index = op.to_index()
has_batch = autogen_context._has_batch
"unique=%(unique)r%(schema)s%(kwargs)s)"
)
+ assert index.table is not None
text = tmpl % {
"prefix": _alembic_autogenerate_prefix(autogen_context),
"name": _render_gen_name(autogen_context, index.name),
@renderers.dispatch_for(ops.DropIndexOp)
-def _drop_index(autogen_context, op):
+def _drop_index(
+ autogen_context: "AutogenContext", op: "ops.DropIndexOp"
+) -> str:
index = op.to_index()
has_batch = autogen_context._has_batch
@renderers.dispatch_for(ops.CreateUniqueConstraintOp)
-def _add_unique_constraint(autogen_context, op):
+def _add_unique_constraint(
+ autogen_context: "AutogenContext", op: "ops.CreateUniqueConstraintOp"
+) -> List[str]:
return [_uq_constraint(op.to_constraint(), autogen_context, True)]
@renderers.dispatch_for(ops.CreateForeignKeyOp)
-def _add_fk_constraint(autogen_context, op):
+def _add_fk_constraint(
+ autogen_context: "AutogenContext", op: "ops.CreateForeignKeyOp"
+) -> str:
args = [repr(_render_gen_name(autogen_context, op.constraint_name))]
if not autogen_context._has_batch:
@renderers.dispatch_for(ops.DropConstraintOp)
-def _drop_constraint(autogen_context, op):
+def _drop_constraint(
+ autogen_context: "AutogenContext", op: "ops.DropConstraintOp"
+) -> str:
if autogen_context._has_batch:
template = "%(prefix)sdrop_constraint" "(%(name)r, type_=%(type)r)"
@renderers.dispatch_for(ops.AddColumnOp)
-def _add_column(autogen_context, op):
+def _add_column(
+ autogen_context: "AutogenContext", op: "ops.AddColumnOp"
+) -> str:
schema, tname, column = op.schema, op.table_name, op.column
if autogen_context._has_batch:
@renderers.dispatch_for(ops.DropColumnOp)
-def _drop_column(autogen_context, op):
+def _drop_column(
+ autogen_context: "AutogenContext", op: "ops.DropColumnOp"
+) -> str:
schema, tname, column_name = op.schema, op.table_name, op.column_name
@renderers.dispatch_for(ops.AlterColumnOp)
-def _alter_column(autogen_context, op):
+def _alter_column(
+ autogen_context: "AutogenContext", op: "ops.AlterColumnOp"
+) -> str:
tname = op.table_name
cname = op.column_name
class _f_name:
- def __init__(self, prefix, name):
+ def __init__(self, prefix: str, name: conv) -> None:
self.prefix = prefix
self.name = name
- def __repr__(self):
+ def __repr__(self) -> str:
return "%sf(%r)" % (self.prefix, _ident(self.name))
-def _ident(name):
+def _ident(name: Optional[Union["quoted_name", str]]) -> Optional[str]:
"""produce a __repr__() object for a string identifier that may
use quoted_name() in SQLAlchemy 0.9 and greater.
def _render_potential_expr(
- value, autogen_context, wrap_in_text=True, is_server_default=False
-):
+ value: Any,
+ autogen_context: "AutogenContext",
+ wrap_in_text: bool = True,
+ is_server_default: bool = False,
+) -> str:
if isinstance(value, sql.ClauseElement):
if wrap_in_text:
return repr(value)
-def _get_index_rendered_expressions(idx, autogen_context):
+def _get_index_rendered_expressions(
+ idx: "Index", autogen_context: "AutogenContext"
+) -> List[str]:
return [
repr(_ident(getattr(exp, "name", None)))
if isinstance(exp, sa_schema.Column)
]
-def _uq_constraint(constraint, autogen_context, alter):
- opts = []
+def _uq_constraint(
+ constraint: "UniqueConstraint",
+ autogen_context: "AutogenContext",
+ alter: bool,
+) -> str:
+ opts: List[Tuple[str, Any]] = []
has_batch = autogen_context._has_batch
return prefix
-def _sqlalchemy_autogenerate_prefix(autogen_context):
+def _sqlalchemy_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
return autogen_context.opts["sqlalchemy_module_prefix"] or ""
-def _alembic_autogenerate_prefix(autogen_context):
+def _alembic_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
if autogen_context._has_batch:
return "batch_op."
else:
return autogen_context.opts["alembic_module_prefix"] or ""
-def _user_defined_render(type_, object_, autogen_context):
+def _user_defined_render(
+ type_: str, object_: Any, autogen_context: "AutogenContext"
+) -> Union[str, "Literal[False]"]:
if "render_item" in autogen_context.opts:
render = autogen_context.opts["render_item"]
if render:
return False
-def _render_column(column, autogen_context):
+def _render_column(column: "Column", autogen_context: "AutogenContext") -> str:
rendered = _user_defined_render("column", column, autogen_context)
if rendered is not False:
return rendered
- args = []
- opts = []
+ args: List[str] = []
+ opts: List[Tuple[str, Any]] = []
if column.server_default:
- rendered = _render_server_default(
+ rendered = _render_server_default( # type:ignore[assignment]
column.server_default, autogen_context
)
if rendered:
}
-def _should_render_server_default_positionally(server_default):
+def _should_render_server_default_positionally(
+ server_default: Union["Computed", "DefaultClause"]
+) -> bool:
return sqla_compat._server_default_is_computed(
server_default
) or sqla_compat._server_default_is_identity(server_default)
-def _render_server_default(default, autogen_context, repr_=True):
+def _render_server_default(
+ default: Optional[
+ Union["FetchedValue", str, "TextClause", "ColumnElement"]
+ ],
+ autogen_context: "AutogenContext",
+ repr_: bool = True,
+) -> Optional[str]:
rendered = _user_defined_render("server_default", default, autogen_context)
if rendered is not False:
return rendered
if sqla_compat._server_default_is_computed(default):
- return _render_computed(default, autogen_context)
+ return _render_computed(cast("Computed", default), autogen_context)
elif sqla_compat._server_default_is_identity(default):
- return _render_identity(default, autogen_context)
+ return _render_identity(cast("Identity", default), autogen_context)
elif isinstance(default, sa_schema.DefaultClause):
if isinstance(default.arg, compat.string_types):
default = default.arg
if isinstance(default, string_types) and repr_:
default = repr(re.sub(r"^'|'$", "", default))
- return default
+ return cast(str, default)
-def _render_computed(computed, autogen_context):
+def _render_computed(
+ computed: "Computed", autogen_context: "AutogenContext"
+) -> str:
text = _render_potential_expr(
computed.sqltext, autogen_context, wrap_in_text=False
)
}
-def _render_identity(identity, autogen_context):
+def _render_identity(
+ identity: "Identity", autogen_context: "AutogenContext"
+) -> str:
# always=None means something different than always=False
kwargs = OrderedDict(always=identity.always)
if identity.on_null is not None:
}
-def _get_identity_options(identity_options):
+def _get_identity_options(identity_options: "Identity") -> OrderedDict:
kwargs = OrderedDict()
for attr in sqla_compat._identity_options_attrs:
value = getattr(identity_options, attr, None)
return kwargs
-def _repr_type(type_, autogen_context):
+def _repr_type(type_: "TypeEngine", autogen_context: "AutogenContext") -> str:
rendered = _user_defined_render("type", type_, autogen_context)
if rendered is not False:
return rendered
mod = type(type_).__module__
imports = autogen_context.imports
if mod.startswith("sqlalchemy.dialects"):
- dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
+ match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
+ assert match is not None
+ dname = match.group(1)
if imports is not None:
imports.add("from sqlalchemy.dialects import %s" % dname)
if impl_rt:
return "%s%r" % (prefix, type_)
-def _render_ARRAY_type(type_, autogen_context):
- return _render_type_w_subtype(
- type_, autogen_context, "item_type", r"(.+?\()"
+def _render_ARRAY_type(
+ type_: "ARRAY", autogen_context: "AutogenContext"
+) -> str:
+ return cast(
+ str,
+ _render_type_w_subtype(
+ type_, autogen_context, "item_type", r"(.+?\()"
+ ),
)
-def _render_Variant_type(type_, autogen_context):
+def _render_Variant_type(
+ type_: "Variant", autogen_context: "AutogenContext"
+) -> str:
base = _repr_type(type_.impl, autogen_context)
+ assert base is not None and base is not False
for dialect in sorted(type_.mapping):
typ = type_.mapping[dialect]
base += ".with_variant(%s, %r)" % (
def _render_type_w_subtype(
- type_, autogen_context, attrname, regexp, prefix=None
-):
+ type_: "TypeEngine",
+ autogen_context: "AutogenContext",
+ attrname: str,
+ regexp: str,
+ prefix: Optional[str] = None,
+) -> Union[Optional[str], "Literal[False]"]:
outer_repr = repr(type_)
inner_type = getattr(type_, attrname, None)
if inner_type is None:
mod = type(type_).__module__
if mod.startswith("sqlalchemy.dialects"):
- dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1)
+ match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod)
+ assert match is not None
+ dname = match.group(1)
return "%s.%s" % (dname, outer_type)
elif mod.startswith("sqlalchemy"):
prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
_constraint_renderers = util.Dispatcher()
-def _render_constraint(constraint, autogen_context, namespace_metadata):
+def _render_constraint(
+ constraint: "Constraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
try:
renderer = _constraint_renderers.dispatch(constraint)
except ValueError:
@_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint)
-def _render_primary_key(constraint, autogen_context, namespace_metadata):
+def _render_primary_key(
+ constraint: "PrimaryKeyConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
rendered = _user_defined_render("primary_key", constraint, autogen_context)
if rendered is not False:
return rendered
}
-def _fk_colspec(fk, metadata_schema, namespace_metadata):
+def _fk_colspec(
+ fk: "ForeignKey",
+ metadata_schema: Optional[str],
+ namespace_metadata: "MetaData",
+) -> str:
"""Implement a 'safe' version of ForeignKey._get_colspec() that
won't fail if the remote table can't be resolved.
"""
- colspec = fk._get_colspec()
+ colspec = fk._get_colspec() # type:ignore[attr-defined]
tokens = colspec.split(".")
tname, colname = tokens[-2:]
return colspec
-def _populate_render_fk_opts(constraint, opts):
+def _populate_render_fk_opts(
+ constraint: "ForeignKeyConstraint", opts: List[Tuple[str, str]]
+) -> None:
if constraint.onupdate:
opts.append(("onupdate", repr(constraint.onupdate)))
@_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint)
-def _render_foreign_key(constraint, autogen_context, namespace_metadata):
+def _render_foreign_key(
+ constraint: "ForeignKeyConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: "MetaData",
+) -> Optional[str]:
rendered = _user_defined_render("foreign_key", constraint, autogen_context)
if rendered is not False:
return rendered
% {
"prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
"cols": ", ".join(
- "%r" % _ident(f.parent.name) for f in constraint.elements
+ "%r" % _ident(cast("Column", f.parent).name)
+ for f in constraint.elements
),
"refcols": ", ".join(
repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
@_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint)
-def _render_unique_constraint(constraint, autogen_context, namespace_metadata):
+def _render_unique_constraint(
+ constraint: "UniqueConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> str:
rendered = _user_defined_render("unique", constraint, autogen_context)
if rendered is not False:
return rendered
@_constraint_renderers.dispatch_for(sa_schema.CheckConstraint)
-def _render_check_constraint(constraint, autogen_context, namespace_metadata):
+def _render_check_constraint(
+ constraint: "CheckConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: Optional["MetaData"],
+) -> Optional[str]:
rendered = _user_defined_render("check", constraint, autogen_context)
if rendered is not False:
return rendered
# ideally SQLAlchemy would give us more of a first class
# way to detect this.
if (
- constraint._create_rule
- and hasattr(constraint._create_rule, "target")
- and isinstance(constraint._create_rule.target, sqltypes.TypeEngine)
+ constraint._create_rule # type:ignore[attr-defined]
+ and hasattr(
+ constraint._create_rule, "target" # type:ignore[attr-defined]
+ )
+ and isinstance(
+ constraint._create_rule.target, # type:ignore[attr-defined]
+ sqltypes.TypeEngine,
+ )
):
return None
opts = []
@renderers.dispatch_for(ops.ExecuteSQLOp)
-def _execute_sql(autogen_context, op):
+def _execute_sql(
+ autogen_context: "AutogenContext", op: "ops.ExecuteSQLOp"
+) -> str:
if not isinstance(op.sqltext, string_types):
raise NotImplementedError(
"Autogenerate rendering of SQL Expression language constructs "
+from typing import Any
+from typing import Callable
+from typing import Iterator
+from typing import List
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
+
from alembic import util
from alembic.operations import ops
+if TYPE_CHECKING:
+ from alembic.operations.ops import AddColumnOp
+ from alembic.operations.ops import AlterColumnOp
+ from alembic.operations.ops import CreateTableOp
+ from alembic.operations.ops import MigrateOperation
+ from alembic.operations.ops import MigrationScript
+ from alembic.operations.ops import ModifyTableOps
+ from alembic.operations.ops import OpContainer
+ from alembic.runtime.migration import MigrationContext
+ from alembic.script.revision import Revision
+
class Rewriter:
"""A helper object that allows easy 'rewriting' of ops streams.
_chained = None
- def __init__(self):
+ def __init__(self) -> None:
self.dispatch = util.Dispatcher()
- def chain(self, other):
+ def chain(self, other: "Rewriter") -> "Rewriter":
"""Produce a "chain" of this :class:`.Rewriter` to another.
This allows two rewriters to operate serially on a stream,
wr._chained = other
return wr
- def rewrites(self, operator):
+ def rewrites(
+ self,
+ operator: Union[
+ Type["AddColumnOp"],
+ Type["MigrateOperation"],
+ Type["AlterColumnOp"],
+ Type["CreateTableOp"],
+ Type["ModifyTableOps"],
+ ],
+ ) -> Callable:
"""Register a function as rewriter for a given type.
The function should receive three arguments, which are
"""
return self.dispatch.dispatch_for(operator)
- def _rewrite(self, context, revision, directive):
+ def _rewrite(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrateOperation",
+ ) -> Iterator["MigrateOperation"]:
try:
_rewriter = self.dispatch.dispatch(directive)
except ValueError:
yield directive
else:
for r_directive in util.to_list(
- _rewriter(context, revision, directive)
+ _rewriter(context, revision, directive), []
):
r_directive._mutations = r_directive._mutations.union(
[self]
)
yield r_directive
- def __call__(self, context, revision, directives):
+ def __call__(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directives: List["MigrationScript"],
+ ) -> None:
self.process_revision_directives(context, revision, directives)
if self._chained:
self._chained(context, revision, directives)
@_traverse.dispatch_for(ops.MigrationScript)
- def _traverse_script(self, context, revision, directive):
+ def _traverse_script(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrationScript",
+ ) -> None:
upgrade_ops_list = []
for upgrade_ops in directive.upgrade_ops_list:
ret = self._traverse_for(context, revision, upgrade_ops)
directive.downgrade_ops = downgrade_ops_list
@_traverse.dispatch_for(ops.OpContainer)
- def _traverse_op_container(self, context, revision, directive):
+ def _traverse_op_container(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "OpContainer",
+ ) -> None:
self._traverse_list(context, revision, directive.ops)
@_traverse.dispatch_for(ops.MigrateOperation)
- def _traverse_any_directive(self, context, revision, directive):
+ def _traverse_any_directive(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrateOperation",
+ ) -> None:
pass
- def _traverse_for(self, context, revision, directive):
+ def _traverse_for(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directive: "MigrateOperation",
+ ) -> Any:
directives = list(self._rewrite(context, revision, directive))
for directive in directives:
traverser = self._traverse.dispatch(directive)
traverser(self, context, revision, directive)
return directives
- def _traverse_list(self, context, revision, directives):
+ def _traverse_list(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directives: Any,
+ ) -> None:
dest = []
for directive in directives:
dest.extend(self._traverse_for(context, revision, directive))
directives[:] = dest
- def process_revision_directives(self, context, revision, directives):
+ def process_revision_directives(
+ self,
+ context: "MigrationContext",
+ revision: "Revision",
+ directives: List["MigrationScript"],
+ ) -> None:
self._traverse_list(context, revision, directives)
import os
+from typing import Callable
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from . import autogenerate as autogen
from . import util
from .runtime.environment import EnvironmentContext
from .script import ScriptDirectory
+if TYPE_CHECKING:
+ from alembic.config import Config
+ from alembic.script.base import Script
+
def list_templates(config):
"""List available templates.
config.print_stdout("\n alembic init --template generic ./scripts")
-def init(config, directory, template="generic", package=False):
+def init(
+ config: "Config",
+ directory: str,
+ template: str = "generic",
+ package: bool = False,
+) -> None:
"""Initialize a new scripts directory.
:param config: a :class:`.Config` object.
for file_ in os.listdir(template_dir):
file_path = os.path.join(template_dir, file_)
if file_ == "alembic.ini.mako":
- config_file = os.path.abspath(config.config_file_name)
- if os.access(config_file, os.F_OK):
+ config_file = os.path.abspath(cast(str, config.config_file_name))
+ if os.access(cast(str, config_file), os.F_OK):
util.msg("File %s already exists, skipping" % config_file)
else:
script._generate_template(
os.path.join(os.path.abspath(versions), "__init__.py"),
]:
file_ = util.status("Adding %s" % path, open, path, "w")
- file_.close()
+ file_.close() # type:ignore[attr-defined]
util.msg(
"Please edit configuration/connection/logging "
def revision(
- config,
- message=None,
- autogenerate=False,
- sql=False,
- head="head",
- splice=False,
- branch_label=None,
- version_path=None,
- rev_id=None,
- depends_on=None,
- process_revision_directives=None,
-):
+ config: "Config",
+ message: Optional[str] = None,
+ autogenerate: bool = False,
+ sql: bool = False,
+ head: str = "head",
+ splice: bool = False,
+ branch_label: Optional[str] = None,
+ version_path: Optional[str] = None,
+ rev_id: Optional[str] = None,
+ depends_on: Optional[str] = None,
+ process_revision_directives: Callable = None,
+) -> Union[Optional["Script"], List[Optional["Script"]]]:
"""Create a new revision file.
:param config: a :class:`.Config` object.
return scripts
-def merge(config, revisions, message=None, branch_label=None, rev_id=None):
+def merge(
+ config: "Config",
+ revisions: str,
+ message: str = None,
+ branch_label: str = None,
+ rev_id: str = None,
+) -> Optional["Script"]:
"""Merge two revisions together. Creates a new migration file.
:param config: a :class:`.Config` instance
script = ScriptDirectory.from_config(config)
template_args = {
- "config": config # Let templates use config for
+ "config": "config" # Let templates use config for
# e.g. multiple databases
}
return script.generate_revision(
refresh=True,
head=revisions,
branch_labels=branch_label,
- **template_args
+ **template_args # type:ignore[arg-type]
)
-def upgrade(config, revision, sql=False, tag=None):
+def upgrade(
+ config: "Config",
+ revision: str,
+ sql: bool = False,
+ tag: Optional[str] = None,
+) -> None:
"""Upgrade to a later version.
:param config: a :class:`.Config` instance.
script.run_env()
-def downgrade(config, revision, sql=False, tag=None):
+def downgrade(
+ config: "Config",
+ revision: str,
+ sql: bool = False,
+ tag: Optional[str] = None,
+) -> None:
"""Revert to a previous version.
:param config: a :class:`.Config` instance.
config.print_stdout(sc.log_entry)
-def history(config, rev_range=None, verbose=False, indicate_current=False):
+def history(
+ config: "Config",
+ rev_range: Optional[str] = None,
+ verbose: bool = False,
+ indicate_current: bool = False,
+) -> None:
"""List changeset scripts in chronological order.
:param config: a :class:`.Config` instance.
:param indicate_current: indicate current revision.
"""
-
+ base: Optional[str]
+ head: Optional[str]
script = ScriptDirectory.from_config(config)
if rev_range is not None:
if ":" not in rev_range:
)
-def current(config, verbose=False):
+def current(config: "Config", verbose: bool = False) -> None:
"""Display the current revision for a database.
:param config: a :class:`.Config` instance.
script.run_env()
-def stamp(config, revision, sql=False, tag=None, purge=False):
+def stamp(
+ config: "Config",
+ revision: str,
+ sql: bool = False,
+ tag: Optional[str] = None,
+ purge: bool = False,
+) -> None:
"""'stamp' the revision table with the given revision; don't
run any migrations.
script.run_env()
-def edit(config, rev):
+def edit(config: "Config", rev: str) -> None:
"""Edit revision script(s) using $EDITOR.
:param config: a :class:`.Config` instance.
from argparse import ArgumentParser
+from argparse import Namespace
from configparser import ConfigParser
import inspect
import os
import sys
+from typing import Dict
+from typing import Optional
+from typing import overload
+from typing import TextIO
from . import __version__
from . import command
def __init__(
self,
- file_=None,
- ini_section="alembic",
- output_buffer=None,
- stdout=sys.stdout,
- cmd_opts=None,
- config_args=util.immutabledict(),
- attributes=None,
- ):
+ file_: Optional[str] = None,
+ ini_section: str = "alembic",
+ output_buffer: Optional[TextIO] = None,
+ stdout: TextIO = sys.stdout,
+ cmd_opts: Optional[Namespace] = None,
+ config_args: util.immutabledict = util.immutabledict(),
+ attributes: dict = None,
+ ) -> None:
"""Construct a new :class:`.Config`"""
self.config_file_name = file_
self.config_ini_section = ini_section
if attributes:
self.attributes.update(attributes)
- cmd_opts = None
+ cmd_opts: Optional[Namespace] = None
"""The command-line options passed to the ``alembic`` script.
Within an ``env.py`` script this can be accessed via the
"""
- config_file_name = None
+ config_file_name: Optional[str] = None
"""Filesystem path to the .ini file in use."""
- config_ini_section = None
+ config_ini_section: str = None # type:ignore[assignment]
"""Name of the config file section to read basic configuration
from. Defaults to ``alembic``, that is the ``[alembic]`` section
of the .ini file. This value is modified using the ``-n/--name``
"""
return {}
- def print_stdout(self, text, *arg):
+ def print_stdout(self, text: str, *arg) -> None:
"""Render a message to standard out.
When :meth:`.Config.print_stdout` is called with additional args
file_config.add_section(self.config_ini_section)
return file_config
- def get_template_directory(self):
+ def get_template_directory(self) -> str:
"""Return the directory where Alembic setup templates are found.
This method is used by the alembic ``init`` and ``list_templates``
package_dir = os.path.abspath(os.path.dirname(alembic.__file__))
return os.path.join(package_dir, "templates")
- def get_section(self, name, default=None):
+ @overload
+ def get_section(
+ self, name: str, default: Dict[str, str]
+ ) -> Dict[str, str]:
+ ...
+
+ @overload
+ def get_section(
+ self, name: str, default: Optional[Dict[str, str]] = ...
+ ) -> Optional[Dict[str, str]]:
+ ...
+
+ def get_section(self, name: str, default=None):
"""Return all the configuration options from a given .ini file section
as a dictionary.
return dict(self.file_config.items(name))
- def set_main_option(self, name, value):
+ def set_main_option(self, name: str, value: str) -> None:
"""Set an option programmatically within the 'main' section.
This overrides whatever was in the .ini file.
"""
self.set_section_option(self.config_ini_section, name, value)
- def remove_main_option(self, name):
+ def remove_main_option(self, name: str) -> None:
self.file_config.remove_option(self.config_ini_section, name)
- def set_section_option(self, section, name, value):
+ def set_section_option(self, section: str, name: str, value: str) -> None:
"""Set an option programmatically within the given section.
The section is created if it doesn't exist already.
self.file_config.add_section(section)
self.file_config.set(section, name, value)
- def get_section_option(self, section, name, default=None):
+ def get_section_option(
+ self, section: str, name: str, default: Optional[str] = None
+ ) -> Optional[str]:
"""Return an option from the given section of the .ini file."""
if not self.file_config.has_section(section):
raise util.CommandError(
else:
return default
+ @overload
+ def get_main_option(self, name: str, default: str) -> str:
+ ...
+
+ @overload
+ def get_main_option(
+ self, name: str, default: Optional[str] = None
+ ) -> Optional[str]:
+ ...
+
def get_main_option(self, name, default=None):
"""Return an option from the 'main' section of the .ini file.
class CommandLine:
- def __init__(self, prog=None):
+ def __init__(self, prog: Optional[str] = None) -> None:
self._generate_args(prog)
- def _generate_args(self, prog):
+ def _generate_args(self, prog: Optional[str]) -> None:
def add_options(fn, parser, positional, kwargs):
kwargs_opts = {
"template": (
else:
help_text.append(line.strip())
else:
- help_text = ""
+ help_text = []
subparser = subparsers.add_parser(
fn.__name__, help=" ".join(help_text)
)
subparser.set_defaults(cmd=(fn, positional, kwarg))
self.parser = parser
- def run_cmd(self, config, options):
+ def run_cmd(self, config: Config, options: Namespace) -> None:
fn, positional, kwarg = options.cmd
try:
import functools
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import exc
from sqlalchemy import Integer
from ..util.sqla_compat import _is_type_bound # noqa
from ..util.sqla_compat import _table_for_constraint # noqa
+if TYPE_CHECKING:
+ from sqlalchemy.sql.compiler import Compiled
+ from sqlalchemy.sql.compiler import DDLCompiler
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.functions import Function
+ from sqlalchemy.sql.schema import FetchedValue
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .impl import DefaultImpl
+ from ..util.sqla_compat import Computed
+ from ..util.sqla_compat import Identity
+
+_ServerDefault = Union["TextClause", "FetchedValue", "Function", str]
+
class AlterTable(DDLElement):
"""
- def __init__(self, table_name, schema=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
self.table_name = table_name
self.schema = schema
class RenameTable(AlterTable):
- def __init__(self, old_table_name, new_table_name, schema=None):
+ def __init__(
+ self,
+ old_table_name: str,
+ new_table_name: Union["quoted_name", str],
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
super(RenameTable, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
class AlterColumn(AlterTable):
def __init__(
self,
- name,
- column_name,
- schema=None,
- existing_type=None,
- existing_nullable=None,
- existing_server_default=None,
- existing_comment=None,
- ):
+ name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_server_default: Optional[_ServerDefault] = None,
+ existing_comment: Optional[str] = None,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.existing_type = (
class ColumnNullable(AlterColumn):
- def __init__(self, name, column_name, nullable, **kw):
+ def __init__(
+ self, name: str, column_name: str, nullable: bool, **kw
+ ) -> None:
super(ColumnNullable, self).__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
- def __init__(self, name, column_name, type_, **kw):
+ def __init__(
+ self, name: str, column_name: str, type_: "TypeEngine", **kw
+ ) -> None:
super(ColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
- def __init__(self, name, column_name, newname, **kw):
+ def __init__(
+ self, name: str, column_name: str, newname: str, **kw
+ ) -> None:
super(ColumnName, self).__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, **kw):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: Optional[_ServerDefault],
+ **kw
+ ) -> None:
super(ColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
class ComputedColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, **kw):
+ def __init__(
+ self, name: str, column_name: str, default: Optional["Computed"], **kw
+ ) -> None:
super(ComputedColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
class IdentityColumnDefault(AlterColumn):
- def __init__(self, name, column_name, default, impl, **kw):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: Optional["Identity"],
+ impl: "DefaultImpl",
+ **kw
+ ) -> None:
super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
self.impl = impl
class AddColumn(AlterTable):
- def __init__(self, name, column, schema=None):
+ def __init__(
+ self,
+ name: str,
+ column: "Column",
+ schema: Optional[Union["quoted_name", str]] = None,
+ ) -> None:
super(AddColumn, self).__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
- def __init__(self, name, column, schema=None):
+ def __init__(
+ self, name: str, column: "Column", schema: Optional[str] = None
+ ) -> None:
super(DropColumn, self).__init__(name, schema=schema)
self.column = column
class ColumnComment(AlterColumn):
- def __init__(self, name, column_name, comment, **kw):
+ def __init__(
+ self, name: str, column_name: str, comment: Optional[str], **kw
+ ) -> None:
super(ColumnComment, self).__init__(name, column_name, **kw)
self.comment = comment
@compiles(RenameTable)
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "DDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, element.schema),
@compiles(AddColumn)
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
@compiles(DropColumn)
-def visit_drop_column(element, compiler, **kw):
+def visit_drop_column(
+ element: "DropColumn", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(compiler, element.column.name, **kw),
@compiles(ColumnNullable)
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ColumnType)
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ColumnName)
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+ element: "ColumnName", compiler: "DDLCompiler", **kw
+) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@compiles(ColumnDefault)
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "DDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ComputedColumnDefault)
-def visit_computed_column(element, compiler, **kw):
+def visit_computed_column(
+ element: "ComputedColumnDefault", compiler: "DDLCompiler", **kw
+):
raise exc.CompileError(
'Adding or removing a "computed" construct, e.g. GENERATED '
"ALWAYS AS, to or from an existing column is not supported."
@compiles(IdentityColumnDefault)
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "DDLCompiler", **kw
+):
raise exc.CompileError(
'Adding, removing or modifying an "identity" construct, '
"e.g. GENERATED AS IDENTITY, to or from an existing "
)
-def quote_dotted(name, quote):
+def quote_dotted(
+ name: Union["quoted_name", str], quote: functools.partial
+) -> Union["quoted_name", str]:
"""quote the elements of a dotted name"""
if isinstance(name, quoted_name):
return result
-def format_table_name(compiler, name, schema):
+def format_table_name(
+ compiler: "Compiled",
+ name: Union["quoted_name", str],
+ schema: Optional[Union["quoted_name", str]],
+) -> Union["quoted_name", str]:
quote = functools.partial(compiler.preparer.quote)
if schema:
return quote_dotted(schema, quote) + "." + quote(name)
return quote(name)
-def format_column_name(compiler, name):
+def format_column_name(
+ compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
+) -> Union["quoted_name", str]:
return compiler.preparer.quote(name)
-def format_server_default(compiler, default):
+def format_server_default(
+ compiler: "DDLCompiler",
+ default: Optional[_ServerDefault],
+) -> str:
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
-def format_type(compiler, type_):
+def format_type(compiler: "DDLCompiler", type_: "TypeEngine") -> str:
return compiler.dialect.type_compiler.process(type_)
-def alter_table(compiler, name, schema):
+def alter_table(
+ compiler: "DDLCompiler",
+ name: str,
+ schema: Optional[str],
+) -> str:
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
-def drop_column(compiler, name):
+def drop_column(compiler: "DDLCompiler", name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
-def alter_column(compiler, name):
+def alter_column(compiler: "DDLCompiler", name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
-def add_column(compiler, column, **kw):
+def add_column(compiler: "DDLCompiler", column: "Column", **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
from collections import namedtuple
import re
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
from ..util.compat import string_types
from ..util.compat import text_type
+if TYPE_CHECKING:
+ from io import StringIO
+ from typing import Literal
+
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.dml import Update
+ from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+ from ..autogenerate.api import AutogenContext
+ from ..operations.batch import ApplyBatchImpl
+ from ..operations.batch import BatchOperationsImpl
+
class ImplMeta(type):
- def __init__(cls, classname, bases, dict_):
+ def __init__(
+ cls,
+ classname: str,
+ bases: Tuple[Type["DefaultImpl"]],
+ dict_: Dict[str, Any],
+ ):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls
return newtype
-_impls = {}
+_impls: dict = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
transactional_ddl = False
command_terminator = ";"
- type_synonyms = ({"NUMERIC", "DECIMAL"},)
- type_arg_extract = ()
+ type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
+ type_arg_extract: Sequence[str] = ()
# on_null is known to be supported only by oracle
- identity_attrs_ignore = ("on_null",)
+ identity_attrs_ignore: Tuple[str, ...] = ("on_null",)
def __init__(
self,
- dialect,
- connection,
- as_sql,
- transactional_ddl,
- output_buffer,
- context_opts,
- ):
+ dialect: "Dialect",
+ connection: Optional["Connection"],
+ as_sql: bool,
+ transactional_ddl: Optional[bool],
+ output_buffer: Optional["StringIO"],
+ context_opts: Dict[str, Any],
+ ) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
- self.memo = {}
+ self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
)
@classmethod
- def get_by_dialect(cls, dialect):
+ def get_by_dialect(cls, dialect: "Dialect") -> Any:
return _impls[dialect.name]
- def static_output(self, text):
+ def static_output(self, text: str) -> None:
+ assert self.output_buffer is not None
self.output_buffer.write(text_type(text + "\n\n"))
self.output_buffer.flush()
- def requires_recreate_in_batch(self, batch_op):
+ def requires_recreate_in_batch(
+ self, batch_op: "BatchOperationsImpl"
+ ) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
"""
return False
- def prep_table_for_batch(self, batch_impl, table):
+ def prep_table_for_batch(
+ self, batch_impl: "ApplyBatchImpl", table: "Table"
+ ) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
"""
@property
- def bind(self):
+ def bind(self) -> Optional["Connection"]:
return self.connection
def _exec(
self,
- construct,
- execution_options=None,
- multiparams=(),
- params=util.immutabledict(),
- ):
+ construct: Union["ClauseElement", str],
+ execution_options: None = None,
+ multiparams: Sequence[dict] = (),
+ params: Dict[str, int] = util.immutabledict(),
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
if isinstance(construct, string_types):
construct = text(construct)
if self.as_sql:
.strip()
+ self.command_terminator
)
+ return None
else:
conn = self.connection
+ assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
+ assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute(construct, multiparams)
- def execute(self, sql, execution_options=None):
+ def execute(
+ self,
+ sql: Union["Update", "TextClause", str],
+ execution_options: None = None,
+ ) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- comment=False,
- existing_comment=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ existing_comment: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
)
if server_default is not False:
kw = {}
+ cls_: Type[
+ Union[
+ base.ComputedColumnDefault,
+ base.IdentityColumnDefault,
+ base.ColumnDefault,
+ ]
+ ]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
cls_(
table_name,
column_name,
- server_default,
+ server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
)
)
- def add_column(self, table_name, column, schema=None):
+ def add_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
- def drop_column(self, table_name, column, schema=None, **kw):
+ def drop_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
- def add_constraint(self, const):
+ def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
self._exec(schema.DropConstraint(const))
- def rename_table(self, old_table_name, new_table_name, schema=None):
+ def rename_table(
+ self,
+ old_table_name: str,
+ new_table_name: Union[str, "quoted_name"],
+ schema: Optional[Union[str, "quoted_name"]] = None,
+ ) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
- def create_table(self, table):
+ def create_table(self, table: "Table") -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
if comment and with_comment:
self.create_column_comment(column)
- def drop_table(self, table):
+ def drop_table(self, table: "Table") -> None:
self._exec(schema.DropTable(table))
- def create_index(self, index):
+ def create_index(self, index: "Index") -> None:
self._exec(schema.CreateIndex(index))
- def create_table_comment(self, table):
+ def create_table_comment(self, table: "Table") -> None:
self._exec(schema.SetTableComment(table))
- def drop_table_comment(self, table):
+ def drop_table_comment(self, table: "Table") -> None:
self._exec(schema.DropTableComment(table))
- def create_column_comment(self, column):
+ def create_column_comment(self, column: "ColumnElement") -> None:
self._exec(schema.SetColumnComment(column))
- def drop_index(self, index):
+ def drop_index(self, index: "Index") -> None:
self._exec(schema.DropIndex(index))
- def bulk_insert(self, table, rows, multiinsert=True):
+ def bulk_insert(
+ self,
+ table: Union["TableClause", "Table"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
sqla_compat._insert_inline(table).values(**row)
)
- def _tokenize_column_type(self, column):
+ def _tokenize_column_type(self, column: "Column") -> Params:
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
return params
- def _column_types_match(self, inspector_params, metadata_params):
+ def _column_types_match(
+ self, inspector_params: "Params", metadata_params: "Params"
+ ) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
return True
return False
- def _column_args_match(self, inspected_params, meta_params):
+ def _column_args_match(
+ self, inspected_params: "Params", meta_params: "Params"
+ ) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
return True
- def compare_type(self, inspector_column, metadata_column):
+ def compare_type(
+ self, inspector_column: "Column", metadata_column: "Column"
+ ) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
def correct_for_autogen_constraints(
self,
- conn_uniques,
- conn_indexes,
- metadata_unique_constraints,
- metadata_indexes,
- ):
+ conn_uniques: Union[Set["UniqueConstraint"]],
+ conn_indexes: Union[Set["Index"]],
+ metadata_unique_constraints: Set["UniqueConstraint"],
+ metadata_indexes: Set["Index"],
+ ) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
existing_transfer["expr"], new_type
)
- def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+ def render_ddl_sql_expr(
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ ) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
)
return text_type(expr.compile(dialect=self.dialect, **compile_kw))
- def _compat_autogen_column_reflect(self, inspector):
+ def _compat_autogen_column_reflect(
+ self, inspector: "Inspector"
+ ) -> Callable:
return self.autogen_column_reflect
- def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
+ def correct_for_autogen_foreignkeys(
+ self,
+ conn_fks: Set["ForeignKeyConstraint"],
+ metadata_fks: Set["ForeignKeyConstraint"],
+ ) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
"""
- def start_migrations(self):
+ def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
"""
- def emit_begin(self):
+ def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
"""
self.static_output("BEGIN" + self.command_terminator)
- def emit_commit(self):
+ def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
"""
self.static_output("COMMIT" + self.command_terminator)
- def render_type(self, type_obj, autogen_context):
+ def render_type(
+ self, type_obj: "TypeEngine", autogen_context: "AutogenContext"
+ ) -> Union[str, "Literal[False]"]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
-from sqlalchemy.sql.expression import ClauseElement
-from sqlalchemy.sql.expression import Executable
+from sqlalchemy.sql.base import Executable
+from sqlalchemy.sql.elements import ClauseElement
from .base import AddColumn
from .base import alter_column
from .. import util
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.mssql.base import MSDDLCompiler
+ from sqlalchemy.dialects.mssql.base import MSSQLCompiler
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+
class MSSQLImpl(DefaultImpl):
__dialect__ = "mssql"
"order",
)
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg, **kw) -> None:
super(MSSQLImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"mssql_batch_separator", self.batch_separator
)
- def _exec(self, construct, *args, **kw):
+ def _exec(
+ self, construct: Any, *args, **kw
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
- def emit_begin(self):
+ def emit_begin(self) -> None:
self.static_output("BEGIN TRANSACTION" + self.command_terminator)
- def emit_commit(self):
+ def emit_commit(self) -> None:
super(MSSQLImpl, self).emit_commit()
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Optional[
+ Union["_ServerDefault", "Literal[False]"]
+ ] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
if nullable is not None:
if existing_type is None:
table_name, column_name, schema=schema, name=name
)
- def create_index(self, index):
+ def create_index(self, index: "Index") -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
mssql_include = index.kwargs.get("mssql_include", None) or ()
+ assert index.table is not None
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index))
- def bulk_insert(self, table, rows, **kw):
+ def bulk_insert( # type:ignore[override]
+ self, table: Union["TableClause", "Table"], rows: List[dict], **kw: Any
+ ) -> None:
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON"
else:
super(MSSQLImpl, self).bulk_insert(table, rows, **kw)
- def drop_column(self, table_name, column, schema=None, **kw):
+ def drop_column(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
class _ExecDropConstraint(Executable, ClauseElement):
- def __init__(self, tname, colname, type_, schema):
+ def __init__(
+ self,
+ tname: str,
+ colname: Union["Column", str],
+ type_: str,
+ schema: Optional[str],
+ ) -> None:
self.tname = tname
self.colname = colname
self.type_ = type_
class _ExecDropFKConstraint(Executable, ClauseElement):
- def __init__(self, tname, colname, schema):
+ def __init__(
+ self, tname: str, colname: "Column", schema: Optional[str]
+ ) -> None:
self.tname = tname
self.colname = colname
self.schema = schema
@compiles(_ExecDropConstraint, "mssql")
-def _exec_drop_col_constraint(element, compiler, **kw):
+def _exec_drop_col_constraint(
+ element: "_ExecDropConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
schema, tname, colname, type_ = (
element.schema,
element.tname,
@compiles(_ExecDropFKConstraint, "mssql")
-def _exec_drop_col_fk_constraint(element, compiler, **kw):
+def _exec_drop_col_fk_constraint(
+ element: "_ExecDropFKConstraint", compiler: "MSSQLCompiler", **kw
+) -> str:
schema, tname, colname = element.schema, element.tname, element.colname
return """declare @const_name varchar(256)
@compiles(AddColumn, "mssql")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
mssql_add_column(compiler, element.column, **kw),
)
-def mssql_add_column(compiler, column, **kw):
+def mssql_add_column(compiler: "MSDDLCompiler", column: "Column", **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(ColumnNullable, "mssql")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ColumnDefault, "mssql")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "MSDDLCompiler", **kw
+) -> str:
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
@compiles(ColumnName, "mssql")
-def visit_rename_column(element, compiler, **kw):
+def visit_rename_column(
+ element: "ColumnName", compiler: "MSDDLCompiler", **kw
+) -> str:
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@compiles(ColumnType, "mssql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "MSDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(RenameTable, "mssql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "MSDDLCompiler", **kw
+) -> str:
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
import re
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
+ from sqlalchemy.sql.ddl import DropConstraint
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+
class MySQLImpl(DefaultImpl):
__dialect__ = "mysql"
type_synonyms = DefaultImpl.type_synonyms + ({"BOOL", "TINYINT"},)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- autoincrement=None,
- existing_autoincrement=None,
- comment=False,
- existing_comment=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ autoincrement: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ existing_comment: Optional[str] = None,
+ **kw: Any
+ ) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
)
)
- def drop_constraint(self, const):
+ def drop_constraint(
+ self,
+ const: "Constraint",
+ ) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
super(MySQLImpl, self).drop_constraint(const)
- def _is_mysql_allowed_functional_default(self, type_, server_default):
+ def _is_mysql_allowed_functional_default(
+ self,
+ type_: Optional["TypeEngine"],
+ server_default: Union["_ServerDefault", "Literal[False]"],
+ ) -> bool:
return (
type_ is not None
- and type_._type_affinity is sqltypes.DateTime
+ and type_._type_affinity # type:ignore[attr-defined]
+ is sqltypes.DateTime
and server_default is not None
)
class MySQLAlterDefault(AlterColumn):
- def __init__(self, name, column_name, default, schema=None):
+ def __init__(
+ self,
+ name: str,
+ column_name: str,
+ default: "_ServerDefault",
+ schema: Optional[str] = None,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.default = default
class MySQLChangeColumn(AlterColumn):
def __init__(
self,
- name,
- column_name,
- schema=None,
- newname=None,
- type_=None,
- nullable=None,
- default=False,
- autoincrement=None,
- comment=False,
- ):
+ name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ newname: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ nullable: Optional[bool] = None,
+ default: Optional[Union["_ServerDefault", "Literal[False]"]] = False,
+ autoincrement: Optional[bool] = None,
+ comment: Optional[Union[str, "Literal[False]"]] = False,
+ ) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
@compiles(MySQLAlterDefault, "mysql", "mariadb")
-def _mysql_alter_default(element, compiler, **kw):
+def _mysql_alter_default(
+ element: "MySQLAlterDefault", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@compiles(MySQLModifyColumn, "mysql", "mariadb")
-def _mysql_modify_column(element, compiler, **kw):
+def _mysql_modify_column(
+ element: "MySQLModifyColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s MODIFY %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@compiles(MySQLChangeColumn, "mysql", "mariadb")
-def _mysql_change_column(element, compiler, **kw):
+def _mysql_change_column(
+ element: "MySQLChangeColumn", compiler: "MySQLDDLCompiler", **kw
+) -> str:
return "%s CHANGE %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
def _mysql_colspec(
- compiler, nullable, server_default, type_, autoincrement, comment
-):
+ compiler: "MySQLDDLCompiler",
+ nullable: Optional[bool],
+ server_default: Optional[Union["_ServerDefault", "Literal[False]"]],
+ type_: "TypeEngine",
+ autoincrement: Optional[bool],
+ comment: Optional[Union[str, "Literal[False]"]],
+) -> str:
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL",
@compiles(schema.DropConstraint, "mysql", "mariadb")
-def _mysql_drop_constraint(element, compiler, **kw):
+def _mysql_drop_constraint(
+ element: "DropConstraint", compiler: "MySQLDDLCompiler", **kw
+) -> str:
"""Redefine SQLAlchemy's drop constraint to
raise errors for invalid constraint type."""
schema.UniqueConstraint,
),
):
- return compiler.visit_drop_constraint(element, **kw)
+ assert not kw
+ return compiler.visit_drop_constraint(element)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
+from typing import Any
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
from .base import RenameTable
from .impl import DefaultImpl
+if TYPE_CHECKING:
+ from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
+ from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.cursor import LegacyCursorResult
+ from sqlalchemy.sql.schema import Column
+
class OracleImpl(DefaultImpl):
__dialect__ = "oracle"
)
identity_attrs_ignore = ()
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg, **kw) -> None:
super(OracleImpl, self).__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"oracle_batch_separator", self.batch_separator
)
- def _exec(self, construct, *args, **kw):
+ def _exec(
+ self, construct: Any, *args, **kw
+ ) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
result = super(OracleImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
- def emit_begin(self):
+ def emit_begin(self) -> None:
self._exec("SET TRANSACTION READ WRITE")
- def emit_commit(self):
+ def emit_commit(self) -> None:
self._exec("COMMIT")
@compiles(AddColumn, "oracle")
-def visit_add_column(element, compiler, **kw):
+def visit_add_column(
+ element: "AddColumn", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
@compiles(ColumnNullable, "oracle")
-def visit_column_nullable(element, compiler, **kw):
+def visit_column_nullable(
+ element: "ColumnNullable", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ColumnType, "oracle")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: "ColumnType", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ColumnName, "oracle")
-def visit_column_name(element, compiler, **kw):
+def visit_column_name(
+ element: "ColumnName", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
@compiles(ColumnDefault, "oracle")
-def visit_column_default(element, compiler, **kw):
+def visit_column_default(
+ element: "ColumnDefault", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ColumnComment, "oracle")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+ element: "ColumnComment", compiler: "OracleDDLCompiler", **kw
+) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = compiler.sql_compiler.render_literal_value(
@compiles(RenameTable, "oracle")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: "RenameTable", compiler: "OracleDDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
-def alter_column(compiler, name):
+def alter_column(compiler: "OracleDDLCompiler", name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)
-def add_column(compiler, column, **kw):
+def add_column(compiler: "OracleDDLCompiler", column: "Column", **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(IdentityColumnDefault, "oracle")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "OracleDDLCompiler", **kw
+):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
import logging
import re
+from typing import Any
+from typing import cast
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import Column
from sqlalchemy import Numeric
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
-from sqlalchemy.sql.expression import ColumnClause
-from sqlalchemy.sql.expression import UnaryExpression
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
from ..util import compat
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.dialects.postgresql.array import ARRAY
+ from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
+ from sqlalchemy.dialects.postgresql.hstore import HSTORE
+ from sqlalchemy.dialects.postgresql.json import JSON
+ from sqlalchemy.dialects.postgresql.json import JSONB
+ from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from .base import _ServerDefault
+ from ..autogenerate.api import AutogenContext
+ from ..autogenerate.render import _f_name
+ from ..runtime.migration import MigrationContext
+
log = logging.getLogger(__name__)
)
)
- def alter_column(
+ def alter_column( # type:ignore[override]
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- schema=None,
- autoincrement=None,
- existing_type=None,
- existing_server_default=None,
- existing_nullable=None,
- existing_autoincrement=None,
- **kw
- ):
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Union["_ServerDefault", "Literal[False]"] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ schema: Optional[str] = None,
+ autoincrement: Optional[bool] = None,
+ existing_type: Optional["TypeEngine"] = None,
+ existing_server_default: Optional["_ServerDefault"] = None,
+ existing_nullable: Optional[bool] = None,
+ existing_autoincrement: Optional[bool] = None,
+ **kw: Any
+ ) -> None:
using = kw.pop("postgresql_using", None)
)
metadata_indexes.discard(idx)
- def render_type(self, type_, autogen_context):
+ def render_type(
+ self, type_: "TypeEngine", autogen_context: "AutogenContext"
+ ) -> Union[str, "Literal[False]"]:
mod = type(type_).__module__
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
return False
- def _render_HSTORE_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+ def _render_HSTORE_type(
+ self, type_: "HSTORE", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
+ ),
)
- def _render_ARRAY_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "item_type", r"(.+?\()"
+ def _render_ARRAY_type(
+ self, type_: "ARRAY", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "item_type", r"(.+?\()"
+ ),
)
- def _render_JSON_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ def _render_JSON_type(
+ self, type_: "JSON", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ ),
)
- def _render_JSONB_type(self, type_, autogen_context):
- return render._render_type_w_subtype(
- type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ def _render_JSONB_type(
+ self, type_: "JSONB", autogen_context: "AutogenContext"
+ ) -> str:
+ return cast(
+ str,
+ render._render_type_w_subtype(
+ type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
+ ),
)
class PostgresqlColumnType(AlterColumn):
- def __init__(self, name, column_name, type_, **kw):
+ def __init__(
+ self, name: str, column_name: str, type_: "TypeEngine", **kw
+ ) -> None:
using = kw.pop("using", None)
super(PostgresqlColumnType, self).__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
@compiles(RenameTable, "postgresql")
-def visit_rename_table(element, compiler, **kw):
+def visit_rename_table(
+ element: RenameTable, compiler: "PGDDLCompiler", **kw
+) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
@compiles(PostgresqlColumnType, "postgresql")
-def visit_column_type(element, compiler, **kw):
+def visit_column_type(
+ element: PostgresqlColumnType, compiler: "PGDDLCompiler", **kw
+) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
@compiles(ColumnComment, "postgresql")
-def visit_column_comment(element, compiler, **kw):
+def visit_column_comment(
+ element: "ColumnComment", compiler: "PGDDLCompiler", **kw
+) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = (
compiler.sql_compiler.render_literal_value(
@compiles(IdentityColumnDefault, "postgresql")
-def visit_identity_column(element, compiler, **kw):
+def visit_identity_column(
+ element: "IdentityColumnDefault", compiler: "PGDDLCompiler", **kw
+):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
def __init__(
self,
- constraint_name,
- table_name,
- elements,
- where=None,
- schema=None,
- _orig_constraint=None,
+ constraint_name: Optional[str],
+ table_name: Union[str, "quoted_name"],
+ elements: Union[
+ Sequence[Tuple[str, str]],
+ Sequence[Tuple["ColumnClause", str]],
+ ],
+ where: Optional[Union["BinaryExpression", str]] = None,
+ schema: Optional[str] = None,
+ _orig_constraint: Optional["ExcludeConstraint"] = None,
**kw
- ):
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint( # type:ignore[override]
+ cls, constraint: "ExcludeConstraint"
+ ) -> "CreateExcludeConstraintOp":
constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
constraint_table.name,
- [(expr, op) for expr, name, op in constraint._render_exprs],
+ [
+ (expr, op)
+ for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
+ ],
where=constraint.where,
schema=constraint_table.schema,
_orig_constraint=constraint,
using=constraint.using,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "ExcludeConstraint":
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
where=self.where,
**self.kw
)
- for expr, name, oper in excl._render_exprs:
+ for (
+ expr,
+ name,
+ oper,
+ ) in excl._render_exprs: # type:ignore[attr-defined]
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@classmethod
def create_exclude_constraint(
- cls, operations, constraint_name, table_name, *elements, **kw
- ):
+ cls,
+ operations: "Operations",
+ constraint_name: str,
+ table_name: str,
+ *elements: Any,
+ **kw: Any
+ ) -> Optional["Table"]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
-def _add_exclude_constraint(autogen_context, op):
+def _add_exclude_constraint(
+ autogen_context: "AutogenContext", op: "CreateExcludeConstraintOp"
+) -> str:
return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(
- constraint, autogen_context, namespace_metadata
-):
+ constraint: "ExcludeConstraint",
+ autogen_context: "AutogenContext",
+ namespace_metadata: "MetaData",
+) -> str:
rendered = render._user_defined_render(
"exclude", constraint, autogen_context
)
return _exclude_constraint(constraint, autogen_context, False)
-def _postgresql_autogenerate_prefix(autogen_context):
+def _postgresql_autogenerate_prefix(autogen_context: "AutogenContext") -> str:
imports = autogen_context.imports
if imports is not None:
return "postgresql."
-def _exclude_constraint(constraint, autogen_context, alter):
- opts = []
+def _exclude_constraint(
+ constraint: "ExcludeConstraint",
+ autogen_context: "AutogenContext",
+ alter: bool,
+) -> str:
+ opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
has_batch = autogen_context._has_batch
_render_potential_column(sqltext, autogen_context),
opstring,
)
- for sqltext, name, opstring in constraint._render_exprs
+ for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
)
if constraint.where is not None:
args = [
"(%s, %r)"
% (_render_potential_column(sqltext, autogen_context), opstring)
- for sqltext, name, opstring in constraint._render_exprs
+ for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
if constraint.where is not None:
args.append(
}
-def _render_potential_column(value, autogen_context):
+def _render_potential_column(
+ value: Union["ColumnClause", "Column"], autogen_context: "AutogenContext"
+) -> str:
if isinstance(value, ColumnClause):
template = "%(prefix)scolumn(%(name)r)"
import re
+from typing import Any
+from typing import Dict
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import cast
from sqlalchemy import JSON
from .impl import DefaultImpl
from .. import util
+if TYPE_CHECKING:
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.elements import Cast
+ from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..operations.batch import BatchOperationsImpl
+
class SQLiteImpl(DefaultImpl):
__dialect__ = "sqlite"
see: http://bugs.python.org/issue10740
"""
- def requires_recreate_in_batch(self, batch_op):
+ def requires_recreate_in_batch(
+ self, batch_op: "BatchOperationsImpl"
+ ) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
else:
return False
- def add_constraint(self, const):
+ def add_constraint(self, const: "Constraint"):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
- if const._create_rule is None:
+ if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect"
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
- elif const._create_rule(self):
+ elif const._create_rule(self): # type:ignore[attr-defined]
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint"
"SQLite migrations using a copy-and-move strategy."
)
- def drop_constraint(self, const):
- if const._create_rule is None:
+ def drop_constraint(self, const: "Constraint"):
+ if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect"
"Please refer to the batch mode feature which allows for "
def compare_server_default(
self,
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_inspector_default,
- ):
+ inspector_column: "Column",
+ metadata_column: "Column",
+ rendered_metadata_default: Optional[str],
+ rendered_inspector_default: Optional[str],
+ ) -> bool:
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
return rendered_inspector_default != rendered_metadata_default
- def _guess_if_default_is_unparenthesized_sql_expr(self, expr):
+ def _guess_if_default_is_unparenthesized_sql_expr(
+ self, expr: Optional[str]
+ ) -> bool:
"""Determine if a server default is a SQL expression or a constant.
There are too many assertions that expect server defaults to round-trip
else:
return True
- def autogen_column_reflect(self, inspector, table, column_info):
+ def autogen_column_reflect(
+ self,
+ inspector: "Inspector",
+ table: "Table",
+ column_info: Dict[str, Any],
+ ) -> None:
# SQLite expression defaults require parenthesis when sent
# as DDL
if self._guess_if_default_is_unparenthesized_sql_expr(
):
column_info["default"] = "(%s)" % (column_info["default"],)
- def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+ def render_ddl_sql_expr(
+ self, expr: "ClauseElement", is_server_default: bool = False, **kw
+ ) -> str:
# SQLite expression defaults require parenthesis when sent
# as DDL
str_expr = super(SQLiteImpl, self).render_ddl_sql_expr(
str_expr = "(%s)" % (str_expr,)
return str_expr
- def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
+ def cast_for_batch_migrate(
+ self,
+ existing: "Column",
+ existing_transfer: Dict[str, Union["TypeEngine", "Cast"]],
+ new_type: "TypeEngine",
+ ) -> None:
if (
- existing.type._type_affinity is not new_type._type_affinity
+ existing.type._type_affinity # type:ignore[attr-defined]
+ is not new_type._type_affinity # type:ignore[attr-defined]
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
--- /dev/null
+from .runtime.environment import * # noqa
--- /dev/null
+from .runtime.migration import * # noqa
from contextlib import contextmanager
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Iterator
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
+
+from sqlalchemy.sql.elements import conv
from . import batch
from . import schemaobj
from ..util.compat import inspect_formatargspec
from ..util.compat import inspect_getargspec
-__all__ = ("Operations", "BatchOperations")
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Connection
-try:
- from sqlalchemy.sql.naming import conv
-except:
- conv = None
+ from .batch import BatchOperationsImpl
+ from .ops import MigrateOperation
+ from ..runtime.migration import MigrationContext
+ from ..util.sqla_compat import _literal_bindparam
+
+__all__ = ("Operations", "BatchOperations")
class Operations(util.ModuleClsProxy):
_to_impl = util.Dispatcher()
- def __init__(self, migration_context, impl=None):
+ def __init__(
+ self,
+ migration_context: "MigrationContext",
+ impl: Optional["BatchOperationsImpl"] = None,
+ ) -> None:
"""Construct a new :class:`.Operations`
:param migration_context: a :class:`.MigrationContext`
self.schema_obj = schemaobj.SchemaObjects(migration_context)
@classmethod
- def register_operation(cls, name, sourcename=None):
+ def register_operation(
+ cls, name: str, sourcename: Optional[str] = None
+ ) -> Callable:
"""Register a new operation for this class.
This method is normally used to add new operations
return register
@classmethod
- def implementation_for(cls, op_cls):
+ def implementation_for(cls, op_cls: Any) -> Callable:
"""Register an implementation for a given :class:`.MigrateOperation`.
This is part of the operation extensibility API.
@classmethod
@contextmanager
- def context(cls, migration_context):
+ def context(
+ cls, migration_context: "MigrationContext"
+ ) -> Iterator["Operations"]:
op = Operations(migration_context)
op._install_proxy()
yield op
return self.migration_context
- def invoke(self, operation):
+ def invoke(self, operation: "MigrateOperation") -> Any:
"""Given a :class:`.MigrateOperation`, invoke it in terms of
this :class:`.Operations` instance.
)
return fn(self, operation)
- def f(self, name):
+ def f(self, name: str) -> "conv":
"""Indicate a string name that has already had a naming convention
applied to it.
CONSTRAINT ck_bool_t_x CHECK (x in (1, 0)))
The function is rendered in the output of autogenerate when
- a particular constraint name is already converted, for SQLAlchemy
- version **0.9.4 and greater only**. Even though ``naming_convention``
- was introduced in 0.9.2, the string disambiguation service is new
- as of 0.9.4.
+ a particular constraint name is already converted.
"""
- if conv:
- return conv(name)
- else:
- raise NotImplementedError(
- "op.f() feature requires SQLAlchemy 0.9.4 or greater."
- )
+ return conv(name)
- def inline_literal(self, value, type_=None):
+ def inline_literal(
+ self, value: Union[str, int], type_: None = None
+ ) -> "_literal_bindparam":
r"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
"""
return sqla_compat._literal_bindparam(None, value, type_=type_)
- def get_bind(self):
+ def get_bind(self) -> "Connection":
"""Return the current 'bind'.
Under normal circumstances, this is the
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import ForeignKeyConstraint
from ..util.sqla_compat import _remove_column_from_collection
from ..util.sqla_compat import _select
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.sql.elements import ColumnClause
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.functions import Function
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..ddl.impl import DefaultImpl
+
class BatchOperationsImpl:
def __init__(
self.batch = []
@property
- def dialect(self):
+ def dialect(self) -> "Dialect":
return self.operations.impl.dialect
@property
- def impl(self):
+ def impl(self) -> "DefaultImpl":
return self.operations.impl
- def _should_recreate(self):
+ def _should_recreate(self) -> bool:
if self.recreate == "auto":
return self.operations.impl.requires_recreate_in_batch(self)
elif self.recreate == "always":
else:
return False
- def flush(self):
+ def flush(self) -> None:
should_recreate = self._should_recreate()
with _ensure_scope_for_ddl(self.impl.connection):
batch_impl._create(self.impl)
- def alter_column(self, *arg, **kw):
+ def alter_column(self, *arg, **kw) -> None:
self.batch.append(("alter_column", arg, kw))
- def add_column(self, *arg, **kw):
+ def add_column(self, *arg, **kw) -> None:
if (
"insert_before" in kw or "insert_after" in kw
) and not self._should_recreate():
)
self.batch.append(("add_column", arg, kw))
- def drop_column(self, *arg, **kw):
+ def drop_column(self, *arg, **kw) -> None:
self.batch.append(("drop_column", arg, kw))
- def add_constraint(self, const):
+ def add_constraint(self, const: "Constraint") -> None:
self.batch.append(("add_constraint", (const,), {}))
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
self.batch.append(("drop_constraint", (const,), {}))
def rename_table(self, *arg, **kw):
self.batch.append(("rename_table", arg, kw))
- def create_index(self, idx):
+ def create_index(self, idx: "Index") -> None:
self.batch.append(("create_index", (idx,), {}))
- def drop_index(self, idx):
+ def drop_index(self, idx: "Index") -> None:
self.batch.append(("drop_index", (idx,), {}))
def create_table_comment(self, table):
class ApplyBatchImpl:
def __init__(
self,
- impl,
- table,
- table_args,
- table_kwargs,
- reflected,
- partial_reordering=(),
- ):
+ impl: "DefaultImpl",
+ table: "Table",
+ table_args: tuple,
+ table_kwargs: Dict[str, Any],
+ reflected: bool,
+ partial_reordering: tuple = (),
+ ) -> None:
self.impl = impl
self.table = table # this is a Table object
self.table_args = table_args
self.table_kwargs = table_kwargs
self.temp_table_name = self._calc_temp_name(table.name)
- self.new_table = None
+ self.new_table: Optional[Table] = None
self.partial_reordering = partial_reordering # tuple of tuples
- self.add_col_ordering = () # tuple of tuples
+ self.add_col_ordering: Tuple[
+ Tuple[str, str], ...
+ ] = () # tuple of tuples
self.column_transfers = OrderedDict(
(c.name, {"expr": c}) for c in self.table.c
self._grab_table_elements()
@classmethod
- def _calc_temp_name(cls, tablename):
+ def _calc_temp_name(cls, tablename: "quoted_name") -> str:
return ("_alembic_tmp_%s" % tablename)[0:50]
- def _grab_table_elements(self):
+ def _grab_table_elements(self) -> None:
schema = self.table.schema
- self.columns = OrderedDict()
+ self.columns: Dict[str, "Column"] = OrderedDict()
for c in self.table.c:
c_copy = _copy(c, schema=schema)
c_copy.unique = c_copy.index = False
if isinstance(c.type, SchemaEventTarget):
assert c_copy.type is not c.type
self.columns[c.name] = c_copy
- self.named_constraints = {}
+ self.named_constraints: Dict[str, "Constraint"] = {}
self.unnamed_constraints = []
self.col_named_constraints = {}
- self.indexes = {}
- self.new_indexes = {}
+ self.indexes: Dict[str, "Index"] = {}
+ self.new_indexes: Dict[str, "Index"] = {}
for const in self.table.constraints:
if _is_type_bound(const):
for k in self.table.kwargs:
self.table_kwargs.setdefault(k, self.table.kwargs[k])
- def _adjust_self_columns_for_partial_reordering(self):
+ def _adjust_self_columns_for_partial_reordering(self) -> None:
pairs = set()
col_by_idx = list(self.columns)
# this can happen if some columns were dropped and not removed
# from existing_ordering. this should be prevented already, but
# conservatively making sure this didn't happen
- pairs = [p for p in pairs if p[0] != p[1]]
+ pairs_list = [p for p in pairs if p[0] != p[1]]
sorted_ = list(
- topological.sort(pairs, col_by_idx, deterministic_order=True)
+ topological.sort(pairs_list, col_by_idx, deterministic_order=True)
)
self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
self.column_transfers = OrderedDict(
(k, self.column_transfers[k]) for k in sorted_
)
- def _transfer_elements_to_new_table(self):
+ def _transfer_elements_to_new_table(self) -> None:
assert self.new_table is None, "Can only create new table once"
m = MetaData()
if not const_columns.issubset(self.column_transfers):
continue
+ const_copy: "Constraint"
if isinstance(const, ForeignKeyConstraint):
if _fk_is_self_referential(const):
# for self-referential constraint, refer to the
self._setup_referent(m, const)
new_table.append_constraint(const_copy)
- def _gather_indexes_from_both_tables(self):
- idx = []
+ def _gather_indexes_from_both_tables(self) -> List["Index"]:
+ assert self.new_table is not None
+ idx: List[Index] = []
idx.extend(self.indexes.values())
for index in self.new_indexes.values():
idx.append(
)
return idx
- def _setup_referent(self, metadata, constraint):
- spec = constraint.elements[0]._get_colspec()
+ def _setup_referent(
+ self, metadata: "MetaData", constraint: "ForeignKeyConstraint"
+ ) -> None:
+ spec = constraint.elements[
+ 0
+ ]._get_colspec() # type:ignore[attr-defined]
parts = spec.split(".")
tname = parts[-2]
if len(parts) == 3:
if tname != self.temp_table_name:
key = sql_schema._get_table_key(tname, referent_schema)
+
+ def colspec(elem: Any):
+ return elem._get_colspec()
+
if key in metadata.tables:
t = metadata.tables[key]
for elem in constraint.elements:
- colname = elem._get_colspec().split(".")[-1]
+ colname = colspec(elem).split(".")[-1]
if colname not in t.c:
t.append_column(Column(colname, sqltypes.NULLTYPE))
else:
*[
Column(n, sqltypes.NULLTYPE)
for n in [
- elem._get_colspec().split(".")[-1]
+ colspec(elem).split(".")[-1]
for elem in constraint.elements
]
],
schema=referent_schema
)
- def _create(self, op_impl):
+ def _create(self, op_impl: "DefaultImpl") -> None:
self._transfer_elements_to_new_table()
op_impl.prep_table_for_batch(self, self.table)
+ assert self.new_table is not None
op_impl.create_table(self.new_table)
try:
def alter_column(
self,
- table_name,
- column_name,
- nullable=None,
- server_default=False,
- name=None,
- type_=None,
- autoincrement=None,
- comment=False,
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ server_default: Optional[Union["Function", str, bool]] = False,
+ name: Optional[str] = None,
+ type_: Optional["TypeEngine"] = None,
+ autoincrement: None = None,
+ comment: Union[str, "Literal[False]"] = False,
**kw
- ):
+ ) -> None:
existing = self.columns[column_name]
- existing_transfer = self.column_transfers[column_name]
+ existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
if name is not None and name != column_name:
# note that we don't change '.key' - we keep referring
# to the renamed column by its old key in _create(). neat!
# we also ignore the drop_constraint that will come here from
# Operations.implementation_for(alter_column)
if isinstance(existing.type, SchemaEventTarget):
- existing.type._create_events = (
- existing.type.create_constraint
+ existing.type._create_events = ( # type:ignore[attr-defined]
+ existing.type.create_constraint # type:ignore[attr-defined] # noqa
) = False
self.impl.cast_for_batch_migrate(
if server_default is None:
existing.server_default = None
else:
- sql_schema.DefaultClause(server_default)._set_parent(existing)
+ sql_schema.DefaultClause(
+ server_default
+ )._set_parent( # type:ignore[attr-defined]
+ existing
+ )
if autoincrement is not None:
existing.autoincrement = bool(autoincrement)
existing.comment = comment
def _setup_dependencies_for_add_column(
- self, colname, insert_before, insert_after
- ):
+ self,
+ colname: str,
+ insert_before: Optional[str],
+ insert_after: Optional[str],
+ ) -> None:
index_cols = self.existing_ordering
col_indexes = {name: i for i, name in enumerate(index_cols)}
self.add_col_ordering += ((index_cols[-1], colname),)
def add_column(
- self, table_name, column, insert_before=None, insert_after=None, **kw
- ):
+ self,
+ table_name: str,
+ column: "Column",
+ insert_before: Optional[str] = None,
+ insert_after: Optional[str] = None,
+ **kw
+ ) -> None:
self._setup_dependencies_for_add_column(
column.name, insert_before, insert_after
)
self.columns[column.name] = _copy(column, schema=self.table.schema)
self.column_transfers[column.name] = {}
- def drop_column(self, table_name, column, **kw):
+ def drop_column(
+ self, table_name: str, column: Union["ColumnClause", "Column"], **kw
+ ) -> None:
if column.name in self.table.primary_key.columns:
_remove_column_from_collection(
self.table.primary_key.columns, column
"""
- def add_constraint(self, const):
+ def add_constraint(self, const: "Constraint") -> None:
if not const.name:
raise ValueError("Constraint must have a name")
if isinstance(const, sql_schema.PrimaryKeyConstraint):
self.named_constraints[const.name] = const
- def drop_constraint(self, const):
+ def drop_constraint(self, const: "Constraint") -> None:
if not const.name:
raise ValueError("Constraint must have a name")
try:
if col_const.name == const.name:
self.columns[col.name].constraints.remove(col_const)
else:
- const = self.named_constraints.pop(const.name)
+ const = self.named_constraints.pop(cast(str, const.name))
except KeyError:
if _is_type_bound(const):
# type-bound constraints are only included in the new
for col in const.columns:
self.columns[col.name].primary_key = False
- def create_index(self, idx):
+ def create_index(self, idx: "Index") -> None:
self.new_indexes[idx.name] = idx
- def drop_index(self, idx):
+ def drop_index(self, idx: "Index") -> None:
try:
del self.indexes[idx.name]
except KeyError:
+from abc import abstractmethod
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import FrozenSet
+from typing import Iterator
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy.types import NULLTYPE
from .. import util
from ..util import sqla_compat
+if TYPE_CHECKING:
+ from sqlalchemy.sql.dml import Insert
+ from sqlalchemy.sql.dml import Update
+ from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import conv
+ from sqlalchemy.sql.elements import quoted_name
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.functions import Function
+ from sqlalchemy.sql.schema import CheckConstraint
+ from sqlalchemy.sql.schema import Column
+ from sqlalchemy.sql.schema import Computed
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import Identity
+ from sqlalchemy.sql.schema import Index
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import PrimaryKeyConstraint
+ from sqlalchemy.sql.schema import SchemaItem
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.selectable import TableClause
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..autogenerate.rewriter import Rewriter
+ from ..runtime.migration import MigrationContext
+
class MigrateOperation:
"""base class for migration command and organization objects.
"""
return {}
- _mutations = frozenset()
+ _mutations: FrozenSet["Rewriter"] = frozenset()
+
+ def reverse(self) -> "MigrateOperation":
+ raise NotImplementedError
+
+ def to_diff_tuple(self) -> Tuple[Any, ...]:
+ raise NotImplementedError
class AddConstraintOp(MigrateOperation):
raise NotImplementedError()
@classmethod
- def register_add_constraint(cls, type_):
+ def register_add_constraint(cls, type_: str) -> Callable:
def go(klass):
cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint)
return klass
return go
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(cls, constraint: "Constraint") -> "AddConstraintOp":
return cls.add_constraint_ops.dispatch(constraint.__visit_name__)(
constraint
)
- def reverse(self):
+ @abstractmethod
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Constraint":
+ pass
+
+ def reverse(self) -> "DropConstraintOp":
return DropConstraintOp.from_constraint(self.to_constraint())
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Constraint"]:
return ("add_constraint", self.to_constraint())
def __init__(
self,
- constraint_name,
- table_name,
- type_=None,
- schema=None,
- _reverse=None,
- ):
+ constraint_name: Optional[str],
+ table_name: str,
+ type_: Optional[str] = None,
+ schema: Optional[str] = None,
+ _reverse: Optional["AddConstraintOp"] = None,
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.constraint_type = type_
self.schema = schema
self._reverse = _reverse
- def reverse(self):
+ def reverse(self) -> "AddConstraintOp":
return AddConstraintOp.from_constraint(self.to_constraint())
- def to_diff_tuple(self):
+ def to_diff_tuple(
+ self,
+ ) -> Tuple[str, "SchemaItem"]:
if self.constraint_type == "foreignkey":
return ("remove_fk", self.to_constraint())
else:
return ("remove_constraint", self.to_constraint())
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(
+ cls,
+ constraint: "Constraint",
+ ) -> "DropConstraintOp":
types = {
"unique_constraint": "unique",
"foreign_key_constraint": "foreignkey",
_reverse=AddConstraintOp.from_constraint(constraint),
)
- def to_constraint(self):
+ def to_constraint(
+ self,
+ ) -> "Constraint":
if self._reverse is not None:
constraint = self._reverse.to_constraint()
@classmethod
def drop_constraint(
- cls, operations, constraint_name, table_name, type_=None, schema=None
- ):
+ cls,
+ operations: "Operations",
+ constraint_name: str,
+ table_name: str,
+ type_: Optional[str] = None,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
r"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
:param constraint_name: name of the constraint.
return operations.invoke(op)
@classmethod
- def batch_drop_constraint(cls, operations, constraint_name, type_=None):
+ def batch_drop_constraint(
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ type_: Optional[str] = None,
+ ) -> None:
"""Issue a "drop constraint" instruction using the
current batch migration context.
constraint_type = "primarykey"
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw
- ):
+ self,
+ constraint_name: Optional[str],
+ table_name: str,
+ columns: Sequence[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(cls, constraint: "Constraint") -> "CreatePrimaryKeyOp":
constraint_table = sqla_compat._table_for_constraint(constraint)
+ pk_constraint = cast("PrimaryKeyConstraint", constraint)
+
return cls(
- constraint.name,
+ pk_constraint.name,
constraint_table.name,
- constraint.columns.keys(),
+ pk_constraint.columns.keys(),
schema=constraint_table.schema,
- **constraint.dialect_kwargs,
+ **pk_constraint.dialect_kwargs,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "PrimaryKeyConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
+
return schema_obj.primary_key_constraint(
self.constraint_name,
self.table_name,
@classmethod
def create_primary_key(
- cls, operations, constraint_name, table_name, columns, schema=None
- ):
+ cls,
+ operations: "Operations",
+ constraint_name: str,
+ table_name: str,
+ columns: List[str],
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue a "create primary key" instruction using the current
migration context.
return operations.invoke(op)
@classmethod
- def batch_create_primary_key(cls, operations, constraint_name, columns):
+ def batch_create_primary_key(
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ columns: List[str],
+ ) -> None:
"""Issue a "create primary key" instruction using the
current batch migration context.
constraint_type = "unique"
def __init__(
- self, constraint_name, table_name, columns, schema=None, **kw
- ):
+ self,
+ constraint_name: Optional[str],
+ table_name: str,
+ columns: Sequence[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.columns = columns
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(
+ cls, constraint: "Constraint"
+ ) -> "CreateUniqueConstraintOp":
+
constraint_table = sqla_compat._table_for_constraint(constraint)
- kw = {}
- if constraint.deferrable:
- kw["deferrable"] = constraint.deferrable
- if constraint.initially:
- kw["initially"] = constraint.initially
- kw.update(constraint.dialect_kwargs)
+ uq_constraint = cast("UniqueConstraint", constraint)
+
+ kw: dict = {}
+ if uq_constraint.deferrable:
+ kw["deferrable"] = uq_constraint.deferrable
+ if uq_constraint.initially:
+ kw["initially"] = uq_constraint.initially
+ kw.update(uq_constraint.dialect_kwargs)
return cls(
- constraint.name,
+ uq_constraint.name,
constraint_table.name,
- [c.name for c in constraint.columns],
+ [c.name for c in uq_constraint.columns],
schema=constraint_table.schema,
**kw,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "UniqueConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.unique_constraint(
self.constraint_name,
@classmethod
def create_unique_constraint(
cls,
- operations,
- constraint_name,
- table_name,
- columns,
- schema=None,
+ operations: "Operations",
+ constraint_name: Optional[str],
+ table_name: str,
+ columns: Sequence[str],
+ schema: Optional[str] = None,
**kw
- ):
+ ) -> Any:
"""Issue a "create unique constraint" instruction using the
current migration context.
@classmethod
def batch_create_unique_constraint(
- cls, operations, constraint_name, columns, **kw
- ):
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ columns: Sequence[str],
+ **kw
+ ) -> Any:
"""Issue a "create unique constraint" instruction using the
current batch migration context.
def __init__(
self,
- constraint_name,
- source_table,
- referent_table,
- local_cols,
- remote_cols,
+ constraint_name: Optional[str],
+ source_table: str,
+ referent_table: str,
+ local_cols: List[str],
+ remote_cols: List[str],
**kw
- ):
+ ) -> None:
self.constraint_name = constraint_name
self.source_table = source_table
self.referent_table = referent_table
self.remote_cols = remote_cols
self.kw = kw
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "ForeignKeyConstraint"]:
return ("add_fk", self.to_constraint())
@classmethod
- def from_constraint(cls, constraint):
- kw = {}
- if constraint.onupdate:
- kw["onupdate"] = constraint.onupdate
- if constraint.ondelete:
- kw["ondelete"] = constraint.ondelete
- if constraint.initially:
- kw["initially"] = constraint.initially
- if constraint.deferrable:
- kw["deferrable"] = constraint.deferrable
- if constraint.use_alter:
- kw["use_alter"] = constraint.use_alter
+ def from_constraint(cls, constraint: "Constraint") -> "CreateForeignKeyOp":
+
+ fk_constraint = cast("ForeignKeyConstraint", constraint)
+ kw: dict = {}
+ if fk_constraint.onupdate:
+ kw["onupdate"] = fk_constraint.onupdate
+ if fk_constraint.ondelete:
+ kw["ondelete"] = fk_constraint.ondelete
+ if fk_constraint.initially:
+ kw["initially"] = fk_constraint.initially
+ if fk_constraint.deferrable:
+ kw["deferrable"] = fk_constraint.deferrable
+ if fk_constraint.use_alter:
+ kw["use_alter"] = fk_constraint.use_alter
(
source_schema,
ondelete,
deferrable,
initially,
- ) = sqla_compat._fk_spec(constraint)
+ ) = sqla_compat._fk_spec(fk_constraint)
kw["source_schema"] = source_schema
kw["referent_schema"] = target_schema
- kw.update(constraint.dialect_kwargs)
+ kw.update(fk_constraint.dialect_kwargs)
return cls(
- constraint.name,
+ fk_constraint.name,
source_table,
target_table,
source_columns,
**kw,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "ForeignKeyConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.foreign_key_constraint(
self.constraint_name,
@classmethod
def create_foreign_key(
cls,
- operations,
- constraint_name,
- source_table,
- referent_table,
- local_cols,
- remote_cols,
- onupdate=None,
- ondelete=None,
- deferrable=None,
- initially=None,
- match=None,
- source_schema=None,
- referent_schema=None,
+ operations: "Operations",
+ constraint_name: str,
+ source_table: str,
+ referent_table: str,
+ local_cols: List[str],
+ remote_cols: List[str],
+ onupdate: Optional[str] = None,
+ ondelete: Optional[str] = None,
+ deferrable: Optional[bool] = None,
+ initially: Optional[str] = None,
+ match: Optional[str] = None,
+ source_schema: Optional[str] = None,
+ referent_schema: Optional[str] = None,
**dialect_kw
- ):
+ ) -> Optional["Table"]:
"""Issue a "create foreign key" instruction using the
current migration context.
@classmethod
def batch_create_foreign_key(
cls,
- operations,
- constraint_name,
- referent_table,
- local_cols,
- remote_cols,
- referent_schema=None,
- onupdate=None,
- ondelete=None,
- deferrable=None,
- initially=None,
- match=None,
+ operations: "BatchOperations",
+ constraint_name: str,
+ referent_table: str,
+ local_cols: List[str],
+ remote_cols: List[str],
+ referent_schema: Optional[str] = None,
+ onupdate: None = None,
+ ondelete: None = None,
+ deferrable: None = None,
+ initially: None = None,
+ match: None = None,
**dialect_kw
- ):
+ ) -> None:
"""Issue a "create foreign key" instruction using the
current batch migration context.
constraint_type = "check"
def __init__(
- self, constraint_name, table_name, condition, schema=None, **kw
- ):
+ self,
+ constraint_name: Optional[str],
+ table_name: str,
+ condition: Union["TextClause", "ColumnElement[Any]"],
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.condition = condition
self.kw = kw
@classmethod
- def from_constraint(cls, constraint):
+ def from_constraint(
+ cls, constraint: "Constraint"
+ ) -> "CreateCheckConstraintOp":
constraint_table = sqla_compat._table_for_constraint(constraint)
+ ck_constraint = cast("CheckConstraint", constraint)
+
return cls(
- constraint.name,
+ ck_constraint.name,
constraint_table.name,
- constraint.sqltext,
+ cast(
+ "Union[TextClause, ColumnElement[Any]]", ck_constraint.sqltext
+ ),
schema=constraint_table.schema,
- **constraint.dialect_kwargs,
+ **ck_constraint.dialect_kwargs,
)
- def to_constraint(self, migration_context=None):
+ def to_constraint(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "CheckConstraint":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.check_constraint(
self.constraint_name,
@classmethod
def create_check_constraint(
cls,
- operations,
- constraint_name,
- table_name,
- condition,
- schema=None,
+ operations: "Operations",
+ constraint_name: Optional[str],
+ table_name: str,
+ condition: "BinaryExpression",
+ schema: Optional[str] = None,
**kw
- ):
+ ) -> Optional["Table"]:
"""Issue a "create check constraint" instruction using the
current migration context.
@classmethod
def batch_create_check_constraint(
- cls, operations, constraint_name, condition, **kw
- ):
+ cls,
+ operations: "BatchOperations",
+ constraint_name: str,
+ condition: "TextClause",
+ **kw
+ ) -> Optional["Table"]:
"""Issue a "create check constraint" instruction using the
current batch migration context.
"""Represent a create index operation."""
def __init__(
- self, index_name, table_name, columns, schema=None, unique=False, **kw
- ):
+ self,
+ index_name: str,
+ table_name: str,
+ columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+ schema: Optional[str] = None,
+ unique: bool = False,
+ **kw
+ ) -> None:
self.index_name = index_name
self.table_name = table_name
self.columns = columns
self.unique = unique
self.kw = kw
- def reverse(self):
+ def reverse(self) -> "DropIndexOp":
return DropIndexOp.from_index(self.to_index())
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Index"]:
return ("add_index", self.to_index())
@classmethod
- def from_index(cls, index):
+ def from_index(cls, index: "Index") -> "CreateIndexOp":
+ assert index.table is not None
return cls(
index.name,
index.table.name,
**index.kwargs,
)
- def to_index(self, migration_context=None):
+ def to_index(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Index":
schema_obj = schemaobj.SchemaObjects(migration_context)
idx = schema_obj.index(
@classmethod
def create_index(
cls,
- operations,
- index_name,
- table_name,
- columns,
- schema=None,
- unique=False,
+ operations: Operations,
+ index_name: str,
+ table_name: str,
+ columns: Sequence[Union[str, "TextClause", "Function"]],
+ schema: Optional[str] = None,
+ unique: bool = False,
**kw
- ):
+ ) -> Optional["Table"]:
r"""Issue a "create index" instruction using the current
migration context.
return operations.invoke(op)
@classmethod
- def batch_create_index(cls, operations, index_name, columns, **kw):
+ def batch_create_index(
+ cls,
+ operations: "BatchOperations",
+ index_name: str,
+ columns: List[str],
+ **kw
+ ) -> Optional["Table"]:
"""Issue a "create index" instruction using the
current batch migration context.
"""Represent a drop index operation."""
def __init__(
- self, index_name, table_name=None, schema=None, _reverse=None, **kw
- ):
+ self,
+ index_name: Union["quoted_name", str, "conv"],
+ table_name: Optional[str] = None,
+ schema: Optional[str] = None,
+ _reverse: Optional["CreateIndexOp"] = None,
+ **kw
+ ) -> None:
self.index_name = index_name
self.table_name = table_name
self.schema = schema
self._reverse = _reverse
self.kw = kw
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Index"]:
return ("remove_index", self.to_index())
- def reverse(self):
+ def reverse(self) -> "CreateIndexOp":
return CreateIndexOp.from_index(self.to_index())
@classmethod
- def from_index(cls, index):
+ def from_index(cls, index: "Index") -> "DropIndexOp":
+ assert index.table is not None
return cls(
index.name,
index.table.name,
**index.kwargs,
)
- def to_index(self, migration_context=None):
+ def to_index(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Index":
schema_obj = schemaobj.SchemaObjects(migration_context)
# need a dummy column name here since SQLAlchemy
@classmethod
def drop_index(
- cls, operations, index_name, table_name=None, schema=None, **kw
- ):
+ cls,
+ operations: "Operations",
+ index_name: str,
+ table_name: Optional[str] = None,
+ schema: Optional[str] = None,
+ **kw
+ ) -> Optional["Table"]:
r"""Issue a "drop index" instruction using the current
migration context.
return operations.invoke(op)
@classmethod
- def batch_drop_index(cls, operations, index_name, **kw):
+ def batch_drop_index(
+ cls, operations: BatchOperations, index_name: str, **kw
+ ) -> Optional["Table"]:
"""Issue a "drop index" instruction using the
current batch migration context.
def __init__(
self,
- table_name,
- columns,
- schema=None,
- _namespace_metadata=None,
- _constraints_included=False,
+ table_name: str,
+ columns: Sequence[Union["Column", "Constraint"]],
+ schema: Optional[str] = None,
+ _namespace_metadata: Optional["MetaData"] = None,
+ _constraints_included: bool = False,
**kw
- ):
+ ) -> None:
self.table_name = table_name
self.columns = columns
self.schema = schema
self._namespace_metadata = _namespace_metadata
self._constraints_included = _constraints_included
- def reverse(self):
+ def reverse(self) -> "DropTableOp":
return DropTableOp.from_table(
self.to_table(), _namespace_metadata=self._namespace_metadata
)
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Table"]:
return ("add_table", self.to_table())
@classmethod
- def from_table(cls, table, _namespace_metadata=None):
+ def from_table(
+ cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
+ ) -> "CreateTableOp":
if _namespace_metadata is None:
_namespace_metadata = table.metadata
return cls(
table.name,
- list(table.c) + list(table.constraints),
+ list(table.c) + list(table.constraints), # type:ignore[arg-type]
schema=table.schema,
_namespace_metadata=_namespace_metadata,
# given a Table() object, this Table will contain full Index()
# not doubled up. see #844 #848
_constraints_included=True,
comment=table.comment,
- info=table.info.copy(),
+ info=dict(table.info),
prefixes=list(table._prefixes),
**table.kwargs,
)
- def to_table(self, migration_context=None):
+ def to_table(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Table":
schema_obj = schemaobj.SchemaObjects(migration_context)
return schema_obj.table(
)
@classmethod
- def create_table(cls, operations, table_name, *columns, **kw):
+ def create_table(
+ cls, operations: "Operations", table_name: str, *columns, **kw
+ ) -> Optional["Table"]:
r"""Issue a "create table" instruction using the current migration
context.
class DropTableOp(MigrateOperation):
"""Represent a drop table operation."""
- def __init__(self, table_name, schema=None, table_kw=None, _reverse=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[str] = None,
+ table_kw: Optional[MutableMapping[Any, Any]] = None,
+ _reverse: Optional["CreateTableOp"] = None,
+ ) -> None:
self.table_name = table_name
self.schema = schema
self.table_kw = table_kw or {}
self.prefixes = self.table_kw.pop("prefixes", None)
self._reverse = _reverse
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Tuple[str, "Table"]:
return ("remove_table", self.to_table())
- def reverse(self):
+ def reverse(self) -> "CreateTableOp":
return CreateTableOp.from_table(self.to_table())
@classmethod
- def from_table(cls, table, _namespace_metadata=None):
+ def from_table(
+ cls, table: "Table", _namespace_metadata: Optional["MetaData"] = None
+ ) -> "DropTableOp":
return cls(
table.name,
schema=table.schema,
table_kw={
"comment": table.comment,
- "info": table.info.copy(),
+ "info": dict(table.info),
"prefixes": list(table._prefixes),
**table.kwargs,
},
),
)
- def to_table(self, migration_context=None):
+ def to_table(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Table":
if self._reverse:
cols_and_constraints = self._reverse.columns
else:
info=self.info.copy() if self.info else {},
prefixes=list(self.prefixes) if self.prefixes else [],
schema=self.schema,
- _constraints_included=bool(self._reverse)
- and self._reverse._constraints_included,
+ _constraints_included=self._reverse._constraints_included
+ if self._reverse
+ else False,
**self.table_kw,
)
return t
@classmethod
- def drop_table(cls, operations, table_name, schema=None, **kw):
+ def drop_table(
+ cls,
+ operations: "Operations",
+ table_name: str,
+ schema: Optional[str] = None,
+ **kw: Any
+ ) -> None:
r"""Issue a "drop table" instruction using the current
migration context.
class AlterTableOp(MigrateOperation):
"""Represent an alter table operation."""
- def __init__(self, table_name, schema=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[str] = None,
+ ) -> None:
self.table_name = table_name
self.schema = schema
class RenameTableOp(AlterTableOp):
"""Represent a rename table operation."""
- def __init__(self, old_table_name, new_table_name, schema=None):
+ def __init__(
+ self,
+ old_table_name: str,
+ new_table_name: str,
+ schema: Optional[str] = None,
+ ) -> None:
super(RenameTableOp, self).__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
@classmethod
def rename_table(
- cls, operations, old_table_name, new_table_name, schema=None
- ):
+ cls,
+ operations: "Operations",
+ old_table_name: str,
+ new_table_name: str,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Emit an ALTER TABLE to rename a table.
:param old_table_name: old name.
"""Represent a COMMENT ON `table` operation."""
def __init__(
- self, table_name, comment, schema=None, existing_comment=None
- ):
+ self,
+ table_name: str,
+ comment: Optional[str],
+ schema: Optional[str] = None,
+ existing_comment: Optional[str] = None,
+ ) -> None:
self.table_name = table_name
self.comment = comment
self.existing_comment = existing_comment
@classmethod
def create_table_comment(
cls,
- operations,
- table_name,
- comment,
- existing_comment=None,
- schema=None,
- ):
+ operations: "Operations",
+ table_name: str,
+ comment: Optional[str],
+ existing_comment: None = None,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Emit a COMMENT ON operation to set the comment for a table.
.. versionadded:: 1.0.6
class DropTableCommentOp(AlterTableOp):
"""Represent an operation to remove the comment from a table."""
- def __init__(self, table_name, schema=None, existing_comment=None):
+ def __init__(
+ self,
+ table_name: str,
+ schema: Optional[str] = None,
+ existing_comment: Optional[str] = None,
+ ) -> None:
self.table_name = table_name
self.existing_comment = existing_comment
self.schema = schema
@classmethod
def drop_table_comment(
- cls, operations, table_name, existing_comment=None, schema=None
- ):
+ cls,
+ operations: "Operations",
+ table_name: str,
+ existing_comment: Optional[str] = None,
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue a "drop table comment" operation to
remove an existing comment set on a table.
def __init__(
self,
- table_name,
- column_name,
- schema=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- existing_comment=None,
- modify_nullable=None,
- modify_comment=False,
- modify_server_default=False,
- modify_name=None,
- modify_type=None,
+ table_name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ existing_type: Optional[Any] = None,
+ existing_server_default: Any = False,
+ existing_nullable: Optional[bool] = None,
+ existing_comment: Optional[str] = None,
+ modify_nullable: Optional[bool] = None,
+ modify_comment: Optional[Union[str, bool]] = False,
+ modify_server_default: Any = False,
+ modify_name: Optional[str] = None,
+ modify_type: Optional[Any] = None,
**kw
- ):
+ ) -> None:
super(AlterColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.existing_type = existing_type
self.modify_type = modify_type
self.kw = kw
- def to_diff_tuple(self):
+ def to_diff_tuple(self) -> Any:
col_diff = []
schema, tname, cname = self.schema, self.table_name, self.column_name
return col_diff
- def has_changes(self):
+ def has_changes(self) -> bool:
hc1 = (
self.modify_nullable is not None
or self.modify_server_default is not False
else:
return False
- def reverse(self):
+ def reverse(self) -> "AlterColumnOp":
kw = self.kw.copy()
kw["existing_type"] = self.existing_type
@classmethod
def alter_column(
cls,
- operations,
- table_name,
- column_name,
- nullable=None,
- comment=False,
- server_default=False,
- new_column_name=None,
- type_=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- existing_comment=None,
- schema=None,
+ operations: Operations,
+ table_name: str,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ comment: Optional[Union[str, bool]] = False,
+ server_default: Any = False,
+ new_column_name: Optional[str] = None,
+ type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
+ existing_type: Optional[
+ Union["TypeEngine", Type["TypeEngine"]]
+ ] = None,
+ existing_server_default: Optional[
+ Union[str, bool, "Identity", "Computed"]
+ ] = False,
+ existing_nullable: Optional[bool] = None,
+ existing_comment: Optional[str] = None,
+ schema: Optional[str] = None,
**kw
- ):
+ ) -> Optional["Table"]:
r"""Issue an "alter column" instruction using the
current migration context.
@classmethod
def batch_alter_column(
cls,
- operations,
- column_name,
- nullable=None,
- comment=False,
- server_default=False,
- new_column_name=None,
- type_=None,
- existing_type=None,
- existing_server_default=False,
- existing_nullable=None,
- existing_comment=None,
- insert_before=None,
- insert_after=None,
+ operations: BatchOperations,
+ column_name: str,
+ nullable: Optional[bool] = None,
+ comment: bool = False,
+ server_default: Union["Function", bool] = False,
+ new_column_name: Optional[str] = None,
+ type_: Optional[Union["TypeEngine", Type["TypeEngine"]]] = None,
+ existing_type: Optional[
+ Union["TypeEngine", Type["TypeEngine"]]
+ ] = None,
+ existing_server_default: bool = False,
+ existing_nullable: None = None,
+ existing_comment: None = None,
+ insert_before: None = None,
+ insert_after: None = None,
**kw
- ):
+ ) -> Optional["Table"]:
"""Issue an "alter column" instruction using the current
batch migration context.
class AddColumnOp(AlterTableOp):
"""Represent an add column operation."""
- def __init__(self, table_name, column, schema=None, **kw):
+ def __init__(
+ self,
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ **kw
+ ) -> None:
super(AddColumnOp, self).__init__(table_name, schema=schema)
self.column = column
self.kw = kw
- def reverse(self):
+ def reverse(self) -> "DropColumnOp":
return DropColumnOp.from_column_and_tablename(
self.schema, self.table_name, self.column
)
- def to_diff_tuple(self):
+ def to_diff_tuple(
+ self,
+ ) -> Tuple[str, Optional[str], str, "Column"]:
return ("add_column", self.schema, self.table_name, self.column)
- def to_column(self):
+ def to_column(self) -> "Column":
return self.column
@classmethod
- def from_column(cls, col):
+ def from_column(cls, col: "Column") -> "AddColumnOp":
return cls(col.table.name, col, schema=col.table.schema)
@classmethod
- def from_column_and_tablename(cls, schema, tname, col):
+ def from_column_and_tablename(
+ cls,
+ schema: Optional[str],
+ tname: str,
+ col: "Column",
+ ) -> "AddColumnOp":
return cls(tname, col, schema=schema)
@classmethod
- def add_column(cls, operations, table_name, column, schema=None):
+ def add_column(
+ cls,
+ operations: "Operations",
+ table_name: str,
+ column: "Column",
+ schema: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue an "add column" instruction using the current
migration context.
@classmethod
def batch_add_column(
- cls, operations, column, insert_before=None, insert_after=None
- ):
+ cls,
+ operations: "BatchOperations",
+ column: "Column",
+ insert_before: Optional[str] = None,
+ insert_after: Optional[str] = None,
+ ) -> Optional["Table"]:
"""Issue an "add column" instruction using the current
batch migration context.
"""Represent a drop column operation."""
def __init__(
- self, table_name, column_name, schema=None, _reverse=None, **kw
- ):
+ self,
+ table_name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ _reverse: Optional["AddColumnOp"] = None,
+ **kw
+ ) -> None:
super(DropColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
self.kw = kw
self._reverse = _reverse
- def to_diff_tuple(self):
+ def to_diff_tuple(
+ self,
+ ) -> Tuple[str, Optional[str], str, "Column"]:
return (
"remove_column",
self.schema,
self.to_column(),
)
- def reverse(self):
+ def reverse(self) -> "AddColumnOp":
if self._reverse is None:
raise ValueError(
"operation is not reversible; "
)
@classmethod
- def from_column_and_tablename(cls, schema, tname, col):
+ def from_column_and_tablename(
+ cls,
+ schema: Optional[str],
+ tname: str,
+ col: "Column",
+ ) -> "DropColumnOp":
return cls(
tname,
col.name,
_reverse=AddColumnOp.from_column_and_tablename(schema, tname, col),
)
- def to_column(self, migration_context=None):
+ def to_column(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> "Column":
if self._reverse is not None:
return self._reverse.column
schema_obj = schemaobj.SchemaObjects(migration_context)
@classmethod
def drop_column(
- cls, operations, table_name, column_name, schema=None, **kw
- ):
+ cls,
+ operations: "Operations",
+ table_name: str,
+ column_name: str,
+ schema: Optional[str] = None,
+ **kw
+ ) -> Optional["Table"]:
"""Issue a "drop column" instruction using the current
migration context.
return operations.invoke(op)
@classmethod
- def batch_drop_column(cls, operations, column_name, **kw):
+ def batch_drop_column(
+ cls, operations: "BatchOperations", column_name: str, **kw
+ ) -> Optional["Table"]:
"""Issue a "drop column" instruction using the current
batch migration context.
class BulkInsertOp(MigrateOperation):
"""Represent a bulk insert operation."""
- def __init__(self, table, rows, multiinsert=True):
+ def __init__(
+ self,
+ table: Union["Table", "TableClause"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
self.table = table
self.rows = rows
self.multiinsert = multiinsert
@classmethod
- def bulk_insert(cls, operations, table, rows, multiinsert=True):
+ def bulk_insert(
+ cls,
+ operations: Operations,
+ table: Union["Table", "TableClause"],
+ rows: List[dict],
+ multiinsert: bool = True,
+ ) -> None:
"""Issue a "bulk insert" operation using the current
migration context.
class ExecuteSQLOp(MigrateOperation):
"""Represent an execute SQL operation."""
- def __init__(self, sqltext, execution_options=None):
+ def __init__(
+ self,
+ sqltext: Union["Update", str, "Insert", "TextClause"],
+ execution_options: None = None,
+ ) -> None:
self.sqltext = sqltext
self.execution_options = execution_options
@classmethod
- def execute(cls, operations, sqltext, execution_options=None):
+ def execute(
+ cls,
+ operations: Operations,
+ sqltext: Union[str, "TextClause", "Update"],
+ execution_options: None = None,
+ ) -> Optional["Table"]:
r"""Execute the given SQL using the current migration context.
The given SQL can be a plain string, e.g.::
class OpContainer(MigrateOperation):
"""Represent a sequence of operations operation."""
- def __init__(self, ops=()):
- self.ops = ops
+ def __init__(self, ops: Sequence[MigrateOperation] = ()) -> None:
+ self.ops = list(ops)
- def is_empty(self):
+ def is_empty(self) -> bool:
return not self.ops
- def as_diffs(self):
+ def as_diffs(self) -> Any:
return list(OpContainer._ops_as_diffs(self))
@classmethod
- def _ops_as_diffs(cls, migrations):
+ def _ops_as_diffs(
+ cls, migrations: "OpContainer"
+ ) -> Iterator[Tuple[Any, ...]]:
for op in migrations.ops:
if hasattr(op, "ops"):
- for sub_op in cls._ops_as_diffs(op):
+ for sub_op in cls._ops_as_diffs(cast("OpContainer", op)):
yield sub_op
else:
yield op.to_diff_tuple()
class ModifyTableOps(OpContainer):
"""Contains a sequence of operations that all apply to a single Table."""
- def __init__(self, table_name, ops, schema=None):
+ def __init__(
+ self,
+ table_name: str,
+ ops: Sequence[MigrateOperation],
+ schema: Optional[str] = None,
+ ) -> None:
super(ModifyTableOps, self).__init__(ops)
self.table_name = table_name
self.schema = schema
- def reverse(self):
+ def reverse(self) -> "ModifyTableOps":
return ModifyTableOps(
self.table_name,
ops=list(reversed([op.reverse() for op in self.ops])),
"""
- def __init__(self, ops=(), upgrade_token="upgrades"):
+ def __init__(
+ self,
+ ops: Sequence[MigrateOperation] = (),
+ upgrade_token: str = "upgrades",
+ ) -> None:
super(UpgradeOps, self).__init__(ops=ops)
self.upgrade_token = upgrade_token
- def reverse_into(self, downgrade_ops):
- downgrade_ops.ops[:] = list(
+ def reverse_into(self, downgrade_ops: "DowngradeOps") -> "DowngradeOps":
+ downgrade_ops.ops[:] = list( # type:ignore[index]
reversed([op.reverse() for op in self.ops])
)
return downgrade_ops
- def reverse(self):
+ def reverse(self) -> "DowngradeOps":
return self.reverse_into(DowngradeOps(ops=[]))
"""
- def __init__(self, ops=(), downgrade_token="downgrades"):
+ def __init__(
+ self,
+ ops: Sequence[MigrateOperation] = (),
+ downgrade_token: str = "downgrades",
+ ) -> None:
super(DowngradeOps, self).__init__(ops=ops)
self.downgrade_token = downgrade_token
"""
+ _needs_render: Optional[bool]
+
def __init__(
self,
- rev_id,
- upgrade_ops,
- downgrade_ops,
- message=None,
- imports=set(),
- head=None,
- splice=None,
- branch_label=None,
- version_path=None,
- depends_on=None,
- ):
+ rev_id: Optional[str],
+ upgrade_ops: "UpgradeOps",
+ downgrade_ops: "DowngradeOps",
+ message: Optional[str] = None,
+ imports: Set[str] = set(),
+ head: Optional[str] = None,
+ splice: Optional[bool] = None,
+ branch_label: Optional[str] = None,
+ version_path: Optional[str] = None,
+ depends_on: Optional[Union[str, Sequence[str]]] = None,
+ ) -> None:
self.rev_id = rev_id
self.message = message
self.imports = imports
assert isinstance(elem, DowngradeOps)
@property
- def upgrade_ops_list(self):
+ def upgrade_ops_list(self) -> List["UpgradeOps"]:
"""A list of :class:`.UpgradeOps` instances.
This is used in place of the :attr:`.MigrationScript.upgrade_ops`
return self._upgrade_ops
@property
- def downgrade_ops_list(self):
+ def downgrade_ops_list(self) -> List["DowngradeOps"]:
"""A list of :class:`.DowngradeOps` instances.
This is used in place of the :attr:`.MigrationScript.downgrade_ops`
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
from sqlalchemy import schema as sa_schema
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from ..util import sqla_compat
from ..util.compat import string_types
+if TYPE_CHECKING:
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.elements import TextClause
+ from sqlalchemy.sql.schema import CheckConstraint
+ from sqlalchemy.sql.schema import ForeignKey
+ from sqlalchemy.sql.schema import ForeignKeyConstraint
+ from sqlalchemy.sql.schema import MetaData
+ from sqlalchemy.sql.schema import PrimaryKeyConstraint
+ from sqlalchemy.sql.schema import Table
+ from sqlalchemy.sql.schema import UniqueConstraint
+ from sqlalchemy.sql.type_api import TypeEngine
+
+ from ..runtime.migration import MigrationContext
+
class SchemaObjects:
- def __init__(self, migration_context=None):
+ def __init__(
+ self, migration_context: Optional["MigrationContext"] = None
+ ) -> None:
self.migration_context = migration_context
- def primary_key_constraint(self, name, table_name, cols, schema=None):
+ def primary_key_constraint(
+ self,
+ name: Optional[str],
+ table_name: str,
+ cols: Sequence[str],
+ schema: Optional[str] = None,
+ **dialect_kw
+ ) -> "PrimaryKeyConstraint":
m = self.metadata()
columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
t = sa_schema.Table(table_name, m, *columns, schema=schema)
- p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name)
+ p = sa_schema.PrimaryKeyConstraint(
+ *[t.c[n] for n in cols], name=name, **dialect_kw
+ )
return p
def foreign_key_constraint(
self,
- name,
- source,
- referent,
- local_cols,
- remote_cols,
- onupdate=None,
- ondelete=None,
- deferrable=None,
- source_schema=None,
- referent_schema=None,
- initially=None,
- match=None,
+ name: Optional[str],
+ source: str,
+ referent: str,
+ local_cols: List[str],
+ remote_cols: List[str],
+ onupdate: Optional[str] = None,
+ ondelete: Optional[str] = None,
+ deferrable: Optional[bool] = None,
+ source_schema: Optional[str] = None,
+ referent_schema: Optional[str] = None,
+ initially: Optional[str] = None,
+ match: Optional[str] = None,
**dialect_kw
- ):
+ ) -> "ForeignKeyConstraint":
m = self.metadata()
if source == referent and source_schema == referent_schema:
t1_cols = local_cols + remote_cols
return f
- def unique_constraint(self, name, source, local_cols, schema=None, **kw):
+ def unique_constraint(
+ self,
+ name: Optional[str],
+ source: str,
+ local_cols: Sequence[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> "UniqueConstraint":
t = sa_schema.Table(
source,
self.metadata(),
t.append_constraint(uq)
return uq
- def check_constraint(self, name, source, condition, schema=None, **kw):
+ def check_constraint(
+ self,
+ name: Optional[str],
+ source: str,
+ condition: Union["TextClause", "ColumnElement[Any]"],
+ schema: Optional[str] = None,
+ **kw
+ ) -> Union["CheckConstraint"]:
t = sa_schema.Table(
source,
self.metadata(),
t.append_constraint(ck)
return ck
- def generic_constraint(self, name, table_name, type_, schema=None, **kw):
+ def generic_constraint(
+ self,
+ name: Optional[str],
+ table_name: str,
+ type_: Optional[str],
+ schema: Optional[str] = None,
+ **kw
+ ) -> Any:
t = self.table(table_name, schema=schema)
- types = {
+ types: Dict[Optional[str], Any] = {
"foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
[], [], name=name
),
t.append_constraint(const)
return const
- def metadata(self):
+ def metadata(self) -> "MetaData":
kw = {}
if (
self.migration_context is not None
kw["naming_convention"] = mt.naming_convention
return sa_schema.MetaData(**kw)
- def table(self, name, *columns, **kw):
+ def table(self, name: str, *columns, **kw) -> "Table":
m = self.metadata()
cols = [
self._ensure_table_for_fk(m, f)
return t
- def column(self, name, type_, **kw):
+ def column(self, name: str, type_: "TypeEngine", **kw) -> "Column":
return sa_schema.Column(name, type_, **kw)
- def index(self, name, tablename, columns, schema=None, **kw):
+ def index(
+ self,
+ name: str,
+ tablename: Optional[str],
+ columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
+ schema: Optional[str] = None,
+ **kw
+ ) -> "Index":
t = sa_schema.Table(
tablename or "no_table",
self.metadata(),
)
return idx
- def _parse_table_key(self, table_key):
+ def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]:
if "." in table_key:
tokens = table_key.split(".")
- sname = ".".join(tokens[0:-1])
+ sname: Optional[str] = ".".join(tokens[0:-1])
tname = tokens[-1]
else:
tname = table_key
sname = None
return (sname, tname)
- def _ensure_table_for_fk(self, metadata, fk):
+ def _ensure_table_for_fk(
+ self, metadata: "MetaData", fk: "ForeignKey"
+ ) -> None:
"""create a placeholder Table object for the referent of a
ForeignKey.
"""
- if isinstance(fk._colspec, string_types):
- table_key, cname = fk._colspec.rsplit(".", 1)
+ if isinstance(fk._colspec, string_types): # type:ignore[attr-defined]
+ table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined]
+ ".", 1
+ )
sname, tname = self._parse_table_key(table_key)
if table_key not in metadata.tables:
rel_t = sa_schema.Table(tname, metadata, schema=sname)
+from typing import TYPE_CHECKING
+
from sqlalchemy import schema as sa_schema
from . import ops
from .base import Operations
from ..util.sqla_compat import _copy
+if TYPE_CHECKING:
+ from sqlalchemy.sql.schema import Table
+
@Operations.implementation_for(ops.AlterColumnOp)
-def alter_column(operations, operation):
+def alter_column(
+ operations: "Operations", operation: "ops.AlterColumnOp"
+) -> None:
compiler = operations.impl.dialect.statement_compiler(
operations.impl.dialect, None
@Operations.implementation_for(ops.DropTableOp)
-def drop_table(operations, operation):
+def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None:
operations.impl.drop_table(
operation.to_table(operations.migration_context)
)
@Operations.implementation_for(ops.DropColumnOp)
-def drop_column(operations, operation):
+def drop_column(
+ operations: "Operations", operation: "ops.DropColumnOp"
+) -> None:
column = operation.to_column(operations.migration_context)
operations.impl.drop_column(
operation.table_name, column, schema=operation.schema, **operation.kw
@Operations.implementation_for(ops.CreateIndexOp)
-def create_index(operations, operation):
+def create_index(
+ operations: "Operations", operation: "ops.CreateIndexOp"
+) -> None:
idx = operation.to_index(operations.migration_context)
operations.impl.create_index(idx)
@Operations.implementation_for(ops.DropIndexOp)
-def drop_index(operations, operation):
+def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None:
operations.impl.drop_index(
operation.to_index(operations.migration_context)
)
@Operations.implementation_for(ops.CreateTableOp)
-def create_table(operations, operation):
+def create_table(
+ operations: "Operations", operation: "ops.CreateTableOp"
+) -> "Table":
table = operation.to_table(operations.migration_context)
operations.impl.create_table(table)
return table
@Operations.implementation_for(ops.RenameTableOp)
-def rename_table(operations, operation):
+def rename_table(
+ operations: "Operations", operation: "ops.RenameTableOp"
+) -> None:
operations.impl.rename_table(
operation.table_name, operation.new_table_name, schema=operation.schema
)
@Operations.implementation_for(ops.CreateTableCommentOp)
-def create_table_comment(operations, operation):
+def create_table_comment(
+ operations: "Operations", operation: "ops.CreateTableCommentOp"
+) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.create_table_comment(table)
@Operations.implementation_for(ops.DropTableCommentOp)
-def drop_table_comment(operations, operation):
+def drop_table_comment(
+ operations: "Operations", operation: "ops.DropTableCommentOp"
+) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.drop_table_comment(table)
@Operations.implementation_for(ops.AddColumnOp)
-def add_column(operations, operation):
+def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None:
table_name = operation.table_name
column = operation.column
schema = operation.schema
@Operations.implementation_for(ops.AddConstraintOp)
-def create_constraint(operations, operation):
+def create_constraint(
+ operations: "Operations", operation: "ops.AddConstraintOp"
+) -> None:
operations.impl.add_constraint(
operation.to_constraint(operations.migration_context)
)
@Operations.implementation_for(ops.DropConstraintOp)
-def drop_constraint(operations, operation):
+def drop_constraint(
+ operations: "Operations", operation: "ops.DropConstraintOp"
+) -> None:
operations.impl.drop_constraint(
operations.schema_obj.generic_constraint(
operation.constraint_name,
@Operations.implementation_for(ops.BulkInsertOp)
-def bulk_insert(operations, operation):
+def bulk_insert(
+ operations: "Operations", operation: "ops.BulkInsertOp"
+) -> None:
operations.impl.bulk_insert(
operation.table, operation.rows, multiinsert=operation.multiinsert
)
@Operations.implementation_for(ops.ExecuteSQLOp)
-def execute_sql(operations, operation):
+def execute_sql(
+ operations: "Operations", operation: "ops.ExecuteSQLOp"
+) -> None:
operations.migration_context.impl.execute(
operation.sqltext, execution_options=operation.execution_options
)
+from typing import Callable
+from typing import ContextManager
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import TextIO
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
from .migration import MigrationContext
from .. import util
from ..operations import Operations
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from sqlalchemy.engine.base import Connection
+ from sqlalchemy.sql.schema import MetaData
+
+ from .migration import _ProxyTransaction
+ from ..config import Config
+ from ..script.base import ScriptDirectory
+
+_RevNumber = Optional[Union[str, Tuple[str, ...]]]
+
class EnvironmentContext(util.ModuleClsProxy):
"""
- _migration_context = None
+ _migration_context: Optional["MigrationContext"] = None
- config = None
+ config: "Config" = None # type:ignore[assignment]
"""An instance of :class:`.Config` representing the
configuration file contents as well as other variables
set programmatically within it."""
- script = None
+ script: "ScriptDirectory" = None # type:ignore[assignment]
"""An instance of :class:`.ScriptDirectory` which provides
programmatic access to version files within the ``versions/``
directory.
"""
- def __init__(self, config, script, **kw):
+ def __init__(
+ self, config: "Config", script: "ScriptDirectory", **kw
+ ) -> None:
r"""Construct a new :class:`.EnvironmentContext`.
:param config: a :class:`.Config` instance.
self.script = script
self.context_opts = kw
- def __enter__(self):
+ def __enter__(self) -> "EnvironmentContext":
"""Establish a context which provides a
:class:`.EnvironmentContext` object to
env.py scripts.
self._install_proxy()
return self
- def __exit__(self, *arg, **kw):
+ def __exit__(self, *arg, **kw) -> None:
self._remove_proxy()
- def is_offline_mode(self):
+ def is_offline_mode(self) -> bool:
"""Return True if the current migrations environment
is running in "offline mode".
"""
return self.get_context().impl.transactional_ddl
- def requires_connection(self):
+ def requires_connection(self) -> bool:
return not self.is_offline_mode()
- def get_head_revision(self):
+ def get_head_revision(self) -> _RevNumber:
"""Return the hex identifier of the 'head' script revision.
If the script directory has multiple heads, this
"""
return self.script.as_revision_number("head")
- def get_head_revisions(self):
+ def get_head_revisions(self) -> _RevNumber:
"""Return the hex identifier of the 'heads' script revision(s).
This returns a tuple containing the version number of all
"""
return self.script.as_revision_number("heads")
- def get_starting_revision_argument(self):
+ def get_starting_revision_argument(self) -> _RevNumber:
"""Return the 'starting revision' argument,
if the revision was passed using ``start:end``.
"No starting revision argument is available."
)
- def get_revision_argument(self):
+ def get_revision_argument(self) -> _RevNumber:
"""Get the 'destination' revision argument.
This is typically the argument passed to the
self.context_opts["destination_rev"]
)
- def get_tag_argument(self):
+ def get_tag_argument(self) -> Optional[str]:
"""Return the value passed for the ``--tag`` argument, if any.
The ``--tag`` argument is not used directly by Alembic,
"""
return self.context_opts.get("tag", None)
- def get_x_argument(self, as_dictionary=False):
+ @overload
+ def get_x_argument( # type:ignore[misc]
+ self, as_dictionary: "Literal[False]" = ...
+ ) -> List[str]:
+ ...
+
+ @overload
+ def get_x_argument( # type:ignore[misc]
+ self, as_dictionary: "Literal[True]" = ...
+ ) -> Dict[str, str]:
+ ...
+
+ def get_x_argument(self, as_dictionary: bool = False):
"""Return the value(s) passed for the ``-x`` argument, if any.
The ``-x`` argument is an open ended flag that allows any user-defined
def configure(
self,
- connection=None,
- url=None,
- dialect_name=None,
- dialect_opts=None,
- transactional_ddl=None,
- transaction_per_migration=False,
- output_buffer=None,
- starting_rev=None,
- tag=None,
- template_args=None,
- render_as_batch=False,
- target_metadata=None,
- include_name=None,
- include_object=None,
- include_schemas=False,
- process_revision_directives=None,
- compare_type=False,
- compare_server_default=False,
- render_item=None,
- literal_binds=False,
- upgrade_token="upgrades",
- downgrade_token="downgrades",
- alembic_module_prefix="op.",
- sqlalchemy_module_prefix="sa.",
- user_module_prefix=None,
- on_version_apply=None,
+ connection: Optional["Connection"] = None,
+ url: Optional[str] = None,
+ dialect_name: Optional[str] = None,
+ dialect_opts: Optional[dict] = None,
+ transactional_ddl: Optional[bool] = None,
+ transaction_per_migration: bool = False,
+ output_buffer: Optional[TextIO] = None,
+ starting_rev: Optional[str] = None,
+ tag: Optional[str] = None,
+ template_args: Optional[dict] = None,
+ render_as_batch: bool = False,
+ target_metadata: Optional["MetaData"] = None,
+ include_name: Optional[Callable] = None,
+ include_object: Optional[Callable] = None,
+ include_schemas: bool = False,
+ process_revision_directives: Optional[Callable] = None,
+ compare_type: bool = False,
+ compare_server_default: bool = False,
+ render_item: Optional[Callable] = None,
+ literal_binds: bool = False,
+ upgrade_token: str = "upgrades",
+ downgrade_token: str = "downgrades",
+ alembic_module_prefix: str = "op.",
+ sqlalchemy_module_prefix: str = "sa.",
+ user_module_prefix: Optional[str] = None,
+ on_version_apply: Optional[Callable] = None,
**kw
- ):
+ ) -> None:
"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
connectivity and other configuration to a series of
opts=opts,
)
- def run_migrations(self, **kw):
+ def run_migrations(self, **kw) -> None:
"""Run migrations as determined by the current command line
configuration
as well as versioning information present (or not) in the current
first been made available via :meth:`.configure`.
"""
+ assert self._migration_context is not None
with Operations.context(self._migration_context):
self.get_context().run_migrations(**kw)
"""
self.get_context().impl.static_output(text)
- def begin_transaction(self):
+ def begin_transaction(
+ self,
+ ) -> Union["_ProxyTransaction", ContextManager]:
"""Return a context manager that will
enclose an operation within a "transaction",
as defined by the environment's offline
return self.get_context().begin_transaction()
- def get_context(self):
+ def get_context(self) -> "MigrationContext":
"""Return the current :class:`.MigrationContext` object.
If :meth:`.EnvironmentContext.configure` has not been
from contextlib import contextmanager
import logging
import sys
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import ContextManager
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from sqlalchemy import Column
from sqlalchemy import literal_column
from ..util import sqla_compat
from ..util.compat import EncodedIO
+if TYPE_CHECKING:
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine.base import Connection
+ from sqlalchemy.engine.base import Transaction
+ from sqlalchemy.engine.mock import MockConnection
+
+ from .environment import EnvironmentContext
+ from ..config import Config
+ from ..script.base import Script
+ from ..script.base import ScriptDirectory
+ from ..script.revision import Revision
+ from ..script.revision import RevisionMap
+
log = logging.getLogger(__name__)
class _ProxyTransaction:
- def __init__(self, migration_context):
+ def __init__(self, migration_context: "MigrationContext") -> None:
self.migration_context = migration_context
@property
- def _proxied_transaction(self):
+ def _proxied_transaction(self) -> Optional["Transaction"]:
return self.migration_context._transaction
- def rollback(self):
- self._proxied_transaction.rollback()
+ def rollback(self) -> None:
+ t = self._proxied_transaction
+ assert t is not None
+ t.rollback()
self.migration_context._transaction = None
- def commit(self):
- self._proxied_transaction.commit()
+ def commit(self) -> None:
+ t = self._proxied_transaction
+ assert t is not None
+ t.commit()
self.migration_context._transaction = None
- def __enter__(self):
+ def __enter__(self) -> "_ProxyTransaction":
return self
- def __exit__(self, type_, value, traceback):
+ def __exit__(self, type_: None, value: None, traceback: None) -> None:
if self._proxied_transaction is not None:
self._proxied_transaction.__exit__(type_, value, traceback)
self.migration_context._transaction = None
"""
- def __init__(self, dialect, connection, opts, environment_context=None):
+ def __init__(
+ self,
+ dialect: "Dialect",
+ connection: Optional["Connection"],
+ opts: Dict[str, Any],
+ environment_context: Optional["EnvironmentContext"] = None,
+ ) -> None:
self.environment_context = environment_context
self.opts = opts
self.dialect = dialect
- self.script = opts.get("script")
- as_sql = opts.get("as_sql", False)
+ self.script: Optional["ScriptDirectory"] = opts.get("script")
+ as_sql: bool = opts.get("as_sql", False)
transactional_ddl = opts.get("transactional_ddl")
self._transaction_per_migration = opts.get(
"transaction_per_migration", False
)
self.on_version_apply_callbacks = opts.get("on_version_apply", ())
- self._transaction = None
+ self._transaction: Optional["Transaction"] = None
if as_sql:
- self.connection = self._stdout_connection(connection)
+ self.connection = cast(
+ Optional["Connection"], self._stdout_connection(connection)
+ )
assert self.connection is not None
self._in_external_transaction = False
else:
if "output_encoding" in opts:
self.output_buffer = EncodedIO(
- opts.get("output_buffer") or sys.stdout,
+ opts.get("output_buffer")
+ or sys.stdout, # type:ignore[arg-type]
opts["output_encoding"],
)
else:
)
)
- self._start_from_rev = opts.get("starting_rev")
+ self._start_from_rev: Optional[str] = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
dialect,
self.connection,
@classmethod
def configure(
cls,
- connection=None,
- url=None,
- dialect_name=None,
- dialect=None,
- environment_context=None,
- dialect_opts=None,
- opts=None,
- ):
+ connection: Optional["Connection"] = None,
+ url: Optional[str] = None,
+ dialect_name: Optional[str] = None,
+ dialect: Optional["Dialect"] = None,
+ environment_context: Optional["EnvironmentContext"] = None,
+ dialect_opts: Optional[Dict[str, str]] = None,
+ opts: Optional[Any] = None,
+ ) -> "MigrationContext":
"""Create a new :class:`.MigrationContext`.
This is a factory method usually called
dialect = connection.dialect
elif url:
- url = sqla_url.make_url(url)
- dialect = url.get_dialect()(**dialect_opts)
+ url_obj = sqla_url.make_url(url)
+ dialect = url_obj.get_dialect()(**dialect_opts)
elif dialect_name:
- url = sqla_url.make_url("%s://" % dialect_name)
- dialect = url.get_dialect()(**dialect_opts)
+ url_obj = sqla_url.make_url("%s://" % dialect_name)
+ dialect = url_obj.get_dialect()(**dialect_opts)
elif not dialect:
raise Exception("Connection, url, or dialect_name is required.")
-
+ assert dialect is not None
return MigrationContext(dialect, connection, opts, environment_context)
@contextmanager
- def autocommit_block(self):
+ def autocommit_block(self) -> Iterator[None]:
"""Enter an "autocommit" block, for databases that support AUTOCOMMIT
isolation levels.
self._transaction = None
if not self.as_sql:
+ assert self.connection is not None
current_level = self.connection.get_isolation_level()
base_connection = self.connection
yield
finally:
if not self.as_sql:
+ assert self.connection is not None
self.connection.execution_options(
isolation_level=current_level
)
self.impl.emit_begin()
elif _in_connection_transaction:
+ assert self.connection is not None
self._transaction = self.connection.begin()
- def begin_transaction(self, _per_migration=False):
+ def begin_transaction(
+ self, _per_migration: bool = False
+ ) -> Union["_ProxyTransaction", ContextManager]:
"""Begin a logical transaction for migration operations.
This method is used within an ``env.py`` script to demarcate where
if in_transaction:
return do_nothing()
else:
+ assert self.connection is not None
self._transaction = (
sqla_compat._safe_begin_connection_transaction(
self.connection
return begin_commit()
else:
+ assert self.connection is not None
self._transaction = sqla_compat._safe_begin_connection_transaction(
self.connection
)
return _ProxyTransaction(self)
- def get_current_revision(self):
+ def get_current_revision(self) -> Optional[str]:
"""Return the current revision, usually that which is present
in the ``alembic_version`` table in the database.
else:
return heads[0]
- def get_current_heads(self):
+ def get_current_heads(self) -> Tuple[str, ...]:
"""Return a tuple of the current 'head versions' that are represented
in the target database.
"""
if self.as_sql:
- start_from_rev = self._start_from_rev
+ start_from_rev: Any = self._start_from_rev
if start_from_rev == "base":
start_from_rev = None
elif start_from_rev is not None and self.script:
)
if not self._has_version_table():
return ()
+ assert self.connection is not None
return tuple(
row[0] for row in self.connection.execute(self._version.select())
)
- def _ensure_version_table(self, purge=False):
+ def _ensure_version_table(self, purge: bool = False) -> None:
with sqla_compat._ensure_scope_for_ddl(self.connection):
self._version.create(self.connection, checkfirst=True)
if purge:
+ assert self.connection is not None
self.connection.execute(self._version.delete())
- def _has_version_table(self):
+ def _has_version_table(self) -> bool:
+ assert self.connection is not None
return sqla_compat._connectable_has_table(
self.connection, self.version_table, self.version_table_schema
)
- def stamp(self, script_directory, revision):
+ def stamp(
+ self, script_directory: "ScriptDirectory", revision: str
+ ) -> None:
"""Stamp the version table with a specific revision.
This method calculates those branches to which the given revision
for step in script_directory._stamp_revs(revision, heads):
head_maintainer.update_to_step(step)
- def run_migrations(self, **kw):
+ def run_migrations(self, **kw) -> None:
r"""Run the migration scripts established for this
:class:`.MigrationContext`, if any.
"""
self.impl.start_migrations()
+ heads: Tuple[str, ...]
if self.purge:
if self.as_sql:
raise util.CommandError("Can't use --purge with --sql mode")
head_maintainer = HeadMaintainer(self, heads)
+ assert self._migrations_fn is not None
for step in self._migrations_fn(heads, self):
with self.begin_transaction(_per_migration=True):
if self.as_sql and not head_maintainer.heads:
self._version.drop(self.connection)
- def _in_connection_transaction(self):
+ def _in_connection_transaction(self) -> bool:
try:
- meth = self.connection.in_transaction
+ meth = self.connection.in_transaction # type:ignore[union-attr]
except AttributeError:
return False
else:
return meth()
- def execute(self, sql, execution_options=None):
+ def execute(self, sql: str, execution_options: None = None) -> None:
"""Execute a SQL construct or string statement.
The underlying execution mechanics are used, that is
"""
self.impl._exec(sql, execution_options)
- def _stdout_connection(self, connection):
+ def _stdout_connection(
+ self, connection: Optional["Connection"]
+ ) -> "MockConnection":
def dump(construct, *multiparams, **params):
self.impl._exec(construct)
return MockEngineStrategy.MockConnection(self.dialect, dump)
@property
- def bind(self):
+ def bind(self) -> Optional["Connection"]:
"""Return the current "bind".
In online mode, this is an instance of
return self.connection
@property
- def config(self):
+ def config(self) -> Optional["Config"]:
"""Return the :class:`.Config` used by the current environment,
if any."""
else:
return None
- def _compare_type(self, inspector_column, metadata_column):
+ def _compare_type(
+ self, inspector_column: "Column", metadata_column: "Column"
+ ) -> bool:
if self._user_compare_type is False:
return False
def _compare_server_default(
self,
- inspector_column,
- metadata_column,
- rendered_metadata_default,
- rendered_column_default,
- ):
+ inspector_column: "Column",
+ metadata_column: "Column",
+ rendered_metadata_default: Optional[str],
+ rendered_column_default: Optional[str],
+ ) -> bool:
if self._user_compare_server_default is False:
return False
class HeadMaintainer:
- def __init__(self, context, heads):
+ def __init__(self, context: "MigrationContext", heads: Any) -> None:
self.context = context
self.heads = set(heads)
- def _insert_version(self, version):
+ def _insert_version(self, version: str) -> None:
assert version not in self.heads
self.heads.add(version)
)
)
- def _delete_version(self, version):
+ def _delete_version(self, version: str) -> None:
self.heads.remove(version)
ret = self.context.impl._exec(
% (version, self.context.version_table, ret.rowcount)
)
- def _update_version(self, from_, to_):
+ def _update_version(self, from_: str, to_: str) -> None:
assert to_ not in self.heads
self.heads.remove(from_)
self.heads.add(to_)
% (from_, to_, self.context.version_table, ret.rowcount)
)
- def update_to_step(self, step):
+ def update_to_step(self, step: Union["RevisionStep", "StampStep"]) -> None:
if step.should_delete_branch(self.heads):
vers = step.delete_version_num
log.debug("branch delete %s", vers)
"""
- is_upgrade = None
+ is_upgrade: bool = None # type:ignore[assignment]
"""True/False: indicates whether this operation ascends or descends the
version tree."""
- is_stamp = None
+ is_stamp: bool = None # type:ignore[assignment]
"""True/False: indicates whether this operation is a stamp (i.e. whether
it results in any actual database operations)."""
- up_revision_id = None
+ up_revision_id: Optional[str] = None
"""Version string corresponding to :attr:`.Revision.revision`.
In the case of a stamp operation, it is advised to use the
"""
- up_revision_ids = None
+ up_revision_ids: Tuple[str, ...] = None # type:ignore[assignment]
"""Tuple of version strings corresponding to :attr:`.Revision.revision`.
In the majority of cases, this tuple will be a single value, synonomous
"""
- down_revision_ids = None
+ down_revision_ids: Tuple[str, ...] = None # type:ignore[assignment]
"""Tuple of strings representing the base revisions of this migration step.
If empty, this represents a root revision; otherwise, the first item
from dependencies.
"""
- revision_map = None
+ revision_map: "RevisionMap" = None # type:ignore[assignment]
"""The revision map inside of which this operation occurs."""
def __init__(
- self, revision_map, is_upgrade, is_stamp, up_revisions, down_revisions
- ):
+ self,
+ revision_map: "RevisionMap",
+ is_upgrade: bool,
+ is_stamp: bool,
+ up_revisions: Union[str, Tuple[str, ...]],
+ down_revisions: Union[str, Tuple[str, ...]],
+ ) -> None:
self.revision_map = revision_map
self.is_upgrade = is_upgrade
self.is_stamp = is_stamp
self.down_revision_ids = util.to_tuple(down_revisions, default=())
@property
- def is_migration(self):
+ def is_migration(self) -> bool:
"""True/False: indicates whether this operation is a migration.
At present this is true if and only the migration is not a stamp.
return not self.is_stamp
@property
- def source_revision_ids(self):
+ def source_revision_ids(self) -> Tuple[str, ...]:
"""Active revisions before this migration step is applied."""
return (
self.down_revision_ids if self.is_upgrade else self.up_revision_ids
)
@property
- def destination_revision_ids(self):
+ def destination_revision_ids(self) -> Tuple[str, ...]:
"""Active revisions after this migration step is applied."""
return (
self.up_revision_ids if self.is_upgrade else self.down_revision_ids
)
@property
- def up_revision(self):
+ def up_revision(self) -> "Revision":
"""Get :attr:`~.MigrationInfo.up_revision_id` as
a :class:`.Revision`.
return self.revision_map.get_revision(self.up_revision_id)
@property
- def up_revisions(self):
+ def up_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~.MigrationInfo.up_revision_ids` as a
:class:`.Revision`."""
return self.revision_map.get_revisions(self.up_revision_ids)
@property
- def down_revisions(self):
+ def down_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.down_revision_ids)
@property
- def source_revisions(self):
+ def source_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.source_revision_ids)
@property
- def destination_revisions(self):
+ def destination_revisions(self) -> Tuple["Revision", ...]:
"""Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.destination_revision_ids)
class MigrationStep:
+
+ from_revisions_no_deps: Tuple[str, ...]
+ to_revisions_no_deps: Tuple[str, ...]
+ is_upgrade: bool
+ migration_fn: Any
+
@property
- def name(self):
+ def name(self) -> str:
return self.migration_fn.__name__
@classmethod
- def upgrade_from_script(cls, revision_map, script):
+ def upgrade_from_script(
+ cls, revision_map: "RevisionMap", script: "Script"
+ ) -> "RevisionStep":
return RevisionStep(revision_map, script, True)
@classmethod
- def downgrade_from_script(cls, revision_map, script):
+ def downgrade_from_script(
+ cls, revision_map: "RevisionMap", script: "Script"
+ ) -> "RevisionStep":
return RevisionStep(revision_map, script, False)
@property
- def is_downgrade(self):
+ def is_downgrade(self) -> bool:
return not self.is_upgrade
@property
- def short_log(self):
+ def short_log(self) -> str:
return "%s %s -> %s" % (
self.name,
util.format_as_comma(self.from_revisions_no_deps),
class RevisionStep(MigrationStep):
- def __init__(self, revision_map, revision, is_upgrade):
+ def __init__(
+ self, revision_map: "RevisionMap", revision: "Script", is_upgrade: bool
+ ) -> None:
self.revision_map = revision_map
self.revision = revision
self.is_upgrade = is_upgrade
if is_upgrade:
- self.migration_fn = revision.module.upgrade
+ self.migration_fn = (
+ revision.module.upgrade # type:ignore[attr-defined]
+ )
else:
- self.migration_fn = revision.module.downgrade
+ self.migration_fn = (
+ revision.module.downgrade # type:ignore[attr-defined]
+ )
def __repr__(self):
return "RevisionStep(%r, is_upgrade=%r)" % (
self.is_upgrade,
)
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return (
isinstance(other, RevisionStep)
and other.revision == self.revision
return self.revision.doc
@property
- def from_revisions(self):
+ def from_revisions(self) -> Tuple[str, ...]:
if self.is_upgrade:
return self.revision._normalized_down_revisions
else:
return (self.revision.revision,)
@property
- def from_revisions_no_deps(self):
+ def from_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
if self.is_upgrade:
return self.revision._versioned_down_revisions
else:
return (self.revision.revision,)
@property
- def to_revisions(self):
+ def to_revisions(self) -> Tuple[str, ...]:
if self.is_upgrade:
return (self.revision.revision,)
else:
return self.revision._normalized_down_revisions
@property
- def to_revisions_no_deps(self):
+ def to_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
if self.is_upgrade:
return (self.revision.revision,)
else:
return self.revision._versioned_down_revisions
@property
- def _has_scalar_down_revision(self):
+ def _has_scalar_down_revision(self) -> bool:
return len(self.revision._normalized_down_revisions) == 1
- def should_delete_branch(self, heads):
+ def should_delete_branch(self, heads: Set[str]) -> bool:
"""A delete is when we are a. in a downgrade and b.
we are going to the "base" or we are going to a version that
is implied as a dependency on another version that is remaining.
to_revisions = self._unmerge_to_revisions(heads)
return not to_revisions
- def merge_branch_idents(self, heads):
+ def merge_branch_idents(
+ self, heads: Set[str]
+ ) -> Tuple[List[str], str, str]:
other_heads = set(heads).difference(self.from_revisions)
if other_heads:
self.to_revisions[0],
)
- def _unmerge_to_revisions(self, heads):
+ def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
other_heads = set(heads).difference([self.revision.revision])
if other_heads:
ancestors = set(
self.revision_map.get_revisions(other_heads), check=False
)
)
- return list(set(self.to_revisions).difference(ancestors))
+ return tuple(set(self.to_revisions).difference(ancestors))
else:
return self.to_revisions
- def unmerge_branch_idents(self, heads):
+ def unmerge_branch_idents(
+ self, heads: Collection[str]
+ ) -> Tuple[str, str, Tuple[str, ...]]:
to_revisions = self._unmerge_to_revisions(heads)
return (
to_revisions[0:-1],
)
- def should_create_branch(self, heads):
+ def should_create_branch(self, heads: Set[str]) -> bool:
if not self.is_upgrade:
return False
else:
return False
- def should_merge_branches(self, heads):
+ def should_merge_branches(self, heads: Set[str]) -> bool:
if not self.is_upgrade:
return False
return False
- def should_unmerge_branches(self, heads):
+ def should_unmerge_branches(self, heads: Set[str]) -> bool:
if not self.is_downgrade:
return False
return False
- def update_version_num(self, heads):
+ def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
if not self._has_scalar_down_revision:
downrev = heads.intersection(
self.revision._normalized_down_revisions
return self.revision.revision, down_revision
@property
- def delete_version_num(self):
+ def delete_version_num(self) -> str:
return self.revision.revision
@property
- def insert_version_num(self):
+ def insert_version_num(self) -> str:
return self.revision.revision
@property
- def info(self):
+ def info(self) -> "MigrationInfo":
return MigrationInfo(
revision_map=self.revision_map,
up_revisions=self.revision.revision,
class StampStep(MigrationStep):
- def __init__(self, from_, to_, is_upgrade, branch_move, revision_map=None):
- self.from_ = util.to_tuple(from_, default=())
- self.to_ = util.to_tuple(to_, default=())
+ def __init__(
+ self,
+ from_: Optional[Union[str, Collection[str]]],
+ to_: Optional[Union[str, Collection[str]]],
+ is_upgrade: bool,
+ branch_move: bool,
+ revision_map: Optional["RevisionMap"] = None,
+ ) -> None:
+ self.from_: Tuple[str, ...] = util.to_tuple(from_, default=())
+ self.to_: Tuple[str, ...] = util.to_tuple(to_, default=())
self.is_upgrade = is_upgrade
self.branch_move = branch_move
self.migration_fn = self.stamp_revision
doc = None
- def stamp_revision(self, **kw):
+ def stamp_revision(self, **kw) -> None:
return None
def __eq__(self, other):
return self.from_
@property
- def to_revisions(self):
+ def to_revisions(self) -> Tuple[str, ...]:
return self.to_
@property
- def from_revisions_no_deps(self):
+ def from_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
return self.from_
@property
- def to_revisions_no_deps(self):
+ def to_revisions_no_deps( # type:ignore[override]
+ self,
+ ) -> Tuple[str, ...]:
return self.to_
@property
- def delete_version_num(self):
+ def delete_version_num(self) -> str:
assert len(self.from_) == 1
return self.from_[0]
@property
- def insert_version_num(self):
+ def insert_version_num(self) -> str:
assert len(self.to_) == 1
return self.to_[0]
- def update_version_num(self, heads):
+ def update_version_num(self, heads: Set[str]) -> Tuple[str, str]:
assert len(self.from_) == 1
assert len(self.to_) == 1
return self.from_[0], self.to_[0]
- def merge_branch_idents(self, heads):
+ def merge_branch_idents(
+ self, heads: Union[Set[str], List[str]]
+ ) -> Union[Tuple[List[Any], str, str], Tuple[List[str], str, str]]:
return (
# delete revs, update from rev, update to rev
list(self.from_[0:-1]),
self.to_[0],
)
- def unmerge_branch_idents(self, heads):
+ def unmerge_branch_idents(
+ self, heads: Set[str]
+ ) -> Tuple[str, str, List[str]]:
return (
# update from rev, update to rev, insert revs
self.from_[0],
list(self.to_[0:-1]),
)
- def should_delete_branch(self, heads):
+ def should_delete_branch(self, heads: Set[str]) -> bool:
# TODO: we probably need to look for self.to_ inside of heads,
# in a similar manner as should_create_branch, however we have
# no tests for this yet (stamp downgrades w/ branches)
return self.is_downgrade and self.branch_move
- def should_create_branch(self, heads):
+ def should_create_branch(self, heads: Set[str]) -> Union[Set[str], bool]:
return (
self.is_upgrade
and (self.branch_move or set(self.from_).difference(heads))
and set(self.to_).difference(heads)
)
- def should_merge_branches(self, heads):
+ def should_merge_branches(self, heads: Set[str]) -> bool:
return len(self.from_) > 1
- def should_unmerge_branches(self, heads):
+ def should_unmerge_branches(self, heads: Set[str]) -> bool:
return len(self.to_) > 1
@property
- def info(self):
+ def info(self) -> "MigrationInfo":
up, down = (
(self.to_, self.from_)
if self.is_upgrade
else (self.from_, self.to_)
)
+ assert self.revision_map is not None
return MigrationInfo(
revision_map=self.revision_map,
up_revisions=up,
import re
import shutil
import sys
+from types import ModuleType
+from typing import Any
+from typing import cast
+from typing import Dict
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from . import revision
from . import write_hooks
from .. import util
from ..runtime import migration
+if TYPE_CHECKING:
+ from ..config import Config
+ from ..runtime.migration import RevisionStep
+ from ..runtime.migration import StampStep
+
try:
from dateutil import tz
except ImportError:
- tz = None # noqa
+ tz = None # type: ignore[assignment]
+
+_RevIdType = Union[str, Sequence[str]]
_sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$")
_only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$")
def __init__(
self,
- dir, # noqa
- file_template=_default_file_template,
- truncate_slug_length=40,
- version_locations=None,
- sourceless=False,
- output_encoding="utf-8",
- timezone=None,
- hook_config=None,
- ):
+ dir: str, # noqa
+ file_template: str = _default_file_template,
+ truncate_slug_length: Optional[int] = 40,
+ version_locations: Optional[List[str]] = None,
+ sourceless: bool = False,
+ output_encoding: str = "utf-8",
+ timezone: Optional[str] = None,
+ hook_config: Optional[Dict[str, str]] = None,
+ ) -> None:
self.dir = dir
self.file_template = file_template
self.version_locations = version_locations
)
@property
- def versions(self):
+ def versions(self) -> str:
loc = self._version_locations
if len(loc) > 1:
raise util.CommandError("Multiple version_locations present")
else:
return (os.path.abspath(os.path.join(self.dir, "versions")),)
- def _load_revisions(self):
+ def _load_revisions(self) -> Iterator["Script"]:
if self.version_locations:
paths = [
vers
yield script
@classmethod
- def from_config(cls, config):
+ def from_config(cls, config: "Config") -> "ScriptDirectory":
"""Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
instance.
raise util.CommandError(
"No 'script_location' key " "found in configuration."
)
- truncate_slug_length = config.get_main_option("truncate_slug_length")
+ truncate_slug_length = cast(
+ Optional[int], config.get_main_option("truncate_slug_length")
+ )
if truncate_slug_length is not None:
truncate_slug_length = int(truncate_slug_length)
else:
if split_char is None:
# legacy behaviour for backwards compatibility
- version_locations = _split_on_space_comma.split(
- version_locations
+ vl = _split_on_space_comma.split(
+ cast(str, version_locations)
)
+ version_locations: List[str] = vl # type: ignore[no-redef]
else:
- version_locations = [
- x for x in version_locations.split(split_char) if x
+ vl = [
+ x
+ for x in cast(str, version_locations).split(split_char)
+ if x
]
+ version_locations: List[str] = vl # type: ignore[no-redef]
prepend_sys_path = config.get_main_option("prepend_sys_path")
if prepend_sys_path:
truncate_slug_length=truncate_slug_length,
sourceless=config.get_main_option("sourceless") == "true",
output_encoding=config.get_main_option("output_encoding", "utf-8"),
- version_locations=version_locations,
+ version_locations=cast("Optional[List[str]]", version_locations),
timezone=config.get_main_option("timezone"),
hook_config=config.get_section("post_write_hooks", {}),
)
@contextmanager
def _catch_revision_errors(
self,
- ancestor=None,
- multiple_heads=None,
- start=None,
- end=None,
- resolution=None,
- ):
+ ancestor: Optional[str] = None,
+ multiple_heads: Optional[str] = None,
+ start: Optional[str] = None,
+ end: Optional[str] = None,
+ resolution: Optional[str] = None,
+ ) -> Iterator[None]:
try:
yield
except revision.RangeNotAncestorError as rna:
if start is None:
- start = rna.lower
+ start = cast(Any, rna.lower)
if end is None:
- end = rna.upper
+ end = cast(Any, rna.upper)
if not ancestor:
ancestor = (
"Requested range %(start)s:%(end)s does not refer to "
except revision.RevisionError as err:
raise util.CommandError(err.args[0]) from err
- def walk_revisions(self, base="base", head="heads"):
+ def walk_revisions(
+ self, base: str = "base", head: str = "heads"
+ ) -> Iterator["Script"]:
"""Iterate through all revisions.
:param base: the base revision, or "base" to start from the
for rev in self.revision_map.iterate_revisions(
head, base, inclusive=True, assert_relative_length=False
):
- yield rev
+ yield cast(Script, rev)
- def get_revisions(self, id_):
+ def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
"""
with self._catch_revision_errors():
- return self.revision_map.get_revisions(id_)
+ return cast(
+ "Tuple[Script, ...]", self.revision_map.get_revisions(id_)
+ )
- def get_all_current(self, id_):
+ def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]:
with self._catch_revision_errors():
- top_revs = set(self.revision_map.get_revisions(id_))
+ top_revs = cast(
+ "Set[Script]",
+ set(self.revision_map.get_revisions(id_)),
+ )
top_revs.update(
- self.revision_map._get_ancestor_nodes(
- list(top_revs), include_dependencies=True
+ cast(
+ "Iterator[Script]",
+ self.revision_map._get_ancestor_nodes(
+ list(top_revs), include_dependencies=True
+ ),
)
)
top_revs = self.revision_map._filter_into_branch_heads(top_revs)
return top_revs
- def get_revision(self, id_):
+ def get_revision(self, id_: str) -> "Script":
"""Return the :class:`.Script` instance with the given rev id.
.. seealso::
"""
with self._catch_revision_errors():
- return self.revision_map.get_revision(id_)
+ return cast(Script, self.revision_map.get_revision(id_))
- def as_revision_number(self, id_):
+ def as_revision_number(
+ self, id_: Optional[str]
+ ) -> Optional[Union[str, Tuple[str, ...]]]:
"""Convert a symbolic revision, i.e. 'head' or 'base', into
an actual revision number."""
):
return self.revision_map.get_current_head()
- def get_heads(self):
+ def get_heads(self) -> List[str]:
"""Return all "versioned head" revisions as strings.
This is normally a list of length one,
"""
return list(self.revision_map.heads)
- def get_base(self):
+ def get_base(self) -> Optional[str]:
"""Return the "base" revision as a string.
This is the revision number of the script that
else:
return None
- def get_bases(self):
+ def get_bases(self) -> List[str]:
"""return all "base" revisions as strings.
This is the revision number of all scripts that
"""
return list(self.revision_map.bases)
- def _upgrade_revs(self, destination, current_rev):
+ def _upgrade_revs(
+ self, destination: str, current_rev: str
+ ) -> List["RevisionStep"]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid upgrade "
"target from current head(s)",
revs = self.revision_map.iterate_revisions(
destination, current_rev, implicit_base=True
)
- revs = list(revs)
return [
migration.MigrationStep.upgrade_from_script(
- self.revision_map, script
+ self.revision_map, cast(Script, script)
)
for script in reversed(list(revs))
]
- def _downgrade_revs(self, destination, current_rev):
+ def _downgrade_revs(
+ self, destination: str, current_rev: Optional[str]
+ ) -> List["RevisionStep"]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid downgrade "
"target from current head(s)",
)
return [
migration.MigrationStep.downgrade_from_script(
- self.revision_map, script
+ self.revision_map, cast(Script, script)
)
for script in revs
]
- def _stamp_revs(self, revision, heads):
+ def _stamp_revs(
+ self, revision: _RevIdType, heads: _RevIdType
+ ) -> List["StampStep"]:
with self._catch_revision_errors(
multiple_heads="Multiple heads are present; please specify a "
"single target revision"
):
- heads = self.get_revisions(heads)
+ heads_revs = self.get_revisions(heads)
steps = []
if not revision:
revision = "base"
- filtered_heads = []
+ filtered_heads: List["Script"] = []
for rev in util.to_tuple(revision):
if rev:
filtered_heads.extend(
self.revision_map.filter_for_lineage(
- heads, rev, include_dependencies=True
+ heads_revs, rev, include_dependencies=True
)
)
filtered_heads = util.unique_list(filtered_heads)
return steps
- def run_env(self):
+ def run_env(self) -> None:
"""Run the script environment.
This basically runs the ``env.py`` script present
def env_py_location(self):
return os.path.abspath(os.path.join(self.dir, "env.py"))
- def _generate_template(self, src, dest, **kw):
+ def _generate_template(self, src: str, dest: str, **kw: Any) -> None:
util.status(
"Generating %s" % os.path.abspath(dest),
util.template_to_file,
**kw
)
- def _copy_file(self, src, dest):
+ def _copy_file(self, src: str, dest: str) -> None:
util.status(
"Generating %s" % os.path.abspath(dest), shutil.copy, src, dest
)
- def _ensure_directory(self, path):
+ def _ensure_directory(self, path: str) -> None:
path = os.path.abspath(path)
if not os.path.exists(path):
util.status("Creating directory %s" % path, os.makedirs, path)
- def _generate_create_date(self):
+ def _generate_create_date(self) -> "datetime.datetime":
if self.timezone is not None:
if tz is None:
raise util.CommandError(
def generate_revision(
self,
- revid,
- message,
- head=None,
- refresh=False,
- splice=False,
- branch_labels=None,
- version_path=None,
- depends_on=None,
- **kw
- ):
+ revid: str,
+ message: Optional[str],
+ head: Optional[str] = None,
+ refresh: bool = False,
+ splice: Optional[bool] = False,
+ branch_labels: Optional[str] = None,
+ version_path: Optional[str] = None,
+ depends_on: Optional[_RevIdType] = None,
+ **kw: Any
+ ) -> Optional["Script"]:
"""Generate a new revision file.
This runs the ``script.py.mako`` template, given
if version_path is None:
if len(self._version_locations) > 1:
- for head in heads:
- if head is not None:
- version_path = os.path.dirname(head.path)
+ for head_ in heads:
+ if head_ is not None:
+ assert isinstance(head_, Script)
+ version_path = os.path.dirname(head_.path)
break
else:
raise util.CommandError(
path = self._rev_path(version_path, revid, message, create_date)
if not splice:
- for head in heads:
- if head is not None and not head.is_head:
+ for head_ in heads:
+ if head_ is not None and not head_.is_head:
raise util.CommandError(
"Revision %s is not a head revision; please specify "
"--splice to create a new branch from this revision"
- % head.revision
+ % head_.revision
)
if depends_on:
tuple(h.revision if h is not None else None for h in heads)
),
branch_labels=util.to_tuple(branch_labels),
- depends_on=revision.tuple_rev_as_scalar(depends_on),
+ depends_on=revision.tuple_rev_as_scalar(
+ cast("Optional[List[str]]", depends_on)
+ ),
create_date=create_date,
comma=util.format_as_comma,
message=message if message is not None else ("empty message"),
script = Script._from_path(self, path)
except revision.RevisionError as err:
raise util.CommandError(err.args[0]) from err
+ if script is None:
+ return None
if branch_labels and not script.branch_labels:
raise util.CommandError(
"Version %s specified branch_labels %s, however the "
"'branch_labels' section?"
% (script.revision, branch_labels, script.path)
)
-
self.revision_map.add_revision(script)
return script
- def _rev_path(self, path, rev_id, message, create_date):
+ def _rev_path(
+ self,
+ path: str,
+ rev_id: str,
+ message: Optional[str],
+ create_date: "datetime.datetime",
+ ) -> str:
slug = "_".join(_slug_re.findall(message or "")).lower()
if len(slug) > self.truncate_slug_length:
slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_"
"""
- def __init__(self, module, rev_id, path):
+ def __init__(self, module: ModuleType, rev_id: str, path: str):
self.module = module
self.path = path
super(Script, self).__init__(
rev_id,
- module.down_revision,
+ module.down_revision, # type: ignore[attr-defined]
branch_labels=util.to_tuple(
getattr(module, "branch_labels", None), default=()
),
),
)
- module = None
+ module: ModuleType = None # type: ignore[assignment]
"""The Python module representing the actual script itself."""
- path = None
+ path: str = None # type: ignore[assignment]
"""Filesystem path of the script."""
_db_current_indicator = None
this is a "current" version in some database"""
@property
- def doc(self):
+ def doc(self) -> str:
"""Return the docstring given in the script."""
return re.split("\n\n", self.longdoc)[0]
@property
- def longdoc(self):
+ def longdoc(self) -> str:
"""Return the docstring given in the script."""
doc = self.module.__doc__
if doc:
if hasattr(self.module, "_alembic_source_encoding"):
- doc = doc.decode(self.module._alembic_source_encoding)
- return doc.strip()
+ doc = doc.decode( # type: ignore[attr-defined]
+ self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa
+ )
+ return doc.strip() # type: ignore[union-attr]
else:
return ""
@property
- def log_entry(self):
+ def log_entry(self) -> str:
entry = "Rev: %s%s%s%s%s\n" % (
self.revision,
" (head)" if self.is_head else "",
def _head_only(
self,
- include_branches=False,
- include_doc=False,
- include_parents=False,
- tree_indicators=True,
- head_indicators=True,
- ):
+ include_branches: bool = False,
+ include_doc: bool = False,
+ include_parents: bool = False,
+ tree_indicators: bool = True,
+ head_indicators: bool = True,
+ ) -> str:
text = self.revision
if include_parents:
if self.dependencies:
)
else:
text = "%s -> %s" % (self._format_down_revision(), text)
+ assert text is not None
if include_branches and self.branch_labels:
text += " (%s)" % util.format_as_comma(self.branch_labels)
if head_indicators or tree_indicators:
def cmd_format(
self,
- verbose,
- include_branches=False,
- include_doc=False,
- include_parents=False,
- tree_indicators=True,
- ):
+ verbose: bool,
+ include_branches: bool = False,
+ include_doc: bool = False,
+ include_parents: bool = False,
+ tree_indicators: bool = True,
+ ) -> str:
if verbose:
return self.log_entry
else:
include_branches, include_doc, include_parents, tree_indicators
)
- def _format_down_revision(self):
+ def _format_down_revision(self) -> str:
if not self.down_revision:
return "<base>"
else:
return util.format_as_comma(self._versioned_down_revisions)
@classmethod
- def _from_path(cls, scriptdir, path):
+ def _from_path(
+ cls, scriptdir: ScriptDirectory, path: str
+ ) -> Optional["Script"]:
dir_, filename = os.path.split(path)
return cls._from_filename(scriptdir, dir_, filename)
@classmethod
- def _list_py_dir(cls, scriptdir, path):
+ def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]:
if scriptdir.sourceless:
# read files in version path, e.g. pyc or pyo files
# in the immediate path
return os.listdir(path)
@classmethod
- def _from_filename(cls, scriptdir, dir_, filename):
+ def _from_filename(
+ cls, scriptdir: ScriptDirectory, dir_: str, filename: str
+ ) -> Optional["Script"]:
if scriptdir.sourceless:
py_match = _sourceless_rev_file.match(filename)
else:
import collections
import re
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Collection
+from typing import Deque
+from typing import Dict
+from typing import FrozenSet
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from sqlalchemy import util as sqlautil
from .. import util
from ..util import compat
+if TYPE_CHECKING:
+ from typing import Literal
+
+ from .base import Script
+
+_RevIdType = Union[str, Sequence[str]]
+_RevisionIdentifierType = Union[str, Tuple[str, ...], None]
+_RevisionOrStr = Union["Revision", str]
+_RevisionOrBase = Union["Revision", "Literal['base']"]
+_InterimRevisionMapType = Dict[str, "Revision"]
+_RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]]
+_T = TypeVar("_T", bound=Union[str, "Revision"])
+
_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
_revision_illegal_chars = ["@", "-", "+"]
class RangeNotAncestorError(RevisionError):
- def __init__(self, lower, upper):
+ def __init__(
+ self, lower: _RevisionIdentifierType, upper: _RevisionIdentifierType
+ ) -> None:
self.lower = lower
self.upper = upper
super(RangeNotAncestorError, self).__init__(
class MultipleHeads(RevisionError):
- def __init__(self, heads, argument):
+ def __init__(self, heads: Sequence[str], argument: Optional[str]) -> None:
self.heads = heads
self.argument = argument
super(MultipleHeads, self).__init__(
class ResolutionError(RevisionError):
- def __init__(self, message, argument):
+ def __init__(self, message: str, argument: str) -> None:
super(ResolutionError, self).__init__(message)
self.argument = argument
class CycleDetected(RevisionError):
kind = "Cycle"
- def __init__(self, revisions):
+ def __init__(self, revisions: Sequence[str]) -> None:
self.revisions = revisions
super(CycleDetected, self).__init__(
"%s is detected in revisions (%s)"
class DependencyCycleDetected(CycleDetected):
kind = "Dependency cycle"
- def __init__(self, revisions):
+ def __init__(self, revisions: Sequence[str]) -> None:
super(DependencyCycleDetected, self).__init__(revisions)
class LoopDetected(CycleDetected):
kind = "Self-loop"
- def __init__(self, revision):
+ def __init__(self, revision: str) -> None:
super(LoopDetected, self).__init__([revision])
class DependencyLoopDetected(DependencyCycleDetected, LoopDetected):
kind = "Dependency self-loop"
- def __init__(self, revision):
+ def __init__(self, revision: Sequence[str]) -> None:
super(DependencyLoopDetected, self).__init__(revision)
"""
- def __init__(self, generator):
+ def __init__(self, generator: Callable[[], Iterator["Revision"]]) -> None:
"""Construct a new :class:`.RevisionMap`.
:param generator: a zero-arg callable that will generate an iterable
self._generator = generator
@util.memoized_property
- def heads(self):
+ def heads(self) -> Tuple[str, ...]:
"""All "head" revisions as strings.
This is normally a tuple of length one,
return self.heads
@util.memoized_property
- def bases(self):
+ def bases(self) -> Tuple[str, ...]:
"""All "base" revisions as strings.
These are revisions that have a ``down_revision`` of None,
return self.bases
@util.memoized_property
- def _real_heads(self):
+ def _real_heads(self) -> Tuple[str, ...]:
"""All "real" head revisions as strings.
:return: a tuple of string revision numbers.
return self._real_heads
@util.memoized_property
- def _real_bases(self):
+ def _real_bases(self) -> Tuple[str, ...]:
"""All "real" base revisions as strings.
:return: a tuple of string revision numbers.
return self._real_bases
@util.memoized_property
- def _revision_map(self):
+ def _revision_map(self) -> _RevisionMapType:
"""memoized attribute, initializes the revision map from the
initial collection.
"""
# Ordering required for some tests to pass (but not required in
# general)
- map_ = sqlautil.OrderedDict()
+ map_: _InterimRevisionMapType = sqlautil.OrderedDict()
- heads = sqlautil.OrderedSet()
- _real_heads = sqlautil.OrderedSet()
- bases = ()
- _real_bases = ()
+ heads: Set["Revision"] = sqlautil.OrderedSet()
+ _real_heads: Set["Revision"] = sqlautil.OrderedSet()
+ bases: Tuple["Revision", ...] = ()
+ _real_bases: Tuple["Revision", ...] = ()
has_branch_labels = set()
all_revisions = set()
# add the branch_labels to the map_. We'll need these
# to resolve the dependencies.
rev_map = map_.copy()
- self._map_branch_labels(has_branch_labels, map_)
+ self._map_branch_labels(
+ has_branch_labels, cast(_RevisionMapType, map_)
+ )
# resolve dependency names from branch labels and symbolic
# names
- self._add_depends_on(all_revisions, map_)
+ self._add_depends_on(all_revisions, cast(_RevisionMapType, map_))
for rev in map_.values():
for downrev in rev._all_down_revisions:
# once the map has downrevisions populated, the dependencies
# can be further refined to include only those which are not
# already ancestors
- self._normalize_depends_on(all_revisions, map_)
+ self._normalize_depends_on(all_revisions, cast(_RevisionMapType, map_))
self._detect_cycles(rev_map, heads, bases, _real_heads, _real_bases)
- map_[None] = map_[()] = None
+ revision_map: _RevisionMapType = dict(map_.items())
+ revision_map[None] = revision_map[()] = None
self.heads = tuple(rev.revision for rev in heads)
self._real_heads = tuple(rev.revision for rev in _real_heads)
self.bases = tuple(rev.revision for rev in bases)
self._real_bases = tuple(rev.revision for rev in _real_bases)
- self._add_branches(has_branch_labels, map_)
- return map_
+ self._add_branches(has_branch_labels, revision_map)
+ return revision_map
- def _detect_cycles(self, rev_map, heads, bases, _real_heads, _real_bases):
+ def _detect_cycles(
+ self,
+ rev_map: _InterimRevisionMapType,
+ heads: Set["Revision"],
+ bases: Tuple["Revision", ...],
+ _real_heads: Set["Revision"],
+ _real_bases: Tuple["Revision", ...],
+ ) -> None:
if not rev_map:
return
if not heads or not bases:
- raise CycleDetected(rev_map.keys())
+ raise CycleDetected(list(rev_map))
total_space = {
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._versioned_down_revisions, heads, map_=rev_map
+ lambda r: r._versioned_down_revisions,
+ heads,
+ map_=cast(_RevisionMapType, rev_map),
)
}.intersection(
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r.nextrev, bases, map_=rev_map
+ lambda r: r.nextrev,
+ bases,
+ map_=cast(_RevisionMapType, rev_map),
)
)
deleted_revs = set(rev_map.keys()) - total_space
raise CycleDetected(sorted(deleted_revs))
if not _real_heads or not _real_bases:
- raise DependencyCycleDetected(rev_map.keys())
+ raise DependencyCycleDetected(list(rev_map))
total_space = {
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._all_down_revisions, _real_heads, map_=rev_map
+ lambda r: r._all_down_revisions,
+ _real_heads,
+ map_=cast(_RevisionMapType, rev_map),
)
}.intersection(
rev.revision
for rev in self._iterate_related_revisions(
- lambda r: r._all_nextrev, _real_bases, map_=rev_map
+ lambda r: r._all_nextrev,
+ _real_bases,
+ map_=cast(_RevisionMapType, rev_map),
)
)
deleted_revs = set(rev_map.keys()) - total_space
if deleted_revs:
raise DependencyCycleDetected(sorted(deleted_revs))
- def _map_branch_labels(self, revisions, map_):
+ def _map_branch_labels(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
for revision in revisions:
if revision.branch_labels:
+ assert revision._orig_branch_labels is not None
for branch_label in revision._orig_branch_labels:
if branch_label in map_:
+ map_rev = map_[branch_label]
+ assert map_rev is not None
raise RevisionError(
"Branch name '%s' in revision %s already "
"used by revision %s"
% (
branch_label,
revision.revision,
- map_[branch_label].revision,
+ map_rev.revision,
)
)
map_[branch_label] = revision
- def _add_branches(self, revisions, map_):
+ def _add_branches(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
for revision in revisions:
if revision.branch_labels:
revision.branch_labels.update(revision.branch_labels)
else:
break
- def _add_depends_on(self, revisions, map_):
+ def _add_depends_on(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
"""Resolve the 'dependencies' for each revision in a collection
in terms of actual revision ids, as opposed to branch labels or other
symbolic names.
map_[dep] for dep in util.to_tuple(revision.dependencies)
]
revision._resolved_dependencies = tuple(
- [d.revision for d in deps]
+ [d.revision for d in deps if d is not None]
)
else:
revision._resolved_dependencies = ()
- def _normalize_depends_on(self, revisions, map_):
+ def _normalize_depends_on(
+ self, revisions: Collection["Revision"], map_: _RevisionMapType
+ ) -> None:
"""Create a collection of "dependencies" that omits dependencies
that are already ancestor nodes for each revision in a given
collection.
if revision._resolved_dependencies:
normalized_resolved = set(revision._resolved_dependencies)
for rev in self._get_ancestor_nodes(
- [revision], include_dependencies=False, map_=map_
+ [revision],
+ include_dependencies=False,
+ map_=cast(_RevisionMapType, map_),
):
if rev is revision:
continue
else:
revision._normalized_resolved_dependencies = ()
- def add_revision(self, revision, _replace=False):
+ def add_revision(
+ self, revision: "Revision", _replace: bool = False
+ ) -> None:
"""add a single revision to an existing map.
This method is for single-revision use cases, it's not
"Revision %s referenced from %s is not present"
% (downrev, revision)
)
- map_[downrev].add_nextrev(revision)
+ cast("Revision", map_[downrev]).add_nextrev(revision)
self._normalize_depends_on(revisions, map_)
)
) + (revision.revision,)
- def get_current_head(self, branch_label=None):
+ def get_current_head(
+ self, branch_label: Optional[str] = None
+ ) -> Optional[str]:
"""Return the current head revision.
If the script directory has multiple heads
:meth:`.ScriptDirectory.get_heads`
"""
- current_heads = self.heads
+ current_heads: Sequence[str] = self.heads
if branch_label:
current_heads = self.filter_for_lineage(
current_heads, branch_label
else:
return None
- def _get_base_revisions(self, identifier):
+ def _get_base_revisions(self, identifier: str) -> Tuple[str, ...]:
return self.filter_for_lineage(self.bases, identifier)
- def get_revisions(self, id_):
+ def get_revisions(
+ self, id_: Union[str, Collection[str], None]
+ ) -> Tuple["Revision", ...]:
"""Return the :class:`.Revision` instances with the given rev id
or identifiers.
if isinstance(id_, (list, tuple, set, frozenset)):
return sum([self.get_revisions(id_elem) for id_elem in id_], ())
else:
- resolved_id, branch_label = self._resolve_revision_number(id_)
+ resolved_id, branch_label = self._resolve_revision_number(
+ id_ # type:ignore [arg-type]
+ )
if len(resolved_id) == 1:
try:
rint = int(resolved_id[0])
# branch@-n -> walk down from heads
select_heads = self.get_revisions("heads")
if branch_label is not None:
- select_heads = [
+ select_heads = tuple(
head
for head in select_heads
if branch_label in head.branch_labels
- ]
+ )
return tuple(
self._walk(head, steps=rint)
for head in select_heads
for rev_id in resolved_id
)
- def get_revision(self, id_):
+ def get_revision(self, id_: Optional[str]) -> "Revision":
"""Return the :class:`.Revision` instance with the given rev id.
If a symbolic name such as "head" or "base" is given, resolves
if len(resolved_id) > 1:
raise MultipleHeads(resolved_id, id_)
elif resolved_id:
- resolved_id = resolved_id[0]
+ resolved_id = resolved_id[0] # type:ignore[assignment]
- return self._revision_for_ident(resolved_id, branch_label)
+ return self._revision_for_ident(cast(str, resolved_id), branch_label)
- def _resolve_branch(self, branch_label):
+ def _resolve_branch(self, branch_label: str) -> "Revision":
try:
branch_rev = self._revision_map[branch_label]
except KeyError:
else:
return nonbranch_rev
else:
- return branch_rev
+ return cast("Revision", branch_rev)
- def _revision_for_ident(self, resolved_id, check_branch=None):
+ def _revision_for_ident(
+ self, resolved_id: str, check_branch: Optional[str] = None
+ ) -> "Revision":
+ branch_rev: Optional["Revision"]
if check_branch:
branch_rev = self._resolve_branch(check_branch)
else:
branch_rev = None
+ revision: Union["Revision", "Literal[False]"]
try:
- revision = self._revision_map[resolved_id]
+ revision = cast("Revision", self._revision_map[resolved_id])
except KeyError:
# break out to avoid misleading py3k stack traces
revision = False
+ revs: Sequence[str]
if revision is False:
# do a partial lookup
revs = [
resolved_id,
)
else:
- revision = self._revision_map[revs[0]]
+ revision = cast("Revision", self._revision_map[revs[0]])
+ revision = cast("Revision", revision)
if check_branch and revision is not None:
+ assert branch_rev is not None
if not self._shares_lineage(
revision.revision, branch_rev.revision
):
)
return revision
- def _filter_into_branch_heads(self, targets):
+ def _filter_into_branch_heads(
+ self, targets: Set["Script"]
+ ) -> Set["Script"]:
targets = set(targets)
for rev in list(targets):
return targets
def filter_for_lineage(
- self, targets, check_against, include_dependencies=False
- ):
+ self,
+ targets: Sequence[_T],
+ check_against: Optional[str],
+ include_dependencies: bool = False,
+ ) -> Tuple[_T, ...]:
id_, branch_label = self._resolve_revision_number(check_against)
shares = []
if id_:
shares.extend(id_)
- return [
+ return tuple(
tg
for tg in targets
if self._shares_lineage(
tg, shares, include_dependencies=include_dependencies
)
- ]
+ )
def _shares_lineage(
- self, target, test_against_revs, include_dependencies=False
- ):
+ self,
+ target: _RevisionOrStr,
+ test_against_revs: Sequence[_RevisionOrStr],
+ include_dependencies: bool = False,
+ ) -> bool:
if not test_against_revs:
return True
if not isinstance(target, Revision):
.intersection(test_against_revs)
)
- def _resolve_revision_number(self, id_):
+ def _resolve_revision_number(
+ self, id_: Optional[str]
+ ) -> Tuple[Tuple[str, ...], Optional[str]]:
+ branch_label: Optional[str]
if isinstance(id_, compat.string_types) and "@" in id_:
branch_label, id_ = id_.split("@", 1)
def iterate_revisions(
self,
- upper,
- lower,
- implicit_base=False,
- inclusive=False,
- assert_relative_length=True,
- select_for_downgrade=False,
- ):
+ upper: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
+ implicit_base: bool = False,
+ inclusive: bool = False,
+ assert_relative_length: bool = True,
+ select_for_downgrade: bool = False,
+ ) -> Iterator["Revision"]:
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
The iterator yields :class:`.Revision` objects.
"""
+ fn: Callable
if select_for_downgrade:
fn = self._collect_downgrade_revisions
else:
def _get_descendant_nodes(
self,
- targets,
- map_=None,
- check=False,
- omit_immediate_dependencies=False,
- include_dependencies=True,
- ):
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType] = None,
+ check: bool = False,
+ omit_immediate_dependencies: bool = False,
+ include_dependencies: bool = True,
+ ) -> Iterator[Any]:
if omit_immediate_dependencies:
)
def _get_ancestor_nodes(
- self, targets, map_=None, check=False, include_dependencies=True
- ):
+ self,
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType] = None,
+ check: bool = False,
+ include_dependencies: bool = True,
+ ) -> Iterator["Revision"]:
if include_dependencies:
fn, targets, map_=map_, check=check
)
- def _iterate_related_revisions(self, fn, targets, map_, check=False):
+ def _iterate_related_revisions(
+ self,
+ fn: Callable,
+ targets: Collection["Revision"],
+ map_: Optional[_RevisionMapType],
+ check: bool = False,
+ ) -> Iterator["Revision"]:
if map_ is None:
map_ = self._revision_map
seen = set()
- todo = collections.deque()
+ todo: Deque["Revision"] = collections.deque()
for target in targets:
todo.append(target)
# Check for map errors before collecting.
for rev_id in fn(rev):
next_rev = map_[rev_id]
+ assert next_rev is not None
if next_rev.revision != rev_id:
raise RevisionError(
"Dependency resolution failed; broken map"
)
)
- def _topological_sort(self, revisions, heads):
+ def _topological_sort(
+ self,
+ revisions: Collection["Revision"],
+ heads: Any,
+ ) -> List[str]:
"""Yield revision ids of a collection of Revision objects in
topological sorted order (i.e. revisions always come after their
down_revisions and dependencies). Uses the order of keys in
# now update the heads with our ancestors.
candidate_rev = id_to_rev[candidate]
+ assert candidate_rev is not None
heads_to_add = [
r
del ancestors_by_idx[current_candidate_idx]
current_candidate_idx = max(current_candidate_idx - 1, 0)
else:
-
if (
not candidate_rev._normalized_resolved_dependencies
and len(candidate_rev._versioned_down_revisions) == 1
assert not todo
return output
- def _walk(self, start, steps, branch_label=None, no_overwalk=True):
+ def _walk(
+ self,
+ start: Optional[Union[str, "Revision"]],
+ steps: int,
+ branch_label: Optional[str] = None,
+ no_overwalk: bool = True,
+ ) -> "Revision":
"""
Walk the requested number of :steps up (steps > 0) or down (steps < 0)
the revision tree.
A RevisionError is raised if there is no unambiguous revision to
walk to.
"""
-
+ initial: Optional[_RevisionOrBase]
if isinstance(start, compat.string_types):
- start = self.get_revision(start)
+ initial = self.get_revision(start)
+ else:
+ initial = start
+ children: Sequence[_RevisionOrBase]
for _ in range(abs(steps)):
if steps > 0:
# Walk up
children = [
rev
for rev in self.get_revisions(
- self.bases if start is None else start.nextrev
+ self.bases
+ if initial is None
+ else cast("Revision", initial).nextrev
)
]
if branch_label:
children = self.filter_for_lineage(children, branch_label)
else:
# Walk down
- if start == "base":
- children = tuple()
+ if initial == "base":
+ children = ()
else:
children = self.get_revisions(
- self.heads if start is None else start.down_revision
+ self.heads
+ if initial is None
+ else initial.down_revision
)
if not children:
- children = ("base",)
+ children = cast("Tuple[Literal['base']]", ("base",))
if not children:
# This will return an invalid result if no_overwalk, otherwise
# further steps will stay where we are.
- return None if no_overwalk else start
+ ret = None if no_overwalk else initial
+ return ret # type:ignore[return-value]
elif len(children) > 1:
raise RevisionError("Ambiguous walk")
- start = children[0]
+ initial = children[0]
- return start
+ return cast("Revision", initial)
def _parse_downgrade_target(
- self, current_revisions, target, assert_relative_length
- ):
+ self,
+ current_revisions: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ assert_relative_length: bool,
+ ) -> Tuple[Optional[str], Optional[_RevisionOrBase]]:
"""
Parse downgrade command syntax :target to retrieve the target revision
and branch label (if any) given the :current_revisons stamp of the
if relative_revision:
# Find target revision relative to current state.
if branch_label:
- symbol = self.filter_for_lineage(
+ symbol_list = self.filter_for_lineage(
util.to_tuple(current_revisions), branch_label
)
- assert len(symbol) == 1
- symbol = symbol[0]
+ assert len(symbol_list) == 1
+ symbol = symbol_list[0]
else:
current_revisions = util.to_tuple(current_revisions)
if not current_revisions:
# No relative destination given, revision specified is absolute.
branch_label, _, symbol = target.rpartition("@")
if not branch_label:
- branch_label = None
+ branch_label = None # type:ignore[assignment]
return branch_label, self.get_revision(symbol)
def _parse_upgrade_target(
- self, current_revisions, target, assert_relative_length
- ):
+ self,
+ current_revisions: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ assert_relative_length: bool,
+ ) -> Tuple["Revision", ...]:
"""
Parse upgrade command syntax :target to retrieve the target revision
and given the :current_revisons stamp of the database.
current_revisions = util.to_tuple(current_revisions)
- branch_label, symbol, relative = match.groups()
- relative_str = relative
- relative = int(relative)
+ branch_label, symbol, relative_str = match.groups()
+ relative = int(relative_str)
if relative > 0:
if symbol is None:
if not current_revisions:
)
def _collect_downgrade_revisions(
- self, upper, target, inclusive, implicit_base, assert_relative_length
- ):
+ self,
+ upper: _RevisionIdentifierType,
+ target: _RevisionIdentifierType,
+ inclusive: bool,
+ implicit_base: bool,
+ assert_relative_length: bool,
+ ) -> Any:
"""
Compute the set of current revisions specified by :upper, and the
downgrade target specified by :target. Return all dependents of target
return downgrade_revisions, heads
def _collect_upgrade_revisions(
- self, upper, lower, inclusive, implicit_base, assert_relative_length
- ):
+ self,
+ upper: _RevisionIdentifierType,
+ lower: _RevisionIdentifierType,
+ inclusive: bool,
+ implicit_base: bool,
+ assert_relative_length: bool,
+ ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]:
"""
Compute the set of required revisions specified by :upper, and the
current set of active revisions specified by :lower. Find the
of the current/lower revisions. Dependencies from branches with
different bases will not be included.
"""
- targets = self._parse_upgrade_target(
+ targets: Collection["Revision"] = self._parse_upgrade_target(
current_revisions=lower,
target=upper,
assert_relative_length=assert_relative_length,
)
- assert targets is not None
- assert type(targets) is tuple, "targets should be a tuple"
+ # assert type(targets) is tuple, "targets should be a tuple"
# Handled named bases (e.g. branch@... -> heads should only produce
# targets on the given branch)
)
needs.intersection_update(lower_descendents)
- return needs, targets
+ return needs, tuple(targets) # type:ignore[return-value]
class Revision:
"""
- nextrev = frozenset()
+ nextrev: FrozenSet[str] = frozenset()
"""following revisions, based on down_revision only."""
- _all_nextrev = frozenset()
+ _all_nextrev: FrozenSet[str] = frozenset()
- revision = None
+ revision: str = None # type: ignore[assignment]
"""The string revision number."""
- down_revision = None
+ down_revision: Optional[_RevIdType] = None
"""The ``down_revision`` identifier(s) within the migration script.
Note that the total set of "down" revisions is
"""
- dependencies = None
+ dependencies: Optional[_RevIdType] = None
"""Additional revisions which this revision is dependent on.
From a migration standpoint, these dependencies are added to the
"""
- branch_labels = None
+ branch_labels: Set[str] = None # type: ignore[assignment]
"""Optional string/tuple of symbolic names to apply to this
revision's branch"""
+ _resolved_dependencies: Tuple[str, ...]
+ _normalized_resolved_dependencies: Tuple[str, ...]
+
@classmethod
- def verify_rev_id(cls, revision):
+ def verify_rev_id(cls, revision: str) -> None:
illegal_chars = set(revision).intersection(_revision_illegal_chars)
if illegal_chars:
raise RevisionError(
)
def __init__(
- self, revision, down_revision, dependencies=None, branch_labels=None
- ):
+ self,
+ revision: str,
+ down_revision: Optional[Union[str, Tuple[str, ...]]],
+ dependencies: Optional[Tuple[str, ...]] = None,
+ branch_labels: Optional[Tuple[str, ...]] = None,
+ ) -> None:
if down_revision and revision in util.to_tuple(down_revision):
raise LoopDetected(revision)
elif dependencies is not None and revision in util.to_tuple(
self._orig_branch_labels = util.to_tuple(branch_labels, default=())
self.branch_labels = set(self._orig_branch_labels)
- def __repr__(self):
+ def __repr__(self) -> str:
args = [repr(self.revision), repr(self.down_revision)]
if self.dependencies:
args.append("dependencies=%r" % (self.dependencies,))
args.append("branch_labels=%r" % (self.branch_labels,))
return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
- def add_nextrev(self, revision):
+ def add_nextrev(self, revision: "Revision") -> None:
self._all_nextrev = self._all_nextrev.union([revision.revision])
if self.revision in revision._versioned_down_revisions:
self.nextrev = self.nextrev.union([revision.revision])
@property
- def _all_down_revisions(self):
+ def _all_down_revisions(self) -> Tuple[str, ...]:
return util.dedupe_tuple(
util.to_tuple(self.down_revision, default=())
+ self._resolved_dependencies
)
@property
- def _normalized_down_revisions(self):
+ def _normalized_down_revisions(self) -> Tuple[str, ...]:
"""return immediate down revisions for a rev, omitting dependencies
that are still dependencies of ancestors.
)
@property
- def _versioned_down_revisions(self):
+ def _versioned_down_revisions(self) -> Tuple[str, ...]:
return util.to_tuple(self.down_revision, default=())
@property
- def is_head(self):
+ def is_head(self) -> bool:
"""Return True if this :class:`.Revision` is a 'head' revision.
This is determined based on whether any other :class:`.Script`
return not bool(self.nextrev)
@property
- def _is_real_head(self):
+ def _is_real_head(self) -> bool:
return not bool(self._all_nextrev)
@property
- def is_base(self):
+ def is_base(self) -> bool:
"""Return True if this :class:`.Revision` is a 'base' revision."""
return self.down_revision is None
@property
- def _is_real_base(self):
+ def _is_real_base(self) -> bool:
"""Return True if this :class:`.Revision` is a "real" base revision,
e.g. that it has no dependencies either."""
return self.down_revision is None and self.dependencies is None
@property
- def is_branch_point(self):
+ def is_branch_point(self) -> bool:
"""Return True if this :class:`.Script` is a branch point.
A branchpoint is defined as a :class:`.Script` which is referred
return len(self.nextrev) > 1
@property
- def _is_real_branch_point(self):
+ def _is_real_branch_point(self) -> bool:
"""Return True if this :class:`.Script` is a 'real' branch point,
taking into account dependencies as well.
return len(self._all_nextrev) > 1
@property
- def is_merge_point(self):
+ def is_merge_point(self) -> bool:
"""Return True if this :class:`.Script` is a merge point."""
return len(self._versioned_down_revisions) > 1
-def tuple_rev_as_scalar(rev):
+def tuple_rev_as_scalar(
+ rev: Optional[Sequence[str]],
+) -> Optional[Union[str, Sequence[str]]]:
if not rev:
return None
elif len(rev) == 1:
import shlex
import subprocess
import sys
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Union
from .. import util
from ..util import compat
_registry = {}
-def register(name):
+def register(name: str) -> Callable:
"""A function decorator that will register that function as a write hook.
See the documentation linked below for an example.
return decorate
-def _invoke(name, revision, options):
+def _invoke(
+ name: str, revision: str, options: Dict[str, Union[str, int]]
+) -> Any:
"""Invokes the formatter registered for the given name.
:param name: The name of a formatter in the registry
return hook(revision, options)
-def _run_hooks(path, hook_config):
+def _run_hooks(path: str, hook_config: Dict[str, str]) -> None:
"""Invoke hooks for a generated revision."""
from .base import _split_on_space_comma
)
-def _parse_cmdline_options(cmdline_options_str, path):
+def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]:
"""Parse options from a string into a list.
Also substitutes the revision script token with the actual filename of
-from __future__ import absolute_import
-
import contextlib
import re
import sys
+from typing import Any
+from typing import Dict
from sqlalchemy import exc as sa_exc
from sqlalchemy import util
assert a == b, msg or "%r != %r" % (a, b)
-_dialect_mods = {}
+_dialect_mods: Dict[Any, Any] = {}
def _get_dialect(name):
from contextlib import contextmanager
import io
import re
+from typing import Any
+from typing import Dict
from sqlalchemy import Column
from sqlalchemy import inspect
from sqlalchemy.testing.fixtures import FutureEngineMixin
else:
- class FutureEngineMixin:
+ class FutureEngineMixin: # type:ignore[no-redef]
__requires__ = ("sqlalchemy_14",)
return engine, buf
-_engs = {}
+_engs: Dict[Any, Any] = {}
@contextmanager
-import sys
-
from sqlalchemy.testing.requirements import Requirements
from alembic import util
"SQLAlchemy 1.4 or greater required",
)
- @property
- def python3(self):
- return exclusions.skip_if(
- lambda: sys.version_info < (3,), "Python version 3.xx is required."
- )
-
@property
def comments(self):
return exclusions.only_if(
+from typing import Any
+from typing import Dict
+
from sqlalchemy import CHAR
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
def _get_bind(cls):
return config.db
- configure_opts = {}
+ configure_opts: Dict[Any, Any] = {}
@classmethod
def setup_class(cls):
import inspect
import io
import os
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Type
is_posix = os.name == "posix"
)
-def inspect_getargspec(func):
+def inspect_getargspec(func: Callable) -> ArgSpec:
"""getargspec based on fully vendored getfullargspec from Python 3.3."""
if inspect.ismethod(func):
- func = func.__func__
+ func = func.__func__ # type: ignore
if not inspect.isfunction(func):
raise TypeError("{!r} is not a Python function".format(func))
if co.co_flags & inspect.CO_VARKEYWORDS:
varkw = co.co_varnames[nargs]
- return ArgSpec(args, varargs, varkw, func.__defaults__)
+ return ArgSpec(args, varargs, varkw, func.__defaults__) # type: ignore
string_types = (str,)
def inspect_formatargspec(
- args,
- varargs=None,
- varkw=None,
- defaults=None,
- kwonlyargs=(),
- kwonlydefaults={},
- annotations={},
- formatarg=str,
- formatvarargs=lambda name: "*" + name,
- formatvarkw=lambda name: "**" + name,
- formatvalue=lambda value: "=" + repr(value),
- formatreturns=lambda text: " -> " + text,
- formatannotation=_formatannotation,
-):
+ args: List[str],
+ varargs: Optional[str] = None,
+ varkw: Optional[str] = None,
+ defaults: Optional[Any] = None,
+ kwonlyargs: tuple = (),
+ kwonlydefaults: Dict[Any, Any] = {},
+ annotations: Dict[Any, Any] = {},
+ formatarg: Type[str] = str,
+ formatvarargs: Callable = lambda name: "*" + name,
+ formatvarkw: Callable = lambda name: "**" + name,
+ formatvalue: Callable = lambda value: "=" + repr(value),
+ formatreturns: Callable = lambda text: " -> " + text,
+ formatannotation: Callable = _formatannotation,
+) -> str:
"""Copy formatargspec from python 3.7 standard library.
Python 3 has deprecated formatargspec and requested that Signature
# into a given buffer, but doesn't close it.
# not sure of a more idiomatic approach to this.
class EncodedIO(io.TextIOWrapper):
- def close(self):
+ def close(self) -> None:
pass
from os.path import join
from os.path import splitext
from subprocess import check_call
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
from .compat import is_posix
from .exc import CommandError
-def open_in_editor(filename, environ=None):
+def open_in_editor(
+ filename: str, environ: Optional[Dict[str, str]] = None
+) -> None:
"""
Opens the given file in a text editor. If the environment variable
``EDITOR`` is set, this is taken as preference.
:param environ: An optional drop-in replacement for ``os.environ``. Used
mainly for testing.
"""
-
+ env = os.environ if environ is None else environ
try:
- editor = _find_editor(environ)
+ editor = _find_editor(env)
check_call([editor, filename])
except Exception as exc:
raise CommandError("Error executing editor (%s)" % (exc,)) from exc
-def _find_editor(environ=None):
+def _find_editor(environ: Mapping[str, str]) -> str:
candidates = _default_editors()
for i, var in enumerate(("EDITOR", "VISUAL")):
if var in environ:
)
-def _find_executable(candidate, environ):
+def _find_executable(
+ candidate: str, environ: Mapping[str, str]
+) -> Optional[str]:
# Assuming this is on the PATH, we need to determine it's absolute
# location. Otherwise, ``check_call`` will fail
if not is_posix and splitext(candidate)[1] != ".exe":
return None
-def _default_editors():
+def _default_editors() -> List[str]:
# Look for an editor. Prefer the user's choice by env-var, fall back to
# most commonly installed editor (nano/vim)
if is_posix:
import collections
from collections.abc import Iterable
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
import uuid
import warnings
from .compat import string_types
+_T = TypeVar("_T")
+
+
class _ModuleClsMeta(type):
- def __setattr__(cls, key, value):
+ def __setattr__(cls, key: str, value: Callable) -> None:
super(_ModuleClsMeta, cls).__setattr__(key, value)
- cls._update_module_proxies(key)
+ cls._update_module_proxies(key) # type: ignore
class ModuleClsProxy(metaclass=_ModuleClsMeta):
"""
- _setups = collections.defaultdict(lambda: (set(), []))
+ _setups: Dict[type, Tuple[set, list]] = collections.defaultdict(
+ lambda: (set(), [])
+ )
@classmethod
- def _update_module_proxies(cls, name):
+ def _update_module_proxies(cls, name: str) -> None:
attr_names, modules = cls._setups[cls]
for globals_, locals_ in modules:
cls._add_proxied_attribute(name, globals_, locals_, attr_names)
- def _install_proxy(self):
+ def _install_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = self
for attr_name in attr_names:
globals_[attr_name] = getattr(self, attr_name)
- def _remove_proxy(self):
+ def _remove_proxy(self) -> None:
attr_names, modules = self._setups[self.__class__]
for globals_, locals_ in modules:
globals_["_proxy"] = None
return decorate
-def rev_id():
+def rev_id() -> str:
return uuid.uuid4().hex[-12:]
+@overload
+def to_tuple(x: Any, default: tuple) -> tuple:
+ ...
+
+
+@overload
+def to_tuple(x: None, default: _T = None) -> _T:
+ ...
+
+
+@overload
+def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple:
+ ...
+
+
def to_tuple(x, default=None):
if x is None:
return default
return (x,)
-def dedupe_tuple(tup):
+def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(unique_list(tup))
class Dispatcher:
- def __init__(self, uselist=False):
- self._registry = {}
+ def __init__(self, uselist: bool = False) -> None:
+ self._registry: Dict[tuple, Any] = {}
self.uselist = uselist
- def dispatch_for(self, target, qualifier="default"):
+ def dispatch_for(
+ self, target: Any, qualifier: str = "default"
+ ) -> Callable:
def decorate(fn):
if self.uselist:
self._registry.setdefault((target, qualifier), []).append(fn)
return decorate
- def dispatch(self, obj, qualifier="default"):
+ def dispatch(self, obj: Any, qualifier: str = "default") -> Any:
if isinstance(obj, string_types):
- targets = [obj]
+ targets: Sequence = [obj]
elif isinstance(obj, type):
targets = obj.__mro__
else:
else:
raise ValueError("no dispatch function for object: %s" % obj)
- def _fn_or_list(self, fn_or_list):
+ def _fn_or_list(
+ self, fn_or_list: Union[List[Callable], Callable]
+ ) -> Callable:
if self.uselist:
def go(*arg, **kw):
return go
else:
- return fn_or_list
+ return fn_or_list # type: ignore
- def branch(self):
+ def branch(self) -> "Dispatcher":
"""Return a copy of this dispatcher that is independently
writable."""
import logging
import sys
import textwrap
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import TextIO
+from typing import Union
import warnings
from sqlalchemy.engine import url
TERMWIDTH = None
-def write_outstream(stream, *text):
+def write_outstream(stream: TextIO, *text) -> None:
encoding = getattr(stream, "encoding", "ascii") or "ascii"
for t in text:
if not isinstance(t, binary_type):
break
-def status(_statmsg, fn, *arg, **kw):
+def status(_statmsg: str, fn: Callable, *arg, **kw) -> Any:
newline = kw.pop("newline", False)
msg(_statmsg + " ...", newline, True)
try:
raise
-def err(message):
+def err(message: str):
log.error(message)
msg("FAILED: %s" % message)
sys.exit(-1)
-def obfuscate_url_pw(u):
- u = url.make_url(u)
+def obfuscate_url_pw(input_url: str) -> str:
+ u = url.make_url(input_url)
if u.password:
if sqla_compat.sqla_14:
u = u.set(password="XXXXX")
else:
- u.password = "XXXXX"
+ u.password = "XXXXX" # type: ignore[misc]
return str(u)
-def warn(msg, stacklevel=2):
+def warn(msg: str, stacklevel: int = 2) -> None:
warnings.warn(msg, UserWarning, stacklevel=stacklevel)
-def msg(msg, newline=True, flush=False):
+def msg(msg: str, newline: bool = True, flush: bool = False) -> None:
if TERMWIDTH is None:
write_outstream(sys.stdout, msg)
if newline:
sys.stdout.flush()
-def format_as_comma(value):
+def format_as_comma(value: Optional[Union[str, "Iterable[str]"]]) -> str:
if value is None:
return ""
elif isinstance(value, string_types):
import os
import re
import tempfile
+from typing import Optional
from mako import exceptions
from mako.template import Template
from .exc import CommandError
-def template_to_file(template_file, dest, output_encoding, **kw):
+def template_to_file(
+ template_file: str, dest: str, output_encoding: str, **kw
+) -> None:
template = Template(filename=template_file)
try:
output = template.render_unicode(**kw).encode(output_encoding)
f.write(output)
-def coerce_resource_to_filename(fname):
+def coerce_resource_to_filename(fname: str) -> str:
"""Interpret a filename as either a filesystem location or as a package
resource.
return fname
-def pyc_file_from_path(path):
+def pyc_file_from_path(path: str) -> Optional[str]:
"""Given a python source path, locate the .pyc."""
candidate = importlib.util.cache_from_source(path)
return None
-def load_python_file(dir_, filename):
+def load_python_file(dir_: str, filename: str):
"""Load a file from the given path as a Python module."""
module_id = re.sub(r"\W", "_", filename)
if pyc_path is None:
raise ImportError("Can't find Python file %s" % path)
else:
- module = load_module_pyc(module_id, pyc_path)
+ module = load_module_py(module_id, pyc_path)
elif ext in (".pyc", ".pyo"):
- module = load_module_pyc(module_id, path)
+ module = load_module_py(module_id, path)
return module
-def load_module_py(module_id, path):
+def load_module_py(module_id: str, path: str):
spec = importlib.util.spec_from_file_location(module_id, path)
+ assert spec
module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- return module
-
-
-def load_module_pyc(module_id, path):
- spec = importlib.util.spec_from_file_location(module_id, path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
+ spec.loader.exec_module(module) # type: ignore
return module
import contextlib
import re
+from typing import Iterator
+from typing import Mapping
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from sqlalchemy import __version__
from sqlalchemy import inspect
from sqlalchemy.schema import Column
from sqlalchemy.schema import ForeignKeyConstraint
from sqlalchemy.sql import visitors
+from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import quoted_name
-from sqlalchemy.sql.expression import _BindParamClause
-from sqlalchemy.sql.expression import _TextClause as TextClause
+from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.visitors import traverse
from . import compat
-
-def _safe_int(value):
+if TYPE_CHECKING:
+ from sqlalchemy import Index
+ from sqlalchemy import Table
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine import Transaction
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.base import ColumnCollection
+ from sqlalchemy.sql.compiler import SQLCompiler
+ from sqlalchemy.sql.dml import Insert
+ from sqlalchemy.sql.elements import ColumnClause
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import SchemaItem
+ from sqlalchemy.sql.selectable import Select
+ from sqlalchemy.sql.selectable import TableClause
+
+_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
+
+
+def _safe_int(value: str) -> Union[int, str]:
try:
return int(value)
except:
try:
from sqlalchemy import Computed # noqa
except ImportError:
+ Computed = None # type: ignore
has_computed = False
has_computed_reflection = False
else:
try:
from sqlalchemy import Identity # noqa
except ImportError:
+ Identity = None # type: ignore
has_identity = False
else:
# attributes common to Indentity and Sequence
@contextlib.contextmanager
-def _ensure_scope_for_ddl(connection):
+def _ensure_scope_for_ddl(
+ connection: Optional["Connection"],
+) -> Iterator[None]:
try:
- in_transaction = connection.in_transaction
+ in_transaction = connection.in_transaction # type: ignore[union-attr]
except AttributeError:
- # catch for MockConnection
+ # catch for MockConnection, None
yield
else:
if not in_transaction():
+ assert connection is not None
with connection.begin():
yield
else:
yield
-def _safe_begin_connection_transaction(connection):
+def _safe_begin_connection_transaction(
+ connection: "Connection",
+) -> "Transaction":
transaction = _get_connection_transaction(connection)
if transaction:
return transaction
return connection.begin()
-def _get_connection_in_transaction(connection):
+def _get_connection_in_transaction(connection: Optional["Connection"]) -> bool:
try:
- in_transaction = connection.in_transaction
+ in_transaction = connection.in_transaction # type: ignore
except AttributeError:
# catch for MockConnection
return False
return in_transaction()
-def _copy(schema_item, **kw):
+def _copy(schema_item: _CE, **kw) -> _CE:
if hasattr(schema_item, "_copy"):
return schema_item._copy(**kw)
else:
return schema_item.copy(**kw)
-def _get_connection_transaction(connection):
+def _get_connection_transaction(
+ connection: "Connection",
+) -> Optional["Transaction"]:
if sqla_14:
return connection.get_transaction()
else:
- return connection._root._Connection__transaction
+ r = connection._root # type: ignore[attr-defined]
+ return r._Connection__transaction
-def _create_url(*arg, **kw):
+def _create_url(*arg, **kw) -> url.URL:
if hasattr(url.URL, "create"):
return url.URL.create(*arg, **kw)
else:
return url.URL(*arg, **kw)
-def _connectable_has_table(connectable, tablename, schemaname):
+def _connectable_has_table(
+ connectable: "Connection", tablename: str, schemaname: Union[str, None]
+) -> bool:
if sqla_14:
return inspect(connectable).has_table(tablename, schemaname)
else:
)
-def _server_default_is_computed(*server_default):
+def _server_default_is_computed(*server_default) -> bool:
if not has_computed:
return False
else:
return any(isinstance(sd, Computed) for sd in server_default)
-def _server_default_is_identity(*server_default):
+def _server_default_is_identity(*server_default) -> bool:
if not sqla_14:
return False
else:
return any(isinstance(sd, Identity) for sd in server_default)
-def _table_for_constraint(constraint):
+def _table_for_constraint(constraint: "Constraint") -> "Table":
if isinstance(constraint, ForeignKeyConstraint):
- return constraint.parent
+ table = constraint.parent
+ assert table is not None
+ return table
else:
return constraint.table
return list(constraint.columns)
-def _reflect_table(inspector, table, include_cols):
+def _reflect_table(
+ inspector: "Inspector", table: "Table", include_cols: None
+) -> None:
if sqla_14:
return inspector.reflect_table(table, None)
else:
)
-def _fk_is_self_referential(constraint):
- spec = constraint.elements[0]._get_colspec()
+def _fk_is_self_referential(constraint: "ForeignKeyConstraint") -> bool:
+ spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
tokens = spec.split(".")
tokens.pop(-1) # colname
tablekey = ".".join(tokens)
+ assert constraint.parent is not None
return tablekey == constraint.parent.key
-def _is_type_bound(constraint):
+def _is_type_bound(constraint: "Constraint") -> bool:
# this deals with SQLAlchemy #3260, don't copy CHECK constraints
# that will be generated by the type.
# new feature added for #3260
- return constraint._type_bound
+ return constraint._type_bound # type: ignore[attr-defined]
def _find_columns(clause):
return cols
-def _remove_column_from_collection(collection, column):
+def _remove_column_from_collection(
+ collection: "ColumnCollection", column: Union["Column", "ColumnClause"]
+) -> None:
"""remove a column from a ColumnCollection."""
# workaround for older SQLAlchemy, remove the
# same object that's present
+ assert column.key is not None
to_remove = collection[column.key]
collection.remove(to_remove)
-def _textual_index_column(table, text_):
+def _textual_index_column(
+ table: "Table", text_: Union[str, "TextClause", "ColumnElement"]
+) -> Union["ColumnElement", "Column"]:
"""a workaround for the Index construct's severe lack of flexibility"""
if isinstance(text_, compat.string_types):
c = Column(text_, sqltypes.NULLTYPE)
raise ValueError("String or text() construct expected")
-def _copy_expression(expression, target_table):
+def _copy_expression(expression: _CE, target_table: "Table") -> _CE:
def replace(col):
if (
isinstance(col, Column)
__visit_name__ = "_textual_idx_element"
- def __init__(self, table, text):
+ def __init__(self, table: "Table", text: "TextClause") -> None:
self.table = table
self.text = text
self.key = text.text
@compiles(_textual_index_element)
-def _render_textual_index_column(element, compiler, **kw):
+def _render_textual_index_column(
+ element: _textual_index_element, compiler: "SQLCompiler", **kw
+) -> str:
return compiler.process(element.text, **kw)
-class _literal_bindparam(_BindParamClause):
+class _literal_bindparam(BindParameter):
pass
@compiles(_literal_bindparam)
-def _render_literal_bindparam(element, compiler, **kw):
+def _render_literal_bindparam(
+ element: _literal_bindparam, compiler: "SQLCompiler", **kw
+) -> str:
return compiler.render_literal_bindparam(element, **kw)
return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
-def _column_kwargs(col):
+def _column_kwargs(col: "Column") -> Mapping:
if sqla_13:
return col.kwargs
else:
return {}
-def _get_constraint_final_name(constraint, dialect):
+def _get_constraint_final_name(
+ constraint: Union["Index", "Constraint"], dialect: Optional["Dialect"]
+) -> Optional[str]:
if constraint.name is None:
return None
- elif sqla_14:
+ assert dialect is not None
+ if sqla_14:
# for SQLAlchemy 1.4 we would like to have the option to expand
# the use of "deferred" names for constraints as well as to have
# some flexibility with "None" name and similar; make use of new
if hasattr(constraint.name, "quote"):
# might be quoted_name, might be truncated_name, keep it the
# same
- quoted_name_cls = type(constraint.name)
+ quoted_name_cls: type = type(constraint.name)
else:
quoted_name_cls = quoted_name
if isinstance(constraint, schema.Index):
# name should not be quoted.
- return dialect.ddl_compiler(dialect, None)._prepared_index_name(
+ d = dialect.ddl_compiler(dialect, None)
+ return d._prepared_index_name( # type: ignore[attr-defined]
constraint
)
else:
return dialect.identifier_preparer.format_constraint(constraint)
-def _constraint_is_named(constraint, dialect):
+def _constraint_is_named(
+ constraint: Union["Constraint", "Index"], dialect: Optional["Dialect"]
+) -> bool:
if sqla_14:
if constraint.name is None:
return False
+ assert dialect is not None
name = dialect.identifier_preparer.format_constraint(
constraint, _alembic_quote=False
)
return constraint.name is not None
-def _is_mariadb(mysql_dialect):
+def _is_mariadb(mysql_dialect: "Dialect") -> bool:
if sqla_14:
- return mysql_dialect.is_mariadb
+ return mysql_dialect.is_mariadb # type: ignore[attr-defined]
else:
- return mysql_dialect.server_version_info and mysql_dialect._is_mariadb
+ return bool(
+ mysql_dialect.server_version_info
+ and mysql_dialect._is_mariadb # type: ignore[attr-defined]
+ )
def _mariadb_normalized_version_info(mysql_dialect):
return mysql_dialect._mariadb_normalized_version_info
-def _insert_inline(table):
+def _insert_inline(table: Union["TableClause", "Table"]) -> "Insert":
if sqla_14:
return table.insert().inline()
else:
else:
from sqlalchemy import create_engine
- def create_mock_engine(url, executor):
+ def create_mock_engine(url, executor, **kw): # type: ignore[misc]
return create_engine(
"postgresql://", strategy="mock", executor=executor
)
- def _select(*columns):
- return sql.select(list(columns))
+ def _select(*columns, **kw) -> "Select":
+ return sql.select(list(columns), **kw)
--- /dev/null
+.. change::
+ :tags: feature, general
+
+ pep-484 type annotations have been added throughout the library. This
+ should be helpful in providing Mypy and IDE support, however there is not
+ full support for Alembic's dynamically modified "op" namespace as of yet; a
+ future release will likely modify the approach used for importing this
+ namespace to be better compatible with pep-484 capabilities.
\ No newline at end of file
application-import-names = alembic,tests
per-file-ignores =
**/__init__.py:F401
+max-line-length = 79
[sqla_testing]
requirement_cls=tests.requirements:DefaultRequirements
addopts= --tb native -v -r sfxX -p no:warnings -p no:logging --maxfail=25
python_files=tests/test_*.py
+[mypy]
+show_error_codes = True
+allow_redefinition = True
+[mypy-mako.*]
+ignore_missing_imports = True
+
+[mypy-sqlalchemy.testing.*]
+ignore_missing_imports = True
from alembic.script.revision import RevisionError
from alembic.script.revision import RevisionMap
from alembic.testing import assert_raises_message
-from alembic.testing import config
from alembic.testing import eq_
+from alembic.testing import expect_raises_message
from alembic.testing.fixtures import TestBase
from . import _large_map
class APITest(TestBase):
- @config.requirements.python3
def test_invalid_datatype(self):
map_ = RevisionMap(
lambda: [
Revision("c", ("b",)),
]
)
- assert_raises_message(
+ with expect_raises_message(
RevisionError,
"revision identifier b'12345' is not a string; "
"ensure database driver settings are correct",
- map_.get_revisions,
- b"12345",
- )
+ ):
+ map_.get_revisions(b"12345")
- assert_raises_message(
+ with expect_raises_message(
RevisionError,
"revision identifier b'12345' is not a string; "
"ensure database driver settings are correct",
- map_.get_revision,
- b"12345",
- )
+ ):
+ map_.get_revision(b"12345")
- assert_raises_message(
+ with expect_raises_message(
RevisionError,
r"revision identifier \(b'12345',\) is not a string; "
"ensure database driver settings are correct",
- map_.get_revision,
- (b"12345",),
- )
+ ):
+ map_.get_revision((b"12345",))
map_.get_revision(("a",))
map_.get_revision("a")
c1 = map_.get_revision("c1")
c2 = map_.get_revision("c2")
d = map_.get_revision("d")
- eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), [c1, c2, d])
+ eq_(map_.filter_for_lineage([c1, c2, d], "c1branch@head"), (c1, c2, d))
def test_filter_for_lineage_heads(self):
eq_(
self.map.filter_for_lineage([self.map.get_revision("f")], "heads"),
- [self.map.get_revision("f")],
+ (self.map.get_revision("f"),),
)
def setUp(self):
)
def test_get_base_revisions_labeled(self):
- eq_(self.map._get_base_revisions("somelongername@base"), ["a"])
+ eq_(self.map._get_base_revisions("somelongername@base"), ("a",))
def test_get_current_named_rev(self):
eq_(self.map.get_revision("ebranch@head"), self.map.get_revision("f"))
def test_get_base_revisions(self):
- eq_(self.map._get_base_revisions("base"), ["a", "d"])
+ eq_(self.map._get_base_revisions("base"), ("a", "d"))
def test_iterate_head_to_named_base(self):
self._assert_iteration(
{oracle,mssql}: python reap_dbs.py db_idents.txt
+[testenv:mypy]
+basepython = python3
+deps=
+ mypy
+ sqlalchemy>=1.4.0
+ sqlalchemy2-stubs
+ mako
+ types-pkg-resources
+ types-python-dateutil
+ # is imported in alembic/testing and mypy complains if it's installed.
+ pytest
+commands = mypy ./alembic/ --exclude alembic/templates
+
[testenv:pep8]
basepython = python3
deps=