]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Fix downgrade with effective head
authorCaselIT <cfederico87@gmail.com>
Mon, 2 May 2022 20:53:39 +0000 (22:53 +0200)
committerCaselIT <cfederico87@gmail.com>
Tue, 3 May 2022 20:42:40 +0000 (22:42 +0200)
Fixed issue where downgrade using a relative revision would
fail in case of multiple branches with a single effectively
head due to interdependencies between revisions.

Fixes: #1026
Change-Id: I79f5595fb9d03124db8039345055571a9134eecd

alembic/script/base.py
alembic/script/revision.py
alembic/util/compat.py
docs/build/unreleased/1026.rst [new file with mode: 0644]
pyproject.toml
tests/test_revision.py

index ccbf86c97b0c452d99e553bc1e1285791fc5a1eb..cd81f375e6f9d5900cd0d003aff321dcc78d6d47 100644 (file)
@@ -298,20 +298,9 @@ class ScriptDirectory:
 
     def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]:
         with self._catch_revision_errors():
-            top_revs = cast(
-                Set[Optional[Script]],
-                set(self.revision_map.get_revisions(id_)),
-            )
-            top_revs.update(
-                cast(
-                    Iterator[Script],
-                    self.revision_map._get_ancestor_nodes(
-                        list(top_revs), include_dependencies=True
-                    ),
-                )
+            return cast(
+                Set[Optional[Script]], self.revision_map._get_all_current(id_)
             )
-            top_revs = self.revision_map._filter_into_branch_heads(top_revs)
-            return top_revs
 
     def get_revision(self, id_: str) -> Optional[Script]:
         """Return the :class:`.Script` instance with the given rev id.
index 335314f9c6ba2657ba6b2f20b431a99043755255..6e25891d47c0a3580669c1f8691688d9b58cebf5 100644 (file)
@@ -29,8 +29,6 @@ from ..util import not_none
 if TYPE_CHECKING:
     from typing import Literal
 
-    from .base import Script
-
 _RevIdType = Union[str, Sequence[str]]
 _RevisionIdentifierType = Union[str, Tuple[str, ...], None]
 _RevisionOrStr = Union["Revision", str]
@@ -660,8 +658,8 @@ class RevisionMap:
         return revision
 
     def _filter_into_branch_heads(
-        self, targets: Set[Optional[Script]]
-    ) -> Set[Optional[Script]]:
+        self, targets: Iterable[Optional[_RevisionOrBase]]
+    ) -> Set[Optional[_RevisionOrBase]]:
         targets = set(targets)
 
         for rev in list(targets):
@@ -811,7 +809,7 @@ class RevisionMap:
 
     def _get_descendant_nodes(
         self,
-        targets: Collection[Revision],
+        targets: Collection[Optional[_RevisionOrBase]],
         map_: Optional[_RevisionMapType] = None,
         check: bool = False,
         omit_immediate_dependencies: bool = False,
@@ -1129,9 +1127,27 @@ class RevisionMap:
                 if relative_revision:
                     # Find target revision relative to current state.
                     if branch_label:
+                        cr_tuple = util.to_tuple(current_revisions)
+                        symbol_list: Sequence[str]
                         symbol_list = self.filter_for_lineage(
-                            util.to_tuple(current_revisions), branch_label
+                            cr_tuple, branch_label
                         )
+                        if not symbol_list:
+                            # check the case where there are multiple branches
+                            # but there is currently a single heads, since all
+                            # other branch heads are dependant of the current
+                            # single heads.
+                            all_current = cast(
+                                Set[Revision], self._get_all_current(cr_tuple)
+                            )
+                            sl_all_current = self.filter_for_lineage(
+                                all_current, branch_label
+                            )
+                            symbol_list = [
+                                r.revision if r else r  # type: ignore[misc]
+                                for r in sl_all_current
+                            ]
+
                         assert len(symbol_list) == 1
                         symbol = symbol_list[0]
                     else:
@@ -1487,6 +1503,16 @@ class RevisionMap:
 
         return needs, tuple(targets)  # type:ignore[return-value]
 
+    def _get_all_current(
+        self, id_: Tuple[str, ...]
+    ) -> Set[Optional[_RevisionOrBase]]:
+        top_revs: Set[Optional[_RevisionOrBase]]
+        top_revs = set(self.get_revisions(id_))
+        top_revs.update(
+            self._get_ancestor_nodes(list(top_revs), include_dependencies=True)
+        )
+        return self._filter_into_branch_heads(top_revs)
+
 
 class Revision:
     """Base class for revisioned objects.
@@ -1545,8 +1571,8 @@ class Revision:
         self,
         revision: str,
         down_revision: Optional[Union[str, Tuple[str, ...]]],
-        dependencies: Optional[Tuple[str, ...]] = None,
-        branch_labels: Optional[Tuple[str, ...]] = None,
+        dependencies: Optional[Union[str, Tuple[str, ...]]] = None,
+        branch_labels: Optional[Union[str, Tuple[str, ...]]] = None,
     ) -> None:
         if down_revision and revision in util.to_tuple(down_revision):
             raise LoopDetected(revision)
index e6a8f6e0acfcd1d1e932fea605f155b9851b63f0..64148b614de89a26fcd70048ed0a70d2132b47c0 100644 (file)
@@ -3,7 +3,7 @@ from __future__ import annotations
 import io
 import os
 import sys
-from typing import Tuple
+from typing import Sequence
 
 from sqlalchemy.util import inspect_getfullargspec  # noqa
 from sqlalchemy.util.compat import inspect_formatargspec  # noqa
@@ -33,7 +33,7 @@ else:
     from importlib_metadata import EntryPoint  # type:ignore # noqa
 
 
-def importlib_metadata_get(group: str) -> Tuple[EntryPoint, ...]:
+def importlib_metadata_get(group: str) -> Sequence[EntryPoint]:
     ep = importlib_metadata.entry_points()
     if hasattr(ep, "select"):
         return ep.select(group=group)  # type:ignore[attr-defined]
diff --git a/docs/build/unreleased/1026.rst b/docs/build/unreleased/1026.rst
new file mode 100644 (file)
index 0000000..b5df749
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, revisioning
+    :tickets: 1026
+
+    Fixed issue where a downgrade using a relative revision would
+    fail in case of multiple branches with a single effectively
+    head due to interdependencies between revisions.
index 721b0db0157a4bda0524440c4090275f156a3f23..e766501faeb0a30de9300a40ce33ec9cc4e64d5e 100644 (file)
@@ -7,6 +7,7 @@ exclude = [
     'alembic/template',
     'alembic.testing.*',
 ]
+show_error_codes = true
 
 [[tool.mypy.overrides]]
 module = [
index 61998bf8fd684a94cafd5c64933dbead3110ae14..0d5bfd54d77d558ba64dfa79a8bb12f8d4da7350 100644 (file)
@@ -239,10 +239,7 @@ class DownIterateTest(TestBase):
             edges, list(reversed(result))
         )
 
-        eq_(
-            result,
-            assertion,
-        )
+        eq_(result, assertion)
 
 
 class DiamondTest(DownIterateTest):
@@ -573,6 +570,45 @@ class MultipleBranchTest(DownIterateTest):
         )
 
 
+class MultipleBranchEffectiveHead(DownIterateTest):
+    def setUp(self):
+        self.map = RevisionMap(
+            lambda: [
+                Revision("y1", None, branch_labels="y"),
+                Revision("x1", None, branch_labels="x"),
+                Revision("y2", "y1", dependencies="x1"),
+                Revision("x2", "x1"),
+            ]
+        )
+
+    def test_other_downgrade(self):
+        self._assert_iteration(
+            ("x2", "y2"),
+            "x@-1",
+            ["x2"],
+            inclusive=False,
+            select_for_downgrade=True,
+        )
+
+    def test_use_all_current(self):
+        self._assert_iteration(
+            ("x1", "y2"),
+            "x@-1",
+            ["y2", "x1"],
+            inclusive=False,
+            select_for_downgrade=True,
+        )
+
+    def test_effective_head(self):
+        self._assert_iteration(
+            "y2",
+            "x@-1",
+            ["y2", "x1"],
+            inclusive=False,
+            select_for_downgrade=True,
+        )
+
+
 class BranchTravellingTest(DownIterateTest):
     """test the order of revs when going along multiple branches.