]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Cancel timeouts if Queue.get or put are resolved first. 1380/head
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Mon, 16 Mar 2015 01:09:12 +0000 (21:09 -0400)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Mon, 16 Mar 2015 01:09:12 +0000 (21:09 -0400)
tornado/queues.py
tornado/test/queues_test.py

index 1e4a31aba112a7dc8aab217be907e0ed34915099..5986ccc96891d2236d64c81d8afa3ba40856435c 100644 (file)
@@ -33,6 +33,16 @@ class QueueFull(Exception):
     pass
 
 
+def _set_timeout(future, timeout):
+    if timeout:
+        def on_timeout():
+            future.set_exception(gen.TimeoutError())
+        io_loop = ioloop.IOLoop.current()
+        timeout_handle = io_loop.add_timeout(timeout, on_timeout)
+        future.add_done_callback(
+            lambda _: io_loop.remove_timeout(timeout_handle))
+
+
 class Queue(object):
     """Coordinate producer and consumer coroutines.
 
@@ -82,10 +92,7 @@ class Queue(object):
         except QueueFull:
             future = Future()
             self._putters.append((item, future))
-            if timeout:
-                def on_timeout():
-                    future.set_exception(gen.TimeoutError())
-                ioloop.IOLoop.current().add_timeout(timeout, on_timeout)
+            _set_timeout(future, timeout)
             return future
         else:
             return gen._null_future
@@ -117,10 +124,7 @@ class Queue(object):
             future.set_result(self.get_nowait())
         except QueueEmpty:
             self._getters.append(future)
-            if timeout:
-                def on_timeout():
-                    future.set_exception(gen.TimeoutError())
-                ioloop.IOLoop.current().add_timeout(timeout, on_timeout)
+            _set_timeout(future, timeout)
         return future
 
     def get_nowait(self):
index 34b611315c9431a51840a0c1419769a5c834823f..ac2118332d8b93add8fe33bd8049f6e377f5f892 100644 (file)
@@ -116,6 +116,14 @@ class QueueGetTest(AsyncTestCase):
         q.put_nowait(0)
         self.assertEqual(0, (yield get))
 
+    @gen_test
+    def test_get_timeout_preempted(self):
+        q = queues.Queue()
+        get = q.get(timeout=timedelta(seconds=0.01))
+        q.put(0)
+        yield gen.sleep(0.02)
+        self.assertEqual(0, (yield get))
+
     @gen_test
     def test_get_clears_timed_out_putters(self):
         q = queues.Queue(1)
@@ -208,6 +216,15 @@ class QueuePutTest(AsyncTestCase):
         # Final get() unblocked this putter.
         yield put
 
+    @gen_test
+    def test_put_timeout_preempted(self):
+        q = queues.Queue(1)
+        q.put_nowait(0)
+        put = q.put(1, timeout=timedelta(seconds=0.01))
+        q.get()
+        yield gen.sleep(0.02)
+        yield put  # No TimeoutError.
+
     @gen_test
     def test_put_clears_timed_out_putters(self):
         q = queues.Queue(1)