From: Ben Darnell Date: Sun, 8 Mar 2015 00:43:57 +0000 (-0500) Subject: Add SSLIOStream.wait_for_handshake. X-Git-Tag: v4.2.0b1~78 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dfb752a7d77047633fdc9ec61c4bacccc6a366cb;p=thirdparty%2Ftornado.git Add SSLIOStream.wait_for_handshake. This allows server-side applications to wait for the handshake to complete in order to verify client certificates or use NPN/ALPN. Fix a discrepancy between the callback and Future modes of SSLIOStream.connect; now they both wait for the handshake to complete. --- diff --git a/tornado/iostream.py b/tornado/iostream.py index 089d413b7..65e7b1714 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -169,6 +169,11 @@ class BaseIOStream(object): self._close_callback = None self._connect_callback = None self._connect_future = None + # _ssl_connect_future should be defined in SSLIOStream + # but it's here so we can clean it up in maybe_run_close_callback. + # TODO: refactor that so subclasses can add additional futures + # to be cancelled. + self._ssl_connect_future = None self._connecting = False self._state = None self._pending_callbacks = 0 @@ -437,6 +442,9 @@ class BaseIOStream(object): if self._connect_future is not None: futures.append(self._connect_future) self._connect_future = None + if self._ssl_connect_future is not None: + futures.append(self._ssl_connect_future) + self._ssl_connect_future = None for future in futures: if self._is_connreset(self.error): # Treat connection resets as closed connections so @@ -1270,10 +1278,17 @@ class SSLIOStream(IOStream): if not self._verify_cert(self.socket.getpeercert()): self.close() return - if self._ssl_connect_callback is not None: - callback = self._ssl_connect_callback - self._ssl_connect_callback = None - self._run_callback(callback) + self._run_ssl_connect_callback() + + def _run_ssl_connect_callback(self): + if self._ssl_connect_callback is not None: + callback = self._ssl_connect_callback + self._ssl_connect_callback = None + self._run_callback(callback) + if self._ssl_connect_future is not None: + future = self._ssl_connect_future + self._ssl_connect_future = None + future.set_result(self) def _verify_cert(self, peercert): """Returns True if peercert is valid according to the configured @@ -1315,14 +1330,11 @@ class SSLIOStream(IOStream): super(SSLIOStream, self)._handle_write() def connect(self, address, callback=None, server_hostname=None): - # Save the user's callback and run it after the ssl handshake - # has completed. - self._ssl_connect_callback = stack_context.wrap(callback) self._server_hostname = server_hostname - # Note: Since we don't pass our callback argument along to - # super.connect(), this will always return a Future. - # This is harmless, but a bit less efficient than it could be. - return super(SSLIOStream, self).connect(address, callback=None) + # Pass a dummy callback to super.connect(), which is slightly + # more efficient than letting it return a Future we ignore. + super(SSLIOStream, self).connect(address, callback=lambda: None) + return self.wait_for_handshake(callback) def _handle_connect(self): # Call the superclass method to check for errors. @@ -1347,6 +1359,37 @@ class SSLIOStream(IOStream): do_handshake_on_connect=False) self._add_io_state(old_state) + def wait_for_handshake(self, callback=None): + """Wait for the initial SSL handshake to complete. + + If a ``callback`` is given, it will be called with no + arguments once the handshake is complete; otherwise this + method returns a `.Future` which will resolve to the + stream itself after the handshake is complete. + + Once the handshake is complete, information such as + the peer's certificate and NPN/ALPN selections may be + accessed on ``self.socket``. + + This method is intended for use on server-side streams + or after using `IOStream.start_tls`; it should not be used + with `IOStream.connect` (which already waits for the + handshake to complete). It may only be called once per stream. + + .. versionadded:: 4.2 + """ + if (self._ssl_connect_callback is not None or + self._ssl_connect_future is not None): + raise RuntimeError("Already waiting") + if callback is not None: + self._ssl_connect_callback = stack_context.wrap(callback) + future = None + else: + future = self._ssl_connect_future = TracebackFuture() + if not self._ssl_accepting: + self._run_ssl_connect_callback() + return future + def write_to_fd(self, data): try: return self.socket.send(data) diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index c1662885f..7c57324d3 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -7,6 +7,7 @@ from tornado.httputil import HTTPHeaders from tornado.log import gen_log, app_log from tornado.netutil import ssl_wrap_socket from tornado.stack_context import NullContext +from tornado.tcpserver import TCPServer from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test from tornado.test.util import unittest, skipIfNonUnix, refusing_port from tornado.web import RequestHandler, Application @@ -907,6 +908,98 @@ class TestIOStreamStartTLS(AsyncTestCase): yield server_future +class WaitForHandshakeTest(AsyncTestCase): + @gen.coroutine + def connect_to_server(self, server_cls): + server = client = None + try: + sock, port = bind_unused_port() + server = server_cls(ssl_options=_server_ssl_options()) + server.add_socket(sock) + + client = SSLIOStream(socket.socket(), + ssl_options=dict(cert_reqs=ssl.CERT_NONE)) + yield client.connect(('127.0.0.1', port)) + self.assertIsNotNone(client.socket.cipher()) + finally: + if server is not None: + server.stop() + if client is not None: + client.close() + + @gen_test + def test_wait_for_handshake_callback(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + # The handshake has not yet completed. + test.assertIsNone(stream.socket.cipher()) + self.stream = stream + stream.wait_for_handshake(self.handshake_done) + + def handshake_done(self): + # Now the handshake is done and ssl information is available. + test.assertIsNotNone(self.stream.socket.cipher()) + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_future(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + test.assertIsNone(stream.socket.cipher()) + test.io_loop.spawn_callback(self.handle_connection, stream) + + @gen.coroutine + def handle_connection(self, stream): + yield stream.wait_for_handshake() + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_waiting_error(self): + test = self + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + stream.wait_for_handshake(self.handshake_done) + test.assertRaises(RuntimeError, stream.wait_for_handshake) + + def handshake_done(self): + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @gen_test + def test_wait_for_handshake_already_connected(self): + handshake_future = Future() + + class TestServer(TCPServer): + def handle_stream(self, stream, address): + self.stream = stream + stream.wait_for_handshake(self.handshake_done) + + def handshake_done(self): + self.stream.wait_for_handshake(self.handshake2_done) + + def handshake2_done(self): + handshake_future.set_result(None) + + yield self.connect_to_server(TestServer) + yield handshake_future + + @skipIfNonUnix class TestPipeIOStream(AsyncTestCase): def test_pipe_iostream(self):