from ..runtime.migration import MigrationContext
from ..script.base import Script
from ..script.base import ScriptDirectory
+ from ..script.revision import _GetRevArg
def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
)
def run_autogenerate(
- self, rev: tuple, migration_context: MigrationContext
+ self, rev: _GetRevArg, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, True)
def run_no_autogenerate(
- self, rev: tuple, migration_context: MigrationContext
+ self, rev: _GetRevArg, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, False)
def _run_environment(
self,
- rev: tuple,
+ rev: _GetRevArg,
migration_context: MigrationContext,
autogenerate: bool,
) -> None:
if TYPE_CHECKING:
from alembic.config import Config
from alembic.script.base import Script
+ from alembic.script.revision import _RevIdType
from .runtime.environment import ProcessRevisionDirectiveFn
sql: bool = False,
head: str = "head",
splice: bool = False,
- branch_label: Optional[str] = None,
+ branch_label: Optional[_RevIdType] = None,
version_path: Optional[str] = None,
rev_id: Optional[str] = None,
depends_on: Optional[str] = None,
return scripts
-def check(
- config: "Config",
-) -> None:
+def check(config: "Config") -> None:
"""Check if revision command with autogenerate has pending upgrade ops.
:param config: a :class:`.Config` object.
def merge(
config: Config,
- revisions: str,
+ revisions: _RevIdType,
message: Optional[str] = None,
- branch_label: Optional[str] = None,
+ branch_label: Optional[_RevIdType] = None,
rev_id: Optional[str] = None,
) -> Optional[Script]:
"""Merge two revisions together. Creates a new migration file.
def stamp(
config: Config,
- revision: str,
+ revision: _RevIdType,
sql: bool = False,
tag: Optional[str] = None,
purge: bool = False,
from sqlalchemy.dialects.postgresql.hstore import HSTORE
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
- from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.elements import ClauseElement
+ from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
Sequence[Tuple[str, str]],
Sequence[Tuple[ColumnClause[Any], str]],
],
- where: Optional[Union[BinaryExpression, str]] = None,
+ where: Optional[Union[ColumnElement[bool], str]] = None,
schema: Optional[str] = None,
_orig_constraint: Optional[ExcludeConstraint] = None,
**kw,
(expr, op)
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
],
- where=cast(
- "Optional[Union[BinaryExpression, str]]", constraint.where
- ),
+ where=cast("ColumnElement[bool] | None", constraint.where),
schema=constraint_table.schema,
_orig_constraint=constraint,
deferrable=constraint.deferrable,
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
- from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import Function
def create_check_constraint(
constraint_name: Optional[str],
table_name: str,
- condition: Union[str, BinaryExpression, TextClause],
+ condition: Union[str, ColumnElement[bool], TextClause],
*,
schema: Optional[str] = None,
**kw: Any,
from sqlalchemy import Table
from sqlalchemy.engine import Connection
- from sqlalchemy.sql.expression import BinaryExpression
+ from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.expression import TableClause
from sqlalchemy.sql.expression import TextClause
from sqlalchemy.sql.expression import Update
self,
constraint_name: Optional[str],
table_name: str,
- condition: Union[str, BinaryExpression, TextClause],
+ condition: Union[str, ColumnElement[bool], TextClause],
*,
schema: Optional[str] = None,
**kw: Any,
def create_check_constraint(
self,
constraint_name: str,
- condition: Union[str, BinaryExpression, TextClause],
+ condition: Union[str, ColumnElement[bool], TextClause],
**kw: Any,
) -> None:
"""Issue a "create check constraint" instruction using the
from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.dml import Update
- from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import quoted_name
operations: Operations,
constraint_name: Optional[str],
table_name: str,
- condition: Union[str, BinaryExpression, TextClause],
+ condition: Union[str, ColumnElement[bool], TextClause],
*,
schema: Optional[str] = None,
**kw: Any,
cls,
operations: BatchOperations,
constraint_name: str,
- condition: Union[str, BinaryExpression, TextClause],
+ condition: Union[str, ColumnElement[bool], TextClause],
**kw: Any,
) -> None:
"""Issue a "create check constraint" instruction using the
self.to_revisions[0],
)
- def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
+ def _unmerge_to_revisions(self, heads: Set[str]) -> Tuple[str, ...]:
other_heads = set(heads).difference([self.revision.revision])
if other_heads:
ancestors = {
return self.to_revisions
def unmerge_branch_idents(
- self, heads: Collection[str]
+ self, heads: Set[str]
) -> Tuple[str, str, Tuple[str, ...]]:
to_revisions = self._unmerge_to_revisions(heads)
from ..util import not_none
if TYPE_CHECKING:
+ from .revision import _GetRevArg
from .revision import _RevIdType
from .revision import Revision
from ..config import Config
):
yield cast(Script, rev)
- def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]:
+ def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]:
"""Return the :class:`.Script` instance with the given rev identifier,
symbolic name, or sequence of identifiers.
self,
revid: str,
message: Optional[str],
- head: Optional[str] = None,
- refresh: bool = False,
+ head: Optional[_RevIdType] = None,
splice: Optional[bool] = False,
branch_labels: Optional[_RevIdType] = None,
version_path: Optional[str] = None,
:param splice: if True, allow the "head" version to not be an
actual head; otherwise, the selected head must be a head
(e.g. endpoint) revision.
- :param refresh: deprecated.
"""
if head is None:
if TYPE_CHECKING:
from typing import Literal
-_RevIdType = Union[str, Sequence[str]]
+_RevIdType = Union[str, List[str], Tuple[str, ...]]
+_GetRevArg = Union[
+ str,
+ List[Optional[str]],
+ Tuple[Optional[str], ...],
+ FrozenSet[Optional[str]],
+ Set[Optional[str]],
+ List[str],
+ Tuple[str, ...],
+ FrozenSet[str],
+ Set[str],
+]
_RevisionIdentifierType = Union[str, Tuple[str, ...], None]
_RevisionOrStr = Union["Revision", str]
_RevisionOrBase = Union["Revision", "Literal['base']"]
_InterimRevisionMapType = Dict[str, "Revision"]
_RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]]
-_T = TypeVar("_T", bound=Union[str, "Revision"])
+_T = TypeVar("_T")
+_TR = TypeVar("_TR", bound=Optional[_RevisionOrStr])
_relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
_revision_illegal_chars = ["@", "-", "+"]
return self.filter_for_lineage(self.bases, identifier)
def get_revisions(
- self, id_: Union[str, Collection[Optional[str]], None]
+ self, id_: Optional[_GetRevArg]
) -> Tuple[Optional[_RevisionOrBase], ...]:
"""Return the :class:`.Revision` instances with the given rev id
or identifiers.
if isinstance(id_, (list, tuple, set, frozenset)):
return sum([self.get_revisions(id_elem) for id_elem in id_], ())
else:
- resolved_id, branch_label = self._resolve_revision_number(
- id_ # type:ignore [arg-type]
- )
+ resolved_id, branch_label = self._resolve_revision_number(id_)
if len(resolved_id) == 1:
try:
rint = int(resolved_id[0])
def _revision_for_ident(
self,
- resolved_id: Union[str, Tuple[()]],
+ resolved_id: Union[str, Tuple[()], None],
check_branch: Optional[str] = None,
) -> Optional[Revision]:
branch_rev: Optional[Revision]
def filter_for_lineage(
self,
- targets: Iterable[_T],
+ targets: Iterable[_TR],
check_against: Optional[str],
include_dependencies: bool = False,
- ) -> Tuple[_T, ...]:
+ ) -> Tuple[_TR, ...]:
id_, branch_label = self._resolve_revision_number(check_against)
shares = []
def _shares_lineage(
self,
- target: _RevisionOrStr,
+ target: Optional[_RevisionOrStr],
test_against_revs: Sequence[_RevisionOrStr],
include_dependencies: bool = False,
) -> bool:
# No relative destination, target is absolute.
return self.get_revisions(target)
- current_revisions_tup: Union[str, Collection[Optional[str]], None]
+ current_revisions_tup: Union[str, Tuple[Optional[str], ...], None]
current_revisions_tup = util.to_tuple(current_revisions)
branch_label, symbol, relative_str = match.groups()
start_revs = current_revisions_tup
if branch_label:
start_revs = self.filter_for_lineage(
- self.get_revisions(current_revisions_tup), branch_label
+ self.get_revisions(current_revisions_tup), # type: ignore[arg-type] # noqa: E501
+ branch_label,
)
if not start_revs:
# The requested branch is not a head, so we need to
self.verify_rev_id(revision)
self.revision = revision
- self.down_revision = tuple_rev_as_scalar(down_revision)
- self.dependencies = tuple_rev_as_scalar(dependencies)
+ self.down_revision = tuple_rev_as_scalar(util.to_tuple(down_revision))
+ self.dependencies = tuple_rev_as_scalar(util.to_tuple(dependencies))
self._orig_branch_labels = util.to_tuple(branch_labels, default=())
self.branch_labels = set(self._orig_branch_labels)
@overload
-def tuple_rev_as_scalar(
- rev: Optional[Sequence[str]],
-) -> Optional[Union[str, Sequence[str]]]:
+def tuple_rev_as_scalar(rev: None) -> None:
...
@overload
def tuple_rev_as_scalar(
- rev: Optional[Sequence[Optional[str]]],
-) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]:
+ rev: Union[Tuple[_T, ...], List[_T]]
+) -> Union[_T, Tuple[_T, ...], List[_T]]:
...
-def tuple_rev_as_scalar(rev):
+def tuple_rev_as_scalar(
+ rev: Optional[Sequence[_T]],
+) -> Union[_T, Sequence[_T], None]:
if not rev:
return None
elif len(rev) == 1:
--- /dev/null
+.. change::
+ :tags: typing
+ :tickets: 930
+
+ Improve typing of the revision parameter in various command functions.
+
+.. change::
+ :tags: typing, bug
+ :tickets: 1266
+
+ Properly type the :paramref:`.Operations.create_check_constraint.condition`
+ parameter of :meth:`.Operations.create_check_constraint` to accept boolean
+ expressions.