]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Improve typings
authorFederico Caselli <cfederico87@gmail.com>
Thu, 31 Aug 2023 21:25:04 +0000 (23:25 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 7 Sep 2023 18:12:07 +0000 (20:12 +0200)
Misc changes to improve the typing of alembic:
- Improve typing of the revision parameter in various command functions.
- Properly type the :paramref:`.Operations.create_check_constraint.condition`
  parameter of :meth:`.Operations.create_check_constraint` to accept boolean
  expressions.

Fixes: #930
Fixes: #1266
Change-Id: I9e8249bbd34f9f0b388b79e75b76e75f8347d8ee

alembic/autogenerate/api.py
alembic/command.py
alembic/ddl/postgresql.py
alembic/op.pyi
alembic/operations/base.py
alembic/operations/ops.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/script/revision.py
docs/build/unreleased/improve_typing.rst [new file with mode: 0644]

index 064bca9fdeb7903ff20193fa65af136431bd56c6..9b7a97d4bdf755032cc4f3af21c9483821ba89d1 100644 (file)
@@ -38,6 +38,7 @@ if TYPE_CHECKING:
     from ..runtime.migration import MigrationContext
     from ..script.base import Script
     from ..script.base import ScriptDirectory
+    from ..script.revision import _GetRevArg
 
 
 def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
@@ -555,18 +556,18 @@ class RevisionContext:
         )
 
     def run_autogenerate(
-        self, rev: tuple, migration_context: MigrationContext
+        self, rev: _GetRevArg, migration_context: MigrationContext
     ) -> None:
         self._run_environment(rev, migration_context, True)
 
     def run_no_autogenerate(
-        self, rev: tuple, migration_context: MigrationContext
+        self, rev: _GetRevArg, migration_context: MigrationContext
     ) -> None:
         self._run_environment(rev, migration_context, False)
 
     def _run_environment(
         self,
-        rev: tuple,
+        rev: _GetRevArg,
         migration_context: MigrationContext,
         autogenerate: bool,
     ) -> None:
index be0e0fd90f72a89f23464c172a0bc3d80b1b4a1b..bd59d42a665092566d870712cecccf2ed5a7533b 100644 (file)
@@ -14,6 +14,7 @@ from .script import ScriptDirectory
 if TYPE_CHECKING:
     from alembic.config import Config
     from alembic.script.base import Script
+    from alembic.script.revision import _RevIdType
     from .runtime.environment import ProcessRevisionDirectiveFn
 
 
@@ -124,7 +125,7 @@ def revision(
     sql: bool = False,
     head: str = "head",
     splice: bool = False,
-    branch_label: Optional[str] = None,
+    branch_label: Optional[_RevIdType] = None,
     version_path: Optional[str] = None,
     rev_id: Optional[str] = None,
     depends_on: Optional[str] = None,
@@ -244,9 +245,7 @@ def revision(
         return scripts
 
 
-def check(
-    config: "Config",
-) -> None:
+def check(config: "Config") -> None:
     """Check if revision command with autogenerate has pending upgrade ops.
 
     :param config: a :class:`.Config` object.
@@ -302,9 +301,9 @@ def check(
 
 def merge(
     config: Config,
-    revisions: str,
+    revisions: _RevIdType,
     message: Optional[str] = None,
-    branch_label: Optional[str] = None,
+    branch_label: Optional[_RevIdType] = None,
     rev_id: Optional[str] = None,
 ) -> Optional[Script]:
     """Merge two revisions together.  Creates a new migration file.
@@ -623,7 +622,7 @@ def current(config: Config, verbose: bool = False) -> None:
 
 def stamp(
     config: Config,
-    revision: str,
+    revision: _RevIdType,
     sql: bool = False,
     tag: Optional[str] = None,
     purge: bool = False,
index b63938ac4617161a8c75a1912b64df88d476331b..a74012fbd9b706ba342e7e000e1c2cbf0af819d7 100644 (file)
@@ -57,8 +57,8 @@ if TYPE_CHECKING:
     from sqlalchemy.dialects.postgresql.hstore import HSTORE
     from sqlalchemy.dialects.postgresql.json import JSON
     from sqlalchemy.dialects.postgresql.json import JSONB
-    from sqlalchemy.sql.elements import BinaryExpression
     from sqlalchemy.sql.elements import ClauseElement
+    from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import quoted_name
     from sqlalchemy.sql.schema import MetaData
     from sqlalchemy.sql.schema import Table
@@ -513,7 +513,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
             Sequence[Tuple[str, str]],
             Sequence[Tuple[ColumnClause[Any], str]],
         ],
-        where: Optional[Union[BinaryExpression, str]] = None,
+        where: Optional[Union[ColumnElement[bool], str]] = None,
         schema: Optional[str] = None,
         _orig_constraint: Optional[ExcludeConstraint] = None,
         **kw,
@@ -538,9 +538,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):
                 (expr, op)
                 for expr, name, op in constraint._render_exprs  # type:ignore[attr-defined] # noqa
             ],
-            where=cast(
-                "Optional[Union[BinaryExpression, str]]", constraint.where
-            ),
+            where=cast("ColumnElement[bool] | None", constraint.where),
             schema=constraint_table.schema,
             _orig_constraint=constraint,
             deferrable=constraint.deferrable,
index 6e143e478bc4724ffeaebf9ffa9835504a37471d..2bb16d813a327d74e5e913599e8c376218424bdf 100644 (file)
@@ -24,7 +24,7 @@ from sqlalchemy.sql.expression import Update
 
 if TYPE_CHECKING:
     from sqlalchemy.engine import Connection
-    from sqlalchemy.sql.elements import BinaryExpression
+    from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import conv
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.functions import Function
@@ -481,7 +481,7 @@ def bulk_insert(
 def create_check_constraint(
     constraint_name: Optional[str],
     table_name: str,
-    condition: Union[str, BinaryExpression, TextClause],
+    condition: Union[str, ColumnElement[bool], TextClause],
     *,
     schema: Optional[str] = None,
     **kw: Any,
index 8df6efd1862398cd77f3371ba5a3d1be31240468..8b74dfd7e0f001cc02b2a34fdb6b54594a6bc81c 100644 (file)
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
 
     from sqlalchemy import Table
     from sqlalchemy.engine import Connection
-    from sqlalchemy.sql.expression import BinaryExpression
+    from sqlalchemy.sql.expression import ColumnElement
     from sqlalchemy.sql.expression import TableClause
     from sqlalchemy.sql.expression import TextClause
     from sqlalchemy.sql.expression import Update
@@ -861,7 +861,7 @@ class Operations(AbstractOperations):
             self,
             constraint_name: Optional[str],
             table_name: str,
-            condition: Union[str, BinaryExpression, TextClause],
+            condition: Union[str, ColumnElement[bool], TextClause],
             *,
             schema: Optional[str] = None,
             **kw: Any,
@@ -1635,7 +1635,7 @@ class BatchOperations(AbstractOperations):
         def create_check_constraint(
             self,
             constraint_name: str,
-            condition: Union[str, BinaryExpression, TextClause],
+            condition: Union[str, ColumnElement[bool], TextClause],
             **kw: Any,
         ) -> None:
             """Issue a "create check constraint" instruction using the
index 68c44eb6ab47ee7c361a6941c1afccd21fabbb80..5f0b563289ce1ca028dea6532670d91853292d07 100644 (file)
@@ -30,7 +30,6 @@ if TYPE_CHECKING:
 
     from sqlalchemy.sql.dml import Insert
     from sqlalchemy.sql.dml import Update
-    from sqlalchemy.sql.elements import BinaryExpression
     from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import conv
     from sqlalchemy.sql.elements import quoted_name
@@ -788,7 +787,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
         operations: Operations,
         constraint_name: Optional[str],
         table_name: str,
-        condition: Union[str, BinaryExpression, TextClause],
+        condition: Union[str, ColumnElement[bool], TextClause],
         *,
         schema: Optional[str] = None,
         **kw: Any,
@@ -841,7 +840,7 @@ class CreateCheckConstraintOp(AddConstraintOp):
         cls,
         operations: BatchOperations,
         constraint_name: str,
-        condition: Union[str, BinaryExpression, TextClause],
+        condition: Union[str, ColumnElement[bool], TextClause],
         **kw: Any,
     ) -> None:
         """Issue a "create check constraint" instruction using the
index 50518ffee2717ea6426f2013aec707211404081f..c9374c227bd252ff2e981e2cadd80a3c7ac28bbb 100644 (file)
@@ -1157,7 +1157,7 @@ class RevisionStep(MigrationStep):
             self.to_revisions[0],
         )
 
-    def _unmerge_to_revisions(self, heads: Collection[str]) -> Tuple[str, ...]:
+    def _unmerge_to_revisions(self, heads: Set[str]) -> Tuple[str, ...]:
         other_heads = set(heads).difference([self.revision.revision])
         if other_heads:
             ancestors = {
@@ -1171,7 +1171,7 @@ class RevisionStep(MigrationStep):
             return self.to_revisions
 
     def unmerge_branch_idents(
-        self, heads: Collection[str]
+        self, heads: Set[str]
     ) -> Tuple[str, str, Tuple[str, ...]]:
         to_revisions = self._unmerge_to_revisions(heads)
 
index 9894b4c3318c7d02857d0eee6b5abfbd5f119a00..d0f9abbde4be082cfd25ae2a2a418c04181997d7 100644 (file)
@@ -26,6 +26,7 @@ from ..runtime import migration
 from ..util import not_none
 
 if TYPE_CHECKING:
+    from .revision import _GetRevArg
     from .revision import _RevIdType
     from .revision import Revision
     from ..config import Config
@@ -296,7 +297,7 @@ class ScriptDirectory:
             ):
                 yield cast(Script, rev)
 
-    def get_revisions(self, id_: _RevIdType) -> Tuple[Optional[Script], ...]:
+    def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]:
         """Return the :class:`.Script` instance with the given rev identifier,
         symbolic name, or sequence of identifiers.
 
@@ -630,8 +631,7 @@ class ScriptDirectory:
         self,
         revid: str,
         message: Optional[str],
-        head: Optional[str] = None,
-        refresh: bool = False,
+        head: Optional[_RevIdType] = None,
         splice: Optional[bool] = False,
         branch_labels: Optional[_RevIdType] = None,
         version_path: Optional[str] = None,
@@ -653,7 +653,6 @@ class ScriptDirectory:
         :param splice: if True, allow the "head" version to not be an
          actual head; otherwise, the selected head must be a head
          (e.g. endpoint) revision.
-        :param refresh: deprecated.
 
         """
         if head is None:
index fe9ff616d6258a7714048606d9923e5f6879afe5..6c18e7aed842ab359c24f074d6f930b23b9164ff 100644 (file)
@@ -29,13 +29,25 @@ from ..util import not_none
 if TYPE_CHECKING:
     from typing import Literal
 
-_RevIdType = Union[str, Sequence[str]]
+_RevIdType = Union[str, List[str], Tuple[str, ...]]
+_GetRevArg = Union[
+    str,
+    List[Optional[str]],
+    Tuple[Optional[str], ...],
+    FrozenSet[Optional[str]],
+    Set[Optional[str]],
+    List[str],
+    Tuple[str, ...],
+    FrozenSet[str],
+    Set[str],
+]
 _RevisionIdentifierType = Union[str, Tuple[str, ...], None]
 _RevisionOrStr = Union["Revision", str]
 _RevisionOrBase = Union["Revision", "Literal['base']"]
 _InterimRevisionMapType = Dict[str, "Revision"]
 _RevisionMapType = Dict[Union[None, str, Tuple[()]], Optional["Revision"]]
-_T = TypeVar("_T", bound=Union[str, "Revision"])
+_T = TypeVar("_T")
+_TR = TypeVar("_TR", bound=Optional[_RevisionOrStr])
 
 _relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)")
 _revision_illegal_chars = ["@", "-", "+"]
@@ -501,7 +513,7 @@ class RevisionMap:
         return self.filter_for_lineage(self.bases, identifier)
 
     def get_revisions(
-        self, id_: Union[str, Collection[Optional[str]], None]
+        self, id_: Optional[_GetRevArg]
     ) -> Tuple[Optional[_RevisionOrBase], ...]:
         """Return the :class:`.Revision` instances with the given rev id
         or identifiers.
@@ -523,9 +535,7 @@ class RevisionMap:
         if isinstance(id_, (list, tuple, set, frozenset)):
             return sum([self.get_revisions(id_elem) for id_elem in id_], ())
         else:
-            resolved_id, branch_label = self._resolve_revision_number(
-                id_  # type:ignore [arg-type]
-            )
+            resolved_id, branch_label = self._resolve_revision_number(id_)
             if len(resolved_id) == 1:
                 try:
                     rint = int(resolved_id[0])
@@ -590,7 +600,7 @@ class RevisionMap:
 
     def _revision_for_ident(
         self,
-        resolved_id: Union[str, Tuple[()]],
+        resolved_id: Union[str, Tuple[()], None],
         check_branch: Optional[str] = None,
     ) -> Optional[Revision]:
         branch_rev: Optional[Revision]
@@ -669,10 +679,10 @@ class RevisionMap:
 
     def filter_for_lineage(
         self,
-        targets: Iterable[_T],
+        targets: Iterable[_TR],
         check_against: Optional[str],
         include_dependencies: bool = False,
-    ) -> Tuple[_T, ...]:
+    ) -> Tuple[_TR, ...]:
         id_, branch_label = self._resolve_revision_number(check_against)
 
         shares = []
@@ -691,7 +701,7 @@ class RevisionMap:
 
     def _shares_lineage(
         self,
-        target: _RevisionOrStr,
+        target: Optional[_RevisionOrStr],
         test_against_revs: Sequence[_RevisionOrStr],
         include_dependencies: bool = False,
     ) -> bool:
@@ -1211,7 +1221,7 @@ class RevisionMap:
             # No relative destination, target is absolute.
             return self.get_revisions(target)
 
-        current_revisions_tup: Union[str, Collection[Optional[str]], None]
+        current_revisions_tup: Union[str, Tuple[Optional[str], ...], None]
         current_revisions_tup = util.to_tuple(current_revisions)
 
         branch_label, symbol, relative_str = match.groups()
@@ -1224,7 +1234,8 @@ class RevisionMap:
                 start_revs = current_revisions_tup
                 if branch_label:
                     start_revs = self.filter_for_lineage(
-                        self.get_revisions(current_revisions_tup), branch_label
+                        self.get_revisions(current_revisions_tup),  # type: ignore[arg-type] # noqa: E501
+                        branch_label,
                     )
                     if not start_revs:
                         # The requested branch is not a head, so we need to
@@ -1577,8 +1588,8 @@ class Revision:
 
         self.verify_rev_id(revision)
         self.revision = revision
-        self.down_revision = tuple_rev_as_scalar(down_revision)
-        self.dependencies = tuple_rev_as_scalar(dependencies)
+        self.down_revision = tuple_rev_as_scalar(util.to_tuple(down_revision))
+        self.dependencies = tuple_rev_as_scalar(util.to_tuple(dependencies))
         self._orig_branch_labels = util.to_tuple(branch_labels, default=())
         self.branch_labels = set(self._orig_branch_labels)
 
@@ -1676,20 +1687,20 @@ class Revision:
 
 
 @overload
-def tuple_rev_as_scalar(
-    rev: Optional[Sequence[str]],
-) -> Optional[Union[str, Sequence[str]]]:
+def tuple_rev_as_scalar(rev: None) -> None:
     ...
 
 
 @overload
 def tuple_rev_as_scalar(
-    rev: Optional[Sequence[Optional[str]]],
-) -> Optional[Union[Optional[str], Sequence[Optional[str]]]]:
+    rev: Union[Tuple[_T, ...], List[_T]]
+) -> Union[_T, Tuple[_T, ...], List[_T]]:
     ...
 
 
-def tuple_rev_as_scalar(rev):
+def tuple_rev_as_scalar(
+    rev: Optional[Sequence[_T]],
+) -> Union[_T, Sequence[_T], None]:
     if not rev:
         return None
     elif len(rev) == 1:
diff --git a/docs/build/unreleased/improve_typing.rst b/docs/build/unreleased/improve_typing.rst
new file mode 100644 (file)
index 0000000..43647fe
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: typing
+    :tickets: 930
+
+    Improve typing of the revision parameter in various command functions.
+
+.. change::
+    :tags: typing, bug
+    :tickets: 1266
+
+    Properly type the :paramref:`.Operations.create_check_constraint.condition`
+    parameter of :meth:`.Operations.create_check_constraint` to accept boolean
+    expressions.