from sqlalchemy import schema, create_engine
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import _BindParamClause
+import sys
import logging
base = util.importlater("alembic.ddl", "base")
transactional_ddl = False
as_sql = False
- def __init__(self, connection, fn, as_sql=False):
- self.connection = connection
+ def __init__(self, connection, fn, as_sql=False, output_buffer=sys.stdout):
+ if as_sql:
+ self.connection = self._stdout_connection(connection)
+ assert self.connection is not None
+ else:
+ self.connection = connection
self._migrations_fn = fn
self.as_sql = as_sql
+ self.output_buffer = output_buffer
def _current_rev(self):
if self.as_sql:
def _update_current_rev(self, old, new):
if old == new:
return
-
if new is None:
self._exec(_version.delete())
elif old is None:
else "non-transactional")
if self.as_sql and self.transactional_ddl:
- print "BEGIN;\n"
-
- if self.as_sql:
- # TODO: coverage, --sql with just one rev == error
- current_rev = prev_rev = rev = None
- else:
- current_rev = prev_rev = rev = self._current_rev()
- for change, rev in self._migrations_fn(current_rev):
+ self.static_output("BEGIN;\n")
+
+ current_rev = False
+ for change, prev_rev, rev in self._migrations_fn(
+ self._current_rev()
+ if not self.as_sql else None):
+ if current_rev is False:
+ current_rev = prev_rev
+ if self.as_sql and not current_rev:
+ _version.create(self.connection)
log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
change(**kw)
if not self.transactional_ddl:
if self.transactional_ddl:
self._update_current_rev(current_rev, rev)
+ if self.as_sql and not rev:
+ _version.drop(self.connection)
+
if self.as_sql and self.transactional_ddl:
- print "COMMIT;\n"
+ self.static_output("COMMIT;\n")
def _exec(self, construct, *args, **kw):
if isinstance(construct, basestring):
if args or kw:
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
- print unicode(
+ self.static_output(unicode(
construct.compile(dialect=self.dialect)
- ).replace("\t", " ") + ";"
+ ).replace("\t", " ") + ";")
else:
self.connection.execute(construct, *args, **kw)
def dialect(self):
return self.connection.dialect
+ def static_output(self, text):
+ self.output_buffer.write(text + "\n")
+
def execute(self, sql):
self._exec(sql)
- @util.memoized_property
- def _stdout_connection(self):
+ def _stdout_connection(self, connection):
def dump(construct, *multiparams, **params):
self._exec(construct)
- return create_engine(self.connection.engine.url,
+ return create_engine(connection.engine.url,
strategy="mock", executor=dump)
@property
return results and is only appropriate for DDL.
"""
- if self.as_sql:
- return self._stdout_connection
- else:
- return self.connection
+ return self.connection
def alter_column(self, table_name, column_name,
nullable=None,
def _render_literal_bindparam(element, compiler, **kw):
return compiler.render_literal_bindparam(element, **kw)
+_context_opts = {}
+
def opts(cfg, **kw):
"""Set up options that will be used by the :func:`.configure_connection`
function.
This basically sets some global variables.
"""
- global _context_opts, config
- _context_opts = kw
+ global config
+ _context_opts.update(kw)
config = cfg
def configure_connection(connection):
revs = self._revs(*reversed(destination.split(':', 2)))
else:
revs = self._revs(destination, current_rev)
-
return [
- (script.module.upgrade, script.revision) for script in
+ (script.module.upgrade, script.down_revision, script.revision) for script in
reversed(list(revs))
]
if destination is not None and ':' in destination:
if not range_ok:
raise util.CommandError("Range revision not allowed")
- revs = self._revs(*reversed(destination.split(':', 2)))
+ revs = self._revs(*destination.split(':', 2))
else:
revs = self._revs(current_rev, destination)
return [
- (script.module.downgrade, script.down_revision) for script in
+ (script.module.downgrade, script.revision, script.down_revision) for script in
revs
]
import re
from alembic.context import _context_impls
from alembic import ddl
+import StringIO
staging_directory = os.path.join(os.path.dirname(__file__), 'scratch')
assert_string.replace("\n", "").replace("\t", "")
)
+def capture_context_buffer():
+ buf = StringIO.StringIO()
+
+ class capture(object):
+ def __enter__(self):
+ context._context_opts['output_buffer'] = buf
+ return buf
+
+ def __exit__(self, *arg, **kw):
+ print buf.getvalue()
+ context._context_opts.pop('output_buffer', None)
+
+ return capture()
+
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
eq_(
env.upgrade_from(False, e.revision, c.revision),
[
- (d.module.upgrade, d.revision),
- (e.module.upgrade, e.revision),
+ (d.module.upgrade, c.revision, d.revision),
+ (e.module.upgrade, d.revision, e.revision),
]
)
eq_(
env.upgrade_from(False, c.revision, None),
[
- (a.module.upgrade, a.revision),
- (b.module.upgrade, b.revision),
- (c.module.upgrade, c.revision),
+ (a.module.upgrade, None, a.revision),
+ (b.module.upgrade, a.revision, b.revision),
+ (c.module.upgrade, b.revision, c.revision),
]
)
eq_(
env.downgrade_to(False, c.revision, e.revision),
[
- (e.module.downgrade, e.down_revision),
- (d.module.downgrade, d.down_revision),
+ (e.module.downgrade, e.revision, e.down_revision),
+ (d.module.downgrade, d.revision, d.down_revision),
]
)
eq_(
env.downgrade_to(False, None, c.revision),
[
- (c.module.downgrade, c.down_revision),
- (b.module.downgrade, b.down_revision),
- (a.module.downgrade, a.down_revision),
+ (c.module.downgrade, c.revision, c.down_revision),
+ (b.module.downgrade, b.revision, b.down_revision),
+ (a.module.downgrade, a.revision, a.down_revision),
]
)
--- /dev/null
+from tests import clear_staging_env, staging_env, _sqlite_testing_config, sqlite_db, eq_, ne_, capture_context_buffer
+from alembic import command, util
+from alembic.script import ScriptDirectory
+
+def setup():
+ global cfg, env
+ env = staging_env()
+ cfg = _sqlite_testing_config()
+
+ global a, b, c
+ a = util.rev_id()
+ b = util.rev_id()
+ c = util.rev_id()
+
+ script = ScriptDirectory.from_config(cfg)
+ script.generate_rev(a, None)
+ script.write(a, """
+down_revision = None
+
+from alembic.op import *
+
+def upgrade():
+ execute("CREATE STEP 1")
+
+def downgrade():
+ execute("DROP STEP 1")
+
+""")
+
+ script.generate_rev(b, None)
+ script.write(b, """
+down_revision = '%s'
+
+from alembic.op import *
+
+def upgrade():
+ execute("CREATE STEP 2")
+
+def downgrade():
+ execute("DROP STEP 2")
+
+""" % a)
+
+ script.generate_rev(c, None)
+ script.write(c, """
+down_revision = '%s'
+
+from alembic.op import *
+
+def upgrade():
+ execute("CREATE STEP 3")
+
+def downgrade():
+ execute("DROP STEP 3")
+
+""" % b)
+
+def teardown():
+ clear_staging_env()
+
+def test_version_from_none_insert():
+ with capture_context_buffer() as buf:
+ command.upgrade(cfg, a, sql=True)
+ assert "CREATE TABLE alembic_version" in buf.getvalue()
+ assert "INSERT INTO alembic_version" in buf.getvalue()
+ assert "CREATE STEP 1" in buf.getvalue()
+ assert "CREATE STEP 2" not in buf.getvalue()
+ assert "CREATE STEP 3" not in buf.getvalue()
+
+def test_version_from_middle_update():
+ with capture_context_buffer() as buf:
+ command.upgrade(cfg, "%s:%s" % (b, c), sql=True)
+ assert "CREATE TABLE alembic_version" not in buf.getvalue()
+ assert "UPDATE alembic_version" in buf.getvalue()
+ assert "CREATE STEP 1" not in buf.getvalue()
+ assert "CREATE STEP 2" not in buf.getvalue()
+ assert "CREATE STEP 3" in buf.getvalue()
+
+def test_version_to_none():
+ with capture_context_buffer() as buf:
+ command.downgrade(cfg, "%s:base" % c, sql=True)
+ assert "CREATE TABLE alembic_version" not in buf.getvalue()
+ assert "INSERT INTO alembic_version" not in buf.getvalue()
+ assert "DROP TABLE alembic_version" in buf.getvalue()
+ assert "DROP STEP 3" in buf.getvalue()
+ assert "DROP STEP 2" in buf.getvalue()
+ assert "DROP STEP 1" in buf.getvalue()
+
+def test_version_to_middle():
+ with capture_context_buffer() as buf:
+ command.downgrade(cfg, "%s:%s" % (c, a), sql=True)
+ assert "CREATE TABLE alembic_version" not in buf.getvalue()
+ assert "INSERT INTO alembic_version" not in buf.getvalue()
+ assert "DROP TABLE alembic_version" not in buf.getvalue()
+ assert "DROP STEP 3" in buf.getvalue()
+ assert "DROP STEP 2" in buf.getvalue()
+ assert "DROP STEP 1" not in buf.getvalue()
+