]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add a close method to Resolver and use it where necessary.
authorBen Darnell <ben@bendarnell.com>
Thu, 16 May 2013 02:47:35 +0000 (22:47 -0400)
committerBen Darnell <ben@bendarnell.com>
Thu, 16 May 2013 02:47:35 +0000 (22:47 -0400)
This fixes a thread leak when running the test suite with
ThreadedResolver.

tornado/concurrent.py
tornado/netutil.py
tornado/simple_httpclient.py
tornado/test/httpclient_test.py
tornado/test/httpserver_test.py
tornado/test/simple_httpclient_test.py
tornado/websocket.py

index 15a039ca1a8d4ec7ae9a98eae075858e4a93f935..8a4f22878cfe2cd2a38dec2758a1a17264156c9b 100644 (file)
@@ -140,6 +140,9 @@ class DummyExecutor(object):
             future.set_exc_info(sys.exc_info())
         return future
 
+    def shutdown(self, wait=True):
+        pass
+
 dummy_executor = DummyExecutor()
 
 
index c1ba01b2bd11a7d6d84b20608a677410c7cb42e0..098f8bf8b209bb759070e303b1f00962d30450d1 100644 (file)
@@ -207,12 +207,20 @@ class Resolver(Configurable):
         """
         raise NotImplementedError()
 
+    def close(self):
+        """Closes the `Resolver`, freeing any resources used."""
+        pass
+
 
 class ExecutorResolver(Resolver):
     def initialize(self, io_loop=None, executor=None):
         self.io_loop = io_loop or IOLoop.current()
         self.executor = executor or dummy_executor
 
+    def close(self):
+        self.executor.shutdown()
+        self.executor = None
+
     @run_on_executor
     def resolve(self, host, port, family=socket.AF_UNSPEC):
         # On Solaris, getaddrinfo fails if the given port is not found
@@ -267,6 +275,9 @@ class OverrideResolver(Resolver):
         self.resolver = resolver
         self.mapping = mapping
 
+    def close(self):
+        self.resolver.close()
+
     def resolve(self, host, port, *args, **kwargs):
         if (host, port) in self.mapping:
             host, port = self.mapping[(host, port)]
index 31b5a73d83996abb37cf1de3eb95ee4334170452..16874e0366efd32beaa2adc00948a28d7fd694a2 100644 (file)
@@ -73,11 +73,21 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
         self.queue = collections.deque()
         self.active = {}
         self.max_buffer_size = max_buffer_size
-        self.resolver = resolver or Resolver(io_loop=io_loop)
+        if resolver:
+            self.resolver = resolver
+            self.own_resolver = False
+        else:
+            self.resolver = Resolver(io_loop=io_loop)
+            self.own_resolver = True
         if hostname_mapping is not None:
             self.resolver = OverrideResolver(resolver=self.resolver,
                                              mapping=hostname_mapping)
 
+    def close(self):
+        super(SimpleAsyncHTTPClient, self).close()
+        if self.own_resolver:
+            self.resolver.close()
+
     def fetch_impl(self, request, callback):
         self.queue.append((request, callback))
         self._process_queue()
index 2ce93c646d511e55a5d7c46c92737a5de2497161..62645e737a0a409c7d03edeff57f232ef9ec01f1 100644 (file)
@@ -431,6 +431,7 @@ class SyncHTTPClientTest(unittest.TestCase):
     def tearDown(self):
         self.server_ioloop.add_callback(self.server_ioloop.stop)
         self.server_thread.join()
+        self.http_client.close()
         self.server_ioloop.close(all_fds=True)
 
     def get_url(self, path):
index 6f53c3af62d53f617202fec0923e9991bfe6431a..a4079812ef81826090d9809ca45f222c68d994e4 100644 (file)
@@ -14,6 +14,7 @@ from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase
 from tornado.test.util import unittest
 from tornado.util import u, bytes_type
 from tornado.web import Application, RequestHandler, asynchronous
+from contextlib import closing
 import datetime
 import os
 import shutil
@@ -183,22 +184,23 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
         return Application(self.get_handlers())
 
     def raw_fetch(self, headers, body):
-        client = SimpleAsyncHTTPClient(self.io_loop)
-        conn = RawRequestHTTPConnection(
-            self.io_loop, client,
-            httpclient._RequestProxy(
-                httpclient.HTTPRequest(self.get_url("/")),
-                dict(httpclient.HTTPRequest._DEFAULTS)),
-            None, self.stop,
-            1024 * 1024, Resolver(io_loop=self.io_loop))
-        conn.set_request(
-            b"\r\n".join(headers +
-                         [utf8("Content-Length: %d\r\n" % len(body))]) +
-            b"\r\n" + body)
-        response = self.wait()
-        client.close()
-        response.rethrow()
-        return response
+        with closing(Resolver(io_loop=self.io_loop)) as resolver:
+            with closing(SimpleAsyncHTTPClient(self.io_loop,
+                                               resolver=resolver)) as client:
+                conn = RawRequestHTTPConnection(
+                    self.io_loop, client,
+                    httpclient._RequestProxy(
+                        httpclient.HTTPRequest(self.get_url("/")),
+                        dict(httpclient.HTTPRequest._DEFAULTS)),
+                    None, self.stop,
+                    1024 * 1024, resolver)
+                conn.set_request(
+                    b"\r\n".join(headers +
+                                 [utf8("Content-Length: %d\r\n" % len(body))]) +
+                    b"\r\n" + body)
+                response = self.wait()
+                response.rethrow()
+                return response
 
     def test_multipart_form(self):
         # Encodings here are tricky:  Headers are latin1, bodies can be
index 5a0d9b1bd0420b0f0590817e5908d42bb1e71ed5..a2e57fb45ea6f41a82e97e4e0520dbc546d56c00 100644 (file)
@@ -127,39 +127,39 @@ class SimpleHTTPClientTestMixin(object):
                         SimpleAsyncHTTPClient(io_loop2))
 
     def test_connection_limit(self):
-        client = self.create_client(max_clients=2)
-        self.assertEqual(client.max_clients, 2)
-        seen = []
-        # Send 4 requests.  Two can be sent immediately, while the others
-        # will be queued
-        for i in range(4):
-            client.fetch(self.get_url("/trigger"),
-                         lambda response, i=i: (seen.append(i), self.stop()))
-        self.wait(condition=lambda: len(self.triggers) == 2)
-        self.assertEqual(len(client.queue), 2)
-
-        # Finish the first two requests and let the next two through
-        self.triggers.popleft()()
-        self.triggers.popleft()()
-        self.wait(condition=lambda: (len(self.triggers) == 2 and
-                                     len(seen) == 2))
-        self.assertEqual(set(seen), set([0, 1]))
-        self.assertEqual(len(client.queue), 0)
-
-        # Finish all the pending requests
-        self.triggers.popleft()()
-        self.triggers.popleft()()
-        self.wait(condition=lambda: len(seen) == 4)
-        self.assertEqual(set(seen), set([0, 1, 2, 3]))
-        self.assertEqual(len(self.triggers), 0)
+        with closing(self.create_client(max_clients=2)) as client:
+            self.assertEqual(client.max_clients, 2)
+            seen = []
+            # Send 4 requests.  Two can be sent immediately, while the others
+            # will be queued
+            for i in range(4):
+                client.fetch(self.get_url("/trigger"),
+                             lambda response, i=i: (seen.append(i), self.stop()))
+            self.wait(condition=lambda: len(self.triggers) == 2)
+            self.assertEqual(len(client.queue), 2)
+
+            # Finish the first two requests and let the next two through
+            self.triggers.popleft()()
+            self.triggers.popleft()()
+            self.wait(condition=lambda: (len(self.triggers) == 2 and
+                                         len(seen) == 2))
+            self.assertEqual(set(seen), set([0, 1]))
+            self.assertEqual(len(client.queue), 0)
+
+            # Finish all the pending requests
+            self.triggers.popleft()()
+            self.triggers.popleft()()
+            self.wait(condition=lambda: len(seen) == 4)
+            self.assertEqual(set(seen), set([0, 1, 2, 3]))
+            self.assertEqual(len(self.triggers), 0)
 
     def test_redirect_connection_limit(self):
         # following redirects should not consume additional connections
-        client = self.create_client(max_clients=1)
-        client.fetch(self.get_url('/countdown/3'), self.stop,
-                     max_redirects=3)
-        response = self.wait()
-        response.rethrow()
+        with closing(self.create_client(max_clients=1)) as client:
+            client.fetch(self.get_url('/countdown/3'), self.stop,
+                         max_redirects=3)
+            response = self.wait()
+            response.rethrow()
 
     def test_default_certificates_exist(self):
         open(_DEFAULT_CA_CERTS).close()
index 4a62882d45ddbf9f832488e58264ba21b963f75f..54f73ecfcd604c62bbacf892e82da21668df108c 100644 (file)
@@ -758,12 +758,14 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             'Sec-WebSocket-Version': '13',
         })
 
+        self.resolver = Resolver(io_loop=io_loop)
         super(WebSocketClientConnection, self).__init__(
             io_loop, None, request, lambda: None, self._on_http_response,
-            104857600, Resolver(io_loop=io_loop))
+            104857600, self.resolver)
 
     def _on_close(self):
         self.on_message(None)
+        self.resolver.close()
 
     def _on_http_response(self, response):
         if not self.connect_future.done():