]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- replace most explicitly-named test objects called "Mock..." with
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Jun 2013 22:35:12 +0000 (18:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Jun 2013 22:41:35 +0000 (18:41 -0400)
actual mock objects from the mock library.  I'd like to use mock
for new tests so we might as well use it in obvious places.
- use unittest.mock in py3.3
- changelog
- add a note to README.unittests
- add tests_require in setup.py
- have tests import from sqlalchemy.testing.mock
- apply usage of mock to one of the event tests.  we can be using
this approach all over the place.

16 files changed:
README.unittests.rst
doc/build/changelog/changelog_08.rst
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/mock.py [new file with mode: 0644]
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
setup.py
test/aaa_profiling/test_resultset.py
test/base/test_events.py
test/dialect/postgresql/test_dialect.py
test/dialect/test_mxodbc.py
test/engine/test_ddlemit.py
test/engine/test_execute.py
test/engine/test_parseconnect.py
test/engine/test_pool.py
test/engine/test_reconnect.py

index ae71898543df2df71128fe46316858b242a95512..7d052cfd7a87ee42f6442ad59216306371c300dc 100644 (file)
@@ -7,12 +7,18 @@ module.  If running on Python 2.4, pysqlite must be installed.
 
 Unit tests are run using nose.  Nose is available at::
 
-    http://pypi.python.org/pypi/nose/
+    https://pypi.python.org/pypi/nose/
 
 SQLAlchemy implements a nose plugin that must be present when tests are run.
 This plugin is invoked when the test runner script provided with
 SQLAlchemy is used.
 
+The test suite as of version 0.8.2 also requires the mock library.  While
+mock is part of the Python standard library as of 3.3, previous versions
+will need to have it installed, and is available at::
+
+    https://pypi.python.org/pypi/mock
+
 **NOTE:** - the nose plugin is no longer installed by setuptools as of
 version 0.7 !  Use "python setup.py test" or "./sqla_nose.py".
 
index c0e430ad66fd8e7539aa7565d8350f68b880d47e..d83d0618eb68dfdecb92d64396abfdd7bbb59dce 100644 (file)
@@ -6,6 +6,15 @@
 .. changelog::
     :version: 0.8.2
 
+    .. change::
+        :tags: requirements
+
+        The Python `mock <https://pypi.python.org/pypi/mock>`_ library
+        is now required in order to run the unit test suite.  While part
+        of the standard library as of Python 3.3, previous Python installations
+        will need to install this in order to run unit tests or to
+        use the ``sqlalchemy.testing`` package for external dialects.
+
     .. change::
         :tags: bug, orm
         :tickets: 2750
index e571a50458f5e8f5b12ded06f7ee96b8abfb6cd6..61e99936796aa52d4b73a0efa8894f619d584ef8 100644 (file)
@@ -18,3 +18,5 @@ from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
 crashes = skip
 
 from .config import db, requirements as requires
+
+from . import mock
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
new file mode 100644 (file)
index 0000000..6509623
--- /dev/null
@@ -0,0 +1,15 @@
+"""Import stub for mock library.
+"""
+from __future__ import absolute_import
+from ..util import py33
+
+if py33:
+    from unittest.mock import MagicMock, Mock, call
+else:
+    try:
+        from mock import MagicMock, Mock, call
+    except ImportError:
+        raise ImportError(
+                "SQLAlchemy's test suite requires the "
+                "'mock' library as of 0.8.2.")
+
index 4ea778a7704838b04aa59ee66ea22efc5a64750c..c2c5bddb85143e414a9048c8225808eccdccd4e0 100644 (file)
@@ -5,7 +5,8 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 from .compat import callable, cmp, reduce,  \
-    threading, py3k, py2k, py3k_warning, jython, pypy, cpython, win32, set_types, \
+    threading, py3k, py33, py2k, py3k_warning, jython, pypy, cpython, win32, \
+    set_types, \
     pickle, dottedgetter, parse_qsl, namedtuple, next, WeakSet, reraise, \
     raise_from_cause, u, b, ue, string_types, text_type, int_types
 
index 1ea4be917a62b82f011348809d17e5bf36492e78..24d78fa701b494eeb75283e96085a7b9ab64c43b 100644 (file)
@@ -13,6 +13,7 @@ try:
 except ImportError:
     import dummy_threading as threading
 
+py33 = sys.version_info >= (3, 3)
 py32 = sys.version_info >= (3, 2)
 py3k_warning = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0)
 py3k = sys.version_info >= (3, 0)
index 2950a12d40f68d03c4aa22ad97ca93e8a3675c72..0cdbb40d77c538e44431513e3a694018a0e83062 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -131,7 +131,7 @@ def run_setup(with_cext):
           license="MIT License",
           cmdclass=cmdclass,
 
-          tests_require=['nose >= 0.11'],
+          tests_require=['nose >= 0.11', 'mock'],
           test_suite="sqla_nose",
           long_description=readme,
           classifiers=[
index 0146d1b08627836007e56fabe623ba6bcf4cd700..4d92e604ab9283569215f1441e931500364dc555 100644 (file)
@@ -2,6 +2,9 @@ from sqlalchemy import *
 from sqlalchemy.testing import fixtures, AssertsExecutionResults, profiling
 from sqlalchemy import testing
 from sqlalchemy.testing import eq_
+from sqlalchemy.engine.result import RowProxy
+import sys
+
 NUM_FIELDS = 10
 NUM_RECORDS = 1000
 
@@ -79,7 +82,6 @@ class RowProxyTest(fixtures.TestBase):
     __requires__ = 'cpython',
 
     def _rowproxy_fixture(self, keys, processors, row):
-        from sqlalchemy.engine.result import RowProxy
         class MockMeta(object):
             def __init__(self):
                 pass
@@ -95,7 +97,6 @@ class RowProxyTest(fixtures.TestBase):
         return RowProxy(metadata, row, processors, keymap)
 
     def _test_getitem_value_refcounts(self, seq_factory):
-        import sys
         col1, col2 = object(), object()
         def proc1(value):
             return value
index 4efb30aba26f07f05224bc5eeb5953b9109bf2fe..20bfa62ff936b13d6801997c1beb52784465a409 100644 (file)
@@ -5,6 +5,8 @@ from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, \
 from sqlalchemy import event, exc
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.util import gc_collect
+from sqlalchemy.testing.mock import Mock, call
+
 
 class EventsTest(fixtures.TestBase):
     """Test class- and instance-level event registration."""
@@ -190,7 +192,7 @@ class ClsLevelListenTest(fixtures.TestBase):
     def test_lis_subcalss_lis(self):
         @event.listens_for(self.TargetOne, "event_one")
         def handler1(x, y):
-            print 'handler1'
+            pass
 
         class SubTarget(self.TargetOne):
             pass
@@ -207,7 +209,7 @@ class ClsLevelListenTest(fixtures.TestBase):
     def test_lis_multisub_lis(self):
         @event.listens_for(self.TargetOne, "event_one")
         def handler1(x, y):
-            print 'handler1'
+            pass
 
         class SubTarget(self.TargetOne):
             pass
@@ -411,12 +413,8 @@ class ListenOverrideTest(fixtures.TestBase):
         event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
 
     def test_listen_override(self):
-        result = []
-        def listen_one(x):
-            result.append(x)
-
-        def listen_two(x, y):
-            result.append((x, y))
+        listen_one = Mock()
+        listen_two = Mock()
 
         event.listen(self.Target, "event_one", listen_one, add=True)
         event.listen(self.Target, "event_one", listen_two)
@@ -425,10 +423,13 @@ class ListenOverrideTest(fixtures.TestBase):
         t1.dispatch.event_one(5, 7)
         t1.dispatch.event_one(10, 5)
 
-        eq_(result,
-            [
-                12, (5, 7), 15, (10, 5)
-            ]
+        eq_(
+            listen_one.mock_calls,
+            [call(12), call(15)]
+        )
+        eq_(
+            listen_two.mock_calls,
+            [call(5, 7), call(10, 5)]
         )
 
 class PropagateTest(fixtures.TestBase):
@@ -446,12 +447,8 @@ class PropagateTest(fixtures.TestBase):
 
 
     def test_propagate(self):
-        result = []
-        def listen_one(target, arg):
-            result.append((target, arg))
-
-        def listen_two(target, arg):
-            result.append((target, arg))
+        listen_one = Mock()
+        listen_two = Mock()
 
         t1 = self.Target()
 
@@ -464,7 +461,15 @@ class PropagateTest(fixtures.TestBase):
 
         t2.dispatch.event_one(t2, 1)
         t2.dispatch.event_two(t2, 2)
-        eq_(result, [(t2, 1)])
+
+        eq_(
+            listen_one.mock_calls,
+            [call(t2, 1)]
+        )
+        eq_(
+            listen_two.mock_calls,
+            []
+        )
 
 class JoinTest(fixtures.TestBase):
     def setUp(self):
@@ -497,12 +502,6 @@ class JoinTest(fixtures.TestBase):
             if 'dispatch' in cls.__dict__:
                 event._remove_dispatcher(cls.__dict__['dispatch'].events)
 
-    def _listener(self):
-        canary = []
-        def listen(target, arg):
-            canary.append((target, arg))
-        return listen, canary
-
     def test_neither(self):
         element = self.TargetFactory().create()
         element.run_event(1)
@@ -510,22 +509,22 @@ class JoinTest(fixtures.TestBase):
         element.run_event(3)
 
     def test_parent_class_only(self):
-        _listener, canary = self._listener()
+        l1 = Mock()
 
-        event.listen(self.TargetFactory, "event_one", _listener)
+        event.listen(self.TargetFactory, "event_one", l1)
 
         element = self.TargetFactory().create()
         element.run_event(1)
         element.run_event(2)
         element.run_event(3)
         eq_(
-            canary,
-            [(element, 1), (element, 2), (element, 3)]
+            l1.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
 
     def test_parent_class_child_class(self):
-        l1, c1 = self._listener()
-        l2, c2 = self._listener()
+        l1 = Mock()
+        l2 = Mock()
 
         event.listen(self.TargetFactory, "event_one", l1)
         event.listen(self.TargetElement, "event_one", l2)
@@ -535,17 +534,17 @@ class JoinTest(fixtures.TestBase):
         element.run_event(2)
         element.run_event(3)
         eq_(
-            c1,
-            [(element, 1), (element, 2), (element, 3)]
+            l1.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
         eq_(
-            c2,
-            [(element, 1), (element, 2), (element, 3)]
+            l2.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
 
     def test_parent_class_child_instance_apply_after(self):
-        l1, c1 = self._listener()
-        l2, c2 = self._listener()
+        l1 = Mock()
+        l2 = Mock()
 
         event.listen(self.TargetFactory, "event_one", l1)
         element = self.TargetFactory().create()
@@ -557,17 +556,17 @@ class JoinTest(fixtures.TestBase):
         element.run_event(3)
 
         eq_(
-            c1,
-            [(element, 1), (element, 2), (element, 3)]
+            l1.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
         eq_(
-            c2,
-            [(element, 2), (element, 3)]
+            l2.mock_calls,
+            [call(element, 2), call(element, 3)]
         )
 
     def test_parent_class_child_instance_apply_before(self):
-        l1, c1 = self._listener()
-        l2, c2 = self._listener()
+        l1 = Mock()
+        l2 = Mock()
 
         event.listen(self.TargetFactory, "event_one", l1)
         element = self.TargetFactory().create()
@@ -579,17 +578,17 @@ class JoinTest(fixtures.TestBase):
         element.run_event(3)
 
         eq_(
-            c1,
-            [(element, 1), (element, 2), (element, 3)]
+            l1.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
         eq_(
-            c2,
-            [(element, 1), (element, 2), (element, 3)]
+            l2.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
 
     def test_parent_instance_child_class_apply_before(self):
-        l1, c1 = self._listener()
-        l2, c2 = self._listener()
+        l1 = Mock()
+        l2 = Mock()
 
         event.listen(self.TargetElement, "event_one", l2)
 
@@ -603,17 +602,18 @@ class JoinTest(fixtures.TestBase):
         element.run_event(3)
 
         eq_(
-            c1,
-            [(element, 1), (element, 2), (element, 3)]
+            l1.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
         eq_(
-            c2,
-            [(element, 1), (element, 2), (element, 3)]
+            l2.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
 
+
     def test_parent_instance_child_class_apply_after(self):
-        l1, c1 = self._listener()
-        l2, c2 = self._listener()
+        l1 = Mock()
+        l2 = Mock()
 
         event.listen(self.TargetElement, "event_one", l2)
 
@@ -632,18 +632,16 @@ class JoinTest(fixtures.TestBase):
         # this can be changed to be "live" at the cost
         # of performance.
         eq_(
-            c1,
-            []
-            #(element, 2), (element, 3)]
+            l1.mock_calls, []
         )
         eq_(
-            c2,
-            [(element, 1), (element, 2), (element, 3)]
+            l2.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
 
     def test_parent_instance_child_instance_apply_before(self):
-        l1, c1 = self._listener()
-        l2, c2 = self._listener()
+        l1 = Mock()
+        l2 = Mock()
         factory = self.TargetFactory()
 
         event.listen(factory, "event_one", l1)
@@ -656,16 +654,16 @@ class JoinTest(fixtures.TestBase):
         element.run_event(3)
 
         eq_(
-            c1,
-            [(element, 1), (element, 2), (element, 3)]
+            l1.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
         eq_(
-            c2,
-            [(element, 1), (element, 2), (element, 3)]
+            l2.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
 
     def test_parent_events_child_no_events(self):
-        l1, c1 = self._listener()
+        l1 = Mock()
         factory = self.TargetFactory()
 
         event.listen(self.TargetElement, "event_one", l1)
@@ -676,6 +674,6 @@ class JoinTest(fixtures.TestBase):
         element.run_event(3)
 
         eq_(
-            c1,
-            [(element, 1), (element, 2), (element, 3)]
+            l1.mock_calls,
+            [call(element, 1), call(element, 2), call(element, 3)]
         )
index 86ce91dc9546e709261d96f022ea2c0e42ada048..1fc239cb748a22e4b1bdf7f9ff0408932b476523 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy import exc, schema
 from sqlalchemy.dialects.postgresql import base as postgresql
 import logging
 import logging.handlers
+from sqlalchemy.testing.mock import Mock
 
 class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
 
@@ -37,18 +38,12 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
                       'The JDBC driver handles the version parsing')
     def test_version_parsing(self):
 
-
-        class MockConn(object):
-
-            def __init__(self, res):
-                self.res = res
-
-            def execute(self, str):
-                return self
-
-            def scalar(self):
-                return self.res
-
+        def mock_conn(res):
+            return Mock(
+                    execute=Mock(
+                            return_value=Mock(scalar=Mock(return_value=res))
+                        )
+                    )
 
         for string, version in \
             [('PostgreSQL 8.3.8 on i686-redhat-linux-gnu, compiled by '
@@ -59,7 +54,7 @@ class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
              ('EnterpriseDB 9.1.2.2 on x86_64-unknown-linux-gnu, '
              'compiled by gcc (GCC) 4.1.2 20080704 (Red Hat 4.1.2-50), '
              '64-bit', (9, 1, 2))]:
-            eq_(testing.db.dialect._get_server_version_info(MockConn(string)),
+            eq_(testing.db.dialect._get_server_version_info(mock_conn(string)),
                 version)
 
     @testing.only_on('postgresql+psycopg2', 'psycopg2-specific feature')
index 32cad4168125f33ff9eac9ed37b3e580f7eb8ad7..e46de91494844c2adba68efeeeb1121d69fa5325 100644 (file)
@@ -2,75 +2,48 @@ from sqlalchemy import *
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import fixtures
-
-# TODO: we should probably build mock bases for
-# these to share with test_reconnect, test_parseconnect
-class MockDBAPI(object):
-    paramstyle = 'qmark'
-    def __init__(self):
-        self.log = []
-    def connect(self, *args, **kwargs):
-        return MockConnection(self)
-
-class MockConnection(object):
-    def __init__(self, parent):
-        self.parent = parent
-    def cursor(self):
-        return MockCursor(self)
-    def close(self):
-        pass
-    def rollback(self):
-        pass
-    def commit(self):
-        pass
-
-class MockCursor(object):
-    description = None
-    rowcount = None
-    def __init__(self, parent):
-        self.parent = parent
-    def execute(self, *args, **kwargs):
-        if kwargs.get('direct', False):
-            self.executedirect()
-        else:
-            self.parent.parent.log.append('execute')
-    def executedirect(self, *args, **kwargs):
-        self.parent.parent.log.append('executedirect')
-    def close(self):
-        pass
+from sqlalchemy.testing.mock import Mock
+
+def mock_dbapi():
+    return Mock(paramstyle='qmark',
+                connect=Mock(
+                        return_value=Mock(
+                            cursor=Mock(
+                                return_value=Mock(
+                                                description=None,
+                                                rowcount=None)
+                                )
+                        )
+                )
+            )
 
 class MxODBCTest(fixtures.TestBase):
 
     def test_native_odbc_execute(self):
         t1 = Table('t1', MetaData(), Column('c1', Integer))
-        dbapi = MockDBAPI()
+        dbapi = mock_dbapi()
+
         engine = engines.testing_engine('mssql+mxodbc://localhost',
                 options={'module': dbapi, '_initialize': False})
         conn = engine.connect()
 
         # crud: uses execute
-
         conn.execute(t1.insert().values(c1='foo'))
         conn.execute(t1.delete().where(t1.c.c1 == 'foo'))
-        conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar'
-                     ))
+        conn.execute(t1.update().where(t1.c.c1 == 'foo').values(c1='bar'))
 
         # select: uses executedirect
-
         conn.execute(t1.select())
 
         # manual flagging
-
         conn.execution_options(native_odbc_execute=True).\
                 execute(t1.select())
         conn.execution_options(native_odbc_execute=False).\
-                execute(t1.insert().values(c1='foo'
-                ))
-        eq_(dbapi.log, [
-            'executedirect',
-            'executedirect',
-            'executedirect',
-            'executedirect',
-            'execute',
-            'executedirect',
-            ])
+                execute(t1.insert().values(c1='foo'))
+
+        eq_(
+            [c[2] for c in
+            dbapi.connect.return_value.cursor.return_value.execute.mock_calls],
+            [{'direct': True}, {'direct': True}, {'direct': True},
+                {'direct': True}, {'direct': False}, {'direct': True}]
+        )
index 3dbd5756ad0cb98ae81b607d344921fe4a98fb15..040b741b4d1967f650acdee55abcc11e438f581e 100644 (file)
@@ -3,28 +3,19 @@ from sqlalchemy.engine.ddl import SchemaGenerator, SchemaDropper
 from sqlalchemy.engine import default
 from sqlalchemy import MetaData, Table, Column, Integer, Sequence
 from sqlalchemy import schema
+from sqlalchemy.testing.mock import Mock
 
 class EmitDDLTest(fixtures.TestBase):
     def _mock_connection(self, item_exists):
-        _canary = []
+        def has_item(connection, name, schema):
+            return item_exists(name)
 
-        class MockDialect(default.DefaultDialect):
-            supports_sequences = True
-
-            def has_table(self, connection, name, schema):
-                return item_exists(name)
-
-            def has_sequence(self, connection, name, schema):
-                return item_exists(name)
-
-        class MockConnection(object):
-            dialect = MockDialect()
-            canary = _canary
-
-            def execute(self, item):
-                _canary.append(item)
-
-        return MockConnection()
+        return Mock(dialect=Mock(
+                    supports_sequences=True,
+                    has_table=Mock(side_effect=has_item),
+                    has_sequence=Mock(side_effect=has_item)
+                )
+                )
 
     def _mock_create_fixture(self, checkfirst, tables,
                     item_exists=lambda item: False):
@@ -176,7 +167,8 @@ class EmitDDLTest(fixtures.TestBase):
 
     def _assert_ddl(self, ddl_cls, elements, generator, argument):
         generator.traverse_single(argument)
-        for c in generator.connection.canary:
+        for call_ in generator.connection.execute.mock_calls:
+            c = call_[1][0]
             assert isinstance(c, ddl_cls)
             assert c.element in elements, "element %r was not expected"\
                              % c.element
index 203d7bd71fd372abd8f6b1ca8734f80073e0fc44..cbe53d07b9462c4033065b2cc4c7ed6075f057c4 100644 (file)
@@ -18,7 +18,7 @@ from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
 from sqlalchemy.engine import result as _result, default
 from sqlalchemy.engine.base import Connection, Engine
 from sqlalchemy.testing import fixtures
-import StringIO
+from sqlalchemy.testing.mock import Mock, call
 
 users, metadata, users_autoinc = None, None, None
 class ExecuteTest(fixtures.TestBase):
@@ -455,20 +455,22 @@ class ConvenienceExecuteTest(fixtures.TablesTest):
 
     def test_transaction_engine_ctx_begin_fails(self):
         engine = engines.testing_engine()
-        class MockConnection(Connection):
-            closed = False
-            def begin(self):
-                raise Exception("boom")
-
-            def close(self):
-                MockConnection.closed = True
-        engine._connection_cls = MockConnection
-        fn = self._trans_fn()
+
+        mock_connection = Mock(
+            return_value=Mock(
+                        begin=Mock(side_effect=Exception("boom"))
+                    )
+        )
+        engine._connection_cls = mock_connection
         assert_raises(
             Exception,
             engine.begin
         )
-        assert MockConnection.closed
+
+        eq_(
+            mock_connection.return_value.close.mock_calls,
+            [call()]
+        )
 
     def test_transaction_engine_ctx_rollback(self):
         fn = self._trans_rollback_fn()
index a00a942cb4d359d552cd3d56d58cdafab666ae86..437cccbe45c991fb5a5b090be86cf72dc0a4c147 100644 (file)
@@ -8,6 +8,8 @@ from sqlalchemy.engine.default import DefaultDialect
 import sqlalchemy as tsa
 from sqlalchemy.testing import fixtures
 from sqlalchemy import testing
+from sqlalchemy.testing.mock import Mock
+
 
 class ParseConnectTest(fixtures.TestBase):
     def test_rfc1738(self):
@@ -251,20 +253,17 @@ pool_timeout=10
         every backend.
 
         """
-        # pretend pysqlite throws the
-        # "Cannot operate on a closed database." error
-        # on connect.   IRL we'd be getting Oracle's "shutdown in progress"
 
         e = create_engine('sqlite://')
         sqlite3 = e.dialect.dbapi
-        class ThrowOnConnect(MockDBAPI):
-            dbapi = sqlite3
-            Error = sqlite3.Error
-            ProgrammingError = sqlite3.ProgrammingError
-            def connect(self, *args, **kw):
-                raise sqlite3.ProgrammingError("Cannot operate on a closed database.")
+
+        dbapi = MockDBAPI()
+        dbapi.Error = sqlite3.Error,
+        dbapi.ProgrammingError = sqlite3.ProgrammingError
+        dbapi.connect = Mock(side_effect=sqlite3.ProgrammingError(
+                                    "Cannot operate on a closed database."))
         try:
-            create_engine('sqlite://', module=ThrowOnConnect()).connect()
+            create_engine('sqlite://', module=dbapi).connect()
             assert False
         except tsa.exc.DBAPIError, de:
             assert de.connection_invalidated
@@ -355,36 +354,23 @@ class MockDialect(DefaultDialect):
     def dbapi(cls, **kw):
         return MockDBAPI()
 
-class MockDBAPI(object):
-    version_info = sqlite_version_info = 99, 9, 9
-    sqlite_version = '99.9.9'
-
-    def __init__(self, **kwargs):
-        self.kwargs = kwargs
-        self.paramstyle = 'named'
-
-    def connect(self, *args, **kwargs):
-        for k in self.kwargs:
+def MockDBAPI(**assert_kwargs):
+    connection = Mock(get_server_version_info=Mock(return_value='5.0'))
+    def connect(*args, **kwargs):
+        for k in assert_kwargs:
             assert k in kwargs, 'key %s not present in dictionary' % k
-            assert kwargs[k] == self.kwargs[k], \
-                'value %s does not match %s' % (kwargs[k],
-                    self.kwargs[k])
-        return MockConnection()
-
-
-class MockConnection(object):
-    def get_server_info(self):
-        return '5.0'
-
-    def close(self):
-        pass
-
-    def cursor(self):
-        return MockCursor()
-
-class MockCursor(object):
-    def close(self):
-        pass
+            eq_(
+                kwargs[k], assert_kwargs[k]
+            )
+        return connection
+
+    return Mock(
+                sqlite_version_info=(99, 9, 9,),
+                version_info=(99, 9, 9,),
+                sqlite_version='99.9.9',
+                paramstyle='named',
+                connect=Mock(side_effect=connect)
+            )
 
 mock_dbapi = MockDBAPI()
 mock_sqlite_dbapi = msd = MockDBAPI()
index ae02417f95597d12dbc9d5b499ddaf571edd68f1..a3c7edf5f32a3cff5378579958280693bdf88900 100644 (file)
@@ -4,37 +4,21 @@ from sqlalchemy import pool, select, event
 import sqlalchemy as tsa
 from sqlalchemy import testing
 from sqlalchemy.testing.util import gc_collect, lazy_gc
-from sqlalchemy.testing import eq_, assert_raises
+from sqlalchemy.testing import eq_, assert_raises, is_not_
 from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing import fixtures
 
-mcid = 1
-class MockDBAPI(object):
-    throw_error = False
-    def connect(self, *args, **kwargs):
-        if self.throw_error:
-            raise Exception("couldnt connect !")
-        delay = kwargs.pop('delay', 0)
-        if delay:
-            time.sleep(delay)
-        return MockConnection()
-class MockConnection(object):
-    closed = False
-    def __init__(self):
-        global mcid
-        self.id = mcid
-        mcid += 1
-    def close(self):
-        self.closed = True
-    def rollback(self):
-        pass
-    def cursor(self):
-        return MockCursor()
-class MockCursor(object):
-    def execute(self, *args, **kw):
-        pass
-    def close(self):
-        pass
+from sqlalchemy.testing.mock import Mock, call
+
+def MockDBAPI():
+    def cursor():
+        while True:
+            yield Mock()
+    def connect():
+        while True:
+            yield Mock(cursor=Mock(side_effect=cursor()))
+
+    return Mock(connect=Mock(side_effect=connect()))
 
 class PoolTestBase(fixtures.TestBase):
     def setup(self):
@@ -71,11 +55,9 @@ class PoolTest(PoolTestBase):
         assert c4 is not c5
 
     def test_manager_with_key(self):
-        class NoKws(object):
-            def connect(self, arg):
-                return MockConnection()
 
-        manager = pool.manage(NoKws(), use_threadlocal=True)
+        dbapi = MockDBAPI()
+        manager = pool.manage(dbapi, use_threadlocal=True)
 
         c1 = manager.connect('foo.db', sa_pool_key="a")
         c2 = manager.connect('foo.db', sa_pool_key="b")
@@ -83,9 +65,14 @@ class PoolTest(PoolTestBase):
 
         assert c1.cursor() is not None
         assert c1 is not c2
-        assert c1 is  c3
-
+        assert c1 is c3
 
+        eq_(dbapi.connect.mock_calls,
+            [
+                call("foo.db"),
+                call("foo.db"),
+            ]
+        )
 
 
     def test_bad_args(self):
@@ -127,7 +114,7 @@ class PoolTest(PoolTestBase):
             p = cls(creator=mock_dbapi.connect)
             conn = p.connect()
             conn.close()
-            mock_dbapi.throw_error = True
+            mock_dbapi.connect.side_effect = Exception("error!")
             p.dispose()
             p.recreate()
 
@@ -211,9 +198,9 @@ class PoolTest(PoolTestBase):
         self.assert_('foo2' in c.info)
 
         c2 = p.connect()
-        self.assert_(c.connection is not c2.connection)
-        self.assert_(not c2.info)
-        self.assert_('foo2' in c.info)
+        is_not_(c.connection, c2.connection)
+        assert not c2.info
+        assert 'foo2' in c.info
 
 
 class PoolDialectTest(PoolTestBase):
@@ -945,19 +932,24 @@ class QueuePoolTest(PoolTestBase):
 
     def test_dispose_closes_pooled(self):
         dbapi = MockDBAPI()
-        def creator():
-            return dbapi.connect()
 
-        p = pool.QueuePool(creator=creator,
+        p = pool.QueuePool(creator=dbapi.connect,
                            pool_size=2, timeout=None,
                            max_overflow=0)
         c1 = p.connect()
         c2 = p.connect()
-        conns = [c1.connection, c2.connection]
+        c1_con = c1.connection
+        c2_con = c2.connection
+
         c1.close()
-        eq_([c.closed for c in conns], [False, False])
+
+        eq_(c1_con.close.call_count, 0)
+        eq_(c2_con.close.call_count, 0)
+
         p.dispose()
-        eq_([c.closed for c in conns], [True, False])
+
+        eq_(c1_con.close.call_count, 1)
+        eq_(c2_con.close.call_count, 0)
 
         # currently, if a ConnectionFairy is closed
         # after the pool has been disposed, there's no
@@ -965,11 +957,12 @@ class QueuePoolTest(PoolTestBase):
         # immediately - it just gets returned to the
         # pool normally...
         c2.close()
-        eq_([c.closed for c in conns], [True, False])
+        eq_(c1_con.close.call_count, 1)
+        eq_(c2_con.close.call_count, 0)
 
         # ...and that's the one we'll get back next.
         c3 = p.connect()
-        assert c3.connection is conns[1]
+        assert c3.connection is c2_con
 
     def test_no_overflow(self):
         self._test_overflow(40, 0)
@@ -1009,13 +1002,22 @@ class QueuePoolTest(PoolTestBase):
             strong_refs.add(c.connection)
             return c
 
-        for j in xrange(5):
-            conns = [_conn() for i in xrange(4)]
+        for j in range(5):
+            # open 4 conns at a time.  each time this
+            # will yield two pooled connections + two
+            # overflow connections.
+            conns = [_conn() for i in range(4)]
             for c in conns:
                 c.close()
 
-        still_opened = len([c for c in strong_refs if not c.closed])
-        eq_(still_opened, 2)
+        # doing that for a total of 5 times yields
+        # ten overflow connections closed plus the
+        # two pooled connections unclosed.
+
+        eq_(
+            set([c.close.call_count for c in strong_refs]),
+            set([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0])
+        )
 
     @testing.requires.predictable_gc
     def test_weakref_kaboom(self):
@@ -1108,18 +1110,30 @@ class QueuePoolTest(PoolTestBase):
         dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
         c1 = p.connect()
         c1.detach()
-        c_id = c1.connection.id
-        c2 = p.connect()
-        assert c2.connection.id != c1.connection.id
-        dbapi.raise_error = True
-        c2.invalidate()
-        c2 = None
         c2 = p.connect()
-        assert c2.connection.id != c1.connection.id
-        con = c1.connection
-        assert not con.closed
+        eq_(dbapi.connect.mock_calls, [call("foo.db"), call("foo.db")])
+
+        c1_con = c1.connection
+        assert c1_con is not None
+        eq_(c1_con.close.call_count, 0)
         c1.close()
-        assert con.closed
+        eq_(c1_con.close.call_count, 1)
+
+    def test_detach_via_invalidate(self):
+        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
+
+        c1 = p.connect()
+        c1_con = c1.connection
+        c1.invalidate()
+        assert c1.connection is None
+        eq_(c1_con.close.call_count, 1)
+
+        c2 = p.connect()
+        assert c2.connection is not c1_con
+        c2_con = c2.connection
+
+        c2.close()
+        eq_(c2_con.close.call_count, 0)
 
     def test_threadfairy(self):
         p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True)
@@ -1141,8 +1155,13 @@ class SingletonThreadPoolTest(PoolTestBase):
         been called."""
 
         dbapi = MockDBAPI()
-        p = pool.SingletonThreadPool(creator=dbapi.connect,
-                pool_size=3)
+
+        lock = threading.Lock()
+        def creator():
+            # the mock iterator isn't threadsafe...
+            with lock:
+                return dbapi.connect()
+        p = pool.SingletonThreadPool(creator=creator, pool_size=3)
 
         if strong_refs:
             sr = set()
@@ -1172,7 +1191,7 @@ class SingletonThreadPoolTest(PoolTestBase):
         assert len(p._all_conns) == 3
 
         if strong_refs:
-            still_opened = len([c for c in sr if not c.closed])
+            still_opened = len([c for c in sr if not c.close.call_count])
             eq_(still_opened, 3)
 
 class AssertionPoolTest(PoolTestBase):
@@ -1198,17 +1217,19 @@ class NullPoolTest(PoolTestBase):
         dbapi = MockDBAPI()
         p = pool.NullPool(creator=lambda: dbapi.connect('foo.db'))
         c1 = p.connect()
-        c_id = c1.connection.id
+
         c1.close()
         c1 = None
 
         c1 = p.connect()
-        dbapi.raise_error = True
         c1.invalidate()
         c1 = None
 
         c1 = p.connect()
-        assert c1.connection.id != c_id
+        dbapi.connect.assert_has_calls([
+                            call('foo.db'),
+                            call('foo.db')],
+                            any_order=True)
 
 
 class StaticPoolTest(PoolTestBase):
index 567647f20851b2683da7a42e47ae4da0d9ab9074..ce8d03a2039d3634b002606d87ca2b79eede662e 100644 (file)
@@ -1,7 +1,6 @@
 from sqlalchemy.testing import eq_, assert_raises, assert_raises_message
 import time
-import weakref
-from sqlalchemy import select, MetaData, Integer, String, pool, create_engine
+from sqlalchemy import select, MetaData, Integer, String, create_engine, pool
 from sqlalchemy.testing.schema import Table, Column
 import sqlalchemy as tsa
 from sqlalchemy import testing
@@ -10,6 +9,8 @@ from sqlalchemy.testing.util import gc_collect
 from sqlalchemy import exc
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.engines import testing_engine
+from sqlalchemy.testing import is_not_
+from sqlalchemy.testing.mock import Mock, call
 
 class MockError(Exception):
     pass
@@ -17,93 +18,103 @@ class MockError(Exception):
 class MockDisconnect(MockError):
     pass
 
-class MockDBAPI(object):
-    def __init__(self):
-        self.paramstyle = 'named'
-        self.connections = weakref.WeakKeyDictionary()
-    def connect(self, *args, **kwargs):
-        return MockConnection(self)
-    def shutdown(self, explode='execute'):
-        for c in self.connections:
-            c.explode = explode
-    Error = MockError
-
-class MockConnection(object):
-    def __init__(self, dbapi):
-        dbapi.connections[self] = True
-        self.explode = ""
-    def rollback(self):
-        if self.explode == 'rollback':
+def mock_connection():
+    def mock_cursor():
+        def execute(*args, **kwargs):
+            if conn.explode == 'execute':
+                raise MockDisconnect("Lost the DB connection on execute")
+            elif conn.explode in ('execute_no_disconnect', ):
+                raise MockError(
+                    "something broke on execute but we didn't lose the connection")
+            elif conn.explode in ('rollback', 'rollback_no_disconnect'):
+                raise MockError(
+                    "something broke on execute but we didn't lose the connection")
+            elif args and "SELECT" in args[0]:
+                cursor.description = [('foo', None, None, None, None, None)]
+            else:
+                return
+
+        def close():
+            cursor.fetchall = cursor.fetchone = \
+                Mock(side_effect=MockError("cursor closed"))
+        cursor = Mock(
+                    execute=Mock(side_effect=execute),
+                    close=Mock(side_effect=close)
+                )
+        return cursor
+
+    def cursor():
+        while True:
+            yield mock_cursor()
+
+    def rollback():
+        if conn.explode == 'rollback':
             raise MockDisconnect("Lost the DB connection on rollback")
-        if self.explode == 'rollback_no_disconnect':
+        if conn.explode == 'rollback_no_disconnect':
             raise MockError(
                 "something broke on rollback but we didn't lose the connection")
         else:
             return
-    def commit(self):
-        pass
-    def cursor(self):
-        return MockCursor(self)
-    def close(self):
-        pass
-
-class MockCursor(object):
-    def __init__(self, parent):
-        self.explode = parent.explode
-        self.description = ()
-        self.closed = False
-    def execute(self, *args, **kwargs):
-        if self.explode == 'execute':
-            raise MockDisconnect("Lost the DB connection on execute")
-        elif self.explode in ('execute_no_disconnect', ):
-            raise MockError(
-                "something broke on execute but we didn't lose the connection")
-        elif self.explode in ('rollback', 'rollback_no_disconnect'):
-            raise MockError(
-                "something broke on execute but we didn't lose the connection")
-        elif args and "select" in args[0]:
-            self.description = [('foo', None, None, None, None, None)]
-        else:
-            return
-    def fetchall(self):
-        if self.closed:
-            raise MockError("cursor closed")
-        return []
-    def fetchone(self):
-        if self.closed:
-            raise MockError("cursor closed")
-        return None
-    def close(self):
-        self.closed = True
-
-db, dbapi = None, None
+
+    conn = Mock(
+                rollback=Mock(side_effect=rollback),
+                cursor=Mock(side_effect=cursor())
+            )
+    return conn
+
+def MockDBAPI():
+    connections = []
+    def connect():
+        while True:
+            conn = mock_connection()
+            connections.append(conn)
+            yield conn
+
+    def shutdown(explode='execute'):
+        for c in connections:
+            c.explode = explode
+
+    def dispose():
+        for c in connections:
+            c.explode = None
+        connections[:] = []
+
+    return Mock(
+                connect=Mock(side_effect=connect()),
+                shutdown=Mock(side_effect=shutdown),
+                dispose=Mock(side_effect=dispose),
+                paramstyle='named',
+                connections=connections,
+                Error=MockError
+            )
+
+
 class MockReconnectTest(fixtures.TestBase):
     def setup(self):
-        global db, dbapi
-        dbapi = MockDBAPI()
+        self.dbapi = MockDBAPI()
 
-        # note - using straight create_engine here
-        # since we are testing gc
-        db = create_engine(
+        self.db = testing_engine(
                     'postgresql://foo:bar@localhost/test',
-                    module=dbapi, _initialize=False)
+                    options=dict(module=self.dbapi, _initialize=False))
 
+        self.mock_connect = call(host='localhost', password='bar',
+                                user='foo', database='test')
         # monkeypatch disconnect checker
-        db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect)
+        self.db.dialect.is_disconnect = lambda e, conn, cursor: isinstance(e, MockDisconnect)
 
     def teardown(self):
-        db.dispose()
+        self.dbapi.dispose()
 
     def test_reconnect(self):
         """test that an 'is_disconnect' condition will invalidate the
         connection, and additionally dispose the previous connection
         pool and recreate."""
 
-        pid = id(db.pool)
+        db_pool = self.db.pool
 
         # make a connection
 
-        conn = db.connect()
+        conn = self.db.connect()
 
         # connection works
 
@@ -112,21 +123,20 @@ class MockReconnectTest(fixtures.TestBase):
         # create a second connection within the pool, which we'll ensure
         # also goes away
 
-        conn2 = db.connect()
+        conn2 = self.db.connect()
         conn2.close()
 
         # two connections opened total now
 
-        assert len(dbapi.connections) == 2
+        assert len(self.dbapi.connections) == 2
 
         # set it to fail
 
-        dbapi.shutdown()
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError:
-            pass
+        self.dbapi.shutdown()
+        assert_raises(
+            tsa.exc.DBAPIError,
+            conn.execute, select([1])
+        )
 
         # assert was invalidated
 
@@ -136,31 +146,38 @@ class MockReconnectTest(fixtures.TestBase):
         # close shouldnt break
 
         conn.close()
-        assert id(db.pool) != pid
+        is_not_(self.db.pool, db_pool)
 
         # ensure all connections closed (pool was recycled)
 
-        gc_collect()
-        assert len(dbapi.connections) == 0
-        conn = db.connect()
+        eq_(
+            [c.close.mock_calls for c in self.dbapi.connections],
+            [[call()], [call()]]
+        )
+
+        conn = self.db.connect()
         conn.execute(select([1]))
         conn.close()
-        assert len(dbapi.connections) == 1
+
+        eq_(
+            [c.close.mock_calls for c in self.dbapi.connections],
+            [[call()], [call()], []]
+        )
 
     def test_invalidate_trans(self):
-        conn = db.connect()
+        conn = self.db.connect()
         trans = conn.begin()
-        dbapi.shutdown()
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError:
-            pass
+        self.dbapi.shutdown()
 
-        # assert was invalidated
+        assert_raises(
+            tsa.exc.DBAPIError,
+            conn.execute, select([1])
+        )
 
-        gc_collect()
-        assert len(dbapi.connections) == 0
+        eq_(
+            [c.close.mock_calls for c in self.dbapi.connections],
+            [[call()]]
+        )
         assert not conn.closed
         assert conn.invalidated
         assert trans.is_active
@@ -170,28 +187,35 @@ class MockReconnectTest(fixtures.TestBase):
             conn.execute, select([1])
         )
         assert trans.is_active
-        try:
-            trans.commit()
-            assert False
-        except tsa.exc.InvalidRequestError, e:
-            assert str(e) \
-                == "Can't reconnect until invalid transaction is "\
-                "rolled back"
+
+        assert_raises_message(
+            tsa.exc.InvalidRequestError,
+            "Can't reconnect until invalid transaction is "
+                "rolled back",
+            trans.commit
+        )
+
         assert trans.is_active
         trans.rollback()
         assert not trans.is_active
         conn.execute(select([1]))
         assert not conn.invalidated
-        assert len(dbapi.connections) == 1
+        eq_(
+            [c.close.mock_calls for c in self.dbapi.connections],
+            [[call()], []]
+        )
 
     def test_conn_reusable(self):
-        conn = db.connect()
+        conn = self.db.connect()
 
         conn.execute(select([1]))
 
-        assert len(dbapi.connections) == 1
+        eq_(
+            self.dbapi.connect.mock_calls,
+            [self.mock_connect]
+        )
 
-        dbapi.shutdown()
+        self.dbapi.shutdown()
 
         assert_raises(
             tsa.exc.DBAPIError,
@@ -201,19 +225,24 @@ class MockReconnectTest(fixtures.TestBase):
         assert not conn.closed
         assert conn.invalidated
 
-        # ensure all connections closed (pool was recycled)
-        gc_collect()
-        assert len(dbapi.connections) == 0
+        eq_(
+            [c.close.mock_calls for c in self.dbapi.connections],
+            [[call()]]
+        )
 
         # test reconnects
         conn.execute(select([1]))
         assert not conn.invalidated
-        assert len(dbapi.connections) == 1
+
+        eq_(
+            [c.close.mock_calls for c in self.dbapi.connections],
+            [[call()], []]
+        )
 
     def test_invalidated_close(self):
-        conn = db.connect()
+        conn = self.db.connect()
 
-        dbapi.shutdown()
+        self.dbapi.shutdown()
 
         assert_raises(
             tsa.exc.DBAPIError,
@@ -230,9 +259,9 @@ class MockReconnectTest(fixtures.TestBase):
         )
 
     def test_noreconnect_execute_plus_closewresult(self):
-        conn = db.connect(close_with_result=True)
+        conn = self.db.connect(close_with_result=True)
 
-        dbapi.shutdown("execute_no_disconnect")
+        self.dbapi.shutdown("execute_no_disconnect")
 
         # raises error
         assert_raises_message(
@@ -245,9 +274,9 @@ class MockReconnectTest(fixtures.TestBase):
         assert not conn.invalidated
 
     def test_noreconnect_rollback_plus_closewresult(self):
-        conn = db.connect(close_with_result=True)
+        conn = self.db.connect(close_with_result=True)
 
-        dbapi.shutdown("rollback_no_disconnect")
+        self.dbapi.shutdown("rollback_no_disconnect")
 
         # raises error
         assert_raises_message(
@@ -266,13 +295,13 @@ class MockReconnectTest(fixtures.TestBase):
         )
 
     def test_reconnect_on_reentrant(self):
-        conn = db.connect()
+        conn = self.db.connect()
 
         conn.execute(select([1]))
 
-        assert len(dbapi.connections) == 1
+        assert len(self.dbapi.connections) == 1
 
-        dbapi.shutdown("rollback")
+        self.dbapi.shutdown("rollback")
 
         # raises error
         assert_raises_message(
@@ -285,9 +314,9 @@ class MockReconnectTest(fixtures.TestBase):
         assert conn.invalidated
 
     def test_reconnect_on_reentrant_plus_closewresult(self):
-        conn = db.connect(close_with_result=True)
+        conn = self.db.connect(close_with_result=True)
 
-        dbapi.shutdown("rollback")
+        self.dbapi.shutdown("rollback")
 
         # raises error
         assert_raises_message(
@@ -306,10 +335,11 @@ class MockReconnectTest(fixtures.TestBase):
         )
 
     def test_check_disconnect_no_cursor(self):
-        conn = db.connect()
-        result = conn.execute("select 1")
+        conn = self.db.connect()
+        result = conn.execute(select([1]))
         result.cursor.close()
         conn.close()
+
         assert_raises_message(
             tsa.exc.DBAPIError,
             "cursor closed",
@@ -319,60 +349,59 @@ class MockReconnectTest(fixtures.TestBase):
 class CursorErrTest(fixtures.TestBase):
 
     def setup(self):
-        global db, dbapi
-
-        class MDBAPI(MockDBAPI):
-            def connect(self, *args, **kwargs):
-                return MConn(self)
-
-        class MConn(MockConnection):
-            def cursor(self):
-                return MCursor(self)
+        def MockDBAPI():
+            def cursor():
+                while True:
+                    yield Mock(
+                        description=[],
+                        close=Mock(side_effect=Exception("explode")))
+            def connect():
+                while True:
+                    yield Mock(cursor=Mock(side_effect=cursor()))
+
+            return Mock(connect=Mock(side_effect=connect()))
 
-        class MCursor(MockCursor):
-            def close(self):
-                raise Exception("explode")
-
-        dbapi = MDBAPI()
-
-        db = testing_engine(
+        dbapi = MockDBAPI()
+        self.db = testing_engine(
                     'postgresql://foo:bar@localhost/test',
                     options=dict(module=dbapi, _initialize=False))
 
     def test_cursor_explode(self):
-        conn = db.connect()
+        conn = self.db.connect()
         result = conn.execute("select foo")
         result.close()
         conn.close()
 
     def teardown(self):
-        db.dispose()
+        self.db.dispose()
+
+
+def _assert_invalidated(fn, *args):
+    try:
+        fn(*args)
+        assert False
+    except tsa.exc.DBAPIError as e:
+        if not e.connection_invalidated:
+            raise
 
-engine = None
 class RealReconnectTest(fixtures.TestBase):
     def setup(self):
-        global engine
-        engine = engines.reconnecting_engine()
+        self.engine = engines.reconnecting_engine()
 
     def teardown(self):
-        engine.dispose()
+        self.engine.dispose()
 
     @testing.fails_on('+informixdb',
                       "Wrong error thrown, fix in informixdb?")
     def test_reconnect(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
 
         eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
 
-        engine.test_shutdown()
+        self.engine.test_shutdown()
 
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError, e:
-            if not e.connection_invalidated:
-                raise
+        _assert_invalidated(conn.execute, select([1]))
 
         assert not conn.closed
         assert conn.invalidated
@@ -382,13 +411,9 @@ class RealReconnectTest(fixtures.TestBase):
         assert not conn.invalidated
 
         # one more time
-        engine.test_shutdown()
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError, e:
-            if not e.connection_invalidated:
-                raise
+        self.engine.test_shutdown()
+        _assert_invalidated(conn.execute, select([1]))
+
         assert conn.invalidated
         eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
@@ -396,30 +421,22 @@ class RealReconnectTest(fixtures.TestBase):
         conn.close()
 
     def test_multiple_invalidate(self):
-        c1 = engine.connect()
-        c2 = engine.connect()
+        c1 = self.engine.connect()
+        c2 = self.engine.connect()
 
         eq_(c1.execute(select([1])).scalar(), 1)
 
-        p1 = engine.pool
-        engine.test_shutdown()
+        p1 = self.engine.pool
+        self.engine.test_shutdown()
 
-        try:
-            c1.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError, e:
-            assert e.connection_invalidated
+        _assert_invalidated(c1.execute, select([1]))
 
-        p2 = engine.pool
+        p2 = self.engine.pool
 
-        try:
-            c2.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError, e:
-            assert e.connection_invalidated
+        _assert_invalidated(c2.execute, select([1]))
 
         # pool isn't replaced
-        assert engine.pool is p2
+        assert self.engine.pool is p2
 
 
     def test_ensure_is_disconnect_gets_connection(self):
@@ -430,37 +447,37 @@ class RealReconnectTest(fixtures.TestBase):
             # though MySQLdb we get a non-working cursor.
             # assert cursor is None
 
-        engine.dialect.is_disconnect = is_disconnect
-        conn = engine.connect()
-        engine.test_shutdown()
+        self.engine.dialect.is_disconnect = is_disconnect
+        conn = self.engine.connect()
+        self.engine.test_shutdown()
         assert_raises(
             tsa.exc.DBAPIError,
             conn.execute, select([1])
         )
 
     def test_rollback_on_invalid_plain(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
         trans = conn.begin()
         conn.invalidate()
         trans.rollback()
 
     @testing.requires.two_phase_transactions
     def test_rollback_on_invalid_twophase(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
         trans = conn.begin_twophase()
         conn.invalidate()
         trans.rollback()
 
     @testing.requires.savepoints
     def test_rollback_on_invalid_savepoint(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
         trans = conn.begin()
         trans2 = conn.begin_nested()
         conn.invalidate()
         trans2.rollback()
 
     def test_invalidate_twice(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
         conn.invalidate()
         conn.invalidate()
 
@@ -500,12 +517,7 @@ class RealReconnectTest(fixtures.TestBase):
         eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
         engine.test_shutdown()
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError, e:
-            if not e.connection_invalidated:
-                raise
+        _assert_invalidated(conn.execute, select([1]))
         assert not conn.closed
         assert conn.invalidated
         eq_(conn.execute(select([1])).scalar(), 1)
@@ -514,37 +526,27 @@ class RealReconnectTest(fixtures.TestBase):
     @testing.fails_on('+informixdb',
                       "Wrong error thrown, fix in informixdb?")
     def test_close(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
         eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
 
-        engine.test_shutdown()
+        self.engine.test_shutdown()
 
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError, e:
-            if not e.connection_invalidated:
-                raise
+        _assert_invalidated(conn.execute, select([1]))
 
         conn.close()
-        conn = engine.connect()
+        conn = self.engine.connect()
         eq_(conn.execute(select([1])).scalar(), 1)
 
     @testing.fails_on('+informixdb',
                       "Wrong error thrown, fix in informixdb?")
     def test_with_transaction(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
         trans = conn.begin()
         eq_(conn.execute(select([1])).scalar(), 1)
         assert not conn.closed
-        engine.test_shutdown()
-        try:
-            conn.execute(select([1]))
-            assert False
-        except tsa.exc.DBAPIError, e:
-            if not e.connection_invalidated:
-                raise
+        self.engine.test_shutdown()
+        _assert_invalidated(conn.execute, select([1]))
         assert not conn.closed
         assert conn.invalidated
         assert trans.is_active
@@ -555,13 +557,11 @@ class RealReconnectTest(fixtures.TestBase):
             conn.execute, select([1])
         )
         assert trans.is_active
-        try:
-            trans.commit()
-            assert False
-        except tsa.exc.InvalidRequestError, e:
-            assert str(e) \
-                == "Can't reconnect until invalid transaction is "\
-                "rolled back"
+        assert_raises_message(
+            tsa.exc.InvalidRequestError,
+            "Can't reconnect until invalid transaction is rolled back",
+            trans.commit
+        )
         assert trans.is_active
         trans.rollback()
         assert not trans.is_active
@@ -599,23 +599,21 @@ class RecycleTest(fixtures.TestBase):
             eq_(conn.execute(select([1])).scalar(), 1)
             conn.close()
 
-meta, table, engine = None, None, None
 class InvalidateDuringResultTest(fixtures.TestBase):
     def setup(self):
-        global meta, table, engine
-        engine = engines.reconnecting_engine()
-        meta = MetaData(engine)
-        table = Table('sometable', meta,
+        self.engine = engines.reconnecting_engine()
+        self.meta = MetaData(self.engine)
+        table = Table('sometable', self.meta,
             Column('id', Integer, primary_key=True),
             Column('name', String(50)))
-        meta.create_all()
+        self.meta.create_all()
         table.insert().execute(
-            [{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
+            [{'id': i, 'name': 'row %d' % i} for i in range(1, 100)]
         )
 
     def teardown(self):
-        meta.drop_all()
-        engine.dispose()
+        self.meta.drop_all()
+        self.engine.dispose()
 
     @testing.fails_if([
                     '+mysqlconnector', '+mysqldb',
@@ -625,16 +623,11 @@ class InvalidateDuringResultTest(fixtures.TestBase):
     @testing.fails_on('+informixdb',
                       "Wrong error thrown, fix in informixdb?")
     def test_invalidate_on_results(self):
-        conn = engine.connect()
+        conn = self.engine.connect()
         result = conn.execute('select * from sometable')
         for x in xrange(20):
             result.fetchone()
-        engine.test_shutdown()
-        try:
-            print 'ghost result: %r' % result.fetchone()
-            assert False
-        except tsa.exc.DBAPIError, e:
-            if not e.connection_invalidated:
-                raise
+        self.engine.test_shutdown()
+        _assert_invalidated(result.fetchone)
         assert conn.invalidated