From: Mike Bayer Date: Wed, 28 Apr 2010 21:47:01 +0000 (-0400) Subject: beginning to lay out migration flow X-Git-Tag: rel_0_1_0~100 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=535be0fc06aff4b7ff9e08a5f257eaa4feca1c89;p=thirdparty%2Fsqlalchemy%2Falembic.git beginning to lay out migration flow --- diff --git a/alembic/command.py b/alembic/command.py index 73e1db86..57c8072e 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -1,8 +1,7 @@ from alembic.script import ScriptDirectory -from alembic import options, util +from alembic import util import os -import sys -import uuid +import functools def list_templates(config): """List available templates""" @@ -70,22 +69,22 @@ def upgrade(config): """Upgrade to the latest version.""" script = ScriptDirectory.from_config(config) + context._migration_fn = script.upgrade_from + script.run_env() - # ... - -def revert(config): +def revert(config, revision): """Revert to a specific previous version.""" script = ScriptDirectory.from_config(config) - - # ... + context._migration_fn = functools.partial(script.downgrade_to, revision) + script.run_env() def history(config): """List changeset scripts in chronological order.""" script = ScriptDirectory.from_config(config) -def splice(config): +def splice(config, parent, child): """'splice' two branches, creating a new revision file.""" diff --git a/alembic/context.py b/alembic/context.py index a919c07b..a89da743 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -1,5 +1,6 @@ from alembic.ddl import base from alembic import util +from sqlalchemy import MetaData, Table, Column, String class ContextMeta(type): def __init__(cls, classname, bases, dict_): @@ -9,14 +10,36 @@ class ContextMeta(type): return newtype _context_impls = {} - + +_meta = MetaData() +_version = Table('alembic_version', _meta, + Column('version_num', String(32), nullable=False) + ) + class DefaultContext(object): __metaclass__ = ContextMeta + __dialect__ = 'default' - def __init__(self, options, connection): - self.options = options + def __init__(self, connection, fn): self.connection = connection + self._migrations_fn = fn + + def _current_rev(self): + _version.create(self.connection, checkfirst=True) + return self.connection.scalar(_version.select()) + def _update_current_rev(self, old, new): + if old is None: + self.connection.execute(_version.insert(), {'version_num':new}) + else: + self.connection.execute(_version.update(), {'version_num':new}) + + def run_migrations(self, **kw): + current_rev = self._current_rev() + for change in self._migrations_fn(current_rev): + change.execute() + self._update_current_rev(current_rev, change.upgrade) + def _exec(self, construct): pass @@ -37,4 +60,12 @@ class DefaultContext(object): def add_constraint(self, const): self._exec(schema.AddConstraint(const)) - \ No newline at end of file + +def configure_connection(connection): + global _context + _context = _context_impls[connection.dialect.name](connection, _migration_fn) + +def run_migrations(**kw): + global _context + _context.run_migrations(**kw) + \ No newline at end of file diff --git a/alembic/script.py b/alembic/script.py index d35751b7..9c83af84 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -22,6 +22,15 @@ class ScriptDirectory(object): return ScriptDirectory( options.get_main_option('script_location')) + def upgrade_from(self, current_rev): + return [] + + def downgrade_to(self, destination, current_rev): + return [] + + def run_env(self): + pass + @util.memoized_property def _revision_map(self): map_ = {} diff --git a/templates/generic/env.py b/templates/generic/env.py index e11c4a6c..b6a389f2 100644 --- a/templates/generic/env.py +++ b/templates/generic/env.py @@ -10,7 +10,7 @@ connection = engine.connect() context.configure_connection(connection) trans = connection.begin() try: - run_migrations() + context.run_migrations() trans.commit() except: trans.rollback() diff --git a/templates/multidb/env.py b/templates/multidb/env.py index 22746cb2..4e50ddf3 100644 --- a/templates/multidb/env.py +++ b/templates/multidb/env.py @@ -24,7 +24,7 @@ for name in re.split(r',\s*', db_names): try: for name, rec in engines.items(): context.configure_connection(rec['connection']) - run_migrations(engine=name) + context.run_migrations(engine=name) if USE_TWOPHASE: for rec in engines.values(): diff --git a/templates/pylons/env.py b/templates/pylons/env.py index 8afa1087..a20327b6 100644 --- a/templates/pylons/env.py +++ b/templates/pylons/env.py @@ -20,7 +20,7 @@ connection = meta.engine.connect() context.configure_connection(connection) trans = connection.begin() try: - run_migrations() + context.run_migrations() trans.commit() except: trans.rollback()