]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added a mutex to QueuePool's "overflow" calculation to prevent a race
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 30 Jun 2007 01:14:15 +0000 (01:14 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 30 Jun 2007 01:14:15 +0000 (01:14 +0000)
condition that can bypass max_overflow; merged from 0.4 branch r2736-2738.
also made the locking logic simpler, tried to get the test to create a
failure on OSX (not successful)

CHANGES
lib/sqlalchemy/pool.py
test/engine/pool.py

diff --git a/CHANGES b/CHANGES
index 220a6fe3174ecfd2de2ead34a9012793fc15aa84..37f24deb2fc09c354cb5bd1acd5fe96678dbb436 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -32,6 +32,8 @@
     - MetaData and all SchemaItems are safe to use with pickle.  slow
       table reflections can be dumped into a pickled file to be reused later.
       Just reconnect the engine to the metadata after unpickling. [ticket:619]
+    - added a mutex to QueuePool's "overflow" calculation to prevent a race 
+      condition that can bypass max_overflow
     - fixed grouping of compound selects to give correct results. will break
       on sqlite in some cases, but those cases were producing incorrect
       results anyway, sqlite doesn't support grouped compound selects
index a77e979abd6c73600d8fa3892bc78df7e29ecee3..0377708e9322d851a8244775a58aaa5d91f3610a 100644 (file)
@@ -23,9 +23,10 @@ from sqlalchemy import exceptions, logging
 from sqlalchemy import queue as Queue
 
 try:
-    import thread
+    import thread, threading
 except:
     import dummy_thread as thread
+    import dummy_threading as threading
 
 proxies = {}
 
@@ -469,6 +470,7 @@ class QueuePool(Pool):
         self._overflow = 0 - pool_size
         self._max_overflow = max_overflow
         self._timeout = timeout
+        self._overflow_lock = threading.Lock()
 
     def recreate(self):
         self.log("Pool recreating")
@@ -478,16 +480,24 @@ class QueuePool(Pool):
         try:
             self._pool.put(conn, False)
         except Queue.Full:
-            self._overflow -= 1
+            self._overflow_lock.acquire()
+            try:
+                self._overflow -= 1
+            finally:
+                self._overflow_lock.release()
 
     def do_get(self):
         try:
             return self._pool.get(self._max_overflow > -1 and self._overflow >= self._max_overflow, self._timeout)
         except Queue.Empty:
-            if self._max_overflow > -1 and self._overflow >= self._max_overflow:
-                raise exceptions.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out" % (self.size(), self.overflow()))
-            con = self.create_connection()
-            self._overflow += 1
+            self._overflow_lock.acquire()
+            try:
+                if self._max_overflow > -1 and self._overflow >= self._max_overflow:
+                    raise exceptions.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out" % (self.size(), self.overflow()))
+                con = self.create_connection()
+                self._overflow += 1
+            finally:
+                self._overflow_lock.release()
             return con
 
     def dispose(self):
index 315470e98b34c634cc57db21a094e610293c0bd6..17a0b369a89188d1e0081db959dfc108b6200248 100644 (file)
@@ -1,6 +1,7 @@
 import testbase
 from testbase import PersistTest
 import unittest, sys, os, time
+import threading, thread
 
 import sqlalchemy.pool as pool
 import sqlalchemy.exceptions as exceptions
@@ -126,6 +127,38 @@ class PoolTest(PersistTest):
             assert False
         except exceptions.TimeoutError, e:
             assert int(time.time() - now) == 2
+
+    def _test_overflow(self, thread_count, max_overflow):
+        # i cant really get this to fail on OSX.  linux? windows ?
+        p = pool.QueuePool(creator=lambda: mock_dbapi.connect('foo.db'),
+                           pool_size=3, timeout=2,
+                           max_overflow=max_overflow)
+        peaks = []
+        def whammy():
+            for i in range(10):
+                try:
+                    con = p.connect()
+                    peaks.append(p.overflow())
+                    time.sleep(.005)
+                    con.close()
+                    del con
+                except exceptions.TimeoutError:
+                    pass
+        threads = []
+        for i in xrange(thread_count):
+            th = threading.Thread(target=whammy)
+            th.start()
+            threads.append(th)
+        for th in threads:
+            th.join()
+
+        self.assert_(max(peaks) <= max_overflow)
+
+    def test_no_overflow(self):
+        self._test_overflow(40, 0)
+
+    def test_max_overflow(self):
+        self._test_overflow(40, 5)
         
     def test_mixed_close(self):
         p = pool.QueuePool(creator = lambda: mock_dbapi.connect('foo.db'), pool_size = 3, max_overflow = -1, use_threadlocal = True)