]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- break out the concept of "down revision" into two pieces:
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Nov 2014 16:49:20 +0000 (11:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Nov 2014 16:49:20 +0000 (11:49 -0500)
down_revision and "dependencies".  For migration traversal, the downrevs
we care about are the union of these two sets.  however for location of nodes
and branch labeling, we look only at down_revsion.  this works really well
and allows us to have mutually-dependent trees that can easily be itererated
independently of each other.  docs are needed

alembic/migration.py
alembic/revision.py
alembic/script.py
tests/test_revision.py

index 7d854bab37566808a9fda360cbcd1dfd3d6dfd7e..eefb1d1ea513aafd714a347c32db879b4e65e6d7 100644 (file)
@@ -553,7 +553,7 @@ class RevisionStep(MigrationStep):
     @property
     def from_revisions(self):
         if self.is_upgrade:
-            return self.revision._down_revision_tuple
+            return self.revision._all_down_revisions
         else:
             return (self.revision.revision, )
 
@@ -562,11 +562,11 @@ class RevisionStep(MigrationStep):
         if self.is_upgrade:
             return (self.revision.revision, )
         else:
-            return self.revision._down_revision_tuple
+            return self.revision._all_down_revisions
 
     @property
     def _has_scalar_down_revision(self):
-        return len(self.revision._down_revision_tuple) == 1
+        return len(self.revision._all_down_revisions) == 1
 
     def should_delete_branch(self, heads):
         if not self.is_downgrade:
@@ -575,7 +575,7 @@ class RevisionStep(MigrationStep):
         if self.revision.revision not in heads:
             return False
 
-        downrevs = self.revision._down_revision_tuple
+        downrevs = self.revision._all_down_revisions
         if not downrevs:
             # is a base
             return True
@@ -587,7 +587,7 @@ class RevisionStep(MigrationStep):
 
             descendants = set(
                 r.revision for r in self.revision_map._get_descendant_nodes(
-                    self.revision_map.get_revisions(downrev.nextrev),
+                    self.revision_map.get_revisions(downrev._all_nextrev),
                     check=False
                 )
             )
@@ -606,7 +606,7 @@ class RevisionStep(MigrationStep):
 
             # TODO: this doesn't work; make sure tests are here to ensure
             # this fails
-            #if len(downrev.nextrev.intersection(heads).difference(
+            #if len(downrev._all_nextrev.intersection(heads).difference(
             #        [self.revision.revision])):
 
                 return True
@@ -662,7 +662,7 @@ class RevisionStep(MigrationStep):
         if not self.is_upgrade:
             return False
 
-        downrevs = self.revision._down_revision_tuple
+        downrevs = self.revision._all_down_revisions
 
         if not downrevs:
             # is a base
@@ -680,7 +680,7 @@ class RevisionStep(MigrationStep):
         if not self.is_upgrade:
             return False
 
-        downrevs = self.revision._down_revision_tuple
+        downrevs = self.revision._all_down_revisions
 
         if len(downrevs) > 1 and \
                 len(heads.intersection(downrevs)) > 1:
@@ -692,7 +692,7 @@ class RevisionStep(MigrationStep):
         if not self.is_downgrade:
             return False
 
-        downrevs = self.revision._down_revision_tuple
+        downrevs = self.revision._all_down_revisions
 
         if self.revision.revision in heads and len(downrevs) > 1:
             return True
@@ -701,12 +701,12 @@ class RevisionStep(MigrationStep):
 
     def update_version_num(self, heads):
         if not self._has_scalar_down_revision:
-            downrev = heads.intersection(self.revision._down_revision_tuple)
+            downrev = heads.intersection(self.revision._all_down_revisions)
             assert len(downrev) == 1, \
                 "Can't do an UPDATE because downrevision is ambiguous"
             down_revision = list(downrev)[0]
         else:
-            down_revision = self.revision.down_revision
+            down_revision = self.revision._all_down_revisions[0]
 
         if self.is_upgrade:
             return down_revision, self.revision.revision
index 071452f200e3ed4cdd41afd02272c0c57f26fe65..0ab52ab6c70d0ba2a1d3185851fcb1780f52c50d 100644 (file)
@@ -81,6 +81,16 @@ class RevisionMap(object):
         self._revision_map
         return self.bases
 
+    @util.memoized_property
+    def _real_bases(self):
+        """All "real" base revisions as strings.
+
+        :return: a tuple of string revision numbers.
+
+        """
+        self._revision_map
+        return self._real_bases
+
     @util.memoized_property
     def _revision_map(self):
         """memoized attribute, initializes the revision map from the
@@ -91,6 +101,7 @@ class RevisionMap(object):
 
         heads = sqlautil.OrderedSet()
         self.bases = ()
+        self._real_bases = ()
 
         has_branch_labels = set()
         for revision in self._generator():
@@ -104,14 +115,16 @@ class RevisionMap(object):
             heads.add(revision.revision)
             if revision.is_base:
                 self.bases += (revision.revision, )
+            if revision._is_real_base:
+                self._real_bases += (revision.revision, )
 
         for rev in map_.values():
-            for downrev in rev._down_revision_tuple:
+            for downrev in rev._all_down_revisions:
                 if downrev not in map_:
                     util.warn("Revision %s referenced from %s is not present"
-                              % (rev.down_revision, rev))
+                              % (downrev, rev))
                 down_revision = map_[downrev]
-                down_revision.add_nextrev(rev.revision)
+                down_revision.add_nextrev(rev)
                 heads.discard(downrev)
 
         map_[None] = map_[()] = None
@@ -133,12 +146,13 @@ class RevisionMap(object):
                     )
                 map_[branch_label] = revision
             revision.branch_labels.update(revision.branch_labels)
-            for node in self._get_descendant_nodes([revision], map_):
+            for node in self._get_descendant_nodes(
+                    [revision], map_, include_dependencies=False):
                 node.branch_labels.update(revision.branch_labels)
 
             parent = node
             while parent and \
-                    not parent.is_branch_point and not parent.is_merge_point:
+                    not parent._is_real_branch_point and not parent.is_merge_point:
 
                 parent.branch_labels.update(revision.branch_labels)
                 if parent.down_revision:
@@ -164,18 +178,20 @@ class RevisionMap(object):
         self._add_branches(revision, map_)
         if revision.is_base:
             self.bases += (revision.revision, )
-        for downrev in revision._down_revision_tuple:
+        if revision._is_real_base:
+            self._real_bases += (revision.revision, )
+        for downrev in revision._all_down_revisions:
             if downrev not in map_:
                 util.warn(
                     "Revision %s referenced from %s is not present"
-                    % (revision.down_revision, revision)
+                    % (downrev, revision)
                 )
-            map_[downrev].add_nextrev(revision.revision)
-        if revision.is_head:
+            map_[downrev].add_nextrev(revision)
+        if revision._is_real_head:
             self.heads = tuple(
                 head for head in self.heads
                 if head not in
-                set(revision._down_revision_tuple).union([revision.revision])
+                set(revision._all_down_revisions).union([revision.revision])
             ) + (revision.revision,)
 
     def get_current_head(self, branch_label=None):
@@ -334,8 +350,12 @@ class RevisionMap(object):
         ]
 
         return bool(
-            set(self._get_descendant_nodes([target]))
-            .union(self._get_ancestor_nodes([target]))
+            set(self._get_descendant_nodes([target],
+                include_dependencies=False
+                ))
+            .union(self._get_ancestor_nodes([target],
+                   include_dependencies=False
+                   ))
             .intersection(test_against_revs)
         )
 
@@ -424,16 +444,28 @@ class RevisionMap(object):
             return self._iterate_revisions(
                 upper, lower, inclusive=inclusive, implicit_base=implicit_base)
 
-    def _get_descendant_nodes(self, targets, map_=None, check=False):
+    def _get_descendant_nodes(
+            self, targets, map_=None, check=False, include_dependencies=True):
+
+        if include_dependencies:
+            fn = lambda rev: rev._all_nextrev
+        else:
+            fn = lambda rev: rev.nextrev
+
         return self._iterate_related_revisions(
-            lambda rev: rev.nextrev,
-            targets, map_=map_, check=check
+            fn, targets, map_=map_, check=check
         )
 
-    def _get_ancestor_nodes(self, targets, map_=None, check=False):
+    def _get_ancestor_nodes(
+            self, targets, map_=None, check=False, include_dependencies=True):
+
+        if include_dependencies:
+            fn = lambda rev: rev._all_down_revisions
+        else:
+            fn = lambda rev: rev._versioned_down_revisions
+
         return self._iterate_related_revisions(
-            lambda rev: rev._down_revision_tuple,
-            targets, map_=map_, check=check
+            fn, targets, map_=map_, check=check
         )
 
     def _iterate_related_revisions(self, fn, targets, map_, check=False):
@@ -494,17 +526,17 @@ class RevisionMap(object):
                 difference(lower_ancestors).\
                 difference(lower_descendants)
             for rev in candidate_lowers:
-                for downrev in rev._down_revision_tuple:
+                for downrev in rev._all_down_revisions:
                     if self._revision_map[downrev] in candidate_lowers:
                         break
                 else:
                     base_lowers.add(rev)
             lowers = base_lowers.union(requested_lowers)
         elif implicit_base:
-            base_lowers = set(self.get_revisions(self.bases))
+            base_lowers = set(self.get_revisions(self._real_bases))
             lowers = base_lowers.union(requested_lowers)
         elif not requested_lowers:
-            lowers = set(self.get_revisions(self.bases))
+            lowers = set(self.get_revisions(self._real_bases))
         else:
             lowers = requested_lowers
 
@@ -523,12 +555,12 @@ class RevisionMap(object):
         branch_todo = set(
             rev for rev in
             (self._revision_map[rev] for rev in total_space)
-            if rev.is_branch_point and
-            len(total_space.intersection(rev.nextrev)) > 1
+            if rev._is_real_branch_point and
+            len(total_space.intersection(rev._all_nextrev)) > 1
         )
 
         # it's not possible for any "uppers" to be in branch_todo,
-        # because the .nextrev of those nodes is not in total_space
+        # because the ._all_nextrev of those nodes is not in total_space
         #assert not branch_todo.intersection(uppers)
 
         todo = collections.deque(
@@ -541,11 +573,17 @@ class RevisionMap(object):
             # descendants left in the queue
             if not todo:
                 todo.extendleft(
-                    rev for rev in branch_todo
-                    if not rev.nextrev.intersection(total_space)
+                    sorted(
+                        (
+                            rev for rev in branch_todo
+                            if not rev._all_nextrev.intersection(total_space)
+                        ),
+                        # favor "revisioned" branch points before
+                        # dependent ones
+                        key=lambda rev: 0 if rev.is_branch_point else 1
+                    )
                 )
                 branch_todo.difference_update(todo)
-
             # iterate nodes that are in the immediate todo
             while todo:
                 rev = todo.popleft()
@@ -555,7 +593,7 @@ class RevisionMap(object):
                 # don't consume any actual branch nodes
                 todo.extendleft([
                     self._revision_map[downrev]
-                    for downrev in reversed(rev._down_revision_tuple)
+                    for downrev in reversed(rev._all_down_revisions)
                     if self._revision_map[downrev] not in branch_todo
                     and downrev in total_space])
 
@@ -577,28 +615,56 @@ class Revision(object):
 
     """
     nextrev = frozenset()
+    """following revisions, based on down_revision only."""
+
+    _all_nextrev = frozenset()
 
     revision = None
     """The string revision number."""
 
     down_revision = None
-    """The ``down_revision`` identifier(s) within the migration script."""
+    """The ``down_revision`` identifier(s) within the migration script.
+
+    Note that the total set of "down" revisions is
+    down_revision + dependencies.
+
+    """
+
+    dependencies = None
+    """Additional revisions which this revision is dependent on.
+
+    From a migration standpoint, these dependencies are added to the
+    down_revision to form the full iteration.  However, the separation
+    of down_revision from "dependencies" is to assist in navigating
+    a history that contains many branches, typically a multi-root scenario.
+
+    """
 
     branch_labels = None
     """Optional string/tuple of symbolic names to apply to this
     revision's branch"""
 
-    def __init__(self, revision, down_revision, branch_labels=None):
+    def __init__(
+            self, revision, down_revision,
+            dependencies=None, branch_labels=None):
         self.revision = revision
         self.down_revision = tuple_rev_as_scalar(down_revision)
+        self.dependencies = tuple_rev_as_scalar(dependencies)
         self._orig_branch_labels = util.to_tuple(branch_labels, default=())
         self.branch_labels = set(self._orig_branch_labels)
 
-    def add_nextrev(self, rev):
-        self.nextrev = self.nextrev.union([rev])
+    def add_nextrev(self, revision):
+        self._all_nextrev = self._all_nextrev.union([revision.revision])
+        if self.revision in revision._versioned_down_revisions:
+            self.nextrev = self.nextrev.union([revision.revision])
+
+    @property
+    def _all_down_revisions(self):
+        return util.to_tuple(self.down_revision, default=()) + \
+            util.to_tuple(self.dependencies, default=())
 
     @property
-    def _down_revision_tuple(self):
+    def _versioned_down_revisions(self):
         return util.to_tuple(self.down_revision, default=())
 
     @property
@@ -612,12 +678,23 @@ class Revision(object):
         """
         return not bool(self.nextrev)
 
+    @property
+    def _is_real_head(self):
+        return not bool(self._all_nextrev)
+
     @property
     def is_base(self):
         """Return True if this :class:`.Revision` is a 'base' revision."""
 
         return self.down_revision is None
 
+    @property
+    def _is_real_base(self):
+        """Return True if this :class:`.Revision` is a "real" base revision,
+        e.g. that it has no dependencies either."""
+
+        return self.down_revision is None and self.dependencies is None
+
     @property
     def is_branch_point(self):
         """Return True if this :class:`.Script` is a branch point.
@@ -630,11 +707,19 @@ class Revision(object):
         """
         return len(self.nextrev) > 1
 
+    @property
+    def _is_real_branch_point(self):
+        """Return True if this :class:`.Script` is a 'real' branch point,
+        taking into account dependencies as well.
+
+        """
+        return len(self._all_nextrev) > 1
+
     @property
     def is_merge_point(self):
         """Return True if this :class:`.Script` is a merge point."""
 
-        return len(self._down_revision_tuple) > 1
+        return len(self._versioned_down_revisions) > 1
 
 
 def tuple_rev_as_scalar(rev):
index 1835605ff949fb04fad2953c52ff03981659b92d..21476521b628e46a4830a6da24b60989b724d130 100644 (file)
@@ -483,7 +483,10 @@ class Script(revision.Revision):
             rev_id,
             module.down_revision,
             branch_labels=util.to_tuple(
-                getattr(module, 'branch_labels', None), default=()))
+                getattr(module, 'branch_labels', None), default=()),
+            dependencies=util.to_tuple(
+                getattr(module, 'depends_on', None), default=())
+        )
 
     module = None
     """The Python module representing the actual script itself."""
@@ -522,6 +525,10 @@ class Script(revision.Revision):
         else:
             entry += "Parent: %s\n" % (self._format_down_revision(), )
 
+        if self.dependencies:
+            entry += "Depends on: %s\n" % (
+                util.format_as_comma(self.dependencies))
+
         if self.is_branch_point:
             entry += "Branches into: %s\n" % (
                 util.format_as_comma(self.nextrev))
@@ -554,8 +561,15 @@ class Script(revision.Revision):
             include_parents=False, tree_indicators=True):
         text = self.revision
         if include_parents:
-            text = "%s -> %s" % (
-                self._format_down_revision(), text)
+            if self.dependencies:
+                text = "%s (%s) -> %s" % (
+                    self._format_down_revision(),
+                    util.format_as_comma(self.dependencies),
+                    text
+                )
+            else:
+                text = "%s -> %s" % (
+                    self._format_down_revision(), text)
         if include_branches and self.branch_labels:
             text += " (%s)" % util.format_as_comma(self.branch_labels)
         if tree_indicators:
@@ -584,7 +598,7 @@ class Script(revision.Revision):
         if not self.down_revision:
             return "<base>"
         else:
-            return util.format_as_comma(self._down_revision_tuple)
+            return util.format_as_comma(self._versioned_down_revisions)
 
     @classmethod
     def _from_path(cls, scriptdir, path):
index 10fae91e1e1ab183966f9fc702cfc9049f1d60cd..e9e8935c1b7546d1fb3d28fb4fc8c2dc5eb91560 100644 (file)
@@ -715,3 +715,160 @@ class MultipleBaseTest(DownIterateTest):
             ['b3', 'a3', 'b2', 'a2', 'base2'],
             inclusive=False, implicit_base=True
         )
+
+
+class MultipleBaseCrossDependencyTestOne(DownIterateTest):
+    def setUp(self):
+        self.map = RevisionMap(
+            lambda: [
+                Revision('base1', (), branch_labels='b_1'),
+                Revision('a1a', ('base1',)),
+                Revision('a1b', ('base1',)),
+                Revision('b1a', ('a1a',)),
+                Revision('b1b', ('a1b', ), dependencies='a3'),
+
+                Revision('base2', (), branch_labels='b_2'),
+                Revision('a2', ('base2',)),
+                Revision('b2', ('a2',)),
+                Revision('c2', ('b2', ), dependencies='a3'),
+                Revision('d2', ('c2',)),
+
+                Revision('base3', (), branch_labels='b_3'),
+                Revision('a3', ('base3',)),
+                Revision('b3', ('a3',)),
+            ]
+        )
+
+    def test_what_are_the_heads(self):
+        eq_(self.map.heads, ("b1a", "b1b", "d2", "b3"))
+
+    def test_heads_to_base(self):
+        self._assert_iteration(
+            "heads", "base",
+            [
+
+                'b1a', 'a1a', 'b1b', 'a1b', 'd2', 'c2', 'b2', 'a2', 'base2',
+                'b3', 'a3', 'base3',
+                'base1'
+            ]
+        )
+
+    def test_we_need_head2(self):
+        # the 2 branch relies on the 3 branch
+        self._assert_iteration(
+            "b_2@head", "base",
+            ['d2', 'c2', 'b2', 'a2', 'base2', 'a3', 'base3']
+        )
+
+    def test_we_need_head3(self):
+        # the 3 branch can be upgraded alone.
+        self._assert_iteration(
+            "b_3@head", "base",
+            ['b3', 'a3', 'base3']
+        )
+
+    def test_we_need_head1(self):
+        # the 1 branch relies on the 3 branch
+        self._assert_iteration(
+            "b1b@head", "base",
+            ['b1b', 'a1b', 'base1', 'a3', 'base3']
+        )
+
+    def test_we_need_base2(self):
+        # consider a downgrade to b_2@base - we
+        # want to run through all the "2"s alone, and we're done.
+        self._assert_iteration(
+            "heads", "b_2@base",
+            ['d2', 'c2', 'b2', 'a2', 'base2']
+        )
+
+    def test_we_need_base3(self):
+        # consider a downgrade to b_3@base - due to the a3 dependency, we
+        # need to downgrade everything dependent on a3
+        # as well, which means b1b and c2.  Then we can downgrade
+        # the 3s.
+        self._assert_iteration(
+            "heads", "b_3@base",
+            ['b1b', 'd2', 'c2', 'b3', 'a3', 'base3']
+        )
+
+
+class MultipleBaseCrossDependencyTestTwo(DownIterateTest):
+    def setUp(self):
+        self.map = RevisionMap(
+            lambda: [
+                Revision('base1', (), branch_labels='b_1'),
+                Revision('a1', 'base1'),
+                Revision('b1', 'a1'),
+                Revision('c1', 'b1'),
+
+                Revision('base2', (), dependencies='base1', branch_labels='b_2'),
+                Revision('a2', 'base2'),
+                Revision('b2', 'a2'),
+                Revision('c2', 'b2'),
+                Revision('d2', 'c2'),
+
+                Revision('base3', (), branch_labels='b_3'),
+                Revision('a3', 'base3'),
+                Revision('b3', 'a3'),
+                Revision('c3', 'b3', dependencies='b2'),
+                Revision('d3', 'c3'),
+            ]
+        )
+
+    def test_what_are_the_heads(self):
+        eq_(self.map.heads, ("c1", "d2", "d3"))
+
+    def test_heads_to_base(self):
+        self._assert_iteration(
+            "heads", "base",
+            [
+                'c1', 'b1', 'a1',
+                'd2', 'c2',
+                'd3', 'c3', 'b3', 'a3', 'base3',
+                'b2', 'a2', 'base2',
+                'base1'
+            ]
+        )
+
+    def test_we_need_head2(self):
+        self._assert_iteration(
+            "b_2@head", "base",
+            ['d2', 'c2', 'b2', 'a2', 'base2', 'base1']
+        )
+
+    def test_we_need_head3(self):
+        self._assert_iteration(
+            "b_3@head", "base",
+            ['d3', 'c3', 'b3', 'a3', 'base3', 'b2', 'a2', 'base2', 'base1']
+        )
+
+    def test_we_need_head1(self):
+        self._assert_iteration(
+            "b_1@head", "base",
+            ['c1', 'b1', 'a1', 'base1']
+        )
+
+    def test_we_need_base1(self):
+        self._assert_iteration(
+            "heads", "b_1@base",
+            [
+                'c1', 'b1', 'a1',
+                'd2', 'c2',
+                'd3', 'c3', 'b2', 'a2', 'base2',
+                'base1'
+            ]
+        )
+
+    def test_we_need_base2(self):
+        self._assert_iteration(
+            "heads", "b_2@base",
+            ['d2', 'c2', 'd3', 'c3', 'b2', 'a2', 'base2']
+        )
+
+    def test_we_need_base3(self):
+        self._assert_iteration(
+            "heads", "b_3@base",
+            ['d3', 'c3', 'b3', 'a3', 'base3']
+        )
+