]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Add server_hostname to IOStream.connect, for SNI and cert verification
authorBen Darnell <ben@bendarnell.com>
Mon, 21 Jan 2013 03:22:02 +0000 (22:22 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 21 Jan 2013 03:33:03 +0000 (22:33 -0500)
SSL hostname verification now happens in SSLIOStream instead of
simple_httpclient (and supporting code has moved to netutil).

tornado/iostream.py
tornado/netutil.py
tornado/simple_httpclient.py
website/sphinx/releases/next.rst

index 43c52d74c97312f791dc386e1580b262d45db224..86cd68a89c59ad1eff2a88ee27a6c6068b4fef03 100644 (file)
@@ -37,7 +37,7 @@ import re
 
 from tornado import ioloop
 from tornado.log import gen_log, app_log
-from tornado.netutil import ssl_wrap_socket
+from tornado.netutil import ssl_wrap_socket, ssl_match_hostname, SSLCertificateError
 from tornado import stack_context
 from tornado.util import bytes_type
 
@@ -645,7 +645,7 @@ class IOStream(BaseIOStream):
     def write_to_fd(self, data):
         return self.socket.send(data)
 
-    def connect(self, address, callback=None):
+    def connect(self, address, callback=None, server_hostname=None):
         """Connects the socket to a remote address without blocking.
 
         May only be called if the socket passed to the constructor was
@@ -654,6 +654,11 @@ class IOStream(BaseIOStream):
         If callback is specified, it will be called when the
         connection is completed.
 
+        If specified, the ``server_hostname`` parameter will be used
+        in SSL connections for certificate validation (if requested in
+        the ``ssl_options``) and SNI (if supported; requires
+        Python 3.2+).
+
         Note that it is safe to call IOStream.write while the
         connection is pending, in which case the data will be written
         as soon as the connection is ready.  Calling IOStream read
@@ -722,6 +727,7 @@ class SSLIOStream(IOStream):
         self._handshake_reading = False
         self._handshake_writing = False
         self._ssl_connect_callback = None
+        self._server_hostname = None
 
     def reading(self):
         return self._handshake_reading or super(SSLIOStream, self).reading()
@@ -759,11 +765,41 @@ class SSLIOStream(IOStream):
                 return self.close(exc_info=True)
         else:
             self._ssl_accepting = False
+            if not self._verify_cert(self.socket.getpeercert()):
+                self.close()
+                return
             if self._ssl_connect_callback is not None:
                 callback = self._ssl_connect_callback
                 self._ssl_connect_callback = None
                 self._run_callback(callback)
 
+    def _verify_cert(self, peercert):
+        """Returns True if peercert is valid according to the configured
+        validation mode and hostname.
+
+        The ssl handshake already tested the certificate for a valid
+        CA signature; the only thing that remains is to check
+        the hostname.
+        """
+        if isinstance(self._ssl_options, dict):
+            verify_mode = self._ssl_options.get('cert_reqs', ssl.CERT_NONE)
+        elif isinstance(self._ssl_options, ssl.SSLContext):
+            verify_mode = self._ssl_options.verify_mode
+        assert verify_mode in (ssl.CERT_NONE, ssl.CERT_REQUIRED, ssl.CERT_OPTIONAL)
+        if verify_mode == ssl.CERT_NONE or self._server_hostname is None:
+            return True
+        cert = self.socket.getpeercert()
+        if cert is None and verify_mode == ssl.CERT_REQUIRED:
+            gen_log.warning("No SSL certificate given")
+            return False
+        try:
+            ssl_match_hostname(peercert, self._server_hostname)
+        except SSLCertificateError:
+            gen_log.warning("Invalid SSL certificate", exc_info=True)
+            return False
+        else:
+            return True
+
     def _handle_read(self):
         if self._ssl_accepting:
             self._do_ssl_handshake()
@@ -776,10 +812,11 @@ class SSLIOStream(IOStream):
             return
         super(SSLIOStream, self)._handle_write()
 
-    def connect(self, address, callback=None):
+    def connect(self, address, callback=None, server_hostname=None):
         # Save the user's callback and run it after the ssl handshake
         # has completed.
         self._ssl_connect_callback = callback
+        self._server_hostname = server_hostname
         super(SSLIOStream, self).connect(address, callback=None)
 
     def _handle_connect(self):
@@ -790,6 +827,7 @@ class SSLIOStream(IOStream):
         # but since _handle_events calls _handle_connect immediately
         # followed by _handle_write we need this to be synchronous.
         self.socket = ssl_wrap_socket(self.socket, self._ssl_options,
+                                      server_hostname=self._server_hostname,
                                       do_handshake_on_connect=False)
         super(SSLIOStream, self)._handle_connect()
 
index fbdef2e38f90da1f16deb6375aea9c5c20051f7b..a3e9568686969248768d7e0f204c18feaf988859 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 import errno
 import os
+import re
 import socket
 import ssl
 import stat
@@ -177,7 +178,7 @@ def ssl_options_to_context(ssl_options):
     return context
 
 
-def ssl_wrap_socket(socket, ssl_options, **kwargs):
+def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
     """Returns an `ssl.SSLSocket` wrapping the given socket.
 
     ``ssl_options`` may be either a dictionary (as accepted by
@@ -188,6 +189,77 @@ def ssl_wrap_socket(socket, ssl_options, **kwargs):
     """
     context = ssl_options_to_context(ssl_options)
     if hasattr(ssl, 'SSLContext') and isinstance(context, ssl.SSLContext):
-        return context.wrap_socket(socket, **kwargs)
+        if server_hostname is not None and getattr(ssl, 'HAS_SNI'):
+            # Python doesn't have server-side SNI support so we can't
+            # really unittest this, but it can be manually tested with
+            # python3.2 -m tornado.httpclient https://sni.velox.ch
+            return context.wrap_socket(socket, server_hostname=server_hostname,
+                                       **kwargs)
+        else:
+            return context.wrap_socket(socket, **kwargs)
     else:
         return ssl.wrap_socket(socket, **dict(context, **kwargs))
+
+if hasattr(ssl, 'match_hostname'):  # python 3.2+
+    ssl_match_hostname = ssl.match_hostname
+    SSLCertificateError = ssl.CertificateError
+else:
+    # match_hostname was added to the standard library ssl module in python 3.2.
+    # The following code was backported for older releases and copied from
+    # https://bitbucket.org/brandon/backports.ssl_match_hostname
+    class SSLCertificateError(ValueError):
+        pass
+
+
+    def _dnsname_to_pat(dn):
+        pats = []
+        for frag in dn.split(r'.'):
+            if frag == '*':
+                # When '*' is a fragment by itself, it matches a non-empty dotless
+                # fragment.
+                pats.append('[^.]+')
+            else:
+                # Otherwise, '*' matches any dotless fragment.
+                frag = re.escape(frag)
+                pats.append(frag.replace(r'\*', '[^.]*'))
+        return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
+
+
+    def ssl_match_hostname(cert, hostname):
+        """Verify that *cert* (in decoded format as returned by
+        SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 rules
+        are mostly followed, but IP addresses are not accepted for *hostname*.
+
+        CertificateError is raised on failure. On success, the function
+        returns nothing.
+        """
+        if not cert:
+            raise ValueError("empty or no certificate")
+        dnsnames = []
+        san = cert.get('subjectAltName', ())
+        for key, value in san:
+            if key == 'DNS':
+                if _dnsname_to_pat(value).match(hostname):
+                    return
+                dnsnames.append(value)
+        if not san:
+            # The subject is only checked when subjectAltName is empty
+            for sub in cert.get('subject', ()):
+                for key, value in sub:
+                    # XXX according to RFC 2818, the most specific Common Name
+                    # must be used.
+                    if key == 'commonName':
+                        if _dnsname_to_pat(value).match(hostname):
+                            return
+                        dnsnames.append(value)
+        if len(dnsnames) > 1:
+            raise SSLCertificateError("hostname %r "
+                                      "doesn't match either of %s"
+                                      % (hostname, ', '.join(map(repr, dnsnames))))
+        elif len(dnsnames) == 1:
+            raise SSLCertificateError("hostname %r "
+                                      "doesn't match %r"
+                                      % (hostname, dnsnames[0]))
+        else:
+            raise SSLCertificateError("no appropriate commonName or "
+                                      "subjectAltName fields were found")
index 65b6a6b3b8cc1910fa4b27cb6252c1829c6b2608..7827a7bfa06b07ae2d2f529193e026117c9be63c 100644 (file)
@@ -212,7 +212,10 @@ class _HTTPConnection(object):
                 self.start_time + timeout,
                 stack_context.wrap(self._on_timeout))
         self.stream.set_close_callback(self._on_close)
-        self.stream.connect(sockaddr, self._on_connect)
+        # ipv6 addresses are broken (in self.parsed.hostname) until
+        # 2.7, here is correctly parsed value calculated in __init__
+        self.stream.connect(sockaddr, self._on_connect,
+                            server_hostname=self.parsed_hostname)
 
     def _on_timeout(self):
         self._timeout = None
@@ -227,14 +230,6 @@ class _HTTPConnection(object):
             self._timeout = self.io_loop.add_timeout(
                 self.start_time + self.request.request_timeout,
                 stack_context.wrap(self._on_timeout))
-        if (self.request.validate_cert and
-                isinstance(self.stream, SSLIOStream)):
-            match_hostname(self.stream.socket.getpeercert(),
-                           # ipv6 addresses are broken (in
-                           # self.parsed.hostname) until 2.7, here is
-                           # correctly parsed value calculated in
-                           # __init__
-                           self.parsed_hostname)
         if (self.request.method not in self._SUPPORTED_METHODS and
                 not self.request.allow_nonstandard_methods):
             raise KeyError("unknown method %s" % self.request.method)
@@ -481,66 +476,6 @@ class _HTTPConnection(object):
         self.stream.read_until(b"\r\n", self._on_chunk_length)
 
 
-# match_hostname was added to the standard library ssl module in python 3.2.
-# The following code was backported for older releases and copied from
-# https://bitbucket.org/brandon/backports.ssl_match_hostname
-class CertificateError(ValueError):
-    pass
-
-
-def _dnsname_to_pat(dn):
-    pats = []
-    for frag in dn.split(r'.'):
-        if frag == '*':
-            # When '*' is a fragment by itself, it matches a non-empty dotless
-            # fragment.
-            pats.append('[^.]+')
-        else:
-            # Otherwise, '*' matches any dotless fragment.
-            frag = re.escape(frag)
-            pats.append(frag.replace(r'\*', '[^.]*'))
-    return re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
-
-
-def match_hostname(cert, hostname):
-    """Verify that *cert* (in decoded format as returned by
-    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 rules
-    are mostly followed, but IP addresses are not accepted for *hostname*.
-
-    CertificateError is raised on failure. On success, the function
-    returns nothing.
-    """
-    if not cert:
-        raise ValueError("empty or no certificate")
-    dnsnames = []
-    san = cert.get('subjectAltName', ())
-    for key, value in san:
-        if key == 'DNS':
-            if _dnsname_to_pat(value).match(hostname):
-                return
-            dnsnames.append(value)
-    if not san:
-        # The subject is only checked when subjectAltName is empty
-        for sub in cert.get('subject', ()):
-            for key, value in sub:
-                # XXX according to RFC 2818, the most specific Common Name
-                # must be used.
-                if key == 'commonName':
-                    if _dnsname_to_pat(value).match(hostname):
-                        return
-                    dnsnames.append(value)
-    if len(dnsnames) > 1:
-        raise CertificateError("hostname %r "
-                               "doesn't match either of %s"
-                               % (hostname, ', '.join(map(repr, dnsnames))))
-    elif len(dnsnames) == 1:
-        raise CertificateError("hostname %r "
-                               "doesn't match %r"
-                               % (hostname, dnsnames[0]))
-    else:
-        raise CertificateError("no appropriate commonName or "
-                               "subjectAltName fields were found")
-
 if __name__ == "__main__":
     AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
     main()
index 7cd9bc18bb372c616247445629815c1e3e38793b..aef8dae6731362bc06b871fa8d54176f1c668312 100644 (file)
@@ -219,3 +219,7 @@ In progress
 * On python 3.2+, methods that take an ``ssl_options`` argument (on
   `SSLIOStream`, `TCPServer`, and `HTTPServer`) now accept either a
   dictionary of options or an `ssl.SSLContext` object.
+* `IOStream.connect` now has an optional ``server_hostname`` argument
+  which will be used for SSL certificate validation when applicable.
+  Additionally, when supported (on Python 3.2+), this hostname
+  will be sent via SNI (and this is supported by `tornado.simple_httpclient`)