]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Add multiple revision support to stamp, support purge
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Sep 2019 03:16:56 +0000 (23:16 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 19 Sep 2019 17:57:49 +0000 (13:57 -0400)
Added new ``--purge`` flag to the ``alembic stamp`` command, which will
unconditionally erase the version table before stamping anything.  This is
useful for development where non-existent version identifiers might be left
within the table.  Additionally, ``alembic.stamp`` now supports a list of
revision identifiers, which are intended to allow setting up muliple heads
at once.  Overall handling of version identifiers within the
``alembic.stamp`` command has been improved with many new tests and
use cases added.

Fixes: #473
Change-Id: If06501b69afae9956df3d0bcd739063fb8042a02

alembic/command.py
alembic/config.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/testing/env.py
alembic/testing/fixtures.py
alembic/util/__init__.py
docs/build/unreleased/473.rst [new file with mode: 0644]
tests/test_command.py

index c02db62a8d7d0111040ad6b09f9bc918238f0fa1..fde74d95aaceb84fd931f7d017be3a32d4d8a4f2 100644 (file)
@@ -496,7 +496,7 @@ def current(config, verbose=False, head_only=False):
         script.run_env()
 
 
-def stamp(config, revision, sql=False, tag=None):
+def stamp(config, revisions, sql=False, tag=None, purge=False):
     """'stamp' the revision table with the given revision; don't
     run any migrations.
 
@@ -510,27 +510,45 @@ def stamp(config, revision, sql=False, tag=None):
      ``env.py`` scripts via the :class:`.EnvironmentContext.get_tag_argument`
      method.
 
+    :param purge: delete all entries in the version table before stamping.
+
+     .. versionadded:: 1.2
+
     """
 
     script = ScriptDirectory.from_config(config)
 
-    starting_rev = None
-    if ":" in revision:
-        if not sql:
-            raise util.CommandError("Range revision not allowed")
-        starting_rev, revision = revision.split(":", 2)
+    if sql:
+        destination_revs = []
+        starting_rev = None
+        for revision in util.to_list(revisions):
+            if ":" in revision:
+                srev, revision = revision.split(":", 2)
+
+                if starting_rev != srev:
+                    if starting_rev is None:
+                        starting_rev = srev
+                    else:
+                        raise util.CommandError(
+                            "Stamp operation with --sql only supports a "
+                            "single starting revision at a time"
+                        )
+            destination_revs.append(revision)
+    else:
+        destination_revs = util.to_list(revisions)
 
     def do_stamp(rev, context):
-        return script._stamp_revs(revision, rev)
+        return script._stamp_revs(util.to_tuple(destination_revs), rev)
 
     with EnvironmentContext(
         config,
         script,
         fn=do_stamp,
         as_sql=sql,
-        destination_rev=revision,
-        starting_rev=starting_rev,
+        starting_rev=starting_rev if sql else None,
+        destination_rev=util.to_tuple(destination_revs),
         tag=tag,
+        purge=purge,
     ):
         script.run_env()
 
index 745ca8b3155dba7b4973a77633dadb364e998948..6ae941646355d64225dd6513cb593f8398531346 100644 (file)
@@ -422,6 +422,14 @@ class CommandLine(object):
                         help="Indicate the current revision",
                     ),
                 ),
+                "purge": (
+                    "--purge",
+                    dict(
+                        action="store_true",
+                        help="Unconditionally erase the version table "
+                        "before stamping",
+                    ),
+                ),
             }
             positional_help = {
                 "directory": "location of scripts directory",
index 31bc727b70d25afeab379a94af54b080df69cc82..0a1e646d321c026a0335ed3e54f5a7c64ecd1a6e 100644 (file)
@@ -109,6 +109,8 @@ class MigrationContext(object):
         self._migrations_fn = opts.get("fn")
         self.as_sql = as_sql
 
+        self.purge = opts.get("purge", False)
+
         if "output_encoding" in opts:
             self.output_buffer = EncodedIO(
                 opts.get("output_buffer") or sys.stdout,
@@ -415,10 +417,12 @@ class MigrationContext(object):
             if start_from_rev == "base":
                 start_from_rev = None
             elif start_from_rev is not None and self.script:
-                start_from_rev = self.script.get_revision(
-                    start_from_rev
-                ).revision
 
+                start_from_rev = [
+                    self.script.get_revision(sfr).revision
+                    for sfr in util.to_list(start_from_rev)
+                    if sfr not in (None, "base")
+                ]
             return util.to_tuple(start_from_rev, default=())
         else:
             if self._start_from_rev:
@@ -432,8 +436,10 @@ class MigrationContext(object):
             row[0] for row in self.connection.execute(self._version.select())
         )
 
-    def _ensure_version_table(self):
+    def _ensure_version_table(self, purge=False):
         self._version.create(self.connection, checkfirst=True)
+        if purge:
+            self.connection.execute(self._version.delete())
 
     def _has_version_table(self):
         return self.connection.dialect.has_table(
@@ -481,9 +487,16 @@ class MigrationContext(object):
         """
         self.impl.start_migrations()
 
-        heads = self.get_current_heads()
-        if not self.as_sql and not heads:
-            self._ensure_version_table()
+        if self.purge:
+            if self.as_sql:
+                raise util.CommandError("Can't use --purge with --sql mode")
+            self._ensure_version_table(purge=True)
+            heads = ()
+        else:
+            heads = self.get_current_heads()
+
+            if not self.as_sql and not heads:
+                self._ensure_version_table()
 
         head_maintainer = HeadMaintainer(self, heads)
 
@@ -1182,10 +1195,17 @@ class StampStep(MigrationStep):
         )
 
     def should_delete_branch(self, heads):
+        # TODO: we probably need to look for self.to_ inside of heads,
+        # in a similar manner as should_create_branch, however we have
+        # no tests for this yet (stamp downgrades w/ branches)
         return self.is_downgrade and self.branch_move
 
     def should_create_branch(self, heads):
-        return self.is_upgrade and self.branch_move
+        return (
+            self.is_upgrade
+            and (self.branch_move or set(self.from_).difference(heads))
+            and set(self.to_).difference(heads)
+        )
 
     def should_merge_branches(self, heads):
         return len(self.from_) > 1
index b386deac8288f99929c255526b9316a3750ce8d6..e62a76bec9ac7c9d5b192f0810f1e98dfe86aba3 100644 (file)
@@ -389,16 +389,25 @@ class ScriptDirectory(object):
 
             heads = self.get_revisions(heads)
 
-            # filter for lineage will resolve things like
-            # branchname@base, version@base, etc.
-            filtered_heads = self.revision_map.filter_for_lineage(
-                heads, revision, include_dependencies=True
-            )
-
             steps = []
 
+            if not revision:
+                revision = "base"
+
+            filtered_heads = []
+            for rev in util.to_tuple(revision):
+                if rev:
+                    filtered_heads.extend(
+                        self.revision_map.filter_for_lineage(
+                            heads, rev, include_dependencies=True
+                        )
+                    )
+            filtered_heads = util.unique_list(filtered_heads)
+
             dests = self.get_revisions(revision) or [None]
+
             for dest in dests:
+
                 if dest is None:
                     # dest is 'base'.  Return a "delete branch" migration
                     # for all applicable heads.
@@ -461,6 +470,7 @@ class ScriptDirectory(object):
                     )
                     steps.append(step)
                     continue
+
             return steps
 
     def run_env(self):
index f3267f13f4b9e23d362aebfe67ebfb544115f00a..01e78f8564aa579da3dd30b7f9e35be4043e1a57 100644 (file)
@@ -290,7 +290,7 @@ def three_rev_fixture(cfg):
     c = util.rev_id()
 
     script = ScriptDirectory.from_config(cfg)
-    script.generate_revision(a, "revision a", refresh=True)
+    script.generate_revision(a, "revision a", refresh=True, head="base")
     write_script(
         script,
         a,
@@ -313,7 +313,7 @@ def downgrade():
         % a,
     )
 
-    script.generate_revision(b, "revision b", refresh=True)
+    script.generate_revision(b, "revision b", refresh=True, head=a)
     write_script(
         script,
         b,
@@ -338,7 +338,7 @@ def downgrade():
         encoding="utf-8",
     )
 
-    script.generate_revision(c, "revision c", refresh=True)
+    script.generate_revision(c, "revision c", refresh=True, head=b)
     write_script(
         script,
         c,
@@ -366,6 +366,9 @@ def downgrade():
 def multi_heads_fixture(cfg, a, b, c):
     """Create a multiple head fixture from the three-revs fixture"""
 
+    # a->b->c
+    #     -> d -> e
+    #     -> f
     d = util.rev_id()
     e = util.rev_id()
     f = util.rev_id()
index 86cdfef81baf2bba817810612f91a0c6152f9b9c..b3990a5b03c363a0fb07cc5bfc7fcf134ddbfab7 100644 (file)
@@ -19,6 +19,7 @@ from .assertions import _get_dialect
 from ..environment import EnvironmentContext
 from ..migration import MigrationContext
 from ..operations import Operations
+from ..util import compat
 from ..util.compat import configparser
 from ..util.compat import string_types
 from ..util.compat import text_type
@@ -28,13 +29,13 @@ testing_config = configparser.ConfigParser()
 testing_config.read(["test.cfg"])
 
 
-def capture_db():
+def capture_db(dialect="postgresql://"):
     buf = []
 
     def dump(sql, *multiparams, **params):
         buf.append(str(sql.compile(dialect=engine.dialect)))
 
-    engine = create_mock_engine("postgresql://", dump)
+    engine = create_mock_engine(dialect, dump)
     return engine, buf
 
 
@@ -59,6 +60,32 @@ def capture_context_buffer(**kw):
         yield buf
 
 
+@contextmanager
+def capture_engine_context_buffer(**kw):
+    from .env import _sqlite_file_db
+    from sqlalchemy import event
+
+    buf = compat.StringIO()
+
+    eng = _sqlite_file_db()
+
+    conn = eng.connect()
+
+    @event.listens_for(conn, "before_cursor_execute")
+    def bce(conn, cursor, statement, parameters, context, executemany):
+        buf.write(statement + "\n")
+
+    kw.update({"connection": conn})
+    conf = EnvironmentContext.configure
+
+    def configure(*arg, **opt):
+        opt.update(**kw)
+        return conf(*arg, **opt)
+
+    with mock.patch.object(EnvironmentContext, "configure", configure):
+        yield buf
+
+
 def op_fixture(
     dialect="default",
     as_sql=False,
index fbe88e3a8b5d239c1e306ca8cbd08f03d10a7634..0ed9a377fc5fee6dffa17685e04438b0d4324c8f 100644 (file)
@@ -9,6 +9,7 @@ from .langhelpers import ModuleClsProxy  # noqa
 from .langhelpers import rev_id  # noqa
 from .langhelpers import to_list  # noqa
 from .langhelpers import to_tuple  # noqa
+from .langhelpers import unique_list  # noqa
 from .messaging import err  # noqa
 from .messaging import format_as_comma  # noqa
 from .messaging import msg  # noqa
diff --git a/docs/build/unreleased/473.rst b/docs/build/unreleased/473.rst
new file mode 100644 (file)
index 0000000..c5a5d53
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: feature, command
+    :tickets: 473
+
+    Added new ``--purge`` flag to the ``alembic stamp`` command, which will
+    unconditionally erase the version table before stamping anything.  This is
+    useful for development where non-existent version identifiers might be left
+    within the table.  Additionally, ``alembic.stamp`` now supports a list of
+    revision identifiers, which are intended to allow setting up muliple heads
+    at once.  Overall handling of version identifiers within the
+    ``alembic.stamp`` command has been improved with many new tests and
+    use cases added.
index 746471be446b7b4b1da6652357e5de2bd00d1a85..e1be67875d3165b6b12a1c428c0806c602fdc915 100644 (file)
@@ -20,10 +20,12 @@ from alembic.testing.env import _sqlite_file_db
 from alembic.testing.env import _sqlite_testing_config
 from alembic.testing.env import clear_staging_env
 from alembic.testing.env import env_file_fixture
+from alembic.testing.env import multi_heads_fixture
 from alembic.testing.env import staging_env
 from alembic.testing.env import three_rev_fixture
 from alembic.testing.env import write_script
 from alembic.testing.fixtures import capture_context_buffer
+from alembic.testing.fixtures import capture_engine_context_buffer
 from alembic.testing.fixtures import TestBase
 
 
@@ -502,6 +504,179 @@ finally:
         command.revision(self.cfg, sql=True)
 
 
+class _StampTest(object):
+    def _assert_sql(self, emitted_sql, origin, destinations):
+        ins_expr = (
+            r"INSERT INTO alembic_version \(version_num\) "
+            r"VALUES \('(.+)'\)"
+        )
+        expected = [ins_expr for elem in destinations]
+        if origin:
+            expected[0] = (
+                "UPDATE alembic_version SET version_num='(.+)' WHERE "
+                "alembic_version.version_num = '%s'" % (origin,)
+            )
+        for line in emitted_sql.split("\n"):
+            if not expected:
+                assert not re.match(
+                    ins_expr, line
+                ), "additional inserts were emitted"
+            else:
+                m = re.match(expected[0], line)
+                if m:
+                    destinations.remove(m.group(1))
+                    expected.pop(0)
+
+        assert not expected, "lines remain"
+
+
+class StampMultipleRootsTest(TestBase, _StampTest):
+    def setUp(self):
+        self.env = staging_env()
+        # self.cfg = cfg = _no_sql_testing_config()
+        self.cfg = cfg = _sqlite_testing_config()
+        # cfg.set_main_option("dialect_name", "sqlite")
+        # cfg.remove_main_option("url")
+
+        self.a1, self.b1, self.c1 = three_rev_fixture(cfg)
+        self.a2, self.b2, self.c2 = three_rev_fixture(cfg)
+
+    def tearDown(self):
+        clear_staging_env()
+
+    def test_sql_stamp_heads(self):
+        with capture_context_buffer() as buf:
+            command.stamp(self.cfg, ["heads"], sql=True)
+
+        self._assert_sql(buf.getvalue(), None, {self.c1, self.c2})
+
+    def test_sql_stamp_single_head(self):
+        with capture_context_buffer() as buf:
+            command.stamp(self.cfg, ["%s@head" % self.c1], sql=True)
+
+        self._assert_sql(buf.getvalue(), None, {self.c1})
+
+
+class StampMultipleHeadsTest(TestBase, _StampTest):
+    def setUp(self):
+        self.env = staging_env()
+        # self.cfg = cfg = _no_sql_testing_config()
+        self.cfg = cfg = _sqlite_testing_config()
+        # cfg.set_main_option("dialect_name", "sqlite")
+        # cfg.remove_main_option("url")
+
+        self.a, self.b, self.c = three_rev_fixture(cfg)
+        self.d, self.e, self.f = multi_heads_fixture(
+            cfg, self.a, self.b, self.c
+        )
+
+    def tearDown(self):
+        clear_staging_env()
+
+    def test_sql_stamp_heads(self):
+        with capture_context_buffer() as buf:
+            command.stamp(self.cfg, ["heads"], sql=True)
+
+        self._assert_sql(buf.getvalue(), None, {self.c, self.e, self.f})
+
+    def test_sql_stamp_multi_rev_nonsensical(self):
+        with capture_context_buffer() as buf:
+            command.stamp(self.cfg, [self.a, self.e, self.f], sql=True)
+        # TODO: this shouldn't be possible, because e/f require b as a
+        # dependency
+        self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+    def test_sql_stamp_multi_rev_from_multi_base_nonsensical(self):
+        with capture_context_buffer() as buf:
+            command.stamp(
+                self.cfg,
+                ["base:%s" % self.a, "base:%s" % self.e, "base:%s" % self.f],
+                sql=True,
+            )
+
+        # TODO: this shouldn't be possible, because e/f require b as a
+        # dependency
+        self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+    def test_online_stamp_multi_rev_nonsensical(self):
+        with capture_engine_context_buffer() as buf:
+            command.stamp(self.cfg, [self.a, self.e, self.f])
+
+        # TODO: this shouldn't be possible, because e/f require b as a
+        # dependency
+        self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+    def test_online_stamp_multi_rev_from_real_ancestor(self):
+        command.stamp(self.cfg, [self.a])
+        with capture_engine_context_buffer() as buf:
+            command.stamp(self.cfg, [self.e, self.f])
+
+        self._assert_sql(buf.getvalue(), self.a, {self.e, self.f})
+
+    def test_online_stamp_version_already_there(self):
+        command.stamp(self.cfg, [self.c, self.e])
+        with capture_engine_context_buffer() as buf:
+            command.stamp(self.cfg, [self.c, self.e])
+        self._assert_sql(buf.getvalue(), None, {})
+
+    def test_sql_stamp_multi_rev_from_multi_start(self):
+        with capture_context_buffer() as buf:
+            command.stamp(
+                self.cfg,
+                [
+                    "%s:%s" % (self.b, self.c),
+                    "%s:%s" % (self.b, self.e),
+                    "%s:%s" % (self.b, self.f),
+                ],
+                sql=True,
+            )
+
+        self._assert_sql(buf.getvalue(), self.b, {self.c, self.e, self.f})
+
+    def test_sql_stamp_heads_symbolic(self):
+        with capture_context_buffer() as buf:
+            command.stamp(self.cfg, ["%s:heads" % self.b], sql=True)
+
+        self._assert_sql(buf.getvalue(), self.b, {self.c, self.e, self.f})
+
+    def test_sql_stamp_different_multi_start(self):
+        assert_raises_message(
+            util.CommandError,
+            "Stamp operation with --sql only supports a single "
+            "starting revision at a time",
+            command.stamp,
+            self.cfg,
+            ["%s:%s" % (self.b, self.c), "%s:%s" % (self.a, self.e)],
+            sql=True,
+        )
+
+    def test_stamp_purge(self):
+        command.stamp(self.cfg, [self.a])
+
+        eng = _sqlite_file_db()
+        with eng.connect() as conn:
+            result = conn.execute(
+                "update alembic_version set version_num='fake'"
+            )
+            eq_(result.rowcount, 1)
+
+        with capture_engine_context_buffer() as buf:
+            command.stamp(self.cfg, [self.a, self.e, self.f], purge=True)
+
+        self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+    def test_stamp_purge_no_sql(self):
+        assert_raises_message(
+            util.CommandError,
+            "Can't use --purge with --sql mode",
+            command.stamp,
+            self.cfg,
+            [self.c],
+            sql=True,
+            purge=True,
+        )
+
+
 class UpgradeDowngradeStampTest(TestBase):
     def setUp(self):
         self.env = staging_env()
@@ -643,6 +818,14 @@ down_revision = '%s'
             self.bind.scalar("select version_num from alembic_version"), self.a
         )
 
+    def test_stamp_version_already_there(self):
+        command.stamp(self.cfg, self.b)
+        command.stamp(self.cfg, self.b)
+
+        eq_(
+            self.bind.scalar("select version_num from alembic_version"), self.b
+        )
+
 
 class EditTest(TestBase):
     @classmethod