]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- pool event tests that don't depend on deprecated listener system,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Nov 2010 18:59:03 +0000 (13:59 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Nov 2010 18:59:03 +0000 (13:59 -0500)
attempting "test just one thing" style
- reorganize fixtures to come primarily from the base test class

test/engine/test_pool.py

index 26fef371c016b4c200dfd32afe8b96dd44ff39ea..7f451698a85e03d5c8269626eebf6840b910e6ca 100644 (file)
@@ -7,8 +7,7 @@ from sqlalchemy.test.testing import eq_
 
 mcid = 1
 class MockDBAPI(object):
-    def __init__(self):
-        self.throw_error = False
+    throw_error = False
     def connect(self, *args, **kwargs):
         if self.throw_error:
             raise Exception("couldnt connect !")
@@ -17,10 +16,10 @@ class MockDBAPI(object):
             time.sleep(delay)
         return MockConnection()
 class MockConnection(object):
+    closed = False
     def __init__(self):
         global mcid
         self.id = mcid
-        self.closed = False
         mcid += 1
     def close(self):
         self.closed = True
@@ -33,8 +32,6 @@ class MockCursor(object):
         pass
     def close(self):
         pass
-mock_dbapi = MockDBAPI()
-
 
 class PoolTestBase(TestBase):    
     def setup(self):
@@ -44,9 +41,17 @@ class PoolTestBase(TestBase):
     def teardown_class(cls):
        pool.clear_managers()
 
+    def _queuepool_fixture(self, **kw):
+        dbapi = MockDBAPI()
+        return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), **kw)
+
+    def _queuepool_dbapi_fixture(self, **kw):
+        dbapi = MockDBAPI()
+        return dbapi, pool.QueuePool(creator=lambda: dbapi.connect('foo.db'), **kw)
+     
 class PoolTest(PoolTestBase):
     def testmanager(self):
-        manager = pool.manage(mock_dbapi, use_threadlocal=True)
+        manager = pool.manage(MockDBAPI(), use_threadlocal=True)
 
         connection = manager.connect('foo.db')
         connection2 = manager.connect('foo.db')
@@ -57,7 +62,7 @@ class PoolTest(PoolTestBase):
         self.assert_(connection2 is not connection3)
 
     def testbadargs(self):
-        manager = pool.manage(mock_dbapi)
+        manager = pool.manage(MockDBAPI())
 
         try:
             connection = manager.connect(None)
@@ -65,7 +70,7 @@ class PoolTest(PoolTestBase):
             pass
 
     def testnonthreadlocalmanager(self):
-        manager = pool.manage(mock_dbapi, use_threadlocal = False)
+        manager = pool.manage(MockDBAPI(), use_threadlocal = False)
 
         connection = manager.connect('foo.db')
         connection2 = manager.connect('foo.db')
@@ -101,7 +106,6 @@ class PoolTest(PoolTestBase):
             p.dispose()
             p.recreate()
             
-            
     def testthreadlocal_del(self):
         self._do_testthreadlocal(useclose=False)
 
@@ -109,10 +113,11 @@ class PoolTest(PoolTestBase):
         self._do_testthreadlocal(useclose=True)
 
     def _do_testthreadlocal(self, useclose=False):
-        for p in pool.QueuePool(creator=mock_dbapi.connect,
+        dbapi = MockDBAPI()
+        for p in pool.QueuePool(creator=dbapi.connect,
                                 pool_size=3, max_overflow=-1,
                                 use_threadlocal=True), \
-            pool.SingletonThreadPool(creator=mock_dbapi.connect,
+            pool.SingletonThreadPool(creator=dbapi.connect,
                 use_threadlocal=True):
             c1 = p.connect()
             c2 = p.connect()
@@ -159,9 +164,7 @@ class PoolTest(PoolTestBase):
                 self.assert_(p.checkedout() == 0)
 
     def test_properties(self):
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
-                           pool_size=1, max_overflow=0, use_threadlocal=False)
+        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
 
         c = p.connect()
         self.assert_(not c.info)
@@ -187,12 +190,195 @@ class PoolTest(PoolTestBase):
         self.assert_(not c2.info)
         self.assert_('foo2' in c.info)
 
+
+
 class PoolEventsTest(PoolTestBase):
+    def _first_connect_event_fixture(self):
+        p = self._queuepool_fixture()
+        canary = []
+        def on_first_connect(*arg, **kw):
+            canary.append('first_connect')
+        
+        event.listen(on_first_connect, 'on_first_connect', p)
+        
+        return p, canary
+
+    def _connect_event_fixture(self):
+        p = self._queuepool_fixture()
+        canary = []
+        def on_connect(*arg, **kw):
+            canary.append('connect')
+        event.listen(on_connect, 'on_connect', p)
+        
+        return p, canary
+
+    def _checkout_event_fixture(self):
+        p = self._queuepool_fixture()
+        canary = []
+        def on_checkout(*arg, **kw):
+            canary.append('checkout')
+        event.listen(on_checkout, 'on_checkout', p)
+        
+        return p, canary
+
+    def _checkin_event_fixture(self):
+        p = self._queuepool_fixture()
+        canary = []
+        def on_checkin(*arg, **kw):
+            canary.append('checkin')
+        event.listen(on_checkin, 'on_checkin', p)
+        
+        return p, canary
+        
+    def test_first_connect_event(self):
+        p, canary = self._first_connect_event_fixture()
+        
+        c1 = p.connect()
+        eq_(canary, ['first_connect'])
+
+    def test_first_connect_event_fires_once(self):
+        p, canary = self._first_connect_event_fixture()
+        
+        c1 = p.connect()
+        c2 = p.connect()
+
+        eq_(canary, ['first_connect'])
+
+    def test_first_connect_on_previously_recreated(self):
+        p, canary = self._first_connect_event_fixture()
+
+        p2 = p.recreate()
+        c1 = p.connect()
+        c2 = p2.connect()
+
+        eq_(canary, ['first_connect', 'first_connect'])
+
+    def test_first_connect_on_subsequently_recreated(self):
+        p, canary = self._first_connect_event_fixture()
+
+        c1 = p.connect()
+        p2 = p.recreate()
+        c2 = p2.connect()
+
+        eq_(canary, ['first_connect', 'first_connect'])
+
+    def test_connect_event(self):
+        p, canary = self._connect_event_fixture()
+        
+        c1 = p.connect()
+        eq_(canary, ['connect'])
+
+    def test_connect_event_fires_subsequent(self):
+        p, canary = self._connect_event_fixture()
+        
+        c1 = p.connect()
+        c2 = p.connect()
+
+        eq_(canary, ['connect', 'connect'])
+
+    def test_connect_on_previously_recreated(self):
+        p, canary = self._connect_event_fixture()
+
+        p2 = p.recreate()
+        
+        c1 = p.connect()
+        c2 = p2.connect()
+
+        eq_(canary, ['connect', 'connect'])
+
+    def test_connect_on_subsequently_recreated(self):
+        p, canary = self._connect_event_fixture()
+
+        c1 = p.connect()
+        p2 = p.recreate()
+        c2 = p2.connect()
+
+        eq_(canary, ['connect', 'connect'])
     
+    def test_checkout_event(self):
+        p, canary = self._checkout_event_fixture()
+        
+        c1 = p.connect()
+        eq_(canary, ['checkout'])
+
+    def test_checkout_event_fires_subsequent(self):
+        p, canary = self._checkout_event_fixture()
+        
+        c1 = p.connect()
+        c2 = p.connect()
+        eq_(canary, ['checkout', 'checkout'])
+
+    def test_checkout_event_on_subsequently_recreated(self):
+        p, canary = self._checkout_event_fixture()
+        
+        c1 = p.connect()
+        p2 = p.recreate()
+        c2 = p2.connect()
+        
+        eq_(canary, ['checkout', 'checkout'])
+        
+    def test_checkin_event(self):
+        p, canary = self._checkin_event_fixture()
+        
+        c1 = p.connect()
+        eq_(canary, [])
+        c1.close()
+        eq_(canary, ['checkin'])
+    
+    def test_checkin_event_gc(self):
+        p, canary = self._checkin_event_fixture()
+        
+        c1 = p.connect()
+        eq_(canary, [])
+        del c1
+        lazy_gc()
+        eq_(canary, ['checkin'])
+
+    def test_checkin_event_on_subsequently_recreated(self):
+        p, canary = self._checkin_event_fixture()
+        
+        c1 = p.connect()
+        p2 = p.recreate()
+        c2 = p2.connect()
+        
+        eq_(canary, [])
+        
+        c1.close()
+        eq_(canary, ['checkin'])
+        
+        c2.close()
+        eq_(canary, ['checkin', 'checkin'])
+        
+    def test_listen_targets(self):
+        canary = []
+        def listen_one(*args):
+            canary.append("listen_one")
+        def listen_two(*args):
+            canary.append("listen_two")
+        def listen_three(*args):
+            canary.append("listen_three")
+        def listen_four(*args):
+            canary.append("listen_four")
+            
+        engine = create_engine(testing.db.url)
+        event.listen(listen_one, 'on_connect', pool.Pool)
+        event.listen(listen_two, 'on_connect', engine.pool)
+        event.listen(listen_three, 'on_connect', engine)
+        event.listen(listen_four, 'on_connect', engine.__class__)
+
+        engine.execute(select([1])).close()
+        eq_(
+            canary, ["listen_one","listen_four", "listen_two","listen_three"]
+        )
+
+    def teardown(self):
+        # TODO: need to get remove() functionality
+        # going
+        pool.Pool.dispatch.clear()
+        
+class DeprecatedPoolListenerTest(PoolTestBase):
     @testing.uses_deprecated(r".*Use event.listen")
     def test_listeners(self):
-        dbapi = MockDBAPI()
-
         class InstrumentingListener(object):
             def __init__(self):
                 if hasattr(self, 'connect'):
@@ -257,10 +443,6 @@ class PoolEventsTest(PoolTestBase):
             def checkin(self, con, record):
                 pass
 
-        def _pool(**kw):
-            return pool.QueuePool(creator=lambda: dbapi.connect('foo.db'),
-                                  use_threadlocal=False, **kw)
-
         def assert_listeners(p, total, conn, fconn, cout, cin):
             for instance in (p, p.recreate()):
                 self.assert_(len(instance.dispatch.on_connect) == conn)
@@ -268,7 +450,7 @@ class PoolEventsTest(PoolTestBase):
                 self.assert_(len(instance.dispatch.on_checkout) == cout)
                 self.assert_(len(instance.dispatch.on_checkin) == cin)
 
-        p = _pool()
+        p = self._queuepool_fixture()
         assert_listeners(p, 0, 0, 0, 0, 0)
 
         p.add_listener(ListenAll())
@@ -288,7 +470,7 @@ class PoolEventsTest(PoolTestBase):
         del p
 
         snoop = ListenAll()
-        p = _pool(listeners=[snoop])
+        p = self._queuepool_fixture(listeners=[snoop])
         assert_listeners(p, 1, 1, 1, 1, 1)
 
         c = p.connect()
@@ -372,8 +554,6 @@ class PoolEventsTest(PoolTestBase):
     
     @testing.uses_deprecated(r".*Use event.listen")
     def test_listeners_callables(self):
-        dbapi = MockDBAPI()
-
         def connect(dbapi_con, con_record):
             counts[0] += 1
         def checkout(dbapi_con, con_record, con_proxy):
@@ -388,9 +568,6 @@ class PoolEventsTest(PoolTestBase):
 
         for cls in (pool.QueuePool, pool.StaticPool):
             counts = [0, 0, 0]
-            def _pool(**kw):
-                return cls(creator=lambda: dbapi.connect('foo.db'),
-                                      use_threadlocal=False, **kw)
 
             def assert_listeners(p, total, conn, cout, cin):
                 for instance in (p, p.recreate()):
@@ -398,7 +575,7 @@ class PoolEventsTest(PoolTestBase):
                     eq_(len(instance.dispatch.on_checkout), cout)
                     eq_(len(instance.dispatch.on_checkin), cin)
 
-            p = _pool()
+            p = self._queuepool_fixture()
             assert_listeners(p, 0, 0, 0, 0)
 
             p.add_listener(i_all)
@@ -414,7 +591,7 @@ class PoolEventsTest(PoolTestBase):
             assert_listeners(p, 4, 1, 1, 1)
             del p
 
-            p = _pool(listeners=[i_all])
+            p = self._queuepool_fixture(listeners=[i_all])
             assert_listeners(p, 1, 1, 1, 1)
 
             c = p.connect()
@@ -428,45 +605,6 @@ class PoolEventsTest(PoolTestBase):
             c.close()
             assert counts == [1, 2, 2]
 
-    def test_listener_after_oninit(self):
-        """Test that listeners are called after OnInit is removed"""
-        
-        called = []
-        def listener(*args):
-            called.append(True)
-        engine = create_engine(testing.db.url)
-        event.listen(listener, 'on_connect', engine.pool)
-        engine.execute(select([1])).close()
-        assert called, "Listener not called on connect"
-
-    def test_targets(self):
-        canary = []
-        def listen_one(*args):
-            canary.append("listen_one")
-        def listen_two(*args):
-            canary.append("listen_two")
-        def listen_three(*args):
-            canary.append("listen_three")
-        def listen_four(*args):
-            canary.append("listen_four")
-            
-        engine = create_engine(testing.db.url)
-        event.listen(listen_one, 'on_connect', pool.Pool)
-        event.listen(listen_two, 'on_connect', engine.pool)
-        event.listen(listen_three, 'on_connect', engine)
-        event.listen(listen_four, 'on_connect', engine.__class__)
-
-        engine.execute(select([1])).close()
-        eq_(
-            canary, ["listen_one","listen_four", "listen_two","listen_three"]
-        )
-
-    def teardown(self):
-        # TODO: need to get remove() functionality
-        # going
-        pool.Pool.dispatch.clear()
-        
-
 class QueuePoolTest(PoolTestBase):
 
     def testqueuepool_del(self):
@@ -476,8 +614,8 @@ class QueuePoolTest(PoolTestBase):
         self._do_testqueuepool(useclose=True)
 
     def _do_testqueuepool(self, useclose=False):
-        p = pool.QueuePool(creator=mock_dbapi.connect, pool_size=3,
-                           max_overflow=-1, use_threadlocal=False)
+        p = self._queuepool_fixture(pool_size=3,
+                           max_overflow=-1)
 
         def status(pool):
             tup = pool.size(), pool.checkedin(), pool.overflow(), \
@@ -528,8 +666,8 @@ class QueuePoolTest(PoolTestBase):
         assert not pool._refs
        
     def test_timeout(self):
-        p = pool.QueuePool(creator=mock_dbapi.connect, pool_size=3,
-                           max_overflow=0, use_threadlocal=False,
+        p = self._queuepool_fixture(pool_size=3,
+                           max_overflow=0, 
                            timeout=2)
         c1 = p.connect()
         c2 = p.connect()
@@ -549,8 +687,9 @@ class QueuePoolTest(PoolTestBase):
         # wait for the timeout on queue.get().  the fix involves checking the
         # timeout again within the mutex, and if so, unlocking and throwing
         # them back to the start of do_get()
+        dbapi = MockDBAPI()
         p = pool.QueuePool(
-                creator = lambda: mock_dbapi.connect(delay=.05), 
+                creator = lambda: dbapi.connect(delay=.05), 
                 pool_size = 2, 
                 max_overflow = 1, use_threadlocal = False, timeout=3)
         timeouts = []
@@ -583,9 +722,10 @@ class QueuePoolTest(PoolTestBase):
     def _test_overflow(self, thread_count, max_overflow):
         gc_collect()
         
+        dbapi = MockDBAPI()
         def creator():
             time.sleep(.05)
-            return mock_dbapi.connect()
+            return dbapi.connect()
  
         p = pool.QueuePool(creator=creator,
                            pool_size=3, timeout=2,
@@ -621,8 +761,7 @@ class QueuePoolTest(PoolTestBase):
         self._test_overflow(40, 5)
  
     def test_mixed_close(self):
-        p = pool.QueuePool(creator=mock_dbapi.connect, pool_size=3,
-                           max_overflow=-1, use_threadlocal=True)
+        p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True)
         c1 = p.connect()
         c2 = p.connect()
         assert c1 is c2
@@ -636,7 +775,7 @@ class QueuePoolTest(PoolTestBase):
         assert not pool._refs
 
     def test_weakref_kaboom(self):
-        p = pool.QueuePool(creator=mock_dbapi.connect, pool_size=3,
+        p = self._queuepool_fixture(pool_size=3,
                            max_overflow=-1, use_threadlocal=True)
         c1 = p.connect()
         c2 = p.connect()
@@ -655,8 +794,8 @@ class QueuePoolTest(PoolTestBase):
         counter, you can fool the counter into giving you a
         ConnectionFairy with an ambiguous counter.  i.e. its not true
         reference counting."""
-
-        p = pool.QueuePool(creator=mock_dbapi.connect, pool_size=3,
+        
+        p = self._queuepool_fixture(pool_size=3,
                            max_overflow=-1, use_threadlocal=True)
         c1 = p.connect()
         c2 = p.connect()
@@ -669,8 +808,8 @@ class QueuePoolTest(PoolTestBase):
         self.assert_(p.checkedout() == 0)
  
     def test_recycle(self):
-        p = pool.QueuePool(creator=mock_dbapi.connect, pool_size=1,
-                           max_overflow=0, use_threadlocal=False,
+        p = self._queuepool_fixture(pool_size=1,
+                           max_overflow=0, 
                            recycle=3)
         c1 = p.connect()
         c_id = id(c1.connection)
@@ -683,10 +822,7 @@ class QueuePoolTest(PoolTestBase):
         assert id(c3.connection) != c_id
 
     def test_invalidate(self):
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator=lambda : dbapi.connect('foo.db'),
-                           pool_size=1, max_overflow=0,
-                           use_threadlocal=False)
+        p  = self._queuepool_fixture(pool_size=1, max_overflow=0)
         c1 = p.connect()
         c_id = c1.connection.id
         c1.close()
@@ -699,10 +835,7 @@ class QueuePoolTest(PoolTestBase):
         assert c1.connection.id != c_id
 
     def test_recreate(self):
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator=lambda : dbapi.connect('foo.db'),
-                           pool_size=1, max_overflow=0,
-                           use_threadlocal=False)
+        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
         p2 = p.recreate()
         assert p2.size() == 1
         assert p2._use_threadlocal is False
@@ -713,10 +846,7 @@ class QueuePoolTest(PoolTestBase):
         engine/dialect includes another layer of reconnect support for
         'database was lost' errors."""
 
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator=lambda : dbapi.connect('foo.db'),
-                           pool_size=1, max_overflow=0,
-                           use_threadlocal=False)
+        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
         c1 = p.connect()
         c_id = c1.connection.id
         c1.close()
@@ -730,10 +860,7 @@ class QueuePoolTest(PoolTestBase):
         assert c1.connection.id != c_id
 
     def test_detach(self):
-        dbapi = MockDBAPI()
-        p = pool.QueuePool(creator=lambda : dbapi.connect('foo.db'),
-                           pool_size=1, max_overflow=0,
-                           use_threadlocal=False)
+        dbapi, p = self._queuepool_dbapi_fixture(pool_size=1, max_overflow=0)
         c1 = p.connect()
         c1.detach()
         c_id = c1.connection.id
@@ -750,8 +877,7 @@ class QueuePoolTest(PoolTestBase):
         assert con.closed
 
     def test_threadfairy(self):
-        p = pool.QueuePool(creator=mock_dbapi.connect, pool_size=3,
-                           max_overflow=-1, use_threadlocal=True)
+        p = self._queuepool_fixture(pool_size=3, max_overflow=-1, use_threadlocal=True)
         c1 = p.connect()
         c1.close()
         c2 = p.connect()
@@ -762,8 +888,9 @@ class SingletonThreadPoolTest(PoolTestBase):
     def test_cleanup(self):
         """test that the pool's connections are OK after cleanup() has
         been called."""
-
-        p = pool.SingletonThreadPool(creator=mock_dbapi.connect,
+        
+        dbapi = MockDBAPI()
+        p = pool.SingletonThreadPool(creator=dbapi.connect,
                 pool_size=3)
 
         def checkout():