From: Ben Darnell Date: Sat, 24 May 2014 20:38:37 +0000 (-0400) Subject: Implement IOStream.start_tls to convert an IOStream to an SSLIOStream. X-Git-Tag: v4.0.0b1~52 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=db7953106f247d4ee4c4942277fa4d9c6f1c42c9;p=thirdparty%2Ftornado.git Implement IOStream.start_tls to convert an IOStream to an SSLIOStream. --- diff --git a/tornado/iostream.py b/tornado/iostream.py index 2836cbfc0..df362b783 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -969,6 +969,71 @@ class IOStream(BaseIOStream): self._add_io_state(self.io_loop.WRITE) return future + def start_tls(self, server_side, ssl_options=None, server_hostname=None): + """Convert this `IOStream` to an `SSLIOStream`. + + This enables protocols that begin in clear-text mode and + switch to SSL after some initial negotiation (such as the + ``STARTTLS`` extension to SMTP and IMAP). + + This method cannot be used if there are outstanding reads + or writes on the stream, or if there is any data in the + IOStream's buffer (data in the operating system's socket + buffer is allowed). This means it must generally be used + immediately after reading or writing the last clear-text + data. It can also be used immediately after connecting, + before any reads or writes. + + The ``ssl_options`` argument may be either a dictionary + of options or an `ssl.SSLContext`. If a ``server_hostname`` + is given, it will be used for certificate verification + (as configured in the ``ssl_options``). + + This method returns a `.Future` whose result is the new + `SSLIOStream`. After this method has been called, + any other operation on the original stream is undefined. + + If a close callback is defined on this stream, it will be + transferred to the new stream. + + .. versionadded:: 3.3 + """ + if (self._read_callback or self._read_future or + self._write_callback or self._write_future or + self._connect_callback or self._connect_future or + self._pending_callbacks or self._closed or + self._read_buffer or self._write_buffer): + raise ValueError("IOStream is not idle; cannot convert to SSL") + if ssl_options is None: + ssl_options = {} + + socket = self.socket + self.io_loop.remove_handler(socket) + self.socket = None + socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side, + do_handshake_on_connect=False) + orig_close_callback = self._close_callback + self._close_callback = None + + future = TracebackFuture() + ssl_stream = SSLIOStream(socket, ssl_options=ssl_options, + io_loop=self.io_loop) + # Wrap the original close callback so we can fail our Future as well. + # If we had an "unwrap" counterpart to this method we would need + # to restore the original callback after our Future resolves + # so that repeated wrap/unwrap calls don't build up layers. + def close_callback(): + if not future.done(): + future.set_exception(ssl_stream.error or StreamClosedError()) + if orig_close_callback is not None: + orig_close_callback() + ssl_stream.set_close_callback(close_callback) + ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream) + ssl_stream.max_buffer_size = self.max_buffer_size + ssl_stream.read_chunk_size = self.read_chunk_size + return future + + def _handle_connect(self): err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index ac91cbd4e..b3b3e82d6 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function, with_statement +from tornado.concurrent import Future +from tornado import gen from tornado import netutil -from tornado.ioloop import IOLoop from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError from tornado.httputil import HTTPHeaders from tornado.log import gen_log, app_log @@ -9,6 +10,7 @@ from tornado.stack_context import NullContext from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test from tornado.test.util import unittest, skipIfNonUnix from tornado.web import RequestHandler, Application +import certifi import errno import logging import os @@ -17,6 +19,11 @@ import socket import ssl import sys +def _server_ssl_options(): + return dict( + certfile=os.path.join(os.path.dirname(__file__), 'test.crt'), + keyfile=os.path.join(os.path.dirname(__file__), 'test.key'), + ) class HelloHandler(RequestHandler): def get(self): @@ -732,14 +739,10 @@ class TestIOStream(TestIOStreamMixin, AsyncTestCase): class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase): def _make_server_iostream(self, connection, **kwargs): - ssl_options = dict( - certfile=os.path.join(os.path.dirname(__file__), 'test.crt'), - keyfile=os.path.join(os.path.dirname(__file__), 'test.key'), - ) connection = ssl.wrap_socket(connection, server_side=True, do_handshake_on_connect=False, - **ssl_options) + **_server_ssl_options()) return SSLIOStream(connection, io_loop=self.io_loop, **kwargs) def _make_client_iostream(self, connection, **kwargs): @@ -767,6 +770,91 @@ class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase): ssl_options=context, **kwargs) +class TestIOStreamStartTLS(AsyncTestCase): + def setUp(self): + try: + super(TestIOStreamStartTLS, self).setUp() + self.listener, self.port = bind_unused_port() + self.server_stream = None + self.server_accepted = Future() + netutil.add_accept_handler(self.listener, self.accept) + self.client_stream = IOStream(socket.socket()) + self.io_loop.add_future(self.client_stream.connect( + ('127.0.0.1', self.port)), self.stop) + self.wait() + self.io_loop.add_future(self.server_accepted, self.stop) + self.wait() + except Exception as e: + print(e) + raise + + def tearDown(self): + if self.server_stream is not None: + self.server_stream.close() + if self.client_stream is not None: + self.client_stream.close() + self.listener.close() + super(TestIOStreamStartTLS, self).tearDown() + + def accept(self, connection, address): + if self.server_stream is not None: + self.fail("should only get one connection") + self.server_stream = IOStream(connection) + self.server_accepted.set_result(None) + + @gen.coroutine + def client_send_line(self, line): + self.client_stream.write(line) + recv_line = yield self.server_stream.read_until(b"\r\n") + self.assertEqual(line, recv_line) + + @gen.coroutine + def server_send_line(self, line): + self.server_stream.write(line) + recv_line = yield self.client_stream.read_until(b"\r\n") + self.assertEqual(line, recv_line) + + def client_start_tls(self, ssl_options=None): + client_stream = self.client_stream + self.client_stream = None + return client_stream.start_tls(False, ssl_options) + + def server_start_tls(self, ssl_options=None): + server_stream = self.server_stream + self.server_stream = None + return server_stream.start_tls(True, ssl_options) + + @gen_test + def test_start_tls_smtp(self): + # This flow is simplified from RFC 3207 section 5. + # We don't really need all of this, but it helps to make sure + # that after realistic back-and-forth traffic the buffers end up + # in a sane state. + yield self.server_send_line(b"220 mail.example.com ready\r\n") + yield self.client_send_line(b"EHLO mail.example.com\r\n") + yield self.server_send_line(b"250-mail.example.com welcome\r\n") + yield self.server_send_line(b"250 STARTTLS\r\n") + yield self.client_send_line(b"STARTTLS\r\n") + yield self.server_send_line(b"220 Go ahead\r\n") + client_future = self.client_start_tls() + server_future = self.server_start_tls(_server_ssl_options()) + self.client_stream = yield client_future + self.server_stream = yield server_future + self.assertTrue(isinstance(self.client_stream, SSLIOStream)) + self.assertTrue(isinstance(self.server_stream, SSLIOStream)) + yield self.client_send_line(b"EHLO mail.example.com\r\n") + yield self.server_send_line(b"250 mail.example.com welcome\r\n") + + @gen_test + def test_handshake_fail(self): + self.server_start_tls(_server_ssl_options()) + client_future = self.client_start_tls( + dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where())) + with ExpectLog(gen_log, "SSL Error"): + with self.assertRaises(ssl.SSLError): + yield client_future + + @skipIfNonUnix class TestPipeIOStream(AsyncTestCase): def test_pipe_iostream(self):