-"""Provide the 'autogenerate' feature which can produce migration operations
-automatically."""
+from __future__ import annotations
import contextlib
from typing import Any
from .. import util
from ..operations import ops
+"""Provide the 'autogenerate' feature which can produce migration operations
+automatically."""
+
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
branch_labels=migration_script.branch_label,
version_path=migration_script.version_path,
depends_on=migration_script.depends_on,
- **template_args
+ **template_args,
)
def run_autogenerate(
+from __future__ import annotations
+
import contextlib
import logging
import re
params["name"],
*[conn_table.c[cname] for cname in params["column_names"]],
unique=params["unique"],
- _table=conn_table
+ _table=conn_table,
)
if "duplicates_constraint" in params:
ix.info["duplicates_constraint"] = params["duplicates_constraint"]
) -> "UniqueConstraint":
uq = sa_schema.UniqueConstraint(
*[conn_table.c[cname] for cname in params["column_names"]],
- name=params["name"]
+ name=params["name"],
)
if "duplicates_index" in params:
uq.info["duplicates_index"] = params["duplicates_index"]
if isinstance(fk, sa_schema.ForeignKeyConstraint)
)
- conn_fks = [
+ conn_fks_list = [
fk
for fk in inspector.get_foreign_keys(tname, schema=schema)
if autogen_context.run_name_filters(
)
]
- backend_reflects_fk_options = bool(conn_fks and "options" in conn_fks[0])
+ backend_reflects_fk_options = bool(
+ conn_fks_list and "options" in conn_fks_list[0]
+ )
- conn_fks = set(_make_foreign_key(const, conn_table) for const in conn_fks)
+ conn_fks = set(
+ _make_foreign_key(const, conn_table) for const in conn_fks_list
+ )
# give the dialect a chance to correct the FKs to match more
# closely
conn_fks, metadata_fks
)
- metadata_fks = set(
+ metadata_fks_sig = set(
_fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
for fk in metadata_fks
)
- conn_fks = set(
+ conn_fks_sig = set(
_fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
for fk in conn_fks
)
- conn_fks_by_sig = dict((c.sig, c) for c in conn_fks)
- metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks)
+ conn_fks_by_sig = dict((c.sig, c) for c in conn_fks_sig)
+ metadata_fks_by_sig = dict((c.sig, c) for c in metadata_fks_sig)
metadata_fks_by_name = dict(
- (c.name, c) for c in metadata_fks if c.name is not None
+ (c.name, c) for c in metadata_fks_sig if c.name is not None
)
conn_fks_by_name = dict(
- (c.name, c) for c in conn_fks if c.name is not None
+ (c.name, c) for c in conn_fks_sig if c.name is not None
)
def _add_fk(obj, compare_to):
+from __future__ import annotations
+
from collections import OrderedDict
from io import StringIO
import re
+from __future__ import annotations
+
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
+from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
_traverse = util.Dispatcher()
- _chained = None
+ _chained: Optional[Rewriter] = None
def __init__(self) -> None:
self.dispatch = util.Dispatcher()
- def chain(self, other: "Rewriter") -> "Rewriter":
+ def chain(self, other: Rewriter) -> Rewriter:
"""Produce a "chain" of this :class:`.Rewriter` to another.
This allows two rewriters to operate serially on a stream,
+from __future__ import annotations
+
import os
from typing import Callable
-from typing import cast
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
for file_ in os.listdir(template_dir):
file_path = os.path.join(template_dir, file_)
if file_ == "alembic.ini.mako":
- config_file = os.path.abspath(cast(str, config.config_file_name))
- if os.access(cast(str, config_file), os.F_OK):
+ assert config.config_file_name is not None
+ config_file = os.path.abspath(config.config_file_name)
+ if os.access(config_file, os.F_OK):
util.msg("File %s already exists, skipping" % config_file)
else:
script._generate_template(
refresh=True,
head=revisions,
branch_labels=branch_label,
- **template_args # type:ignore[arg-type]
+ **template_args, # type:ignore[arg-type]
)
"No revision files indicated by symbol '%s'" % rev
)
for sc in revs:
+ assert sc
util.open_in_editor(sc.path)
+from __future__ import annotations
+
from argparse import ArgumentParser
from argparse import Namespace
from configparser import ConfigParser
fn(
config,
*[getattr(options, k, None) for k in positional],
- **dict((k, getattr(options, k, None)) for k in kwarg)
+ **dict((k, getattr(options, k, None)) for k in kwarg),
)
except util.CommandError as e:
if options.raiseerr:
# ### this file stubs are generated by tools/write_pyi.py - do not edit ###
# ### imports are manually managed
+from __future__ import annotations
from typing import Callable
from typing import ContextManager
config: Config
def configure(
- connection: Optional["Connection"] = None,
+ connection: Optional[Connection] = None,
url: Optional[str] = None,
dialect_name: Optional[str] = None,
dialect_opts: Optional[dict] = None,
tag: Optional[str] = None,
template_args: Optional[dict] = None,
render_as_batch: bool = False,
- target_metadata: Optional["MetaData"] = None,
+ target_metadata: Optional[MetaData] = None,
include_name: Optional[Callable] = None,
include_object: Optional[Callable] = None,
include_schemas: bool = False,
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
on_version_apply: Optional[Callable] = None,
- **kw
+ **kw,
) -> None:
"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
"""
-def get_context() -> "MigrationContext":
+def get_context() -> MigrationContext:
"""Return the current :class:`.MigrationContext` object.
If :meth:`.EnvironmentContext.configure` has not been
+from __future__ import annotations
+
import functools
from typing import Optional
from typing import TYPE_CHECKING
name: str,
column_name: str,
default: Optional[_ServerDefault],
- **kw
+ **kw,
) -> None:
super(ColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
column_name: str,
default: Optional["Identity"],
impl: "DefaultImpl",
- **kw
+ **kw,
) -> None:
super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
self.default = default
+from __future__ import annotations
+
from collections import namedtuple
import re
from typing import Any
existing_server_default: Optional["_ServerDefault"] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
- **kw: Any
+ **kw: Any,
) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
- **kw
+ **kw,
)
)
if type_ is not None:
table_name: str,
column: "Column",
schema: Optional[str] = None,
- **kw
+ **kw,
) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
+from __future__ import annotations
+
from typing import Any
from typing import List
from typing import Optional
existing_type: Optional["TypeEngine"] = None,
existing_server_default: Optional["_ServerDefault"] = None,
existing_nullable: Optional[bool] = None,
- **kw: Any
+ **kw: Any,
) -> None:
if nullable is not None:
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
- **kw
+ **kw,
)
if server_default is not False and used_default is False:
table_name: str,
column: "Column",
schema: Optional[str] = None,
- **kw
+ **kw,
) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
+from __future__ import annotations
+
import re
from typing import Any
from typing import Optional
existing_autoincrement: Optional[bool] = None,
comment: Optional[Union[str, "Literal[False]"]] = False,
existing_comment: Optional[str] = None,
- **kw: Any
+ **kw: Any,
) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
existing_nullable=existing_nullable,
server_default=server_default,
existing_server_default=existing_server_default,
- **kw
+ **kw,
)
if name is not None or self._is_mysql_allowed_functional_default(
type_ if type_ is not None else existing_type, server_default
+from __future__ import annotations
+
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
+from __future__ import annotations
+
import logging
import re
from typing import Any
existing_server_default: Optional["_ServerDefault"] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
- **kw: Any
+ **kw: Any,
) -> None:
using = kw.pop("postgresql_using", None)
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_autoincrement=existing_autoincrement,
- **kw
+ **kw,
)
def autogen_column_reflect(self, inspector, table, column_info):
where: Optional[Union["BinaryExpression", str]] = None,
schema: Optional[str] = None,
_orig_constraint: Optional["ExcludeConstraint"] = None,
- **kw
+ **kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
*self.elements,
name=self.constraint_name,
where=self.where,
- **self.kw
+ **self.kw,
)
for (
expr,
constraint_name: str,
table_name: str,
*elements: Any,
- **kw: Any
+ **kw: Any,
) -> Optional["Table"]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
+from __future__ import annotations
+
import re
from typing import Any
from typing import Dict
### end imports ###
def add_column(
- table_name: str, column: "Column", schema: Optional[str] = None
-) -> Optional["Table"]:
+ table_name: str, column: Column, schema: Optional[str] = None
+) -> Optional[Table]:
"""Issue an "add column" instruction using the current
migration context.
comment: Union[str, bool, None] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
- type_: Union["TypeEngine", Type["TypeEngine"], None] = None,
- existing_type: Union["TypeEngine", Type["TypeEngine"], None] = None,
+ type_: Union[TypeEngine, Type[TypeEngine], None] = None,
+ existing_type: Union[TypeEngine, Type[TypeEngine], None] = None,
existing_server_default: Union[
- str, bool, "Identity", "Computed", None
+ str, bool, Identity, Computed, None
] = False,
existing_nullable: Optional[bool] = None,
existing_comment: Optional[str] = None,
schema: Optional[str] = None,
**kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue an "alter column" instruction using the
current migration context.
"""
def bulk_insert(
- table: Union["Table", "TableClause"],
+ table: Union[Table, TableClause],
rows: List[dict],
multiinsert: bool = True,
) -> None:
def create_check_constraint(
constraint_name: Optional[str],
table_name: str,
- condition: Union[str, "BinaryExpression"],
+ condition: Union[str, BinaryExpression],
schema: Optional[str] = None,
**kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue a "create check constraint" instruction using the
current migration context.
def create_exclude_constraint(
constraint_name: str, table_name: str, *elements: Any, **kw: Any
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
source_schema: Optional[str] = None,
referent_schema: Optional[str] = None,
**dialect_kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue a "create foreign key" instruction using the
current migration context.
def create_index(
index_name: str,
table_name: str,
- columns: Sequence[Union[str, "TextClause", "Function"]],
+ columns: Sequence[Union[str, TextClause, Function]],
schema: Optional[str] = None,
unique: bool = False,
**kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue a "create index" instruction using the current
migration context.
table_name: str,
columns: List[str],
schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue a "create primary key" instruction using the current
migration context.
"""
-def create_table(table_name: str, *columns, **kw) -> Optional["Table"]:
+def create_table(table_name: str, *columns, **kw) -> Optional[Table]:
"""Issue a "create table" instruction using the current migration
context.
comment: Optional[str],
existing_comment: None = None,
schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Emit a COMMENT ON operation to set the comment for a table.
.. versionadded:: 1.0.6
def drop_column(
table_name: str, column_name: str, schema: Optional[str] = None, **kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue a "drop column" instruction using the current
migration context.
table_name: str,
type_: Optional[str] = None,
schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Drop a constraint of the given name, typically via DROP CONSTRAINT.
:param constraint_name: name of the constraint.
table_name: Optional[str] = None,
schema: Optional[str] = None,
**kw
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue a "drop index" instruction using the current
migration context.
table_name: str,
existing_comment: Optional[str] = None,
schema: Optional[str] = None,
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Issue a "drop table comment" operation to
remove an existing comment set on a table.
"""
def execute(
- sqltext: Union[str, "TextClause", "Update"], execution_options: None = None
-) -> Optional["Table"]:
+ sqltext: Union[str, TextClause, Update], execution_options: None = None
+) -> Optional[Table]:
"""Execute the given SQL using the current migration context.
The given SQL can be a plain string, e.g.::
:meth:`sqlalchemy.engine.Connection.execution_options`.
"""
-def f(name: str) -> "conv":
+def f(name: str) -> conv:
"""Indicate a string name that has already had a naming convention
applied to it.
"""
-def get_bind() -> "Connection":
+def get_bind() -> Connection:
"""Return the current 'bind'.
Under normal circumstances, this is the
"""
-def invoke(operation: "MigrateOperation") -> Any:
+def invoke(operation: MigrateOperation) -> Any:
"""Given a :class:`.MigrateOperation`, invoke it in terms of
this :class:`.Operations` instance.
def rename_table(
old_table_name: str, new_table_name: str, schema: Optional[str] = None
-) -> Optional["Table"]:
+) -> Optional[Table]:
"""Emit an ALTER TABLE to rename a table.
:param old_table_name: old name.
+from __future__ import annotations
+
from contextlib import contextmanager
import re
import textwrap
+from __future__ import annotations
+
from typing import Any
-from typing import cast
from typing import Dict
from typing import List
from typing import Optional
schema=self.schema,
autoload_with=self.operations.get_bind(),
*self.reflect_args,
- **self.reflect_kwargs
+ **self.reflect_kwargs,
)
reflected = True
m,
*(list(self.columns.values()) + list(self.table_args)),
schema=schema,
- **self.table_kwargs
+ **self.table_kwargs,
)
for const in (
index.name,
unique=index.unique,
*[self.new_table.c[col] for col in index.columns.keys()],
- **index.kwargs
+ **index.kwargs,
)
)
return idx
for elem in constraint.elements
]
],
- schema=referent_schema
+ schema=referent_schema,
)
def _create(self, op_impl: "DefaultImpl") -> None:
type_: Optional["TypeEngine"] = None,
autoincrement: None = None,
comment: Union[str, "Literal[False]"] = False,
- **kw
+ **kw,
) -> None:
existing = self.columns[column_name]
existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
column: "Column",
insert_before: Optional[str] = None,
insert_after: Optional[str] = None,
- **kw
+ **kw,
) -> None:
self._setup_dependencies_for_add_column(
column.name, insert_before, insert_after
if col_const.name == const.name:
self.columns[col.name].constraints.remove(col_const)
else:
- const = self.named_constraints.pop(cast(str, const.name))
+ assert const.name
+ const = self.named_constraints.pop(const.name)
except KeyError:
if _is_type_bound(const):
# type-bound constraints are only included in the new
+from __future__ import annotations
+
from abc import abstractmethod
import re
from typing import Any
table_name: str,
columns: Sequence[str],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
table_name: str,
columns: Sequence[str],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
table_name: str,
columns: Sequence[str],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> Any:
"""Issue a "create unique constraint" instruction using the
current migration context.
operations: "BatchOperations",
constraint_name: str,
columns: Sequence[str],
- **kw
+ **kw,
) -> Any:
"""Issue a "create unique constraint" instruction using the
current batch migration context.
referent_table: str,
local_cols: List[str],
remote_cols: List[str],
- **kw
+ **kw,
) -> None:
self.constraint_name = constraint_name
self.source_table = source_table
match: Optional[str] = None,
source_schema: Optional[str] = None,
referent_schema: Optional[str] = None,
- **dialect_kw
+ **dialect_kw,
) -> Optional["Table"]:
"""Issue a "create foreign key" instruction using the
current migration context.
deferrable: None = None,
initially: None = None,
match: None = None,
- **dialect_kw
+ **dialect_kw,
) -> None:
"""Issue a "create foreign key" instruction using the
current batch migration context.
table_name: str,
condition: Union[str, "TextClause", "ColumnElement[Any]"],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
return cls(
ck_constraint.name,
constraint_table.name,
- cast(
- "Union[TextClause, ColumnElement[Any]]", ck_constraint.sqltext
- ),
+ cast("ColumnElement[Any]", ck_constraint.sqltext),
schema=constraint_table.schema,
**ck_constraint.dialect_kwargs,
)
table_name: str,
condition: Union[str, "BinaryExpression"],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> Optional["Table"]:
"""Issue a "create check constraint" instruction using the
current migration context.
operations: "BatchOperations",
constraint_name: str,
condition: "TextClause",
- **kw
+ **kw,
) -> Optional["Table"]:
"""Issue a "create check constraint" instruction using the
current batch migration context.
columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
schema: Optional[str] = None,
unique: bool = False,
- **kw
+ **kw,
) -> None:
self.index_name = index_name
self.table_name = table_name
columns: Sequence[Union[str, "TextClause", "Function"]],
schema: Optional[str] = None,
unique: bool = False,
- **kw
+ **kw,
) -> Optional["Table"]:
r"""Issue a "create index" instruction using the current
migration context.
operations: "BatchOperations",
index_name: str,
columns: List[str],
- **kw
+ **kw,
) -> Optional["Table"]:
"""Issue a "create index" instruction using the
current batch migration context.
table_name: Optional[str] = None,
schema: Optional[str] = None,
_reverse: Optional["CreateIndexOp"] = None,
- **kw
+ **kw,
) -> None:
self.index_name = index_name
self.table_name = table_name
index_name: str,
table_name: Optional[str] = None,
schema: Optional[str] = None,
- **kw
+ **kw,
) -> Optional["Table"]:
r"""Issue a "drop index" instruction using the current
migration context.
schema: Optional[str] = None,
_namespace_metadata: Optional["MetaData"] = None,
_constraints_included: bool = False,
- **kw
+ **kw,
) -> None:
self.table_name = table_name
self.columns = columns
operations: "Operations",
table_name: str,
schema: Optional[str] = None,
- **kw: Any
+ **kw: Any,
) -> None:
r"""Issue a "drop table" instruction using the current
migration context.
modify_server_default: Any = False,
modify_name: Optional[str] = None,
modify_type: Optional[Any] = None,
- **kw
+ **kw,
) -> None:
super(AlterColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
existing_nullable: Optional[bool] = None,
existing_comment: Optional[str] = None,
schema: Optional[str] = None,
- **kw
+ **kw,
) -> Optional["Table"]:
r"""Issue an "alter column" instruction using the
current migration context.
existing_comment: None = None,
insert_before: None = None,
insert_after: None = None,
- **kw
+ **kw,
) -> Optional["Table"]:
"""Issue an "alter column" instruction using the current
batch migration context.
table_name: str,
column: "Column",
schema: Optional[str] = None,
- **kw
+ **kw,
) -> None:
super(AddColumnOp, self).__init__(table_name, schema=schema)
self.column = column
column_name: str,
schema: Optional[str] = None,
_reverse: Optional["AddColumnOp"] = None,
- **kw
+ **kw,
) -> None:
super(DropColumnOp, self).__init__(table_name, schema=schema)
self.column_name = column_name
table_name: str,
column_name: str,
schema: Optional[str] = None,
- **kw
+ **kw,
) -> Optional["Table"]:
"""Issue a "drop column" instruction using the current
migration context.
+from __future__ import annotations
+
from typing import Any
from typing import Dict
from typing import List
table_name: str,
cols: Sequence[str],
schema: Optional[str] = None,
- **dialect_kw
+ **dialect_kw,
) -> "PrimaryKeyConstraint":
m = self.metadata()
columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
referent_schema: Optional[str] = None,
initially: Optional[str] = None,
match: Optional[str] = None,
- **dialect_kw
+ **dialect_kw,
) -> "ForeignKeyConstraint":
m = self.metadata()
if source == referent and source_schema == referent_schema:
referent,
m,
*[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
- schema=referent_schema
+ schema=referent_schema,
)
t1 = sa_schema.Table(
source,
m,
*[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
- schema=source_schema
+ schema=source_schema,
)
tname = (
ondelete=ondelete,
deferrable=deferrable,
initially=initially,
- **dialect_kw
+ **dialect_kw,
)
t1.append_constraint(f)
source: str,
local_cols: Sequence[str],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> "UniqueConstraint":
t = sa_schema.Table(
source,
self.metadata(),
*[sa_schema.Column(n, NULLTYPE) for n in local_cols],
- schema=schema
+ schema=schema,
)
kw["name"] = name
uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
source: str,
condition: Union[str, "TextClause", "ColumnElement[Any]"],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> Union["CheckConstraint"]:
t = sa_schema.Table(
source,
table_name: str,
type_: Optional[str],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> Any:
t = self.table(table_name, schema=schema)
types: Dict[Optional[str], Any] = {
tablename: Optional[str],
columns: Sequence[Union[str, "TextClause", "ColumnElement[Any]"]],
schema: Optional[str] = None,
- **kw
+ **kw,
) -> "Index":
t = sa_schema.Table(
tablename or "no_table",
idx = sa_schema.Index(
name,
*[util.sqla_compat._textual_index_column(t, n) for n in columns],
- **kw
+ **kw,
)
return idx
+from __future__ import annotations
+
from typing import Callable
from typing import ContextManager
from typing import Dict
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
on_version_apply: Optional[Callable] = None,
- **kw
+ **kw,
) -> None:
"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
+from __future__ import annotations
+
from contextlib import contextmanager
import logging
import sys
from ..config import Config
from ..script.base import Script
from ..script.base import ScriptDirectory
+ from ..script.revision import _RevisionOrBase
from ..script.revision import Revision
from ..script.revision import RevisionMap
elif start_from_rev is not None and self.script:
start_from_rev = [
- self.script.get_revision(sfr).revision
+ cast("Script", self.script.get_revision(sfr)).revision
for sfr in util.to_list(start_from_rev)
if sfr not in (None, "base")
]
"""
- is_upgrade: bool = None # type:ignore[assignment]
+ is_upgrade: bool
"""True/False: indicates whether this operation ascends or descends the
version tree."""
- is_stamp: bool = None # type:ignore[assignment]
+ is_stamp: bool
"""True/False: indicates whether this operation is a stamp (i.e. whether
it results in any actual database operations)."""
- up_revision_id: Optional[str] = None
+ up_revision_id: Optional[str]
"""Version string corresponding to :attr:`.Revision.revision`.
In the case of a stamp operation, it is advised to use the
"""
- up_revision_ids: Tuple[str, ...] = None # type:ignore[assignment]
+ up_revision_ids: Tuple[str, ...]
"""Tuple of version strings corresponding to :attr:`.Revision.revision`.
- In the majority of cases, this tuple will be a single value, synonomous
+ In the majority of cases, this tuple will be a single value, synonymous
with the scalar value of :attr:`.MigrationInfo.up_revision_id`.
It can be multiple revision identifiers only in the case of an
``alembic stamp`` operation which is moving downwards from multiple
"""
- down_revision_ids: Tuple[str, ...] = None # type:ignore[assignment]
+ down_revision_ids: Tuple[str, ...]
"""Tuple of strings representing the base revisions of this migration step.
If empty, this represents a root revision; otherwise, the first item
from dependencies.
"""
- revision_map: "RevisionMap" = None # type:ignore[assignment]
+ revision_map: "RevisionMap"
"""The revision map inside of which this operation occurs."""
def __init__(
)
@property
- def up_revision(self) -> "Revision":
+ def up_revision(self) -> Optional[Revision]:
"""Get :attr:`~.MigrationInfo.up_revision_id` as
a :class:`.Revision`.
return self.revision_map.get_revision(self.up_revision_id)
@property
- def up_revisions(self) -> Tuple["Revision", ...]:
+ def up_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Get :attr:`~.MigrationInfo.up_revision_ids` as a
:class:`.Revision`."""
return self.revision_map.get_revisions(self.up_revision_ids)
@property
- def down_revisions(self) -> Tuple["Revision", ...]:
+ def down_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Get :attr:`~.MigrationInfo.down_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.down_revision_ids)
@property
- def source_revisions(self) -> Tuple["Revision", ...]:
+ def source_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Get :attr:`~MigrationInfo.source_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.source_revision_ids)
@property
- def destination_revisions(self) -> Tuple["Revision", ...]:
+ def destination_revisions(self) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Get :attr:`~MigrationInfo.destination_revision_ids` as a tuple of
:class:`Revisions <.Revision>`."""
return self.revision_map.get_revisions(self.destination_revision_ids)
)
@property
- def doc(self):
+ def doc(self) -> str:
return self.revision.doc
@property
self.migration_fn = self.stamp_revision
self.revision_map = revision_map
- doc = None
+ doc: None = None
def stamp_revision(self, **kw) -> None:
return None
+from __future__ import annotations
+
from contextlib import contextmanager
import datetime
import os
from . import write_hooks
from .. import util
from ..runtime import migration
+from ..util import not_none
if TYPE_CHECKING:
from ..config import Config
from ..runtime.migration import RevisionStep
from ..runtime.migration import StampStep
+ from ..script.revision import Revision
try:
from dateutil import tz
else:
return (os.path.abspath(os.path.join(self.dir, "versions")),)
- def _load_revisions(self) -> Iterator["Script"]:
+ def _load_revisions(self) -> Iterator[Script]:
if self.version_locations:
paths = [
vers
yield script
@classmethod
- def from_config(cls, config: "Config") -> "ScriptDirectory":
+ def from_config(cls, config: Config) -> ScriptDirectory:
"""Produce a new :class:`.ScriptDirectory` given a :class:`.Config`
instance.
raise util.CommandError(
"No 'script_location' key " "found in configuration."
)
- truncate_slug_length = cast(
- Optional[int], config.get_main_option("truncate_slug_length")
- )
- if truncate_slug_length is not None:
- truncate_slug_length = int(truncate_slug_length)
+ truncate_slug_length: Optional[int]
+ tsl = config.get_main_option("truncate_slug_length")
+ if tsl is not None:
+ truncate_slug_length = int(tsl)
+ else:
+ truncate_slug_length = None
- version_locations = config.get_main_option("version_locations")
- if version_locations:
+ version_locations_str = config.get_main_option("version_locations")
+ version_locations: Optional[List[str]]
+ if version_locations_str:
version_path_separator = config.get_main_option(
"version_path_separator"
)
}
try:
- split_char = split_on_path[version_path_separator]
+ split_char: Optional[str] = split_on_path[
+ version_path_separator
+ ]
except KeyError as ke:
raise ValueError(
"'%s' is not a valid value for "
else:
if split_char is None:
# legacy behaviour for backwards compatibility
- vl = _split_on_space_comma.split(
- cast(str, version_locations)
+ version_locations = _split_on_space_comma.split(
+ version_locations_str
)
- version_locations: List[str] = vl # type: ignore[no-redef]
else:
- vl = [
- x
- for x in cast(str, version_locations).split(split_char)
- if x
+ version_locations = [
+ x for x in version_locations_str.split(split_char) if x
]
- version_locations: List[str] = vl # type: ignore[no-redef]
+ else:
+ version_locations = None
prepend_sys_path = config.get_main_option("prepend_sys_path")
if prepend_sys_path:
truncate_slug_length=truncate_slug_length,
sourceless=config.get_main_option("sourceless") == "true",
output_encoding=config.get_main_option("output_encoding", "utf-8"),
- version_locations=cast("Optional[List[str]]", version_locations),
+ version_locations=version_locations,
timezone=config.get_main_option("timezone"),
hook_config=config.get_section("post_write_hooks", {}),
)
def walk_revisions(
self, base: str = "base", head: str = "heads"
- ) -> Iterator["Script"]:
+ ) -> Iterator[Script]:
"""Iterate through all revisions.
:param base: the base revision, or "base" to start from the
):
yield cast(Script, rev)
- def get_revisions(self, id_: _RevIdType) -> Tuple["Script", ...]:
+ def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
"""
with self._catch_revision_errors():
return cast(
- "Tuple[Script, ...]", self.revision_map.get_revisions(id_)
+ Tuple[Optional[Script], ...],
+ self.revision_map.get_revisions(id_),
)
- def get_all_current(self, id_: Tuple[str, ...]) -> Set["Script"]:
+ def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
with self._catch_revision_errors():
top_revs = cast(
- "Set[Script]",
+ Set[Optional[Script]],
set(self.revision_map.get_revisions(id_)),
)
top_revs.update(
cast(
- "Iterator[Script]",
+ Iterator[Script],
self.revision_map._get_ancestor_nodes(
list(top_revs), include_dependencies=True
),
top_revs = self.revision_map._filter_into_branch_heads(top_revs)
return top_revs
- def get_revision(self, id_: str) -> "Script":
+ def get_revision(self, id_: str) -> Optional[Script]:
"""Return the :class:`.Script` instance with the given rev id.
.. seealso::
"""
with self._catch_revision_errors():
- return cast(Script, self.revision_map.get_revision(id_))
+ return cast(Optional[Script], self.revision_map.get_revision(id_))
def as_revision_number(
self, id_: Optional[str]
else:
return rev[0]
- def iterate_revisions(self, upper, lower):
+ def iterate_revisions(
+ self,
+ upper: Union[str, Tuple[str, ...], None],
+ lower: Union[str, Tuple[str, ...], None],
+ **kw: Any,
+ ) -> Iterator[Script]:
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
:meth:`.RevisionMap.iterate_revisions`
"""
- return self.revision_map.iterate_revisions(upper, lower)
+ return cast(
+ Iterator[Script],
+ self.revision_map.iterate_revisions(upper, lower, **kw),
+ )
- def get_current_head(self):
+ def get_current_head(self) -> Optional[str]:
"""Return the current head revision.
If the script directory has multiple heads
def _upgrade_revs(
self, destination: str, current_rev: str
- ) -> List["RevisionStep"]:
+ ) -> List[RevisionStep]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid upgrade "
"target from current head(s)",
end=destination,
):
- revs = self.revision_map.iterate_revisions(
+ revs = self.iterate_revisions(
destination, current_rev, implicit_base=True
)
return [
migration.MigrationStep.upgrade_from_script(
- self.revision_map, cast(Script, script)
+ self.revision_map, script
)
for script in reversed(list(revs))
]
def _downgrade_revs(
self, destination: str, current_rev: Optional[str]
- ) -> List["RevisionStep"]:
+ ) -> List[RevisionStep]:
with self._catch_revision_errors(
ancestor="Destination %(end)s is not a valid downgrade "
"target from current head(s)",
end=destination,
):
- revs = self.revision_map.iterate_revisions(
+ revs = self.iterate_revisions(
current_rev, destination, select_for_downgrade=True
)
return [
migration.MigrationStep.downgrade_from_script(
- self.revision_map, cast(Script, script)
+ self.revision_map, script
)
for script in revs
]
if not revision:
revision = "base"
- filtered_heads: List["Script"] = []
+ filtered_heads: List[Script] = []
for rev in util.to_tuple(revision):
if rev:
filtered_heads.extend(
self.revision_map.filter_for_lineage(
- heads_revs, rev, include_dependencies=True
+ cast(Sequence[Script], heads_revs),
+ rev,
+ include_dependencies=True,
)
)
filtered_heads = util.unique_list(filtered_heads)
src,
dest,
self.output_encoding,
- **kw
+ **kw,
)
def _copy_file(self, src: str, dest: str) -> None:
branch_labels: Optional[str] = None,
version_path: Optional[str] = None,
depends_on: Optional[_RevIdType] = None,
- **kw: Any
- ) -> Optional["Script"]:
+ **kw: Any,
+ ) -> Optional[Script]:
"""Generate a new revision file.
This runs the ``script.py.mako`` template, given
"or perform a merge."
)
):
- heads = self.revision_map.get_revisions(head)
+ heads = cast(
+ Tuple[Optional["Revision"], ...],
+ self.revision_map.get_revisions(head),
+ )
+ for h in heads:
+ assert h != "base"
if len(set(heads)) != len(heads):
raise util.CommandError("Duplicate head revisions specified")
% head_.revision
)
+ resolved_depends_on: Optional[List[str]]
if depends_on:
with self._catch_revision_errors():
- depends_on = [
+ resolved_depends_on = [
dep
if dep in rev.branch_labels # maintain branch labels
else rev.revision # resolve partial revision identifiers
for rev, dep in [
- (self.revision_map.get_revision(dep), dep)
+ (not_none(self.revision_map.get_revision(dep)), dep)
for dep in util.to_list(depends_on)
]
]
+ else:
+ resolved_depends_on = None
self._generate_template(
os.path.join(self.dir, "script.py.mako"),
tuple(h.revision if h is not None else None for h in heads)
),
branch_labels=util.to_tuple(branch_labels),
- depends_on=revision.tuple_rev_as_scalar(
- cast("Optional[List[str]]", depends_on)
- ),
+ depends_on=revision.tuple_rev_as_scalar(resolved_depends_on),
create_date=create_date,
comma=util.format_as_comma,
message=message if message is not None else ("empty message"),
- **kw
+ **kw,
)
post_write_hooks = self.hook_config
),
)
- module: ModuleType = None # type: ignore[assignment]
+ module: ModuleType
"""The Python module representing the actual script itself."""
- path: str = None # type: ignore[assignment]
+ path: str
"""Filesystem path of the script."""
- _db_current_indicator = None
+ _db_current_indicator: Optional[bool] = None
"""Utility variable which when set will cause string output to indicate
this is a "current" version in some database"""
@classmethod
def _from_path(
cls, scriptdir: ScriptDirectory, path: str
- ) -> Optional["Script"]:
+ ) -> Optional[Script]:
dir_, filename = os.path.split(path)
return cls._from_filename(scriptdir, dir_, filename)
@classmethod
def _from_filename(
cls, scriptdir: ScriptDirectory, dir_: str, filename: str
- ) -> Optional["Script"]:
+ ) -> Optional[Script]:
if scriptdir.sourceless:
py_match = _sourceless_rev_file.match(filename)
else:
+from __future__ import annotations
+
import collections
import re
from typing import Any
from typing import Iterator
from typing import List
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
from sqlalchemy import util as sqlautil
from .. import util
+from ..util import not_none
if TYPE_CHECKING:
from typing import Literal
"Revision %s referenced from %s is not present"
% (downrev, revision)
)
- cast("Revision", map_[downrev]).add_nextrev(revision)
+ not_none(map_[downrev]).add_nextrev(revision)
self._normalize_depends_on(revisions, map_)
return self.filter_for_lineage(self.bases, identifier)
def get_revisions(
- self, id_: Union[str, Collection[str], None]
- ) -> Tuple["Revision", ...]:
+ self, id_: Union[str, Collection[Optional[str]], None]
+ ) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Return the :class:`.Revision` instances with the given rev id
or identifiers.
select_heads = tuple(
head
for head in select_heads
- if branch_label in head.branch_labels
+ if branch_label
+ in is_revision(head).branch_labels
)
return tuple(
self._walk(head, steps=rint)
for rev_id in resolved_id
)
- def get_revision(self, id_: Optional[str]) -> "Revision":
+ def get_revision(self, id_: Optional[str]) -> Optional[Revision]:
"""Return the :class:`.Revision` instance with the given rev id.
If a symbolic name such as "head" or "base" is given, resolves
resolved_id, branch_label = self._resolve_revision_number(id_)
if len(resolved_id) > 1:
raise MultipleHeads(resolved_id, id_)
- elif resolved_id:
- resolved_id = resolved_id[0] # type:ignore[assignment]
- return self._revision_for_ident(cast(str, resolved_id), branch_label)
+ resolved: Union[str, Tuple[()]] = resolved_id[0] if resolved_id else ()
+ return self._revision_for_ident(resolved, branch_label)
- def _resolve_branch(self, branch_label: str) -> "Revision":
+ def _resolve_branch(self, branch_label: str) -> Optional[Revision]:
try:
branch_rev = self._revision_map[branch_label]
except KeyError:
else:
return nonbranch_rev
else:
- return cast("Revision", branch_rev)
+ return branch_rev
def _revision_for_ident(
- self, resolved_id: str, check_branch: Optional[str] = None
- ) -> "Revision":
- branch_rev: Optional["Revision"]
+ self,
+ resolved_id: Union[str, Tuple[()]],
+ check_branch: Optional[str] = None,
+ ) -> Optional[Revision]:
+ branch_rev: Optional[Revision]
if check_branch:
branch_rev = self._resolve_branch(check_branch)
else:
branch_rev = None
- revision: Union["Revision", "Literal[False]"]
+ revision: Union[Optional[Revision], "Literal[False]"]
try:
- revision = cast("Revision", self._revision_map[resolved_id])
+ revision = self._revision_map[resolved_id]
except KeyError:
# break out to avoid misleading py3k stack traces
revision = False
revs: Sequence[str]
if revision is False:
+ assert resolved_id
# do a partial lookup
revs = [
x
resolved_id,
)
else:
- revision = cast("Revision", self._revision_map[revs[0]])
+ revision = self._revision_map[revs[0]]
- revision = cast("Revision", revision)
if check_branch and revision is not None:
assert branch_rev is not None
+ assert resolved_id
if not self._shares_lineage(
revision.revision, branch_rev.revision
):
return revision
def _filter_into_branch_heads(
- self, targets: Set["Script"]
- ) -> Set["Script"]:
+ self, targets: Set[Optional[Script]]
+ ) -> Set[Optional[Script]]:
targets = set(targets)
for rev in list(targets):
+ assert rev
if targets.intersection(
self._get_descendant_nodes([rev], include_dependencies=False)
).difference([rev]):
if not test_against_revs:
return True
if not isinstance(target, Revision):
- target = self._revision_for_ident(target)
+ resolved_target = not_none(self._revision_for_ident(target))
+ else:
+ resolved_target = target
- test_against_revs = [
+ resolved_test_against_revs = [
self._revision_for_ident(test_against_rev)
if not isinstance(test_against_rev, Revision)
else test_against_rev
return bool(
set(
self._get_descendant_nodes(
- [target], include_dependencies=include_dependencies
+ [resolved_target],
+ include_dependencies=include_dependencies,
)
)
.union(
self._get_ancestor_nodes(
- [target], include_dependencies=include_dependencies
+ [resolved_target],
+ include_dependencies=include_dependencies,
)
)
- .intersection(test_against_revs)
+ .intersection(resolved_test_against_revs)
)
def _resolve_revision_number(
inclusive: bool = False,
assert_relative_length: bool = True,
select_for_downgrade: bool = False,
- ) -> Iterator["Revision"]:
+ ) -> Iterator[Revision]:
"""Iterate through script revisions, starting at the given
upper revision identifier and ending at the lower.
)
for node in self._topological_sort(revisions, heads):
- yield self.get_revision(node)
+ yield not_none(self.get_revision(node))
def _get_descendant_nodes(
self,
- targets: Collection["Revision"],
+ targets: Collection[Revision],
map_: Optional[_RevisionMapType] = None,
check: bool = False,
omit_immediate_dependencies: bool = False,
def _get_ancestor_nodes(
self,
- targets: Collection["Revision"],
+ targets: Collection[Optional[_RevisionOrBase]],
map_: Optional[_RevisionMapType] = None,
check: bool = False,
include_dependencies: bool = True,
- ) -> Iterator["Revision"]:
+ ) -> Iterator[Revision]:
if include_dependencies:
def _iterate_related_revisions(
self,
fn: Callable,
- targets: Collection["Revision"],
+ targets: Collection[Optional[_RevisionOrBase]],
map_: Optional[_RevisionMapType],
check: bool = False,
- ) -> Iterator["Revision"]:
+ ) -> Iterator[Revision]:
if map_ is None:
map_ = self._revision_map
seen = set()
- todo: Deque["Revision"] = collections.deque()
- for target in targets:
-
+ todo: Deque[Revision] = collections.deque()
+ for target_for in targets:
+ target = is_revision(target_for)
todo.append(target)
if check:
per_target = set()
def _topological_sort(
self,
- revisions: Collection["Revision"],
+ revisions: Collection[Revision],
heads: Any,
) -> List[str]:
"""Yield revision ids of a collection of Revision objects in
def _walk(
self,
- start: Optional[Union[str, "Revision"]],
+ start: Optional[Union[str, Revision]],
steps: int,
branch_label: Optional[str] = None,
no_overwalk: bool = True,
- ) -> "Revision":
+ ) -> Optional[_RevisionOrBase]:
"""
Walk the requested number of :steps up (steps > 0) or down (steps < 0)
the revision tree.
else:
initial = start
- children: Sequence[_RevisionOrBase]
+ children: Sequence[Optional[_RevisionOrBase]]
for _ in range(abs(steps)):
if steps > 0:
+ assert initial != "base"
# Walk up
- children = [
- rev
+ walk_up = [
+ is_revision(rev)
for rev in self.get_revisions(
- self.bases
- if initial is None
- else cast("Revision", initial).nextrev
+ self.bases if initial is None else initial.nextrev
)
]
if branch_label:
- children = self.filter_for_lineage(children, branch_label)
+ children = self.filter_for_lineage(walk_up, branch_label)
+ else:
+ children = walk_up
else:
# Walk down
if initial == "base":
else initial.down_revision
)
if not children:
- children = cast("Tuple[Literal['base']]", ("base",))
+ children = ("base",)
if not children:
# This will return an invalid result if no_overwalk, otherwise
# further steps will stay where we are.
ret = None if no_overwalk else initial
- return ret # type:ignore[return-value]
+ return ret
elif len(children) > 1:
raise RevisionError("Ambiguous walk")
initial = children[0]
- return cast("Revision", initial)
+ return initial
def _parse_downgrade_target(
self,
current_revisions: _RevisionIdentifierType,
target: _RevisionIdentifierType,
assert_relative_length: bool,
- ) -> Tuple["Revision", ...]:
+ ) -> Tuple[Optional[_RevisionOrBase], ...]:
"""
Parse upgrade command syntax :target to retrieve the target revision
and given the :current_revisons stamp of the database.
# No relative destination, target is absolute.
return self.get_revisions(target)
- current_revisions = util.to_tuple(current_revisions)
+ current_revisions_tup: Union[str, Collection[Optional[str]], None]
+ current_revisions_tup = util.to_tuple(current_revisions)
branch_label, symbol, relative_str = match.groups()
relative = int(relative_str)
if relative > 0:
if symbol is None:
- if not current_revisions:
- current_revisions = (None,)
+ if not current_revisions_tup:
+ current_revisions_tup = (None,)
# Try to filter to a single target (avoid ambiguous branches).
- start_revs = current_revisions
+ start_revs = current_revisions_tup
if branch_label:
start_revs = self.filter_for_lineage(
- self.get_revisions(current_revisions), branch_label
+ self.get_revisions(current_revisions_tup), branch_label
)
if not start_revs:
# The requested branch is not a head, so we need to
# backtrack to find a branchpoint.
active_on_branch = self.filter_for_lineage(
self._get_ancestor_nodes(
- self.get_revisions(current_revisions)
+ self.get_revisions(current_revisions_tup)
),
branch_label,
)
target_revision = None
assert target_revision is None or isinstance(target_revision, Revision)
+ roots: List[Revision]
# Find candidates to drop.
if target_revision is None:
# Downgrading back to base: find all tree roots.
roots = [target_revision]
else:
# Downgrading to fixed target: find all direct children.
- roots = list(self.get_revisions(target_revision.nextrev))
+ roots = [
+ is_revision(rev)
+ for rev in self.get_revisions(target_revision.nextrev)
+ ]
if branch_label and len(roots) > 1:
# Need to filter roots.
}
# Intersection gives the root revisions we are trying to
# rollback with the downgrade.
- roots = list(
- self.get_revisions(
+ roots = [
+ is_revision(rev)
+ for rev in self.get_revisions(
{rev.revision for rev in roots}.intersection(ancestors)
)
- )
+ ]
# Ensure we didn't throw everything away when filtering branches.
if len(roots) == 0:
inclusive: bool,
implicit_base: bool,
assert_relative_length: bool,
- ) -> Tuple[Set["Revision"], Tuple[Optional[_RevisionOrBase]]]:
+ ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]:
"""
Compute the set of required revisions specified by :upper, and the
current set of active revisions specified by :lower. Find the
of the current/lower revisions. Dependencies from branches with
different bases will not be included.
"""
- targets: Collection["Revision"] = self._parse_upgrade_target(
- current_revisions=lower,
- target=upper,
- assert_relative_length=assert_relative_length,
- )
+ targets: Collection[Revision] = [
+ is_revision(rev)
+ for rev in self._parse_upgrade_target(
+ current_revisions=lower,
+ target=upper,
+ assert_relative_length=assert_relative_length,
+ )
+ ]
# assert type(targets) is tuple, "targets should be a tuple"
target=lower,
assert_relative_length=assert_relative_length,
)
+ assert rev
if rev == "base":
current_revisions = tuple()
lower = None
# Include the lower revision (=current_revisions?) in the iteration
if inclusive:
- needs.update(self.get_revisions(lower))
+ needs.update(is_revision(rev) for rev in self.get_revisions(lower))
# By default, base is implicit as we want all dependencies returned.
# Base is also implicit if lower = base
# implicit_base=False -> only return direct downstreams of
# current_revisions
if current_revisions and not implicit_base:
lower_descendents = self._get_descendant_nodes(
- current_revisions, check=True, include_dependencies=False
+ [is_revision(rev) for rev in current_revisions],
+ check=True,
+ include_dependencies=False,
)
needs.intersection_update(lower_descendents)
args.append("branch_labels=%r" % (self.branch_labels,))
return "%s(%s)" % (self.__class__.__name__, ", ".join(args))
- def add_nextrev(self, revision: "Revision") -> None:
+ def add_nextrev(self, revision: Revision) -> None:
self._all_nextrev = self._all_nextrev.union([revision.revision])
if self.revision in revision._versioned_down_revisions:
self.nextrev = self.nextrev.union([revision.revision])
return len(self._versioned_down_revisions) > 1
+@overload
def tuple_rev_as_scalar(
rev: Optional[Sequence[str]],
) -> Optional[Union[str, Sequence[str]]]:
+ ...
+
+
+@overload
+def tuple_rev_as_scalar(
+ rev: Optional[Sequence[Optional[str]]],
+) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]:
+ ...
+
+
+def tuple_rev_as_scalar(rev):
if not rev:
return None
elif len(rev) == 1:
return rev[0]
else:
return rev
+
+
+def is_revision(rev: Any) -> Revision:
+ assert isinstance(rev, Revision)
+ return rev
+from __future__ import annotations
+
import shlex
import subprocess
import sys
REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME"
-_registry = {}
+_registry: dict = {}
def register(name: str) -> Callable:
+from __future__ import annotations
+
import contextlib
import re
import sys
class _ErrorContainer:
- error = None
+ error: Any = None
@contextlib.contextmanager
# coding: utf-8
+from __future__ import annotations
+
import configparser
from contextlib import contextmanager
import io
+from __future__ import annotations
+
from typing import Any
from typing import Dict
+from typing import Set
from sqlalchemy import CHAR
from sqlalchemy import CheckConstraint
from ...testing.env import clear_staging_env
from ...testing.env import staging_env
-names_in_this_test = set()
+names_in_this_test: Set[Any] = set()
@event.listens_for(Table, "after_parent_attach")
return True
-_default_object_filters = _default_include_object
+_default_object_filters: Any = _default_include_object
-_default_name_filters = None
+_default_name_filters: Any = None
class ModelOne:
__requires__ = ("unique_constraint_reflection",)
- schema = None
+ schema: Any = None
@classmethod
def _get_db_schema(cls):
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
+from __future__ import annotations
import re
import types
for d in combinations
],
id_="i" + ("a" * len(keys)),
- argnames=",".join(keys)
+ argnames=",".join(keys),
)
from .langhelpers import immutabledict
from .langhelpers import memoized_property
from .langhelpers import ModuleClsProxy
+from .langhelpers import not_none
from .langhelpers import rev_id
from .langhelpers import to_list
from .langhelpers import to_tuple
+from __future__ import annotations
+
import io
import os
import sys
from importlib import metadata as importlib_metadata
from importlib.metadata import EntryPoint
else:
- import importlib_resources # type:ignore[no-redef] # noqa
- import importlib_metadata # type:ignore[no-redef] # noqa
+ import importlib_resources # type:ignore # noqa
+ import importlib_metadata # type:ignore # noqa
from importlib_metadata import EntryPoint # type:ignore # noqa
+from __future__ import annotations
+
import os
from os.path import exists
from os.path import join
+from __future__ import annotations
+
import collections
from collections.abc import Iterable
import textwrap
else:
d._registry.update(self._registry)
return d
+
+
+def not_none(value: Optional[_T]) -> _T:
+ assert value is not None
+ return value
+from __future__ import annotations
+
from collections.abc import Iterable
import logging
import sys
+from __future__ import annotations
+
import atexit
from contextlib import ExitStack
import importlib
+from __future__ import annotations
+
import contextlib
import re
from typing import Iterator
sqla_13 = _vers >= (1, 3)
sqla_14 = _vers >= (1, 4)
sqla_14_26 = _vers >= (1, 4, 26)
+sqla_2 = _vers >= (1, 5)
if sqla_14:
[tool.black]
line-length = 79
+
+[tool.mypy]
+
+exclude = [
+ 'alembic/template',
+ 'alembic.testing.*',
+]
+
+[[tool.mypy.overrides]]
+module = [
+ 'mako.*',
+ 'sqlalchemy.testing.*'
+]
+ignore_missing_imports = true
version = exclusions.only_if(
lambda _: compat.py39, "python 3.9 is required"
)
- return imports + version
+
+ sqlalchemy = exclusions.only_if(
+ lambda _: sqla_compat.sqla_2, "sqlalchemy 2 is required"
+ )
+
+ return imports + version + sqlalchemy
import sys
from tempfile import NamedTemporaryFile
import textwrap
+import typing
from mako.pygen import PythonPrinter
from alembic.script.write_hooks import console_scripts
from alembic.util.compat import inspect_formatargspec
from alembic.util.compat import inspect_getfullargspec
+ from alembic.operations import ops
+ import sqlalchemy as sa
IGNORE_ITEMS = {
"op": {"context", "create_module_class_proxy"},
"requires_connection",
},
}
+TRIM_MODULE = [
+ "alembic.runtime.migration.",
+ "alembic.operations.ops.",
+ "sqlalchemy.engine.base.",
+ "sqlalchemy.sql.schema.",
+ "sqlalchemy.sql.selectable.",
+ "sqlalchemy.sql.elements.",
+ "sqlalchemy.sql.type_api.",
+ "sqlalchemy.sql.functions.",
+ "sqlalchemy.sql.dml.",
+]
def generate_pyi_for_proxy(
printer.writeline("### end imports ###")
buf.write("\n\n")
+ module = sys.modules[cls.__module__]
+ env = {
+ **sa.__dict__,
+ **sa.types.__dict__,
+ **ops.__dict__,
+ **module.__dict__,
+ }
+
for name in dir(cls):
if name.startswith("_") or name in ignore_items:
continue
- meth = getattr(cls, name)
+ meth = getattr(cls, name, None)
if callable(meth):
- _generate_stub_for_meth(cls, name, printer)
+ _generate_stub_for_meth(cls, name, printer, env)
else:
- _generate_stub_for_attr(cls, name, printer)
+ _generate_stub_for_attr(cls, name, printer, env)
printer.close()
)
-def _generate_stub_for_attr(cls, name, printer):
- type_ = cls.__annotations__.get(name, "Any")
+def _generate_stub_for_attr(cls, name, printer, env):
+ try:
+ annotations = typing.get_type_hints(cls, env)
+ except NameError as e:
+ annotations = cls.__annotations__
+ type_ = annotations.get(name, "Any")
+ if isinstance(type_, str) and type_[0] in "'\"":
+ type_ = type_[1:-1]
printer.writeline(f"{name}: {type_}")
-def _generate_stub_for_meth(cls, name, printer):
+def _generate_stub_for_meth(cls, name, printer, env):
fn = getattr(cls, name)
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
spec = inspect_getfullargspec(fn)
+ try:
+ annotations = typing.get_type_hints(fn, env)
+ spec.annotations.update(annotations)
+ except NameError as e:
+ pass
name_args = spec[0]
assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]
else:
retval = annotation.__module__ + "." + annotation.__qualname__
else:
- retval = repr(annotation)
+ retval = annotation
+
+ for trim in TRIM_MODULE:
+ retval = retval.replace(trim, "")
retval = re.sub(
r'ForwardRef\(([\'"].+?[\'"])\)', lambda m: m.group(1), retval
retval = re.sub("NoneType", "None", retval)
return retval
- argspec = inspect_formatargspec(*spec, formatannotation=_formatannotation)
+ argspec = inspect_formatargspec(
+ *spec,
+ formatannotation=_formatannotation,
+ formatreturns=lambda val: f"-> {_formatannotation(val)}",
+ )
func_text = textwrap.dedent(
"""\