From: Ben Darnell Date: Sun, 19 Aug 2012 19:16:15 +0000 (-0700) Subject: Allow calls to SSLIOStream.write while the connection is in progress. X-Git-Tag: v2.4.0~22 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=302c5032883e9e1bfd84ad9f4895f308015a2d6c;p=thirdparty%2Ftornado.git Allow calls to SSLIOStream.write while the connection is in progress. Skip fast-path writes while connecting, and rework the interaction between base class and subclass to avoid the possibility of doubly-wrapped sockets. Closes #587. --- diff --git a/tornado/iostream.py b/tornado/iostream.py index 896cf31e0..e73df0dee 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -207,10 +207,11 @@ class IOStream(object): else: self._write_buffer.append(data) self._write_callback = stack_context.wrap(callback) - self._handle_write() - if self._write_buffer: - self._add_io_state(self.io_loop.WRITE) - self._maybe_add_error_listener() + if not self._connecting: + self._handle_write() + if self._write_buffer: + self._add_io_state(self.io_loop.WRITE) + self._maybe_add_error_listener() def set_close_callback(self, callback): """Call the given callback when the stream is closed.""" @@ -626,6 +627,7 @@ class SSLIOStream(IOStream): self._ssl_accepting = True self._handshake_reading = False self._handshake_writing = False + self._ssl_connect_callback = None def reading(self): return self._handshake_reading or super(SSLIOStream, self).reading() @@ -663,7 +665,11 @@ class SSLIOStream(IOStream): return self.close() else: self._ssl_accepting = False - super(SSLIOStream, self)._handle_connect() + if self._ssl_connect_callback is not None: + callback = self._ssl_connect_callback + self._ssl_connect_callback = None + self._run_callback(callback) + def _handle_read(self): if self._ssl_accepting: @@ -677,14 +683,23 @@ class SSLIOStream(IOStream): return super(SSLIOStream, self)._handle_write() + def connect(self, address, callback=None): + # Save the user's callback and run it after the ssl handshake + # has completed. + self._ssl_connect_callback = callback + super(SSLIOStream, self).connect(address, callback=None) + def _handle_connect(self): + # When the connection is complete, wrap the socket for SSL + # traffic. Note that we do this by overriding _handle_connect + # instead of by passing a callback to super().connect because + # user callbacks are enqueued asynchronously on the IOLoop, + # but since _handle_events calls _handle_connect immediately + # followed by _handle_write we need this to be synchronous. self.socket = ssl.wrap_socket(self.socket, do_handshake_on_connect=False, **self._ssl_options) - # Don't call the superclass's _handle_connect (which is responsible - # for telling the application that the connection is complete) - # until we've completed the SSL handshake (so certificates are - # available, etc). + super(SSLIOStream, self)._handle_connect() def _read_from_socket(self): if self._ssl_accepting: diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index 0f267b854..e58d0f77e 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -6,6 +6,7 @@ from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase from tornado.util import b from tornado.web import RequestHandler, Application import errno +import logging import os import platform import socket @@ -75,6 +76,36 @@ class TestIOStreamWebMixin(object): self.stream.close() + def test_write_while_connecting(self): + stream = self._make_client_iostream() + connected = [False] + def connected_callback(): + connected[0] = True + self.stop() + stream.connect(("localhost", self.get_http_port()), + callback=connected_callback) + # unlike the previous tests, try to write before the connection + # is complete. + written = [False] + def write_callback(): + written[0] = True + self.stop() + stream.write(b("GET / HTTP/1.0\r\nConnection: close\r\n\r\n"), + callback=write_callback) + self.assertTrue(not connected[0]) + # by the time the write has flushed, the connection callback has + # also run + try: + self.wait(lambda: connected[0] and written[0]) + finally: + logging.info((connected, written)) + + stream.read_until_close(self.stop) + data = self.wait() + self.assertTrue(data.endswith(b("Hello"))) + + stream.close() + class TestIOStreamMixin(object): def _make_server_iostream(self, connection, **kwargs):