]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The :class:`.QueuePool` has been enhanced to not block new connection
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 7 Dec 2013 00:57:19 +0000 (19:57 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 7 Dec 2013 00:59:06 +0000 (19:59 -0500)
attempts when an existing connection attempt is blocking.  Previously,
the production of new connections was serialized within the block
that monitored overflow; the overflow counter is now altered within
it's own critical section outside of the connection process itself.
[ticket:2880]

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/pool.py
test/engine/test_pool.py

index c5f882f7f625b89d0fee9e3bccef8fd6c43da526..4c9cc85d9352c71c5cc02b12e506ab80deb3f212 100644 (file)
 .. changelog::
     :version: 0.8.4
 
+     .. change::
+        :tags: bug, engine, pool
+        :versions: 0.9.0b2
+        :tickets: 2880
+
+        The :class:`.QueuePool` has been enhanced to not block new connection
+        attempts when an existing connection attempt is blocking.  Previously,
+        the production of new connections was serialized within the block
+        that monitored overflow; the overflow counter is now altered within
+        it's own critical section outside of the connection process itself.
+
      .. change::
         :tags: bug, engine, pool
         :versions: 0.9.0b2
index 3c551031949cda3953d1355cbf44ab79ab32820e..dd00e745e2088adff3fe75b19f963d058049c60e 100644 (file)
@@ -631,15 +631,6 @@ class SingletonThreadPool(Pool):
         return c
 
 
-class DummyLock(object):
-
-    def acquire(self, wait=True):
-        return True
-
-    def release(self):
-        pass
-
-
 class QueuePool(Pool):
     """A :class:`.Pool` that imposes a limit on the number of open connections.
 
@@ -747,30 +738,25 @@ class QueuePool(Pool):
         self._overflow = 0 - pool_size
         self._max_overflow = max_overflow
         self._timeout = timeout
-        self._overflow_lock = threading.Lock() if self._max_overflow > -1 \
-                                    else DummyLock()
+        self._overflow_lock = threading.Lock()
 
     def _do_return_conn(self, conn):
         try:
             self._pool.put(conn, False)
         except sqla_queue.Full:
+            self._dec_overflow()
             conn.close()
-            self._overflow_lock.acquire()
-            try:
-                self._overflow -= 1
-            finally:
-                self._overflow_lock.release()
 
     def _do_get(self):
+        use_overflow = self._max_overflow > -1
+
         try:
-            wait = self._max_overflow > -1 and \
-                        self._overflow >= self._max_overflow
+            wait = use_overflow and self._overflow >= self._max_overflow
             return self._pool.get(wait, self._timeout)
         except sqla_queue.SAAbort, aborted:
             return aborted.context._do_get()
         except sqla_queue.Empty:
-            if self._max_overflow > -1 and \
-                        self._overflow >= self._max_overflow:
+            if use_overflow and self._overflow >= self._max_overflow:
                 if not wait:
                     return self._do_get()
                 else:
@@ -779,17 +765,33 @@ class QueuePool(Pool):
                             "connection timed out, timeout %d" %
                             (self.size(), self.overflow(), self._timeout))
 
-            self._overflow_lock.acquire()
-            try:
-                if self._max_overflow > -1 and \
-                            self._overflow >= self._max_overflow:
-                    return self._do_get()
-                else:
-                    con = self._create_connection()
-                    self._overflow += 1
-                    return con
-            finally:
-                self._overflow_lock.release()
+            if self._inc_overflow():
+                try:
+                    return self._create_connection()
+                except:
+                    self._dec_overflow()
+                    raise
+            else:
+                return self._do_get()
+
+    def _inc_overflow(self):
+        if self._max_overflow == -1:
+            self._overflow += 1
+            return True
+        with self._overflow_lock:
+            if self._overflow < self._max_overflow:
+                self._overflow += 1
+                return True
+            else:
+                return False
+
+    def _dec_overflow(self):
+        if self._max_overflow == -1:
+            self._overflow -= 1
+            return True
+        with self._overflow_lock:
+            self._overflow -= 1
+            return True
 
     def recreate(self):
         self.logger.info("Pool recreating")
index e0a9c602488275557735d6e2cbb44f5952c237cf..702ec5560773627ef4135797469e142fa5d25cf9 100644 (file)
@@ -875,6 +875,88 @@ class QueuePoolTest(PoolTestBase):
         lazy_gc()
         assert not pool._refs
 
+
+    def test_overflow_reset_on_failed_connect(self):
+        dbapi = Mock()
+
+        def failing_dbapi():
+            time.sleep(2)
+            raise Exception("connection failed")
+
+        creator = dbapi.connect
+        def create():
+            return creator()
+
+        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
+        c1 = p.connect()
+        c2 = p.connect()
+        c3 = p.connect()
+        eq_(p._overflow, 1)
+        creator = failing_dbapi
+        assert_raises(Exception, p.connect)
+        eq_(p._overflow, 1)
+
+    @testing.requires.threading_with_mock
+    def test_hanging_connect_within_overflow(self):
+        """test that a single connect() call which is hanging
+        does not block other connections from proceeding."""
+
+        dbapi = Mock()
+        mutex = threading.Lock()
+
+        def hanging_dbapi():
+            time.sleep(2)
+            with mutex:
+                return dbapi.connect()
+
+        def fast_dbapi():
+            with mutex:
+                return dbapi.connect()
+
+        creator = threading.local()
+
+        def create():
+            return creator.mock_connector()
+
+        def run_test(name, pool, should_hang):
+            if should_hang:
+                creator.mock_connector = hanging_dbapi
+            else:
+                creator.mock_connector = fast_dbapi
+
+            conn = pool.connect()
+            conn.operation(name)
+            time.sleep(1)
+            conn.close()
+
+        p = pool.QueuePool(creator=create, pool_size=2, max_overflow=3)
+
+        threads = [
+            threading.Thread(
+                        target=run_test, args=("success_one", p, False)),
+            threading.Thread(
+                        target=run_test, args=("success_two", p, False)),
+            threading.Thread(
+                        target=run_test, args=("overflow_one", p, True)),
+            threading.Thread(
+                        target=run_test, args=("overflow_two", p, False)),
+            threading.Thread(
+                        target=run_test, args=("overflow_three", p, False))
+        ]
+        for t in threads:
+            t.start()
+            time.sleep(.2)
+
+        for t in threads:
+            t.join(timeout=join_timeout)
+        eq_(
+            dbapi.connect().operation.mock_calls,
+            [call("success_one"), call("success_two"),
+                call("overflow_two"), call("overflow_three"),
+                call("overflow_one")]
+        )
+
+
     @testing.requires.threading_with_mock
     def test_waiters_handled(self):
         """test that threads waiting for connections are