From: Mike Bayer Date: Thu, 19 Sep 2019 03:16:56 +0000 (-0400) Subject: Add multiple revision support to stamp, support purge X-Git-Tag: rel_1_2_0~7^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fdcc5bf973fec758eaf67cf8e8ad48985caffb65;p=thirdparty%2Fsqlalchemy%2Falembic.git Add multiple revision support to stamp, support purge Added new ``--purge`` flag to the ``alembic stamp`` command, which will unconditionally erase the version table before stamping anything. This is useful for development where non-existent version identifiers might be left within the table. Additionally, ``alembic.stamp`` now supports a list of revision identifiers, which are intended to allow setting up muliple heads at once. Overall handling of version identifiers within the ``alembic.stamp`` command has been improved with many new tests and use cases added. Fixes: #473 Change-Id: If06501b69afae9956df3d0bcd739063fb8042a02 --- diff --git a/alembic/command.py b/alembic/command.py index c02db62a..fde74d95 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -496,7 +496,7 @@ def current(config, verbose=False, head_only=False): script.run_env() -def stamp(config, revision, sql=False, tag=None): +def stamp(config, revisions, sql=False, tag=None, purge=False): """'stamp' the revision table with the given revision; don't run any migrations. @@ -510,27 +510,45 @@ def stamp(config, revision, sql=False, tag=None): ``env.py`` scripts via the :class:`.EnvironmentContext.get_tag_argument` method. + :param purge: delete all entries in the version table before stamping. + + .. versionadded:: 1.2 + """ script = ScriptDirectory.from_config(config) - starting_rev = None - if ":" in revision: - if not sql: - raise util.CommandError("Range revision not allowed") - starting_rev, revision = revision.split(":", 2) + if sql: + destination_revs = [] + starting_rev = None + for revision in util.to_list(revisions): + if ":" in revision: + srev, revision = revision.split(":", 2) + + if starting_rev != srev: + if starting_rev is None: + starting_rev = srev + else: + raise util.CommandError( + "Stamp operation with --sql only supports a " + "single starting revision at a time" + ) + destination_revs.append(revision) + else: + destination_revs = util.to_list(revisions) def do_stamp(rev, context): - return script._stamp_revs(revision, rev) + return script._stamp_revs(util.to_tuple(destination_revs), rev) with EnvironmentContext( config, script, fn=do_stamp, as_sql=sql, - destination_rev=revision, - starting_rev=starting_rev, + starting_rev=starting_rev if sql else None, + destination_rev=util.to_tuple(destination_revs), tag=tag, + purge=purge, ): script.run_env() diff --git a/alembic/config.py b/alembic/config.py index 745ca8b3..6ae94164 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -422,6 +422,14 @@ class CommandLine(object): help="Indicate the current revision", ), ), + "purge": ( + "--purge", + dict( + action="store_true", + help="Unconditionally erase the version table " + "before stamping", + ), + ), } positional_help = { "directory": "location of scripts directory", diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 31bc727b..0a1e646d 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -109,6 +109,8 @@ class MigrationContext(object): self._migrations_fn = opts.get("fn") self.as_sql = as_sql + self.purge = opts.get("purge", False) + if "output_encoding" in opts: self.output_buffer = EncodedIO( opts.get("output_buffer") or sys.stdout, @@ -415,10 +417,12 @@ class MigrationContext(object): if start_from_rev == "base": start_from_rev = None elif start_from_rev is not None and self.script: - start_from_rev = self.script.get_revision( - start_from_rev - ).revision + start_from_rev = [ + self.script.get_revision(sfr).revision + for sfr in util.to_list(start_from_rev) + if sfr not in (None, "base") + ] return util.to_tuple(start_from_rev, default=()) else: if self._start_from_rev: @@ -432,8 +436,10 @@ class MigrationContext(object): row[0] for row in self.connection.execute(self._version.select()) ) - def _ensure_version_table(self): + def _ensure_version_table(self, purge=False): self._version.create(self.connection, checkfirst=True) + if purge: + self.connection.execute(self._version.delete()) def _has_version_table(self): return self.connection.dialect.has_table( @@ -481,9 +487,16 @@ class MigrationContext(object): """ self.impl.start_migrations() - heads = self.get_current_heads() - if not self.as_sql and not heads: - self._ensure_version_table() + if self.purge: + if self.as_sql: + raise util.CommandError("Can't use --purge with --sql mode") + self._ensure_version_table(purge=True) + heads = () + else: + heads = self.get_current_heads() + + if not self.as_sql and not heads: + self._ensure_version_table() head_maintainer = HeadMaintainer(self, heads) @@ -1182,10 +1195,17 @@ class StampStep(MigrationStep): ) def should_delete_branch(self, heads): + # TODO: we probably need to look for self.to_ inside of heads, + # in a similar manner as should_create_branch, however we have + # no tests for this yet (stamp downgrades w/ branches) return self.is_downgrade and self.branch_move def should_create_branch(self, heads): - return self.is_upgrade and self.branch_move + return ( + self.is_upgrade + and (self.branch_move or set(self.from_).difference(heads)) + and set(self.to_).difference(heads) + ) def should_merge_branches(self, heads): return len(self.from_) > 1 diff --git a/alembic/script/base.py b/alembic/script/base.py index b386deac..e62a76be 100644 --- a/alembic/script/base.py +++ b/alembic/script/base.py @@ -389,16 +389,25 @@ class ScriptDirectory(object): heads = self.get_revisions(heads) - # filter for lineage will resolve things like - # branchname@base, version@base, etc. - filtered_heads = self.revision_map.filter_for_lineage( - heads, revision, include_dependencies=True - ) - steps = [] + if not revision: + revision = "base" + + filtered_heads = [] + for rev in util.to_tuple(revision): + if rev: + filtered_heads.extend( + self.revision_map.filter_for_lineage( + heads, rev, include_dependencies=True + ) + ) + filtered_heads = util.unique_list(filtered_heads) + dests = self.get_revisions(revision) or [None] + for dest in dests: + if dest is None: # dest is 'base'. Return a "delete branch" migration # for all applicable heads. @@ -461,6 +470,7 @@ class ScriptDirectory(object): ) steps.append(step) continue + return steps def run_env(self): diff --git a/alembic/testing/env.py b/alembic/testing/env.py index f3267f13..01e78f85 100644 --- a/alembic/testing/env.py +++ b/alembic/testing/env.py @@ -290,7 +290,7 @@ def three_rev_fixture(cfg): c = util.rev_id() script = ScriptDirectory.from_config(cfg) - script.generate_revision(a, "revision a", refresh=True) + script.generate_revision(a, "revision a", refresh=True, head="base") write_script( script, a, @@ -313,7 +313,7 @@ def downgrade(): % a, ) - script.generate_revision(b, "revision b", refresh=True) + script.generate_revision(b, "revision b", refresh=True, head=a) write_script( script, b, @@ -338,7 +338,7 @@ def downgrade(): encoding="utf-8", ) - script.generate_revision(c, "revision c", refresh=True) + script.generate_revision(c, "revision c", refresh=True, head=b) write_script( script, c, @@ -366,6 +366,9 @@ def downgrade(): def multi_heads_fixture(cfg, a, b, c): """Create a multiple head fixture from the three-revs fixture""" + # a->b->c + # -> d -> e + # -> f d = util.rev_id() e = util.rev_id() f = util.rev_id() diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index 86cdfef8..b3990a5b 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -19,6 +19,7 @@ from .assertions import _get_dialect from ..environment import EnvironmentContext from ..migration import MigrationContext from ..operations import Operations +from ..util import compat from ..util.compat import configparser from ..util.compat import string_types from ..util.compat import text_type @@ -28,13 +29,13 @@ testing_config = configparser.ConfigParser() testing_config.read(["test.cfg"]) -def capture_db(): +def capture_db(dialect="postgresql://"): buf = [] def dump(sql, *multiparams, **params): buf.append(str(sql.compile(dialect=engine.dialect))) - engine = create_mock_engine("postgresql://", dump) + engine = create_mock_engine(dialect, dump) return engine, buf @@ -59,6 +60,32 @@ def capture_context_buffer(**kw): yield buf +@contextmanager +def capture_engine_context_buffer(**kw): + from .env import _sqlite_file_db + from sqlalchemy import event + + buf = compat.StringIO() + + eng = _sqlite_file_db() + + conn = eng.connect() + + @event.listens_for(conn, "before_cursor_execute") + def bce(conn, cursor, statement, parameters, context, executemany): + buf.write(statement + "\n") + + kw.update({"connection": conn}) + conf = EnvironmentContext.configure + + def configure(*arg, **opt): + opt.update(**kw) + return conf(*arg, **opt) + + with mock.patch.object(EnvironmentContext, "configure", configure): + yield buf + + def op_fixture( dialect="default", as_sql=False, diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index fbe88e3a..0ed9a377 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -9,6 +9,7 @@ from .langhelpers import ModuleClsProxy # noqa from .langhelpers import rev_id # noqa from .langhelpers import to_list # noqa from .langhelpers import to_tuple # noqa +from .langhelpers import unique_list # noqa from .messaging import err # noqa from .messaging import format_as_comma # noqa from .messaging import msg # noqa diff --git a/docs/build/unreleased/473.rst b/docs/build/unreleased/473.rst new file mode 100644 index 00000000..c5a5d536 --- /dev/null +++ b/docs/build/unreleased/473.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: feature, command + :tickets: 473 + + Added new ``--purge`` flag to the ``alembic stamp`` command, which will + unconditionally erase the version table before stamping anything. This is + useful for development where non-existent version identifiers might be left + within the table. Additionally, ``alembic.stamp`` now supports a list of + revision identifiers, which are intended to allow setting up muliple heads + at once. Overall handling of version identifiers within the + ``alembic.stamp`` command has been improved with many new tests and + use cases added. diff --git a/tests/test_command.py b/tests/test_command.py index 746471be..e1be6787 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -20,10 +20,12 @@ from alembic.testing.env import _sqlite_file_db from alembic.testing.env import _sqlite_testing_config from alembic.testing.env import clear_staging_env from alembic.testing.env import env_file_fixture +from alembic.testing.env import multi_heads_fixture from alembic.testing.env import staging_env from alembic.testing.env import three_rev_fixture from alembic.testing.env import write_script from alembic.testing.fixtures import capture_context_buffer +from alembic.testing.fixtures import capture_engine_context_buffer from alembic.testing.fixtures import TestBase @@ -502,6 +504,179 @@ finally: command.revision(self.cfg, sql=True) +class _StampTest(object): + def _assert_sql(self, emitted_sql, origin, destinations): + ins_expr = ( + r"INSERT INTO alembic_version \(version_num\) " + r"VALUES \('(.+)'\)" + ) + expected = [ins_expr for elem in destinations] + if origin: + expected[0] = ( + "UPDATE alembic_version SET version_num='(.+)' WHERE " + "alembic_version.version_num = '%s'" % (origin,) + ) + for line in emitted_sql.split("\n"): + if not expected: + assert not re.match( + ins_expr, line + ), "additional inserts were emitted" + else: + m = re.match(expected[0], line) + if m: + destinations.remove(m.group(1)) + expected.pop(0) + + assert not expected, "lines remain" + + +class StampMultipleRootsTest(TestBase, _StampTest): + def setUp(self): + self.env = staging_env() + # self.cfg = cfg = _no_sql_testing_config() + self.cfg = cfg = _sqlite_testing_config() + # cfg.set_main_option("dialect_name", "sqlite") + # cfg.remove_main_option("url") + + self.a1, self.b1, self.c1 = three_rev_fixture(cfg) + self.a2, self.b2, self.c2 = three_rev_fixture(cfg) + + def tearDown(self): + clear_staging_env() + + def test_sql_stamp_heads(self): + with capture_context_buffer() as buf: + command.stamp(self.cfg, ["heads"], sql=True) + + self._assert_sql(buf.getvalue(), None, {self.c1, self.c2}) + + def test_sql_stamp_single_head(self): + with capture_context_buffer() as buf: + command.stamp(self.cfg, ["%s@head" % self.c1], sql=True) + + self._assert_sql(buf.getvalue(), None, {self.c1}) + + +class StampMultipleHeadsTest(TestBase, _StampTest): + def setUp(self): + self.env = staging_env() + # self.cfg = cfg = _no_sql_testing_config() + self.cfg = cfg = _sqlite_testing_config() + # cfg.set_main_option("dialect_name", "sqlite") + # cfg.remove_main_option("url") + + self.a, self.b, self.c = three_rev_fixture(cfg) + self.d, self.e, self.f = multi_heads_fixture( + cfg, self.a, self.b, self.c + ) + + def tearDown(self): + clear_staging_env() + + def test_sql_stamp_heads(self): + with capture_context_buffer() as buf: + command.stamp(self.cfg, ["heads"], sql=True) + + self._assert_sql(buf.getvalue(), None, {self.c, self.e, self.f}) + + def test_sql_stamp_multi_rev_nonsensical(self): + with capture_context_buffer() as buf: + command.stamp(self.cfg, [self.a, self.e, self.f], sql=True) + # TODO: this shouldn't be possible, because e/f require b as a + # dependency + self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f}) + + def test_sql_stamp_multi_rev_from_multi_base_nonsensical(self): + with capture_context_buffer() as buf: + command.stamp( + self.cfg, + ["base:%s" % self.a, "base:%s" % self.e, "base:%s" % self.f], + sql=True, + ) + + # TODO: this shouldn't be possible, because e/f require b as a + # dependency + self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f}) + + def test_online_stamp_multi_rev_nonsensical(self): + with capture_engine_context_buffer() as buf: + command.stamp(self.cfg, [self.a, self.e, self.f]) + + # TODO: this shouldn't be possible, because e/f require b as a + # dependency + self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f}) + + def test_online_stamp_multi_rev_from_real_ancestor(self): + command.stamp(self.cfg, [self.a]) + with capture_engine_context_buffer() as buf: + command.stamp(self.cfg, [self.e, self.f]) + + self._assert_sql(buf.getvalue(), self.a, {self.e, self.f}) + + def test_online_stamp_version_already_there(self): + command.stamp(self.cfg, [self.c, self.e]) + with capture_engine_context_buffer() as buf: + command.stamp(self.cfg, [self.c, self.e]) + self._assert_sql(buf.getvalue(), None, {}) + + def test_sql_stamp_multi_rev_from_multi_start(self): + with capture_context_buffer() as buf: + command.stamp( + self.cfg, + [ + "%s:%s" % (self.b, self.c), + "%s:%s" % (self.b, self.e), + "%s:%s" % (self.b, self.f), + ], + sql=True, + ) + + self._assert_sql(buf.getvalue(), self.b, {self.c, self.e, self.f}) + + def test_sql_stamp_heads_symbolic(self): + with capture_context_buffer() as buf: + command.stamp(self.cfg, ["%s:heads" % self.b], sql=True) + + self._assert_sql(buf.getvalue(), self.b, {self.c, self.e, self.f}) + + def test_sql_stamp_different_multi_start(self): + assert_raises_message( + util.CommandError, + "Stamp operation with --sql only supports a single " + "starting revision at a time", + command.stamp, + self.cfg, + ["%s:%s" % (self.b, self.c), "%s:%s" % (self.a, self.e)], + sql=True, + ) + + def test_stamp_purge(self): + command.stamp(self.cfg, [self.a]) + + eng = _sqlite_file_db() + with eng.connect() as conn: + result = conn.execute( + "update alembic_version set version_num='fake'" + ) + eq_(result.rowcount, 1) + + with capture_engine_context_buffer() as buf: + command.stamp(self.cfg, [self.a, self.e, self.f], purge=True) + + self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f}) + + def test_stamp_purge_no_sql(self): + assert_raises_message( + util.CommandError, + "Can't use --purge with --sql mode", + command.stamp, + self.cfg, + [self.c], + sql=True, + purge=True, + ) + + class UpgradeDowngradeStampTest(TestBase): def setUp(self): self.env = staging_env() @@ -643,6 +818,14 @@ down_revision = '%s' self.bind.scalar("select version_num from alembic_version"), self.a ) + def test_stamp_version_already_there(self): + command.stamp(self.cfg, self.b) + command.stamp(self.cfg, self.b) + + eq_( + self.bind.scalar("select version_num from alembic_version"), self.b + ) + class EditTest(TestBase): @classmethod