From: Jeff Dairiki Date: Thu, 3 May 2012 20:40:43 +0000 (-0700) Subject: Make version table name configurable. X-Git-Tag: rel_0_3_3~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6ed983fc075450d874557e81feb0237d7a28a222;p=thirdparty%2Fsqlalchemy%2Falembic.git Make version table name configurable. --- diff --git a/CHANGES b/CHANGES index 4a20a5e3..e49a77f5 100644 --- a/CHANGES +++ b/CHANGES @@ -6,6 +6,10 @@ MySQL needs the constraint type in order to emit a DROP CONSTRAINT. #44 +- [feature] Added version_table argument to + EnvironmentContext.configure(), allowing for the + configuration of the version table name. #34 + 0.3.2 ===== - [feature] Basic support for Oracle added, diff --git a/alembic/environment.py b/alembic/environment.py index ab7a7d4e..4aeb8116 100644 --- a/alembic/environment.py +++ b/alembic/environment.py @@ -273,7 +273,9 @@ class EnvironmentContext(object): when using ``--sql`` mode. :param tag: a string tag for usage by custom ``env.py`` scripts. Set via the ``--tag`` option, can be overridden here. - + :param version_table: The name of the Alembic version table. + The default is ``'alembic_version'``. + Parameters specific to the autogenerate feature, when ``alembic revision`` is run with the ``--autogenerate`` feature: diff --git a/alembic/migration.py b/alembic/migration.py index 80e53b25..1a65cd22 100644 --- a/alembic/migration.py +++ b/alembic/migration.py @@ -9,11 +9,6 @@ from sqlalchemy.engine import url as sqla_url import logging log = logging.getLogger(__name__) -_meta = MetaData() -_version = Table('alembic_version', _meta, - Column('version_num', String(32), nullable=False) - ) - class MigrationContext(object): """Represent the database state made available to a migration script. @@ -81,6 +76,11 @@ class MigrationContext(object): 'compare_server_default', False) + version_table = opts.get('version_table', 'alembic_version') + self._version = Table( + version_table, MetaData(), + Column('version_num', String(32), nullable=False)) + self._start_from_rev = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( dialect, self.connection, self.as_sql, @@ -152,8 +152,8 @@ class MigrationContext(object): raise util.CommandError( "Can't specify current_rev to context " "when using a database connection") - _version.create(self.connection, checkfirst=True) - return self.connection.scalar(_version.select()) + self._version.create(self.connection, checkfirst=True) + return self.connection.scalar(self._version.select()) _current_rev = get_current_revision """The 0.2 method name, for backwards compat.""" @@ -162,13 +162,13 @@ class MigrationContext(object): if old == new: return if new is None: - self.impl._exec(_version.delete()) + self.impl._exec(self._version.delete()) elif old is None: - self.impl._exec(_version.insert(). + self.impl._exec(self._version.insert(). values(version_num=literal_column("'%s'" % new)) ) else: - self.impl._exec(_version.update(). + self.impl._exec(self._version.update(). values(version_num=literal_column("'%s'" % new)) ) @@ -201,7 +201,7 @@ class MigrationContext(object): if current_rev is False: current_rev = prev_rev if self.as_sql and not current_rev: - _version.create(self.connection) + self._version.create(self.connection) log.info("Running %s %s -> %s", change.__name__, prev_rev, rev) if self.as_sql: self.impl.static_output( @@ -218,7 +218,7 @@ class MigrationContext(object): self._update_current_rev(current_rev, rev) if self.as_sql and not rev: - _version.drop(self.connection) + self._version.drop(self.connection) def execute(self, sql): """Execute a SQL construct or string statement. diff --git a/tests/test_version_table.py b/tests/test_version_table.py new file mode 100644 index 00000000..0343cb1c --- /dev/null +++ b/tests/test_version_table.py @@ -0,0 +1,87 @@ +import unittest + +from sqlalchemy import Table, MetaData, Column, String, create_engine +from sqlalchemy.engine.reflection import Inspector + +from alembic.util import CommandError + +version_table = Table('version_table', MetaData(), + Column('version_num', String(32), nullable=False)) + +class TestMigrationContext(unittest.TestCase): + _bind = [] + + @property + def bind(self): + if not self._bind: + engine = create_engine('sqlite:///', echo=True) + self._bind.append(engine) + return self._bind[0] + + def setUp(self): + self.connection = self.bind.connect() + self.transaction = self.connection.begin() + + def tearDown(self): + version_table.drop(self.connection, checkfirst=True) + self.transaction.rollback() + + def make_one(self, **kwargs): + from alembic.migration import MigrationContext + return MigrationContext.configure(**kwargs) + + def get_revision(self): + result = self.connection.execute(version_table.select()) + rows = result.fetchall() + if len(rows) == 0: + return None + self.assertEqual(len(rows), 1) + return rows[0]['version_num'] + + def test_config_default_version_table_name(self): + context = self.make_one(dialect_name='sqlite') + self.assertEqual(context._version.name, 'alembic_version') + + def test_config_explicit_version_table_name(self): + context = self.make_one(dialect_name='sqlite', + opts={'version_table': 'explicit'}) + self.assertEqual(context._version.name, 'explicit') + + def test_get_current_revision_creates_version_table(self): + context = self.make_one(connection=self.connection, + opts={'version_table': 'version_table'}) + self.assertEqual(context.get_current_revision(), None) + insp = Inspector(self.connection) + self.assertTrue('version_table' in insp.get_table_names()) + + def test_get_current_revision(self): + context = self.make_one(connection=self.connection, + opts={'version_table': 'version_table'}) + version_table.create(self.connection) + self.assertEqual(context.get_current_revision(), None) + self.connection.execute( + version_table.insert().values(version_num='revid')) + self.assertEqual(context.get_current_revision(), 'revid') + + def test_get_current_revision_error_if_starting_rev_given_online(self): + context = self.make_one(connection=self.connection, + opts={'starting_rev': 'boo'}) + self.assertRaises(CommandError, context.get_current_revision) + + def test_get_current_revision_offline(self): + context = self.make_one(dialect_name='sqlite', + opts={'starting_rev': 'startrev', + 'as_sql': True}) + self.assertEqual(context.get_current_revision(), 'startrev') + + def test__update_current_rev(self): + version_table.create(self.connection) + context = self.make_one(connection=self.connection, + opts={'version_table': 'version_table'}) + + context._update_current_rev(None, 'a') + self.assertEqual(self.get_revision(), 'a') + context._update_current_rev('a', 'b') + self.assertEqual(self.get_revision(), 'b') + context._update_current_rev('b', None) + self.assertEqual(self.get_revision(), None)