From: Mike Bayer Date: Wed, 9 Nov 2011 01:48:40 +0000 (-0800) Subject: - tests for SQL script X-Git-Tag: rel_0_1_0~65 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2da7f004e49167b3d94c9fe031d61d968ce85c92;p=thirdparty%2Fsqlalchemy%2Falembic.git - tests for SQL script - link create/drop of version table in SQL mode to the "none" revision - get downgrades on SQL script to work --- diff --git a/alembic/context.py b/alembic/context.py index 5ba5ea1c..33f35793 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -4,6 +4,7 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \ from sqlalchemy import schema, create_engine from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import _BindParamClause +import sys import logging base = util.importlater("alembic.ddl", "base") @@ -30,10 +31,15 @@ class DefaultContext(object): transactional_ddl = False as_sql = False - def __init__(self, connection, fn, as_sql=False): - self.connection = connection + def __init__(self, connection, fn, as_sql=False, output_buffer=sys.stdout): + if as_sql: + self.connection = self._stdout_connection(connection) + assert self.connection is not None + else: + self.connection = connection self._migrations_fn = fn self.as_sql = as_sql + self.output_buffer = output_buffer def _current_rev(self): if self.as_sql: @@ -48,7 +54,6 @@ class DefaultContext(object): def _update_current_rev(self, old, new): if old == new: return - if new is None: self._exec(_version.delete()) elif old is None: @@ -67,14 +72,16 @@ class DefaultContext(object): else "non-transactional") if self.as_sql and self.transactional_ddl: - print "BEGIN;\n" - - if self.as_sql: - # TODO: coverage, --sql with just one rev == error - current_rev = prev_rev = rev = None - else: - current_rev = prev_rev = rev = self._current_rev() - for change, rev in self._migrations_fn(current_rev): + self.static_output("BEGIN;\n") + + current_rev = False + for change, prev_rev, rev in self._migrations_fn( + self._current_rev() + if not self.as_sql else None): + if current_rev is False: + current_rev = prev_rev + if self.as_sql and not current_rev: + _version.create(self.connection) log.info("Running %s %s -> %s", change.__name__, prev_rev, rev) change(**kw) if not self.transactional_ddl: @@ -84,8 +91,11 @@ class DefaultContext(object): if self.transactional_ddl: self._update_current_rev(current_rev, rev) + if self.as_sql and not rev: + _version.drop(self.connection) + if self.as_sql and self.transactional_ddl: - print "COMMIT;\n" + self.static_output("COMMIT;\n") def _exec(self, construct, *args, **kw): if isinstance(construct, basestring): @@ -94,9 +104,9 @@ class DefaultContext(object): if args or kw: # TODO: coverage raise Exception("Execution arguments not allowed with as_sql") - print unicode( + self.static_output(unicode( construct.compile(dialect=self.dialect) - ).replace("\t", " ") + ";" + ).replace("\t", " ") + ";") else: self.connection.execute(construct, *args, **kw) @@ -104,15 +114,17 @@ class DefaultContext(object): def dialect(self): return self.connection.dialect + def static_output(self, text): + self.output_buffer.write(text + "\n") + def execute(self, sql): self._exec(sql) - @util.memoized_property - def _stdout_connection(self): + def _stdout_connection(self, connection): def dump(construct, *multiparams, **params): self._exec(construct) - return create_engine(self.connection.engine.url, + return create_engine(connection.engine.url, strategy="mock", executor=dump) @property @@ -125,10 +137,7 @@ class DefaultContext(object): return results and is only appropriate for DDL. """ - if self.as_sql: - return self._stdout_connection - else: - return self.connection + return self.connection def alter_column(self, table_name, column_name, nullable=None, @@ -185,6 +194,8 @@ class _literal_bindparam(_BindParamClause): def _render_literal_bindparam(element, compiler, **kw): return compiler.render_literal_bindparam(element, **kw) +_context_opts = {} + def opts(cfg, **kw): """Set up options that will be used by the :func:`.configure_connection` function. @@ -192,8 +203,8 @@ def opts(cfg, **kw): This basically sets some global variables. """ - global _context_opts, config - _context_opts = kw + global config + _context_opts.update(kw) config = cfg def configure_connection(connection): diff --git a/alembic/script.py b/alembic/script.py index a8ac9a04..7848ebfc 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -76,9 +76,8 @@ class ScriptDirectory(object): revs = self._revs(*reversed(destination.split(':', 2))) else: revs = self._revs(destination, current_rev) - return [ - (script.module.upgrade, script.revision) for script in + (script.module.upgrade, script.down_revision, script.revision) for script in reversed(list(revs)) ] @@ -89,12 +88,12 @@ class ScriptDirectory(object): if destination is not None and ':' in destination: if not range_ok: raise util.CommandError("Range revision not allowed") - revs = self._revs(*reversed(destination.split(':', 2))) + revs = self._revs(*destination.split(':', 2)) else: revs = self._revs(current_rev, destination) return [ - (script.module.downgrade, script.down_revision) for script in + (script.module.downgrade, script.revision, script.down_revision) for script in revs ] diff --git a/tests/__init__.py b/tests/__init__.py index 17788f66..1fc72721 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,6 +7,7 @@ from alembic import context import re from alembic.context import _context_impls from alembic import ddl +import StringIO staging_directory = os.path.join(os.path.dirname(__file__), 'scratch') @@ -30,6 +31,20 @@ def assert_compiled(element, assert_string, dialect=None): assert_string.replace("\n", "").replace("\t", "") ) +def capture_context_buffer(): + buf = StringIO.StringIO() + + class capture(object): + def __enter__(self): + context._context_opts['output_buffer'] = buf + return buf + + def __exit__(self, *arg, **kw): + print buf.getvalue() + context._context_opts.pop('output_buffer', None) + + return capture() + def eq_(a, b, msg=None): """Assert a == b, with repr messaging on failure.""" assert a == b, msg or "%r != %r" % (a, b) diff --git a/tests/test_revision_paths.py b/tests/test_revision_paths.py index 7feb4445..1320a33c 100644 --- a/tests/test_revision_paths.py +++ b/tests/test_revision_paths.py @@ -21,17 +21,17 @@ def test_upgrade_path(): eq_( env.upgrade_from(False, e.revision, c.revision), [ - (d.module.upgrade, d.revision), - (e.module.upgrade, e.revision), + (d.module.upgrade, c.revision, d.revision), + (e.module.upgrade, d.revision, e.revision), ] ) eq_( env.upgrade_from(False, c.revision, None), [ - (a.module.upgrade, a.revision), - (b.module.upgrade, b.revision), - (c.module.upgrade, c.revision), + (a.module.upgrade, None, a.revision), + (b.module.upgrade, a.revision, b.revision), + (c.module.upgrade, b.revision, c.revision), ] ) @@ -40,16 +40,16 @@ def test_downgrade_path(): eq_( env.downgrade_to(False, c.revision, e.revision), [ - (e.module.downgrade, e.down_revision), - (d.module.downgrade, d.down_revision), + (e.module.downgrade, e.revision, e.down_revision), + (d.module.downgrade, d.revision, d.down_revision), ] ) eq_( env.downgrade_to(False, None, c.revision), [ - (c.module.downgrade, c.down_revision), - (b.module.downgrade, b.down_revision), - (a.module.downgrade, a.down_revision), + (c.module.downgrade, c.revision, c.down_revision), + (b.module.downgrade, b.revision, b.down_revision), + (a.module.downgrade, a.revision, a.down_revision), ] ) diff --git a/tests/test_sql_script.py b/tests/test_sql_script.py new file mode 100644 index 00000000..1971b6e6 --- /dev/null +++ b/tests/test_sql_script.py @@ -0,0 +1,98 @@ +from tests import clear_staging_env, staging_env, _sqlite_testing_config, sqlite_db, eq_, ne_, capture_context_buffer +from alembic import command, util +from alembic.script import ScriptDirectory + +def setup(): + global cfg, env + env = staging_env() + cfg = _sqlite_testing_config() + + global a, b, c + a = util.rev_id() + b = util.rev_id() + c = util.rev_id() + + script = ScriptDirectory.from_config(cfg) + script.generate_rev(a, None) + script.write(a, """ +down_revision = None + +from alembic.op import * + +def upgrade(): + execute("CREATE STEP 1") + +def downgrade(): + execute("DROP STEP 1") + +""") + + script.generate_rev(b, None) + script.write(b, """ +down_revision = '%s' + +from alembic.op import * + +def upgrade(): + execute("CREATE STEP 2") + +def downgrade(): + execute("DROP STEP 2") + +""" % a) + + script.generate_rev(c, None) + script.write(c, """ +down_revision = '%s' + +from alembic.op import * + +def upgrade(): + execute("CREATE STEP 3") + +def downgrade(): + execute("DROP STEP 3") + +""" % b) + +def teardown(): + clear_staging_env() + +def test_version_from_none_insert(): + with capture_context_buffer() as buf: + command.upgrade(cfg, a, sql=True) + assert "CREATE TABLE alembic_version" in buf.getvalue() + assert "INSERT INTO alembic_version" in buf.getvalue() + assert "CREATE STEP 1" in buf.getvalue() + assert "CREATE STEP 2" not in buf.getvalue() + assert "CREATE STEP 3" not in buf.getvalue() + +def test_version_from_middle_update(): + with capture_context_buffer() as buf: + command.upgrade(cfg, "%s:%s" % (b, c), sql=True) + assert "CREATE TABLE alembic_version" not in buf.getvalue() + assert "UPDATE alembic_version" in buf.getvalue() + assert "CREATE STEP 1" not in buf.getvalue() + assert "CREATE STEP 2" not in buf.getvalue() + assert "CREATE STEP 3" in buf.getvalue() + +def test_version_to_none(): + with capture_context_buffer() as buf: + command.downgrade(cfg, "%s:base" % c, sql=True) + assert "CREATE TABLE alembic_version" not in buf.getvalue() + assert "INSERT INTO alembic_version" not in buf.getvalue() + assert "DROP TABLE alembic_version" in buf.getvalue() + assert "DROP STEP 3" in buf.getvalue() + assert "DROP STEP 2" in buf.getvalue() + assert "DROP STEP 1" in buf.getvalue() + +def test_version_to_middle(): + with capture_context_buffer() as buf: + command.downgrade(cfg, "%s:%s" % (c, a), sql=True) + assert "CREATE TABLE alembic_version" not in buf.getvalue() + assert "INSERT INTO alembic_version" not in buf.getvalue() + assert "DROP TABLE alembic_version" not in buf.getvalue() + assert "DROP STEP 3" in buf.getvalue() + assert "DROP STEP 2" in buf.getvalue() + assert "DROP STEP 1" not in buf.getvalue() +