]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- add plugin directory
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Sep 2014 15:45:04 +0000 (11:45 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Sep 2014 15:45:04 +0000 (11:45 -0400)
alembic/testing/plugin/__init__.py [new file with mode: 0644]
alembic/testing/plugin/noseplugin.py [new file with mode: 0644]
alembic/testing/plugin/plugin_base.py [new file with mode: 0644]
alembic/testing/plugin/provision.py [new file with mode: 0644]
alembic/testing/plugin/pytestplugin.py [new file with mode: 0644]

diff --git a/alembic/testing/plugin/__init__.py b/alembic/testing/plugin/__init__.py
new file mode 100644 (file)
index 0000000..3e00e40
--- /dev/null
@@ -0,0 +1,3 @@
+"""NOTE:  copied/adapted from SQLAlchemy master for backwards compatibility;
+   this should be removable when Alembic targets SQLAlchemy 0.9.4.
+"""
diff --git a/alembic/testing/plugin/noseplugin.py b/alembic/testing/plugin/noseplugin.py
new file mode 100644 (file)
index 0000000..2b37dc9
--- /dev/null
@@ -0,0 +1,104 @@
+# plugin/noseplugin.py
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""NOTE:  copied/adapted from SQLAlchemy master for backwards compatibility;
+   this should be removable when Alembic targets SQLAlchemy 0.9.4.
+"""
+
+"""Enhance nose with extra options and behaviors for running SQLAlchemy tests.
+
+
+"""
+
+import os
+import sys
+
+from nose.plugins import Plugin
+fixtures = None
+
+py3k = sys.version_info >= (3, 0)
+# no package imports yet!  this prevents us from tripping coverage
+# too soon.
+path = os.path.join(os.path.dirname(__file__), "plugin_base.py")
+if sys.version_info >= (3, 3):
+    from importlib import machinery
+    plugin_base = machinery.SourceFileLoader(
+        "plugin_base", path).load_module()
+else:
+    import imp
+    plugin_base = imp.load_source("plugin_base", path)
+
+
+class NoseSQLAlchemy(Plugin):
+    enabled = True
+
+    name = 'sqla_testing'
+    score = 100
+
+    def options(self, parser, env=os.environ):
+        Plugin.options(self, parser, env)
+        opt = parser.add_option
+
+        def make_option(name, **kw):
+            callback_ = kw.pop("callback", None)
+            if callback_:
+                def wrap_(option, opt_str, value, parser):
+                    callback_(opt_str, value, parser)
+                kw["callback"] = wrap_
+            opt(name, **kw)
+
+        plugin_base.setup_options(make_option)
+        plugin_base.read_config()
+
+    def configure(self, options, conf):
+        super(NoseSQLAlchemy, self).configure(options, conf)
+        plugin_base.pre_begin(options)
+
+        plugin_base.set_coverage_flag(options.enable_plugin_coverage)
+
+        global fixtures
+        from sqlalchemy.testing import fixtures
+
+    def begin(self):
+        plugin_base.post_begin()
+
+    def describeTest(self, test):
+        return ""
+
+    def wantFunction(self, fn):
+        return False
+
+    def wantMethod(self, fn):
+        if py3k:
+            cls = fn.__self__.cls
+        else:
+            cls = fn.im_class
+        print "METH:", fn, "CLS:", cls
+        return plugin_base.want_method(cls, fn)
+
+    def wantClass(self, cls):
+        return plugin_base.want_class(cls)
+
+    def beforeTest(self, test):
+        plugin_base.before_test(test,
+                                test.test.cls.__module__,
+                                test.test.cls, test.test.method.__name__)
+
+    def afterTest(self, test):
+        plugin_base.after_test(test)
+
+    def startContext(self, ctx):
+        if not isinstance(ctx, type) \
+                or not issubclass(ctx, fixtures.TestBase):
+            return
+        plugin_base.start_test_class(ctx)
+
+    def stopContext(self, ctx):
+        if not isinstance(ctx, type) \
+                or not issubclass(ctx, fixtures.TestBase):
+            return
+        plugin_base.stop_test_class(ctx)
diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py
new file mode 100644 (file)
index 0000000..577134d
--- /dev/null
@@ -0,0 +1,557 @@
+# plugin/plugin_base.py
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Testing extensions.
+
+this module is designed to work as a testing-framework-agnostic library,
+so that we can continue to support nose and also begin adding new
+functionality via py.test.
+
+NOTE:  copied/adapted from SQLAlchemy master for backwards compatibility;
+   this should be removable when Alembic targets SQLAlchemy 0.9.4.
+
+
+"""
+
+from __future__ import absolute_import
+try:
+    # unitttest has a SkipTest also but pytest doesn't
+    # honor it unless nose is imported too...
+    from nose import SkipTest
+except ImportError:
+    from _pytest.runner import Skipped as SkipTest
+
+import sys
+import re
+
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+    import configparser
+else:
+    import ConfigParser as configparser
+
+FOLLOWER_IDENT = None
+
+# late imports
+fixtures = None
+engines = None
+exclusions = None
+warnings = None
+assertions = None
+requirements = None
+config = None
+util = None
+file_config = None
+
+
+logging = None
+db_opts = {}
+include_tags = set()
+exclude_tags = set()
+options = None
+
+
+def setup_options(make_option):
+    make_option("--log-info", action="callback", type="string", callback=_log,
+                help="turn on info logging for <LOG> (multiple OK)")
+    make_option("--log-debug", action="callback",
+                type="string", callback=_log,
+                help="turn on debug logging for <LOG> (multiple OK)")
+    make_option("--db", action="append", type="string", dest="db",
+                help="Use prefab database uri. Multiple OK, "
+                "first one is run by default.")
+    make_option('--dbs', action='callback', callback=_list_dbs,
+                help="List available prefab dbs")
+    make_option("--dburi", action="append", type="string", dest="dburi",
+                help="Database uri.  Multiple OK, "
+                "first one is run by default.")
+    make_option("--dropfirst", action="store_true", dest="dropfirst",
+                help="Drop all tables in the target database first")
+    make_option("--backend-only", action="store_true", dest="backend_only",
+                help="Run only tests marked with __backend__")
+    make_option("--mockpool", action="store_true", dest="mockpool",
+                help="Use mock pool (asserts only one connection used)")
+    make_option("--low-connections", action="store_true",
+                dest="low_connections",
+                help="Use a low number of distinct connections - "
+                "i.e. for Oracle TNS")
+    make_option("--reversetop", action="store_true",
+                dest="reversetop", default=False,
+                help="Use a random-ordering set implementation in the ORM "
+                "(helps reveal dependency issues)")
+    make_option("--requirements", action="callback", type="string",
+                callback=_requirements_opt,
+                help="requirements class for testing, overrides setup.cfg")
+    make_option("--with-cdecimal", action="store_true",
+                dest="cdecimal", default=False,
+                help="Monkeypatch the cdecimal library into Python 'decimal' "
+                "for all tests")
+    make_option("--include-tag", action="callback", callback=_include_tag,
+                type="string",
+                help="Include tests with tag <tag>")
+    make_option("--exclude-tag", action="callback", callback=_exclude_tag,
+                type="string",
+                help="Exclude tests with tag <tag>")
+    make_option("--serverside", action="store_true",
+                help="Turn on server side cursors for PG")
+    make_option("--mysql-engine", action="store",
+                dest="mysql_engine", default=None,
+                help="Use the specified MySQL storage engine for all tables, "
+                "default is a db-default/InnoDB combo.")
+
+
+def configure_follower(follower_ident):
+    """Configure required state for a follower.
+
+    This invokes in the parent process and typically includes
+    database creation.
+
+    """
+    global FOLLOWER_IDENT
+    FOLLOWER_IDENT = follower_ident
+
+
+def memoize_important_follower_config(dict_):
+    """Store important configuration we will need to send to a follower.
+
+    This invokes in the parent process after normal config is set up.
+
+    This is necessary as py.test seems to not be using forking, so we
+    start with nothing in memory, *but* it isn't running our argparse
+    callables, so we have to just copy all of that over.
+
+    """
+    dict_['memoized_config'] = {
+        'db_opts': db_opts,
+        'include_tags': include_tags,
+        'exclude_tags': exclude_tags
+    }
+
+
+def restore_important_follower_config(dict_):
+    """Restore important configuration needed by a follower.
+
+    This invokes in the follower process.
+
+    """
+    global db_opts, include_tags, exclude_tags
+    db_opts.update(dict_['memoized_config']['db_opts'])
+    include_tags.update(dict_['memoized_config']['include_tags'])
+    exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+
+
+def read_config():
+    global file_config
+    file_config = configparser.ConfigParser()
+    file_config.read(['setup.cfg', 'test.cfg'])
+
+
+def pre_begin(opt):
+    """things to set up early, before coverage might be setup."""
+    global options
+    options = opt
+    for fn in pre_configure:
+        fn(options, file_config)
+
+
+def set_coverage_flag(value):
+    options.has_coverage = value
+
+
+def post_begin():
+    """things to set up later, once we know coverage is running."""
+    # Lazy setup of other options (post coverage)
+    for fn in post_configure:
+        fn(options, file_config)
+
+    # late imports, has to happen after config as well
+    # as nose plugins like coverage
+    global util, fixtures, engines, exclusions, \
+        assertions, warnings, profiling,\
+        config, testing
+    from alembic.testing import config, warnings, exclusions, engines, fixtures
+    from sqlalchemy import util
+    warnings.setup_filters()
+
+def _log(opt_str, value, parser):
+    global logging
+    if not logging:
+        import logging
+        logging.basicConfig()
+
+    if opt_str.endswith('-info'):
+        logging.getLogger(value).setLevel(logging.INFO)
+    elif opt_str.endswith('-debug'):
+        logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+    print("Available --db options (use --dburi to override)")
+    for macro in sorted(file_config.options('db')):
+        print("%20s\t%s" % (macro, file_config.get('db', macro)))
+    sys.exit(0)
+
+
+def _requirements_opt(opt_str, value, parser):
+    _setup_requirements(value)
+
+
+def _exclude_tag(opt_str, value, parser):
+    exclude_tags.add(value.replace('-', '_'))
+
+
+def _include_tag(opt_str, value, parser):
+    include_tags.add(value.replace('-', '_'))
+
+pre_configure = []
+post_configure = []
+
+
+def pre(fn):
+    pre_configure.append(fn)
+    return fn
+
+
+def post(fn):
+    post_configure.append(fn)
+    return fn
+
+
+@pre
+def _setup_options(opt, file_config):
+    global options
+    options = opt
+
+
+@pre
+def _server_side_cursors(options, file_config):
+    if options.serverside:
+        db_opts['server_side_cursors'] = True
+
+
+@pre
+def _monkeypatch_cdecimal(options, file_config):
+    if options.cdecimal:
+        import cdecimal
+        sys.modules['decimal'] = cdecimal
+
+
+@post
+def _engine_uri(options, file_config):
+    from alembic.testing import config
+    from alembic.testing.plugin import provision
+
+    if options.dburi:
+        db_urls = list(options.dburi)
+    else:
+        db_urls = []
+
+    if options.db:
+        for db_token in options.db:
+            for db in re.split(r'[,\s]+', db_token):
+                if db not in file_config.options('db'):
+                    raise RuntimeError(
+                        "Unknown URI specifier '%s'.  "
+                        "Specify --dbs for known uris."
+                        % db)
+                else:
+                    db_urls.append(file_config.get('db', db))
+
+    if not db_urls:
+        db_urls.append(file_config.get('db', 'default'))
+
+    for db_url in db_urls:
+        cfg = provision.setup_config(
+            db_url, db_opts, options, file_config, FOLLOWER_IDENT)
+
+        if not config._current:
+            cfg.set_as_current(cfg)
+
+
+@post
+def _engine_pool(options, file_config):
+    if options.mockpool:
+        from sqlalchemy import pool
+        db_opts['poolclass'] = pool.AssertionPool
+
+
+@post
+def _requirements(options, file_config):
+
+    requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+    _setup_requirements(requirement_cls)
+
+
+def _setup_requirements(argument):
+    from alembic.testing import config
+
+    if config.requirements is not None:
+        return
+
+    modname, clsname = argument.split(":")
+
+    # importlib.import_module() only introduced in 2.7, a little
+    # late
+    mod = __import__(modname)
+    for component in modname.split(".")[1:]:
+        mod = getattr(mod, component)
+    req_cls = getattr(mod, clsname)
+
+    config.requirements = req_cls()
+
+
+@post
+def _prep_testing_database(options, file_config):
+    from alembic.testing import config
+    from alembic.testing.exclusions import against
+    from sqlalchemy import schema
+    from alembic import util
+
+    if util.sqla_08:
+        from sqlalchemy import inspect
+    else:
+        from sqlalchemy.engine.reflection import Inspector
+        inspect = Inspector.from_engine
+
+    if options.dropfirst:
+        for cfg in config.Config.all_configs():
+            e = cfg.db
+            inspector = inspect(e)
+            try:
+                view_names = inspector.get_view_names()
+            except NotImplementedError:
+                pass
+            else:
+                for vname in view_names:
+                    e.execute(schema._DropView(
+                        schema.Table(vname, schema.MetaData())
+                    ))
+
+            if config.requirements.schemas.enabled_for_config(cfg):
+                try:
+                    view_names = inspector.get_view_names(
+                        schema="test_schema")
+                except NotImplementedError:
+                    pass
+                else:
+                    for vname in view_names:
+                        e.execute(schema._DropView(
+                            schema.Table(vname, schema.MetaData(),
+                                         schema="test_schema")
+                        ))
+
+            for tname in reversed(inspector.get_table_names(
+                    order_by="foreign_key")):
+                e.execute(schema.DropTable(
+                    schema.Table(tname, schema.MetaData())
+                ))
+
+            if config.requirements.schemas.enabled_for_config(cfg):
+                for tname in reversed(inspector.get_table_names(
+                        order_by="foreign_key", schema="test_schema")):
+                    e.execute(schema.DropTable(
+                        schema.Table(tname, schema.MetaData(),
+                                     schema="test_schema")
+                    ))
+
+            if against(cfg, "postgresql"):
+                from sqlalchemy.dialects import postgresql
+                for enum in inspector.get_enums("*"):
+                    e.execute(postgresql.DropEnumType(
+                        postgresql.ENUM(
+                            name=enum['name'],
+                            schema=enum['schema'])))
+
+
+
+
+@post
+def _reverse_topological(options, file_config):
+    if options.reversetop:
+        from sqlalchemy.orm.util import randomize_unitofwork
+        randomize_unitofwork()
+
+
+@post
+def _post_setup_options(opt, file_config):
+    from alembic.testing import config
+    config.options = options
+    config.file_config = file_config
+
+
+
+def want_class(cls):
+    if not issubclass(cls, fixtures.TestBase):
+        return False
+    elif cls.__name__.startswith('_'):
+        return False
+    elif config.options.backend_only and not getattr(cls, '__backend__',
+                                                     False):
+        return False
+    else:
+        return True
+
+
+def want_method(cls, fn):
+    if not fn.__name__.startswith("test_"):
+        return False
+    elif fn.__module__ is None:
+        return False
+    elif include_tags:
+        return (
+            hasattr(cls, '__tags__') and
+            exclusions.tags(cls.__tags__).include_test(
+                include_tags, exclude_tags)
+        ) or (
+            hasattr(fn, '_sa_exclusion_extend') and
+            fn._sa_exclusion_extend.include_test(
+                include_tags, exclude_tags)
+        )
+    elif exclude_tags and hasattr(cls, '__tags__'):
+        return exclusions.tags(cls.__tags__).include_test(
+            include_tags, exclude_tags)
+    elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+        return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
+    else:
+        return True
+
+
+def generate_sub_tests(cls, module):
+    if getattr(cls, '__backend__', False):
+        for cfg in _possible_configs_for_cls(cls):
+            name = "%s_%s_%s" % (cls.__name__, cfg.db.name, cfg.db.driver)
+            subcls = type(
+                name,
+                (cls, ),
+                {
+                    "__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)),
+                }
+            )
+            setattr(module, name, subcls)
+            yield subcls
+    else:
+        yield cls
+
+
+def start_test_class(cls):
+    _do_skips(cls)
+    _setup_engine(cls)
+
+
+def stop_test_class(cls):
+    #from sqlalchemy import inspect
+    #assert not inspect(testing.db).get_table_names()
+    _restore_engine()
+
+
+def _restore_engine():
+    config._current.reset()
+
+
+def _setup_engine(cls):
+    if getattr(cls, '__engine_options__', None):
+        eng = engines.testing_engine(options=cls.__engine_options__)
+        config._current.push_engine(eng)
+
+
+def before_test(test, test_module_name, test_class, test_name):
+    pass
+
+
+def after_test(test):
+    pass
+
+
+def _possible_configs_for_cls(cls, reasons=None):
+    all_configs = set(config.Config.all_configs())
+
+    if cls.__unsupported_on__:
+        spec = exclusions.db_spec(*cls.__unsupported_on__)
+        for config_obj in list(all_configs):
+            if spec(config_obj):
+                all_configs.remove(config_obj)
+
+    if getattr(cls, '__only_on__', None):
+        spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+        for config_obj in list(all_configs):
+            if not spec(config_obj):
+                all_configs.remove(config_obj)
+
+    if hasattr(cls, '__requires__'):
+        requirements = config.requirements
+        for config_obj in list(all_configs):
+            for requirement in cls.__requires__:
+                check = getattr(requirements, requirement)
+
+                skip_reasons = check.matching_config_reasons(config_obj)
+                if skip_reasons:
+                    all_configs.remove(config_obj)
+                    if reasons is not None:
+                        reasons.extend(skip_reasons)
+                    break
+
+    if hasattr(cls, '__prefer_requires__'):
+        non_preferred = set()
+        requirements = config.requirements
+        for config_obj in list(all_configs):
+            for requirement in cls.__prefer_requires__:
+                check = getattr(requirements, requirement)
+
+                if not check.enabled_for_config(config_obj):
+                    non_preferred.add(config_obj)
+        if all_configs.difference(non_preferred):
+            all_configs.difference_update(non_preferred)
+
+    return all_configs
+
+
+def _do_skips(cls):
+    reasons = []
+    all_configs = _possible_configs_for_cls(cls, reasons)
+
+    if getattr(cls, '__skip_if__', False):
+        for c in getattr(cls, '__skip_if__'):
+            if c():
+                raise SkipTest("'%s' skipped by %s" % (
+                    cls.__name__, c.__name__)
+                )
+
+    if not all_configs:
+        if getattr(cls, '__backend__', False):
+            msg = "'%s' unsupported for implementation '%s'" % (
+                cls.__name__, cls.__only_on__)
+        else:
+            msg = "'%s' unsupported on any DB implementation %s%s" % (
+                cls.__name__,
+                ", ".join(
+                    "'%s(%s)+%s'" % (
+                        config_obj.db.name,
+                        ".".join(
+                            str(dig) for dig in
+                            config_obj.db.dialect.server_version_info),
+                        config_obj.db.driver
+                    )
+                  for config_obj in config.Config.all_configs()
+                ),
+                ", ".join(reasons)
+            )
+        raise SkipTest(msg)
+    elif hasattr(cls, '__prefer_backends__'):
+        non_preferred = set()
+        spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
+        for config_obj in all_configs:
+            if not spec(config_obj):
+                non_preferred.add(config_obj)
+        if all_configs.difference(non_preferred):
+            all_configs.difference_update(non_preferred)
+
+    if config._current not in all_configs:
+        _setup_config(all_configs.pop(), cls)
+
+
+def _setup_config(config_obj, ctx):
+    config._current.push(config_obj)
diff --git a/alembic/testing/plugin/provision.py b/alembic/testing/plugin/provision.py
new file mode 100644 (file)
index 0000000..d0edbef
--- /dev/null
@@ -0,0 +1,186 @@
+"""NOTE:  copied/adapted from SQLAlchemy master for backwards compatibility;
+   this should be removable when Alembic targets SQLAlchemy 0.9.4.
+"""
+from sqlalchemy.engine import url as sa_url
+from sqlalchemy import text
+from alembic import compat
+from alembic.testing import config, engines
+from alembic.testing.compat import get_url_backend_name
+
+
+class register(object):
+    def __init__(self):
+        self.fns = {}
+
+    @classmethod
+    def init(cls, fn):
+        return register().for_db("*")(fn)
+
+    def for_db(self, dbname):
+        def decorate(fn):
+            self.fns[dbname] = fn
+            return self
+        return decorate
+
+    def __call__(self, cfg, *arg):
+        if isinstance(cfg, compat.string_types):
+            url = sa_url.make_url(cfg)
+        elif isinstance(cfg, sa_url.URL):
+            url = cfg
+        else:
+            url = cfg.db.url
+        backend = get_url_backend_name(url)
+        if backend in self.fns:
+            return self.fns[backend](cfg, *arg)
+        else:
+            return self.fns['*'](cfg, *arg)
+
+
+def create_follower_db(follower_ident):
+
+    for cfg in _configs_for_db_operation():
+        _create_db(cfg, cfg.db, follower_ident)
+
+
+def configure_follower(follower_ident):
+    for cfg in config.Config.all_configs():
+        _configure_follower(cfg, follower_ident)
+
+
+def setup_config(db_url, db_opts, options, file_config, follower_ident):
+    if follower_ident:
+        db_url = _follower_url_from_main(db_url, follower_ident)
+    eng = engines.testing_engine(db_url, db_opts)
+    eng.connect().close()
+    cfg = config.Config.register(eng, db_opts, options, file_config)
+    if follower_ident:
+        _configure_follower(cfg, follower_ident)
+    return cfg
+
+
+def drop_follower_db(follower_ident):
+    for cfg in _configs_for_db_operation():
+        _drop_db(cfg, cfg.db, follower_ident)
+
+
+def _configs_for_db_operation():
+    hosts = set()
+
+    for cfg in config.Config.all_configs():
+        cfg.db.dispose()
+
+    for cfg in config.Config.all_configs():
+        url = cfg.db.url
+        backend = get_url_backend_name(url)
+        host_conf = (
+            backend,
+            url.username, url.host, url.database)
+
+        if host_conf not in hosts:
+            yield cfg
+            hosts.add(host_conf)
+
+    for cfg in config.Config.all_configs():
+        cfg.db.dispose()
+
+
+@register.init
+def _create_db(cfg, eng, ident):
+    raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
+
+
+@register.init
+def _drop_db(cfg, eng, ident):
+    raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
+
+
+@register.init
+def _configure_follower(cfg, ident):
+    pass
+
+
+@register.init
+def _follower_url_from_main(url, ident):
+    url = sa_url.make_url(url)
+    url.database = ident
+    return url
+
+
+@_follower_url_from_main.for_db("sqlite")
+def _sqlite_follower_url_from_main(url, ident):
+    url = sa_url.make_url(url)
+    if not url.database or url.database == ':memory:':
+        return url
+    else:
+        return sa_url.make_url("sqlite:///%s.db" % ident)
+
+
+@_create_db.for_db("postgresql")
+def _pg_create_db(cfg, eng, ident):
+    with eng.connect().execution_options(
+            isolation_level="AUTOCOMMIT") as conn:
+        try:
+            _pg_drop_db(cfg, conn, ident)
+        except:
+            pass
+        currentdb = conn.scalar("select current_database()")
+        conn.execute("CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb))
+
+
+@_create_db.for_db("mysql")
+def _mysql_create_db(cfg, eng, ident):
+    with eng.connect() as conn:
+        try:
+            _mysql_drop_db(cfg, conn, ident)
+        except:
+            pass
+        conn.execute("CREATE DATABASE %s" % ident)
+        conn.execute("CREATE DATABASE %s_test_schema" % ident)
+        conn.execute("CREATE DATABASE %s_test_schema_2" % ident)
+
+
+@_configure_follower.for_db("mysql")
+def _mysql_configure_follower(config, ident):
+    config.test_schema = "%s_test_schema" % ident
+    config.test_schema_2 = "%s_test_schema_2" % ident
+
+
+@_create_db.for_db("sqlite")
+def _sqlite_create_db(cfg, eng, ident):
+    pass
+
+
+@_drop_db.for_db("postgresql")
+def _pg_drop_db(cfg, eng, ident):
+    with eng.connect().execution_options(
+            isolation_level="AUTOCOMMIT") as conn:
+        conn.execute(
+            text(
+                "select pg_terminate_backend(pid) from pg_stat_activity "
+                "where usename=current_user and pid != pg_backend_pid() "
+                "and datname=:dname"
+            ), dname=ident)
+        conn.execute("DROP DATABASE %s" % ident)
+
+
+@_drop_db.for_db("sqlite")
+def _sqlite_drop_db(cfg, eng, ident):
+    pass
+    #os.remove("%s.db" % ident)
+
+
+@_drop_db.for_db("mysql")
+def _mysql_drop_db(cfg, eng, ident):
+    with eng.connect() as conn:
+        try:
+            conn.execute("DROP DATABASE %s_test_schema" % ident)
+        except:
+            pass
+        try:
+            conn.execute("DROP DATABASE %s_test_schema_2" % ident)
+        except:
+            pass
+        try:
+            conn.execute("DROP DATABASE %s" % ident)
+        except:
+            pass
diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py
new file mode 100644 (file)
index 0000000..fa13db8
--- /dev/null
@@ -0,0 +1,172 @@
+"""NOTE:  copied/adapted from SQLAlchemy master for backwards compatibility;
+   this should be removable when Alembic targets SQLAlchemy 0.9.4.
+"""
+import pytest
+import argparse
+import inspect
+from . import plugin_base
+import collections
+import itertools
+
+try:
+    import xdist
+    has_xdist = True
+except ImportError:
+    has_xdist = False
+
+
+def pytest_addoption(parser):
+    group = parser.getgroup("sqlalchemy")
+
+    def make_option(name, **kw):
+        callback_ = kw.pop("callback", None)
+        if callback_:
+            class CallableAction(argparse.Action):
+                def __call__(self, parser, namespace,
+                             values, option_string=None):
+                    callback_(option_string, values, parser)
+            kw["action"] = CallableAction
+
+        group.addoption(name, **kw)
+
+    plugin_base.setup_options(make_option)
+    plugin_base.read_config()
+
+
+def pytest_configure(config):
+    if hasattr(config, "slaveinput"):
+        plugin_base.restore_important_follower_config(config.slaveinput)
+        plugin_base.configure_follower(
+            config.slaveinput["follower_ident"]
+        )
+
+    plugin_base.pre_begin(config.option)
+
+    plugin_base.set_coverage_flag(bool(getattr(config.option,
+                                               "cov_source", False)))
+
+    plugin_base.post_begin()
+
+if has_xdist:
+    _follower_count = itertools.count(1)
+
+    def pytest_configure_node(node):
+        # the master for each node fills slaveinput dictionary
+        # which pytest-xdist will transfer to the subprocess
+
+        plugin_base.memoize_important_follower_config(node.slaveinput)
+
+        node.slaveinput["follower_ident"] = "test_%s" % next(_follower_count)
+        from . import provision
+        provision.create_follower_db(node.slaveinput["follower_ident"])
+
+    def pytest_testnodedown(node, error):
+        from . import provision
+        provision.drop_follower_db(node.slaveinput["follower_ident"])
+
+
+def pytest_collection_modifyitems(session, config, items):
+    # look for all those classes that specify __backend__ and
+    # expand them out into per-database test cases.
+
+    # this is much easier to do within pytest_pycollect_makeitem, however
+    # pytest is iterating through cls.__dict__ as makeitem is
+    # called which causes a "dictionary changed size" error on py3k.
+    # I'd submit a pullreq for them to turn it into a list first, but
+    # it's to suit the rather odd use case here which is that we are adding
+    # new classes to a module on the fly.
+
+    rebuilt_items = collections.defaultdict(list)
+    items[:] = [
+        item for item in
+        items if isinstance(item.parent, pytest.Instance)]
+    test_classes = set(item.parent for item in items)
+    for test_class in test_classes:
+        for sub_cls in plugin_base.generate_sub_tests(
+                test_class.cls, test_class.parent.module):
+            if sub_cls is not test_class.cls:
+                list_ = rebuilt_items[test_class.cls]
+
+                for inst in pytest.Class(
+                        sub_cls.__name__,
+                        parent=test_class.parent.parent).collect():
+                    list_.extend(inst.collect())
+
+    newitems = []
+    for item in items:
+        if item.parent.cls in rebuilt_items:
+            newitems.extend(rebuilt_items[item.parent.cls])
+            rebuilt_items[item.parent.cls][:] = []
+        else:
+            newitems.append(item)
+
+    # seems like the functions attached to a test class aren't sorted already?
+    # is that true and why's that? (when using unittest, they're sorted)
+    items[:] = sorted(newitems, key=lambda item: (
+        item.parent.parent.parent.name,
+        item.parent.parent.name,
+        item.name
+    ))
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+    if inspect.isclass(obj) and plugin_base.want_class(obj):
+        return pytest.Class(name, parent=collector)
+    elif inspect.isfunction(obj) and \
+            isinstance(collector, pytest.Instance) and \
+            plugin_base.want_method(collector.cls, obj):
+        return pytest.Function(name, parent=collector)
+    else:
+        return []
+
+_current_class = None
+
+def pytest_runtest_setup(item):
+    # here we seem to get called only based on what we collected
+    # in pytest_collection_modifyitems.   So to do class-based stuff
+    # we have to tear that out.
+    global _current_class
+
+    if not isinstance(item, pytest.Function):
+        return
+
+    # ... so we're doing a little dance here to figure it out...
+    if _current_class is None:
+        class_setup(item.parent.parent)
+        _current_class = item.parent.parent
+
+        # this is needed for the class-level, to ensure that the
+        # teardown runs after the class is completed with its own
+        # class-level teardown...
+        def finalize():
+            global _current_class
+            class_teardown(item.parent.parent)
+            _current_class = None
+        item.parent.parent.addfinalizer(finalize)
+
+    test_setup(item)
+
+
+def pytest_runtest_teardown(item):
+    # ...but this works better as the hook here rather than
+    # using a finalizer, as the finalizer seems to get in the way
+    # of the test reporting failures correctly (you get a bunch of
+    # py.test assertion stuff instead)
+    test_teardown(item)
+
+
+def test_setup(item):
+    plugin_base.before_test(item, item.parent.module.__name__,
+                            item.parent.cls, item.name)
+
+
+def test_teardown(item):
+    plugin_base.after_test(item)
+
+
+def class_setup(item):
+    plugin_base.start_test_class(item.cls)
+
+
+def class_teardown(item):
+    plugin_base.stop_test_class(item.cls)