]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Clean up timed-out waiters in Semaphore. 1366/head
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Tue, 3 Mar 2015 17:14:57 +0000 (12:14 -0500)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Tue, 3 Mar 2015 17:31:15 +0000 (12:31 -0500)
tornado/locks.py
tornado/test/locks_test.py

index ea2e89d9b0deff1729b0d871f1a5a5a60b3fac7f..4d2ab9b38331408464fac2e94ff1f3edce3a73d5 100644 (file)
@@ -22,7 +22,29 @@ from tornado import gen, ioloop
 from tornado.concurrent import Future
 
 
-class Condition(object):
+class _TimeoutGarbageCollector(object):
+    """Base class for objects that periodically clean up timed-out waiters.
+
+    Avoids memory leak in a common pattern like:
+
+        while True:
+            yield condition.wait(short_timeout)
+            print('looping....')
+    """
+    def __init__(self):
+        self._waiters = collections.deque()  # Futures.
+        self._timeouts = 0
+
+    def _garbage_collect(self):
+        # Occasionally clear timed-out waiters.
+        self._timeouts += 1
+        if self._timeouts > 100:
+            self._timeouts = 0
+            self._waiters = collections.deque(
+                w for w in self._waiters if not w.done())
+
+
+class Condition(_TimeoutGarbageCollector):
     """A condition allows one or more coroutines to wait until notified.
 
     Like a standard `threading.Condition`, but does not need an underlying lock
@@ -30,9 +52,8 @@ class Condition(object):
     """
 
     def __init__(self):
+        super(Condition, self).__init__()
         self.io_loop = ioloop.IOLoop.current()
-        self._waiters = collections.deque()  # Futures.
-        self._timeouts = 0
 
     def __repr__(self):
         result = '<%s' % (self.__class__.__name__, )
@@ -71,15 +92,6 @@ class Condition(object):
         """Wake all waiters."""
         self.notify(len(self._waiters))
 
-    def _garbage_collect(self):
-        # Occasionally clear timed-out waiters, if many coroutines wait with a
-        # timeout but notify is called rarely.
-        self._timeouts += 1
-        if self._timeouts > 100:
-            self._timeouts = 0
-            self._waiters = collections.deque(
-                w for w in self._waiters if not w.done())
-
 
 class Event(object):
     """An event blocks coroutines until its internal flag is set to True.
@@ -143,7 +155,7 @@ class _ReleasingContextManager(object):
         self._obj.release()
 
 
-class Semaphore(object):
+class Semaphore(_TimeoutGarbageCollector):
     """A lock that can be acquired a fixed number of times before blocking.
 
     A Semaphore manages a counter representing the number of `.release` calls
@@ -164,11 +176,11 @@ class Semaphore(object):
     ...    # Now the semaphore is released.
     """
     def __init__(self, value=1):
+        super(Semaphore, self).__init__()
         if value < 0:
             raise ValueError('semaphore initial value must be >= 0')
 
         self._value = value
-        self._waiters = collections.deque()
 
     def __repr__(self):
         res = super(Semaphore, self).__repr__()
@@ -201,22 +213,18 @@ class Semaphore(object):
         Block if the counter is zero and wait for a `.release`. The Future
         raises `.TimeoutError` after the deadline.
         """
+        waiter = Future()
         if self._value > 0:
             self._value -= 1
-            future = Future()
-            future.set_result(_ReleasingContextManager(self))
+            waiter.set_result(_ReleasingContextManager(self))
         else:
-            waiter = Future()
             self._waiters.append(waiter)
             if timeout:
-                future = gen.with_timeout(timeout, waiter,
-                                          quiet_exceptions=gen.TimeoutError)
-
-                # Set waiter's exception after the deadline.
-                gen.chain_future(future, waiter)
-            else:
-                future = waiter
-        return future
+                def on_timeout():
+                    waiter.set_exception(gen.TimeoutError())
+                    self._garbage_collect()
+                ioloop.IOLoop.current().add_timeout(timeout, on_timeout)
+        return waiter
 
     def __enter__(self):
         raise RuntimeError(
index aabb8040349c58b2debf89ce4e41caa558f435d2..8eaa4236f4a84c86d246cb8bfdd8502a3211b218 100644 (file)
@@ -274,6 +274,28 @@ class SemaphoreTest(AsyncTestCase):
         self.assertTrue(sem.acquire().done())
         self.assertFalse(sem.acquire().done())
 
+    @gen_test
+    def test_garbage_collection(self):
+        # Test that timed-out waiters are occasionally cleaned from the queue.
+        sem = locks.Semaphore(value=0)
+        futures = [sem.acquire(timedelta(seconds=0.01)) for _ in range(101)]
+
+        future = sem.acquire()
+        self.assertEqual(102, len(sem._waiters))
+
+        # Let first 101 waiters time out, triggering a collection.
+        yield gen.sleep(0.02)
+        self.assertEqual(1, len(sem._waiters))
+
+        # Final waiter is still active.
+        self.assertFalse(future.done())
+        sem.release()
+        self.assertTrue(future.done())
+
+        # Prevent "Future exception was never retrieved" messages.
+        for future in futures:
+            self.assertRaises(TimeoutError, future.result)
+
 
 class SemaphoreContextManagerTest(AsyncTestCase):
     @gen_test