]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add SSLIOStream.wait_for_handshake.
authorBen Darnell <ben@bendarnell.com>
Sun, 8 Mar 2015 00:43:57 +0000 (19:43 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 8 Mar 2015 00:45:02 +0000 (19:45 -0500)
This allows server-side applications to wait for the handshake to
complete in order to verify client certificates or use NPN/ALPN.

Fix a discrepancy between the callback and Future modes of
SSLIOStream.connect; now they both wait for the handshake to
complete.

tornado/iostream.py
tornado/test/iostream_test.py

index 089d413b76776062bdbd1fed1dcf6b67f338a7da..65e7b1714f27338cd2e487b7c02a37a5a94cbd1e 100644 (file)
@@ -169,6 +169,11 @@ class BaseIOStream(object):
         self._close_callback = None
         self._connect_callback = None
         self._connect_future = None
+        # _ssl_connect_future should be defined in SSLIOStream
+        # but it's here so we can clean it up in maybe_run_close_callback.
+        # TODO: refactor that so subclasses can add additional futures
+        # to be cancelled.
+        self._ssl_connect_future = None
         self._connecting = False
         self._state = None
         self._pending_callbacks = 0
@@ -437,6 +442,9 @@ class BaseIOStream(object):
             if self._connect_future is not None:
                 futures.append(self._connect_future)
                 self._connect_future = None
+            if self._ssl_connect_future is not None:
+                futures.append(self._ssl_connect_future)
+                self._ssl_connect_future = None
             for future in futures:
                 if self._is_connreset(self.error):
                     # Treat connection resets as closed connections so
@@ -1270,10 +1278,17 @@ class SSLIOStream(IOStream):
             if not self._verify_cert(self.socket.getpeercert()):
                 self.close()
                 return
-            if self._ssl_connect_callback is not None:
-                callback = self._ssl_connect_callback
-                self._ssl_connect_callback = None
-                self._run_callback(callback)
+            self._run_ssl_connect_callback()
+
+    def _run_ssl_connect_callback(self):
+        if self._ssl_connect_callback is not None:
+            callback = self._ssl_connect_callback
+            self._ssl_connect_callback = None
+            self._run_callback(callback)
+        if self._ssl_connect_future is not None:
+            future = self._ssl_connect_future
+            self._ssl_connect_future = None
+            future.set_result(self)
 
     def _verify_cert(self, peercert):
         """Returns True if peercert is valid according to the configured
@@ -1315,14 +1330,11 @@ class SSLIOStream(IOStream):
         super(SSLIOStream, self)._handle_write()
 
     def connect(self, address, callback=None, server_hostname=None):
-        # Save the user's callback and run it after the ssl handshake
-        # has completed.
-        self._ssl_connect_callback = stack_context.wrap(callback)
         self._server_hostname = server_hostname
-        # Note: Since we don't pass our callback argument along to
-        # super.connect(), this will always return a Future.
-        # This is harmless, but a bit less efficient than it could be.
-        return super(SSLIOStream, self).connect(address, callback=None)
+        # Pass a dummy callback to super.connect(), which is slightly
+        # more efficient than letting it return a Future we ignore.
+        super(SSLIOStream, self).connect(address, callback=lambda: None)
+        return self.wait_for_handshake(callback)
 
     def _handle_connect(self):
         # Call the superclass method to check for errors.
@@ -1347,6 +1359,37 @@ class SSLIOStream(IOStream):
                                       do_handshake_on_connect=False)
         self._add_io_state(old_state)
 
+    def wait_for_handshake(self, callback=None):
+        """Wait for the initial SSL handshake to complete.
+
+        If a ``callback`` is given, it will be called with no
+        arguments once the handshake is complete; otherwise this
+        method returns a `.Future` which will resolve to the
+        stream itself after the handshake is complete.
+
+        Once the handshake is complete, information such as
+        the peer's certificate and NPN/ALPN selections may be
+        accessed on ``self.socket``.
+
+        This method is intended for use on server-side streams
+        or after using `IOStream.start_tls`; it should not be used
+        with `IOStream.connect` (which already waits for the
+        handshake to complete). It may only be called once per stream.
+
+        .. versionadded:: 4.2
+        """
+        if (self._ssl_connect_callback is not None or
+                self._ssl_connect_future is not None):
+            raise RuntimeError("Already waiting")
+        if callback is not None:
+            self._ssl_connect_callback = stack_context.wrap(callback)
+            future = None
+        else:
+            future = self._ssl_connect_future = TracebackFuture()
+        if not self._ssl_accepting:
+            self._run_ssl_connect_callback()
+        return future
+
     def write_to_fd(self, data):
         try:
             return self.socket.send(data)
index c1662885f63562cb56ec5b313045fdc890d642e6..7c57324d3926e1544c680b799fa3f17b6792d66f 100644 (file)
@@ -7,6 +7,7 @@ from tornado.httputil import HTTPHeaders
 from tornado.log import gen_log, app_log
 from tornado.netutil import ssl_wrap_socket
 from tornado.stack_context import NullContext
+from tornado.tcpserver import TCPServer
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
 from tornado.test.util import unittest, skipIfNonUnix, refusing_port
 from tornado.web import RequestHandler, Application
@@ -907,6 +908,98 @@ class TestIOStreamStartTLS(AsyncTestCase):
             yield server_future
 
 
+class WaitForHandshakeTest(AsyncTestCase):
+    @gen.coroutine
+    def connect_to_server(self, server_cls):
+        server = client = None
+        try:
+            sock, port = bind_unused_port()
+            server = server_cls(ssl_options=_server_ssl_options())
+            server.add_socket(sock)
+
+            client = SSLIOStream(socket.socket(),
+                                 ssl_options=dict(cert_reqs=ssl.CERT_NONE))
+            yield client.connect(('127.0.0.1', port))
+            self.assertIsNotNone(client.socket.cipher())
+        finally:
+            if server is not None:
+                server.stop()
+            if client is not None:
+                client.close()
+
+    @gen_test
+    def test_wait_for_handshake_callback(self):
+        test = self
+        handshake_future = Future()
+
+        class TestServer(TCPServer):
+            def handle_stream(self, stream, address):
+                # The handshake has not yet completed.
+                test.assertIsNone(stream.socket.cipher())
+                self.stream = stream
+                stream.wait_for_handshake(self.handshake_done)
+
+            def handshake_done(self):
+                # Now the handshake is done and ssl information is available.
+                test.assertIsNotNone(self.stream.socket.cipher())
+                handshake_future.set_result(None)
+
+        yield self.connect_to_server(TestServer)
+        yield handshake_future
+
+    @gen_test
+    def test_wait_for_handshake_future(self):
+        test = self
+        handshake_future = Future()
+
+        class TestServer(TCPServer):
+            def handle_stream(self, stream, address):
+                test.assertIsNone(stream.socket.cipher())
+                test.io_loop.spawn_callback(self.handle_connection, stream)
+
+            @gen.coroutine
+            def handle_connection(self, stream):
+                yield stream.wait_for_handshake()
+                handshake_future.set_result(None)
+
+        yield self.connect_to_server(TestServer)
+        yield handshake_future
+
+    @gen_test
+    def test_wait_for_handshake_already_waiting_error(self):
+        test = self
+        handshake_future = Future()
+
+        class TestServer(TCPServer):
+            def handle_stream(self, stream, address):
+                stream.wait_for_handshake(self.handshake_done)
+                test.assertRaises(RuntimeError, stream.wait_for_handshake)
+
+            def handshake_done(self):
+                handshake_future.set_result(None)
+
+        yield self.connect_to_server(TestServer)
+        yield handshake_future
+
+    @gen_test
+    def test_wait_for_handshake_already_connected(self):
+        handshake_future = Future()
+
+        class TestServer(TCPServer):
+            def handle_stream(self, stream, address):
+                self.stream = stream
+                stream.wait_for_handshake(self.handshake_done)
+
+            def handshake_done(self):
+                self.stream.wait_for_handshake(self.handshake2_done)
+
+            def handshake2_done(self):
+                handshake_future.set_result(None)
+
+        yield self.connect_to_server(TestServer)
+        yield handshake_future
+
+
 @skipIfNonUnix
 class TestPipeIOStream(AsyncTestCase):
     def test_pipe_iostream(self):