]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Allow calls to SSLIOStream.write while the connection is in progress.
authorBen Darnell <ben@bendarnell.com>
Sun, 19 Aug 2012 19:16:15 +0000 (12:16 -0700)
committerBen Darnell <ben@bendarnell.com>
Sun, 19 Aug 2012 19:16:15 +0000 (12:16 -0700)
Skip fast-path writes while connecting, and rework the interaction
between base class and subclass to avoid the possibility of doubly-wrapped
sockets.

Closes #587.

tornado/iostream.py
tornado/test/iostream_test.py

index 896cf31e0199b98b4a72eebd2abceaeca48832b3..e73df0deeeb9adbe82e2e5f82ffbda366876a3ae 100644 (file)
@@ -207,10 +207,11 @@ class IOStream(object):
             else:
                 self._write_buffer.append(data)
         self._write_callback = stack_context.wrap(callback)
-        self._handle_write()
-        if self._write_buffer:
-            self._add_io_state(self.io_loop.WRITE)
-        self._maybe_add_error_listener()
+        if not self._connecting:
+            self._handle_write()
+            if self._write_buffer:
+                self._add_io_state(self.io_loop.WRITE)
+            self._maybe_add_error_listener()
 
     def set_close_callback(self, callback):
         """Call the given callback when the stream is closed."""
@@ -626,6 +627,7 @@ class SSLIOStream(IOStream):
         self._ssl_accepting = True
         self._handshake_reading = False
         self._handshake_writing = False
+        self._ssl_connect_callback = None
 
     def reading(self):
         return self._handshake_reading or super(SSLIOStream, self).reading()
@@ -663,7 +665,11 @@ class SSLIOStream(IOStream):
                 return self.close()
         else:
             self._ssl_accepting = False
-            super(SSLIOStream, self)._handle_connect()
+            if self._ssl_connect_callback is not None:
+                callback = self._ssl_connect_callback
+                self._ssl_connect_callback = None
+                self._run_callback(callback)
+
 
     def _handle_read(self):
         if self._ssl_accepting:
@@ -677,14 +683,23 @@ class SSLIOStream(IOStream):
             return
         super(SSLIOStream, self)._handle_write()
 
+    def connect(self, address, callback=None):
+        # Save the user's callback and run it after the ssl handshake
+        # has completed.
+        self._ssl_connect_callback = callback
+        super(SSLIOStream, self).connect(address, callback=None)
+
     def _handle_connect(self):
+        # When the connection is complete, wrap the socket for SSL
+        # traffic.  Note that we do this by overriding _handle_connect
+        # instead of by passing a callback to super().connect because
+        # user callbacks are enqueued asynchronously on the IOLoop,
+        # but since _handle_events calls _handle_connect immediately
+        # followed by _handle_write we need this to be synchronous.
         self.socket = ssl.wrap_socket(self.socket,
                                       do_handshake_on_connect=False,
                                       **self._ssl_options)
-        # Don't call the superclass's _handle_connect (which is responsible
-        # for telling the application that the connection is complete)
-        # until we've completed the SSL handshake (so certificates are
-        # available, etc).
+        super(SSLIOStream, self)._handle_connect()
 
     def _read_from_socket(self):
         if self._ssl_accepting:
index 0f267b854f9ce26b773d50ea32ae6cde741a1489..e58d0f77ecaf35ee47d134691f6378d99aae9011 100644 (file)
@@ -6,6 +6,7 @@ from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase
 from tornado.util import b
 from tornado.web import RequestHandler, Application
 import errno
+import logging
 import os
 import platform
 import socket
@@ -75,6 +76,36 @@ class TestIOStreamWebMixin(object):
 
         self.stream.close()
 
+    def test_write_while_connecting(self):
+        stream = self._make_client_iostream()
+        connected = [False]
+        def connected_callback():
+            connected[0] = True
+            self.stop()
+        stream.connect(("localhost", self.get_http_port()),
+                       callback=connected_callback)
+        # unlike the previous tests, try to write before the connection
+        # is complete.
+        written = [False]
+        def write_callback():
+            written[0] = True
+            self.stop()
+        stream.write(b("GET / HTTP/1.0\r\nConnection: close\r\n\r\n"),
+                     callback=write_callback)
+        self.assertTrue(not connected[0])
+        # by the time the write has flushed, the connection callback has
+        # also run
+        try:
+            self.wait(lambda: connected[0] and written[0])
+        finally:
+            logging.info((connected, written))
+
+        stream.read_until_close(self.stop)
+        data = self.wait()
+        self.assertTrue(data.endswith(b("Hello")))
+
+        stream.close()
+
 
 class TestIOStreamMixin(object):
     def _make_server_iostream(self, connection, **kwargs):