]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- move towards sqlalchemy test base. autogenerate tests so far
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Sep 2014 19:38:53 +0000 (15:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 13 Sep 2014 19:38:53 +0000 (15:38 -0400)
16 files changed:
.gitignore
alembic/testing/__init__.py [new file with mode: 0644]
alembic/testing/assertions.py [new file with mode: 0644]
alembic/testing/env.py [new file with mode: 0644]
alembic/testing/fixtures.py [new file with mode: 0644]
alembic/testing/mock.py [new file with mode: 0644]
alembic/testing/requirements.py [new file with mode: 0644]
docs/build/changelog.rst
setup.cfg
setup.py
test.cfg [deleted file]
tests/__init__.py
tests/conftest.py [new file with mode: 0755]
tests/requirements.py [new file with mode: 0644]
tests/test_autogen_indexes.py
tests/test_autogenerate.py

index a252ee8eeb0beb8a3d0b7fa247cdeb75ffc998d9..08756184e9dba597484e36a9cd49e48d3370430b 100644 (file)
@@ -8,4 +8,5 @@ alembic.ini
 .venv
 *.egg-info
 .coverage
+coverage.xml
 .tox
diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py
new file mode 100644 (file)
index 0000000..33e14a2
--- /dev/null
@@ -0,0 +1,4 @@
+from .fixtures import TestBase
+from .assertions import eq_, ne_
+
+from sqlalchemy.testing import config
diff --git a/alembic/testing/assertions.py b/alembic/testing/assertions.py
new file mode 100644 (file)
index 0000000..89f32cd
--- /dev/null
@@ -0,0 +1,40 @@
+import re
+from sqlalchemy.engine import default
+from sqlalchemy.testing.assertions import eq_, ne_, is_, assert_raises_message
+from alembic.compat import text_type
+
+
+def eq_ignore_whitespace(a, b, msg=None):
+    a = re.sub(r'^\s+?|\n', "", a)
+    a = re.sub(r' {2,}', " ", a)
+    b = re.sub(r'^\s+?|\n', "", b)
+    b = re.sub(r' {2,}', " ", b)
+    assert a == b, msg or "%r != %r" % (a, b)
+
+
+def assert_compiled(element, assert_string, dialect=None):
+    dialect = _get_dialect(dialect)
+    eq_(
+        text_type(element.compile(dialect=dialect)).
+        replace("\n", "").replace("\t", ""),
+        assert_string.replace("\n", "").replace("\t", "")
+    )
+
+
+_dialects = {}
+
+
+def _get_dialect(name):
+    if name is None or name == 'default':
+        return default.DefaultDialect()
+    else:
+        try:
+            return _dialects[name]
+        except KeyError:
+            dialect_mod = getattr(
+                __import__('sqlalchemy.dialects.%s' % name).dialects, name)
+            _dialects[name] = d = dialect_mod.dialect()
+            if name == 'postgresql':
+                d.implicit_returning = True
+            return d
+
diff --git a/alembic/testing/env.py b/alembic/testing/env.py
new file mode 100644 (file)
index 0000000..493bde2
--- /dev/null
@@ -0,0 +1,248 @@
+#!coding: utf-8
+
+import io
+import os
+import re
+import shutil
+import textwrap
+
+from alembic.compat import u
+from alembic.script import Script, ScriptDirectory
+from alembic import util
+
+staging_directory = 'scratch'
+files_directory = 'files'
+
+
+def staging_env(create=True, template="generic", sourceless=False):
+    from alembic import command, script
+    cfg = _testing_config()
+    if create:
+        path = os.path.join(staging_directory, 'scripts')
+        if os.path.exists(path):
+            shutil.rmtree(path)
+        command.init(cfg, path)
+        if sourceless:
+            try:
+                # do an import so that a .pyc/.pyo is generated.
+                util.load_python_file(path, 'env.py')
+            except AttributeError:
+                # we don't have the migration context set up yet
+                # so running the .env py throws this exception.
+                # theoretically we could be using py_compiler here to
+                # generate .pyc/.pyo without importing but not really
+                # worth it.
+                pass
+            make_sourceless(os.path.join(path, "env.py"))
+
+    sc = script.ScriptDirectory.from_config(cfg)
+    return sc
+
+
+def clear_staging_env():
+    shutil.rmtree(staging_directory, True)
+
+
+def script_file_fixture(txt):
+    dir_ = os.path.join(staging_directory, 'scripts')
+    path = os.path.join(dir_, "script.py.mako")
+    with open(path, 'w') as f:
+        f.write(txt)
+
+
+def env_file_fixture(txt):
+    dir_ = os.path.join(staging_directory, 'scripts')
+    txt = """
+from alembic import context
+
+config = context.config
+""" + txt
+
+    path = os.path.join(dir_, "env.py")
+    pyc_path = util.pyc_file_from_path(path)
+    if os.access(pyc_path, os.F_OK):
+        os.unlink(pyc_path)
+
+    with open(path, 'w') as f:
+        f.write(txt)
+
+
+def _sqlite_testing_config(sourceless=False):
+    dir_ = os.path.join(staging_directory, 'scripts')
+    return _write_config_file("""
+[alembic]
+script_location = %s
+sqlalchemy.url = sqlite:///%s/foo.db
+sourceless = %s
+
+[loggers]
+keys = root
+
+[handlers]
+keys = console
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatters]
+keys = generic
+
+[formatter_generic]
+format = %%(levelname)-5.5s [%%(name)s] %%(message)s
+datefmt = %%H:%%M:%%S
+    """ % (dir_, dir_, "true" if sourceless else "false"))
+
+
+def _no_sql_testing_config(dialect="postgresql", directives=""):
+    """use a postgresql url with no host so that
+    connections guaranteed to fail"""
+    dir_ = os.path.join(staging_directory, 'scripts')
+    return _write_config_file("""
+[alembic]
+script_location = %s
+sqlalchemy.url = %s://
+%s
+
+[loggers]
+keys = root
+
+[handlers]
+keys = console
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatters]
+keys = generic
+
+[formatter_generic]
+format = %%(levelname)-5.5s [%%(name)s] %%(message)s
+datefmt = %%H:%%M:%%S
+
+""" % (dir_, dialect, directives))
+
+
+def _write_config_file(text):
+    cfg = _testing_config()
+    with open(cfg.config_file_name, 'w') as f:
+        f.write(text)
+    return cfg
+
+
+def _testing_config():
+    from alembic.config import Config
+    if not os.access(staging_directory, os.F_OK):
+        os.mkdir(staging_directory)
+    return Config(os.path.join(staging_directory, 'test_alembic.ini'))
+
+
+def write_script(
+        scriptdir, rev_id, content, encoding='ascii', sourceless=False):
+    old = scriptdir._revision_map[rev_id]
+    path = old.path
+
+    content = textwrap.dedent(content)
+    if encoding:
+        content = content.encode(encoding)
+    with open(path, 'wb') as fp:
+        fp.write(content)
+    pyc_path = util.pyc_file_from_path(path)
+    if os.access(pyc_path, os.F_OK):
+        os.unlink(pyc_path)
+    script = Script._from_path(scriptdir, path)
+    old = scriptdir._revision_map[script.revision]
+    if old.down_revision != script.down_revision:
+        raise Exception("Can't change down_revision "
+                        "on a refresh operation.")
+    scriptdir._revision_map[script.revision] = script
+    script.nextrev = old.nextrev
+
+    if sourceless:
+        make_sourceless(path)
+
+
+def make_sourceless(path):
+    # note that if -O is set, you'd see pyo files here,
+    # the pyc util function looks at sys.flags.optimize to handle this
+    pyc_path = util.pyc_file_from_path(path)
+    assert os.access(pyc_path, os.F_OK)
+
+    # look for a non-pep3147 path here.
+    # if not present, need to copy from __pycache__
+    simple_pyc_path = util.simple_pyc_file_from_path(path)
+
+    if not os.access(simple_pyc_path, os.F_OK):
+        shutil.copyfile(pyc_path, simple_pyc_path)
+    os.unlink(path)
+
+
+def three_rev_fixture(cfg):
+    a = util.rev_id()
+    b = util.rev_id()
+    c = util.rev_id()
+
+    script = ScriptDirectory.from_config(cfg)
+    script.generate_revision(a, "revision a", refresh=True)
+    write_script(script, a, """\
+"Rev A"
+revision = '%s'
+down_revision = None
+
+from alembic import op
+
+def upgrade():
+    op.execute("CREATE STEP 1")
+
+def downgrade():
+    op.execute("DROP STEP 1")
+
+""" % a)
+
+    script.generate_revision(b, "revision b", refresh=True)
+    write_script(script, b, u("""# coding: utf-8
+"Rev B, méil"
+revision = '%s'
+down_revision = '%s'
+
+from alembic import op
+
+def upgrade():
+    op.execute("CREATE STEP 2")
+
+def downgrade():
+    op.execute("DROP STEP 2")
+
+""") % (b, a), encoding="utf-8")
+
+    script.generate_revision(c, "revision c", refresh=True)
+    write_script(script, c, """\
+"Rev C"
+revision = '%s'
+down_revision = '%s'
+
+from alembic import op
+
+def upgrade():
+    op.execute("CREATE STEP 3")
+
+def downgrade():
+    op.execute("DROP STEP 3")
+
+""" % (c, b))
+    return a, b, c
diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py
new file mode 100644 (file)
index 0000000..2b78639
--- /dev/null
@@ -0,0 +1,149 @@
+# coding: utf-8
+import io
+import os
+import re
+import shutil
+import textwrap
+
+from nose import SkipTest
+from sqlalchemy.engine import default
+from sqlalchemy import create_engine, text, MetaData
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.util import decorator
+
+import alembic
+from alembic.compat import configparser
+from alembic import util
+from alembic.compat import string_types, text_type, u, py33
+from alembic.migration import MigrationContext
+from alembic.environment import EnvironmentContext
+from alembic.operations import Operations
+from alembic.script import ScriptDirectory, Script
+from alembic.ddl.impl import _impls
+from contextlib import contextmanager
+
+from sqlalchemy.testing.fixtures import TestBase
+from .assertions import _get_dialect, eq_
+
+testing_config = configparser.ConfigParser()
+testing_config.read(['test.cfg'])
+
+
+def capture_db():
+    buf = []
+
+    def dump(sql, *multiparams, **params):
+        buf.append(str(sql.compile(dialect=engine.dialect)))
+    engine = create_engine("postgresql://", strategy="mock", executor=dump)
+    return engine, buf
+
+_engs = {}
+
+
+@decorator
+def requires_08(fn, *arg, **kw):
+    if not util.sqla_08:
+        raise SkipTest("SQLAlchemy 0.8.0b2 or greater required")
+    return fn(*arg, **kw)
+
+
+@decorator
+def requires_09(fn, *arg, **kw):
+    if not util.sqla_09:
+        raise SkipTest("SQLAlchemy 0.9 or greater required")
+    return fn(*arg, **kw)
+
+
+@decorator
+def requires_092(fn, *arg, **kw):
+    if not util.sqla_092:
+        raise SkipTest("SQLAlchemy 0.9.2 or greater required")
+    return fn(*arg, **kw)
+
+
+@decorator
+def requires_094(fn, *arg, **kw):
+    if not util.sqla_094:
+        raise SkipTest("SQLAlchemy 0.9.4 or greater required")
+    return fn(*arg, **kw)
+
+
+@contextmanager
+def capture_context_buffer(**kw):
+    if kw.pop('bytes_io', False):
+        buf = io.BytesIO()
+    else:
+        buf = io.StringIO()
+
+    kw.update({
+        'dialect_name': "sqlite",
+        'output_buffer': buf
+    })
+    conf = EnvironmentContext.configure
+
+    def configure(*arg, **opt):
+        opt.update(**kw)
+        return conf(*arg, **opt)
+
+    with mock.patch.object(EnvironmentContext, "configure", configure):
+        yield buf
+
+
+def op_fixture(dialect='default', as_sql=False, naming_convention=None):
+    impl = _impls[dialect]
+
+    class Impl(impl):
+
+        def __init__(self, dialect, as_sql):
+            self.assertion = []
+            self.dialect = dialect
+            self.as_sql = as_sql
+            # TODO: this might need to
+            # be more like a real connection
+            # as tests get more involved
+            self.connection = None
+
+        def _exec(self, construct, *args, **kw):
+            if isinstance(construct, string_types):
+                construct = text(construct)
+            assert construct.supports_execution
+            sql = text_type(construct.compile(dialect=self.dialect))
+            sql = re.sub(r'[\n\t]', '', sql)
+            self.assertion.append(
+                sql
+            )
+
+    opts = {}
+    if naming_convention:
+        if not util.sqla_092:
+            raise SkipTest(
+                "naming_convention feature requires "
+                "sqla 0.9.2 or greater")
+        opts['target_metadata'] = MetaData(naming_convention=naming_convention)
+
+    class ctx(MigrationContext):
+
+        def __init__(self, dialect='default', as_sql=False):
+            self.dialect = _get_dialect(dialect)
+            self.impl = Impl(self.dialect, as_sql)
+            self.opts = opts
+            self.as_sql = as_sql
+
+        def assert_(self, *sql):
+            # TODO: make this more flexible about
+            # whitespace and such
+            eq_(self.impl.assertion, list(sql))
+
+        def assert_contains(self, sql):
+            for stmt in self.impl.assertion:
+                if sql in stmt:
+                    return
+            else:
+                assert False, "Could not locate fragment %r in %r" % (
+                    sql,
+                    self.impl.assertion
+                )
+    context = ctx(dialect, as_sql)
+    alembic.op._proxy = Operations(context)
+    return context
+
diff --git a/alembic/testing/mock.py b/alembic/testing/mock.py
new file mode 100644 (file)
index 0000000..f8162f8
--- /dev/null
@@ -0,0 +1,2 @@
+from sqlalchemy.testing import mock
+from sqlalchemy.testing.mock import Mock, call, patch
diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py
new file mode 100644 (file)
index 0000000..2d2b678
--- /dev/null
@@ -0,0 +1,12 @@
+from sqlalchemy.testing.requirements import Requirements
+from sqlalchemy.testing import exclusions
+
+
+class SuiteRequirements(Requirements):
+    @property
+    def schemas(self):
+        """Target database must support external schemas, and have one
+        named 'test_schema'."""
+
+        return exclusions.open()
+
index 6f23270cb367c5f4f17103c20df1124570670cad..3207e688a440418b765a11a7268a1877b662d226 100644 (file)
@@ -2,6 +2,14 @@
 ==========
 Changelog
 ==========
+.. changelog::
+    :version: 0.7.0
+
+    .. change::
+      :tags: change
+
+      Minimum SQLAlchemy version is now 0.8.4.
+
 .. changelog::
     :version: 0.6.7
     :released: September 9, 2014
index 0052af881d4838b591d3bf9f1b3c44e7ae330813..ab40b6b82bcf4700b6a4777b01bef148f3d774dd 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -11,6 +11,22 @@ upload-dir = docs/build/output/html
 sign = 1
 identity = C4DAFEE1
 
+
+[sqla_testing]
+requirement_cls=tests.requirements:DefaultRequirements
+profile_file=tests/profiles.txt
+
+
+[db]
+default=sqlite:///:memory:
+sqlite=sqlite:///:memory:
+sqlite_file=sqlite:///querytest.db
+postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
+mysql=mysql+mysqlconnector://scott:tiger@127.0.0.1:3306/test
+mssql=mssql+pyodbc://scott:tiger@ms_2008
+oracle=oracle://scott:tiger@127.0.0.1:1521
+oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
+
 [alembic]
 
 
index cf9542e8cbdf7b493ebb48bfd61a2b4575b9d8ea..3918034cb6aa0583cec12ef49e800db8c17dc872 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -11,7 +11,7 @@ v.close()
 readme = os.path.join(os.path.dirname(__file__), 'README.rst')
 
 requires = [
-    'SQLAlchemy>=0.7.3',
+    'SQLAlchemy>=0.8.4',
     'Mako',
 ]
 
diff --git a/test.cfg b/test.cfg
deleted file mode 100644 (file)
index 748e980..0000000
--- a/test.cfg
+++ /dev/null
@@ -1,9 +0,0 @@
-[db]
-postgresql = postgresql://scott:tiger@localhost/test
-mysql = mysql://scott:tiger@localhost/test
-mssql = mssql+pyodbc://scott:tiger@ms_2005/
-oracle=oracle://scott:tiger@172.16.248.129/xe
-sybase=sybase+pyodbc://scott:tiger7@sybase/
-firebird=firebird://scott:tiger@localhost/foo.gdb?type_conv=300
-oursql=mysql+oursql://scott:tiger@localhost/test
-pymssql=mssql+pymssql://scott:tiger@ms_2005/
index 3e54640781d381fb557c5166278166c9a98e427b..587f83dec63953cd2efb90400d30c19090d2fe74 100644 (file)
@@ -1,480 +1,2 @@
-# coding: utf-8
-import io
-import os
-import re
-import shutil
-import textwrap
+from alembic.testing.fixtures import *
 
-from nose import SkipTest
-from sqlalchemy.engine import default
-from sqlalchemy import create_engine, text, MetaData
-from sqlalchemy.exc import SQLAlchemyError
-from sqlalchemy.util import decorator
-
-import alembic
-from alembic.compat import configparser
-from alembic import util
-from alembic.compat import string_types, text_type, u, py33
-from alembic.migration import MigrationContext
-from alembic.environment import EnvironmentContext
-from alembic.operations import Operations
-from alembic.script import ScriptDirectory, Script
-from alembic.ddl.impl import _impls
-from contextlib import contextmanager
-
-staging_directory = os.path.join(os.path.dirname(__file__), 'scratch')
-files_directory = os.path.join(os.path.dirname(__file__), 'files')
-
-testing_config = configparser.ConfigParser()
-testing_config.read(['test.cfg'])
-
-if py33:
-    from unittest.mock import Mock, call, patch
-    from unittest import mock
-else:
-    try:
-        from mock import Mock, call, patch
-        import mock
-    except ImportError:
-        raise ImportError(
-            "Alembic's test suite requires the "
-            "'mock' library as of 0.6.1.")
-
-
-def sqlite_db():
-    # sqlite caches table pragma info
-    # per connection, so create a new
-    # engine for each assertion
-    dir_ = os.path.join(staging_directory, 'scripts')
-    return create_engine('sqlite:///%s/foo.db' % dir_)
-
-
-def capture_db():
-    buf = []
-
-    def dump(sql, *multiparams, **params):
-        buf.append(str(sql.compile(dialect=engine.dialect)))
-    engine = create_engine("postgresql://", strategy="mock", executor=dump)
-    return engine, buf
-
-_engs = {}
-
-
-def db_for_dialect(name):
-    if name in _engs:
-        return _engs[name]
-    else:
-        try:
-            cfg = testing_config.get("db", name)
-        except configparser.NoOptionError:
-            raise SkipTest("No dialect %r in test.cfg" % name)
-        try:
-            eng = create_engine(cfg, echo='debug')
-        except ImportError as er1:
-            raise SkipTest("Can't import DBAPI: %s" % er1)
-        try:
-            eng.connect()
-        except SQLAlchemyError as er2:
-            raise SkipTest("Can't connect to database: %s" % er2)
-        _engs[name] = eng
-        return eng
-
-
-@decorator
-def requires_08(fn, *arg, **kw):
-    if not util.sqla_08:
-        raise SkipTest("SQLAlchemy 0.8.0b2 or greater required")
-    return fn(*arg, **kw)
-
-
-@decorator
-def requires_09(fn, *arg, **kw):
-    if not util.sqla_09:
-        raise SkipTest("SQLAlchemy 0.9 or greater required")
-    return fn(*arg, **kw)
-
-
-@decorator
-def requires_092(fn, *arg, **kw):
-    if not util.sqla_092:
-        raise SkipTest("SQLAlchemy 0.9.2 or greater required")
-    return fn(*arg, **kw)
-
-
-@decorator
-def requires_094(fn, *arg, **kw):
-    if not util.sqla_094:
-        raise SkipTest("SQLAlchemy 0.9.4 or greater required")
-    return fn(*arg, **kw)
-
-_dialects = {}
-
-
-def _get_dialect(name):
-    if name is None or name == 'default':
-        return default.DefaultDialect()
-    else:
-        try:
-            return _dialects[name]
-        except KeyError:
-            dialect_mod = getattr(
-                __import__('sqlalchemy.dialects.%s' % name).dialects, name)
-            _dialects[name] = d = dialect_mod.dialect()
-            if name == 'postgresql':
-                d.implicit_returning = True
-            return d
-
-
-def assert_compiled(element, assert_string, dialect=None):
-    dialect = _get_dialect(dialect)
-    eq_(
-        text_type(element.compile(dialect=dialect)).
-        replace("\n", "").replace("\t", ""),
-        assert_string.replace("\n", "").replace("\t", "")
-    )
-
-
-@contextmanager
-def capture_context_buffer(**kw):
-    if kw.pop('bytes_io', False):
-        buf = io.BytesIO()
-    else:
-        buf = io.StringIO()
-
-    kw.update({
-        'dialect_name': "sqlite",
-        'output_buffer': buf
-    })
-    conf = EnvironmentContext.configure
-
-    def configure(*arg, **opt):
-        opt.update(**kw)
-        return conf(*arg, **opt)
-
-    with mock.patch.object(EnvironmentContext, "configure", configure):
-        yield buf
-
-
-def eq_ignore_whitespace(a, b, msg=None):
-    a = re.sub(r'^\s+?|\n', "", a)
-    a = re.sub(r' {2,}', " ", a)
-    b = re.sub(r'^\s+?|\n', "", b)
-    b = re.sub(r' {2,}', " ", b)
-    assert a == b, msg or "%r != %r" % (a, b)
-
-
-def eq_(a, b, msg=None):
-    """Assert a == b, with repr messaging on failure."""
-    assert a == b, msg or "%r != %r" % (a, b)
-
-
-def ne_(a, b, msg=None):
-    """Assert a != b, with repr messaging on failure."""
-    assert a != b, msg or "%r == %r" % (a, b)
-
-
-def is_(a, b, msg=None):
-    """Assert a is b, with repr messaging on failure."""
-    assert a is b, msg or "%r is not %r" % (a, b)
-
-
-def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
-    try:
-        callable_(*args, **kwargs)
-        assert False, "Callable did not raise an exception"
-    except except_cls as e:
-        assert re.search(msg, str(e)), "%r !~ %s" % (msg, e)
-        print(text_type(e))
-
-
-def op_fixture(dialect='default', as_sql=False, naming_convention=None):
-    impl = _impls[dialect]
-
-    class Impl(impl):
-
-        def __init__(self, dialect, as_sql):
-            self.assertion = []
-            self.dialect = dialect
-            self.as_sql = as_sql
-            # TODO: this might need to
-            # be more like a real connection
-            # as tests get more involved
-            self.connection = None
-
-        def _exec(self, construct, *args, **kw):
-            if isinstance(construct, string_types):
-                construct = text(construct)
-            assert construct.supports_execution
-            sql = text_type(construct.compile(dialect=self.dialect))
-            sql = re.sub(r'[\n\t]', '', sql)
-            self.assertion.append(
-                sql
-            )
-
-    opts = {}
-    if naming_convention:
-        if not util.sqla_092:
-            raise SkipTest(
-                "naming_convention feature requires "
-                "sqla 0.9.2 or greater")
-        opts['target_metadata'] = MetaData(naming_convention=naming_convention)
-
-    class ctx(MigrationContext):
-
-        def __init__(self, dialect='default', as_sql=False):
-            self.dialect = _get_dialect(dialect)
-            self.impl = Impl(self.dialect, as_sql)
-            self.opts = opts
-            self.as_sql = as_sql
-
-        def assert_(self, *sql):
-            # TODO: make this more flexible about
-            # whitespace and such
-            eq_(self.impl.assertion, list(sql))
-
-        def assert_contains(self, sql):
-            for stmt in self.impl.assertion:
-                if sql in stmt:
-                    return
-            else:
-                assert False, "Could not locate fragment %r in %r" % (
-                    sql,
-                    self.impl.assertion
-                )
-    context = ctx(dialect, as_sql)
-    alembic.op._proxy = Operations(context)
-    return context
-
-
-def script_file_fixture(txt):
-    dir_ = os.path.join(staging_directory, 'scripts')
-    path = os.path.join(dir_, "script.py.mako")
-    with open(path, 'w') as f:
-        f.write(txt)
-
-
-def env_file_fixture(txt):
-    dir_ = os.path.join(staging_directory, 'scripts')
-    txt = """
-from alembic import context
-
-config = context.config
-""" + txt
-
-    path = os.path.join(dir_, "env.py")
-    pyc_path = util.pyc_file_from_path(path)
-    if os.access(pyc_path, os.F_OK):
-        os.unlink(pyc_path)
-
-    with open(path, 'w') as f:
-        f.write(txt)
-
-
-def _sqlite_testing_config(sourceless=False):
-    dir_ = os.path.join(staging_directory, 'scripts')
-    return _write_config_file("""
-[alembic]
-script_location = %s
-sqlalchemy.url = sqlite:///%s/foo.db
-sourceless = %s
-
-[loggers]
-keys = root
-
-[handlers]
-keys = console
-
-[logger_root]
-level = WARN
-handlers = console
-qualname =
-
-[handler_console]
-class = StreamHandler
-args = (sys.stderr,)
-level = NOTSET
-formatter = generic
-
-[formatters]
-keys = generic
-
-[formatter_generic]
-format = %%(levelname)-5.5s [%%(name)s] %%(message)s
-datefmt = %%H:%%M:%%S
-    """ % (dir_, dir_, "true" if sourceless else "false"))
-
-
-def _no_sql_testing_config(dialect="postgresql", directives=""):
-    """use a postgresql url with no host so that
-    connections guaranteed to fail"""
-    dir_ = os.path.join(staging_directory, 'scripts')
-    return _write_config_file("""
-[alembic]
-script_location = %s
-sqlalchemy.url = %s://
-%s
-
-[loggers]
-keys = root
-
-[handlers]
-keys = console
-
-[logger_root]
-level = WARN
-handlers = console
-qualname =
-
-[handler_console]
-class = StreamHandler
-args = (sys.stderr,)
-level = NOTSET
-formatter = generic
-
-[formatters]
-keys = generic
-
-[formatter_generic]
-format = %%(levelname)-5.5s [%%(name)s] %%(message)s
-datefmt = %%H:%%M:%%S
-
-""" % (dir_, dialect, directives))
-
-
-def _write_config_file(text):
-    cfg = _testing_config()
-    with open(cfg.config_file_name, 'w') as f:
-        f.write(text)
-    return cfg
-
-
-def _testing_config():
-    from alembic.config import Config
-    if not os.access(staging_directory, os.F_OK):
-        os.mkdir(staging_directory)
-    return Config(os.path.join(staging_directory, 'test_alembic.ini'))
-
-
-def staging_env(create=True, template="generic", sourceless=False):
-    from alembic import command, script
-    cfg = _testing_config()
-    if create:
-        path = os.path.join(staging_directory, 'scripts')
-        if os.path.exists(path):
-            shutil.rmtree(path)
-        command.init(cfg, path)
-        if sourceless:
-            try:
-                # do an import so that a .pyc/.pyo is generated.
-                util.load_python_file(path, 'env.py')
-            except AttributeError:
-                # we don't have the migration context set up yet
-                # so running the .env py throws this exception.
-                # theoretically we could be using py_compiler here to
-                # generate .pyc/.pyo without importing but not really
-                # worth it.
-                pass
-            make_sourceless(os.path.join(path, "env.py"))
-
-    sc = script.ScriptDirectory.from_config(cfg)
-    return sc
-
-
-def clear_staging_env():
-    shutil.rmtree(staging_directory, True)
-
-
-def write_script(
-        scriptdir, rev_id, content, encoding='ascii', sourceless=False):
-    old = scriptdir._revision_map[rev_id]
-    path = old.path
-
-    content = textwrap.dedent(content)
-    if encoding:
-        content = content.encode(encoding)
-    with open(path, 'wb') as fp:
-        fp.write(content)
-    pyc_path = util.pyc_file_from_path(path)
-    if os.access(pyc_path, os.F_OK):
-        os.unlink(pyc_path)
-    script = Script._from_path(scriptdir, path)
-    old = scriptdir._revision_map[script.revision]
-    if old.down_revision != script.down_revision:
-        raise Exception("Can't change down_revision "
-                        "on a refresh operation.")
-    scriptdir._revision_map[script.revision] = script
-    script.nextrev = old.nextrev
-
-    if sourceless:
-        make_sourceless(path)
-
-
-def make_sourceless(path):
-    # note that if -O is set, you'd see pyo files here,
-    # the pyc util function looks at sys.flags.optimize to handle this
-    pyc_path = util.pyc_file_from_path(path)
-    assert os.access(pyc_path, os.F_OK)
-
-    # look for a non-pep3147 path here.
-    # if not present, need to copy from __pycache__
-    simple_pyc_path = util.simple_pyc_file_from_path(path)
-
-    if not os.access(simple_pyc_path, os.F_OK):
-        shutil.copyfile(pyc_path, simple_pyc_path)
-    os.unlink(path)
-
-
-def three_rev_fixture(cfg):
-    a = util.rev_id()
-    b = util.rev_id()
-    c = util.rev_id()
-
-    script = ScriptDirectory.from_config(cfg)
-    script.generate_revision(a, "revision a", refresh=True)
-    write_script(script, a, """\
-"Rev A"
-revision = '%s'
-down_revision = None
-
-from alembic import op
-
-def upgrade():
-    op.execute("CREATE STEP 1")
-
-def downgrade():
-    op.execute("DROP STEP 1")
-
-""" % a)
-
-    script.generate_revision(b, "revision b", refresh=True)
-    write_script(script, b, u("""# coding: utf-8
-"Rev B, méil"
-revision = '%s'
-down_revision = '%s'
-
-from alembic import op
-
-def upgrade():
-    op.execute("CREATE STEP 2")
-
-def downgrade():
-    op.execute("DROP STEP 2")
-
-""") % (b, a), encoding="utf-8")
-
-    script.generate_revision(c, "revision c", refresh=True)
-    write_script(script, c, """\
-"Rev C"
-revision = '%s'
-down_revision = '%s'
-
-from alembic import op
-
-def upgrade():
-    op.execute("CREATE STEP 3")
-
-def downgrade():
-    op.execute("DROP STEP 3")
-
-""" % (c, b))
-    return a, b, c
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100755 (executable)
index 0000000..1dd4423
--- /dev/null
@@ -0,0 +1,15 @@
+#!/usr/bin/env python
+"""
+pytest plugin script.
+
+This script is an extension to py.test which
+installs SQLAlchemy's testing plugin into the local environment.
+
+"""
+import sys
+
+from os import path
+for pth in ['../lib']:
+    sys.path.insert(0, path.join(path.dirname(path.abspath(__file__)), pth))
+
+from sqlalchemy.testing.plugin.pytestplugin import *
diff --git a/tests/requirements.py b/tests/requirements.py
new file mode 100644 (file)
index 0000000..b158dff
--- /dev/null
@@ -0,0 +1,14 @@
+from alembic.testing.requirements import SuiteRequirements
+from sqlalchemy.testing import exclusions
+
+
+class DefaultRequirements(SuiteRequirements):
+    @property
+    def schemas(self):
+        """Target database must support external schemas, and have one
+        named 'test_schema'."""
+
+        return exclusions.skip_if([
+            "sqlite",
+            "firebird"
+        ], "no schema support")
index 26a148ecdfdbff3012ff5785b09ed46e5d5bf8d8..934fc432658894aea978e7a24078f977371c3c34 100644 (file)
@@ -1,5 +1,6 @@
 import sys
-from unittest import TestCase
+from alembic.testing import TestBase
+from alembic.testing import config
 
 from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
     Numeric, DATETIME, INTEGER, \
@@ -8,14 +9,26 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
     PrimaryKeyConstraint, Index, func, ForeignKeyConstraint,\
     ForeignKey
 from sqlalchemy.schema import AddConstraint
-from . import sqlite_db, eq_, db_for_dialect
+from sqlalchemy.testing import engines
+from alembic.testing import eq_
+from alembic.testing.env import staging_env
 
 py3k = sys.version_info >= (3, )
 
 from .test_autogenerate import AutogenFixtureTest
 
 
-class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
+class NoUqReflection(object):
+    def setUp(self):
+        staging_env()
+        self.bind = eng = engines.testing_engine()
+
+        def unimpl(*arg, **kw):
+            raise NotImplementedError()
+        eng.dialect.get_unique_constraints = unimpl
+
+
+class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
     reports_unique_constraints = True
 
     def test_index_flag_becomes_named_unique_constraint(self):
@@ -435,10 +448,7 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase):
 
 class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
     reports_unnamed_constraints = True
-
-    @classmethod
-    def _get_bind(cls):
-        return db_for_dialect('postgresql')
+    __only_on__ = "postgresql"
 
     def test_idx_added_schema(self):
         m1 = MetaData()
@@ -502,6 +512,7 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest):
 
 class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest):
     reports_unnamed_constraints = True
+    __only_on__ = 'mysql'
 
     def test_removed_idx_index_named_as_column(self):
         try:
@@ -512,22 +523,10 @@ class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest):
         else:
             assert False, "unexpected success"
 
-    @classmethod
-    def _get_bind(cls):
-        return db_for_dialect('mysql')
 
-
-class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest):
+class NoUqReflectionIndexTest(NoUqReflection, AutogenerateUniqueIndexTest):
     reports_unique_constraints = False
-
-    @classmethod
-    def _get_bind(cls):
-        eng = sqlite_db()
-
-        def unimpl(*arg, **kw):
-            raise NotImplementedError()
-        eng.dialect.get_unique_constraints = unimpl
-        return eng
+    __only_on__ = 'sqlite'
 
     def test_unique_not_reported(self):
         m1 = MetaData()
@@ -596,9 +595,11 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest):
 
     """
 
+    __only_on__ = 'sqlite'
+
     @classmethod
     def _get_bind(cls):
-        eng = sqlite_db()
+        eng = config.db
 
         _get_unique_constraints = eng.dialect.get_unique_constraints
         _get_indexes = eng.dialect.get_indexes
index b8c4b2a36e3082a9545748ae32687db7bb02f8c0..8733150c8b794de6c6fee861099ad6a33a5a38ad 100644 (file)
@@ -1,7 +1,5 @@
 import re
 import sys
-from unittest import TestCase
-from . import Mock
 
 from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
     Numeric, CHAR, ForeignKey, DATETIME, INTEGER, \
@@ -13,8 +11,11 @@ from sqlalchemy.engine.reflection import Inspector
 
 from alembic import autogenerate
 from alembic.migration import MigrationContext
-from . import staging_env, sqlite_db, clear_staging_env, eq_, \
-    db_for_dialect
+from alembic.testing import TestBase
+from alembic.testing import config
+from alembic.testing.mock import Mock
+from alembic.testing.env import staging_env, clear_staging_env
+from alembic.testing import eq_
 
 py3k = sys.version_info >= (3, )
 
@@ -39,10 +40,11 @@ def new_table(table, parent):
 
 
 class AutogenTest(object):
+    __only_on__ = 'sqlite'
 
     @classmethod
     def _get_bind(cls):
-        return sqlite_db()
+        return config.db
 
     @classmethod
     def setup_class(cls):
@@ -120,24 +122,16 @@ class AutogenFixtureTest(object):
 
     def setUp(self):
         staging_env()
-        self.bind = self._get_bind()
+        self.bind = config.db
 
     def tearDown(self):
         if hasattr(self, 'metadata'):
             self.metadata.drop_all(self.bind)
         clear_staging_env()
 
-    @classmethod
-    def _get_bind(cls):
-        return sqlite_db()
-
-
-class AutogenCrossSchemaTest(AutogenTest, TestCase):
 
-    @classmethod
-    def _get_bind(cls):
-        cls.test_schema_name = "test_schema"
-        return db_for_dialect('postgresql')
+class AutogenCrossSchemaTest(AutogenTest, TestBase):
+    __only_on__ = 'postgresql'
 
     @classmethod
     def _get_db_schema(cls):
@@ -147,14 +141,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
               )
         Table('t2', m,
               Column('y', Integer),
-              schema=cls.test_schema_name
+              schema=config.test_schema
               )
         Table('t6', m,
               Column('u', Integer)
               )
         Table('t7', m,
               Column('v', Integer),
-              schema=cls.test_schema_name
+              schema=config.test_schema
               )
 
         return m
@@ -167,14 +161,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
               )
         Table('t4', m,
               Column('z', Integer),
-              schema=cls.test_schema_name
+              schema=config.test_schema
               )
         Table('t6', m,
               Column('u', Integer)
               )
         Table('t7', m,
               Column('v', Integer),
-              schema=cls.test_schema_name
+              schema=config.test_schema
               )
         return m
 
@@ -212,7 +206,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
                                           include_schemas=True
                                           )
         eq_(diffs[0][0], "add_table")
-        eq_(diffs[0][1].schema, self.test_schema_name)
+        eq_(diffs[0][1].schema, config.test_schema)
 
     def test_default_schema_omitted_downgrade(self):
         metadata = self.m2
@@ -248,15 +242,11 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase):
                                           include_schemas=True
                                           )
         eq_(diffs[0][0], "remove_table")
-        eq_(diffs[0][1].schema, self.test_schema_name)
+        eq_(diffs[0][1].schema, config.test_schema)
 
 
-class AutogenDefaultSchemaTest(AutogenFixtureTest, TestCase):
-
-    @classmethod
-    def _get_bind(cls):
-        cls.test_schema_name = "test_schema"
-        return db_for_dialect('postgresql')
+class AutogenDefaultSchemaTest(AutogenFixtureTest, TestBase):
+    __only_on__ = 'postgresql'
 
     def test_uses_explcit_schema_in_default_one(self):
 
@@ -377,7 +367,7 @@ class ModelOne(object):
         return m
 
 
-class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase):
+class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
 
     def test_diffs(self):
         """test generation of diff rules"""
@@ -640,13 +630,10 @@ nullable=True))
         )
 
 
-class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase):
+class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
+    __only_on__ = 'postgresql'
     schema = "test_schema"
 
-    @classmethod
-    def _get_bind(cls):
-        return db_for_dialect('postgresql')
-
     def test_diffs(self):
         """test generation of diff rules"""
 
@@ -801,7 +788,7 @@ name='extra_uid_fkey'),
     ### end Alembic commands ###""" % {"schema": self.schema})
 
 
-class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase):
+class AutogenerateCustomCompareTypeTest(AutogenTest, TestBase):
 
     @classmethod
     def _get_db_schema(cls):
@@ -868,7 +855,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase):
         eq_(diffs[1][0][0], 'modify_type')
 
 
-class AutogenKeyTest(AutogenTest, TestCase):
+class AutogenKeyTest(AutogenTest, TestBase):
 
     @classmethod
     def _get_db_schema(cls):
@@ -913,7 +900,7 @@ class AutogenKeyTest(AutogenTest, TestCase):
         eq_(diffs[1][3].key, "otherkey")
 
 
-class AutogenerateDiffOrderTest(AutogenTest, TestCase):
+class AutogenerateDiffOrderTest(AutogenTest, TestBase):
 
     @classmethod
     def _get_db_schema(cls):
@@ -952,7 +939,7 @@ class AutogenerateDiffOrderTest(AutogenTest, TestCase):
         eq_(diffs[1][1].name, "child")
 
 
-class CompareMetadataTest(ModelOne, AutogenTest, TestCase):
+class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
 
     def test_compare_metadata(self):
         metadata = self.m2
@@ -1065,13 +1052,10 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestCase):
         eq_(diffs[2][1][6], True)
 
 
-class PGCompareMetaData(ModelOne, AutogenTest, TestCase):
+class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
+    __only_on__ = 'postgresql'
     schema = "test_schema"
 
-    @classmethod
-    def _get_bind(cls):
-        return db_for_dialect('postgresql')
-
     def test_compare_metadata_schema(self):
         metadata = self.m2