import logging
log = logging.getLogger(__name__)
-_meta = MetaData()
-_version = Table('alembic_version', _meta,
- Column('version_num', String(32), nullable=False)
- )
-
class MigrationContext(object):
"""Represent the database state made available to a migration
script.
'compare_server_default',
False)
+ version_table = opts.get('version_table', 'alembic_version')
+ self._version = Table(
+ version_table, MetaData(),
+ Column('version_num', String(32), nullable=False))
+
self._start_from_rev = opts.get("starting_rev")
self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
dialect, self.connection, self.as_sql,
raise util.CommandError(
"Can't specify current_rev to context "
"when using a database connection")
- _version.create(self.connection, checkfirst=True)
- return self.connection.scalar(_version.select())
+ self._version.create(self.connection, checkfirst=True)
+ return self.connection.scalar(self._version.select())
_current_rev = get_current_revision
"""The 0.2 method name, for backwards compat."""
if old == new:
return
if new is None:
- self.impl._exec(_version.delete())
+ self.impl._exec(self._version.delete())
elif old is None:
- self.impl._exec(_version.insert().
+ self.impl._exec(self._version.insert().
values(version_num=literal_column("'%s'" % new))
)
else:
- self.impl._exec(_version.update().
+ self.impl._exec(self._version.update().
values(version_num=literal_column("'%s'" % new))
)
if current_rev is False:
current_rev = prev_rev
if self.as_sql and not current_rev:
- _version.create(self.connection)
+ self._version.create(self.connection)
log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
if self.as_sql:
self.impl.static_output(
self._update_current_rev(current_rev, rev)
if self.as_sql and not rev:
- _version.drop(self.connection)
+ self._version.drop(self.connection)
def execute(self, sql):
"""Execute a SQL construct or string statement.
--- /dev/null
+import unittest
+
+from sqlalchemy import Table, MetaData, Column, String, create_engine
+from sqlalchemy.engine.reflection import Inspector
+
+from alembic.util import CommandError
+
+version_table = Table('version_table', MetaData(),
+ Column('version_num', String(32), nullable=False))
+
+class TestMigrationContext(unittest.TestCase):
+ _bind = []
+
+ @property
+ def bind(self):
+ if not self._bind:
+ engine = create_engine('sqlite:///', echo=True)
+ self._bind.append(engine)
+ return self._bind[0]
+
+ def setUp(self):
+ self.connection = self.bind.connect()
+ self.transaction = self.connection.begin()
+
+ def tearDown(self):
+ version_table.drop(self.connection, checkfirst=True)
+ self.transaction.rollback()
+
+ def make_one(self, **kwargs):
+ from alembic.migration import MigrationContext
+ return MigrationContext.configure(**kwargs)
+
+ def get_revision(self):
+ result = self.connection.execute(version_table.select())
+ rows = result.fetchall()
+ if len(rows) == 0:
+ return None
+ self.assertEqual(len(rows), 1)
+ return rows[0]['version_num']
+
+ def test_config_default_version_table_name(self):
+ context = self.make_one(dialect_name='sqlite')
+ self.assertEqual(context._version.name, 'alembic_version')
+
+ def test_config_explicit_version_table_name(self):
+ context = self.make_one(dialect_name='sqlite',
+ opts={'version_table': 'explicit'})
+ self.assertEqual(context._version.name, 'explicit')
+
+ def test_get_current_revision_creates_version_table(self):
+ context = self.make_one(connection=self.connection,
+ opts={'version_table': 'version_table'})
+ self.assertEqual(context.get_current_revision(), None)
+ insp = Inspector(self.connection)
+ self.assertTrue('version_table' in insp.get_table_names())
+
+ def test_get_current_revision(self):
+ context = self.make_one(connection=self.connection,
+ opts={'version_table': 'version_table'})
+ version_table.create(self.connection)
+ self.assertEqual(context.get_current_revision(), None)
+ self.connection.execute(
+ version_table.insert().values(version_num='revid'))
+ self.assertEqual(context.get_current_revision(), 'revid')
+
+ def test_get_current_revision_error_if_starting_rev_given_online(self):
+ context = self.make_one(connection=self.connection,
+ opts={'starting_rev': 'boo'})
+ self.assertRaises(CommandError, context.get_current_revision)
+
+ def test_get_current_revision_offline(self):
+ context = self.make_one(dialect_name='sqlite',
+ opts={'starting_rev': 'startrev',
+ 'as_sql': True})
+ self.assertEqual(context.get_current_revision(), 'startrev')
+
+ def test__update_current_rev(self):
+ version_table.create(self.connection)
+ context = self.make_one(connection=self.connection,
+ opts={'version_table': 'version_table'})
+
+ context._update_current_rev(None, 'a')
+ self.assertEqual(self.get_revision(), 'a')
+ context._update_current_rev('a', 'b')
+ self.assertEqual(self.get_revision(), 'b')
+ context._update_current_rev('b', None)
+ self.assertEqual(self.get_revision(), None)