]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Implement max_clients limitation and queueing for SimpleAsyncHTTPClient
authorBen Darnell <ben@bendarnell.com>
Tue, 16 Nov 2010 00:53:44 +0000 (16:53 -0800)
committerBen Darnell <ben@bendarnell.com>
Tue, 16 Nov 2010 00:53:44 +0000 (16:53 -0800)
tornado/simple_httpclient.py
tornado/test/simple_httpclient_test.py

index 07d655afb3e2cd378bec625d535e74ba969cb2c9..c1b340627d16e8fa67a74f62356a996bae691d72 100644 (file)
@@ -8,6 +8,7 @@ from tornado.ioloop import IOLoop
 from tornado.iostream import IOStream, SSLIOStream
 from tornado import stack_context
 
+import collections
 import contextlib
 import errno
 import functools
@@ -47,11 +48,13 @@ class SimpleAsyncHTTPClient(object):
     _ASYNC_CLIENTS = weakref.WeakKeyDictionary()
 
     def __new__(cls, io_loop=None, max_clients=10,
-                max_simultaneous_connections=None):
+                max_simultaneous_connections=None,
+                force_instance=False):
         """Creates a SimpleAsyncHTTPClient.
 
         Only a single SimpleAsyncHTTPClient instance exists per IOLoop
         in order to provide limitations on the number of pending connections.
+        force_instance=True may be used to suppress this behavior.
 
         max_clients is the number of concurrent requests that can be in
         progress.  max_simultaneous_connections has no effect and is accepted
@@ -60,13 +63,16 @@ class SimpleAsyncHTTPClient(object):
         and will be ignored when an existing client is reused.
         """
         io_loop = io_loop or IOLoop.instance()
-        if io_loop in cls._ASYNC_CLIENTS:
+        if io_loop in cls._ASYNC_CLIENTS and not force_instance:
             return cls._ASYNC_CLIENTS[io_loop]
         else:
             instance = super(SimpleAsyncHTTPClient, cls).__new__(cls)
             instance.io_loop = io_loop
             instance.max_clients = max_clients
-            cls._ASYNC_CLIENTS[io_loop] = instance
+            instance.queue = collections.deque()
+            instance.active = {}
+            if not force_instance:
+                cls._ASYNC_CLIENTS[io_loop] = instance
             return instance
 
     def close(self):
@@ -78,7 +84,23 @@ class SimpleAsyncHTTPClient(object):
         if not isinstance(request.headers, HTTPHeaders):
             request.headers = HTTPHeaders(request.headers)
         callback = stack_context.wrap(callback)
-        _HTTPConnection(self.io_loop, request, callback)
+        self.queue.append((request, callback))
+        self._process_queue()
+
+    def _process_queue(self):
+        with stack_context.NullContext():
+            while self.queue and len(self.active) < self.max_clients:
+                request, callback = self.queue.popleft()
+                key = object()
+                self.active[key] = (request, callback)
+                _HTTPConnection(self.io_loop, request,
+                                functools.partial(self._on_fetch_complete,
+                                                  key, callback))
+
+    def _on_fetch_complete(self, key, callback, response):
+        del self.active[key]
+        callback(response)
+        self._process_queue()
 
 
 
index 24b67b024f0875475f7fab6f561d512fd7eb1fdf..e8c35277a0271bf40a4fa0acd74ecaa99acae6af 100644 (file)
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 
+import collections
 import gzip
 import logging
 import socket
@@ -36,14 +37,29 @@ class HangHandler(RequestHandler):
     def get(self):
         pass
 
+class TriggerHandler(RequestHandler):
+    def initialize(self, queue, wake_callback):
+        self.queue = queue
+        self.wake_callback = wake_callback
+
+    @asynchronous
+    def get(self):
+        logging.info("queuing trigger")
+        self.queue.append(self.finish)
+        self.wake_callback()
+
 class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
     def get_app(self):
+        # callable objects to finish pending /trigger requests
+        self.triggers = collections.deque()
         return Application([
             ("/hello", HelloWorldHandler),
             ("/post", PostHandler),
             ("/chunk", ChunkHandler),
             ("/auth", AuthHandler),
             ("/hang", HangHandler),
+            ("/trigger", TriggerHandler, dict(queue=self.triggers,
+                                              wake_callback=self.stop)),
             ], gzip=True)
 
     def setUp(self):
@@ -128,7 +144,33 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
         # Class "constructor" reuses objects on the same IOLoop
         self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is
                         SimpleAsyncHTTPClient(self.io_loop))
+        # unless force_instance is used
+        self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
+                        SimpleAsyncHTTPClient(self.io_loop,
+                                              force_instance=True))
         # different IOLoops use different objects
         io_loop2 = IOLoop()
         self.assertTrue(SimpleAsyncHTTPClient(self.io_loop) is not
                         SimpleAsyncHTTPClient(io_loop2))
+
+    def test_connection_limit(self):
+        client = SimpleAsyncHTTPClient(self.io_loop, max_clients=2,
+                                       force_instance=True)
+        self.assertEqual(client.max_clients, 2)
+        seen = []
+        # Send 4 requests.  Two can be sent immediately, while the others
+        # will be queued
+        for i in range(4):
+            client.fetch(self.get_url("/trigger"),
+                         lambda response, i=i: (seen.append(i), self.stop()))
+        self.wait(condition=lambda: len(self.triggers) == 2)
+        self.assertEqual(len(client.queue), 2)
+
+        # Finish the first two requests and let the next two through
+        self.triggers.popleft()()
+        self.triggers.popleft()()
+        self.wait(condition=lambda: (len(self.triggers) == 2 and
+                                     len(seen) == 2))
+        self.assertEqual(seen, [0, 1])
+        self.assertEqual(len(client.queue), 0)
+