]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
this is all tests passing with the refactor, which IMHO is
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jan 2012 18:42:43 +0000 (13:42 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jan 2012 18:42:43 +0000 (13:42 -0500)
miraculous

12 files changed:
alembic/autogenerate.py
alembic/command.py
alembic/config.py
alembic/environment.py
alembic/migration.py
alembic/operations.py
alembic/script.py
tests/__init__.py
tests/test_autogenerate.py
tests/test_postgresql.py
tests/test_revision_paths.py
tests/test_sql_script.py

index 728d1e22c4e6ceb2181b54b605616b3cfffe617f..d90114a398db73de003f819b8544ca9175f79be5 100644 (file)
@@ -13,7 +13,8 @@ log = logging.getLogger(__name__)
 # top level
 
 
-def produce_migration_diffs(context, opts, template_args, imports):
+def produce_migration_diffs(context, template_args, imports):
+    opts = context.opts
     metadata = opts['target_metadata']
     if metadata is None:
         raise util.CommandError(
@@ -22,7 +23,7 @@ def produce_migration_diffs(context, opts, template_args, imports):
                 "a MetaData object to the context." % (
                     context._script.env_py_location
                 ))
-    connection = get_bind()
+    connection = context.bind
     diffs = []
     autogen_context = {
         'imports':imports,
@@ -308,7 +309,7 @@ def _add_table(table, autogen_context):
         'args':',\n'.join(
             [_render_column(col, autogen_context) for col in table.c] +
             sorted([rcons for rcons in 
-                [_render_constraint(cons) for cons in 
+                [_render_constraint(cons, autogen_context) for cons in 
                     table.constraints]
                 if rcons is not None
             ])
@@ -420,14 +421,14 @@ def _repr_type(prefix, type_, autogen_context):
     else:
         return "%s%r" % (prefix, type_)
 
-def _render_constraint(constraint):
+def _render_constraint(constraint, autogen_context):
     renderer = _constraint_renderers.get(type(constraint), None)
     if renderer:
-        return renderer(constraint)
+        return renderer(constraint, autogen_context)
     else:
         return None
 
-def _render_primary_key(constraint):
+def _render_primary_key(constraint, autogen_context):
     opts = []
     if constraint.name:
         opts.append(("name", repr(constraint.name)))
@@ -439,7 +440,7 @@ def _render_primary_key(constraint):
         ),
     }
 
-def _render_foreign_key(constraint):
+def _render_foreign_key(constraint, autogen_context):
     opts = []
     if constraint.name:
         opts.append(("name", repr(constraint.name)))
index b743f3b55c2aa135f8c542e8794bfe1427b1d056..a8c0fc42682d6e8a118c513043c207f4bd2152ed 100644 (file)
@@ -67,10 +67,10 @@ def revision(config, message=None, autogenerate=False):
     imports = set()
     if autogenerate:
         util.requires_07("autogenerate")
-        def retrieve_migrations(rev):
+        def retrieve_migrations(rev, context):
             if script._get_rev(rev) is not script._get_rev("head"):
                 raise util.CommandError("Target database is not up to date.")
-            autogen.produce_migration_diffs(template_args, imports)
+            autogen.produce_migration_diffs(context, template_args, imports)
             return []
 
         with environment.configure(
@@ -150,7 +150,7 @@ def current(config):
     """Display the current revision for each database."""
 
     script = ScriptDirectory.from_config(config)
-    def display_version(rev):
+    def display_version(rev, context):
         print "Current revision for %s: %s" % (
                             util.obfuscate_url_pw(
                                 context.get_context().connection.engine.url),
@@ -169,15 +169,15 @@ def stamp(config, revision, sql=False, tag=None):
     run any migrations."""
 
     script = ScriptDirectory.from_config(config)
-    def do_stamp(rev):
+    def do_stamp(rev, context):
         if sql:
             current = False
         else:
-            current = context.get_context()._current_rev()
+            current = context._current_rev()
         dest = script._get_rev(revision)
         if dest is not None:
             dest = dest.revision
-        context.get_context()._update_current_rev(current, dest)
+        context._update_current_rev(current, dest)
         return []
     with environment.configure(
         config, 
@@ -186,7 +186,7 @@ def stamp(config, revision, sql=False, tag=None):
         as_sql = sql,
         destination_rev = revision,
         tag = tag
-    ):
+    ) as env:
         script.run_env()
 
 def splice(config, parent, child):
index 22a48749fed0c83b3992f141537f98d5c5269f47..1dc6eb95c268016b477d37e651e53c2eb4fa2bf8 100644 (file)
@@ -107,6 +107,9 @@ class Config(object):
         """
         self.file_config.set(self.config_ini_section, name, value)
 
+    def remove_main_option(self, name):
+        self.file_config.remove_option(self.config_ini_section, name)
+
     def set_section_option(self, section, name, value):
         """Set an option programmatically within the given section.
         
index 125187989c8d3428718ef85b229c9afaff50cb31..53c35594ef0528895ea170f9ceb5a4fa6ec1a059 100644 (file)
@@ -2,17 +2,20 @@ import alembic
 from alembic.operations import Operations
 from alembic.migration import MigrationContext
 from alembic import util
-from sqlalchemy.engine import url as sqla_url
+from contextlib import contextmanager
 
 class EnvironmentContext(object):
     """Represent the state made available to an env.py script."""
 
     _migration_context = None
+    _default_opts = None
 
     def __init__(self, config, script, **kw):
         self.config = config
         self.script = script
         self.context_opts = kw
+        if self._default_opts:
+            self.context_opts.update(self._default_opts)
 
     def __enter__(self):
         """Establish a context which provides a 
@@ -264,18 +267,6 @@ class EnvironmentContext(object):
          one step.
 
         """
-
-        if connection:
-            dialect = connection.dialect
-        elif url:
-            url = sqla_url.make_url(url)
-            dialect = url.get_dialect()()
-        elif dialect_name:
-            url = sqla_url.make_url("%s://" % dialect_name)
-            dialect = url.get_dialect()()
-        else:
-            raise Exception("Connection, url, or dialect_name is required.")
-
         opts = self.context_opts
         if transactional_ddl is not None:
             opts["transactional_ddl"] =  transactional_ddl
@@ -292,19 +283,19 @@ class EnvironmentContext(object):
         opts['downgrade_token'] = downgrade_token
         opts['sqlalchemy_module_prefix'] = sqlalchemy_module_prefix
         opts['alembic_module_prefix'] = alembic_module_prefix
+        if compare_type is not None:
+            opts['compare_type'] = compare_type
+        if compare_server_default is not None:
+            opts['compare_server_default'] = compare_server_default
+        opts['script'] = self.script
         opts.update(kw)
 
-        self._migration_context = MigrationContext(
-                            dialect, self.script, connection, 
-                            opts,
-                            as_sql=opts.get('as_sql', False), 
-                            output_buffer=opts.get("output_buffer"),
-                            transactional_ddl=opts.get("transactional_ddl"),
-                            starting_rev=opts.get("starting_rev"),
-                            compare_type=compare_type,
-                            compare_server_default=compare_server_default,
-                        )
-        alembic.op._proxy = Operations(self._migration_context)
+        self._migration_context = MigrationContext.configure(
+            connection=connection,
+            url=url,
+            dialect_name=dialect_name,
+            opts=opts
+        )
 
     def run_migrations(self, **kw):
         """Run migrations as determined by the current command line configuration
@@ -324,7 +315,8 @@ class EnvironmentContext(object):
         made available via :func:`.configure`.
 
         """
-        self.migration_context.run_migrations(**kw)
+        with Operations.context(self._migration_context):
+            self.migration_context.run_migrations(**kw)
 
     def execute(self, sql):
         """Execute the given SQL using the current change context.
index 69e69304bf842c05ed5ac1d77ceade3a679831a5..733727d21ab749c9a389df93768b2687ed195282 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \
 from sqlalchemy import create_engine
 from alembic import ddl
 import sys
-from contextlib import contextmanager
+from sqlalchemy.engine import url as sqla_url
 
 import logging
 log = logging.getLogger(__name__)
@@ -21,22 +21,19 @@ class MigrationContext(object):
     Mediates the relationship between an ``env.py`` environment script, 
     a :class:`.ScriptDirectory` instance, and a :class:`.DefaultImpl` instance.
 
-    The :class:`.Context` is available directly via the :func:`.get_context` function,
+    The :class:`.MigrationContext` is available directly via the :func:`.get_context` function,
     though usually it is referenced behind the scenes by the various module level functions
     within the :mod:`alembic.context` module.
 
     """
-    def __init__(self, dialect, script, connection, 
-                        opts,
-                        as_sql=False, 
-                        output_buffer=None,
-                        transactional_ddl=None,
-                        starting_rev=None,
-                        compare_type=False,
-                        compare_server_default=False):
+    def __init__(self, dialect, connection, opts):
+        self.opts = opts
         self.dialect = dialect
-        # TODO: need this ?
-        self.script = script
+        self.script = opts.get('script')
+
+        as_sql=opts.get('as_sql', False)
+        transactional_ddl=opts.get("transactional_ddl")
+
         if as_sql:
             self.connection = self._stdout_connection(connection)
             assert self.connection is not None
@@ -44,12 +41,12 @@ class MigrationContext(object):
             self.connection = connection
         self._migrations_fn = opts.get('fn')
         self.as_sql = as_sql
-        self.output_buffer = output_buffer if output_buffer else sys.stdout
+        self.output_buffer = opts.get("output_buffer", sys.stdout)
 
-        self._user_compare_type = compare_type
-        self._user_compare_server_default = compare_server_default
+        self._user_compare_type = opts.get('compare_type', False)
+        self._user_compare_server_default = opts.get('compare_server_default', False)
 
-        self._start_from_rev = starting_rev
+        self._start_from_rev = opts.get("starting_rev")
         self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
                             dialect, self.connection, self.as_sql,
                             transactional_ddl,
@@ -63,6 +60,46 @@ class MigrationContext(object):
                         "transactional" if self.impl.transactional_ddl 
                         else "non-transactional")
 
+    @classmethod
+    def configure(cls,
+                connection=None,
+                url=None,
+                dialect_name=None,
+                opts=None,
+    ):
+        """Create a new :class:`.MigrationContext`.
+        
+        This is a factory method usually called
+        by :meth:`.EnvironmentContext.configure`.
+        
+        :param connection: a :class:`~sqlalchemy.engine.base.Connection` to use
+         for SQL execution in "online" mode.  When present, is also used to 
+         determine the type of dialect in use.
+        :param url: a string database url, or a :class:`sqlalchemy.engine.url.URL` object.
+         The type of dialect to be used will be derived from this if ``connection`` is
+         not passed.
+        :param dialect_name: string name of a dialect, such as "postgresql", "mssql", etc.
+         The type of dialect to be used will be derived from this if ``connection``
+         and ``url`` are not passed.
+        :param opts: dictionary of options.  Most other options
+         accepted by :meth:`.EnvironmentContext.configure` are passed via 
+         this dictionary.
+
+        """
+        if connection:
+            dialect = connection.dialect
+        elif url:
+            url = sqla_url.make_url(url)
+            dialect = url.get_dialect()()
+        elif dialect_name:
+            url = sqla_url.make_url("%s://" % dialect_name)
+            dialect = url.get_dialect()()
+        else:
+            raise Exception("Connection, url, or dialect_name is required.")
+
+        return MigrationContext(dialect, connection, opts)
+
+
     def _current_rev(self):
         if self.as_sql:
             return self._start_from_rev
@@ -93,7 +130,8 @@ class MigrationContext(object):
         current_rev = rev = False
         self.impl.start_migrations()
         for change, prev_rev, rev in self._migrations_fn(
-                                        self._current_rev()):
+                                        self._current_rev(),
+                                        self):
             if current_rev is False:
                 current_rev = prev_rev
                 if self.as_sql and not current_rev:
index cc2ef48b93fac27f5a437d4664b95d1dad79b980..a6530cf30d4fd8de739f26298d469cb5c4cc8bc5 100644 (file)
@@ -2,6 +2,8 @@ from alembic import util
 from alembic.ddl import impl
 from sqlalchemy.types import NULLTYPE, Integer
 from sqlalchemy import schema, sql
+from contextlib import contextmanager
+import alembic
 
 __all__ = sorted([
             'alter_column', 
@@ -34,6 +36,14 @@ class Operations(object):
         self.migration_context = migration_context
         self.impl = migration_context.impl
 
+    @classmethod
+    @contextmanager
+    def context(cls, migration_context):
+        op = Operations(migration_context)
+        alembic.op._proxy = op
+        yield op
+        del alembic.op._proxy
+
     def _foreign_key_constraint(self, name, source, referent, local_cols, remote_cols):
         m = schema.MetaData()
         t1 = schema.Table(source, m, 
index 4b7eaf25dd8602eec7943bb9e100700fa0cda81d..7c3bb0fc663600bee9a6d0d636d1913c38aeed67 100644 (file)
@@ -88,14 +88,14 @@ class ScriptDirectory(object):
             if script is None and lower is not None:
                 raise util.CommandError("Couldn't find revision %s" % downrev)
 
-    def upgrade_from(self, destination, current_rev):
+    def upgrade_from(self, destination, current_rev, context):
         revs = self._revs(destination, current_rev)
         return [
             (script.module.upgrade, script.down_revision, script.revision) for script in 
             reversed(list(revs))
             ]
 
-    def downgrade_to(self, destination, current_rev):
+    def downgrade_to(self, destination, current_rev, context):
         revs = self._revs(current_rev, destination)
         return [
             (script.module.downgrade, script.revision, script.down_revision) for script in 
index 328040a4821aae2da30470136e9c2481ad0e9b9b..7e3e4b9f400296143460d8f6b935b7d7730019bf 100644 (file)
@@ -7,6 +7,7 @@ import itertools
 from sqlalchemy import create_engine, text, MetaData
 from alembic import util
 from alembic.migration import MigrationContext
+from alembic.environment import EnvironmentContext
 import re
 import alembic
 from alembic.operations import Operations
@@ -84,17 +85,16 @@ def capture_context_buffer(**kw):
 
     class capture(object):
         def __enter__(self):
-            context.configure(
-                dialect_name="sqlite",
-                output_buffer = buf,
-                **kw
-            )
+            EnvironmentContext._default_opts = {
+                'dialect_name':"sqlite",
+                'output_buffer':buf
+            }
+            EnvironmentContext._default_opts.update(kw)
             return buf
 
         def __exit__(self, *arg, **kwarg):
             print buf.getvalue()
-            for k in kw:
-                context._context_opts.pop(k, None)
+            EnvironmentContext._default_opts = None
 
     return capture()
 
index 0e913cf541ea59310749ccad9075b0f5276bc564..2b6cd3a5735dc6e20159194ec63e9a377b850fe2 100644 (file)
@@ -2,7 +2,8 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
     Numeric, CHAR, ForeignKey, DATETIME, TypeDecorator
 from sqlalchemy.types import NULLTYPE
 from sqlalchemy.engine.reflection import Inspector
-from alembic import autogenerate, context
+from alembic import autogenerate
+from alembic.migration import MigrationContext
 from unittest import TestCase
 from tests import staging_env, sqlite_db, clear_staging_env, eq_, \
         eq_ignore_whitespace, requires_07
@@ -76,18 +77,26 @@ class AutogenerateDiffTest(TestCase):
         cls.m1 = _model_one()
         cls.m1.create_all(cls.bind)
         cls.m2 = _model_two()
-        context.configure(
+
+        cls.context = context = MigrationContext.configure(
             connection = cls.bind.connect(),
-            compare_type = True,
-            compare_server_default = True,
-            target_metadata=cls.m2
+            opts = {
+                'compare_type':True,
+                'compare_server_default':True,
+                'target_metadata':cls.m2,
+                'upgrade_token':"upgrades",
+                'downgrade_token':"downgrades",
+                'alembic_module_prefix':'op.',
+                'sqlalchemy_module_prefix':'sa.'
+            }
         )
-        connection = context.get_bind()
+
+        connection = context.bind
         cls.autogen_context = {
             'imports':set(),
             'connection':connection,
             'dialect':connection.dialect,
-            'context':context.get_context()
+            'context':context
             }
 
     @classmethod
@@ -98,7 +107,7 @@ class AutogenerateDiffTest(TestCase):
         """test generation of diff rules"""
 
         metadata = self.m2
-        connection = context.get_bind()
+        connection = self.context.bind
         diffs = []
         autogenerate._produce_net_changes(connection, metadata, diffs, 
                                         self.autogen_context)
@@ -140,14 +149,18 @@ class AutogenerateDiffTest(TestCase):
 
 
     def test_render_nothing(self):
-        context.configure(
+        context = MigrationContext.configure(
             connection = self.bind.connect(),
-            compare_type = True,
-            compare_server_default = True,
-            target_metadata=self.m1
+            opts = {
+                'compare_type' : True,
+                'compare_server_default' : True,
+                'target_metadata' : self.m1,
+                'upgrade_token':"upgrades",
+                'downgrade_token':"downgrades",
+            }
         )
         template_args = {}
-        autogenerate.produce_migration_diffs(template_args, self.autogen_context)
+        autogenerate.produce_migration_diffs(context, template_args, set())
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
 """### commands auto generated by Alembic - please adjust! ###
     pass
@@ -162,7 +175,7 @@ class AutogenerateDiffTest(TestCase):
 
         metadata = self.m2
         template_args = {}
-        autogenerate.produce_migration_diffs(template_args, self.autogen_context)
+        autogenerate.produce_migration_diffs(self.context, template_args, set())
         eq_(re.sub(r"u'", "'", template_args['upgrades']),
 """### commands auto generated by Alembic - please adjust! ###
     op.create_table('item',
@@ -273,8 +286,12 @@ class AutogenRenderTest(TestCase):
     @classmethod
     @requires_07
     def setup_class(cls):
-        context._context_opts['sqlalchemy_module_prefix'] = 'sa.'
-        context._context_opts['alembic_module_prefix'] = 'op.'
+        cls.autogen_context = {
+            'opts':{
+                'sqlalchemy_module_prefix' : 'sa.',
+                'alembic_module_prefix' : 'op.'
+            }
+        }
 
     def test_render_table_upgrade(self):
         m = MetaData()
@@ -285,7 +302,7 @@ class AutogenRenderTest(TestCase):
             Column("amount", Numeric(5, 2)),
         )
         eq_ignore_whitespace(
-            autogenerate._add_table(t, {}),
+            autogenerate._add_table(t, self.autogen_context),
             "op.create_table('test',"
             "sa.Column('id', sa.Integer(), nullable=False),"
             "sa.Column('address_id', sa.Integer(), nullable=True),"
@@ -300,14 +317,14 @@ class AutogenRenderTest(TestCase):
 
     def test_render_drop_table(self):
         eq_(
-            autogenerate._drop_table(Table("sometable", MetaData()), {}),
+            autogenerate._drop_table(Table("sometable", MetaData()), self.autogen_context),
             "op.drop_table('sometable')"
         )
 
     def test_render_add_column(self):
         eq_(
             autogenerate._add_column(
-                    "foo", Column("x", Integer, server_default="5"), {}),
+                    "foo", Column("x", Integer, server_default="5"), self.autogen_context),
             "op.add_column('foo', sa.Column('x', sa.Integer(), "
                 "server_default='5', nullable=True))"
         )
@@ -315,7 +332,7 @@ class AutogenRenderTest(TestCase):
     def test_render_drop_column(self):
         eq_(
             autogenerate._drop_column(
-                    "foo", Column("x", Integer, server_default="5"), {}),
+                    "foo", Column("x", Integer, server_default="5"), self.autogen_context),
 
             "op.drop_column('foo', 'x')"
         )
@@ -324,7 +341,7 @@ class AutogenRenderTest(TestCase):
         eq_ignore_whitespace(
             autogenerate._modify_col(
                         "sometable", "somecolumn", 
-                        {},
+                        self.autogen_context,
                         type_=CHAR(10), existing_type=CHAR(20)),
             "op.alter_column('sometable', 'somecolumn', "
                 "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))"
@@ -334,7 +351,7 @@ class AutogenRenderTest(TestCase):
         eq_ignore_whitespace(
             autogenerate._modify_col(
                         "sometable", "somecolumn", 
-                        {},
+                        self.autogen_context,
                         existing_type=Integer(),
                         nullable=True),
             "op.alter_column('sometable', 'somecolumn', "
@@ -345,7 +362,7 @@ class AutogenRenderTest(TestCase):
         eq_ignore_whitespace(
             autogenerate._modify_col(
                         "sometable", "somecolumn", 
-                        {},
+                        self.autogen_context,
                         existing_type=Integer(),
                         existing_server_default="5",
                         nullable=True),
index 46cd81d613afeca40c46911793bf483cbac46bc7..feb867d8613168ee9f1b7eb05dbf36a326a0d3fc 100644 (file)
@@ -5,7 +5,8 @@ from tests import op_fixture, db_for_dialect, eq_, staging_env, \
 from unittest import TestCase
 from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, String
 from sqlalchemy.engine.reflection import Inspector
-from alembic import context, command, util
+from alembic import command, util
+from alembic.migration import MigrationContext
 from alembic.script import ScriptDirectory
 
 class PGOfflineEnumTest(TestCase):
@@ -27,37 +28,37 @@ class PGOfflineEnumTest(TestCase):
         self.script.write(self.rid, """
 down_revision = None
 
-from alembic.op import *
+from alembic import op
 from sqlalchemy.dialects.postgresql import ENUM
 from sqlalchemy import Column
 
 def upgrade():
-    create_table("sometable", 
+    op.create_table("sometable", 
         Column("data", ENUM("one", "two", "three", name="pgenum"))
     )
 
 def downgrade():
-    drop_table("sometable")
+    op.drop_table("sometable")
 """)
 
     def _distinct_enum_script(self):
         self.script.write(self.rid, """
 down_revision = None
 
-from alembic.op import *
+from alembic import op
 from sqlalchemy.dialects.postgresql import ENUM
 from sqlalchemy import Column
 
 def upgrade():
     enum = ENUM("one", "two", "three", name="pgenum", create_type=False)
-    enum.create(get_bind(), checkfirst=False)
-    create_table("sometable", 
+    enum.create(op.get_bind(), checkfirst=False)
+    op.create_table("sometable", 
         Column("data", enum)
     )
 
 def downgrade():
-    drop_table("sometable")
-    ENUM(name="pgenum").drop(get_bind(), checkfirst=False)
+    op.drop_table("sometable")
+    ENUM(name="pgenum").drop(op.get_bind(), checkfirst=False)
     
 """)
 
@@ -97,17 +98,19 @@ class PostgresqlDefaultCompareTest(TestCase):
     def setup_class(cls):
         cls.bind = db_for_dialect("postgresql")
         staging_env()
-        context.configure(
+        context = MigrationContext.configure(
             connection = cls.bind.connect(),
-            compare_type = True,
-            compare_server_default = True,
+            opts = {
+                'compare_type':True,
+                'compare_server_default':True
+            }
         )
-        connection = context.get_bind()
+        connection = context.bind
         cls.autogen_context = {
             'imports':set(),
             'connection':connection,
             'dialect':connection.dialect,
-            'context':context.get_context()
+            'context':context
             }
 
     @classmethod
@@ -145,7 +148,7 @@ class PostgresqlDefaultCompareTest(TestCase):
         t1.create(self.bind)
         insp = Inspector.from_engine(self.bind)
         cols = insp.get_columns(t1.name)
-        ctx = context.get_context()
+        ctx = self.autogen_context['context']
         return ctx.impl.compare_server_default(
             cols[0],
             col, 
index b4bff6e045651fe0d2659d842519f35ce247a6e7..fd09a85b3031dffafc0273f54c2f87149ef148ab 100644 (file)
@@ -19,7 +19,7 @@ def teardown():
 def test_upgrade_path():
 
     eq_(
-        env.upgrade_from(e.revision, c.revision),
+        env.upgrade_from(e.revision, c.revision, None),
         [
             (d.module.upgrade, c.revision, d.revision),
             (e.module.upgrade, d.revision, e.revision),
@@ -27,7 +27,7 @@ def test_upgrade_path():
     )
 
     eq_(
-        env.upgrade_from(c.revision, None),
+        env.upgrade_from(c.revision, None, None),
         [
             (a.module.upgrade, None, a.revision),
             (b.module.upgrade, a.revision, b.revision),
@@ -38,7 +38,7 @@ def test_upgrade_path():
 def test_downgrade_path():
 
     eq_(
-        env.downgrade_to(c.revision, e.revision),
+        env.downgrade_to(c.revision, e.revision, None),
         [
             (e.module.downgrade, e.revision, e.down_revision),
             (d.module.downgrade, d.revision, d.down_revision),
@@ -46,7 +46,7 @@ def test_downgrade_path():
     )
 
     eq_(
-        env.downgrade_to(None, c.revision),
+        env.downgrade_to(None, c.revision, None),
         [
             (c.module.downgrade, c.revision, c.down_revision),
             (b.module.downgrade, b.revision, b.down_revision),
index ab86f19bc1b85a04a4ebc40ce41d93d655b26bed..af9b1361ac7e5613c06dfd1153c1df7090fcae6a 100644 (file)
@@ -9,7 +9,8 @@ def setup():
     global cfg, env
     env = staging_env()
     cfg = _no_sql_testing_config()
-
+    cfg.set_main_option('dialect_name', 'sqlite')
+    cfg.remove_main_option('url')
     global a, b, c
     a, b, c = three_rev_fixture(cfg)