]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- tests for SQL script
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Nov 2011 01:48:40 +0000 (17:48 -0800)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Nov 2011 01:48:40 +0000 (17:48 -0800)
- link create/drop of version table in SQL mode to the "none" revision
- get downgrades on SQL script to work

alembic/context.py
alembic/script.py
tests/__init__.py
tests/test_revision_paths.py
tests/test_sql_script.py [new file with mode: 0644]

index 5ba5ea1cfdf31f954951729438ff15e4236b5d05..33f357939af68231a6e1eb6b1b7eae2d33d1d428 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \
 from sqlalchemy import schema, create_engine
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql.expression import _BindParamClause
+import sys
 
 import logging
 base = util.importlater("alembic.ddl", "base")
@@ -30,10 +31,15 @@ class DefaultContext(object):
     transactional_ddl = False
     as_sql = False
 
-    def __init__(self, connection, fn, as_sql=False):
-        self.connection = connection
+    def __init__(self, connection, fn, as_sql=False, output_buffer=sys.stdout):
+        if as_sql:
+            self.connection = self._stdout_connection(connection)
+            assert self.connection is not None
+        else:
+            self.connection = connection
         self._migrations_fn = fn
         self.as_sql = as_sql
+        self.output_buffer = output_buffer
 
     def _current_rev(self):
         if self.as_sql:
@@ -48,7 +54,6 @@ class DefaultContext(object):
     def _update_current_rev(self, old, new):
         if old == new:
             return
-
         if new is None:
             self._exec(_version.delete())
         elif old is None:
@@ -67,14 +72,16 @@ class DefaultContext(object):
                         else "non-transactional")
 
         if self.as_sql and self.transactional_ddl:
-            print "BEGIN;\n"
-
-        if self.as_sql:
-            # TODO: coverage, --sql with just one rev == error
-            current_rev = prev_rev = rev = None
-        else:
-            current_rev = prev_rev = rev = self._current_rev()
-        for change, rev in self._migrations_fn(current_rev):
+            self.static_output("BEGIN;\n")
+
+        current_rev = False
+        for change, prev_rev, rev in self._migrations_fn(
+                                        self._current_rev() 
+                                        if not self.as_sql else None):
+            if current_rev is False:
+                current_rev = prev_rev
+                if self.as_sql and not current_rev:
+                    _version.create(self.connection)
             log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
             change(**kw)
             if not self.transactional_ddl:
@@ -84,8 +91,11 @@ class DefaultContext(object):
         if self.transactional_ddl:
             self._update_current_rev(current_rev, rev)
 
+        if self.as_sql and not rev:
+            _version.drop(self.connection)
+
         if self.as_sql and self.transactional_ddl:
-            print "COMMIT;\n"
+            self.static_output("COMMIT;\n")
 
     def _exec(self, construct, *args, **kw):
         if isinstance(construct, basestring):
@@ -94,9 +104,9 @@ class DefaultContext(object):
             if args or kw:
                 # TODO: coverage
                 raise Exception("Execution arguments not allowed with as_sql")
-            print unicode(
+            self.static_output(unicode(
                     construct.compile(dialect=self.dialect)
-                    ).replace("\t", "    ") + ";"
+                    ).replace("\t", "    ") + ";")
         else:
             self.connection.execute(construct, *args, **kw)
 
@@ -104,15 +114,17 @@ class DefaultContext(object):
     def dialect(self):
         return self.connection.dialect
 
+    def static_output(self, text):
+        self.output_buffer.write(text + "\n")
+
     def execute(self, sql):
         self._exec(sql)
 
-    @util.memoized_property
-    def _stdout_connection(self):
+    def _stdout_connection(self, connection):
         def dump(construct, *multiparams, **params):
             self._exec(construct)
 
-        return create_engine(self.connection.engine.url, 
+        return create_engine(connection.engine.url, 
                         strategy="mock", executor=dump)
 
     @property
@@ -125,10 +137,7 @@ class DefaultContext(object):
         return results and is only appropriate for DDL.
         
         """
-        if self.as_sql:
-            return self._stdout_connection
-        else:
-            return self.connection
+        return self.connection
 
     def alter_column(self, table_name, column_name, 
                         nullable=None,
@@ -185,6 +194,8 @@ class _literal_bindparam(_BindParamClause):
 def _render_literal_bindparam(element, compiler, **kw):
     return compiler.render_literal_bindparam(element, **kw)
 
+_context_opts = {}
+
 def opts(cfg, **kw):
     """Set up options that will be used by the :func:`.configure_connection`
     function.
@@ -192,8 +203,8 @@ def opts(cfg, **kw):
     This basically sets some global variables.
     
     """
-    global _context_opts, config
-    _context_opts = kw
+    global config
+    _context_opts.update(kw)
     config = cfg
 
 def configure_connection(connection):
index a8ac9a0436043a882a2143fe51c81c22aca6b740..7848ebfc7befe62e943cc8c6cd166c0181acce60 100644 (file)
@@ -76,9 +76,8 @@ class ScriptDirectory(object):
             revs = self._revs(*reversed(destination.split(':', 2)))
         else:
             revs = self._revs(destination, current_rev)
-
         return [
-            (script.module.upgrade, script.revision) for script in 
+            (script.module.upgrade, script.down_revision, script.revision) for script in 
             reversed(list(revs))
             ]
 
@@ -89,12 +88,12 @@ class ScriptDirectory(object):
         if destination is not None and ':' in destination:
             if not range_ok:
                 raise util.CommandError("Range revision not allowed")
-            revs = self._revs(*reversed(destination.split(':', 2)))
+            revs = self._revs(*destination.split(':', 2))
         else:
             revs = self._revs(current_rev, destination)
 
         return [
-            (script.module.downgrade, script.down_revision) for script in 
+            (script.module.downgrade, script.revision, script.down_revision) for script in 
             revs
             ]
 
index 17788f66122b5f2a3b26ccb13757c97c7d8e8228..1fc72721caeb3c322c37df0279a6083320731fce 100644 (file)
@@ -7,6 +7,7 @@ from alembic import context
 import re
 from alembic.context import _context_impls
 from alembic import ddl
+import StringIO
 
 staging_directory = os.path.join(os.path.dirname(__file__), 'scratch')
 
@@ -30,6 +31,20 @@ def assert_compiled(element, assert_string, dialect=None):
         assert_string.replace("\n", "").replace("\t", "")
     )
 
+def capture_context_buffer():
+    buf = StringIO.StringIO()
+
+    class capture(object):
+        def __enter__(self):
+            context._context_opts['output_buffer'] = buf
+            return buf
+
+        def __exit__(self, *arg, **kw):
+            print buf.getvalue()
+            context._context_opts.pop('output_buffer', None)
+
+    return capture()
+
 def eq_(a, b, msg=None):
     """Assert a == b, with repr messaging on failure."""
     assert a == b, msg or "%r != %r" % (a, b)
index 7feb44452dc427132196e97f21c7f19423788bb7..1320a33c4e76dba023b5f0af3c8dc5f7a7ddb66c 100644 (file)
@@ -21,17 +21,17 @@ def test_upgrade_path():
     eq_(
         env.upgrade_from(False, e.revision, c.revision),
         [
-            (d.module.upgrade, d.revision),
-            (e.module.upgrade, e.revision),
+            (d.module.upgrade, c.revision, d.revision),
+            (e.module.upgrade, d.revision, e.revision),
         ]
     )
 
     eq_(
         env.upgrade_from(False, c.revision, None),
         [
-            (a.module.upgrade, a.revision),
-            (b.module.upgrade, b.revision),
-            (c.module.upgrade, c.revision),
+            (a.module.upgrade, None, a.revision),
+            (b.module.upgrade, a.revision, b.revision),
+            (c.module.upgrade, b.revision, c.revision),
         ]
     )
 
@@ -40,16 +40,16 @@ def test_downgrade_path():
     eq_(
         env.downgrade_to(False, c.revision, e.revision),
         [
-            (e.module.downgrade, e.down_revision),
-            (d.module.downgrade, d.down_revision),
+            (e.module.downgrade, e.revision, e.down_revision),
+            (d.module.downgrade, d.revision, d.down_revision),
         ]
     )
 
     eq_(
         env.downgrade_to(False, None, c.revision),
         [
-            (c.module.downgrade, c.down_revision),
-            (b.module.downgrade, b.down_revision),
-            (a.module.downgrade, a.down_revision),
+            (c.module.downgrade, c.revision, c.down_revision),
+            (b.module.downgrade, b.revision, b.down_revision),
+            (a.module.downgrade, a.revision, a.down_revision),
         ]
     )
diff --git a/tests/test_sql_script.py b/tests/test_sql_script.py
new file mode 100644 (file)
index 0000000..1971b6e
--- /dev/null
@@ -0,0 +1,98 @@
+from tests import clear_staging_env, staging_env, _sqlite_testing_config, sqlite_db, eq_, ne_, capture_context_buffer
+from alembic import command, util
+from alembic.script import ScriptDirectory
+
+def setup():
+    global cfg, env
+    env = staging_env()
+    cfg = _sqlite_testing_config()
+
+    global a, b, c
+    a = util.rev_id()
+    b = util.rev_id()
+    c = util.rev_id()
+
+    script = ScriptDirectory.from_config(cfg)
+    script.generate_rev(a, None)
+    script.write(a, """
+down_revision = None
+
+from alembic.op import *
+
+def upgrade():
+    execute("CREATE STEP 1")
+
+def downgrade():
+    execute("DROP STEP 1")
+
+""")
+
+    script.generate_rev(b, None)
+    script.write(b, """
+down_revision = '%s'
+
+from alembic.op import *
+
+def upgrade():
+    execute("CREATE STEP 2")
+
+def downgrade():
+    execute("DROP STEP 2")
+
+""" % a)
+
+    script.generate_rev(c, None)
+    script.write(c, """
+down_revision = '%s'
+
+from alembic.op import *
+
+def upgrade():
+    execute("CREATE STEP 3")
+
+def downgrade():
+    execute("DROP STEP 3")
+
+""" % b)
+
+def teardown():
+    clear_staging_env()
+
+def test_version_from_none_insert():
+    with capture_context_buffer() as buf:
+        command.upgrade(cfg, a, sql=True)
+    assert "CREATE TABLE alembic_version" in buf.getvalue()
+    assert "INSERT INTO alembic_version" in buf.getvalue()
+    assert "CREATE STEP 1" in buf.getvalue()
+    assert "CREATE STEP 2" not in buf.getvalue()
+    assert "CREATE STEP 3" not in buf.getvalue()
+
+def test_version_from_middle_update():
+    with capture_context_buffer() as buf:
+        command.upgrade(cfg, "%s:%s" % (b, c), sql=True)
+    assert "CREATE TABLE alembic_version" not in buf.getvalue()
+    assert "UPDATE alembic_version" in buf.getvalue()
+    assert "CREATE STEP 1" not in buf.getvalue()
+    assert "CREATE STEP 2" not in buf.getvalue()
+    assert "CREATE STEP 3" in buf.getvalue()
+
+def test_version_to_none():
+    with capture_context_buffer() as buf:
+        command.downgrade(cfg, "%s:base" % c, sql=True)
+    assert "CREATE TABLE alembic_version" not in buf.getvalue()
+    assert "INSERT INTO alembic_version" not in buf.getvalue()
+    assert "DROP TABLE alembic_version" in buf.getvalue()
+    assert "DROP STEP 3" in buf.getvalue()
+    assert "DROP STEP 2" in buf.getvalue()
+    assert "DROP STEP 1" in buf.getvalue()
+
+def test_version_to_middle():
+    with capture_context_buffer() as buf:
+        command.downgrade(cfg, "%s:%s" % (c, a), sql=True)
+    assert "CREATE TABLE alembic_version" not in buf.getvalue()
+    assert "INSERT INTO alembic_version" not in buf.getvalue()
+    assert "DROP TABLE alembic_version" not in buf.getvalue()
+    assert "DROP STEP 3" in buf.getvalue()
+    assert "DROP STEP 2" in buf.getvalue()
+    assert "DROP STEP 1" not in buf.getvalue()
+