]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Issue #28085: Add PROTOCOL_TLS_CLIENT and PROTOCOL_TLS_SERVER for SSLContext
authorChristian Heimes <christian@python.org>
Sun, 11 Sep 2016 22:01:11 +0000 (00:01 +0200)
committerChristian Heimes <christian@python.org>
Sun, 11 Sep 2016 22:01:11 +0000 (00:01 +0200)
Doc/library/ssl.rst
Lib/ssl.py
Lib/test/test_ssl.py
Modules/_ssl.c

index e942f44ae1f0487a3a4d10bdf6d546ac399f161a..d68b8d035beae385e24689405fba2d292dbe6f80 100644 (file)
@@ -610,6 +610,22 @@ Constants
 
    .. versionadded:: 3.6
 
+.. data:: PROTOCOL_TLS_CLIENT
+
+   Auto-negotiate the the highest protocol version like :data:`PROTOCOL_SSLv23`,
+   but only support client-side :class:`SSLSocket` connections. The protocol
+   enables :data:`CERT_REQUIRED` and :attr:`~SSLContext.check_hostname` by
+   default.
+
+   .. versionadded:: 3.6
+
+.. data:: PROTOCOL_TLS_SERVER
+
+   Auto-negotiate the the highest protocol version like :data:`PROTOCOL_SSLv23`,
+   but only support server-side :class:`SSLSocket` connections.
+
+   .. versionadded:: 3.6
+
 .. data:: PROTOCOL_SSLv23
 
    Alias for data:`PROTOCOL_TLS`.
@@ -2235,18 +2251,20 @@ Protocol versions
 
 SSL versions 2 and 3 are considered insecure and are therefore dangerous to
 use.  If you want maximum compatibility between clients and servers, it is
-recommended to use :const:`PROTOCOL_TLS` as the protocol version and then
-disable SSLv2 and SSLv3 explicitly using the :data:`SSLContext.options`
-attribute::
+recommended to use :const:`PROTOCOL_TLS_CLIENT` or
+:const:`PROTOCOL_TLS_SERVER` as the protocol version. SSLv2 and SSLv3 are
+disabled by default.
+
+   client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+   client_context.options |= ssl.OP_NO_TLSv1
+   client_context.options |= ssl.OP_NO_TLSv1_1
 
-   context = ssl.SSLContext(ssl.PROTOCOL_TLS)
-   context.options |= ssl.OP_NO_SSLv2
-   context.options |= ssl.OP_NO_SSLv3
-   context.options |= ssl.OP_NO_TLSv1
-   context.options |= ssl.OP_NO_TLSv1_1
 
 The SSL context created above will only allow TLSv1.2 and later (if
-supported by your system) connections.
+supported by your system) connections to a server. :const:`PROTOCOL_TLS_CLIENT`
+implies certificate validation and hostname checks by default. You have to
+load certificates into the context.
+
 
 Cipher selection
 ''''''''''''''''
@@ -2257,8 +2275,9 @@ enabled when negotiating a SSL session is possible through the
 ssl module disables certain weak ciphers by default, but you may want
 to further restrict the cipher choice. Be sure to read OpenSSL's documentation
 about the `cipher list format <https://www.openssl.org/docs/apps/ciphers.html#CIPHER-LIST-FORMAT>`_.
-If you want to check which ciphers are enabled by a given cipher list, use the
-``openssl ciphers`` command on your system.
+If you want to check which ciphers are enabled by a given cipher list, use
+:meth:`SSLContext.get_ciphers` or the ``openssl ciphers`` command on your
+system.
 
 Multi-processing
 ^^^^^^^^^^^^^^^^
index df5e98efc7bc405dcd633c639def00088d4dba03..8ad4a339a933d726b5027899b4d3cf1923c233a7 100644 (file)
@@ -52,6 +52,8 @@ PROTOCOL_SSLv2
 PROTOCOL_SSLv3
 PROTOCOL_SSLv23
 PROTOCOL_TLS
+PROTOCOL_TLS_CLIENT
+PROTOCOL_TLS_SERVER
 PROTOCOL_TLSv1
 PROTOCOL_TLSv1_1
 PROTOCOL_TLSv1_2
index 61744ae95ad8613833ac8d57fe6a86edca657aa0..557b6dec5b501b1695b66914c85c3c0b184b9d12 100644 (file)
@@ -1342,6 +1342,17 @@ class ContextTests(unittest.TestCase):
         ctx.check_hostname = False
         self.assertFalse(ctx.check_hostname)
 
+    def test_context_client_server(self):
+        # PROTOCOL_TLS_CLIENT has sane defaults
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+        self.assertTrue(ctx.check_hostname)
+        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
+
+        # PROTOCOL_TLS_SERVER has different but also sane defaults
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        self.assertFalse(ctx.check_hostname)
+        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
+
 
 class SSLErrorTests(unittest.TestCase):
 
@@ -2280,12 +2291,33 @@ if _have_threads:
             if support.verbose:
                 sys.stdout.write("\n")
             for protocol in PROTOCOLS:
+                if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
+                    continue
                 with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
                     context = ssl.SSLContext(protocol)
                     context.load_cert_chain(CERTFILE)
                     server_params_test(context, context,
                                        chatty=True, connectionchatty=True)
 
+            client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+            client_context.load_verify_locations(SIGNING_CA)
+            server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+            # server_context.load_verify_locations(SIGNING_CA)
+            server_context.load_cert_chain(SIGNED_CERTFILE2)
+
+            with self.subTest(client='PROTOCOL_TLS_CLIENT', server='PROTOCOL_TLS_SERVER'):
+                server_params_test(client_context=client_context,
+                                   server_context=server_context,
+                                   chatty=True, connectionchatty=True,
+                                   sni_name='fakehostname')
+
+            with self.subTest(client='PROTOCOL_TLS_SERVER', server='PROTOCOL_TLS_CLIENT'):
+                with self.assertRaises(ssl.SSLError):
+                    server_params_test(client_context=server_context,
+                                       server_context=client_context,
+                                       chatty=True, connectionchatty=True,
+                                       sni_name='fakehostname')
+
         def test_getpeercert(self):
             if support.verbose:
                 sys.stdout.write("\n")
index 4d8e7e7a39d1ba2b56a24e121b6bc29e12530f42..736fc1d81046cd3d7dff33facbaad92c893ed81f 100644 (file)
@@ -140,6 +140,8 @@ struct py_ssl_library_code {
 #endif
 
 #define TLS_method SSLv23_method
+#define TLS_client_method SSLv23_client_method
+#define TLS_server_method SSLv23_server_method
 
 static int X509_NAME_ENTRY_set(const X509_NAME_ENTRY *ne)
 {
@@ -233,14 +235,16 @@ enum py_ssl_cert_requirements {
 enum py_ssl_version {
     PY_SSL_VERSION_SSL2,
     PY_SSL_VERSION_SSL3=1,
-    PY_SSL_VERSION_TLS,
+    PY_SSL_VERSION_TLS, /* SSLv23 */
 #if HAVE_TLSv1_2
     PY_SSL_VERSION_TLS1,
     PY_SSL_VERSION_TLS1_1,
-    PY_SSL_VERSION_TLS1_2
+    PY_SSL_VERSION_TLS1_2,
 #else
-    PY_SSL_VERSION_TLS1
+    PY_SSL_VERSION_TLS1,
 #endif
+    PY_SSL_VERSION_TLS_CLIENT=0x10,
+    PY_SSL_VERSION_TLS_SERVER,
 };
 
 #ifdef WITH_THREAD
@@ -2566,6 +2570,33 @@ static PyTypeObject PySSLSocket_Type = {
  * _SSLContext objects
  */
 
+static int
+_set_verify_mode(SSL_CTX *ctx, enum py_ssl_cert_requirements n)
+{
+    int mode;
+    int (*verify_cb)(int, X509_STORE_CTX *) = NULL;
+
+    switch(n) {
+    case PY_SSL_CERT_NONE:
+        mode = SSL_VERIFY_NONE;
+        break;
+    case PY_SSL_CERT_OPTIONAL:
+        mode = SSL_VERIFY_PEER;
+        break;
+    case PY_SSL_CERT_REQUIRED:
+        mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
+        break;
+    default:
+         PyErr_SetString(PyExc_ValueError,
+                        "invalid value for verify_mode");
+        return -1;
+    }
+    /* keep current verify cb */
+    verify_cb = SSL_CTX_get_verify_callback(ctx);
+    SSL_CTX_set_verify(ctx, mode, verify_cb);
+    return 0;
+}
+
 /*[clinic input]
 @classmethod
 _ssl._SSLContext.__new__
@@ -2602,8 +2633,12 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
     else if (proto_version == PY_SSL_VERSION_SSL2)
         ctx = SSL_CTX_new(SSLv2_method());
 #endif
-    else if (proto_version == PY_SSL_VERSION_TLS)
+    else if (proto_version == PY_SSL_VERSION_TLS) /* SSLv23 */
         ctx = SSL_CTX_new(TLS_method());
+    else if (proto_version == PY_SSL_VERSION_TLS_CLIENT)
+        ctx = SSL_CTX_new(TLS_client_method());
+    else if (proto_version == PY_SSL_VERSION_TLS_SERVER)
+        ctx = SSL_CTX_new(TLS_server_method());
     else
         proto_version = -1;
     PySSL_END_ALLOW_THREADS
@@ -2636,9 +2671,20 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
     self->set_hostname = NULL;
 #endif
     /* Don't check host name by default */
-    self->check_hostname = 0;
+    if (proto_version == PY_SSL_VERSION_TLS_CLIENT) {
+        self->check_hostname = 1;
+        if (_set_verify_mode(self->ctx, PY_SSL_CERT_REQUIRED) == -1) {
+            Py_DECREF(self);
+            return NULL;
+        }
+    } else {
+        self->check_hostname = 0;
+        if (_set_verify_mode(self->ctx, PY_SSL_CERT_NONE) == -1) {
+            Py_DECREF(self);
+            return NULL;
+        }
+    }
     /* Defaults */
-    SSL_CTX_set_verify(self->ctx, SSL_VERIFY_NONE, NULL);
     options = SSL_OP_ALL & ~SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
     if (proto_version != PY_SSL_VERSION_SSL2)
         options |= SSL_OP_NO_SSLv2;
@@ -2982,28 +3028,16 @@ get_verify_mode(PySSLContext *self, void *c)
 static int
 set_verify_mode(PySSLContext *self, PyObject *arg, void *c)
 {
-    int n, mode;
+    int n;
     if (!PyArg_Parse(arg, "i", &n))
         return -1;
-    if (n == PY_SSL_CERT_NONE)
-        mode = SSL_VERIFY_NONE;
-    else if (n == PY_SSL_CERT_OPTIONAL)
-        mode = SSL_VERIFY_PEER;
-    else if (n == PY_SSL_CERT_REQUIRED)
-        mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
-    else {
-        PyErr_SetString(PyExc_ValueError,
-                        "invalid value for verify_mode");
-        return -1;
-    }
-    if (mode == SSL_VERIFY_NONE && self->check_hostname) {
+    if (n == PY_SSL_CERT_NONE && self->check_hostname) {
         PyErr_SetString(PyExc_ValueError,
                         "Cannot set verify_mode to CERT_NONE when "
                         "check_hostname is enabled.");
         return -1;
     }
-    SSL_CTX_set_verify(self->ctx, mode, NULL);
-    return 0;
+    return _set_verify_mode(self->ctx, n);
 }
 
 static PyObject *
@@ -5313,6 +5347,10 @@ PyInit__ssl(void)
                             PY_SSL_VERSION_TLS);
     PyModule_AddIntConstant(m, "PROTOCOL_TLS",
                             PY_SSL_VERSION_TLS);
+    PyModule_AddIntConstant(m, "PROTOCOL_TLS_CLIENT",
+                            PY_SSL_VERSION_TLS_CLIENT);
+    PyModule_AddIntConstant(m, "PROTOCOL_TLS_SERVER",
+                            PY_SSL_VERSION_TLS_SERVER);
     PyModule_AddIntConstant(m, "PROTOCOL_TLSv1",
                             PY_SSL_VERSION_TLS1);
 #if HAVE_TLSv1_2