script.run_env()
-def stamp(config, revision, sql=False, tag=None):
+def stamp(config, revisions, sql=False, tag=None, purge=False):
"""'stamp' the revision table with the given revision; don't
run any migrations.
``env.py`` scripts via the :class:`.EnvironmentContext.get_tag_argument`
method.
+ :param purge: delete all entries in the version table before stamping.
+
+ .. versionadded:: 1.2
+
"""
script = ScriptDirectory.from_config(config)
- starting_rev = None
- if ":" in revision:
- if not sql:
- raise util.CommandError("Range revision not allowed")
- starting_rev, revision = revision.split(":", 2)
+ if sql:
+ destination_revs = []
+ starting_rev = None
+ for revision in util.to_list(revisions):
+ if ":" in revision:
+ srev, revision = revision.split(":", 2)
+
+ if starting_rev != srev:
+ if starting_rev is None:
+ starting_rev = srev
+ else:
+ raise util.CommandError(
+ "Stamp operation with --sql only supports a "
+ "single starting revision at a time"
+ )
+ destination_revs.append(revision)
+ else:
+ destination_revs = util.to_list(revisions)
def do_stamp(rev, context):
- return script._stamp_revs(revision, rev)
+ return script._stamp_revs(util.to_tuple(destination_revs), rev)
with EnvironmentContext(
config,
script,
fn=do_stamp,
as_sql=sql,
- destination_rev=revision,
- starting_rev=starting_rev,
+ starting_rev=starting_rev if sql else None,
+ destination_rev=util.to_tuple(destination_revs),
tag=tag,
+ purge=purge,
):
script.run_env()
self._migrations_fn = opts.get("fn")
self.as_sql = as_sql
+ self.purge = opts.get("purge", False)
+
if "output_encoding" in opts:
self.output_buffer = EncodedIO(
opts.get("output_buffer") or sys.stdout,
if start_from_rev == "base":
start_from_rev = None
elif start_from_rev is not None and self.script:
- start_from_rev = self.script.get_revision(
- start_from_rev
- ).revision
+ start_from_rev = [
+ self.script.get_revision(sfr).revision
+ for sfr in util.to_list(start_from_rev)
+ if sfr not in (None, "base")
+ ]
return util.to_tuple(start_from_rev, default=())
else:
if self._start_from_rev:
row[0] for row in self.connection.execute(self._version.select())
)
- def _ensure_version_table(self):
+ def _ensure_version_table(self, purge=False):
self._version.create(self.connection, checkfirst=True)
+ if purge:
+ self.connection.execute(self._version.delete())
def _has_version_table(self):
return self.connection.dialect.has_table(
"""
self.impl.start_migrations()
- heads = self.get_current_heads()
- if not self.as_sql and not heads:
- self._ensure_version_table()
+ if self.purge:
+ if self.as_sql:
+ raise util.CommandError("Can't use --purge with --sql mode")
+ self._ensure_version_table(purge=True)
+ heads = ()
+ else:
+ heads = self.get_current_heads()
+
+ if not self.as_sql and not heads:
+ self._ensure_version_table()
head_maintainer = HeadMaintainer(self, heads)
)
def should_delete_branch(self, heads):
+ # TODO: we probably need to look for self.to_ inside of heads,
+ # in a similar manner as should_create_branch, however we have
+ # no tests for this yet (stamp downgrades w/ branches)
return self.is_downgrade and self.branch_move
def should_create_branch(self, heads):
- return self.is_upgrade and self.branch_move
+ return (
+ self.is_upgrade
+ and (self.branch_move or set(self.from_).difference(heads))
+ and set(self.to_).difference(heads)
+ )
def should_merge_branches(self, heads):
return len(self.from_) > 1
from ..environment import EnvironmentContext
from ..migration import MigrationContext
from ..operations import Operations
+from ..util import compat
from ..util.compat import configparser
from ..util.compat import string_types
from ..util.compat import text_type
testing_config.read(["test.cfg"])
-def capture_db():
+def capture_db(dialect="postgresql://"):
buf = []
def dump(sql, *multiparams, **params):
buf.append(str(sql.compile(dialect=engine.dialect)))
- engine = create_mock_engine("postgresql://", dump)
+ engine = create_mock_engine(dialect, dump)
return engine, buf
yield buf
+@contextmanager
+def capture_engine_context_buffer(**kw):
+ from .env import _sqlite_file_db
+ from sqlalchemy import event
+
+ buf = compat.StringIO()
+
+ eng = _sqlite_file_db()
+
+ conn = eng.connect()
+
+ @event.listens_for(conn, "before_cursor_execute")
+ def bce(conn, cursor, statement, parameters, context, executemany):
+ buf.write(statement + "\n")
+
+ kw.update({"connection": conn})
+ conf = EnvironmentContext.configure
+
+ def configure(*arg, **opt):
+ opt.update(**kw)
+ return conf(*arg, **opt)
+
+ with mock.patch.object(EnvironmentContext, "configure", configure):
+ yield buf
+
+
def op_fixture(
dialect="default",
as_sql=False,
from alembic.testing.env import _sqlite_testing_config
from alembic.testing.env import clear_staging_env
from alembic.testing.env import env_file_fixture
+from alembic.testing.env import multi_heads_fixture
from alembic.testing.env import staging_env
from alembic.testing.env import three_rev_fixture
from alembic.testing.env import write_script
from alembic.testing.fixtures import capture_context_buffer
+from alembic.testing.fixtures import capture_engine_context_buffer
from alembic.testing.fixtures import TestBase
command.revision(self.cfg, sql=True)
+class _StampTest(object):
+ def _assert_sql(self, emitted_sql, origin, destinations):
+ ins_expr = (
+ r"INSERT INTO alembic_version \(version_num\) "
+ r"VALUES \('(.+)'\)"
+ )
+ expected = [ins_expr for elem in destinations]
+ if origin:
+ expected[0] = (
+ "UPDATE alembic_version SET version_num='(.+)' WHERE "
+ "alembic_version.version_num = '%s'" % (origin,)
+ )
+ for line in emitted_sql.split("\n"):
+ if not expected:
+ assert not re.match(
+ ins_expr, line
+ ), "additional inserts were emitted"
+ else:
+ m = re.match(expected[0], line)
+ if m:
+ destinations.remove(m.group(1))
+ expected.pop(0)
+
+ assert not expected, "lines remain"
+
+
+class StampMultipleRootsTest(TestBase, _StampTest):
+ def setUp(self):
+ self.env = staging_env()
+ # self.cfg = cfg = _no_sql_testing_config()
+ self.cfg = cfg = _sqlite_testing_config()
+ # cfg.set_main_option("dialect_name", "sqlite")
+ # cfg.remove_main_option("url")
+
+ self.a1, self.b1, self.c1 = three_rev_fixture(cfg)
+ self.a2, self.b2, self.c2 = three_rev_fixture(cfg)
+
+ def tearDown(self):
+ clear_staging_env()
+
+ def test_sql_stamp_heads(self):
+ with capture_context_buffer() as buf:
+ command.stamp(self.cfg, ["heads"], sql=True)
+
+ self._assert_sql(buf.getvalue(), None, {self.c1, self.c2})
+
+ def test_sql_stamp_single_head(self):
+ with capture_context_buffer() as buf:
+ command.stamp(self.cfg, ["%s@head" % self.c1], sql=True)
+
+ self._assert_sql(buf.getvalue(), None, {self.c1})
+
+
+class StampMultipleHeadsTest(TestBase, _StampTest):
+ def setUp(self):
+ self.env = staging_env()
+ # self.cfg = cfg = _no_sql_testing_config()
+ self.cfg = cfg = _sqlite_testing_config()
+ # cfg.set_main_option("dialect_name", "sqlite")
+ # cfg.remove_main_option("url")
+
+ self.a, self.b, self.c = three_rev_fixture(cfg)
+ self.d, self.e, self.f = multi_heads_fixture(
+ cfg, self.a, self.b, self.c
+ )
+
+ def tearDown(self):
+ clear_staging_env()
+
+ def test_sql_stamp_heads(self):
+ with capture_context_buffer() as buf:
+ command.stamp(self.cfg, ["heads"], sql=True)
+
+ self._assert_sql(buf.getvalue(), None, {self.c, self.e, self.f})
+
+ def test_sql_stamp_multi_rev_nonsensical(self):
+ with capture_context_buffer() as buf:
+ command.stamp(self.cfg, [self.a, self.e, self.f], sql=True)
+ # TODO: this shouldn't be possible, because e/f require b as a
+ # dependency
+ self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+ def test_sql_stamp_multi_rev_from_multi_base_nonsensical(self):
+ with capture_context_buffer() as buf:
+ command.stamp(
+ self.cfg,
+ ["base:%s" % self.a, "base:%s" % self.e, "base:%s" % self.f],
+ sql=True,
+ )
+
+ # TODO: this shouldn't be possible, because e/f require b as a
+ # dependency
+ self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+ def test_online_stamp_multi_rev_nonsensical(self):
+ with capture_engine_context_buffer() as buf:
+ command.stamp(self.cfg, [self.a, self.e, self.f])
+
+ # TODO: this shouldn't be possible, because e/f require b as a
+ # dependency
+ self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+ def test_online_stamp_multi_rev_from_real_ancestor(self):
+ command.stamp(self.cfg, [self.a])
+ with capture_engine_context_buffer() as buf:
+ command.stamp(self.cfg, [self.e, self.f])
+
+ self._assert_sql(buf.getvalue(), self.a, {self.e, self.f})
+
+ def test_online_stamp_version_already_there(self):
+ command.stamp(self.cfg, [self.c, self.e])
+ with capture_engine_context_buffer() as buf:
+ command.stamp(self.cfg, [self.c, self.e])
+ self._assert_sql(buf.getvalue(), None, {})
+
+ def test_sql_stamp_multi_rev_from_multi_start(self):
+ with capture_context_buffer() as buf:
+ command.stamp(
+ self.cfg,
+ [
+ "%s:%s" % (self.b, self.c),
+ "%s:%s" % (self.b, self.e),
+ "%s:%s" % (self.b, self.f),
+ ],
+ sql=True,
+ )
+
+ self._assert_sql(buf.getvalue(), self.b, {self.c, self.e, self.f})
+
+ def test_sql_stamp_heads_symbolic(self):
+ with capture_context_buffer() as buf:
+ command.stamp(self.cfg, ["%s:heads" % self.b], sql=True)
+
+ self._assert_sql(buf.getvalue(), self.b, {self.c, self.e, self.f})
+
+ def test_sql_stamp_different_multi_start(self):
+ assert_raises_message(
+ util.CommandError,
+ "Stamp operation with --sql only supports a single "
+ "starting revision at a time",
+ command.stamp,
+ self.cfg,
+ ["%s:%s" % (self.b, self.c), "%s:%s" % (self.a, self.e)],
+ sql=True,
+ )
+
+ def test_stamp_purge(self):
+ command.stamp(self.cfg, [self.a])
+
+ eng = _sqlite_file_db()
+ with eng.connect() as conn:
+ result = conn.execute(
+ "update alembic_version set version_num='fake'"
+ )
+ eq_(result.rowcount, 1)
+
+ with capture_engine_context_buffer() as buf:
+ command.stamp(self.cfg, [self.a, self.e, self.f], purge=True)
+
+ self._assert_sql(buf.getvalue(), None, {self.a, self.e, self.f})
+
+ def test_stamp_purge_no_sql(self):
+ assert_raises_message(
+ util.CommandError,
+ "Can't use --purge with --sql mode",
+ command.stamp,
+ self.cfg,
+ [self.c],
+ sql=True,
+ purge=True,
+ )
+
+
class UpgradeDowngradeStampTest(TestBase):
def setUp(self):
self.env = staging_env()
self.bind.scalar("select version_num from alembic_version"), self.a
)
+ def test_stamp_version_already_there(self):
+ command.stamp(self.cfg, self.b)
+ command.stamp(self.cfg, self.b)
+
+ eq_(
+ self.bind.scalar("select version_num from alembic_version"), self.b
+ )
+
class EditTest(TestBase):
@classmethod