]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Fix bug when max_clients kwarg is passed to AsyncHTTPClient.configure.
authorBen Darnell <ben@bendarnell.com>
Mon, 7 May 2012 01:10:13 +0000 (18:10 -0700)
committerBen Darnell <ben@bendarnell.com>
Mon, 7 May 2012 01:26:54 +0000 (18:26 -0700)
Closes #493.

tornado/httpclient.py
tornado/test/simple_httpclient_test.py

index 89f0057ad99ee664c0f4e98a8ee99621a217e15b..0fcc943f9d0712f7c9bd9d4585f867aabb1ec6c8 100644 (file)
@@ -124,6 +124,8 @@ class AsyncHTTPClient(object):
     _impl_class = None
     _impl_kwargs = None
 
+    _DEFAULT_MAX_CLIENTS = 10
+
     @classmethod
     def _async_clients(cls):
         assert cls is not AsyncHTTPClient, "should only be called on subclasses"
@@ -131,7 +133,7 @@ class AsyncHTTPClient(object):
             cls._async_client_dict = weakref.WeakKeyDictionary()
         return cls._async_client_dict
 
-    def __new__(cls, io_loop=None, max_clients=10, force_instance=False,
+    def __new__(cls, io_loop=None, max_clients=None, force_instance=False,
                 **kwargs):
         io_loop = io_loop or IOLoop.instance()
         if cls is AsyncHTTPClient:
@@ -149,7 +151,13 @@ class AsyncHTTPClient(object):
             if cls._impl_kwargs:
                 args.update(cls._impl_kwargs)
             args.update(kwargs)
-            instance.initialize(io_loop, max_clients, **args)
+            if max_clients is not None:
+                # max_clients is special because it may be passed
+                # positionally instead of by keyword
+                args["max_clients"] = max_clients
+            elif "max_clients" not in args:
+                args["max_clients"] = AsyncHTTPClient._DEFAULT_MAX_CLIENTS
+            instance.initialize(io_loop, **args)
             if not force_instance:
                 impl._async_clients()[io_loop] = instance
             return instance
@@ -204,6 +212,15 @@ class AsyncHTTPClient(object):
         AsyncHTTPClient._impl_class = impl
         AsyncHTTPClient._impl_kwargs = kwargs
 
+    @staticmethod
+    def _save_configuration():
+        return (AsyncHTTPClient._impl_class, AsyncHTTPClient._impl_kwargs)
+
+    @staticmethod
+    def _restore_configuration(saved):
+        AsyncHTTPClient._impl_class = saved[0]
+        AsyncHTTPClient._impl_kwargs = saved[1]
+
 
 class HTTPRequest(object):
     """HTTP client request object."""
index bbfd57b1e991687bf7ff8cc872c6e1470aa8f8a8..4a48eb0eebb8c91c4e944824273910e3db63da93 100644 (file)
@@ -1,16 +1,18 @@
 from __future__ import absolute_import, division, with_statement
 
 import collections
+from contextlib import closing
 import gzip
 import logging
 import re
 import socket
 
+from tornado.httpclient import AsyncHTTPClient
 from tornado.httputil import HTTPHeaders
 from tornado.ioloop import IOLoop
 from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
 from tornado.test.httpclient_test import HTTPClientCommonTestCase, ChunkHandler, CountdownHandler, HelloWorldHandler
-from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
+from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, LogTrapTestCase
 from tornado.util import b
 from tornado.web import RequestHandler, Application, asynchronous, url
 
@@ -263,3 +265,40 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
         self.http_client.fetch(url, self.stop)
         response = self.wait()
         self.assertTrue(host_re.match(response.body), response.body)
+
+
+class CreateAsyncHTTPClientTestCase(AsyncTestCase, LogTrapTestCase):
+    def setUp(self):
+        super(CreateAsyncHTTPClientTestCase, self).setUp()
+        self.saved = AsyncHTTPClient._save_configuration()
+
+    def tearDown(self):
+        AsyncHTTPClient._restore_configuration(self.saved)
+        super(CreateAsyncHTTPClientTestCase, self).tearDown()
+
+    def test_max_clients(self):
+        # The max_clients argument is tricky because it was originally
+        # allowed to be passed positionally; newer arguments are keyword-only.
+        AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
+        with closing(AsyncHTTPClient(
+                self.io_loop, force_instance=True)) as client:
+            self.assertEqual(client.max_clients, 10)
+        with closing(AsyncHTTPClient(
+                self.io_loop, 11, force_instance=True)) as client:
+            self.assertEqual(client.max_clients, 11)
+        with closing(AsyncHTTPClient(
+                self.io_loop, max_clients=11, force_instance=True)) as client:
+            self.assertEqual(client.max_clients, 11)
+
+        # Now configure max_clients statically and try overriding it
+        # with each way max_clients can be passed
+        AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
+        with closing(AsyncHTTPClient(
+                self.io_loop, force_instance=True)) as client:
+            self.assertEqual(client.max_clients, 12)
+        with closing(AsyncHTTPClient(
+                self.io_loop, max_clients=13, force_instance=True)) as client:
+            self.assertEqual(client.max_clients, 13)
+        with closing(AsyncHTTPClient(
+                self.io_loop, max_clients=14, force_instance=True)) as client:
+            self.assertEqual(client.max_clients, 14)