]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Implement IOStream.start_tls to convert an IOStream to an SSLIOStream.
authorBen Darnell <ben@bendarnell.com>
Sat, 24 May 2014 20:38:37 +0000 (16:38 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 24 May 2014 20:38:37 +0000 (16:38 -0400)
tornado/iostream.py
tornado/test/iostream_test.py

index 2836cbfc00837d97c2c67880fe37278f74c2e549..df362b78355b99c857ddff1770b1c0a54cdbc5ab 100644 (file)
@@ -969,6 +969,71 @@ class IOStream(BaseIOStream):
         self._add_io_state(self.io_loop.WRITE)
         return future
 
+    def start_tls(self, server_side, ssl_options=None, server_hostname=None):
+        """Convert this `IOStream` to an `SSLIOStream`.
+
+        This enables protocols that begin in clear-text mode and
+        switch to SSL after some initial negotiation (such as the
+        ``STARTTLS`` extension to SMTP and IMAP).
+
+        This method cannot be used if there are outstanding reads
+        or writes on the stream, or if there is any data in the
+        IOStream's buffer (data in the operating system's socket
+        buffer is allowed).  This means it must generally be used
+        immediately after reading or writing the last clear-text
+        data.  It can also be used immediately after connecting,
+        before any reads or writes.
+
+        The ``ssl_options`` argument may be either a dictionary
+        of options or an `ssl.SSLContext`.  If a ``server_hostname``
+        is given, it will be used for certificate verification
+        (as configured in the ``ssl_options``).
+
+        This method returns a `.Future` whose result is the new
+        `SSLIOStream`.  After this method has been called,
+        any other operation on the original stream is undefined.
+
+        If a close callback is defined on this stream, it will be
+        transferred to the new stream.
+
+        .. versionadded:: 3.3
+        """
+        if (self._read_callback or self._read_future or
+            self._write_callback or self._write_future or
+            self._connect_callback or self._connect_future or
+            self._pending_callbacks or self._closed or
+            self._read_buffer or self._write_buffer):
+            raise ValueError("IOStream is not idle; cannot convert to SSL")
+        if ssl_options is None:
+            ssl_options = {}
+
+        socket = self.socket
+        self.io_loop.remove_handler(socket)
+        self.socket = None
+        socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side,
+                                 do_handshake_on_connect=False)
+        orig_close_callback = self._close_callback
+        self._close_callback = None
+
+        future = TracebackFuture()
+        ssl_stream = SSLIOStream(socket, ssl_options=ssl_options,
+                                 io_loop=self.io_loop)
+        # Wrap the original close callback so we can fail our Future as well.
+        # If we had an "unwrap" counterpart to this method we would need
+        # to restore the original callback after our Future resolves
+        # so that repeated wrap/unwrap calls don't build up layers.
+        def close_callback():
+            if not future.done():
+                future.set_exception(ssl_stream.error or StreamClosedError())
+            if orig_close_callback is not None:
+                orig_close_callback()
+        ssl_stream.set_close_callback(close_callback)
+        ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream)
+        ssl_stream.max_buffer_size = self.max_buffer_size
+        ssl_stream.read_chunk_size = self.read_chunk_size
+        return future
+
+
     def _handle_connect(self):
         err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
         if err != 0:
index ac91cbd4e601c1ff2f6e1bc5141e0adbf698b135..b3b3e82d6974304fbbb91f4384361dff44318c64 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import absolute_import, division, print_function, with_statement
+from tornado.concurrent import Future
+from tornado import gen
 from tornado import netutil
-from tornado.ioloop import IOLoop
 from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
 from tornado.httputil import HTTPHeaders
 from tornado.log import gen_log, app_log
@@ -9,6 +10,7 @@ from tornado.stack_context import NullContext
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
 from tornado.test.util import unittest, skipIfNonUnix
 from tornado.web import RequestHandler, Application
+import certifi
 import errno
 import logging
 import os
@@ -17,6 +19,11 @@ import socket
 import ssl
 import sys
 
+def _server_ssl_options():
+    return dict(
+        certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
+        keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
+    )
 
 class HelloHandler(RequestHandler):
     def get(self):
@@ -732,14 +739,10 @@ class TestIOStream(TestIOStreamMixin, AsyncTestCase):
 
 class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
     def _make_server_iostream(self, connection, **kwargs):
-        ssl_options = dict(
-            certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
-            keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
-        )
         connection = ssl.wrap_socket(connection,
                                      server_side=True,
                                      do_handshake_on_connect=False,
-                                     **ssl_options)
+                                     **_server_ssl_options())
         return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
 
     def _make_client_iostream(self, connection, **kwargs):
@@ -767,6 +770,91 @@ class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
                            ssl_options=context, **kwargs)
 
 
+class TestIOStreamStartTLS(AsyncTestCase):
+    def setUp(self):
+        try:
+            super(TestIOStreamStartTLS, self).setUp()
+            self.listener, self.port = bind_unused_port()
+            self.server_stream = None
+            self.server_accepted = Future()
+            netutil.add_accept_handler(self.listener, self.accept)
+            self.client_stream = IOStream(socket.socket())
+            self.io_loop.add_future(self.client_stream.connect(
+                ('127.0.0.1', self.port)), self.stop)
+            self.wait()
+            self.io_loop.add_future(self.server_accepted, self.stop)
+            self.wait()
+        except Exception as e:
+            print(e)
+            raise
+
+    def tearDown(self):
+        if self.server_stream is not None:
+            self.server_stream.close()
+        if self.client_stream is not None:
+            self.client_stream.close()
+        self.listener.close()
+        super(TestIOStreamStartTLS, self).tearDown()
+
+    def accept(self, connection, address):
+        if self.server_stream is not None:
+            self.fail("should only get one connection")
+        self.server_stream = IOStream(connection)
+        self.server_accepted.set_result(None)
+
+    @gen.coroutine
+    def client_send_line(self, line):
+        self.client_stream.write(line)
+        recv_line = yield self.server_stream.read_until(b"\r\n")
+        self.assertEqual(line, recv_line)
+
+    @gen.coroutine
+    def server_send_line(self, line):
+        self.server_stream.write(line)
+        recv_line = yield self.client_stream.read_until(b"\r\n")
+        self.assertEqual(line, recv_line)
+
+    def client_start_tls(self, ssl_options=None):
+        client_stream = self.client_stream
+        self.client_stream = None
+        return client_stream.start_tls(False, ssl_options)
+
+    def server_start_tls(self, ssl_options=None):
+        server_stream = self.server_stream
+        self.server_stream = None
+        return server_stream.start_tls(True, ssl_options)
+
+    @gen_test
+    def test_start_tls_smtp(self):
+        # This flow is simplified from RFC 3207 section 5.
+        # We don't really need all of this, but it helps to make sure
+        # that after realistic back-and-forth traffic the buffers end up
+        # in a sane state.
+        yield self.server_send_line(b"220 mail.example.com ready\r\n")
+        yield self.client_send_line(b"EHLO mail.example.com\r\n")
+        yield self.server_send_line(b"250-mail.example.com welcome\r\n")
+        yield self.server_send_line(b"250 STARTTLS\r\n")
+        yield self.client_send_line(b"STARTTLS\r\n")
+        yield self.server_send_line(b"220 Go ahead\r\n")
+        client_future = self.client_start_tls()
+        server_future = self.server_start_tls(_server_ssl_options())
+        self.client_stream = yield client_future
+        self.server_stream = yield server_future
+        self.assertTrue(isinstance(self.client_stream, SSLIOStream))
+        self.assertTrue(isinstance(self.server_stream, SSLIOStream))
+        yield self.client_send_line(b"EHLO mail.example.com\r\n")
+        yield self.server_send_line(b"250 mail.example.com welcome\r\n")
+
+    @gen_test
+    def test_handshake_fail(self):
+        self.server_start_tls(_server_ssl_options())
+        client_future = self.client_start_tls(
+            dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
+        with ExpectLog(gen_log, "SSL Error"):
+            with self.assertRaises(ssl.SSLError):
+                yield client_future
+
+
 @skipIfNonUnix
 class TestPipeIOStream(AsyncTestCase):
     def test_pipe_iostream(self):