import testlib.testing as testing
from testlib.testing import PersistTest, AssertMixin, ORMTest, SQLCompileTest
import testlib.profiling as profiling
-import testlib.engines
+import testlib.engines as engines
__all__ = ('testing',
'Table', 'Column',
- 'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest', 'profiling')
+ 'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest',
+ 'profiling', 'engines')
+import sys, weakref
from testlib import config
+class ConnectionKiller(object):
+ def __init__(self):
+ self.record_refs = []
+
+ def connect(self, dbapi_con, con_record):
+ self.record_refs.append(weakref.ref(con_record))
+
+ def _apply_all(self, methods):
+ for ref in self.record_refs:
+ rec = ref()
+ if rec is not None and rec.connection is not None:
+ try:
+ for name in methods:
+ getattr(rec.connection, name)()
+ except (SystemExit, KeyboardInterrupt):
+ raise
+ except Exception, e:
+ # fixme
+ sys.stderr.write("\n" + str(e) + "\n")
+ del self.record_refs[:]
+
+ def rollback_all(self):
+ self._apply_all(('rollback',))
+ def close_all(self):
+ self._apply_all(('rollback','close'))
+
+testing_reaper = ConnectionKiller()
+
+def rollback_open_connections(fn):
+ """Decorator that rolls back all open connections after fn execution."""
+
+ def decorated(*args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.rollback_all()
+ decorated.__name__ = fn.__name__
+ return decorated
+
+def close_open_connections(fn):
+ """Decorator that closes all connections after fn execution."""
+
+ def decorated(*args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.close_all()
+ decorated.__name__ = fn.__name__
+ return decorated
+
+
def testing_engine(url=None, options=None):
"""Produce an engine configured by --options with optional overrides."""
url = url or config.db_url
options = options or config.db_opts
+ listeners = options.setdefault('listeners', [])
+ listeners.append(testing_reaper)
+
engine = create_engine(url, **options)
create_context = engine.dialect.create_execution_context