]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add a missing stack_context.wrap in SSLIOStream.connect.
authorBen Darnell <ben@bendarnell.com>
Thu, 16 May 2013 01:44:34 +0000 (21:44 -0400)
committerBen Darnell <ben@bendarnell.com>
Thu, 16 May 2013 01:44:34 +0000 (21:44 -0400)
Run some of simple_httpclient test in both HTTP and HTTPS modes, which
would have detected this bug.

Closes #787.

tornado/iostream.py
tornado/test/simple_httpclient_test.py

index 60126e81b71e06775e5c3ef3216108e1d6c887f6..425d3299db1a791168d70e11db33e14cff88eef1 100644 (file)
@@ -872,7 +872,7 @@ class SSLIOStream(IOStream):
     def connect(self, address, callback=None, server_hostname=None):
         # Save the user's callback and run it after the ssl handshake
         # has completed.
-        self._ssl_connect_callback = callback
+        self._ssl_connect_callback = stack_context.wrap(callback)
         self._server_hostname = server_hostname
         super(SSLIOStream, self).connect(address, callback=None)
 
index 8f028e2401a8988dceab391b00e4f26b1e76170e..5a0d9b1bd0420b0f0590817e5908d42bb1e71ed5 100644 (file)
@@ -17,7 +17,7 @@ from tornado.log import gen_log
 from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
 from tornado.test.httpclient_test import ChunkHandler, CountdownHandler, HelloWorldHandler
 from tornado.test import httpclient_test
-from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, bind_unused_port, ExpectLog
+from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
 from tornado.test.util import unittest, skipOnTravis
 from tornado.web import RequestHandler, Application, asynchronous, url
 
@@ -93,11 +93,7 @@ class HostEchoHandler(RequestHandler):
         self.write(self.request.headers["Host"])
 
 
-class SimpleHTTPClientTestCase(AsyncHTTPTestCase):
-    def setUp(self):
-        super(SimpleHTTPClientTestCase, self).setUp()
-        self.http_client = SimpleAsyncHTTPClient(self.io_loop)
-
+class SimpleHTTPClientTestMixin(object):
     def get_app(self):
         # callable objects to finish pending /trigger requests
         self.triggers = collections.deque()
@@ -131,8 +127,7 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase):
                         SimpleAsyncHTTPClient(io_loop2))
 
     def test_connection_limit(self):
-        client = SimpleAsyncHTTPClient(self.io_loop, max_clients=2,
-                                       force_instance=True)
+        client = self.create_client(max_clients=2)
         self.assertEqual(client.max_clients, 2)
         seen = []
         # Send 4 requests.  Two can be sent immediately, while the others
@@ -160,8 +155,7 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase):
 
     def test_redirect_connection_limit(self):
         # following redirects should not consume additional connections
-        client = SimpleAsyncHTTPClient(self.io_loop, max_clients=1,
-                                       force_instance=True)
+        client = self.create_client(max_clients=1)
         client.fetch(self.get_url('/countdown/3'), self.stop,
                      max_redirects=3)
         response = self.wait()
@@ -307,6 +301,27 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase):
                             response.error)
 
 
+class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
+    def setUp(self):
+        super(SimpleHTTPClientTestCase, self).setUp()
+        self.http_client = self.create_client()
+
+    def create_client(self, **kwargs):
+        return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
+                                     **kwargs)
+
+
+class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
+    def setUp(self):
+        super(SimpleHTTPSClientTestCase, self).setUp()
+        self.http_client = self.create_client()
+
+    def create_client(self, **kwargs):
+        return SimpleAsyncHTTPClient(self.io_loop, force_instance=True,
+                                     defaults=dict(validate_cert=False),
+                                     **kwargs)
+
+
 class CreateAsyncHTTPClientTestCase(AsyncTestCase):
     def setUp(self):
         super(CreateAsyncHTTPClientTestCase, self).setUp()