from sqlalchemy import queue as Queue
try:
- import dummy_threading
import thread, threading
except:
import dummy_thread as thread
self._overflow = 0 - pool_size
self._max_overflow = max_overflow
self._timeout = timeout
- self._overflow_lock = max_overflow > 0 and threading.Lock() or dummy_threading.Lock()
+ self._overflow_lock = max_overflow > 0 and threading.Lock() or None
def recreate(self):
self.log("Pool recreating")
try:
self._pool.put(conn, False)
except Queue.Full:
- self._overflow_lock.acquire()
- self._overflow -= 1
- self._overflow_lock.release()
+ if not self._overflow_lock:
+ self._overflow -= 1
+ else:
+ self._overflow_lock.acquire()
+ self._overflow -= 1
+ self._overflow_lock.release()
def do_get(self):
try:
return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow, self._timeout)
except Queue.Empty:
- self._overflow_lock.acquire()
+ if self._overflow_lock:
+ self._overflow_lock.acquire()
try:
if self._max_overflow > -1 and self._overflow >= self._max_overflow:
raise exceptions.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out" % (self.size(), self.overflow()))
con = self.create_connection()
self._overflow += 1
finally:
- self._overflow_lock.release()
+ if self._overflow_lock:
+ self._overflow_lock.release()
return con
def dispose(self):
import testbase
from testbase import PersistTest
import unittest, sys, os, time
+import threading
import sqlalchemy.pool as pool
import sqlalchemy.exceptions as exceptions
assert False
except exceptions.TimeoutError, e:
assert int(time.time() - now) == 2
+
+ def _test_overflow(self, thread_count, max_overflow):
+ p = pool.QueuePool(creator=lambda: mock_dbapi.connect('foo.db'),
+ pool_size=3, timeout=2,
+ max_overflow=max_overflow)
+ peaks = []
+ def whammy():
+ for i in range(10):
+ try:
+ con = p.connect()
+ peaks.append(p.overflow())
+ con.close()
+ del con
+ except exceptions.TimeoutError:
+ pass
+ threads = []
+ for i in xrange(thread_count):
+ th = threading.Thread(target=whammy)
+ th.start()
+ threads.append(th)
+ for th in threads:
+ th.join()
+
+ self.assert_(max(peaks) <= max_overflow)
+
+ def test_no_overflow(self):
+ self._test_overflow(20, 0)
+
+ def test_max_overflow(self):
+ self._test_overflow(20, 5)
def test_mixed_close(self):
p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)