]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
upgrade, downgrade motion
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Apr 2010 22:00:44 +0000 (18:00 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Apr 2010 22:00:44 +0000 (18:00 -0400)
15 files changed:
alembic/__init__.py
alembic/command.py
alembic/context.py
alembic/op.py
alembic/script.py
alembic/util.py
templates/generic/env.py
templates/generic/script.py.mako
templates/multidb/env.py
templates/multidb/script.py.mako
templates/pylons/env.py
templates/pylons/script.py.mako
tests/test_migrations.py [deleted file]
tests/test_revision_create.py [moved from tests/test_revisions.py with 87% similarity]
tests/test_revision_paths.py [new file with mode: 0644]

index 568b0006e5d82ac60390bfdb589271685b854255..c8c90904c54f425d45ae9273784d1e4576772882 100644 (file)
@@ -13,22 +13,48 @@ def main(argv):
     # and derives everything from callables ?
     # we're inventing here a bit.
     
-    commands = dict([
-                (fn.__name__, fn) for fn in 
-                [getattr(command, n) for n in dir(command)]
-                if inspect.isfunction(fn) and 
-                    fn.__name__[0] != '_' and 
-                    fn.__module__ == 'alembic.command'
-                ])
+    commands = {}
+    for fn in [getattr(command, n) for n in dir(command)]:
+        if inspect.isfunction(fn) and \
+            fn.__name__[0] != '_' and \
+            fn.__module__ == 'alembic.command':
+            
+            spec = inspect.getargspec(fn)
+            if spec[3]:
+                positional = spec[0][1:-len(spec[3])]
+                kwarg = spec[0][-len(spec[3]):]
+            else:
+                positional = spec[0][1:]
+                kwarg = []
+            
+            commands[fn.__name__] = {
+                'name':fn.__name__,
+                'fn':fn,
+                'positional':positional,
+                'kwargs':kwarg
+            }
+
+    def format_cmd(cmd):
+        return "%s %s" % (
+            cmd['name'], 
+            " ".join(["<%s>" % p for p in cmd['positional']])
+        )
     
+    def format_opt(cmd, padding=32):
+        opt = format_cmd(cmd)
+        return "  " + opt + \
+                ((padding - len(opt)) * " ") + cmd['fn'].__doc__
+        
     parser = OptionParser(
                 "usage: %prog [options] <command> [command arguments]\n\n"
                 "Available Commands:\n" +
                 "\n".join(sorted([
-                    util.format_opt(fn.__name__.replace('_', '-'), fn.__doc__)
-                    for fn in commands.values()
-                ]))
+                    format_opt(cmd)
+                    for cmd in commands.values()
+                ])) +
+                "\n\n<revision> is a hex revision id or 'head'"
                 )
+                
     parser.add_option("-c", "--config", 
                         type="string", 
                         default="alembic.ini", 
@@ -53,28 +79,19 @@ def main(argv):
     except KeyError:
         util.err("no such command %r" % cmd)
         
-    spec = inspect.getargspec(cmd_fn)
-    if spec[3]:
-        positional = spec[0][1:-len(spec[3])]
-        kwarg = spec[0][-len(spec[3]):]
-    else:
-        positional = spec[0][1:]
-        kwarg = []
-        
     kw = dict(
         (k, getattr(cmd_line_options, k)) 
-        for k in kwarg
+        for k in cmd_fn['kwargs']
     )
         
-    if len(cmd_line_args) != len(positional):
-        util.err("Usage: %s %s [options] %s" % (
+    if len(cmd_line_args) != len(cmd_fn['positional']):
+        util.err("Usage: %s %s [options]" % (
                         os.path.basename(argv[0]), 
-                        cmd, 
-                        " ".join(["<%s>" % p for p in positional])
+                        format_cmd(cmd_fn)
                     ))
 
     cfg = config.Config(cmd_line_options.config)
-    cmd_fn(cfg, *cmd_line_args, **kw)
+    cmd_fn['fn'](cfg, *cmd_line_args, **kw)
 
 
 
index 1e6ea0169d4d9f67c4382b09a8ff717f8c615c4f..a19d3824b8dbf33cb823dd778ae94285dc6e78fb 100644 (file)
@@ -65,19 +65,20 @@ def revision(config, message=None):
     script = ScriptDirectory.from_config(config)
     script.generate_rev(util.rev_id(), message)
     
-def upgrade(config):
-    """Upgrade to the latest version."""
+def upgrade(config, revision):
+    """Upgrade to a later version."""
 
     script = ScriptDirectory.from_config(config)
-    context._migration_fn = script.upgrade_from
+    context._migration_fn = functools.partial(script.upgrade_from, revision)
     context.config = config
     script.run_env()
     
 def revert(config, revision):
-    """Revert to a specific previous version."""
+    """Revert to a previous version."""
     
     script = ScriptDirectory.from_config(config)
     context._migration_fn = functools.partial(script.downgrade_to, revision)
+    context.config = config
     script.run_env()
 
 def history(config):
index 8565695e4ff8b6a251186b1b4b9919e004cf2c5f..c0380350ce1597ba86d59faf99d088dbd6f2b9b6 100644 (file)
@@ -29,19 +29,24 @@ class DefaultContext(object):
         return self.connection.scalar(_version.select())
     
     def _update_current_rev(self, old, new):
-        if old is None:
+        if new is None:
+            self.connection.execute(_version.delete())
+        elif old is None:
             self.connection.execute(_version.insert(), {'version_num':new})
         else:
             self.connection.execute(_version.update(), {'version_num':new})
             
     def run_migrations(self, **kw):
         current_rev = self._current_rev()
+        rev = -1
         for change, rev in self._migrations_fn(current_rev):
-            change.execute(**kw)
-        self._update_current_rev(current_rev, rev)
+            print "-> %s" % (rev, )
+            change(**kw)
+        if rev != -1:
+            self._update_current_rev(current_rev, rev)
         
     def _exec(self, construct):
-        pass
+        self.connection.execute(construct)
     
     def execute(self, sql):
         self._exec(sql)
@@ -70,4 +75,6 @@ def configure_connection(connection):
     
 def run_migrations(**kw):
     _context.run_migrations(**kw)
-    
\ No newline at end of file
+
+def get_context():
+    return _context
\ No newline at end of file
index c7b5d10043691f7940fe1ac71a8c0f88322eaac2..2e1abd17519e7f512679bbe640e0d77b212c75db 100644 (file)
@@ -1,8 +1,9 @@
 from alembic import util
+from alembic.context import get_context
 from sqlalchemy.types import NULLTYPE
 from sqlalchemy import schema
 
-__all__ = ['alter_column', 'create_foreign_key', 'create_unique_constraint']
+__all__ = ['alter_column', 'create_foreign_key', 'create_unique_constraint', 'execute']
 
 def alter_column(table_name, column_name, 
                     nullable=util.NO_VALUE,
@@ -41,14 +42,14 @@ def _unique_constraint(name, source, local_cols):
     return schema.UniqueConstraint(*t.c, name=name)
     
 def create_foreign_key(name, source, referent, local_cols, remote_cols):
-    context.add_constraint(
+    get_context().add_constraint(
                 _foreign_key_constraint(source, referent, local_cols, remote_cols)
             )
 
 def create_unique_constraint(name, source, local_cols):
-    context.add_constraint(
+    get_context().add_constraint(
                 _unique_constraint(name, source, local_cols)
             )
 
 def execute(sql):
-    context.execute(sql)
\ No newline at end of file
+    get_context().execute(sql)
\ No newline at end of file
index 2f33ffb800ccfc90c34fcfe59dc2a9aa449c8a97..e6a62213c41a3b67e76eefe04c99222313244b5b 100644 (file)
@@ -18,21 +18,37 @@ class ScriptDirectory(object):
                         "scripts folder." % dir)
         
     @classmethod
-    def from_config(cls, options):
+    def from_config(cls, config):
         return ScriptDirectory(
-                    options.get_main_option('script_location'))
-
-    def upgrade_from(self, current_rev):
-        head = self._current_head()
-        script = self._revision_map[head]
-        scripts = []
-        while script.upgrade != current_rev:
-            scripts.append((script.module.upgrade, script.upgrade))
-            script = self._revision_map[script.downgrade]
-        return reversed(scripts)
+                    config.get_main_option('script_location'))
+    
+    def _get_rev(self, id_):
+        if id_ == 'head':
+            return self._current_head()
+        elif id_ == 'base':
+            return None
+        else:
+            return id_
+            
+    def _revs(self, upper, lower):
+        lower = self._revision_map[self._get_rev(lower)]
+        upper = self._revision_map[self._get_rev(upper)]
+        script = upper
+        while script != lower:
+            yield script
+            script = self._revision_map[script.down_revision]
+        
+    def upgrade_from(self, destination, current_rev):
+        return [
+            (script.module.upgrade, script.revision) for script in 
+            reversed(list(self._revs(destination, current_rev)))
+            ]
         
     def downgrade_to(self, destination, current_rev):
-        return []
+        return [
+            (script.module.downgrade, script.down_revision) for script in 
+            self._revs(current_rev, destination)
+            ]
         
     def run_env(self):
         util.load_python_file(self.dir, 'env.py')
@@ -44,18 +60,19 @@ class ScriptDirectory(object):
             script = Script.from_path(self.versions, file_)
             if script is None:
                 continue
-            if script.upgrade in map_:
-                util.warn("Revision %s is present more than once" % script.upgrade)
-            map_[script.upgrade] = script
+            if script.revision in map_:
+                util.warn("Revision %s is present more than once" % script.revision)
+            map_[script.revision] = script
         for rev in map_.values():
-            if rev.downgrade is None:
+            if rev.down_revision is None:
                 continue
-            if rev.downgrade not in map_:
+            if rev.down_revision not in map_:
                 util.warn("Revision %s referenced from %s is not present"
-                            % (rev.downgrade, rev))
-                rev.downgrade = None
+                            % (rev.down_revision, rev))
+                rev.down_revision = None
             else:
-                map_[rev.downgrade].nextrev = rev.upgrade
+                map_[rev.down_revision].nextrev = rev.revision
+        map_[None] = None
         return map_
     
     def _current_head(self):
@@ -71,16 +88,16 @@ class ScriptDirectory(object):
         # TODO: keep map sorted chronologically
         heads = []
         for script in self._revision_map.values():
-            if script.nextrev is None:
-                heads.append(script.upgrade)
+            if script and script.nextrev is None:
+                heads.append(script.revision)
         return heads
     
     def _get_origin(self):
         # TODO: keep map sorted chronologically
         
         for script in self._revision_map.values():
-            if script.downgrade is None \
-                and script.upgrade in self._revision_map:
+            if script.down_revision is None \
+                and script.revision in self._revision_map:
                 return script
         else:
             return None
@@ -109,9 +126,9 @@ class ScriptDirectory(object):
             message=message if message is not None else ("Alembic revision %s" % revid)
         )
         script = Script.from_path(self.versions, filename)
-        self._revision_map[script.upgrade] = script
-        if script.downgrade:
-            self._revision_map[script.downgrade].nextrev = script.upgrade
+        self._revision_map[script.revision] = script
+        if script.down_revision:
+            self._revision_map[script.down_revision].nextrev = script.revision
         return script
         
 class Script(object):
@@ -119,12 +136,12 @@ class Script(object):
     
     def __init__(self, module, rev_id):
         self.module = module
-        self.upgrade = rev_id
-        self.downgrade = getattr(module, 'down_revision', None)
+        self.revision = rev_id
+        self.down_revision = getattr(module, 'down_revision', None)
     
     def __str__(self):
-        return "revision %s" % self.upgrade
-        
+        return "revision %s" % self.revision
+    
     @classmethod
     def from_path(cls, dir_, filename):
         m = _rev_file.match(filename)
index 3e297043fcd06580d017c846aa934ba8ac24cf8f..5751e204d4b016f534b99839d37378bcf3310499 100644 (file)
@@ -25,10 +25,6 @@ def template_to_file(template_file, dest, **kw):
     f.close()
 
 
-def format_opt(opt, hlp, padding=22):
-    return "  " + opt + \
-        ((padding - len(opt)) * " ") + hlp
-
 def status(_statmsg, fn, *arg, **kw):
     msg(_statmsg + "...", False)
     try:
index cf025bf612670a46b15bcf1209332125458f0d08..60073764135e94c594de93425ac897d7c987394e 100644 (file)
@@ -15,3 +15,4 @@ try:
     trans.commit()
 except:
     trans.rollback()
+    raise
\ No newline at end of file
index 9e89fa603f8321afd0b7e6f9912bf5269bcfb497..ece48fa989bf541434a33c9050b00f6242d473a7 100644 (file)
@@ -8,9 +8,5 @@ from alembic.op import *
 def upgrade():
     pass
 
-% if down_revision:
 def downgrade():
     pass
-% else:
-# this is the origin node, no downgrade !
-% endif
index 4e50ddf3cf5db286a15a7500283d7a0c8fc2a07a..bde257d5a100913addde953847faba396bfafa5c 100644 (file)
@@ -35,3 +35,4 @@ try:
 except:
     for rec in engines.values():
         rec['transaction'].rollback()
+    raise
\ No newline at end of file
index d485a58e499b14950507a4a72292119391bbca3d..b572a4b348000920565e4f193f04afaeb34bde17 100644 (file)
@@ -8,12 +8,8 @@ from alembic.op import *
 def upgrade(engine):
     eval("upgrade_%s" % engine.name)()
 
-% if down_revision:
 def downgrade(engine):
     eval("upgrade_%s" % engine.name)()
-% else:
-# this is the origin node, no downgrade !
-% endif
 
 
 % for engine in ["engine1", "engine2"]:
index a20327b62835e5aac7abb694ff2bfa2e9d4e8a7f..d733022197485ea9dc5b04c65847b2e60dfda770 100644 (file)
@@ -24,3 +24,4 @@ try:
     trans.commit()
 except:
     trans.rollback()
+    raise
\ No newline at end of file
index 9e89fa603f8321afd0b7e6f9912bf5269bcfb497..ece48fa989bf541434a33c9050b00f6242d473a7 100644 (file)
@@ -8,9 +8,5 @@ from alembic.op import *
 def upgrade():
     pass
 
-% if down_revision:
 def downgrade():
     pass
-% else:
-# this is the origin node, no downgrade !
-% endif
diff --git a/tests/test_migrations.py b/tests/test_migrations.py
deleted file mode 100644 (file)
index 1956bb9..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-from tests import clear_staging_env, staging_env, eq_, ne_
-from alembic import util
-
-
-def setup():
-    global env
-    env = staging_env()
-    global a, b, c, d, e
-    a = env.generate_rev(util.rev_id(), None)
-    b = env.generate_rev(util.rev_id(), None)
-    c = env.generate_rev(util.rev_id(), None)
-    d = env.generate_rev(util.rev_id(), None)
-    e = env.generate_rev(util.rev_id(), None)
-    
-def teardown():
-    clear_staging_env()
-
-
-def test_upgrade_path():
-    
-    eq_(
-        list(env.upgrade_from(c.upgrade)),
-        [
-            (d.module.upgrade, d.upgrade),
-            (e.module.upgrade, e.upgrade),
-        ]
-    )
-    
\ No newline at end of file
similarity index 87%
rename from tests/test_revisions.py
rename to tests/test_revision_create.py
index 30fc9c6725813af8a1d4e9fdfa695dc7ae9fffe2..0f176f2bc884aacbe546faa7fe0488e475003934 100644 (file)
@@ -21,16 +21,16 @@ def test_003_heads():
 def test_004_rev():
     script = env.generate_rev(abc, "this is a message")
     eq_(script.module.__doc__,"this is a message")
-    eq_(script.upgrade, abc)
-    eq_(script.downgrade, None)
+    eq_(script.revision, abc)
+    eq_(script.down_revision, None)
     assert os.access(os.path.join(env.dir, 'versions', '%s.py' % abc), os.F_OK)
     assert callable(script.module.upgrade)
     eq_(env._get_heads(), [abc])
     
 def test_005_nextrev():
     script = env.generate_rev(def_, "this is the next rev")
-    eq_(script.upgrade, def_)
-    eq_(script.downgrade, abc)
+    eq_(script.revision, def_)
+    eq_(script.down_revision, abc)
     eq_(env._revision_map[abc].nextrev, def_)
     assert script.module.down_revision == abc
     assert callable(script.module.upgrade)
@@ -45,8 +45,8 @@ def test_006_from_clean_env():
     abc_rev = env._revision_map[abc]
     def_rev = env._revision_map[def_]
     eq_(abc_rev.nextrev, def_)
-    eq_(abc_rev.upgrade, abc)
-    eq_(def_rev.downgrade, abc)
+    eq_(abc_rev.revision, abc)
+    eq_(def_rev.down_revision, abc)
     eq_(env._get_heads(), [def_])
     
 def setup():
diff --git a/tests/test_revision_paths.py b/tests/test_revision_paths.py
new file mode 100644 (file)
index 0000000..c891dab
--- /dev/null
@@ -0,0 +1,55 @@
+from tests import clear_staging_env, staging_env, eq_, ne_
+from alembic import util
+
+
+def setup():
+    global env
+    env = staging_env()
+    global a, b, c, d, e
+    a = env.generate_rev(util.rev_id(), None)
+    b = env.generate_rev(util.rev_id(), None)
+    c = env.generate_rev(util.rev_id(), None)
+    d = env.generate_rev(util.rev_id(), None)
+    e = env.generate_rev(util.rev_id(), None)
+    
+def teardown():
+    clear_staging_env()
+
+
+def test_upgrade_path():
+    
+    eq_(
+        env.upgrade_from(e.revision, c.revision),
+        [
+            (d.module.upgrade, d.revision),
+            (e.module.upgrade, e.revision),
+        ]
+    )
+
+    eq_(
+        env.upgrade_from(c.revision, None),
+        [
+            (a.module.upgrade, a.revision),
+            (b.module.upgrade, b.revision),
+            (c.module.upgrade, c.revision),
+        ]
+    )
+    
+def test_downgrade_path():
+
+    eq_(
+        env.downgrade_to(c.revision, e.revision),
+        [
+            (e.module.downgrade, e.down_revision),
+            (d.module.downgrade, d.down_revision),
+        ]
+    )
+
+    eq_(
+        env.downgrade_to(None, c.revision),
+        [
+            (c.module.downgrade, c.down_revision),
+            (b.module.downgrade, b.down_revision),
+            (a.module.downgrade, a.down_revision),
+        ]
+    )