]> git.ipfire.org Git - thirdparty/gnutls.git/commitdiff
tests/pskself2: extend with RSA-PSK support
authorAlexander Sosedkin <asosedkin@redhat.com>
Tue, 21 Apr 2026 17:02:43 +0000 (19:02 +0200)
committerAlexander Sosedkin <asosedkin@redhat.com>
Wed, 29 Apr 2026 13:35:03 +0000 (15:35 +0200)
Signed-off-by: Alexander Sosedkin <asosedkin@redhat.com>
tests/pskself2.c

index c0e507d8fa82bbdc62d34120f0a064d5aecfbff2..6543d36b3401a80d91f1bf8fc43359c15b7db961 100644 (file)
@@ -27,6 +27,7 @@
 #include "config.h"
 #endif
 
+#include <stdbool.h>
 #include <stdio.h>
 #include <stdlib.h>
 
@@ -51,6 +52,7 @@ int main(int argc, char **argv)
 
 #include "utils.h"
 #include "extras/hex.h"
+#include "cert-common.h"
 
 /* A very basic TLS client, with PSK authentication.
  */
@@ -65,12 +67,13 @@ static void tls_log_func(int level, const char *str)
 #define MAX_BUF 1024
 #define MSG "Hello TLS"
 
-static void client(int sd, const char *prio, unsigned exp_hint)
+static void client(int sd, const char *prio, bool exp_hint, bool rsa)
 {
        int ret, ii;
        gnutls_session_t session;
        char buffer[MAX_BUF + 1];
        gnutls_psk_client_credentials_t pskcred;
+       gnutls_certificate_credentials_t xcred = NULL;
        /* Need to enable anonymous KX specifically. */
        const gnutls_datum_t key = { (void *)"DEADBEEF", 8 };
        gnutls_datum_t user;
@@ -112,6 +115,11 @@ static void client(int sd, const char *prio, unsigned exp_hint)
         */
        gnutls_credentials_set(session, GNUTLS_CRD_PSK, pskcred);
 
+       if (rsa) {
+               gnutls_certificate_allocate_credentials(&xcred);
+               gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, xcred);
+       }
+
        gnutls_transport_set_int(session, sd);
 
        /* Perform the TLS handshake
@@ -167,6 +175,8 @@ end:
 
        gnutls_free(user.data);
        gnutls_psk_free_client_credentials(pskcred);
+       if (xcred)
+               gnutls_certificate_free_credentials(xcred);
 
        gnutls_global_deinit();
 }
@@ -194,9 +204,10 @@ static int pskfunc(gnutls_session_t session, const gnutls_datum_t *username,
        return 0;
 }
 
-static void server(int sd, const char *prio)
+static void server(int sd, const char *prio, bool rsa)
 {
        gnutls_psk_server_credentials_t server_pskcred;
+       gnutls_certificate_credentials_t serverx509cred = NULL;
        int ret;
        gnutls_session_t session;
        gnutls_datum_t psk_username;
@@ -216,6 +227,13 @@ static void server(int sd, const char *prio)
        gnutls_psk_set_server_credentials_hint(server_pskcred, "hint");
        gnutls_psk_set_server_credentials_function2(server_pskcred, pskfunc);
 
+       if (rsa) {
+               gnutls_certificate_allocate_credentials(&serverx509cred);
+               gnutls_certificate_set_x509_key_mem(serverx509cred,
+                                                   &server_cert, &server_key,
+                                                   GNUTLS_X509_FMT_PEM);
+       }
+
        gnutls_init(&session, GNUTLS_SERVER);
 
        /* avoid calling all the priority functions, since the defaults
@@ -224,6 +242,9 @@ static void server(int sd, const char *prio)
        gnutls_priority_set_direct(session, prio, NULL);
 
        gnutls_credentials_set(session, GNUTLS_CRD_PSK, server_pskcred);
+       if (serverx509cred)
+               gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE,
+                                      serverx509cred);
 
        gnutls_transport_set_int(session, sd);
        ret = gnutls_handshake(session);
@@ -280,6 +301,8 @@ static void server(int sd, const char *prio)
        gnutls_deinit(session);
 
        gnutls_psk_free_server_credentials(server_pskcred);
+       if (serverx509cred)
+               gnutls_certificate_free_credentials(serverx509cred);
 
        gnutls_global_deinit();
 
@@ -287,7 +310,7 @@ static void server(int sd, const char *prio)
                success("server: finished\n");
 }
 
-static void run_test(const char *prio, unsigned exp_hint)
+static void run_test(const char *prio, bool exp_hint, bool rsa)
 {
        pid_t child;
        int err;
@@ -313,42 +336,46 @@ static void run_test(const char *prio, unsigned exp_hint)
                int status;
                /* parent */
                close(sockets[1]);
-               server(sockets[0], prio);
+               server(sockets[0], prio, rsa);
                wait(&status);
                check_wait_status(status);
        } else {
                close(sockets[0]);
-               client(sockets[1], prio, exp_hint);
+               client(sockets[1], prio, exp_hint, rsa);
                exit(0);
        }
 }
 
 void doit(void)
 {
-       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:-KX-ALL:+PSK", 1);
-       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:-KX-ALL:+ECDHE-PSK", 1);
-       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:-KX-ALL:+DHE-PSK", 1);
-
-       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:+PSK", 0);
-       run_test(
-               "NORMAL:-VERS-ALL:+VERS-TLS1.2:-GROUP-ALL:+GROUP-FFDHE2048:+DHE-PSK",
-               0);
-       run_test(
-               "NORMAL:-VERS-ALL:+VERS-TLS1.2:-GROUP-ALL:+GROUP-SECP256R1:+ECDHE-PSK",
-               0);
-       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.3:+PSK", 0);
-       run_test(
-               "NORMAL:-VERS-ALL:+VERS-TLS1.3:-GROUP-ALL:+GROUP-FFDHE2048:+DHE-PSK",
-               0);
-       run_test(
-               "NORMAL:-VERS-ALL:+VERS-TLS1.3:-GROUP-ALL:+GROUP-SECP256R1:+ECDHE-PSK",
-               0);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:-KX-ALL:+PSK", true, false);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:-KX-ALL:+ECDHE-PSK", true,
+                false);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:-KX-ALL:+DHE-PSK", true, false);
+
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:+PSK", false, false);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:"
+                "-GROUP-ALL:+GROUP-FFDHE2048:+DHE-PSK",
+                false, false);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:"
+                "-GROUP-ALL:+GROUP-SECP256R1:+ECDHE-PSK",
+                false, false);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.3:+PSK", false, false);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.3:"
+                "-GROUP-ALL:+GROUP-FFDHE2048:+DHE-PSK",
+                false, false);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.3:"
+                "-GROUP-ALL:+GROUP-SECP256R1:+ECDHE-PSK",
+                false, false);
        /* the following should work once we support PSK without DH */
-       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.3:-GROUP-ALL:+PSK", 0);
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.3:-GROUP-ALL:+PSK", false, false);
+
+       run_test("NORMAL:-KX-ALL:+PSK", false, false);
+       run_test("NORMAL:-KX-ALL:+ECDHE-PSK", false, false);
+       run_test("NORMAL:-KX-ALL:+DHE-PSK", false, false);
 
-       run_test("NORMAL:-KX-ALL:+PSK", 0);
-       run_test("NORMAL:-KX-ALL:+ECDHE-PSK", 0);
-       run_test("NORMAL:-KX-ALL:+DHE-PSK", 0);
+       /* RSA-PSK */
+       run_test("NORMAL:-VERS-ALL:+VERS-TLS1.2:-KX-ALL:+RSA-PSK", false, true);
 }
 
 #endif /* _WIN32 */