self._add_io_state(self.io_loop.WRITE)
return future
+ def start_tls(self, server_side, ssl_options=None, server_hostname=None):
+ """Convert this `IOStream` to an `SSLIOStream`.
+
+ This enables protocols that begin in clear-text mode and
+ switch to SSL after some initial negotiation (such as the
+ ``STARTTLS`` extension to SMTP and IMAP).
+
+ This method cannot be used if there are outstanding reads
+ or writes on the stream, or if there is any data in the
+ IOStream's buffer (data in the operating system's socket
+ buffer is allowed). This means it must generally be used
+ immediately after reading or writing the last clear-text
+ data. It can also be used immediately after connecting,
+ before any reads or writes.
+
+ The ``ssl_options`` argument may be either a dictionary
+ of options or an `ssl.SSLContext`. If a ``server_hostname``
+ is given, it will be used for certificate verification
+ (as configured in the ``ssl_options``).
+
+ This method returns a `.Future` whose result is the new
+ `SSLIOStream`. After this method has been called,
+ any other operation on the original stream is undefined.
+
+ If a close callback is defined on this stream, it will be
+ transferred to the new stream.
+
+ .. versionadded:: 3.3
+ """
+ if (self._read_callback or self._read_future or
+ self._write_callback or self._write_future or
+ self._connect_callback or self._connect_future or
+ self._pending_callbacks or self._closed or
+ self._read_buffer or self._write_buffer):
+ raise ValueError("IOStream is not idle; cannot convert to SSL")
+ if ssl_options is None:
+ ssl_options = {}
+
+ socket = self.socket
+ self.io_loop.remove_handler(socket)
+ self.socket = None
+ socket = ssl_wrap_socket(socket, ssl_options, server_side=server_side,
+ do_handshake_on_connect=False)
+ orig_close_callback = self._close_callback
+ self._close_callback = None
+
+ future = TracebackFuture()
+ ssl_stream = SSLIOStream(socket, ssl_options=ssl_options,
+ io_loop=self.io_loop)
+ # Wrap the original close callback so we can fail our Future as well.
+ # If we had an "unwrap" counterpart to this method we would need
+ # to restore the original callback after our Future resolves
+ # so that repeated wrap/unwrap calls don't build up layers.
+ def close_callback():
+ if not future.done():
+ future.set_exception(ssl_stream.error or StreamClosedError())
+ if orig_close_callback is not None:
+ orig_close_callback()
+ ssl_stream.set_close_callback(close_callback)
+ ssl_stream._ssl_connect_callback = lambda: future.set_result(ssl_stream)
+ ssl_stream.max_buffer_size = self.max_buffer_size
+ ssl_stream.read_chunk_size = self.read_chunk_size
+ return future
+
+
def _handle_connect(self):
err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
from __future__ import absolute_import, division, print_function, with_statement
+from tornado.concurrent import Future
+from tornado import gen
from tornado import netutil
-from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream, PipeIOStream, StreamClosedError
from tornado.httputil import HTTPHeaders
from tornado.log import gen_log, app_log
from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import unittest, skipIfNonUnix
from tornado.web import RequestHandler, Application
+import certifi
import errno
import logging
import os
import ssl
import sys
+def _server_ssl_options():
+ return dict(
+ certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
+ keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
+ )
class HelloHandler(RequestHandler):
def get(self):
class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
def _make_server_iostream(self, connection, **kwargs):
- ssl_options = dict(
- certfile=os.path.join(os.path.dirname(__file__), 'test.crt'),
- keyfile=os.path.join(os.path.dirname(__file__), 'test.key'),
- )
connection = ssl.wrap_socket(connection,
server_side=True,
do_handshake_on_connect=False,
- **ssl_options)
+ **_server_ssl_options())
return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
def _make_client_iostream(self, connection, **kwargs):
ssl_options=context, **kwargs)
+class TestIOStreamStartTLS(AsyncTestCase):
+ def setUp(self):
+ try:
+ super(TestIOStreamStartTLS, self).setUp()
+ self.listener, self.port = bind_unused_port()
+ self.server_stream = None
+ self.server_accepted = Future()
+ netutil.add_accept_handler(self.listener, self.accept)
+ self.client_stream = IOStream(socket.socket())
+ self.io_loop.add_future(self.client_stream.connect(
+ ('127.0.0.1', self.port)), self.stop)
+ self.wait()
+ self.io_loop.add_future(self.server_accepted, self.stop)
+ self.wait()
+ except Exception as e:
+ print(e)
+ raise
+
+ def tearDown(self):
+ if self.server_stream is not None:
+ self.server_stream.close()
+ if self.client_stream is not None:
+ self.client_stream.close()
+ self.listener.close()
+ super(TestIOStreamStartTLS, self).tearDown()
+
+ def accept(self, connection, address):
+ if self.server_stream is not None:
+ self.fail("should only get one connection")
+ self.server_stream = IOStream(connection)
+ self.server_accepted.set_result(None)
+
+ @gen.coroutine
+ def client_send_line(self, line):
+ self.client_stream.write(line)
+ recv_line = yield self.server_stream.read_until(b"\r\n")
+ self.assertEqual(line, recv_line)
+
+ @gen.coroutine
+ def server_send_line(self, line):
+ self.server_stream.write(line)
+ recv_line = yield self.client_stream.read_until(b"\r\n")
+ self.assertEqual(line, recv_line)
+
+ def client_start_tls(self, ssl_options=None):
+ client_stream = self.client_stream
+ self.client_stream = None
+ return client_stream.start_tls(False, ssl_options)
+
+ def server_start_tls(self, ssl_options=None):
+ server_stream = self.server_stream
+ self.server_stream = None
+ return server_stream.start_tls(True, ssl_options)
+
+ @gen_test
+ def test_start_tls_smtp(self):
+ # This flow is simplified from RFC 3207 section 5.
+ # We don't really need all of this, but it helps to make sure
+ # that after realistic back-and-forth traffic the buffers end up
+ # in a sane state.
+ yield self.server_send_line(b"220 mail.example.com ready\r\n")
+ yield self.client_send_line(b"EHLO mail.example.com\r\n")
+ yield self.server_send_line(b"250-mail.example.com welcome\r\n")
+ yield self.server_send_line(b"250 STARTTLS\r\n")
+ yield self.client_send_line(b"STARTTLS\r\n")
+ yield self.server_send_line(b"220 Go ahead\r\n")
+ client_future = self.client_start_tls()
+ server_future = self.server_start_tls(_server_ssl_options())
+ self.client_stream = yield client_future
+ self.server_stream = yield server_future
+ self.assertTrue(isinstance(self.client_stream, SSLIOStream))
+ self.assertTrue(isinstance(self.server_stream, SSLIOStream))
+ yield self.client_send_line(b"EHLO mail.example.com\r\n")
+ yield self.server_send_line(b"250 mail.example.com welcome\r\n")
+
+ @gen_test
+ def test_handshake_fail(self):
+ self.server_start_tls(_server_ssl_options())
+ client_future = self.client_start_tls(
+ dict(cert_reqs=ssl.CERT_REQUIRED, ca_certs=certifi.where()))
+ with ExpectLog(gen_log, "SSL Error"):
+ with self.assertRaises(ssl.SSLError):
+ yield client_future
+
+
@skipIfNonUnix
class TestPipeIOStream(AsyncTestCase):
def test_pipe_iostream(self):