]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Detect loop(s)/cycle(s) in a revision graph 758/head
authorKoichiro Den <den@valinux.co.jp>
Sun, 22 Nov 2020 11:15:12 +0000 (20:15 +0900)
committerKoichiro Den <den@valinux.co.jp>
Tue, 1 Dec 2020 01:23:31 +0000 (10:23 +0900)
Fixes: #757
alembic/script/revision.py
tests/test_revision.py

index 22481a0872e1652b9e1f85dda627e0a6c1bf9dc1..683d3227e8ee439f6989112a479955ea41dcd565 100644 (file)
@@ -40,6 +40,38 @@ class ResolutionError(RevisionError):
         self.argument = argument
 
 
+class CycleDetected(RevisionError):
+    kind = "Cycle"
+
+    def __init__(self, revisions):
+        self.revisions = revisions
+        super(CycleDetected, self).__init__(
+            "%s is detected in revisions (%s)"
+            % (self.kind, ", ".join(revisions))
+        )
+
+
+class DependencyCycleDetected(CycleDetected):
+    kind = "Dependency cycle"
+
+    def __init__(self, revisions):
+        super(DependencyCycleDetected, self).__init__(revisions)
+
+
+class LoopDetected(CycleDetected):
+    kind = "Self-loop"
+
+    def __init__(self, revision):
+        super(LoopDetected, self).__init__([revision])
+
+
+class DependencyLoopDetected(DependencyCycleDetected, LoopDetected):
+    kind = "Dependency self-loop"
+
+    def __init__(self, revision):
+        super(DependencyLoopDetected, self).__init__(revision)
+
+
 class RevisionMap(object):
     """Maintains a map of :class:`.Revision` objects.
 
@@ -115,8 +147,8 @@ class RevisionMap(object):
 
         heads = sqlautil.OrderedSet()
         _real_heads = sqlautil.OrderedSet()
-        self.bases = ()
-        self._real_bases = ()
+        bases = ()
+        _real_bases = ()
 
         has_branch_labels = set()
         has_depends_on = set()
@@ -131,15 +163,16 @@ class RevisionMap(object):
                 has_branch_labels.add(revision)
             if revision.dependencies:
                 has_depends_on.add(revision)
-            heads.add(revision.revision)
-            _real_heads.add(revision.revision)
+            heads.add(revision)
+            _real_heads.add(revision)
             if revision.is_base:
-                self.bases += (revision.revision,)
+                bases += (revision,)
             if revision._is_real_base:
-                self._real_bases += (revision.revision,)
+                _real_bases += (revision,)
 
         # add the branch_labels to the map_.  We'll need these
         # to resolve the dependencies.
+        rev_map = map_.copy()
         for revision in has_branch_labels:
             self._map_branch_labels(revision, map_)
 
@@ -156,12 +189,49 @@ class RevisionMap(object):
                 down_revision = map_[downrev]
                 down_revision.add_nextrev(rev)
                 if downrev in rev._versioned_down_revisions:
-                    heads.discard(downrev)
-                _real_heads.discard(downrev)
+                    heads.discard(down_revision)
+                _real_heads.discard(down_revision)
+
+        if rev_map:
+            if not heads or not bases:
+                raise CycleDetected(rev_map.keys())
+            total_space = {
+                rev.revision
+                for rev in self._iterate_related_revisions(
+                    lambda r: r._versioned_down_revisions, heads, map_=rev_map
+                )
+            }.intersection(
+                rev.revision
+                for rev in self._iterate_related_revisions(
+                    lambda r: r.nextrev, bases, map_=rev_map
+                )
+            )
+            deleted_revs = set(rev_map.keys()) - total_space
+            if deleted_revs:
+                raise CycleDetected(sorted(deleted_revs))
+
+            if not _real_heads or not _real_bases:
+                raise DependencyCycleDetected(rev_map.keys())
+            total_space = {
+                rev.revision
+                for rev in self._iterate_related_revisions(
+                    lambda r: r._all_down_revisions, _real_heads, map_=rev_map
+                )
+            }.intersection(
+                rev.revision
+                for rev in self._iterate_related_revisions(
+                    lambda r: r._all_nextrev, _real_bases, map_=rev_map
+                )
+            )
+            deleted_revs = set(rev_map.keys()) - total_space
+            if deleted_revs:
+                raise DependencyCycleDetected(sorted(deleted_revs))
 
         map_[None] = map_[()] = None
-        self.heads = tuple(heads)
-        self._real_heads = tuple(_real_heads)
+        self.heads = tuple(rev.revision for rev in heads)
+        self._real_heads = tuple(rev.revision for rev in _real_heads)
+        self.bases = tuple(rev.revision for rev in bases)
+        self._real_bases = tuple(rev.revision for rev in _real_bases)
 
         for revision in has_branch_labels:
             self._add_branches(revision, map_, map_branch_labels=False)
@@ -964,6 +1034,11 @@ class Revision(object):
     def __init__(
         self, revision, down_revision, dependencies=None, branch_labels=None
     ):
+        if down_revision and revision in down_revision:
+            raise LoopDetected(revision)
+        elif dependencies is not None and revision in dependencies:
+            raise DependencyLoopDetected(revision)
+
         self.verify_rev_id(revision)
         self.revision = revision
         self.down_revision = tuple_rev_as_scalar(down_revision)
index bf433f51d50ba3eaf173cb4af42c9fb53b98c57e..767baa6e60d6788d412a92c911b947dd941f5f59 100644 (file)
@@ -1,3 +1,7 @@
+from alembic.script.revision import CycleDetected
+from alembic.script.revision import DependencyCycleDetected
+from alembic.script.revision import DependencyLoopDetected
+from alembic.script.revision import LoopDetected
 from alembic.script.revision import MultipleHeads
 from alembic.script.revision import Revision
 from alembic.script.revision import RevisionError
@@ -125,6 +129,39 @@ class APITest(TestBase):
             "heads",
         )
 
+    def test_get_revisions_head_multiple(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c1", ("b",)),
+                Revision("c2", ("b",)),
+            ]
+        )
+        assert_raises_message(
+            MultipleHeads,
+            "Multiple heads are present",
+            map_.get_revisions,
+            "head",
+        )
+
+    def test_get_revisions_heads_multiple(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", ("a",)),
+                Revision("c1", ("b",)),
+                Revision("c2", ("b",)),
+            ]
+        )
+        eq_(
+            map_.get_revisions("heads"),
+            (
+                map_._revision_map["c1"],
+                map_._revision_map["c2"],
+            ),
+        )
+
     def test_get_revision_base_multiple(self):
         map_ = RevisionMap(
             lambda: [
@@ -1213,3 +1250,239 @@ class DepResolutionFailedTest(DownIterateTest):
         assert_raises_message(
             RevisionError, "Dependency resolution failed;", list, iter_
         )
+
+
+class InvalidRevisionMapTest(TestBase):
+    def _assert_raises_revision_map(self, map_, except_cls, msg):
+        assert_raises_message(except_cls, msg, lambda: map_._revision_map)
+
+    def _assert_raises_revision_map_loop(self, map_, revision):
+        self._assert_raises_revision_map(
+            map_,
+            LoopDetected,
+            r"^Self-loop is detected in revisions \(%s\)$" % revision,
+        )
+
+    def _assert_raises_revision_map_dep_loop(self, map_, revision):
+        self._assert_raises_revision_map(
+            map_,
+            DependencyLoopDetected,
+            r"^Dependency self-loop is detected in revisions \(%s\)$"
+            % revision,
+        )
+
+    def _assert_raises_revision_map_cycle(self, map_, revisions):
+        self._assert_raises_revision_map(
+            map_,
+            CycleDetected,
+            r"^Cycle is detected in revisions \(\(%s\)\(, \)?\)+$"
+            % "|".join(revisions),
+        )
+
+    def _assert_raises_revision_map_dep_cycle(self, map_, revisions):
+        self._assert_raises_revision_map(
+            map_,
+            DependencyCycleDetected,
+            r"^Dependency cycle is detected in revisions \(\(%s\)\(, \)?\)+$"
+            % "|".join(revisions),
+        )
+
+
+class GraphWithLoopTest(InvalidRevisionMapTest):
+    def test_revision_map_solitary_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", "a"),
+            ]
+        )
+        self._assert_raises_revision_map_loop(map_, "a")
+
+    def test_revision_map_base_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", "a"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+            ]
+        )
+        self._assert_raises_revision_map_loop(map_, "a")
+
+    def test_revision_map_head_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", "a"),
+                Revision("c", ("b", "c")),
+            ]
+        )
+        self._assert_raises_revision_map_loop(map_, "c")
+
+    def test_revision_map_branch_point_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", ("a", "b")),
+                Revision("c1", "b"),
+                Revision("c2", "b"),
+            ]
+        )
+        self._assert_raises_revision_map_loop(map_, "b")
+
+    def test_revision_map_merge_point_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b1", "a"),
+                Revision("b2", "a"),
+                Revision("c", ("b1", "b2", "c")),
+            ]
+        )
+        self._assert_raises_revision_map_loop(map_, "c")
+
+    def test_revision_map_solitary_dependency_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", (), dependencies="a"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_loop(map_, "a")
+
+    def test_revision_map_base_dependency_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", (), dependencies="a"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_loop(map_, "a")
+
+    def test_revision_map_head_dep_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", "a"),
+                Revision("c", "b", dependencies="c"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_loop(map_, "c")
+
+    def test_revision_map_branch_point_dep_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", "a", dependencies="b"),
+                Revision("c1", "b"),
+                Revision("c2", "b"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_loop(map_, "b")
+
+    def test_revision_map_merge_point_dep_loop(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b1", "a"),
+                Revision("b2", "a"),
+                Revision("c", ("b1", "b2"), dependencies="c"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_loop(map_, "c")
+
+
+class GraphWithCycleTest(InvalidRevisionMapTest):
+    def test_revision_map_simple_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", "c"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+            ]
+        )
+        self._assert_raises_revision_map_cycle(map_, ["a", "b", "c"])
+
+    def test_revision_map_extra_simple_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", "c"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+                Revision("d", ()),
+                Revision("e", "d"),
+            ]
+        )
+        self._assert_raises_revision_map_cycle(map_, ["a", "b", "c"])
+
+    def test_revision_map_lower_simple_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", "c"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+                Revision("d", "c"),
+                Revision("e", "d"),
+            ]
+        )
+        self._assert_raises_revision_map_cycle(map_, ["a", "b", "c", "d", "e"])
+
+    def test_revision_map_upper_simple_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", "a"),
+                Revision("c", ("b", "e")),
+                Revision("d", "c"),
+                Revision("e", "d"),
+            ]
+        )
+        self._assert_raises_revision_map_cycle(map_, ["a", "b", "c", "d", "e"])
+
+    def test_revision_map_simple_dep_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", (), dependencies="c"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_cycle(map_, ["a", "b", "c"])
+
+    def test_revision_map_extra_simple_dep_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", (), dependencies="c"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+                Revision("d", ()),
+                Revision("e", "d"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_cycle(map_, ["a", "b", "c"])
+
+    def test_revision_map_lower_simple_dep_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", (), dependencies="c"),
+                Revision("b", "a"),
+                Revision("c", "b"),
+                Revision("d", "c"),
+                Revision("e", "d"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_cycle(
+            map_, ["a", "b", "c", "d", "e"]
+        )
+
+    def test_revision_map_upper_simple_dep_cycle(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("a", ()),
+                Revision("b", "a"),
+                Revision("c", "b", dependencies="e"),
+                Revision("d", "c"),
+                Revision("e", "d"),
+            ]
+        )
+        self._assert_raises_revision_map_dep_cycle(
+            map_, ["a", "b", "c", "d", "e"]
+        )