]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
On py32+, methods that take ssl_options now also accept SSLContext objects.
authorBen Darnell <ben@bendarnell.com>
Mon, 21 Jan 2013 02:28:04 +0000 (21:28 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 21 Jan 2013 03:32:52 +0000 (22:32 -0500)
This is necessary to support SNI and NPN.

tornado/httpserver.py
tornado/iostream.py
tornado/netutil.py
tornado/tcpserver.py
tornado/test/httpserver_test.py
tornado/test/iostream_test.py
website/sphinx/releases/next.rst

index 83babaac8706a098d65c20a0294308325c843e8c..3f7bf41d5879df0bc8459e387731547d6366e336 100644 (file)
@@ -95,7 +95,8 @@ class HTTPServer(TCPServer):
     `HTTPServer` can serve SSL traffic with Python 2.6+ and OpenSSL.
     To make this server serve SSL traffic, send the ssl_options dictionary
     argument with the arguments required for the `ssl.wrap_socket` method,
-    including "certfile" and "keyfile"::
+    including "certfile" and "keyfile".  In Python 3.2+ you can pass
+    an `ssl.SSLContext` object instead of a dict::
 
        HTTPServer(applicaton, ssl_options={
            "certfile": os.path.join(data_dir, "mydomain.crt"),
index d98136ea40e57b2e6294e32e104773069e04f64b..43c52d74c97312f791dc386e1580b262d45db224 100644 (file)
@@ -37,6 +37,7 @@ import re
 
 from tornado import ioloop
 from tornado.log import gen_log, app_log
+from tornado.netutil import ssl_wrap_socket
 from tornado import stack_context
 from tornado.util import bytes_type
 
@@ -711,8 +712,9 @@ class SSLIOStream(IOStream):
     def __init__(self, *args, **kwargs):
         """Creates an SSLIOStream.
 
-        If a dictionary is provided as keyword argument ssl_options,
-        it will be used as additional keyword arguments to ssl.wrap_socket.
+        The ``ssl_options`` keyword argument may either be a dictionary
+        of keywords arguments for `ssl.wrap_socket`, or an `ssl.SSLContext`
+        object.
         """
         self._ssl_options = kwargs.pop('ssl_options', {})
         super(SSLIOStream, self).__init__(*args, **kwargs)
@@ -787,9 +789,8 @@ class SSLIOStream(IOStream):
         # user callbacks are enqueued asynchronously on the IOLoop,
         # but since _handle_events calls _handle_connect immediately
         # followed by _handle_write we need this to be synchronous.
-        self.socket = ssl.wrap_socket(self.socket,
-                                      do_handshake_on_connect=False,
-                                      **self._ssl_options)
+        self.socket = ssl_wrap_socket(self.socket, self._ssl_options,
+                                      do_handshake_on_connect=False)
         super(SSLIOStream, self)._handle_connect()
 
     def read_from_fd(self):
index 4dc82d25094455e5502ed1af45f790f9c0f5d2a4..fbdef2e38f90da1f16deb6375aea9c5c20051f7b 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function, with_statement
 import errno
 import os
 import socket
+import ssl
 import stat
 
 from tornado.concurrent import dummy_executor, run_on_executor
@@ -140,3 +141,53 @@ class Resolver(object):
     @run_on_executor
     def getaddrinfo(self, *args, **kwargs):
         return socket.getaddrinfo(*args, **kwargs)
+
+
+# These are the keyword arguments to ssl.wrap_socket that must be translated
+# to their SSLContext equivalents (the other arguments are still passed
+# to SSLContext.wrap_socket).
+_SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
+                                   'cert_reqs', 'ca_certs', 'ciphers'])
+
+def ssl_options_to_context(ssl_options):
+    """Try to Convert an ssl_options dictionary to an SSLContext object.
+
+    The ``ssl_options`` dictionary contains keywords to be passed to
+    `ssl.wrap_sockets`.  In Python 3.2+, `ssl.SSLContext` objects can
+    be used instead.  This function converts the dict form to its
+    `SSLContext` equivalent, and may be used when a component which
+    accepts both forms needs to upgrade to the `SSLContext` version
+    to use features like SNI or NPN.
+    """
+    if isinstance(ssl_options, dict):
+        assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
+    if (not hasattr(ssl, 'SSLContext') or
+        isinstance(ssl_options, ssl.SSLContext)):
+        return ssl_options
+    context = ssl.SSLContext(
+        ssl_options.get('ssl_version', ssl.PROTOCOL_SSLv23))
+    if 'certfile' in ssl_options:
+        context.load_cert_chain(ssl_options['certfile'], ssl_options.get('keyfile', None))
+    if 'cert_reqs' in ssl_options:
+        context.verify_mode = ssl_options['cert_reqs']
+    if 'ca_certs' in ssl_options:
+        context.load_verify_locations(ssl_options['ca_certs'])
+    if 'ciphers' in ssl_options:
+        context.set_ciphers(ssl_options['ciphers'])
+    return context
+
+
+def ssl_wrap_socket(socket, ssl_options, **kwargs):
+    """Returns an `ssl.SSLSocket` wrapping the given socket.
+
+    ``ssl_options`` may be either a dictionary (as accepted by
+    `ssl_options_to_context) or an `ssl.SSLContext` object.
+    Additional keyword arguments are passed to `wrap_socket`
+    (either the `SSLContext` method or the `ssl` module function
+    as appropriate).
+    """
+    context = ssl_options_to_context(ssl_options)
+    if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):
+        return context.wrap_socket(socket, **kwargs)
+    else:
+        return ssl.wrap_socket(socket, **dict(context, **kwargs))
index af30aa25673cb377046d4bb42550fd1f570d8046..52ed70b1d733dcfa1db79c64b3a673aab9aefbe8 100644 (file)
@@ -25,7 +25,7 @@ import ssl
 from tornado.log import app_log
 from tornado.ioloop import IOLoop
 from tornado.iostream import IOStream, SSLIOStream
-from tornado.netutil import bind_sockets, add_accept_handler
+from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
 from tornado import process
 
 class TCPServer(object):
@@ -89,7 +89,7 @@ class TCPServer(object):
         # connect. This doesn't verify that the keys are legitimate, but
         # the SSL module doesn't do that until there is a connected socket
         # which seems like too much work
-        if self.ssl_options is not None:
+        if self.ssl_options is not None and isinstance(self.ssl_options, dict):
             # Only certfile is required: it can contain both keys
             if 'certfile' not in self.ssl_options:
                 raise KeyError('missing key "certfile" in ssl_options')
@@ -206,10 +206,10 @@ class TCPServer(object):
         if self.ssl_options is not None:
             assert ssl, "Python 2.6+ and OpenSSL required for SSL"
             try:
-                connection = ssl.wrap_socket(connection,
+                connection = ssl_wrap_socket(connection,
+                                             self.ssl_options,
                                              server_side=True,
-                                             do_handshake_on_connect=False,
-                                             **self.ssl_options)
+                                             do_handshake_on_connect=False)
             except ssl.SSLError as err:
                 if err.args[0] == ssl.SSL_ERROR_EOF:
                     return connection.close()
index 6038dc290119fd6608d7474cef6ac4bace939a6c..814d54e14d37a56d33d666c213b0ef4152f95b7b 100644 (file)
@@ -8,6 +8,7 @@ from tornado.httpserver import HTTPServer
 from tornado.httputil import HTTPHeaders
 from tornado.iostream import IOStream
 from tornado.log import gen_log
+from tornado.netutil import ssl_options_to_context
 from tornado.simple_httpclient import SimpleAsyncHTTPClient
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog
 from tornado.test.util import unittest
@@ -114,6 +115,15 @@ class TLSv1Test(BaseSSLTest, SSLTestMixin):
         return ssl.PROTOCOL_TLSv1
 
 
+@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
+class SSLContextTest(BaseSSLTest, SSLTestMixin):
+    def get_ssl_options(self):
+        context = ssl_options_to_context(
+            AsyncHTTPSTestCase.get_ssl_options(self))
+        assert isinstance(context, ssl.SSLContext)
+        return context
+
+
 class BadSSLOptionsTest(unittest.TestCase):
     def test_missing_arguments(self):
         application = Application()
index a7793ec0abe7bb21621b7808b8877b0127f65bc8..6d9f878346dc71783e52a29954ec0124db4566f8 100644 (file)
@@ -3,6 +3,7 @@ from tornado import netutil
 from tornado.ioloop import IOLoop
 from tornado.iostream import IOStream, SSLIOStream, PipeIOStream
 from tornado.log import gen_log, app_log
+from tornado.netutil import ssl_wrap_socket
 from tornado.stack_context import NullContext
 from tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, bind_unused_port, ExpectLog
 from tornado.test.util import unittest, skipIfNonUnix
@@ -467,6 +468,27 @@ class TestIOStreamSSL(TestIOStreamMixin, AsyncTestCase):
         return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
 
 
+# This will run some tests that are basically redundant but it's the
+# simplest way to make sure that it works to pass an SSLContext
+# instead of an ssl_options dict to the SSLIOStream constructor.
+@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present')
+class TestIOStreamSSLContext(TestIOStreamMixin, AsyncTestCase):
+    def _make_server_iostream(self, connection, **kwargs):
+        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        context.load_cert_chain(
+            os.path.join(os.path.dirname(__file__), 'test.crt'),
+            os.path.join(os.path.dirname(__file__), 'test.key'))
+        connection = ssl_wrap_socket(connection, context,
+                                     server_side=True,
+                                     do_handshake_on_connect=False)
+        return SSLIOStream(connection, io_loop=self.io_loop, **kwargs)
+
+    def _make_client_iostream(self, connection, **kwargs):
+        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        return SSLIOStream(connection, io_loop=self.io_loop,
+                           ssl_options=context, **kwargs)
+
+
 @skipIfNonUnix
 class TestPipeIOStream(AsyncTestCase):
     def test_pipe_iostream(self):
index b15e97e2a31eca3854c9896d18b1eaa58c3fd239..7cd9bc18bb372c616247445629815c1e3e38793b 100644 (file)
@@ -216,3 +216,6 @@ In progress
   method that should be used instead of reaching into its ``stream``
   attribute.
 * `tornado.netutil.TCPServer` has moved to its own module, `tornado.tcpserver`.
+* On python 3.2+, methods that take an ``ssl_options`` argument (on
+  `SSLIOStream`, `TCPServer`, and `HTTPServer`) now accept either a
+  dictionary of options or an `ssl.SSLContext` object.