]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added a testing decorator that mops up wayward connections
authorJason Kirtland <jek@discorporate.us>
Fri, 17 Aug 2007 01:24:46 +0000 (01:24 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 17 Aug 2007 01:24:46 +0000 (01:24 +0000)
test/testlib/__init__.py
test/testlib/engines.py

index 046f8f9b407139923858363eef4c6d10e57b9589..d30e4fc28738c9e0f65988fdd117271fc689fb09 100644 (file)
@@ -8,9 +8,10 @@ from testlib.schema import Table, Column
 import testlib.testing as testing
 from testlib.testing import PersistTest, AssertMixin, ORMTest, SQLCompileTest
 import testlib.profiling as profiling
-import testlib.engines
+import testlib.engines as engines
 
 
 __all__ = ('testing',
            'Table', 'Column',
-           'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest', 'profiling')
+           'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest',
+           'profiling', 'engines')
index addef3f9eeea2fcf8e1e76a9ec5bb002545020ff..a66b336e948fa8ac4b4f7878cc24bc222c5b3c1b 100644 (file)
@@ -1,6 +1,58 @@
+import sys, weakref
 from testlib import config
 
 
+class ConnectionKiller(object):
+    def __init__(self):
+        self.record_refs = []
+        
+    def connect(self, dbapi_con, con_record):
+        self.record_refs.append(weakref.ref(con_record))
+
+    def _apply_all(self, methods):
+        for ref in self.record_refs:
+            rec = ref()
+            if rec is not None and rec.connection is not None:
+                try:
+                    for name in methods:
+                        getattr(rec.connection, name)()
+                except (SystemExit, KeyboardInterrupt):
+                    raise
+                except Exception, e:
+                    # fixme
+                    sys.stderr.write("\n" + str(e) + "\n")
+        del self.record_refs[:]
+
+    def rollback_all(self):
+        self._apply_all(('rollback',))
+    def close_all(self):
+        self._apply_all(('rollback','close'))
+
+testing_reaper = ConnectionKiller()
+
+def rollback_open_connections(fn):
+    """Decorator that rolls back all open connections after fn execution."""
+
+    def decorated(*args, **kw):
+        try:
+            fn(*args, **kw)
+        finally:
+            testing_reaper.rollback_all()
+    decorated.__name__ = fn.__name__
+    return decorated
+
+def close_open_connections(fn):
+    """Decorator that closes all connections after fn execution."""
+
+    def decorated(*args, **kw):
+        try:
+            fn(*args, **kw)
+        finally:
+            testing_reaper.close_all()
+    decorated.__name__ = fn.__name__
+    return decorated
+
+
 def testing_engine(url=None, options=None):
     """Produce an engine configured by --options with optional overrides."""
     
@@ -10,6 +62,9 @@ def testing_engine(url=None, options=None):
     url = url or config.db_url
     options = options or config.db_opts
 
+    listeners = options.setdefault('listeners', [])
+    listeners.append(testing_reaper)
+
     engine = create_engine(url, **options)
 
     create_context = engine.dialect.create_execution_context