]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- A major improvement made to the mechanics by which the :class:`.Engine`
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Mar 2014 22:45:39 +0000 (18:45 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Mar 2014 22:45:39 +0000 (18:45 -0400)
recycles the connection pool when a "disconnect" condition is detected;
instead of discarding the pool and explicitly closing out connections,
the pool is retained and a "generational" timestamp is updated to
reflect the current time, thereby causing all existing connections
to be recycled when they are next checked out.   This greatly simplifies
the recycle process, removes the need for "waking up" connect attempts
waiting on the old pool and eliminates the race condition that many
immediately-discarded "pool" objects could be created during the
recycle operation. fixes #2985

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/util/queue.py
test/engine/test_pool.py
test/engine/test_reconnect.py

index 2e4e5f0983096c0a0fa8dbcab6615f956168ebdf..c4fa76e493a5cb5877b2f6f64724d6c99482688a 100644 (file)
 .. changelog::
     :version: 0.9.4
 
+    .. change::
+        :tags: bug, engine
+        :tickets: 2985
+
+        A major improvement made to the mechanics by which the :class:`.Engine`
+        recycles the connection pool when a "disconnect" condition is detected;
+        instead of discarding the pool and explicitly closing out connections,
+        the pool is retained and a "generational" timestamp is updated to
+        reflect the current time, thereby causing all existing connections
+        to be recycled when they are next checked out.   This greatly simplifies
+        the recycle process, removes the need for "waking up" connect attempts
+        waiting on the old pool and eliminates the race condition that many
+        immediately-discarded "pool" objects could be created during the
+        recycle operation.
+
     .. change::
         :tags: bug, oracle
         :tickets: 2987
index d3024640bb44575d5d668072db19e03db18d2b3a..2cad2a09498185c56ae49846b870e9916709214c 100644 (file)
@@ -1091,9 +1091,7 @@ class Connection(Connectable):
                 del self._is_disconnect
                 dbapi_conn_wrapper = self.connection
                 self.invalidate(e)
-                if not hasattr(dbapi_conn_wrapper, '_pool') or \
-                        dbapi_conn_wrapper._pool is self.engine.pool:
-                    self.engine.dispose()
+                self.engine.pool._invalidate(dbapi_conn_wrapper)
             if self.should_close_with_result:
                 self.close()
 
@@ -1503,7 +1501,7 @@ class Engine(Connectable, log.Identified):
         the engine are not affected.
 
         """
-        self.pool = self.pool._replace()
+        self.pool.dispose()
 
     def _execute_default(self, default):
         with self.contextual_connect() as conn:
index 473b665c8395885eae0cc5cd04831b3e9d98931e..4a07e785696f3520085a40999762215510eb5bb1 100644 (file)
@@ -528,7 +528,6 @@ class LazyLoader(AbstractRelationshipLoader):
     def _emit_lazyload(self, strategy_options, session, state, ident_key, passive):
         q = session.query(self.mapper)._adapt_all_clauses()
 
-
         if self.parent_property.secondary is not None:
             q = q.select_from(self.mapper, self.parent_property.secondary)
 
index 59c1e614a2f8dc12b4a22a26c45bc198868e4336..7fc4fc6591f63126d0a63fa19770dc55c0c5d3e6 100644 (file)
@@ -210,6 +210,7 @@ class Pool(log.Identified):
         self._threadconns = threading.local()
         self._creator = creator
         self._recycle = recycle
+        self._invalidate_time = 0
         self._use_threadlocal = use_threadlocal
         if reset_on_return in ('rollback', True, reset_rollback):
             self._reset_on_return = reset_rollback
@@ -276,6 +277,22 @@ class Pool(log.Identified):
 
         return _ConnectionRecord(self)
 
+    def _invalidate(self, connection):
+        """Mark all connections established within the generation
+        of the given connection as invalidated.
+
+        If this pool's last invalidate time is before when the given
+        connection was created, update the timestamp til now.  Otherwise,
+        no action is performed.
+
+        Connections with a start time prior to this pool's invalidation
+        time will be recycled upon next checkout.
+        """
+        rec = getattr(connection, "_connection_record", None)
+        if not rec or self._invalidate_time < rec.starttime:
+            self._invalidate_time = time.time()
+
+
     def recreate(self):
         """Return a new :class:`.Pool`, of the same class as this one
         and configured with identical creation arguments.
@@ -301,17 +318,6 @@ class Pool(log.Identified):
 
         raise NotImplementedError()
 
-    def _replace(self):
-        """Dispose + recreate this pool.
-
-        Subclasses may employ special logic to
-        move threads waiting on this pool to the
-        new one.
-
-        """
-        self.dispose()
-        return self.recreate()
-
     def connect(self):
         """Return a DBAPI connection from the pool.
 
@@ -483,6 +489,7 @@ class _ConnectionRecord(object):
         self.connection = None
 
     def get_connection(self):
+        recycle = False
         if self.connection is None:
             self.connection = self.__connect()
             self.info.clear()
@@ -493,6 +500,15 @@ class _ConnectionRecord(object):
             self.__pool.logger.info(
                     "Connection %r exceeded timeout; recycling",
                     self.connection)
+            recycle = True
+        elif self.__pool._invalidate_time > self.starttime:
+            self.__pool.logger.info(
+                    "Connection %r invalidated due to pool invalidation; recycling",
+                    self.connection
+                    )
+            recycle = True
+
+        if recycle:
             self.__close()
             self.connection = self.__connect()
             self.info.clear()
@@ -911,8 +927,6 @@ class QueuePool(Pool):
         try:
             wait = use_overflow and self._overflow >= self._max_overflow
             return self._pool.get(wait, self._timeout)
-        except sqla_queue.SAAbort as aborted:
-            return aborted.context._do_get()
         except sqla_queue.Empty:
             if use_overflow and self._overflow >= self._max_overflow:
                 if not wait:
@@ -974,12 +988,6 @@ class QueuePool(Pool):
         self._overflow = 0 - self.size()
         self.logger.info("Pool disposed. %s", self.status())
 
-    def _replace(self):
-        self.dispose()
-        np = self.recreate()
-        self._pool.abort(np)
-        return np
-
     def status(self):
         return "Pool size: %d  Connections in pool: %d "\
                 "Current Overflow: %d Current Checked out "\
index 82ff55a5d3fb4bdbebf55b4999073a55390d2bbb..c98aa7fdaba311d1edb30f5d1f729306e8854f78 100644 (file)
@@ -15,11 +15,6 @@ rare cases be invoked within the ``get()`` method of the Queue itself,
 producing a ``put()`` inside the ``get()`` and therefore a reentrant
 condition.
 
-An additional change includes a special "abort" method which can be used
-to immediately raise a special exception for threads that are blocking
-on get().  This is to accommodate a rare race condition that can occur
-within QueuePool.
-
 """
 
 from collections import deque
@@ -27,7 +22,7 @@ from time import time as _time
 from .compat import threading
 
 
-__all__ = ['Empty', 'Full', 'Queue', 'SAAbort']
+__all__ = ['Empty', 'Full', 'Queue']
 
 
 class Empty(Exception):
@@ -42,12 +37,6 @@ class Full(Exception):
     pass
 
 
-class SAAbort(Exception):
-    "Special SQLA exception to abort waiting"
-    def __init__(self, context):
-        self.context = context
-
-
 class Queue:
     def __init__(self, maxsize=0):
         """Initialize a queue object with a given maximum size.
@@ -68,8 +57,6 @@ class Queue:
         # a thread waiting to put is notified then.
         self.not_full = threading.Condition(self.mutex)
 
-        # when this is set, SAAbort is raised within get().
-        self._sqla_abort_context = False
 
     def qsize(self):
         """Return the approximate size of the queue (not reliable!)."""
@@ -158,13 +145,7 @@ class Queue:
                     raise Empty
             elif timeout is None:
                 while self._empty():
-                    # wait for only half a second, then
-                    # loop around, so that we can see a change in
-                    # _sqla_abort_context in case we missed the notify_all()
-                    # called by abort()
-                    self.not_empty.wait(.5)
-                    if self._sqla_abort_context:
-                        raise SAAbort(self._sqla_abort_context)
+                    self.not_empty.wait()
             else:
                 if timeout < 0:
                     raise ValueError("'timeout' must be a positive number")
@@ -174,30 +155,12 @@ class Queue:
                     if remaining <= 0.0:
                         raise Empty
                     self.not_empty.wait(remaining)
-                    if self._sqla_abort_context:
-                        raise SAAbort(self._sqla_abort_context)
             item = self._get()
             self.not_full.notify()
             return item
         finally:
             self.not_empty.release()
 
-    def abort(self, context):
-        """Issue an 'abort', will force any thread waiting on get()
-        to stop waiting and raise SAAbort.
-
-        """
-        self._sqla_abort_context = context
-        if not self.not_full.acquire(False):
-            return
-        try:
-            # note that this is now optional
-            # as the waiters in get() both loop around
-            # to check the _sqla_abort_context flag periodically
-            self.not_empty.notify_all()
-        finally:
-            self.not_full.release()
-
     def get_nowait(self):
         """Remove and return an item from the queue without blocking.
 
index fc6f3dceaa97307c5b837a9d9c271d58491c76e6..bbab0a7c3d8e7ff411379dd5a16299ca85810340 100644 (file)
@@ -7,7 +7,7 @@ from sqlalchemy.testing.util import gc_collect, lazy_gc
 from sqlalchemy.testing import eq_, assert_raises, is_not_
 from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing import fixtures
-
+import random
 from sqlalchemy.testing.mock import Mock, call
 
 join_timeout = 10
@@ -1069,7 +1069,8 @@ class QueuePoolTest(PoolTestBase):
                 # inside the queue, before we invalidate the other
                 # two conns
                 time.sleep(.2)
-                p2 = p._replace()
+                p._invalidate(c2)
+                c2.invalidate()
 
                 for t in threads:
                     t.join(join_timeout)
@@ -1079,19 +1080,18 @@ class QueuePoolTest(PoolTestBase):
     @testing.requires.threading_with_mock
     def test_notify_waiters(self):
         dbapi = MockDBAPI()
+
         canary = []
-        def creator1():
+        def creator():
             canary.append(1)
             return dbapi.connect()
-        def creator2():
-            canary.append(2)
-            return dbapi.connect()
-        p1 = pool.QueuePool(creator=creator1,
+        p1 = pool.QueuePool(creator=creator,
                            pool_size=1, timeout=None,
                            max_overflow=0)
-        p2 = pool.NullPool(creator=creator2)
+        #p2 = pool.NullPool(creator=creator2)
         def waiter(p):
             conn = p.connect()
+            canary.append(2)
             time.sleep(.5)
             conn.close()
 
@@ -1104,12 +1104,14 @@ class QueuePoolTest(PoolTestBase):
             threads.append(t)
         time.sleep(.5)
         eq_(canary, [1])
-        p1._pool.abort(p2)
+
+        c1.invalidate()
+        p1._invalidate(c1)
 
         for t in threads:
             t.join(join_timeout)
 
-        eq_(canary, [1, 2, 2, 2, 2, 2])
+        eq_(canary, [1, 1, 2, 2, 2, 2, 2])
 
     def test_dispose_closes_pooled(self):
         dbapi = MockDBAPI()
@@ -1251,6 +1253,21 @@ class QueuePoolTest(PoolTestBase):
         c3 = p.connect()
         assert id(c3.connection) != c_id
 
+    def test_recycle_on_invalidate(self):
+        p = self._queuepool_fixture(pool_size=1,
+                           max_overflow=0)
+        c1 = p.connect()
+        c_id = id(c1.connection)
+        c1.close()
+        c2 = p.connect()
+        assert id(c2.connection) == c_id
+
+        p._invalidate(c2)
+        c2.close()
+        time.sleep(.5)
+        c3 = p.connect()
+        assert id(c3.connection) != c_id
+
     def _assert_cleanup_on_pooled_reconnect(self, dbapi, p):
         # p is QueuePool with size=1, max_overflow=2,
         # and one connection in the pool that will need to
@@ -1290,6 +1307,72 @@ class QueuePoolTest(PoolTestBase):
         time.sleep(1)
         self._assert_cleanup_on_pooled_reconnect(dbapi, p)
 
+    def test_recycle_pool_no_race(self):
+        def slow_close():
+            slow_closing_connection._slow_close()
+            time.sleep(.5)
+
+        slow_closing_connection = Mock()
+        slow_closing_connection.connect.return_value.close = slow_close
+
+        class Error(Exception):
+            pass
+
+        dialect = Mock()
+        dialect.is_disconnect = lambda *arg, **kw: True
+        dialect.dbapi.Error = Error
+
+        pools = []
+        class TrackQueuePool(pool.QueuePool):
+            def __init__(self, *arg, **kw):
+                pools.append(self)
+                super(TrackQueuePool, self).__init__(*arg, **kw)
+
+        def creator():
+            return slow_closing_connection.connect()
+        p1 = TrackQueuePool(creator=creator, pool_size=20)
+
+        from sqlalchemy import create_engine
+        eng = create_engine("postgresql://", pool=p1, _initialize=False)
+        eng.dialect = dialect
+
+        # 15 total connections
+        conns = [eng.connect() for i in range(15)]
+
+        # return 8 back to the pool
+        for conn in conns[3:10]:
+            conn.close()
+
+        def attempt(conn):
+            time.sleep(random.random())
+            try:
+                conn._handle_dbapi_exception(Error(), "statement", {}, Mock(), Mock())
+            except tsa.exc.DBAPIError:
+                pass
+
+        # run an error + invalidate operation on the remaining 7 open connections
+        threads = []
+        for conn in conns:
+            t = threading.Thread(target=attempt, args=(conn, ))
+            t.start()
+            threads.append(t)
+
+        for t in threads:
+            t.join()
+
+        # return all 15 connections to the pool
+        for conn in conns:
+            conn.close()
+
+        # re-open 15 total connections
+        conns = [eng.connect() for i in range(15)]
+
+        # 15 connections have been fully closed due to invalidate
+        assert slow_closing_connection._slow_close.call_count == 15
+
+        # 15 initial connections + 15 reconnections
+        assert slow_closing_connection.connect.call_count == 30
+        assert len(pools) <= 2, len(pools)
 
     def test_invalidate(self):
         p = self._queuepool_fixture(pool_size=1, max_overflow=0)
index ba336a1bfac6359e37fb9859fdca3dca18f28210..a3ad9c548aa4516a96ea24708c365a1f28e06dd7 100644 (file)
@@ -146,16 +146,20 @@ class MockReconnectTest(fixtures.TestBase):
         # close shouldnt break
 
         conn.close()
-        is_not_(self.db.pool, db_pool)
-
-        # ensure all connections closed (pool was recycled)
 
+        # ensure one connection closed...
         eq_(
             [c.close.mock_calls for c in self.dbapi.connections],
-            [[call()], [call()]]
+            [[call()], []]
         )
 
         conn = self.db.connect()
+
+        eq_(
+            [c.close.mock_calls for c in self.dbapi.connections],
+            [[call()], [call()], []]
+        )
+
         conn.execute(select([1]))
         conn.close()
 
@@ -534,8 +538,6 @@ class RealReconnectTest(fixtures.TestBase):
         # invalidate() also doesn't screw up
         assert_raises(exc.DBAPIError, engine.connect)
 
-        # pool was recreated
-        assert engine.pool is not p1
 
     def test_null_pool(self):
         engine = \