]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-63284: Add support for TLS-PSK (pre-shared key) to the ssl module (#103181)
authorGrant Ramsay <grant.ramsay@hotmail.com>
Mon, 27 Nov 2023 04:01:44 +0000 (17:01 +1300)
committerGitHub <noreply@github.com>
Mon, 27 Nov 2023 04:01:44 +0000 (04:01 +0000)
Add support for TLS-PSK (pre-shared key) to the ssl module.

---------

Co-authored-by: Oleg Iarygin <oleg@arhadthedev.net>
Co-authored-by: Gregory P. Smith <greg@krypto.org>
Doc/library/ssl.rst
Include/internal/pycore_global_objects_fini_generated.h
Include/internal/pycore_global_strings.h
Include/internal/pycore_runtime_init_generated.h
Include/internal/pycore_unicodeobject_generated.h
Lib/test/test_ssl.py
Misc/ACKS
Misc/NEWS.d/next/Library/2023-11-27-12-41-23.gh-issue-63284.q2Qi9q.rst [new file with mode: 0644]
Modules/_ssl.c
Modules/clinic/_ssl.c.h

index 21b38ae62fe02f8fa93c5c2b34fb7de9804b8347..206294528e0016222702a7083298ef6d8ad1e852 100644 (file)
@@ -2006,6 +2006,94 @@ to speed up repeated connections from the same clients.
          >>> ssl.create_default_context().verify_mode  # doctest: +SKIP
          <VerifyMode.CERT_REQUIRED: 2>
 
+.. method:: SSLContext.set_psk_client_callback(callback)
+
+   Enables TLS-PSK (pre-shared key) authentication on a client-side connection.
+
+   In general, certificate based authentication should be preferred over this method.
+
+   The parameter ``callback`` is a callable object with the signature:
+   ``def callback(hint: str | None) -> tuple[str | None, bytes]``.
+   The ``hint`` parameter is an optional identity hint sent by the server.
+   The return value is a tuple in the form (client-identity, psk).
+   Client-identity is an optional string which may be used by the server to
+   select a corresponding PSK for the client. The string must be less than or
+   equal to ``256`` octets when UTF-8 encoded. PSK is a
+   :term:`bytes-like object` representing the pre-shared key. Return a zero
+   length PSK to reject the connection.
+
+   Setting ``callback`` to :const:`None` removes any existing callback.
+
+   .. note::
+      When using TLS 1.3:
+
+      - the ``hint`` parameter is always :const:`None`.
+      - client-identity must be a non-empty string.
+
+   Example usage::
+
+      context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+      context.check_hostname = False
+      context.verify_mode = ssl.CERT_NONE
+      context.maximum_version = ssl.TLSVersion.TLSv1_2
+      context.set_ciphers('PSK')
+
+      # A simple lambda:
+      psk = bytes.fromhex('c0ffee')
+      context.set_psk_client_callback(lambda hint: (None, psk))
+
+      # A table using the hint from the server:
+      psk_table = { 'ServerId_1': bytes.fromhex('c0ffee'),
+                    'ServerId_2': bytes.fromhex('facade')
+      }
+      def callback(hint):
+          return 'ClientId_1', psk_table.get(hint, b'')
+      context.set_psk_client_callback(callback)
+
+   .. versionadded:: 3.13
+
+.. method:: SSLContext.set_psk_server_callback(callback, identity_hint=None)
+
+   Enables TLS-PSK (pre-shared key) authentication on a server-side connection.
+
+   In general, certificate based authentication should be preferred over this method.
+
+   The parameter ``callback`` is a callable object with the signature:
+   ``def callback(identity: str | None) -> bytes``.
+   The ``identity`` parameter is an optional identity sent by the client which can
+   be used to select a corresponding PSK.
+   The return value is a :term:`bytes-like object` representing the pre-shared key.
+   Return a zero length PSK to reject the connection.
+
+   Setting ``callback`` to :const:`None` removes any existing callback.
+
+   The parameter ``identity_hint`` is an optional identity hint string sent to
+   the client. The string must be less than or equal to ``256`` octets when
+   UTF-8 encoded.
+
+   .. note::
+      When using TLS 1.3 the ``identity_hint`` parameter is not sent to the client.
+
+   Example usage::
+
+      context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+      context.maximum_version = ssl.TLSVersion.TLSv1_2
+      context.set_ciphers('PSK')
+
+      # A simple lambda:
+      psk = bytes.fromhex('c0ffee')
+      context.set_psk_server_callback(lambda identity: psk)
+
+      # A table using the identity of the client:
+      psk_table = { 'ClientId_1': bytes.fromhex('c0ffee'),
+                    'ClientId_2': bytes.fromhex('facade')
+      }
+      def callback(identity):
+          return psk_table.get(identity, b'')
+      context.set_psk_server_callback(callback, 'ServerId_1')
+
+   .. versionadded:: 3.13
+
 .. index:: single: certificates
 
 .. index:: single: X509 certificate
index 0808076f44de3101b7bddfe8114d83aa134eb294..89ec8cbbbcd649022feefe579e04c86021f7cae3 100644 (file)
@@ -826,6 +826,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) {
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(call));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(call_exception_handler));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(call_soon));
+    _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(callback));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(cancel));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(capath));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(category));
@@ -971,6 +972,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) {
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(hook));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(id));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(ident));
+    _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(identity_hint));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(ignore));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(imag));
     _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(importlib));
index 8d22a9ba261010c50e4c8c9a75b2c37ae83a22e3..62c3ee3ae2a0bd8f53239be7740ce5776a994601 100644 (file)
@@ -315,6 +315,7 @@ struct _Py_global_strings {
         STRUCT_FOR_ID(call)
         STRUCT_FOR_ID(call_exception_handler)
         STRUCT_FOR_ID(call_soon)
+        STRUCT_FOR_ID(callback)
         STRUCT_FOR_ID(cancel)
         STRUCT_FOR_ID(capath)
         STRUCT_FOR_ID(category)
@@ -460,6 +461,7 @@ struct _Py_global_strings {
         STRUCT_FOR_ID(hook)
         STRUCT_FOR_ID(id)
         STRUCT_FOR_ID(ident)
+        STRUCT_FOR_ID(identity_hint)
         STRUCT_FOR_ID(ignore)
         STRUCT_FOR_ID(imag)
         STRUCT_FOR_ID(importlib)
index d41a7478db663fd62f869c8cf07ec85b0955b6eb..1defa39f816e78297ded144d0be88b026b636868 100644 (file)
@@ -824,6 +824,7 @@ extern "C" {
     INIT_ID(call), \
     INIT_ID(call_exception_handler), \
     INIT_ID(call_soon), \
+    INIT_ID(callback), \
     INIT_ID(cancel), \
     INIT_ID(capath), \
     INIT_ID(category), \
@@ -969,6 +970,7 @@ extern "C" {
     INIT_ID(hook), \
     INIT_ID(id), \
     INIT_ID(ident), \
+    INIT_ID(identity_hint), \
     INIT_ID(ignore), \
     INIT_ID(imag), \
     INIT_ID(importlib), \
index 0c02e902b308e31fd2fb5a8ae1eea18c2bf2d002..be9baa3eebecfcf0c8b206fac1af5908ba6280f0 100644 (file)
@@ -786,6 +786,9 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) {
     string = &_Py_ID(call_soon);
     assert(_PyUnicode_CheckConsistency(string, 1));
     _PyUnicode_InternInPlace(interp, &string);
+    string = &_Py_ID(callback);
+    assert(_PyUnicode_CheckConsistency(string, 1));
+    _PyUnicode_InternInPlace(interp, &string);
     string = &_Py_ID(cancel);
     assert(_PyUnicode_CheckConsistency(string, 1));
     _PyUnicode_InternInPlace(interp, &string);
@@ -1221,6 +1224,9 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) {
     string = &_Py_ID(ident);
     assert(_PyUnicode_CheckConsistency(string, 1));
     _PyUnicode_InternInPlace(interp, &string);
+    string = &_Py_ID(identity_hint);
+    assert(_PyUnicode_CheckConsistency(string, 1));
+    _PyUnicode_InternInPlace(interp, &string);
     string = &_Py_ID(ignore);
     assert(_PyUnicode_CheckConsistency(string, 1));
     _PyUnicode_InternInPlace(interp, &string);
index d8ae7b75e1815031ac90c982386738a83b743584..9ade595ef8ae7ee32ec886913d8d3ac6cc242216 100644 (file)
@@ -4236,6 +4236,105 @@ class ThreadedTests(unittest.TestCase):
                 self.assertEqual(str(e.exception),
                                  'Session refers to a different SSLContext.')
 
+    @requires_tls_version('TLSv1_2')
+    def test_psk(self):
+        psk = bytes.fromhex('deadbeef')
+
+        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+        client_context.check_hostname = False
+        client_context.verify_mode = ssl.CERT_NONE
+        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
+        client_context.set_ciphers('PSK')
+        client_context.set_psk_client_callback(lambda hint: (None, psk))
+
+        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
+        server_context.set_ciphers('PSK')
+        server_context.set_psk_server_callback(lambda identity: psk)
+
+        # correct PSK should connect
+        server = ThreadedEchoServer(context=server_context)
+        with server:
+            with client_context.wrap_socket(socket.socket()) as s:
+                s.connect((HOST, server.port))
+
+        # incorrect PSK should fail
+        incorrect_psk = bytes.fromhex('cafebabe')
+        client_context.set_psk_client_callback(lambda hint: (None, incorrect_psk))
+        server = ThreadedEchoServer(context=server_context)
+        with server:
+            with client_context.wrap_socket(socket.socket()) as s:
+                with self.assertRaises(ssl.SSLError):
+                    s.connect((HOST, server.port))
+
+        # identity_hint and client_identity should be sent to the other side
+        identity_hint = 'identity-hint'
+        client_identity = 'client-identity'
+
+        def client_callback(hint):
+            self.assertEqual(hint, identity_hint)
+            return client_identity, psk
+
+        def server_callback(identity):
+            self.assertEqual(identity, client_identity)
+            return psk
+
+        client_context.set_psk_client_callback(client_callback)
+        server_context.set_psk_server_callback(server_callback, identity_hint)
+        server = ThreadedEchoServer(context=server_context)
+        with server:
+            with client_context.wrap_socket(socket.socket()) as s:
+                s.connect((HOST, server.port))
+
+        # adding client callback to server or vice versa raises an exception
+        with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK server callback'):
+            client_context.set_psk_server_callback(server_callback, identity_hint)
+        with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK client callback'):
+            server_context.set_psk_client_callback(client_callback)
+
+        # test with UTF-8 identities
+        identity_hint = '身份暗示'  # Translation: "Identity hint"
+        client_identity = '客户身份'  # Translation: "Customer identity"
+
+        client_context.set_psk_client_callback(client_callback)
+        server_context.set_psk_server_callback(server_callback, identity_hint)
+        server = ThreadedEchoServer(context=server_context)
+        with server:
+            with client_context.wrap_socket(socket.socket()) as s:
+                s.connect((HOST, server.port))
+
+    @requires_tls_version('TLSv1_3')
+    def test_psk_tls1_3(self):
+        psk = bytes.fromhex('deadbeef')
+        identity_hint = 'identity-hint'
+        client_identity = 'client-identity'
+
+        def client_callback(hint):
+            # identity_hint is not sent to the client in TLS 1.3
+            self.assertIsNone(hint)
+            return client_identity, psk
+
+        def server_callback(identity):
+            self.assertEqual(identity, client_identity)
+            return psk
+
+        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+        client_context.check_hostname = False
+        client_context.verify_mode = ssl.CERT_NONE
+        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
+        client_context.set_ciphers('PSK')
+        client_context.set_psk_client_callback(client_callback)
+
+        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
+        server_context.set_ciphers('PSK')
+        server_context.set_psk_server_callback(server_callback, identity_hint)
+
+        server = ThreadedEchoServer(context=server_context)
+        with server:
+            with client_context.wrap_socket(socket.socket()) as s:
+                s.connect((HOST, server.port))
+
 
 @unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
 class TestPostHandshakeAuth(unittest.TestCase):
index 6d3a4e3fdb8fe7215831902945450012bea27765..5fe3a177a262925c0b8e37ce4d050fbfbf5d1aaa 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -1482,6 +1482,7 @@ Ajith Ramachandran
 Dhushyanth Ramasamy
 Ashwin Ramaswami
 Jeff Ramnani
+Grant Ramsay
 Bayard Randel
 Varpu Rantala
 Brodie Rao
diff --git a/Misc/NEWS.d/next/Library/2023-11-27-12-41-23.gh-issue-63284.q2Qi9q.rst b/Misc/NEWS.d/next/Library/2023-11-27-12-41-23.gh-issue-63284.q2Qi9q.rst
new file mode 100644 (file)
index 0000000..abb57dc
--- /dev/null
@@ -0,0 +1 @@
+Added support for TLS-PSK (pre-shared key) mode to the :mod:`ssl` module.
index 7bc30cb3529d4742f6e4655e500eab8f9d0c09b4..707e7ad9543acb679d9296b0d250040163677975 100644 (file)
@@ -301,6 +301,8 @@ typedef struct {
     BIO *keylog_bio;
     /* Cached module state, also used in SSLSocket and SSLSession code. */
     _sslmodulestate *state;
+    PyObject *psk_client_callback;
+    PyObject *psk_server_callback;
 } PySSLContext;
 
 typedef struct {
@@ -3123,6 +3125,8 @@ _ssl__SSLContext_impl(PyTypeObject *type, int proto_version)
     self->alpn_protocols = NULL;
     self->set_sni_cb = NULL;
     self->state = get_ssl_state(module);
+    self->psk_client_callback = NULL;
+    self->psk_server_callback = NULL;
 
     /* Don't check host name by default */
     if (proto_version == PY_SSL_VERSION_TLS_CLIENT) {
@@ -3235,6 +3239,8 @@ context_clear(PySSLContext *self)
     Py_CLEAR(self->set_sni_cb);
     Py_CLEAR(self->msg_cb);
     Py_CLEAR(self->keylog_filename);
+    Py_CLEAR(self->psk_client_callback);
+    Py_CLEAR(self->psk_server_callback);
     if (self->keylog_bio != NULL) {
         PySSL_BEGIN_ALLOW_THREADS
         BIO_free_all(self->keylog_bio);
@@ -4662,6 +4668,222 @@ _ssl__SSLContext_get_ca_certs_impl(PySSLContext *self, int binary_form)
     return NULL;
 }
 
+static unsigned int psk_client_callback(SSL *s,
+                                        const char *hint,
+                                        char *identity,
+                                        unsigned int max_identity_len,
+                                        unsigned char *psk,
+                                        unsigned int max_psk_len)
+{
+    PyGILState_STATE gstate = PyGILState_Ensure();
+    PyObject *callback = NULL;
+
+    PySSLSocket *ssl = SSL_get_app_data(s);
+    if (ssl == NULL || ssl->ctx == NULL) {
+        goto error;
+    }
+    callback = ssl->ctx->psk_client_callback;
+    if (callback == NULL) {
+        goto error;
+    }
+
+    PyObject *hint_str = (hint != NULL && hint[0] != '\0') ?
+            PyUnicode_DecodeUTF8(hint, strlen(hint), "strict") :
+            Py_NewRef(Py_None);
+    if (hint_str == NULL) {
+        /* The remote side has sent an invalid UTF-8 string
+         * (breaking the standard), drop the connection without
+         * raising a decode exception. */
+        PyErr_Clear();
+        goto error;
+    }
+    PyObject *result = PyObject_CallFunctionObjArgs(callback, hint_str, NULL);
+    Py_DECREF(hint_str);
+
+    if (result == NULL) {
+        goto error;
+    }
+
+    const char *psk_;
+    const char *identity_;
+    Py_ssize_t psk_len_;
+    Py_ssize_t identity_len_ = 0;
+    if (!PyArg_ParseTuple(result, "z#y#", &identity_, &identity_len_, &psk_, &psk_len_)) {
+        Py_DECREF(result);
+        goto error;
+    }
+
+    if (identity_len_ + 1 > max_identity_len || psk_len_ > max_psk_len) {
+        Py_DECREF(result);
+        goto error;
+    }
+    memcpy(psk, psk_, psk_len_);
+    if (identity_ != NULL) {
+        memcpy(identity, identity_, identity_len_);
+    }
+    identity[identity_len_] = 0;
+
+    Py_DECREF(result);
+
+    PyGILState_Release(gstate);
+    return (unsigned int)psk_len_;
+
+error:
+    if (PyErr_Occurred()) {
+        PyErr_WriteUnraisable(callback);
+    }
+    PyGILState_Release(gstate);
+    return 0;
+}
+
+/*[clinic input]
+_ssl._SSLContext.set_psk_client_callback
+    callback: object
+
+[clinic start generated code]*/
+
+static PyObject *
+_ssl__SSLContext_set_psk_client_callback_impl(PySSLContext *self,
+                                              PyObject *callback)
+/*[clinic end generated code: output=0aba86f6ed75119e input=7627bae0e5ee7635]*/
+{
+    if (self->protocol == PY_SSL_VERSION_TLS_SERVER) {
+        _setSSLError(get_state_ctx(self),
+                     "Cannot add PSK client callback to a "
+                     "PROTOCOL_TLS_SERVER context", 0, __FILE__, __LINE__);
+        return NULL;
+    }
+
+    SSL_psk_client_cb_func ssl_callback;
+    if (callback == Py_None) {
+        callback = NULL;
+        // Delete the existing callback
+        ssl_callback = NULL;
+    } else {
+        if (!PyCallable_Check(callback)) {
+            PyErr_SetString(PyExc_TypeError, "callback must be callable");
+            return NULL;
+        }
+        ssl_callback = psk_client_callback;
+    }
+
+    Py_XDECREF(self->psk_client_callback);
+    Py_XINCREF(callback);
+
+    self->psk_client_callback = callback;
+    SSL_CTX_set_psk_client_callback(self->ctx, ssl_callback);
+
+    Py_RETURN_NONE;
+}
+
+static unsigned int psk_server_callback(SSL *s,
+                                        const char *identity,
+                                        unsigned char *psk,
+                                        unsigned int max_psk_len)
+{
+    PyGILState_STATE gstate = PyGILState_Ensure();
+    PyObject *callback = NULL;
+
+    PySSLSocket *ssl = SSL_get_app_data(s);
+    if (ssl == NULL || ssl->ctx == NULL) {
+        goto error;
+    }
+    callback = ssl->ctx->psk_server_callback;
+    if (callback == NULL) {
+        goto error;
+    }
+
+    PyObject *identity_str = (identity != NULL && identity[0] != '\0') ?
+            PyUnicode_DecodeUTF8(identity, strlen(identity), "strict") :
+            Py_NewRef(Py_None);
+    if (identity_str == NULL) {
+        /* The remote side has sent an invalid UTF-8 string
+         * (breaking the standard), drop the connection without
+         * raising a decode exception. */
+        PyErr_Clear();
+        goto error;
+    }
+    PyObject *result = PyObject_CallFunctionObjArgs(callback, identity_str, NULL);
+    Py_DECREF(identity_str);
+
+    if (result == NULL) {
+        goto error;
+    }
+
+    char *psk_;
+    Py_ssize_t psk_len_;
+    if (PyBytes_AsStringAndSize(result, &psk_, &psk_len_) < 0) {
+        Py_DECREF(result);
+        goto error;
+    }
+
+    if (psk_len_ > max_psk_len) {
+        Py_DECREF(result);
+        goto error;
+    }
+    memcpy(psk, psk_, psk_len_);
+
+    Py_DECREF(result);
+
+    PyGILState_Release(gstate);
+    return (unsigned int)psk_len_;
+
+error:
+    if (PyErr_Occurred()) {
+        PyErr_WriteUnraisable(callback);
+    }
+    PyGILState_Release(gstate);
+    return 0;
+}
+
+/*[clinic input]
+_ssl._SSLContext.set_psk_server_callback
+    callback: object
+    identity_hint: str(accept={str, NoneType}) = None
+
+[clinic start generated code]*/
+
+static PyObject *
+_ssl__SSLContext_set_psk_server_callback_impl(PySSLContext *self,
+                                              PyObject *callback,
+                                              const char *identity_hint)
+/*[clinic end generated code: output=1f4d6a4e09a92b03 input=65d4b6022aa85ea3]*/
+{
+    if (self->protocol == PY_SSL_VERSION_TLS_CLIENT) {
+        _setSSLError(get_state_ctx(self),
+                     "Cannot add PSK server callback to a "
+                     "PROTOCOL_TLS_CLIENT context", 0, __FILE__, __LINE__);
+        return NULL;
+    }
+
+    SSL_psk_server_cb_func ssl_callback;
+    if (callback == Py_None) {
+        callback = NULL;
+        // Delete the existing callback and hint
+        ssl_callback = NULL;
+        identity_hint = NULL;
+    } else {
+        if (!PyCallable_Check(callback)) {
+            PyErr_SetString(PyExc_TypeError, "callback must be callable");
+            return NULL;
+        }
+        ssl_callback = psk_server_callback;
+    }
+
+    if (SSL_CTX_use_psk_identity_hint(self->ctx, identity_hint) != 1) {
+        PyErr_SetString(PyExc_ValueError, "failed to set identity hint");
+        return NULL;
+    }
+
+    Py_XDECREF(self->psk_server_callback);
+    Py_XINCREF(callback);
+
+    self->psk_server_callback = callback;
+    SSL_CTX_set_psk_server_callback(self->ctx, ssl_callback);
+
+    Py_RETURN_NONE;
+}
+
 
 static PyGetSetDef context_getsetlist[] = {
     {"check_hostname", (getter) get_check_hostname,
@@ -4716,6 +4938,8 @@ static struct PyMethodDef context_methods[] = {
     _SSL__SSLCONTEXT_CERT_STORE_STATS_METHODDEF
     _SSL__SSLCONTEXT_GET_CA_CERTS_METHODDEF
     _SSL__SSLCONTEXT_GET_CIPHERS_METHODDEF
+    _SSL__SSLCONTEXT_SET_PSK_CLIENT_CALLBACK_METHODDEF
+    _SSL__SSLCONTEXT_SET_PSK_SERVER_CALLBACK_METHODDEF
     {NULL, NULL}        /* sentinel */
 };
 
index 88401b0490a1bbf973b8e094dda772853f96a058..19c0f619b92f45c4d2968ac8837bffc8c1b839d6 100644 (file)
@@ -1014,6 +1014,141 @@ exit:
     return return_value;
 }
 
+PyDoc_STRVAR(_ssl__SSLContext_set_psk_client_callback__doc__,
+"set_psk_client_callback($self, /, callback)\n"
+"--\n"
+"\n");
+
+#define _SSL__SSLCONTEXT_SET_PSK_CLIENT_CALLBACK_METHODDEF    \
+    {"set_psk_client_callback", _PyCFunction_CAST(_ssl__SSLContext_set_psk_client_callback), METH_FASTCALL|METH_KEYWORDS, _ssl__SSLContext_set_psk_client_callback__doc__},
+
+static PyObject *
+_ssl__SSLContext_set_psk_client_callback_impl(PySSLContext *self,
+                                              PyObject *callback);
+
+static PyObject *
+_ssl__SSLContext_set_psk_client_callback(PySSLContext *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+{
+    PyObject *return_value = NULL;
+    #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
+
+    #define NUM_KEYWORDS 1
+    static struct {
+        PyGC_Head _this_is_not_used;
+        PyObject_VAR_HEAD
+        PyObject *ob_item[NUM_KEYWORDS];
+    } _kwtuple = {
+        .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
+        .ob_item = { &_Py_ID(callback), },
+    };
+    #undef NUM_KEYWORDS
+    #define KWTUPLE (&_kwtuple.ob_base.ob_base)
+
+    #else  // !Py_BUILD_CORE
+    #  define KWTUPLE NULL
+    #endif  // !Py_BUILD_CORE
+
+    static const char * const _keywords[] = {"callback", NULL};
+    static _PyArg_Parser _parser = {
+        .keywords = _keywords,
+        .fname = "set_psk_client_callback",
+        .kwtuple = KWTUPLE,
+    };
+    #undef KWTUPLE
+    PyObject *argsbuf[1];
+    PyObject *callback;
+
+    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 1, 0, argsbuf);
+    if (!args) {
+        goto exit;
+    }
+    callback = args[0];
+    return_value = _ssl__SSLContext_set_psk_client_callback_impl(self, callback);
+
+exit:
+    return return_value;
+}
+
+PyDoc_STRVAR(_ssl__SSLContext_set_psk_server_callback__doc__,
+"set_psk_server_callback($self, /, callback, identity_hint=None)\n"
+"--\n"
+"\n");
+
+#define _SSL__SSLCONTEXT_SET_PSK_SERVER_CALLBACK_METHODDEF    \
+    {"set_psk_server_callback", _PyCFunction_CAST(_ssl__SSLContext_set_psk_server_callback), METH_FASTCALL|METH_KEYWORDS, _ssl__SSLContext_set_psk_server_callback__doc__},
+
+static PyObject *
+_ssl__SSLContext_set_psk_server_callback_impl(PySSLContext *self,
+                                              PyObject *callback,
+                                              const char *identity_hint);
+
+static PyObject *
+_ssl__SSLContext_set_psk_server_callback(PySSLContext *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+{
+    PyObject *return_value = NULL;
+    #if defined(Py_BUILD_CORE) && !defined(Py_BUILD_CORE_MODULE)
+
+    #define NUM_KEYWORDS 2
+    static struct {
+        PyGC_Head _this_is_not_used;
+        PyObject_VAR_HEAD
+        PyObject *ob_item[NUM_KEYWORDS];
+    } _kwtuple = {
+        .ob_base = PyVarObject_HEAD_INIT(&PyTuple_Type, NUM_KEYWORDS)
+        .ob_item = { &_Py_ID(callback), &_Py_ID(identity_hint), },
+    };
+    #undef NUM_KEYWORDS
+    #define KWTUPLE (&_kwtuple.ob_base.ob_base)
+
+    #else  // !Py_BUILD_CORE
+    #  define KWTUPLE NULL
+    #endif  // !Py_BUILD_CORE
+
+    static const char * const _keywords[] = {"callback", "identity_hint", NULL};
+    static _PyArg_Parser _parser = {
+        .keywords = _keywords,
+        .fname = "set_psk_server_callback",
+        .kwtuple = KWTUPLE,
+    };
+    #undef KWTUPLE
+    PyObject *argsbuf[2];
+    Py_ssize_t noptargs = nargs + (kwnames ? PyTuple_GET_SIZE(kwnames) : 0) - 1;
+    PyObject *callback;
+    const char *identity_hint = NULL;
+
+    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 1, 2, 0, argsbuf);
+    if (!args) {
+        goto exit;
+    }
+    callback = args[0];
+    if (!noptargs) {
+        goto skip_optional_pos;
+    }
+    if (args[1] == Py_None) {
+        identity_hint = NULL;
+    }
+    else if (PyUnicode_Check(args[1])) {
+        Py_ssize_t identity_hint_length;
+        identity_hint = PyUnicode_AsUTF8AndSize(args[1], &identity_hint_length);
+        if (identity_hint == NULL) {
+            goto exit;
+        }
+        if (strlen(identity_hint) != (size_t)identity_hint_length) {
+            PyErr_SetString(PyExc_ValueError, "embedded null character");
+            goto exit;
+        }
+    }
+    else {
+        _PyArg_BadArgument("set_psk_server_callback", "argument 'identity_hint'", "str or None", args[1]);
+        goto exit;
+    }
+skip_optional_pos:
+    return_value = _ssl__SSLContext_set_psk_server_callback_impl(self, callback, identity_hint);
+
+exit:
+    return return_value;
+}
+
 static PyObject *
 _ssl_MemoryBIO_impl(PyTypeObject *type);
 
@@ -1527,4 +1662,4 @@ exit:
 #ifndef _SSL_ENUM_CRLS_METHODDEF
     #define _SSL_ENUM_CRLS_METHODDEF
 #endif /* !defined(_SSL_ENUM_CRLS_METHODDEF) */
-/*[clinic end generated code: output=aa6b0a898b6077fe input=a9049054013a1b77]*/
+/*[clinic end generated code: output=6342ea0062ab16c7 input=a9049054013a1b77]*/