From f6357fb07dc4190a346a785580f3c22f4c2cb0f0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 19 Jun 2009 22:15:23 +0000 Subject: [PATCH] jython pool tests pass 100% [ticket:1444] --- lib/sqlalchemy/pool.py | 21 +- lib/sqlalchemy/test/noseplugin.py | 1 + lib/sqlalchemy/test/testing.py | 17 +- lib/sqlalchemy/test/util.py | 11 +- test/engine/test_pool.py | 540 ++++++++++++++++-------------- 5 files changed, 319 insertions(+), 271 deletions(-) diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 4173a78786..4604de34bb 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -276,8 +276,11 @@ class _ConnectionRecord(object): def _finalize_fairy(connection, connection_record, pool, ref=None): - if ref is not None and (connection_record.backref is not ref or isinstance(pool, AssertionPool)): + _refs.discard(connection_record) + + if ref is not None and (connection_record.fairy is not ref or isinstance(pool, AssertionPool)): return + if connection is not None: try: if pool._reset_on_return: @@ -291,7 +294,7 @@ def _finalize_fairy(connection, connection_record, pool, ref=None): if isinstance(e, (SystemExit, KeyboardInterrupt)): raise if connection_record is not None: - connection_record.backref = None + connection_record.fairy = None if pool._should_log_info: pool.log("Connection %r being returned to pool" % connection) if pool._on_checkin: @@ -299,6 +302,8 @@ def _finalize_fairy(connection, connection_record, pool, ref=None): l.checkin(connection, connection_record) pool.return_conn(connection_record) +_refs = set() + class _ConnectionFairy(object): """Proxies a DB-API connection and provides return-on-dereference support.""" @@ -310,7 +315,8 @@ class _ConnectionFairy(object): try: rec = self._connection_record = pool.get() conn = self.connection = self._connection_record.get_connection() - self._connection_record.backref = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref)) + rec.fairy = weakref.ref(self, lambda ref:_finalize_fairy(conn, rec, pool, ref)) + _refs.add(rec) except: self.connection = None # helps with endless __getattr__ loops later on self._connection_record = None @@ -409,8 +415,9 @@ class _ConnectionFairy(object): """ if self._connection_record is not None: + _refs.remove(self._connection_record) + self._connection_record.fairy = None self._connection_record.connection = None - self._connection_record.backref = None self._pool.do_return_conn(self._connection_record) self._detached_info = \ self._connection_record.info.copy() @@ -508,10 +515,8 @@ class SingletonThreadPool(Pool): del self._conn.current def cleanup(self): - for conn in list(self._all_conns): - self._all_conns.discard(conn) - if len(self._all_conns) <= self.size: - return + while len(self._all_conns) > self.size: + self._all_conns.pop() def status(self): return "SingletonThreadPool id:%d size: %d" % (id(self), len(self._all_conns)) diff --git a/lib/sqlalchemy/test/noseplugin.py b/lib/sqlalchemy/test/noseplugin.py index ddf4b32e4c..dbad80409c 100644 --- a/lib/sqlalchemy/test/noseplugin.py +++ b/lib/sqlalchemy/test/noseplugin.py @@ -149,6 +149,7 @@ class NoseSQLAlchemy(Plugin): def afterTest(self, test): testing.resetwarnings() + testing.global_cleanup_assertions() #def handleError(self, test, err): #pass diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py index da164a67a3..39dab81b5a 100644 --- a/lib/sqlalchemy/test/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -8,11 +8,11 @@ import types import warnings from cStringIO import StringIO -from sqlalchemy.test import config, assertsql +from sqlalchemy.test import config, assertsql, util as testutil from sqlalchemy.util import function_named from engines import drop_all_tables -from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema +from sqlalchemy import exc as sa_exc, util, types as sqltypes, schema, pool _ops = { '<': operator.lt, '>': operator.gt, @@ -413,6 +413,19 @@ def resetwarnings(): if sys.version_info < (2, 4): warnings.filterwarnings('ignore', category=FutureWarning) +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + + testutil.lazy_gc() + assert not pool._refs + + def against(*queries): """Boolean predicate, compares to testing database configuration. diff --git a/lib/sqlalchemy/test/util.py b/lib/sqlalchemy/test/util.py index 0ae10390f7..60b0a4ef81 100644 --- a/lib/sqlalchemy/test/util.py +++ b/lib/sqlalchemy/test/util.py @@ -1,17 +1,24 @@ from sqlalchemy.util import jython, function_named import gc +import time if jython: def gc_collect(*args): + """aggressive gc.collect for tests.""" gc.collect() time.sleep(0.1) gc.collect() gc.collect() return 0 + + # "lazy" gc, for VM's that don't GC on refcount == 0 + lazy_gc = gc_collect + else: + # assume CPython - straight gc.collect, lazy_gc() is a pass gc_collect = gc.collect - - + def lazy_gc(): + pass diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index ad409450e5..0e2a9ae1ac 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -2,7 +2,7 @@ import threading, time from sqlalchemy import pool, interfaces, create_engine import sqlalchemy as tsa from sqlalchemy.test import TestBase, testing -from sqlalchemy.test.util import gc_collect +from sqlalchemy.test.util import gc_collect, lazy_gc mcid = 1 @@ -52,7 +52,6 @@ class PoolTest(PoolTestBase): connection2 = manager.connect('foo.db') connection3 = manager.connect('bar.db') - print "connection " + repr(connection) self.assert_(connection.cursor() is not None) self.assert_(connection is connection2) self.assert_(connection2 is not connection3) @@ -71,8 +70,6 @@ class PoolTest(PoolTestBase): connection = manager.connect('foo.db') connection2 = manager.connect('foo.db') - print "connection " + repr(connection) - self.assert_(connection.cursor() is not None) self.assert_(connection is not connection2) @@ -104,7 +101,8 @@ class PoolTest(PoolTestBase): c2.close() else: c2 = None - + lazy_gc() + if useclose: c1 = p.connect() c2 = p.connect() @@ -118,6 +116,8 @@ class PoolTest(PoolTestBase): # extra tests with QueuePool to ensure connections get __del__()ed when dereferenced if isinstance(p, pool.QueuePool): + lazy_gc() + self.assert_(p.checkedout() == 0) c1 = p.connect() c2 = p.connect() @@ -127,6 +127,7 @@ class PoolTest(PoolTestBase): else: c2 = None c1 = None + lazy_gc() self.assert_(p.checkedout() == 0) def test_properties(self): @@ -255,7 +256,6 @@ class PoolTest(PoolTestBase): assert_listeners(p, 5, 2, 2, 2, 2) del p - print "----" snoop = ListenAll() p = _pool(listeners=[snoop]) assert_listeners(p, 1, 1, 1, 1, 1) @@ -276,6 +276,7 @@ class PoolTest(PoolTestBase): snoop.assert_in(cc, False, False, True, False) snoop.assert_total(0, 0, 1, 0) del c, cc + lazy_gc() snoop.assert_total(0, 0, 1, 1) p.dispose() @@ -299,11 +300,13 @@ class PoolTest(PoolTestBase): c.close() snoop.assert_total(1, 0, 1, 1) del c + lazy_gc() snoop.assert_total(1, 0, 1, 1) c = p.connect() snoop.assert_total(2, 0, 2, 1) c.close() del c + lazy_gc() snoop.assert_total(2, 0, 2, 2) # detached @@ -394,260 +397,279 @@ class PoolTest(PoolTestBase): class QueuePoolTest(PoolTestBase): - def testqueuepool_del(self): - self._do_testqueuepool(useclose=False) - - def testqueuepool_close(self): - self._do_testqueuepool(useclose=True) - - def _do_testqueuepool(self, useclose=False): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) - - def status(pool): - tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) - print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup - return tup - - c1 = p.connect() - self.assert_(status(p) == (3,0,-2,1)) - c2 = p.connect() - self.assert_(status(p) == (3,0,-1,2)) - c3 = p.connect() - self.assert_(status(p) == (3,0,0,3)) - c4 = p.connect() - self.assert_(status(p) == (3,0,1,4)) - c5 = p.connect() - self.assert_(status(p) == (3,0,2,5)) - c6 = p.connect() - self.assert_(status(p) == (3,0,3,6)) - if useclose: - c4.close() - c3.close() - c2.close() - else: - c4 = c3 = c2 = None - self.assert_(status(p) == (3,3,3,3)) - if useclose: - c1.close() - c5.close() - c6.close() - else: - c1 = c5 = c6 = None - self.assert_(status(p) == (3,3,0,0)) - c1 = p.connect() - c2 = p.connect() - self.assert_(status(p) == (3, 1, 0, 2), status(p)) - if useclose: - c2.close() - else: - c2 = None - self.assert_(status(p) == (3, 2, 0, 1)) - - def test_timeout(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) - c1 = p.connect() - c2 = p.connect() - c3 = p.connect() - now = time.time() - try: - c4 = p.connect() - assert False - except tsa.exc.TimeoutError, e: - assert int(time.time() - now) == 2 - - def test_timeout_race(self): - # test a race condition where the initial connecting threads all race - # to queue.Empty, then block on the mutex. each thread consumes a - # connection as they go in. when the limit is reached, the remaining - # threads go in, and get TimeoutError; even though they never got to - # wait for the timeout on queue.get(). the fix involves checking the - # timeout again within the mutex, and if so, unlocking and throwing - # them back to the start of do_get() - p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) - timeouts = [] - def checkout(): - for x in xrange(1): - now = time.time() - try: - c1 = p.connect() - except tsa.exc.TimeoutError, e: - timeouts.append(int(time.time()) - now) - continue - time.sleep(4) - c1.close() - - threads = [] - for i in xrange(10): - th = threading.Thread(target=checkout) - th.start() - threads.append(th) - for th in threads: - th.join() - - print timeouts - assert len(timeouts) > 0 - for t in timeouts: - assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) - - def _test_overflow(self, thread_count, max_overflow): - def creator(): - time.sleep(.05) - return mock_dbapi.connect() - - p = pool.QueuePool(creator=creator, - pool_size=3, timeout=2, - max_overflow=max_overflow) - peaks = [] - def whammy(): - for i in range(10): - try: - con = p.connect() - time.sleep(.005) - peaks.append(p.overflow()) - con.close() - del con - except tsa.exc.TimeoutError: - pass - threads = [] - for i in xrange(thread_count): - th = threading.Thread(target=whammy) - th.start() - threads.append(th) - for th in threads: - th.join() - - self.assert_(max(peaks) <= max_overflow) - - def test_no_overflow(self): - self._test_overflow(40, 0) - - def test_max_overflow(self): - self._test_overflow(40, 5) - - def test_mixed_close(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = None - assert p.checkedout() == 1 - c1 = None - assert p.checkedout() == 0 - - def test_weakref_kaboom(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - c1.close() - c2 = None - del c1 - del c2 - gc_collect() - assert p.checkedout() == 0 - c3 = p.connect() - assert c3 is not None - - def test_trick_the_counter(self): - """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread - with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an - ambiguous counter. i.e. its not true reference counting.""" - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c2 = p.connect() - assert c1 is c2 - c1.close() - c2 = p.connect() - c2.close() - self.assert_(p.checkedout() != 0) - - c2.close() - self.assert_(p.checkedout() == 0) - - def test_recycle(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) - - c1 = p.connect() - c_id = id(c1.connection) - c1.close() - c2 = p.connect() - assert id(c2.connection) == c_id - c2.close() - time.sleep(4) - c3= p.connect() - assert id(c3.connection) != c_id - - def test_invalidate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - c1 = p.connect() - assert c1.connection.id == c_id - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_recreate(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - p2 = p.recreate() - assert p2.size() == 1 - assert p2._use_threadlocal is False - assert p2._max_overflow == 0 - - def test_reconnect(self): - """tests reconnect operations at the pool level. SA's engine/dialect includes another - layer of reconnect support for 'database was lost' errors.""" + def testqueuepool_del(self): + self._do_testqueuepool(useclose=False) + + def testqueuepool_close(self): + self._do_testqueuepool(useclose=True) + + def _do_testqueuepool(self, useclose=False): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = False) + + def status(pool): + tup = (pool.size(), pool.checkedin(), pool.overflow(), pool.checkedout()) + print "Pool size: %d Connections in pool: %d Current Overflow: %d Current Checked out connections: %d" % tup + return tup + + c1 = p.connect() + self.assert_(status(p) == (3,0,-2,1)) + c2 = p.connect() + self.assert_(status(p) == (3,0,-1,2)) + c3 = p.connect() + self.assert_(status(p) == (3,0,0,3)) + c4 = p.connect() + self.assert_(status(p) == (3,0,1,4)) + c5 = p.connect() + self.assert_(status(p) == (3,0,2,5)) + c6 = p.connect() + self.assert_(status(p) == (3,0,3,6)) + if useclose: + c4.close() + c3.close() + c2.close() + else: + c4 = c3 = c2 = None + lazy_gc() + + self.assert_(status(p) == (3,3,3,3)) + if useclose: + c1.close() + c5.close() + c6.close() + else: + c1 = c5 = c6 = None + lazy_gc() + + self.assert_(status(p) == (3,3,0,0)) + + c1 = p.connect() + c2 = p.connect() + self.assert_(status(p) == (3, 1, 0, 2), status(p)) + if useclose: + c2.close() + else: + c2 = None + lazy_gc() + + self.assert_(status(p) == (3, 2, 0, 1)) + + c1.close() - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - c1 = p.connect() - c_id = c1.connection.id - c1.close(); c1=None - - c1 = p.connect() - assert c1.connection.id == c_id - dbapi.raise_error = True - c1.invalidate() - c1 = None - - c1 = p.connect() - assert c1.connection.id != c_id - - def test_detach(self): - dbapi = MockDBAPI() - p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) - - c1 = p.connect() - c1.detach() - c_id = c1.connection.id - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - dbapi.raise_error = True - - c2.invalidate() - c2 = None - - c2 = p.connect() - assert c2.connection.id != c1.connection.id - - con = c1.connection - - assert not con.closed - c1.close() - assert con.closed - - def test_threadfairy(self): - p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) - c1 = p.connect() - c1.close() - c2 = p.connect() - assert c2.connection is not None + lazy_gc() + assert not pool._refs + + def test_timeout(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = 0, use_threadlocal = False, timeout=2) + c1 = p.connect() + c2 = p.connect() + c3 = p.connect() + now = time.time() + try: + c4 = p.connect() + assert False + except tsa.exc.TimeoutError, e: + assert int(time.time() - now) == 2 + + def test_timeout_race(self): + # test a race condition where the initial connecting threads all race + # to queue.Empty, then block on the mutex. each thread consumes a + # connection as they go in. when the limit is reached, the remaining + # threads go in, and get TimeoutError; even though they never got to + # wait for the timeout on queue.get(). the fix involves checking the + # timeout again within the mutex, and if so, unlocking and throwing + # them back to the start of do_get() + p = pool.QueuePool(creator = lambda: mock_dbapi.connect(delay=.05), pool_size = 2, max_overflow = 1, use_threadlocal = False, timeout=3) + timeouts = [] + def checkout(): + for x in xrange(1): + now = time.time() + try: + c1 = p.connect() + except tsa.exc.TimeoutError, e: + timeouts.append(int(time.time()) - now) + continue + time.sleep(4) + c1.close() + + threads = [] + for i in xrange(10): + th = threading.Thread(target=checkout) + th.start() + threads.append(th) + for th in threads: + th.join() + + print timeouts + assert len(timeouts) > 0 + for t in timeouts: + assert abs(t - 3) < 1, "Not all timeouts were 3 seconds: " + repr(timeouts) + + def _test_overflow(self, thread_count, max_overflow): + def creator(): + time.sleep(.05) + return mock_dbapi.connect() + + p = pool.QueuePool(creator=creator, + pool_size=3, timeout=2, + max_overflow=max_overflow) + peaks = [] + def whammy(): + for i in range(10): + try: + con = p.connect() + time.sleep(.005) + peaks.append(p.overflow()) + con.close() + del con + except tsa.exc.TimeoutError: + pass + threads = [] + for i in xrange(thread_count): + th = threading.Thread(target=whammy) + th.start() + threads.append(th) + for th in threads: + th.join() + + self.assert_(max(peaks) <= max_overflow) + + lazy_gc() + assert not pool._refs + + def test_no_overflow(self): + self._test_overflow(40, 0) + + def test_max_overflow(self): + self._test_overflow(40, 5) + + def test_mixed_close(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = None + assert p.checkedout() == 1 + c1 = None + lazy_gc() + assert p.checkedout() == 0 + + lazy_gc() + assert not pool._refs + + def test_weakref_kaboom(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + c1.close() + c2 = None + del c1 + del c2 + gc_collect() + assert p.checkedout() == 0 + c3 = p.connect() + assert c3 is not None + + def test_trick_the_counter(self): + """this is a "flaw" in the connection pool; since threadlocal uses a single ConnectionFairy per thread + with an open/close counter, you can fool the counter into giving you a ConnectionFairy with an + ambiguous counter. i.e. its not true reference counting.""" + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c2 = p.connect() + assert c1 is c2 + c1.close() + c2 = p.connect() + c2.close() + self.assert_(p.checkedout() != 0) + + c2.close() + self.assert_(p.checkedout() == 0) + + def test_recycle(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 1, max_overflow = 0, use_threadlocal = False, recycle=3) + + c1 = p.connect() + c_id = id(c1.connection) + c1.close() + c2 = p.connect() + assert id(c2.connection) == c_id + c2.close() + time.sleep(4) + c3= p.connect() + assert id(c3.connection) != c_id + + def test_invalidate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + c1 = p.connect() + assert c1.connection.id == c_id + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_recreate(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + p2 = p.recreate() + assert p2.size() == 1 + assert p2._use_threadlocal is False + assert p2._max_overflow == 0 + + def test_reconnect(self): + """tests reconnect operations at the pool level. SA's engine/dialect includes another + layer of reconnect support for 'database was lost' errors.""" + + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + c1 = p.connect() + c_id = c1.connection.id + c1.close(); c1=None + + c1 = p.connect() + assert c1.connection.id == c_id + dbapi.raise_error = True + c1.invalidate() + c1 = None + + c1 = p.connect() + assert c1.connection.id != c_id + + def test_detach(self): + dbapi = MockDBAPI() + p = pool.QueuePool(creator = lambda: dbapi.connect('foo.db'), pool_size = 1, max_overflow = 0, use_threadlocal = False) + + c1 = p.connect() + c1.detach() + c_id = c1.connection.id + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + dbapi.raise_error = True + + c2.invalidate() + c2 = None + + c2 = p.connect() + assert c2.connection.id != c1.connection.id + + con = c1.connection + + assert not con.closed + c1.close() + assert con.closed + + def test_threadfairy(self): + p = pool.QueuePool(creator = mock_dbapi.connect, pool_size = 3, max_overflow = -1, use_threadlocal = True) + c1 = p.connect() + c1.close() + c2 = p.connect() + assert c2.connection is not None class SingletonThreadPoolTest(PoolTestBase): def test_cleanup(self): -- 2.47.3