]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Make version table name configurable.
authorJeff Dairiki <dairiki@dairiki.org>
Thu, 3 May 2012 20:40:43 +0000 (13:40 -0700)
committerJeff Dairiki <dairiki@dairiki.org>
Thu, 3 May 2012 20:40:43 +0000 (13:40 -0700)
CHANGES
alembic/environment.py
alembic/migration.py
tests/test_version_table.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index 4a20a5e3a3a260a46158ada44575c6e91cf65721..e49a77f565ea9a68d623fd17bc988d767129473d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,10 @@
   MySQL needs the constraint type
   in order to emit a DROP CONSTRAINT. #44
 
+- [feature] Added version_table argument to
+  EnvironmentContext.configure(), allowing for the
+  configuration of the version table name. #34
+
 0.3.2
 =====
 - [feature] Basic support for Oracle added, 
index ab7a7d4ec694d0b6846a9f2eae5dada642fca1a9..4aeb8116a155d520ed967f90d84d27d6f838adf2 100644 (file)
@@ -273,7 +273,9 @@ class EnvironmentContext(object):
          when using ``--sql`` mode.
         :param tag: a string tag for usage by custom ``env.py`` scripts.  
          Set via the ``--tag`` option, can be overridden here.
-     
+        :param version_table: The name of the Alembic version table.
+         The default is ``'alembic_version'``.
+
         Parameters specific to the autogenerate feature, when 
         ``alembic revision`` is run with the ``--autogenerate`` feature:
     
index 80e53b2564c2e42153c0c52e3b7e3a35f6ea3a23..1a65cd22253de529f0d3cf4af549a14f8e5085d9 100644 (file)
@@ -9,11 +9,6 @@ from sqlalchemy.engine import url as sqla_url
 import logging
 log = logging.getLogger(__name__)
 
-_meta = MetaData()
-_version = Table('alembic_version', _meta, 
-                Column('version_num', String(32), nullable=False)
-            )
-
 class MigrationContext(object):
     """Represent the database state made available to a migration 
     script.
@@ -81,6 +76,11 @@ class MigrationContext(object):
                                             'compare_server_default', 
                                             False)
 
+        version_table = opts.get('version_table', 'alembic_version')
+        self._version = Table(
+            version_table, MetaData(),
+            Column('version_num', String(32), nullable=False))
+
         self._start_from_rev = opts.get("starting_rev")
         self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
                             dialect, self.connection, self.as_sql,
@@ -152,8 +152,8 @@ class MigrationContext(object):
                 raise util.CommandError(
                     "Can't specify current_rev to context "
                     "when using a database connection")
-            _version.create(self.connection, checkfirst=True)
-        return self.connection.scalar(_version.select())
+            self._version.create(self.connection, checkfirst=True)
+        return self.connection.scalar(self._version.select())
 
     _current_rev = get_current_revision
     """The 0.2 method name, for backwards compat."""
@@ -162,13 +162,13 @@ class MigrationContext(object):
         if old == new:
             return
         if new is None:
-            self.impl._exec(_version.delete())
+            self.impl._exec(self._version.delete())
         elif old is None:
-            self.impl._exec(_version.insert().
+            self.impl._exec(self._version.insert().
                         values(version_num=literal_column("'%s'" % new))
                     )
         else:
-            self.impl._exec(_version.update().
+            self.impl._exec(self._version.update().
                         values(version_num=literal_column("'%s'" % new))
                     )
 
@@ -201,7 +201,7 @@ class MigrationContext(object):
             if current_rev is False:
                 current_rev = prev_rev
                 if self.as_sql and not current_rev:
-                    _version.create(self.connection)
+                    self._version.create(self.connection)
             log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
             if self.as_sql:
                 self.impl.static_output(
@@ -218,7 +218,7 @@ class MigrationContext(object):
                 self._update_current_rev(current_rev, rev)
 
             if self.as_sql and not rev:
-                _version.drop(self.connection)
+                self._version.drop(self.connection)
 
     def execute(self, sql):
         """Execute a SQL construct or string statement.
diff --git a/tests/test_version_table.py b/tests/test_version_table.py
new file mode 100644 (file)
index 0000000..0343cb1
--- /dev/null
@@ -0,0 +1,87 @@
+import unittest
+
+from sqlalchemy import Table, MetaData, Column, String, create_engine
+from sqlalchemy.engine.reflection import Inspector
+
+from alembic.util import CommandError
+
+version_table = Table('version_table', MetaData(),
+                      Column('version_num', String(32), nullable=False))
+
+class TestMigrationContext(unittest.TestCase):
+    _bind = []
+
+    @property
+    def bind(self):
+        if not self._bind:
+            engine = create_engine('sqlite:///', echo=True)
+            self._bind.append(engine)
+        return self._bind[0]
+
+    def setUp(self):
+        self.connection = self.bind.connect()
+        self.transaction = self.connection.begin()
+
+    def tearDown(self):
+        version_table.drop(self.connection, checkfirst=True)
+        self.transaction.rollback()
+
+    def make_one(self, **kwargs):
+        from alembic.migration import MigrationContext
+        return MigrationContext.configure(**kwargs)
+
+    def get_revision(self):
+        result = self.connection.execute(version_table.select())
+        rows = result.fetchall()
+        if len(rows) == 0:
+            return None
+        self.assertEqual(len(rows), 1)
+        return rows[0]['version_num']
+
+    def test_config_default_version_table_name(self):
+        context = self.make_one(dialect_name='sqlite')
+        self.assertEqual(context._version.name, 'alembic_version')
+
+    def test_config_explicit_version_table_name(self):
+        context = self.make_one(dialect_name='sqlite',
+                                opts={'version_table': 'explicit'})
+        self.assertEqual(context._version.name, 'explicit')
+
+    def test_get_current_revision_creates_version_table(self):
+        context = self.make_one(connection=self.connection,
+                                opts={'version_table': 'version_table'})
+        self.assertEqual(context.get_current_revision(), None)
+        insp = Inspector(self.connection)
+        self.assertTrue('version_table' in insp.get_table_names())
+
+    def test_get_current_revision(self):
+        context = self.make_one(connection=self.connection,
+                                opts={'version_table': 'version_table'})
+        version_table.create(self.connection)
+        self.assertEqual(context.get_current_revision(), None)
+        self.connection.execute(
+            version_table.insert().values(version_num='revid'))
+        self.assertEqual(context.get_current_revision(), 'revid')
+
+    def test_get_current_revision_error_if_starting_rev_given_online(self):
+        context = self.make_one(connection=self.connection,
+                                opts={'starting_rev': 'boo'})
+        self.assertRaises(CommandError, context.get_current_revision)
+
+    def test_get_current_revision_offline(self):
+        context = self.make_one(dialect_name='sqlite',
+                                opts={'starting_rev': 'startrev',
+                                      'as_sql': True})
+        self.assertEqual(context.get_current_revision(), 'startrev')
+
+    def test__update_current_rev(self):
+        version_table.create(self.connection)
+        context = self.make_one(connection=self.connection,
+                                opts={'version_table': 'version_table'})
+
+        context._update_current_rev(None, 'a')
+        self.assertEqual(self.get_revision(), 'a')
+        context._update_current_rev('a', 'b')
+        self.assertEqual(self.get_revision(), 'b')
+        context._update_current_rev('b', None)
+        self.assertEqual(self.get_revision(), None)