From 5628a2270ec66e334a109c3859bb74a86b08a9fb Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Thu, 31 Aug 2023 23:25:04 +0200 Subject: [PATCH] Improve typings Misc changes to improve the typing of alembic: - Improve typing of the revision parameter in various command functions. - Properly type the :paramref:`.Operations.create_check_constraint.condition` parameter of :meth:`.Operations.create_check_constraint` to accept boolean expressions. Fixes: #930 Fixes: #1266 Change-Id: I9e8249bbd34f9f0b388b79e75b76e75f8347d8ee --- alembic/autogenerate/api.py | 7 ++-- alembic/command.py | 13 +++--- alembic/ddl/postgresql.py | 8 ++-- alembic/op.pyi | 4 +- alembic/operations/base.py | 6 +-- alembic/operations/ops.py | 5 +-- alembic/runtime/migration.py | 4 +- alembic/script/base.py | 7 ++-- alembic/script/revision.py | 51 ++++++++++++++---------- docs/build/unreleased/improve_typing.rst | 13 ++++++ 10 files changed, 69 insertions(+), 49 deletions(-) create mode 100644 docs/build/unreleased/improve_typing.rst diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 064bca9f..9b7a97d4 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: from ..runtime.migration import MigrationContext from ..script.base import Script from ..script.base import ScriptDirectory + from ..script.revision import _GetRevArg def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any: @@ -555,18 +556,18 @@ class RevisionContext: ) def run_autogenerate( - self, rev: tuple, migration_context: MigrationContext + self, rev: _GetRevArg, migration_context: MigrationContext ) -> None: self._run_environment(rev, migration_context, True) def run_no_autogenerate( - self, rev: tuple, migration_context: MigrationContext + self, rev: _GetRevArg, migration_context: MigrationContext ) -> None: self._run_environment(rev, migration_context, False) def _run_environment( self, - rev: tuple, + rev: _GetRevArg, migration_context: MigrationContext, autogenerate: bool, ) -> None: diff --git a/alembic/command.py b/alembic/command.py index be0e0fd9..bd59d42a 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -14,6 +14,7 @@ from .script import ScriptDirectory if TYPE_CHECKING: from alembic.config import Config from alembic.script.base import Script + from alembic.script.revision import _RevIdType from .runtime.environment import ProcessRevisionDirectiveFn @@ -124,7 +125,7 @@ def revision( sql: bool = False, head: str = "head", splice: bool = False, - branch_label: Optional[str] = None, + branch_label: Optional[_RevIdType] = None, version_path: Optional[str] = None, rev_id: Optional[str] = None, depends_on: Optional[str] = None, @@ -244,9 +245,7 @@ def revision( return scripts -def check( - config: "Config", -) -> None: +def check(config: "Config") -> None: """Check if revision command with autogenerate has pending upgrade ops. :param config: a :class:`.Config` object. @@ -302,9 +301,9 @@ def check( def merge( config: Config, - revisions: str, + revisions: _RevIdType, message: Optional[str] = None, - branch_label: Optional[str] = None, + branch_label: Optional[_RevIdType] = None, rev_id: Optional[str] = None, ) -> Optional[Script]: """Merge two revisions together. Creates a new migration file. @@ -623,7 +622,7 @@ def current(config: Config, verbose: bool = False) -> None: def stamp( config: Config, - revision: str, + revision: _RevIdType, sql: bool = False, tag: Optional[str] = None, purge: bool = False, diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index b63938ac..a74012fb 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -57,8 +57,8 @@ if TYPE_CHECKING: from sqlalchemy.dialects.postgresql.hstore import HSTORE from sqlalchemy.dialects.postgresql.json import JSON from sqlalchemy.dialects.postgresql.json import JSONB - from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy.sql.elements import ClauseElement + from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import Table @@ -513,7 +513,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): Sequence[Tuple[str, str]], Sequence[Tuple[ColumnClause[Any], str]], ], - where: Optional[Union[BinaryExpression, str]] = None, + where: Optional[Union[ColumnElement[bool], str]] = None, schema: Optional[str] = None, _orig_constraint: Optional[ExcludeConstraint] = None, **kw, @@ -538,9 +538,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): (expr, op) for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa ], - where=cast( - "Optional[Union[BinaryExpression, str]]", constraint.where - ), + where=cast("ColumnElement[bool] | None", constraint.where), schema=constraint_table.schema, _orig_constraint=constraint, deferrable=constraint.deferrable, diff --git a/alembic/op.pyi b/alembic/op.pyi index 6e143e47..2bb16d81 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -24,7 +24,7 @@ from sqlalchemy.sql.expression import Update if TYPE_CHECKING: from sqlalchemy.engine import Connection - from sqlalchemy.sql.elements import BinaryExpression + from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import conv from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.functions import Function @@ -481,7 +481,7 @@ def bulk_insert( def create_check_constraint( constraint_name: Optional[str], table_name: str, - condition: Union[str, BinaryExpression, TextClause], + condition: Union[str, ColumnElement[bool], TextClause], *, schema: Optional[str] = None, **kw: Any, diff --git a/alembic/operations/base.py b/alembic/operations/base.py index 8df6efd1..8b74dfd7 100644 --- a/alembic/operations/base.py +++ b/alembic/operations/base.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from sqlalchemy import Table from sqlalchemy.engine import Connection - from sqlalchemy.sql.expression import BinaryExpression + from sqlalchemy.sql.expression import ColumnElement from sqlalchemy.sql.expression import TableClause from sqlalchemy.sql.expression import TextClause from sqlalchemy.sql.expression import Update @@ -861,7 +861,7 @@ class Operations(AbstractOperations): self, constraint_name: Optional[str], table_name: str, - condition: Union[str, BinaryExpression, TextClause], + condition: Union[str, ColumnElement[bool], TextClause], *, schema: Optional[str] = None, **kw: Any, @@ -1635,7 +1635,7 @@ class BatchOperations(AbstractOperations): def create_check_constraint( self, constraint_name: str, - condition: Union[str, BinaryExpression, TextClause], + condition: Union[str, ColumnElement[bool], TextClause], **kw: Any, ) -> None: """Issue a "create check constraint" instruction using the diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 68c44eb6..5f0b5632 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -30,7 +30,6 @@ if TYPE_CHECKING: from sqlalchemy.sql.dml import Insert from sqlalchemy.sql.dml import Update - from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import conv from sqlalchemy.sql.elements import quoted_name @@ -788,7 +787,7 @@ class CreateCheckConstraintOp(AddConstraintOp): operations: Operations, constraint_name: Optional[str], table_name: str, - condition: Union[str, BinaryExpression, TextClause], + condition: Union[str, ColumnElement[bool], TextClause], *, schema: Optional[str] = None, **kw: Any, @@ -841,7 +840,7 @@ class CreateCheckConstraintOp(AddConstraintOp): cls, operations: BatchOperations, constraint_name: str, - condition: Union[str, BinaryExpression, TextClause], + condition: Union[str, ColumnElement[bool], TextClause], **kw: Any, ) -> None: """Issue a "create check constraint" instruction using the diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 50518ffe..c9374c22 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -1157,7 +1157,7 @@ class RevisionStep(MigrationStep): self.to_revisions[0], ) - def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]: + def _unmerge_to_revisions(self, heads: Set[str]) -> Tuple[str, ...]: other_heads = set(heads).difference([self.revision.revision]) if other_heads: ancestors = { @@ -1171,7 +1171,7 @@ class RevisionStep(MigrationStep): return self.to_revisions def unmerge_branch_idents( - self, heads: Collection[str] + self, heads: Set[str] ) -> Tuple[str, str, Tuple[str, ...]]: to_revisions = self._unmerge_to_revisions(heads) diff --git a/alembic/script/base.py b/alembic/script/base.py index 9894b4c3..d0f9abbd 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -26,6 +26,7 @@ from ..runtime import migration from ..util import not_none if TYPE_CHECKING: + from .revision import _GetRevArg from .revision import _RevIdType from .revision import Revision from ..config import Config @@ -296,7 +297,7 @@ class ScriptDirectory: ): yield cast(Script, rev) - def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]: + def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]: """Return the :class:`.Script` instance with the given rev identifier, symbolic name, or sequence of identifiers. @@ -630,8 +631,7 @@ class ScriptDirectory: self, revid: str, message: Optional[str], - head: Optional[str] = None, - refresh: bool = False, + head: Optional[_RevIdType] = None, splice: Optional[bool] = False, branch_labels: Optional[_RevIdType] = None, version_path: Optional[str] = None, @@ -653,7 +653,6 @@ class ScriptDirectory: :param splice: if True, allow the "head" version to not be an actual head; otherwise, the selected head must be a head (e.g. endpoint) revision. - :param refresh: deprecated. """ if head is None: diff --git a/alembic/script/revision.py b/alembic/script/revision.py index fe9ff616..6c18e7ae 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -29,13 +29,25 @@ from ..util import not_none if TYPE_CHECKING: from typing import Literal -_RevIdType = Union[str, Sequence[str]] +_RevIdType = Union[str, List[str], Tuple[str, ...]] +_GetRevArg = Union[ + str, + List[Optional[str]], + Tuple[Optional[str], ...], + FrozenSet[Optional[str]], + Set[Optional[str]], + List[str], + Tuple[str, ...], + FrozenSet[str], + Set[str], +] _RevisionIdentifierType = Union[str, Tuple[str, ...], None] _RevisionOrStr = Union["Revision", str] _RevisionOrBase = Union["Revision", "Literal['base']"] _InterimRevisionMapType = Dict[str, "Revision"] _RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]] -_T = TypeVar("_T", bound=Union[str, "Revision"]) +_T = TypeVar("_T") +_TR = TypeVar("_TR", bound=Optional[_RevisionOrStr]) _relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)") _revision_illegal_chars = ["@", "-", "+"] @@ -501,7 +513,7 @@ class RevisionMap: return self.filter_for_lineage(self.bases, identifier) def get_revisions( - self, id_: Union[str, Collection[Optional[str]], None] + self, id_: Optional[_GetRevArg] ) -> Tuple[Optional[_RevisionOrBase], ...]: """Return the :class:`.Revision` instances with the given rev id or identifiers. @@ -523,9 +535,7 @@ class RevisionMap: if isinstance(id_, (list, tuple, set, frozenset)): return sum([self.get_revisions(id_elem) for id_elem in id_], ()) else: - resolved_id, branch_label = self._resolve_revision_number( - id_ # type:ignore [arg-type] - ) + resolved_id, branch_label = self._resolve_revision_number(id_) if len(resolved_id) == 1: try: rint = int(resolved_id[0]) @@ -590,7 +600,7 @@ class RevisionMap: def _revision_for_ident( self, - resolved_id: Union[str, Tuple[()]], + resolved_id: Union[str, Tuple[()], None], check_branch: Optional[str] = None, ) -> Optional[Revision]: branch_rev: Optional[Revision] @@ -669,10 +679,10 @@ class RevisionMap: def filter_for_lineage( self, - targets: Iterable[_T], + targets: Iterable[_TR], check_against: Optional[str], include_dependencies: bool = False, - ) -> Tuple[_T, ...]: + ) -> Tuple[_TR, ...]: id_, branch_label = self._resolve_revision_number(check_against) shares = [] @@ -691,7 +701,7 @@ class RevisionMap: def _shares_lineage( self, - target: _RevisionOrStr, + target: Optional[_RevisionOrStr], test_against_revs: Sequence[_RevisionOrStr], include_dependencies: bool = False, ) -> bool: @@ -1211,7 +1221,7 @@ class RevisionMap: # No relative destination, target is absolute. return self.get_revisions(target) - current_revisions_tup: Union[str, Collection[Optional[str]], None] + current_revisions_tup: Union[str, Tuple[Optional[str], ...], None] current_revisions_tup = util.to_tuple(current_revisions) branch_label, symbol, relative_str = match.groups() @@ -1224,7 +1234,8 @@ class RevisionMap: start_revs = current_revisions_tup if branch_label: start_revs = self.filter_for_lineage( - self.get_revisions(current_revisions_tup), branch_label + self.get_revisions(current_revisions_tup), # type: ignore[arg-type] # noqa: E501 + branch_label, ) if not start_revs: # The requested branch is not a head, so we need to @@ -1577,8 +1588,8 @@ class Revision: self.verify_rev_id(revision) self.revision = revision - self.down_revision = tuple_rev_as_scalar(down_revision) - self.dependencies = tuple_rev_as_scalar(dependencies) + self.down_revision = tuple_rev_as_scalar(util.to_tuple(down_revision)) + self.dependencies = tuple_rev_as_scalar(util.to_tuple(dependencies)) self._orig_branch_labels = util.to_tuple(branch_labels, default=()) self.branch_labels = set(self._orig_branch_labels) @@ -1676,20 +1687,20 @@ class Revision: @overload -def tuple_rev_as_scalar( - rev: Optional[Sequence[str]], -) -> Optional[Union[str, Sequence[str]]]: +def tuple_rev_as_scalar(rev: None) -> None: ... @overload def tuple_rev_as_scalar( - rev: Optional[Sequence[Optional[str]]], -) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]: + rev: Union[Tuple[_T, ...], List[_T]] +) -> Union[_T, Tuple[_T, ...], List[_T]]: ... -def tuple_rev_as_scalar(rev): +def tuple_rev_as_scalar( + rev: Optional[Sequence[_T]], +) -> Union[_T, Sequence[_T], None]: if not rev: return None elif len(rev) == 1: diff --git a/docs/build/unreleased/improve_typing.rst b/docs/build/unreleased/improve_typing.rst new file mode 100644 index 00000000..43647fe6 --- /dev/null +++ b/docs/build/unreleased/improve_typing.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: typing + :tickets: 930 + + Improve typing of the revision parameter in various command functions. + +.. change:: + :tags: typing, bug + :tickets: 1266 + + Properly type the :paramref:`.Operations.create_check_constraint.condition` + parameter of :meth:`.Operations.create_check_constraint` to accept boolean + expressions. -- 2.47.2