self._close_callback = None
self._connect_callback = None
self._connect_future = None
+ # _ssl_connect_future should be defined in SSLIOStream
+ # but it's here so we can clean it up in maybe_run_close_callback.
+ # TODO: refactor that so subclasses can add additional futures
+ # to be cancelled.
+ self._ssl_connect_future = None
self._connecting = False
self._state = None
self._pending_callbacks = 0
if self._connect_future is not None:
futures.append(self._connect_future)
self._connect_future = None
+ if self._ssl_connect_future is not None:
+ futures.append(self._ssl_connect_future)
+ self._ssl_connect_future = None
for future in futures:
if self._is_connreset(self.error):
# Treat connection resets as closed connections so
if not self._verify_cert(self.socket.getpeercert()):
self.close()
return
- if self._ssl_connect_callback is not None:
- callback = self._ssl_connect_callback
- self._ssl_connect_callback = None
- self._run_callback(callback)
+ self._run_ssl_connect_callback()
+
+ def _run_ssl_connect_callback(self):
+ if self._ssl_connect_callback is not None:
+ callback = self._ssl_connect_callback
+ self._ssl_connect_callback = None
+ self._run_callback(callback)
+ if self._ssl_connect_future is not None:
+ future = self._ssl_connect_future
+ self._ssl_connect_future = None
+ future.set_result(self)
def _verify_cert(self, peercert):
"""Returns True if peercert is valid according to the configured
super(SSLIOStream, self)._handle_write()
def connect(self, address, callback=None, server_hostname=None):
- # Save the user's callback and run it after the ssl handshake
- # has completed.
- self._ssl_connect_callback = stack_context.wrap(callback)
self._server_hostname = server_hostname
- # Note: Since we don't pass our callback argument along to
- # super.connect(), this will always return a Future.
- # This is harmless, but a bit less efficient than it could be.
- return super(SSLIOStream, self).connect(address, callback=None)
+ # Pass a dummy callback to super.connect(), which is slightly
+ # more efficient than letting it return a Future we ignore.
+ super(SSLIOStream, self).connect(address, callback=lambda: None)
+ return self.wait_for_handshake(callback)
def _handle_connect(self):
# Call the superclass method to check for errors.
do_handshake_on_connect=False)
self._add_io_state(old_state)
+ def wait_for_handshake(self, callback=None):
+ """Wait for the initial SSL handshake to complete.
+
+ If a ``callback`` is given, it will be called with no
+ arguments once the handshake is complete; otherwise this
+ method returns a `.Future` which will resolve to the
+ stream itself after the handshake is complete.
+
+ Once the handshake is complete, information such as
+ the peer's certificate and NPN/ALPN selections may be
+ accessed on ``self.socket``.
+
+ This method is intended for use on server-side streams
+ or after using `IOStream.start_tls`; it should not be used
+ with `IOStream.connect` (which already waits for the
+ handshake to complete). It may only be called once per stream.
+
+ .. versionadded:: 4.2
+ """
+ if (self._ssl_connect_callback is not None or
+ self._ssl_connect_future is not None):
+ raise RuntimeError("Already waiting")
+ if callback is not None:
+ self._ssl_connect_callback = stack_context.wrap(callback)
+ future = None
+ else:
+ future = self._ssl_connect_future = TracebackFuture()
+ if not self._ssl_accepting:
+ self._run_ssl_connect_callback()
+ return future
+
def write_to_fd(self, data):
try:
return self.socket.send(data)
from tornado.log import gen_log, app_log
from tornado.netutil import ssl_wrap_socket
from tornado.stack_context import NullContext
+from tornado.tcpserver import TCPServer
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix, refusing_port
from tornado.web import RequestHandler, Application
yield server_future
+class WaitForHandshakeTest(AsyncTestCase):
+ @gen.coroutine
+ def connect_to_server(self, server_cls):
+ server = client = None
+ try:
+ sock, port = bind_unused_port()
+ server = server_cls(ssl_options=_server_ssl_options())
+ server.add_socket(sock)
+
+ client = SSLIOStream(socket.socket(),
+ ssl_options=dict(cert_reqs=ssl.CERT_NONE))
+ yield client.connect(('127.0.0.1', port))
+ self.assertIsNotNone(client.socket.cipher())
+ finally:
+ if server is not None:
+ server.stop()
+ if client is not None:
+ client.close()
+
+ @gen_test
+ def test_wait_for_handshake_callback(self):
+ test = self
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ # The handshake has not yet completed.
+ test.assertIsNone(stream.socket.cipher())
+ self.stream = stream
+ stream.wait_for_handshake(self.handshake_done)
+
+ def handshake_done(self):
+ # Now the handshake is done and ssl information is available.
+ test.assertIsNotNone(self.stream.socket.cipher())
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
+
+ @gen_test
+ def test_wait_for_handshake_future(self):
+ test = self
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ test.assertIsNone(stream.socket.cipher())
+ test.io_loop.spawn_callback(self.handle_connection, stream)
+
+ @gen.coroutine
+ def handle_connection(self, stream):
+ yield stream.wait_for_handshake()
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
+
+ @gen_test
+ def test_wait_for_handshake_already_waiting_error(self):
+ test = self
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ stream.wait_for_handshake(self.handshake_done)
+ test.assertRaises(RuntimeError, stream.wait_for_handshake)
+
+ def handshake_done(self):
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
+
+ @gen_test
+ def test_wait_for_handshake_already_connected(self):
+ handshake_future = Future()
+
+ class TestServer(TCPServer):
+ def handle_stream(self, stream, address):
+ self.stream = stream
+ stream.wait_for_handshake(self.handshake_done)
+
+ def handshake_done(self):
+ self.stream.wait_for_handshake(self.handshake2_done)
+
+ def handshake2_done(self):
+ handshake_future.set_result(None)
+
+ yield self.connect_to_server(TestServer)
+ yield handshake_future
+
+
@skipIfNonUnix
class TestPipeIOStream(AsyncTestCase):
def test_pipe_iostream(self):