--- /dev/null
+# plugin/noseplugin.py
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
+ this should be removable when Alembic targets SQLAlchemy 0.9.4.
+"""
+
+"""Enhance nose with extra options and behaviors for running SQLAlchemy tests.
+
+
+"""
+
+import os
+import sys
+
+from nose.plugins import Plugin
+fixtures = None
+
+py3k = sys.version_info >= (3, 0)
+# no package imports yet! this prevents us from tripping coverage
+# too soon.
+path = os.path.join(os.path.dirname(__file__), "plugin_base.py")
+if sys.version_info >= (3, 3):
+ from importlib import machinery
+ plugin_base = machinery.SourceFileLoader(
+ "plugin_base", path).load_module()
+else:
+ import imp
+ plugin_base = imp.load_source("plugin_base", path)
+
+
+class NoseSQLAlchemy(Plugin):
+ enabled = True
+
+ name = 'sqla_testing'
+ score = 100
+
+ def options(self, parser, env=os.environ):
+ Plugin.options(self, parser, env)
+ opt = parser.add_option
+
+ def make_option(name, **kw):
+ callback_ = kw.pop("callback", None)
+ if callback_:
+ def wrap_(option, opt_str, value, parser):
+ callback_(opt_str, value, parser)
+ kw["callback"] = wrap_
+ opt(name, **kw)
+
+ plugin_base.setup_options(make_option)
+ plugin_base.read_config()
+
+ def configure(self, options, conf):
+ super(NoseSQLAlchemy, self).configure(options, conf)
+ plugin_base.pre_begin(options)
+
+ plugin_base.set_coverage_flag(options.enable_plugin_coverage)
+
+ global fixtures
+ from sqlalchemy.testing import fixtures
+
+ def begin(self):
+ plugin_base.post_begin()
+
+ def describeTest(self, test):
+ return ""
+
+ def wantFunction(self, fn):
+ return False
+
+ def wantMethod(self, fn):
+ if py3k:
+ cls = fn.__self__.cls
+ else:
+ cls = fn.im_class
+ print "METH:", fn, "CLS:", cls
+ return plugin_base.want_method(cls, fn)
+
+ def wantClass(self, cls):
+ return plugin_base.want_class(cls)
+
+ def beforeTest(self, test):
+ plugin_base.before_test(test,
+ test.test.cls.__module__,
+ test.test.cls, test.test.method.__name__)
+
+ def afterTest(self, test):
+ plugin_base.after_test(test)
+
+ def startContext(self, ctx):
+ if not isinstance(ctx, type) \
+ or not issubclass(ctx, fixtures.TestBase):
+ return
+ plugin_base.start_test_class(ctx)
+
+ def stopContext(self, ctx):
+ if not isinstance(ctx, type) \
+ or not issubclass(ctx, fixtures.TestBase):
+ return
+ plugin_base.stop_test_class(ctx)
--- /dev/null
+# plugin/plugin_base.py
+# Copyright (C) 2005-2014 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+"""Testing extensions.
+
+this module is designed to work as a testing-framework-agnostic library,
+so that we can continue to support nose and also begin adding new
+functionality via py.test.
+
+NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
+ this should be removable when Alembic targets SQLAlchemy 0.9.4.
+
+
+"""
+
+from __future__ import absolute_import
+try:
+ # unitttest has a SkipTest also but pytest doesn't
+ # honor it unless nose is imported too...
+ from nose import SkipTest
+except ImportError:
+ from _pytest.runner import Skipped as SkipTest
+
+import sys
+import re
+
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+ import configparser
+else:
+ import ConfigParser as configparser
+
+FOLLOWER_IDENT = None
+
+# late imports
+fixtures = None
+engines = None
+exclusions = None
+warnings = None
+assertions = None
+requirements = None
+config = None
+util = None
+file_config = None
+
+
+logging = None
+db_opts = {}
+include_tags = set()
+exclude_tags = set()
+options = None
+
+
+def setup_options(make_option):
+ make_option("--log-info", action="callback", type="string", callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)")
+ make_option("--log-debug", action="callback",
+ type="string", callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)")
+ make_option("--db", action="append", type="string", dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.")
+ make_option('--dbs', action='callback', callback=_list_dbs,
+ help="List available prefab dbs")
+ make_option("--dburi", action="append", type="string", dest="dburi",
+ help="Database uri. Multiple OK, "
+ "first one is run by default.")
+ make_option("--dropfirst", action="store_true", dest="dropfirst",
+ help="Drop all tables in the target database first")
+ make_option("--backend-only", action="store_true", dest="backend_only",
+ help="Run only tests marked with __backend__")
+ make_option("--mockpool", action="store_true", dest="mockpool",
+ help="Use mock pool (asserts only one connection used)")
+ make_option("--low-connections", action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS")
+ make_option("--reversetop", action="store_true",
+ dest="reversetop", default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)")
+ make_option("--requirements", action="callback", type="string",
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg")
+ make_option("--with-cdecimal", action="store_true",
+ dest="cdecimal", default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests")
+ make_option("--include-tag", action="callback", callback=_include_tag,
+ type="string",
+ help="Include tests with tag <tag>")
+ make_option("--exclude-tag", action="callback", callback=_exclude_tag,
+ type="string",
+ help="Exclude tests with tag <tag>")
+ make_option("--serverside", action="store_true",
+ help="Turn on server side cursors for PG")
+ make_option("--mysql-engine", action="store",
+ dest="mysql_engine", default=None,
+ help="Use the specified MySQL storage engine for all tables, "
+ "default is a db-default/InnoDB combo.")
+
+
+def configure_follower(follower_ident):
+ """Configure required state for a follower.
+
+ This invokes in the parent process and typically includes
+ database creation.
+
+ """
+ global FOLLOWER_IDENT
+ FOLLOWER_IDENT = follower_ident
+
+
+def memoize_important_follower_config(dict_):
+ """Store important configuration we will need to send to a follower.
+
+ This invokes in the parent process after normal config is set up.
+
+ This is necessary as py.test seems to not be using forking, so we
+ start with nothing in memory, *but* it isn't running our argparse
+ callables, so we have to just copy all of that over.
+
+ """
+ dict_['memoized_config'] = {
+ 'db_opts': db_opts,
+ 'include_tags': include_tags,
+ 'exclude_tags': exclude_tags
+ }
+
+
+def restore_important_follower_config(dict_):
+ """Restore important configuration needed by a follower.
+
+ This invokes in the follower process.
+
+ """
+ global db_opts, include_tags, exclude_tags
+ db_opts.update(dict_['memoized_config']['db_opts'])
+ include_tags.update(dict_['memoized_config']['include_tags'])
+ exclude_tags.update(dict_['memoized_config']['exclude_tags'])
+
+
+def read_config():
+ global file_config
+ file_config = configparser.ConfigParser()
+ file_config.read(['setup.cfg', 'test.cfg'])
+
+
+def pre_begin(opt):
+ """things to set up early, before coverage might be setup."""
+ global options
+ options = opt
+ for fn in pre_configure:
+ fn(options, file_config)
+
+
+def set_coverage_flag(value):
+ options.has_coverage = value
+
+
+def post_begin():
+ """things to set up later, once we know coverage is running."""
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ # late imports, has to happen after config as well
+ # as nose plugins like coverage
+ global util, fixtures, engines, exclusions, \
+ assertions, warnings, profiling,\
+ config, testing
+ from alembic.testing import config, warnings, exclusions, engines, fixtures
+ from sqlalchemy import util
+ warnings.setup_filters()
+
+def _log(opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+ logging.basicConfig()
+
+ if opt_str.endswith('-info'):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith('-debug'):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+ print("Available --db options (use --dburi to override)")
+ for macro in sorted(file_config.options('db')):
+ print("%20s\t%s" % (macro, file_config.get('db', macro)))
+ sys.exit(0)
+
+
+def _requirements_opt(opt_str, value, parser):
+ _setup_requirements(value)
+
+
+def _exclude_tag(opt_str, value, parser):
+ exclude_tags.add(value.replace('-', '_'))
+
+
+def _include_tag(opt_str, value, parser):
+ include_tags.add(value.replace('-', '_'))
+
+pre_configure = []
+post_configure = []
+
+
+def pre(fn):
+ pre_configure.append(fn)
+ return fn
+
+
+def post(fn):
+ post_configure.append(fn)
+ return fn
+
+
+@pre
+def _setup_options(opt, file_config):
+ global options
+ options = opt
+
+
+@pre
+def _server_side_cursors(options, file_config):
+ if options.serverside:
+ db_opts['server_side_cursors'] = True
+
+
+@pre
+def _monkeypatch_cdecimal(options, file_config):
+ if options.cdecimal:
+ import cdecimal
+ sys.modules['decimal'] = cdecimal
+
+
+@post
+def _engine_uri(options, file_config):
+ from alembic.testing import config
+ from alembic.testing.plugin import provision
+
+ if options.dburi:
+ db_urls = list(options.dburi)
+ else:
+ db_urls = []
+
+ if options.db:
+ for db_token in options.db:
+ for db in re.split(r'[,\s]+', db_token):
+ if db not in file_config.options('db'):
+ raise RuntimeError(
+ "Unknown URI specifier '%s'. "
+ "Specify --dbs for known uris."
+ % db)
+ else:
+ db_urls.append(file_config.get('db', db))
+
+ if not db_urls:
+ db_urls.append(file_config.get('db', 'default'))
+
+ for db_url in db_urls:
+ cfg = provision.setup_config(
+ db_url, db_opts, options, file_config, FOLLOWER_IDENT)
+
+ if not config._current:
+ cfg.set_as_current(cfg)
+
+
+@post
+def _engine_pool(options, file_config):
+ if options.mockpool:
+ from sqlalchemy import pool
+ db_opts['poolclass'] = pool.AssertionPool
+
+
+@post
+def _requirements(options, file_config):
+
+ requirement_cls = file_config.get('sqla_testing', "requirement_cls")
+ _setup_requirements(requirement_cls)
+
+
+def _setup_requirements(argument):
+ from alembic.testing import config
+
+ if config.requirements is not None:
+ return
+
+ modname, clsname = argument.split(":")
+
+ # importlib.import_module() only introduced in 2.7, a little
+ # late
+ mod = __import__(modname)
+ for component in modname.split(".")[1:]:
+ mod = getattr(mod, component)
+ req_cls = getattr(mod, clsname)
+
+ config.requirements = req_cls()
+
+
+@post
+def _prep_testing_database(options, file_config):
+ from alembic.testing import config
+ from alembic.testing.exclusions import against
+ from sqlalchemy import schema
+ from alembic import util
+
+ if util.sqla_08:
+ from sqlalchemy import inspect
+ else:
+ from sqlalchemy.engine.reflection import Inspector
+ inspect = Inspector.from_engine
+
+ if options.dropfirst:
+ for cfg in config.Config.all_configs():
+ e = cfg.db
+ inspector = inspect(e)
+ try:
+ view_names = inspector.get_view_names()
+ except NotImplementedError:
+ pass
+ else:
+ for vname in view_names:
+ e.execute(schema._DropView(
+ schema.Table(vname, schema.MetaData())
+ ))
+
+ if config.requirements.schemas.enabled_for_config(cfg):
+ try:
+ view_names = inspector.get_view_names(
+ schema="test_schema")
+ except NotImplementedError:
+ pass
+ else:
+ for vname in view_names:
+ e.execute(schema._DropView(
+ schema.Table(vname, schema.MetaData(),
+ schema="test_schema")
+ ))
+
+ for tname in reversed(inspector.get_table_names(
+ order_by="foreign_key")):
+ e.execute(schema.DropTable(
+ schema.Table(tname, schema.MetaData())
+ ))
+
+ if config.requirements.schemas.enabled_for_config(cfg):
+ for tname in reversed(inspector.get_table_names(
+ order_by="foreign_key", schema="test_schema")):
+ e.execute(schema.DropTable(
+ schema.Table(tname, schema.MetaData(),
+ schema="test_schema")
+ ))
+
+ if against(cfg, "postgresql"):
+ from sqlalchemy.dialects import postgresql
+ for enum in inspector.get_enums("*"):
+ e.execute(postgresql.DropEnumType(
+ postgresql.ENUM(
+ name=enum['name'],
+ schema=enum['schema'])))
+
+
+
+
+@post
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm.util import randomize_unitofwork
+ randomize_unitofwork()
+
+
+@post
+def _post_setup_options(opt, file_config):
+ from alembic.testing import config
+ config.options = options
+ config.file_config = file_config
+
+
+
+def want_class(cls):
+ if not issubclass(cls, fixtures.TestBase):
+ return False
+ elif cls.__name__.startswith('_'):
+ return False
+ elif config.options.backend_only and not getattr(cls, '__backend__',
+ False):
+ return False
+ else:
+ return True
+
+
+def want_method(cls, fn):
+ if not fn.__name__.startswith("test_"):
+ return False
+ elif fn.__module__ is None:
+ return False
+ elif include_tags:
+ return (
+ hasattr(cls, '__tags__') and
+ exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags)
+ ) or (
+ hasattr(fn, '_sa_exclusion_extend') and
+ fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags)
+ )
+ elif exclude_tags and hasattr(cls, '__tags__'):
+ return exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags)
+ elif exclude_tags and hasattr(fn, '_sa_exclusion_extend'):
+ return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
+ else:
+ return True
+
+
+def generate_sub_tests(cls, module):
+ if getattr(cls, '__backend__', False):
+ for cfg in _possible_configs_for_cls(cls):
+ name = "%s_%s_%s" % (cls.__name__, cfg.db.name, cfg.db.driver)
+ subcls = type(
+ name,
+ (cls, ),
+ {
+ "__only_on__": ("%s+%s" % (cfg.db.name, cfg.db.driver)),
+ }
+ )
+ setattr(module, name, subcls)
+ yield subcls
+ else:
+ yield cls
+
+
+def start_test_class(cls):
+ _do_skips(cls)
+ _setup_engine(cls)
+
+
+def stop_test_class(cls):
+ #from sqlalchemy import inspect
+ #assert not inspect(testing.db).get_table_names()
+ _restore_engine()
+
+
+def _restore_engine():
+ config._current.reset()
+
+
+def _setup_engine(cls):
+ if getattr(cls, '__engine_options__', None):
+ eng = engines.testing_engine(options=cls.__engine_options__)
+ config._current.push_engine(eng)
+
+
+def before_test(test, test_module_name, test_class, test_name):
+ pass
+
+
+def after_test(test):
+ pass
+
+
+def _possible_configs_for_cls(cls, reasons=None):
+ all_configs = set(config.Config.all_configs())
+
+ if cls.__unsupported_on__:
+ spec = exclusions.db_spec(*cls.__unsupported_on__)
+ for config_obj in list(all_configs):
+ if spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if getattr(cls, '__only_on__', None):
+ spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+ for config_obj in list(all_configs):
+ if not spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if hasattr(cls, '__requires__'):
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__requires__:
+ check = getattr(requirements, requirement)
+
+ skip_reasons = check.matching_config_reasons(config_obj)
+ if skip_reasons:
+ all_configs.remove(config_obj)
+ if reasons is not None:
+ reasons.extend(skip_reasons)
+ break
+
+ if hasattr(cls, '__prefer_requires__'):
+ non_preferred = set()
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__prefer_requires__:
+ check = getattr(requirements, requirement)
+
+ if not check.enabled_for_config(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ return all_configs
+
+
+def _do_skips(cls):
+ reasons = []
+ all_configs = _possible_configs_for_cls(cls, reasons)
+
+ if getattr(cls, '__skip_if__', False):
+ for c in getattr(cls, '__skip_if__'):
+ if c():
+ raise SkipTest("'%s' skipped by %s" % (
+ cls.__name__, c.__name__)
+ )
+
+ if not all_configs:
+ if getattr(cls, '__backend__', False):
+ msg = "'%s' unsupported for implementation '%s'" % (
+ cls.__name__, cls.__only_on__)
+ else:
+ msg = "'%s' unsupported on any DB implementation %s%s" % (
+ cls.__name__,
+ ", ".join(
+ "'%s(%s)+%s'" % (
+ config_obj.db.name,
+ ".".join(
+ str(dig) for dig in
+ config_obj.db.dialect.server_version_info),
+ config_obj.db.driver
+ )
+ for config_obj in config.Config.all_configs()
+ ),
+ ", ".join(reasons)
+ )
+ raise SkipTest(msg)
+ elif hasattr(cls, '__prefer_backends__'):
+ non_preferred = set()
+ spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
+ for config_obj in all_configs:
+ if not spec(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ if config._current not in all_configs:
+ _setup_config(all_configs.pop(), cls)
+
+
+def _setup_config(config_obj, ctx):
+ config._current.push(config_obj)
--- /dev/null
+"""NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
+ this should be removable when Alembic targets SQLAlchemy 0.9.4.
+"""
+from sqlalchemy.engine import url as sa_url
+from sqlalchemy import text
+from alembic import compat
+from alembic.testing import config, engines
+from alembic.testing.compat import get_url_backend_name
+
+
+class register(object):
+ def __init__(self):
+ self.fns = {}
+
+ @classmethod
+ def init(cls, fn):
+ return register().for_db("*")(fn)
+
+ def for_db(self, dbname):
+ def decorate(fn):
+ self.fns[dbname] = fn
+ return self
+ return decorate
+
+ def __call__(self, cfg, *arg):
+ if isinstance(cfg, compat.string_types):
+ url = sa_url.make_url(cfg)
+ elif isinstance(cfg, sa_url.URL):
+ url = cfg
+ else:
+ url = cfg.db.url
+ backend = get_url_backend_name(url)
+ if backend in self.fns:
+ return self.fns[backend](cfg, *arg)
+ else:
+ return self.fns['*'](cfg, *arg)
+
+
+def create_follower_db(follower_ident):
+
+ for cfg in _configs_for_db_operation():
+ _create_db(cfg, cfg.db, follower_ident)
+
+
+def configure_follower(follower_ident):
+ for cfg in config.Config.all_configs():
+ _configure_follower(cfg, follower_ident)
+
+
+def setup_config(db_url, db_opts, options, file_config, follower_ident):
+ if follower_ident:
+ db_url = _follower_url_from_main(db_url, follower_ident)
+ eng = engines.testing_engine(db_url, db_opts)
+ eng.connect().close()
+ cfg = config.Config.register(eng, db_opts, options, file_config)
+ if follower_ident:
+ _configure_follower(cfg, follower_ident)
+ return cfg
+
+
+def drop_follower_db(follower_ident):
+ for cfg in _configs_for_db_operation():
+ _drop_db(cfg, cfg.db, follower_ident)
+
+
+def _configs_for_db_operation():
+ hosts = set()
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+ for cfg in config.Config.all_configs():
+ url = cfg.db.url
+ backend = get_url_backend_name(url)
+ host_conf = (
+ backend,
+ url.username, url.host, url.database)
+
+ if host_conf not in hosts:
+ yield cfg
+ hosts.add(host_conf)
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+
+@register.init
+def _create_db(cfg, eng, ident):
+ raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
+
+
+@register.init
+def _drop_db(cfg, eng, ident):
+ raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
+
+
+@register.init
+def _configure_follower(cfg, ident):
+ pass
+
+
+@register.init
+def _follower_url_from_main(url, ident):
+ url = sa_url.make_url(url)
+ url.database = ident
+ return url
+
+
+@_follower_url_from_main.for_db("sqlite")
+def _sqlite_follower_url_from_main(url, ident):
+ url = sa_url.make_url(url)
+ if not url.database or url.database == ':memory:':
+ return url
+ else:
+ return sa_url.make_url("sqlite:///%s.db" % ident)
+
+
+@_create_db.for_db("postgresql")
+def _pg_create_db(cfg, eng, ident):
+ with eng.connect().execution_options(
+ isolation_level="AUTOCOMMIT") as conn:
+ try:
+ _pg_drop_db(cfg, conn, ident)
+ except:
+ pass
+ currentdb = conn.scalar("select current_database()")
+ conn.execute("CREATE DATABASE %s TEMPLATE %s" % (ident, currentdb))
+
+
+@_create_db.for_db("mysql")
+def _mysql_create_db(cfg, eng, ident):
+ with eng.connect() as conn:
+ try:
+ _mysql_drop_db(cfg, conn, ident)
+ except:
+ pass
+ conn.execute("CREATE DATABASE %s" % ident)
+ conn.execute("CREATE DATABASE %s_test_schema" % ident)
+ conn.execute("CREATE DATABASE %s_test_schema_2" % ident)
+
+
+@_configure_follower.for_db("mysql")
+def _mysql_configure_follower(config, ident):
+ config.test_schema = "%s_test_schema" % ident
+ config.test_schema_2 = "%s_test_schema_2" % ident
+
+
+@_create_db.for_db("sqlite")
+def _sqlite_create_db(cfg, eng, ident):
+ pass
+
+
+@_drop_db.for_db("postgresql")
+def _pg_drop_db(cfg, eng, ident):
+ with eng.connect().execution_options(
+ isolation_level="AUTOCOMMIT") as conn:
+ conn.execute(
+ text(
+ "select pg_terminate_backend(pid) from pg_stat_activity "
+ "where usename=current_user and pid != pg_backend_pid() "
+ "and datname=:dname"
+ ), dname=ident)
+ conn.execute("DROP DATABASE %s" % ident)
+
+
+@_drop_db.for_db("sqlite")
+def _sqlite_drop_db(cfg, eng, ident):
+ pass
+ #os.remove("%s.db" % ident)
+
+
+@_drop_db.for_db("mysql")
+def _mysql_drop_db(cfg, eng, ident):
+ with eng.connect() as conn:
+ try:
+ conn.execute("DROP DATABASE %s_test_schema" % ident)
+ except:
+ pass
+ try:
+ conn.execute("DROP DATABASE %s_test_schema_2" % ident)
+ except:
+ pass
+ try:
+ conn.execute("DROP DATABASE %s" % ident)
+ except:
+ pass
--- /dev/null
+"""NOTE: copied/adapted from SQLAlchemy master for backwards compatibility;
+ this should be removable when Alembic targets SQLAlchemy 0.9.4.
+"""
+import pytest
+import argparse
+import inspect
+from . import plugin_base
+import collections
+import itertools
+
+try:
+ import xdist
+ has_xdist = True
+except ImportError:
+ has_xdist = False
+
+
+def pytest_addoption(parser):
+ group = parser.getgroup("sqlalchemy")
+
+ def make_option(name, **kw):
+ callback_ = kw.pop("callback", None)
+ if callback_:
+ class CallableAction(argparse.Action):
+ def __call__(self, parser, namespace,
+ values, option_string=None):
+ callback_(option_string, values, parser)
+ kw["action"] = CallableAction
+
+ group.addoption(name, **kw)
+
+ plugin_base.setup_options(make_option)
+ plugin_base.read_config()
+
+
+def pytest_configure(config):
+ if hasattr(config, "slaveinput"):
+ plugin_base.restore_important_follower_config(config.slaveinput)
+ plugin_base.configure_follower(
+ config.slaveinput["follower_ident"]
+ )
+
+ plugin_base.pre_begin(config.option)
+
+ plugin_base.set_coverage_flag(bool(getattr(config.option,
+ "cov_source", False)))
+
+ plugin_base.post_begin()
+
+if has_xdist:
+ _follower_count = itertools.count(1)
+
+ def pytest_configure_node(node):
+ # the master for each node fills slaveinput dictionary
+ # which pytest-xdist will transfer to the subprocess
+
+ plugin_base.memoize_important_follower_config(node.slaveinput)
+
+ node.slaveinput["follower_ident"] = "test_%s" % next(_follower_count)
+ from . import provision
+ provision.create_follower_db(node.slaveinput["follower_ident"])
+
+ def pytest_testnodedown(node, error):
+ from . import provision
+ provision.drop_follower_db(node.slaveinput["follower_ident"])
+
+
+def pytest_collection_modifyitems(session, config, items):
+ # look for all those classes that specify __backend__ and
+ # expand them out into per-database test cases.
+
+ # this is much easier to do within pytest_pycollect_makeitem, however
+ # pytest is iterating through cls.__dict__ as makeitem is
+ # called which causes a "dictionary changed size" error on py3k.
+ # I'd submit a pullreq for them to turn it into a list first, but
+ # it's to suit the rather odd use case here which is that we are adding
+ # new classes to a module on the fly.
+
+ rebuilt_items = collections.defaultdict(list)
+ items[:] = [
+ item for item in
+ items if isinstance(item.parent, pytest.Instance)]
+ test_classes = set(item.parent for item in items)
+ for test_class in test_classes:
+ for sub_cls in plugin_base.generate_sub_tests(
+ test_class.cls, test_class.parent.module):
+ if sub_cls is not test_class.cls:
+ list_ = rebuilt_items[test_class.cls]
+
+ for inst in pytest.Class(
+ sub_cls.__name__,
+ parent=test_class.parent.parent).collect():
+ list_.extend(inst.collect())
+
+ newitems = []
+ for item in items:
+ if item.parent.cls in rebuilt_items:
+ newitems.extend(rebuilt_items[item.parent.cls])
+ rebuilt_items[item.parent.cls][:] = []
+ else:
+ newitems.append(item)
+
+ # seems like the functions attached to a test class aren't sorted already?
+ # is that true and why's that? (when using unittest, they're sorted)
+ items[:] = sorted(newitems, key=lambda item: (
+ item.parent.parent.parent.name,
+ item.parent.parent.name,
+ item.name
+ ))
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+ if inspect.isclass(obj) and plugin_base.want_class(obj):
+ return pytest.Class(name, parent=collector)
+ elif inspect.isfunction(obj) and \
+ isinstance(collector, pytest.Instance) and \
+ plugin_base.want_method(collector.cls, obj):
+ return pytest.Function(name, parent=collector)
+ else:
+ return []
+
+_current_class = None
+
+def pytest_runtest_setup(item):
+ # here we seem to get called only based on what we collected
+ # in pytest_collection_modifyitems. So to do class-based stuff
+ # we have to tear that out.
+ global _current_class
+
+ if not isinstance(item, pytest.Function):
+ return
+
+ # ... so we're doing a little dance here to figure it out...
+ if _current_class is None:
+ class_setup(item.parent.parent)
+ _current_class = item.parent.parent
+
+ # this is needed for the class-level, to ensure that the
+ # teardown runs after the class is completed with its own
+ # class-level teardown...
+ def finalize():
+ global _current_class
+ class_teardown(item.parent.parent)
+ _current_class = None
+ item.parent.parent.addfinalizer(finalize)
+
+ test_setup(item)
+
+
+def pytest_runtest_teardown(item):
+ # ...but this works better as the hook here rather than
+ # using a finalizer, as the finalizer seems to get in the way
+ # of the test reporting failures correctly (you get a bunch of
+ # py.test assertion stuff instead)
+ test_teardown(item)
+
+
+def test_setup(item):
+ plugin_base.before_test(item, item.parent.module.__name__,
+ item.parent.cls, item.name)
+
+
+def test_teardown(item):
+ plugin_base.after_test(item)
+
+
+def class_setup(item):
+ plugin_base.start_test_class(item.cls)
+
+
+def class_teardown(item):
+ plugin_base.stop_test_class(item.cls)