]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Fix HTTP client selection as used in curl_httpclient_test.
authorBen Darnell <ben@bendarnell.com>
Tue, 19 Jul 2011 05:34:19 +0000 (22:34 -0700)
committerBen Darnell <ben@bendarnell.com>
Tue, 19 Jul 2011 05:34:19 +0000 (22:34 -0700)
AsyncHTTPClient.configure() was working, but it didn't work to instantiate
the client directly like the unit tests were using.

tornado/httpclient.py
tornado/test/curl_httpclient_test.py
tornado/test/simple_httpclient_test.py

index 56d727317ba5ec7f5cb6c27552b81371d866e7e7..b8f97c6295c35fe78d995b661dfec90a1ab28988 100644 (file)
@@ -111,23 +111,29 @@ class AsyncHTTPClient(object):
     are deprecated.  The implementation subclass as well as arguments to
     its constructor can be set with the static method configure()
     """
-    _async_clients = weakref.WeakKeyDictionary()
     _impl_class = None
     _impl_kwargs = None
 
+    @classmethod
+    def _async_clients(cls):
+        assert cls is not AsyncHTTPClient, "should only be called on subclasses"
+        if not hasattr(cls, '_async_client_dict'):
+            cls._async_client_dict = weakref.WeakKeyDictionary()
+        return cls._async_client_dict
+
     def __new__(cls, io_loop=None, max_clients=10, force_instance=False, 
                 **kwargs):
         io_loop = io_loop or IOLoop.instance()
-        if io_loop in cls._async_clients and not force_instance:
-            return cls._async_clients[io_loop]
+        if cls is AsyncHTTPClient:
+            if cls._impl_class is None:
+                from tornado.simple_httpclient import SimpleAsyncHTTPClient
+                AsyncHTTPClient._impl_class = SimpleAsyncHTTPClient
+            impl = AsyncHTTPClient._impl_class
+        else:
+            impl = cls
+        if io_loop in impl._async_clients() and not force_instance:
+            return impl._async_clients()[io_loop]
         else:
-            if cls is AsyncHTTPClient:
-                if cls._impl_class is None:
-                    from tornado.simple_httpclient import SimpleAsyncHTTPClient
-                    AsyncHTTPClient._impl_class = SimpleAsyncHTTPClient
-                impl = cls._impl_class
-            else:
-                impl = cls
             instance = super(AsyncHTTPClient, cls).__new__(impl)
             args = {}
             if cls._impl_kwargs:
@@ -135,7 +141,7 @@ class AsyncHTTPClient(object):
             args.update(kwargs)
             instance.initialize(io_loop, max_clients, **args)
             if not force_instance:
-                cls._async_clients[io_loop] = instance
+                impl._async_clients()[io_loop] = instance
             return instance
 
     def close(self):
@@ -144,8 +150,8 @@ class AsyncHTTPClient(object):
         create and destroy http clients.  No other methods may be called
         on the AsyncHTTPClient after close().
         """
-        if self._async_clients.get(self.io_loop) is self:
-            del self._async_clients[self.io_loop]
+        if self._async_clients().get(self.io_loop) is self:
+            del self._async_clients()[self.io_loop]
 
     def fetch(self, request, callback, **kwargs):
         """Executes a request, calling callback with an `HTTPResponse`.
index 2fb4e2d871819b3c8343ec0beb98ef31b9d2ee82..afa56f8f752556f2ebe68e21d013f3edace4f558 100644 (file)
@@ -10,7 +10,10 @@ if pycurl is not None:
 
 class CurlHTTPClientCommonTestCase(HTTPClientCommonTestCase):
     def get_http_client(self):
-        return CurlAsyncHTTPClient(io_loop=self.io_loop)
+        client = CurlAsyncHTTPClient(io_loop=self.io_loop)
+        # make sure AsyncHTTPClient magic doesn't give us the wrong class
+        self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
+        return client
 
 # Remove the base class from our namespace so the unittest module doesn't
 # try to run it again.
index b7e3ffcdafa192d5c9b2eaf89e510f94f5eabadf..bf25b1b41fd614f48d4cf73e11ba55ddbcb3685b 100644 (file)
@@ -9,8 +9,10 @@ from tornado.web import RequestHandler, Application, asynchronous, url
 
 class SimpleHTTPClientCommonTestCase(HTTPClientCommonTestCase):
     def get_http_client(self):
-        return SimpleAsyncHTTPClient(io_loop=self.io_loop,
-                                     force_instance=True)
+        client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
+                                       force_instance=True)
+        self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
+        return client
 
 # Remove the base class from our namespace so the unittest module doesn't
 # try to run it again.