]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Adjust for immediate dependencies that are still ancestors
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Feb 2021 05:18:48 +0000 (00:18 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Feb 2021 16:17:01 +0000 (11:17 -0500)
Fixed bug in versioning model where a downgrade across a revision with a
dependency on another branch, yet an ancestor is also dependent on that
branch, would produce an erroneous state in the alembic_version table,
making upgrades impossible without manually repairing the table.

Change-Id: Ic755f8c3f78845bb80f4f5b3d172caf77cb51132
Fixes: #789
alembic/runtime/migration.py
alembic/script/revision.py
docs/build/unreleased/789.rst [new file with mode: 0644]
tests/test_revision.py
tests/test_version_traversal.py

index bfd73c8c184c57cedab8ee048a143bd6f8dbe151..e8a36e9de2623924b68898e836e9a9f41d25eabd 100644 (file)
@@ -980,7 +980,7 @@ class RevisionStep(MigrationStep):
     @property
     def from_revisions(self):
         if self.is_upgrade:
-            return self.revision._all_down_revisions
+            return self.revision._normalized_down_revisions
         else:
             return (self.revision.revision,)
 
@@ -996,7 +996,7 @@ class RevisionStep(MigrationStep):
         if self.is_upgrade:
             return (self.revision.revision,)
         else:
-            return self.revision._all_down_revisions
+            return self.revision._normalized_down_revisions
 
     @property
     def to_revisions_no_deps(self):
@@ -1007,7 +1007,7 @@ class RevisionStep(MigrationStep):
 
     @property
     def _has_scalar_down_revision(self):
-        return len(self.revision._all_down_revisions) == 1
+        return len(self.revision._normalized_down_revisions) == 1
 
     def should_delete_branch(self, heads):
         """A delete is when we are a. in a downgrade and b.
@@ -1021,7 +1021,7 @@ class RevisionStep(MigrationStep):
         if self.revision.revision not in heads:
             return False
 
-        downrevs = self.revision._all_down_revisions
+        downrevs = self.revision._normalized_down_revisions
 
         if not downrevs:
             # is a base
@@ -1082,7 +1082,7 @@ class RevisionStep(MigrationStep):
         if not self.is_upgrade:
             return False
 
-        downrevs = self.revision._all_down_revisions
+        downrevs = self.revision._normalized_down_revisions
 
         if not downrevs:
             # is a base
@@ -1101,7 +1101,7 @@ class RevisionStep(MigrationStep):
         if not self.is_upgrade:
             return False
 
-        downrevs = self.revision._all_down_revisions
+        downrevs = self.revision._normalized_down_revisions
 
         if len(downrevs) > 1 and len(heads.intersection(downrevs)) > 1:
             return True
@@ -1112,7 +1112,7 @@ class RevisionStep(MigrationStep):
         if not self.is_downgrade:
             return False
 
-        downrevs = self.revision._all_down_revisions
+        downrevs = self.revision._normalized_down_revisions
 
         if self.revision.revision in heads and len(downrevs) > 1:
             return True
@@ -1121,13 +1121,15 @@ class RevisionStep(MigrationStep):
 
     def update_version_num(self, heads):
         if not self._has_scalar_down_revision:
-            downrev = heads.intersection(self.revision._all_down_revisions)
+            downrev = heads.intersection(
+                self.revision._normalized_down_revisions
+            )
             assert (
                 len(downrev) == 1
             ), "Can't do an UPDATE because downrevision is ambiguous"
             down_revision = list(downrev)[0]
         else:
-            down_revision = self.revision._all_down_revisions[0]
+            down_revision = self.revision._normalized_down_revisions[0]
 
         if self.is_upgrade:
             return down_revision, self.revision.revision
@@ -1147,7 +1149,7 @@ class RevisionStep(MigrationStep):
         return MigrationInfo(
             revision_map=self.revision_map,
             up_revisions=self.revision.revision,
-            down_revisions=self.revision._all_down_revisions,
+            down_revisions=self.revision._normalized_down_revisions,
             is_upgrade=self.is_upgrade,
             is_stamp=False,
         )
index 75249b4a6530ef931f2925e69e28b71aa209834d..d131912b6113372e95b53f2014b4bef4ea720529 100644 (file)
@@ -151,8 +151,10 @@ class RevisionMap(object):
         _real_bases = ()
 
         has_branch_labels = set()
-        has_depends_on = set()
+        all_revisions = set()
+
         for revision in self._generator():
+            all_revisions.add(revision)
 
             if revision.revision in map_:
                 util.warn(
@@ -161,8 +163,7 @@ class RevisionMap(object):
             map_[revision.revision] = revision
             if revision.branch_labels:
                 has_branch_labels.add(revision)
-            if revision.dependencies:
-                has_depends_on.add(revision)
+
             heads.add(revision)
             _real_heads.add(revision)
             if revision.is_base:
@@ -173,11 +174,11 @@ class RevisionMap(object):
         # add the branch_labels to the map_.  We'll need these
         # to resolve the dependencies.
         rev_map = map_.copy()
-        for revision in has_branch_labels:
-            self._map_branch_labels(revision, map_)
+        self._map_branch_labels(has_branch_labels, map_)
 
-        for revision in has_depends_on:
-            self._add_depends_on(revision, map_)
+        # resolve dependency names from branch labels and symbolic
+        # names
+        self._add_depends_on(all_revisions, map_)
 
         for rev in map_.values():
             for downrev in rev._all_down_revisions:
@@ -192,40 +193,11 @@ class RevisionMap(object):
                     heads.discard(down_revision)
                 _real_heads.discard(down_revision)
 
-        if rev_map:
-            if not heads or not bases:
-                raise CycleDetected(rev_map.keys())
-            total_space = {
-                rev.revision
-                for rev in self._iterate_related_revisions(
-                    lambda r: r._versioned_down_revisions, heads, map_=rev_map
-                )
-            }.intersection(
-                rev.revision
-                for rev in self._iterate_related_revisions(
-                    lambda r: r.nextrev, bases, map_=rev_map
-                )
-            )
-            deleted_revs = set(rev_map.keys()) - total_space
-            if deleted_revs:
-                raise CycleDetected(sorted(deleted_revs))
-
-            if not _real_heads or not _real_bases:
-                raise DependencyCycleDetected(rev_map.keys())
-            total_space = {
-                rev.revision
-                for rev in self._iterate_related_revisions(
-                    lambda r: r._all_down_revisions, _real_heads, map_=rev_map
-                )
-            }.intersection(
-                rev.revision
-                for rev in self._iterate_related_revisions(
-                    lambda r: r._all_nextrev, _real_bases, map_=rev_map
-                )
-            )
-            deleted_revs = set(rev_map.keys()) - total_space
-            if deleted_revs:
-                raise DependencyCycleDetected(sorted(deleted_revs))
+        # once the map has downrevisions populated, the dependencies
+        # can be further refined to include only those which are not
+        # already ancestors
+        self._normalize_depends_on(all_revisions, map_)
+        self._detect_cycles(rev_map, heads, bases, _real_heads, _real_bases)
 
         map_[None] = map_[()] = None
         self.heads = tuple(rev.revision for rev in heads)
@@ -233,53 +205,140 @@ class RevisionMap(object):
         self.bases = tuple(rev.revision for rev in bases)
         self._real_bases = tuple(rev.revision for rev in _real_bases)
 
-        for revision in has_branch_labels:
-            self._add_branches(revision, map_, map_branch_labels=False)
+        self._add_branches(has_branch_labels, map_)
         return map_
 
-    def _map_branch_labels(self, revision, map_):
-        if revision.branch_labels:
-            for branch_label in revision._orig_branch_labels:
-                if branch_label in map_:
-                    raise RevisionError(
-                        "Branch name '%s' in revision %s already "
-                        "used by revision %s"
-                        % (
-                            branch_label,
-                            revision.revision,
-                            map_[branch_label].revision,
+    def _detect_cycles(self, rev_map, heads, bases, _real_heads, _real_bases):
+        if not rev_map:
+            return
+        if not heads or not bases:
+            raise CycleDetected(rev_map.keys())
+        total_space = {
+            rev.revision
+            for rev in self._iterate_related_revisions(
+                lambda r: r._versioned_down_revisions, heads, map_=rev_map
+            )
+        }.intersection(
+            rev.revision
+            for rev in self._iterate_related_revisions(
+                lambda r: r.nextrev, bases, map_=rev_map
+            )
+        )
+        deleted_revs = set(rev_map.keys()) - total_space
+        if deleted_revs:
+            raise CycleDetected(sorted(deleted_revs))
+
+        if not _real_heads or not _real_bases:
+            raise DependencyCycleDetected(rev_map.keys())
+        total_space = {
+            rev.revision
+            for rev in self._iterate_related_revisions(
+                lambda r: r._all_down_revisions, _real_heads, map_=rev_map
+            )
+        }.intersection(
+            rev.revision
+            for rev in self._iterate_related_revisions(
+                lambda r: r._all_nextrev, _real_bases, map_=rev_map
+            )
+        )
+        deleted_revs = set(rev_map.keys()) - total_space
+        if deleted_revs:
+            raise DependencyCycleDetected(sorted(deleted_revs))
+
+    def _map_branch_labels(self, revisions, map_):
+        for revision in revisions:
+            if revision.branch_labels:
+                for branch_label in revision._orig_branch_labels:
+                    if branch_label in map_:
+                        raise RevisionError(
+                            "Branch name '%s' in revision %s already "
+                            "used by revision %s"
+                            % (
+                                branch_label,
+                                revision.revision,
+                                map_[branch_label].revision,
+                            )
                         )
-                    )
-                map_[branch_label] = revision
+                    map_[branch_label] = revision
+
+    def _add_branches(self, revisions, map_):
+        for revision in revisions:
+            if revision.branch_labels:
+                revision.branch_labels.update(revision.branch_labels)
+                for node in self._get_descendant_nodes(
+                    [revision], map_, include_dependencies=False
+                ):
+                    node.branch_labels.update(revision.branch_labels)
+
+                parent = node
+                while (
+                    parent
+                    and not parent._is_real_branch_point
+                    and not parent.is_merge_point
+                ):
+
+                    parent.branch_labels.update(revision.branch_labels)
+                    if parent.down_revision:
+                        parent = map_[parent.down_revision]
+                    else:
+                        break
 
-    def _add_branches(self, revision, map_, map_branch_labels=True):
-        if map_branch_labels:
-            self._map_branch_labels(revision, map_)
+    def _add_depends_on(self, revisions, map_):
+        """Resolve the 'dependencies' for each revision in a collection
+        in terms of actual revision ids, as opposed to branch labels or other
+        symbolic names.
 
-        if revision.branch_labels:
-            revision.branch_labels.update(revision.branch_labels)
-            for node in self._get_descendant_nodes(
-                [revision], map_, include_dependencies=False
-            ):
-                node.branch_labels.update(revision.branch_labels)
+        The collection is then assigned to the _resolved_dependencies
+        attribute on each revision object.
 
-            parent = node
-            while (
-                parent
-                and not parent._is_real_branch_point
-                and not parent.is_merge_point
-            ):
+        """
 
-                parent.branch_labels.update(revision.branch_labels)
-                if parent.down_revision:
-                    parent = map_[parent.down_revision]
-                else:
-                    break
+        for revision in revisions:
+            if revision.dependencies:
+                deps = [
+                    map_[dep] for dep in util.to_tuple(revision.dependencies)
+                ]
+                revision._resolved_dependencies = tuple(
+                    [d.revision for d in deps]
+                )
+            else:
+                revision._resolved_dependencies = ()
+
+    def _normalize_depends_on(self, revisions, map_):
+        """Create a collection of "dependencies" that omits dependencies
+        that are already ancestor nodes for each revision in a given
+        collection.
+
+        This builds upon the _resolved_dependencies collection created in the
+        _add_depends_on() method, looking in the fully populated revision map
+        for ancestors, and omitting them as the _resolved_dependencies
+        collection as it is copied to a new collection. The new collection is
+        then assigned to the _normalized_resolved_dependencies attribute on
+        each revision object.
+
+        The collection is then used to determine the immediate "down revision"
+        identifiers for this revision.
+
+        """
+
+        for revision in revisions:
+            if revision._resolved_dependencies:
+                normalized_resolved = set(revision._resolved_dependencies)
+                for rev in self._get_ancestor_nodes(
+                    [revision], include_dependencies=False, map_=map_
+                ):
+                    if rev is revision:
+                        continue
+                    elif rev._resolved_dependencies:
+                        normalized_resolved.difference_update(
+                            rev._resolved_dependencies
+                        )
 
-    def _add_depends_on(self, revision, map_):
-        if revision.dependencies:
-            deps = [map_[dep] for dep in util.to_tuple(revision.dependencies)]
-            revision._resolved_dependencies = tuple([d.revision for d in deps])
+                revision._normalized_resolved_dependencies = tuple(
+                    normalized_resolved
+                )
+            else:
+                revision._normalized_resolved_dependencies = ()
 
     def add_revision(self, revision, _replace=False):
         """add a single revision to an existing map.
@@ -297,13 +356,17 @@ class RevisionMap(object):
             raise Exception("revision %s not in map" % revision.revision)
 
         map_[revision.revision] = revision
-        self._add_branches(revision, map_)
-        self._add_depends_on(revision, map_)
+
+        revisions = [revision]
+        self._add_branches(revisions, map_)
+        self._map_branch_labels(revisions, map_)
+        self._add_depends_on(revisions, map_)
 
         if revision.is_base:
             self.bases += (revision.revision,)
         if revision._is_real_base:
             self._real_bases += (revision.revision,)
+
         for downrev in revision._all_down_revisions:
             if downrev not in map_:
                 util.warn(
@@ -311,6 +374,9 @@ class RevisionMap(object):
                     % (downrev, revision)
                 )
             map_[downrev].add_nextrev(revision)
+
+        self._normalize_depends_on(revisions, map_)
+
         if revision._is_real_head:
             self._real_heads = tuple(
                 head
@@ -773,7 +839,7 @@ class RevisionMap(object):
         if include_dependencies:
 
             def fn(rev):
-                return rev._all_down_revisions
+                return rev._normalized_down_revisions
 
         else:
 
@@ -866,7 +932,15 @@ class RevisionMap(object):
                 lower_ancestors
             ).difference(lower_descendants)
             for rev in candidate_lowers:
-                for downrev in rev._all_down_revisions:
+                # note: the use of _normalized_down_revisions as opposed
+                # to _all_down_revisions repairs
+                # an issue related to looking at a revision in isolation
+                # when updating the alembic_version table (issue #789).
+                # however, while it seems likely that using
+                # _normalized_down_revisions within traversal is more correct
+                # than _all_down_revisions, we don't yet have any case to
+                # show that it actually makes a difference.
+                for downrev in rev._normalized_down_revisions:
                     if self._revision_map[downrev] in candidate_lowers:
                         break
                 else:
@@ -971,7 +1045,7 @@ class RevisionMap(object):
                 todo.extendleft(
                     [
                         self._revision_map[downrev]
-                        for downrev in reversed(rev._all_down_revisions)
+                        for downrev in reversed(rev._normalized_down_revisions)
                         if self._revision_map[downrev] not in branch_todo
                         and downrev in total_space
                     ]
@@ -1048,7 +1122,6 @@ class Revision(object):
         self.revision = revision
         self.down_revision = tuple_rev_as_scalar(down_revision)
         self.dependencies = tuple_rev_as_scalar(dependencies)
-        self._resolved_dependencies = ()
         self._orig_branch_labels = util.to_tuple(branch_labels, default=())
         self.branch_labels = set(self._orig_branch_labels)
 
@@ -1072,6 +1145,17 @@ class Revision(object):
             + self._resolved_dependencies
         )
 
+    @property
+    def _normalized_down_revisions(self):
+        """return immediate down revisions for a rev, omitting dependencies
+        that are still dependencies of ancestors.
+
+        """
+        return (
+            util.to_tuple(self.down_revision, default=())
+            + self._normalized_resolved_dependencies
+        )
+
     @property
     def _versioned_down_revisions(self):
         return util.to_tuple(self.down_revision, default=())
diff --git a/docs/build/unreleased/789.rst b/docs/build/unreleased/789.rst
new file mode 100644 (file)
index 0000000..8fb5e10
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, versioning
+    :tickets: 789
+
+    Fixed bug in versioning model where a downgrade across a revision with a
+    dependency on another branch, yet an ancestor is also dependent on that
+    branch, would produce an erroneous state in the alembic_version table,
+    making upgrades impossible without manually repairing the table.
index 3a6e9562a45c3517b5def1aacac4cc3ed331499c..749ded3a36699204c51367b1b98211f2f1a9bcc9 100644 (file)
@@ -1244,6 +1244,9 @@ class DepResolutionFailedTest(DownIterateTest):
         self.map._revision_map["fake"] = self.map._revision_map["a2"]
         self.map._revision_map["b1"].dependencies = "fake"
         self.map._revision_map["b1"]._resolved_dependencies = ("fake",)
+        self.map._revision_map["b1"]._normalized_resolved_dependencies = (
+            "fake",
+        )
 
     def test_failure_message(self):
         iter_ = self.map.iterate_revisions("c1", "base1")
@@ -1502,3 +1505,90 @@ class GraphWithCycleTest(InvalidRevisionMapTest):
         self._assert_raises_revision_map_dep_cycle(
             map_, ["a", "b", "c", "d", "e"]
         )
+
+
+class NormalizedDownRevTest(DownIterateTest):
+    def setUp(self):
+        self.map = RevisionMap(
+            lambda: [
+                Revision("a1", ()),
+                Revision("a2", "a1"),
+                Revision("a3", "a2"),
+                Revision("b1", ()),
+                Revision("b2", "b1", dependencies="a3"),
+                Revision("b3", "b2"),
+                Revision("b4", "b3", dependencies="a3"),
+                Revision("b5", "b4"),
+            ]
+        )
+
+    def test_normalized_down_revisions(self):
+        b4 = self.map.get_revision("b4")
+
+        eq_(b4._all_down_revisions, ("b3", "a3"))
+
+        # "a3" is not included because ancestor b2 is also dependent
+        eq_(b4._normalized_down_revisions, ("b3",))
+
+    def test_branch_traversal(self):
+        self._assert_iteration(
+            "b4",
+            "b1@base",
+            ["b4", "b3", "b2", "b1"],
+            select_for_downgrade=True,
+        )
+
+    def test_all_traversal(self):
+        self._assert_iteration(
+            "heads",
+            "base",
+            ["b5", "b4", "b3", "b2", "b1", "a3", "a2", "a1"],
+            select_for_downgrade=True,
+        )
+
+    def test_partial_traversal(self):
+        self._assert_iteration(
+            "heads",
+            "a2",
+            ["b5", "b4", "b3", "b2", "a3", "a2"],
+            select_for_downgrade=True,
+        )
+
+    def test_partial_traversal_implicit_base_one(self):
+        self._assert_iteration(
+            "heads",
+            "a2",
+            ["b5", "b4", "b3", "b2", "b1", "a3", "a2"],
+            select_for_downgrade=True,
+            implicit_base=True,
+        )
+
+    def test_partial_traversal_implicit_base_two(self):
+        self._assert_iteration(
+            "b5",
+            ("b1",),
+            ["b5", "b4", "b3", "b2", "b1", "a3", "a2", "a1"],
+            implicit_base=True,
+        )
+
+    def test_partial_traversal_implicit_base_three(self):
+        map_ = RevisionMap(
+            lambda: [
+                Revision("c1", ()),
+                Revision("a1", ()),
+                Revision("a2", "a1", dependencies="c1"),
+                Revision("a3", "a2", dependencies="c1"),
+                Revision("b1", ()),
+                Revision("b2", "b1", dependencies="a3"),
+                Revision("b3", "b2"),
+                Revision("b4", "b3", dependencies="a3"),
+                Revision("b5", "b4"),
+            ]
+        )
+        self._assert_iteration(
+            "b5",
+            ("b1",),
+            ["b5", "b4", "b3", "b2", "b1", "a3", "a2", "a1", "c1"],
+            implicit_base=True,
+            map_=map_,
+        )
index dc5683f9ad8a419ba12ecada6232da86deef9116..915c9fd2a24626efc1caab8fdcda0c43b6e29c82 100644 (file)
@@ -771,6 +771,7 @@ class DependsOnBranchTestTwo(MigrationTest):
                     self.amerge.revision,
                     self.b1.revision,
                     self.cmerge.revision,
+                    # b2 isn't here, but d1 is, which implies b2. OK!
                     self.d1.revision,
                 ]
             ),
@@ -789,11 +790,11 @@ class DependsOnBranchTestTwo(MigrationTest):
             "d1@base",
             heads,
             [self.down_(self.d1)],
-            # b2 has to be INSERTed, because it was implied by d1
             set(
                 [
                     self.amerge.revision,
                     self.b1.revision,
+                    # b2 has to be INSERTed, because it was implied by d1
                     self.b2.revision,
                     self.cmerge.revision,
                 ]
@@ -878,6 +879,43 @@ class DependsOnBranchTestThree(MigrationTest):
         )
 
 
+class DependsOnBranchTestFour(MigrationTest):
+    @classmethod
+    def setup_class(cls):
+        """
+        test issue #789
+        """
+        cls.env = env = staging_env()
+        cls.a1 = env.generate_revision("a1", "->a1", head="base")
+        cls.a2 = env.generate_revision("a2", "->a2")
+        cls.a3 = env.generate_revision("a3", "->a3")
+
+        cls.b1 = env.generate_revision("b1", "->b1", head="base")
+        cls.b2 = env.generate_revision(
+            "b2", "->b2", head="b1", depends_on="a3"
+        )
+        cls.b3 = env.generate_revision("b3", "->b3", head="b2")
+        cls.b4 = env.generate_revision(
+            "b4", "->b4", head="b3", depends_on="a3"
+        )
+
+    @classmethod
+    def teardown_class(cls):
+        clear_staging_env()
+
+    def test_dependencies_are_normalized(self):
+
+        heads = [self.b4.revision]
+
+        self._assert_downgrade(
+            self.b3.revision,
+            heads,
+            [self.down_(self.b4)],
+            # a3 isn't here, because b3 still implies a3
+            set([self.b3.revision]),
+        )
+
+
 class DependsOnBranchLabelTest(MigrationTest):
     @classmethod
     def setup_class(cls):