raise NotImplementedError()
def log(self, msg):
- self.logger.write(msg)
+ self._logger.write(msg)
class ConnectionFairy(object):
def __init__(self, pool, connection=None):
"""Maintains one connection per each thread, never moving to another thread. this is
used for SQLite and other databases with a similar restriction."""
def __init__(self, creator, **params):
- params['use_threadlocal'] = False
Pool.__init__(self, **params)
self._conns = {}
self._creator = creator
def status(self):
- return "SingletonThreadPool size: %d" % len(self._conns)
-
- def unique_connection(self):
- return ConnectionFairy(self, self._creator())
+ return "SingletonThreadPool thread:%d size: %d" % (thread.get_ident(), len(self._conns))
def do_return_conn(self, conn):
- pass
+ if self._conns.get(thread.get_ident(), None) is None:
+ self._conns[thread.get_ident()] = conn
+
def do_return_invalid(self):
try:
del self._conns[thread.get_ident()]
def do_get(self):
try:
- return self._conns[thread.get_ident()]
+ c = self._conns[thread.get_ident()]
+ if c is None:
+ return self._creator()
except KeyError:
- return self._conns.setdefault(thread.get_ident(), self._creator())
+ c = self._creator()
+ self._conns[thread.get_ident()] = None
+ return c
class QueuePool(Pool):
"""uses Queue.Queue to maintain a fixed-size list of connections."""
self.assert_(status(p) == (3, 1, 0, 2))
c2 = None
self.assert_(status(p) == (3, 2, 0, 1))
-
+
+ def testthreadlocal(self):
+ for p in (
+ pool.QueuePool(creator = lambda: sqlite.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True, echo = False),
+ pool.SingletonThreadPool(creator = lambda: sqlite.connect('foo.db'), use_threadlocal = True)
+ ):
+ c1 = p.connect()
+ c2 = p.connect()
+ self.assert_(c1 is c2)
+ c3 = p.unique_connection()
+ self.assert_(c3 is not c1)
+ c2 = None
+ c2 = p.connect()
+ self.assert_(c1 is c2)
+ self.assert_(c3 is not c1)
+
def tearDown(self):
pool.clear_managers()
for file in ('foo.db', 'bar.db'):