]> git.ipfire.org Git - thirdparty/openssl.git/blobdiff - test/helpers/quictestlib.c
Copyright year updates
[thirdparty/openssl.git] / test / helpers / quictestlib.c
index 017ba54b5bf6a6178eaec27d1d59d4683d6519ba..f0955559dcac251fdd172771b1f5699502fbea9e 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2022 The OpenSSL Project Authors. All Rights Reserved.
+ * Copyright 2022-2023 The OpenSSL Project Authors. All Rights Reserved.
  *
  * Licensed under the Apache License 2.0 (the "License").  You may not use
  * this file except in compliance with the License.  You can obtain a copy
@@ -8,17 +8,24 @@
  */
 
 #include <assert.h>
+#include <openssl/configuration.h>
 #include <openssl/bio.h>
 #include "quictestlib.h"
+#include "ssltestlib.h"
 #include "../testutil.h"
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+# include "../threadstest.h"
+#endif
+#include "internal/quic_ssl.h"
 #include "internal/quic_wire_pkt.h"
 #include "internal/quic_record_tx.h"
 #include "internal/quic_error.h"
 #include "internal/packet.h"
+#include "internal/tsan_assist.h"
 
 #define GROWTH_ALLOWANCE 1024
 
-struct ossl_quic_fault {
+struct qtest_fault {
     QUIC_TSERVER *qtserv;
 
     /* Plain packet mutations */
@@ -26,9 +33,9 @@ struct ossl_quic_fault {
     QUIC_PKT_HDR pplainhdr;
     /* iovec for the plaintext packet data buffer */
     OSSL_QTX_IOVEC pplainio;
-    /* Allocted size of the plaintext packet data buffer */
+    /* Allocated size of the plaintext packet data buffer */
     size_t pplainbuf_alloc;
-    ossl_quic_fault_on_packet_plain_cb pplaincb;
+    qtest_fault_on_packet_plain_cb pplaincb;
     void *pplaincbarg;
 
     /* Handshake message mutations */
@@ -38,24 +45,38 @@ struct ossl_quic_fault {
     size_t handbufalloc;
     /* Actual length of the handshake message */
     size_t handbuflen;
-    ossl_quic_fault_on_handshake_cb handshakecb;
+    qtest_fault_on_handshake_cb handshakecb;
     void *handshakecbarg;
-    ossl_quic_fault_on_enc_ext_cb encextcb;
+    qtest_fault_on_enc_ext_cb encextcb;
     void *encextcbarg;
 
     /* Cipher packet mutations */
-    ossl_quic_fault_on_packet_cipher_cb pciphercb;
+    qtest_fault_on_packet_cipher_cb pciphercb;
     void *pciphercbarg;
+
+    /* Datagram mutations */
+    qtest_fault_on_datagram_cb datagramcb;
+    void *datagramcbarg;
+    /* The currently processed message */
+    BIO_MSG msg;
+    /* Allocated size of msg data buffer */
+    size_t msgalloc;
 };
 
 static void packet_plain_finish(void *arg);
 static void handshake_finish(void *arg);
 
-static BIO_METHOD *get_bio_method(void);
+static OSSL_TIME fake_now;
 
-int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
-                              QUIC_TSERVER **qtserv, SSL **cssl,
-                              OSSL_QUIC_FAULT **fault)
+static OSSL_TIME fake_now_cb(void *arg)
+{
+    return fake_now;
+}
+
+int qtest_create_quic_objects(OSSL_LIB_CTX *libctx, SSL_CTX *clientctx,
+                              SSL_CTX *serverctx, char *certfile, char *keyfile,
+                              int flags, QUIC_TSERVER **qtserv, SSL **cssl,
+                              QTEST_FAULT **fault)
 {
     /* ALPN value as recognised by QUIC_TSERVER */
     unsigned char alpn[] = { 8, 'o', 's', 's', 'l', 't', 'e', 's', 't' };
@@ -67,35 +88,65 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
     *qtserv = NULL;
     if (fault != NULL)
         *fault = NULL;
-    *cssl = SSL_new(clientctx);
-    if (!TEST_ptr(*cssl))
-        return 0;
 
-    if (!TEST_true(SSL_set_blocking_mode(*cssl, 0)))
-        goto err;
+    if (*cssl == NULL) {
+        *cssl = SSL_new(clientctx);
+        if (!TEST_ptr(*cssl))
+            return 0;
+    }
 
     /* SSL_set_alpn_protos returns 0 for success! */
     if (!TEST_false(SSL_set_alpn_protos(*cssl, alpn, sizeof(alpn))))
         goto err;
 
-    if (!TEST_true(BIO_new_bio_dgram_pair(&cbio, 0, &sbio, 0)))
+    if (!TEST_ptr(peeraddr = BIO_ADDR_new()))
         goto err;
 
-    if (!TEST_true(BIO_dgram_set_caps(cbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
-            || !TEST_true(BIO_dgram_set_caps(sbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
+    if ((flags & QTEST_FLAG_BLOCK) != 0) {
+#if !defined(OPENSSL_NO_POSIX_IO)
+        int cfd, sfd;
+
+        /*
+         * For blocking mode we need to create actual sockets rather than doing
+         * everything in memory
+         */
+        if (!TEST_true(create_test_sockets(&cfd, &sfd, SOCK_DGRAM, peeraddr)))
+            goto err;
+        cbio = BIO_new_dgram(cfd, 1);
+        if (!TEST_ptr(cbio)) {
+            close(cfd);
+            close(sfd);
+            goto err;
+        }
+        sbio = BIO_new_dgram(sfd, 1);
+        if (!TEST_ptr(sbio)) {
+            close(sfd);
+            goto err;
+        }
+#else
         goto err;
+#endif
+    } else {
+        if (!TEST_true(BIO_new_bio_dgram_pair(&cbio, 0, &sbio, 0)))
+            goto err;
 
-    SSL_set_bio(*cssl, cbio, cbio);
+        if (!TEST_true(BIO_dgram_set_caps(cbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
+                || !TEST_true(BIO_dgram_set_caps(sbio, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
+            goto err;
 
-    if (!TEST_ptr(peeraddr = BIO_ADDR_new()))
-        goto err;
+        /* Dummy server address */
+        if (!TEST_true(BIO_ADDR_rawmake(peeraddr, AF_INET, &ina, sizeof(ina),
+                                        htons(0))))
+            goto err;
+    }
+
+    SSL_set_bio(*cssl, cbio, cbio);
 
-    /* Dummy server address */
-    if (!TEST_true(BIO_ADDR_rawmake(peeraddr, AF_INET, &ina, sizeof(ina),
-                                    htons(0))))
+    if (!TEST_true(SSL_set_blocking_mode(*cssl,
+                                         (flags & QTEST_FLAG_BLOCK) != 0 ? 1 : 0)))
         goto err;
 
-    if (!TEST_true(SSL_set_initial_peer_addr(*cssl, peeraddr)))
+    if (!TEST_true(SSL_set1_initial_peer_addr(*cssl, peeraddr)))
         goto err;
 
     if (fault != NULL) {
@@ -104,7 +155,7 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
             goto err;
     }
 
-    fisbio = BIO_new(get_bio_method());
+    fisbio = BIO_new(qtest_get_bio_method());
     if (!TEST_ptr(fisbio))
         goto err;
 
@@ -113,8 +164,20 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
     if (!TEST_ptr(BIO_push(fisbio, sbio)))
         goto err;
 
+    tserver_args.libctx = libctx;
     tserver_args.net_rbio = sbio;
     tserver_args.net_wbio = fisbio;
+    tserver_args.alpn = NULL;
+    if (serverctx != NULL && !TEST_true(SSL_CTX_up_ref(serverctx)))
+        goto err;
+    tserver_args.ctx = serverctx;
+    if ((flags & QTEST_FLAG_FAKE_TIME) != 0) {
+        fake_now = ossl_time_zero();
+        /* zero time can have a special meaning, bump it */
+        qtest_add_time(1);
+        tserver_args.now_cb = fake_now_cb;
+        (void)ossl_quic_conn_set_override_now_cb(*cssl, fake_now_cb, NULL);
+    }
 
     if (!TEST_ptr(*qtserv = ossl_quic_tserver_new(&tserver_args, certfile,
                                                   keyfile)))
@@ -131,11 +194,13 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
 
     return 1;
  err:
+    SSL_CTX_free(tserver_args.ctx);
     BIO_ADDR_free(peeraddr);
     BIO_free(cbio);
     BIO_free(fisbio);
     BIO_free(sbio);
     SSL_free(*cssl);
+    *cssl = NULL;
     ossl_quic_tserver_free(*qtserv);
     if (fault != NULL)
         OPENSSL_free(*fault);
@@ -143,39 +208,131 @@ int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
     return 0;
 }
 
+void qtest_add_time(uint64_t millis)
+{
+    fake_now = ossl_time_add(fake_now, ossl_ms2time(millis));
+}
+
+QTEST_FAULT *qtest_create_injector(QUIC_TSERVER *ts)
+{
+    QTEST_FAULT *f;
+
+    f = OPENSSL_zalloc(sizeof(*f));
+    if (f == NULL)
+        return NULL;
+
+    f->qtserv = ts;
+    return f;
+
+}
+
+int qtest_supports_blocking(void)
+{
+#if !defined(OPENSSL_NO_POSIX_IO) && defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 #define MAXLOOPS    1000
 
-int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+static int globserverret = 0;
+static TSAN_QUALIFIER int abortserverthread = 0;
+static QUIC_TSERVER *globtserv;
+static const thread_t thread_zero;
+
+static void run_server_thread(void)
+{
+    /*
+     * This will operate in a busy loop because the server does not block,
+     * but should be acceptable because it is local and we expect this to be
+     * fast
+     */
+    globserverret = qtest_create_quic_connection(globtserv, NULL);
+}
+#endif
+
+int qtest_create_quic_connection_ex(QUIC_TSERVER *qtserv, SSL *clientssl,
+                                    int wanterr)
 {
-    int retc = -1, rets = 0, err, abortctr = 0, ret = 0;
+    int retc = -1, rets = 0, abortctr = 0, ret = 0;
     int clienterr = 0, servererr = 0;
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+    /*
+     * Pointless initialisation to avoid bogus compiler warnings about using
+     * t uninitialised
+     */
+    thread_t t = thread_zero;
+
+    if (clientssl != NULL)
+        abortserverthread = 0;
+#endif
+
+    if (!TEST_ptr(qtserv)) {
+        goto err;
+    } else if (clientssl == NULL) {
+        retc = 1;
+    } else if (SSL_get_blocking_mode(clientssl) > 0) {
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+        /*
+         * clientssl is blocking. We will need a thread to complete the
+         * connection
+         */
+        globtserv = qtserv;
+        if (!TEST_true(run_thread(&t, run_server_thread)))
+            goto err;
+
+        qtserv = NULL;
+        rets = 1;
+#else
+        TEST_error("No thread support in this build");
+        goto err;
+#endif
+    }
 
     do {
-        err = SSL_ERROR_WANT_WRITE;
-        while (!clienterr && retc <= 0 && err == SSL_ERROR_WANT_WRITE) {
+        if (!clienterr && retc <= 0) {
+            int err;
+
             retc = SSL_connect(clientssl);
-            if (retc <= 0)
+            if (retc <= 0) {
                 err = SSL_get_error(clientssl, retc);
-        }
 
-        if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
-            TEST_info("SSL_connect() failed %d, %d", retc, err);
-            TEST_openssl_errors();
-            clienterr = 1;
+                if (err == wanterr) {
+                    retc = 1;
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+                    if (qtserv == NULL && rets > 0)
+                        tsan_store(&abortserverthread, 1);
+                    else
+#endif
+                        rets = 1;
+                } else {
+                    if (err != SSL_ERROR_WANT_READ
+                            && err != SSL_ERROR_WANT_WRITE) {
+                        TEST_info("SSL_connect() failed %d, %d", retc, err);
+                        TEST_openssl_errors();
+                        clienterr = 1;
+                    }
+                }
+            }
         }
 
         /*
          * We're cheating. We don't take any notice of SSL_get_tick_timeout()
-         * and tick everytime around the loop anyway. This is inefficient. We
+         * and tick every time around the loop anyway. This is inefficient. We
          * can get away with it in test code because we control both ends of
          * the communications and don't expect network delays. This shouldn't
          * be done in a real application.
          */
         if (!clienterr && retc <= 0)
-            SSL_tick(clientssl);
+            SSL_handle_events(clientssl);
+
         if (!servererr && rets <= 0) {
+            qtest_add_time(1);
             ossl_quic_tserver_tick(qtserv);
-            servererr = ossl_quic_tserver_is_term_any(qtserv, NULL);
+            servererr = ossl_quic_tserver_is_term_any(qtserv);
             if (!servererr)
                 rets = ossl_quic_tserver_is_handshake_confirmed(qtserv);
         }
@@ -183,11 +340,26 @@ int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
         if (clienterr && servererr)
             goto err;
 
-        if (++abortctr == MAXLOOPS) {
+        if (clientssl != NULL && ++abortctr == MAXLOOPS) {
             TEST_info("No progress made");
             goto err;
         }
-    } while ((retc <= 0 && !clienterr) || (rets <= 0 && !servererr));
+    } while ((retc <= 0 && !clienterr)
+             || (rets <= 0 && !servererr
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+                 && !tsan_load(&abortserverthread)
+#endif
+                ));
+
+    if (qtserv == NULL && rets > 0) {
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+        if (!TEST_true(wait_for_thread(t)) || !TEST_true(globserverret))
+            goto err;
+#else
+        TEST_error("Should not happen");
+        goto err;
+#endif
+    }
 
     if (!clienterr && !servererr)
         ret = 1;
@@ -195,25 +367,117 @@ int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
     return ret;
 }
 
-int qtest_check_server_protocol_err(QUIC_TSERVER *qtserv)
+int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
 {
-    QUIC_TERMINATE_CAUSE cause;
+    return qtest_create_quic_connection_ex(qtserv, clientssl, SSL_ERROR_NONE);
+}
+
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+static TSAN_QUALIFIER int shutdowndone;
+
+static void run_server_shutdown_thread(void)
+{
+    /*
+     * This will operate in a busy loop because the server does not block,
+     * but should be acceptable because it is local and we expect this to be
+     * fast
+     */
+    do {
+        ossl_quic_tserver_tick(globtserv);
+    } while(!tsan_load(&shutdowndone));
+}
+#endif
+
+int qtest_shutdown(QUIC_TSERVER *qtserv, SSL *clientssl)
+{
+    int tickserver = 1;
+    int ret = 0;
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+    /*
+     * Pointless initialisation to avoid bogus compiler warnings about using
+     * t uninitialised
+     */
+    thread_t t = thread_zero;
+#endif
+
+    if (SSL_get_blocking_mode(clientssl) > 0) {
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+        /*
+         * clientssl is blocking. We will need a thread to complete the
+         * connection
+         */
+        globtserv = qtserv;
+        shutdowndone = 0;
+        if (!TEST_true(run_thread(&t, run_server_shutdown_thread)))
+            return 0;
+
+        tickserver = 0;
+#else
+        TEST_error("No thread support in this build");
+        return 0;
+#endif
+    }
+
+    /* Busy loop in non-blocking mode. It should be quick because its local */
+    for (;;) {
+        int rc = SSL_shutdown(clientssl);
+
+        if (rc == 1) {
+            ret = 1;
+            break;
+        }
+
+        if (rc < 0)
+            break;
+
+        if (tickserver)
+            ossl_quic_tserver_tick(qtserv);
+    }
+
+#if defined(OPENSSL_THREADS) && !defined(CRYPTO_TDEBUG)
+    tsan_store(&shutdowndone, 1);
+    if (!tickserver) {
+        if (!TEST_true(wait_for_thread(t)))
+            ret = 0;
+    }
+#endif
+
+    return ret;
+}
+
+int qtest_check_server_transport_err(QUIC_TSERVER *qtserv, uint64_t code)
+{
+    const QUIC_TERMINATE_CAUSE *cause;
 
     ossl_quic_tserver_tick(qtserv);
 
     /*
-     * Check that the server has received the protocol violation error
-     * connection close from the client
+     * Check that the server has closed with the specified code from the client
      */
-    if (!TEST_true(ossl_quic_tserver_is_term_any(qtserv, &cause))
-            || !TEST_true(cause.remote)
-            || !TEST_uint64_t_eq(cause.error_code, QUIC_ERR_PROTOCOL_VIOLATION))
+    if (!TEST_true(ossl_quic_tserver_is_term_any(qtserv)))
+        return 0;
+
+    cause = ossl_quic_tserver_get_terminate_cause(qtserv);
+    if  (!TEST_ptr(cause)
+            || !TEST_true(cause->remote)
+            || !TEST_false(cause->app)
+            || !TEST_uint64_t_eq(cause->error_code, code))
         return 0;
 
     return 1;
 }
 
-void ossl_quic_fault_free(OSSL_QUIC_FAULT *fault)
+int qtest_check_server_protocol_err(QUIC_TSERVER *qtserv)
+{
+    return qtest_check_server_transport_err(qtserv, QUIC_ERR_PROTOCOL_VIOLATION);
+}
+
+int qtest_check_server_frame_encoding_err(QUIC_TSERVER *qtserv)
+{
+    return qtest_check_server_transport_err(qtserv, QUIC_ERR_FRAME_ENCODING_ERROR);
+}
+
+void qtest_fault_free(QTEST_FAULT *fault)
 {
     if (fault == NULL)
         return;
@@ -231,7 +495,7 @@ static int packet_plain_mutate(const QUIC_PKT_HDR *hdrin,
                                size_t *numout,
                                void *arg)
 {
-    OSSL_QUIC_FAULT *fault = arg;
+    QTEST_FAULT *fault = arg;
     size_t i, bufsz = 0;
     unsigned char *cur;
 
@@ -278,7 +542,7 @@ static int packet_plain_mutate(const QUIC_PKT_HDR *hdrin,
 
 static void packet_plain_finish(void *arg)
 {
-    OSSL_QUIC_FAULT *fault = arg;
+    QTEST_FAULT *fault = arg;
 
     /* Cast below is safe because we allocated the buffer */
     OPENSSL_free((unsigned char *)fault->pplainio.buf);
@@ -287,9 +551,9 @@ static void packet_plain_finish(void *arg)
     fault->pplainio.buf = NULL;
 }
 
-int ossl_quic_fault_set_packet_plain_listener(OSSL_QUIC_FAULT *fault,
-                                              ossl_quic_fault_on_packet_plain_cb pplaincb,
-                                              void *pplaincbarg)
+int qtest_fault_set_packet_plain_listener(QTEST_FAULT *fault,
+                                          qtest_fault_on_packet_plain_cb pplaincb,
+                                          void *pplaincbarg)
 {
     fault->pplaincb = pplaincb;
     fault->pplaincbarg = pplaincbarg;
@@ -301,7 +565,7 @@ int ossl_quic_fault_set_packet_plain_listener(OSSL_QUIC_FAULT *fault,
 }
 
 /* To be called from a packet_plain_listener callback */
-int ossl_quic_fault_resize_plain_packet(OSSL_QUIC_FAULT *fault, size_t newlen)
+int qtest_fault_resize_plain_packet(QTEST_FAULT *fault, size_t newlen)
 {
     unsigned char *buf;
     size_t oldlen = fault->pplainio.buf_len;
@@ -332,11 +596,43 @@ int ossl_quic_fault_resize_plain_packet(OSSL_QUIC_FAULT *fault, size_t newlen)
     return 1;
 }
 
+/*
+ * Prepend frame data into a packet. To be called from a packet_plain_listener
+ * callback
+ */
+int qtest_fault_prepend_frame(QTEST_FAULT *fault, const unsigned char *frame,
+                              size_t frame_len)
+{
+    unsigned char *buf;
+    size_t old_len;
+
+    /*
+     * Alloc'd size should always be non-zero, so if this fails we've been
+     * incorrectly called
+     */
+    if (fault->pplainbuf_alloc == 0)
+        return 0;
+
+    /* Cast below is safe because we allocated the buffer */
+    buf = (unsigned char *)fault->pplainio.buf;
+    old_len = fault->pplainio.buf_len;
+
+    /* Extend the size of the packet by the size of the new frame */
+    if (!TEST_true(qtest_fault_resize_plain_packet(fault,
+                                                   old_len + frame_len)))
+        return 0;
+
+    memmove(buf + frame_len, buf, old_len);
+    memcpy(buf, frame, frame_len);
+
+    return 1;
+}
+
 static int handshake_mutate(const unsigned char *msgin, size_t msginlen,
                             unsigned char **msgout, size_t *msgoutlen,
                             void *arg)
 {
-    OSSL_QUIC_FAULT *fault = arg;
+    QTEST_FAULT *fault = arg;
     unsigned char *buf;
     unsigned long payloadlen;
     unsigned int msgtype;
@@ -361,7 +657,7 @@ static int handshake_mutate(const unsigned char *msgin, size_t msginlen,
     switch (msgtype) {
     case SSL3_MT_ENCRYPTED_EXTENSIONS:
     {
-        OSSL_QF_ENCRYPTED_EXTENSIONS ee;
+        QTEST_ENCRYPTED_EXTENSIONS ee;
 
         if (fault->encextcb == NULL)
             break;
@@ -394,15 +690,15 @@ static int handshake_mutate(const unsigned char *msgin, size_t msginlen,
 
 static void handshake_finish(void *arg)
 {
-    OSSL_QUIC_FAULT *fault = arg;
+    QTEST_FAULT *fault = arg;
 
     OPENSSL_free(fault->handbuf);
     fault->handbuf = NULL;
 }
 
-int ossl_quic_fault_set_handshake_listener(OSSL_QUIC_FAULT *fault,
-                                           ossl_quic_fault_on_handshake_cb handshakecb,
-                                           void *handshakecbarg)
+int qtest_fault_set_handshake_listener(QTEST_FAULT *fault,
+                                       qtest_fault_on_handshake_cb handshakecb,
+                                       void *handshakecbarg)
 {
     fault->handshakecb = handshakecb;
     fault->handshakecbarg = handshakecbarg;
@@ -413,9 +709,9 @@ int ossl_quic_fault_set_handshake_listener(OSSL_QUIC_FAULT *fault,
                                                    fault);
 }
 
-int ossl_quic_fault_set_hand_enc_ext_listener(OSSL_QUIC_FAULT *fault,
-                                              ossl_quic_fault_on_enc_ext_cb encextcb,
-                                              void *encextcbarg)
+int qtest_fault_set_hand_enc_ext_listener(QTEST_FAULT *fault,
+                                          qtest_fault_on_enc_ext_cb encextcb,
+                                          void *encextcbarg)
 {
     fault->encextcb = encextcb;
     fault->encextcbarg = encextcbarg;
@@ -427,7 +723,7 @@ int ossl_quic_fault_set_hand_enc_ext_listener(OSSL_QUIC_FAULT *fault,
 }
 
 /* To be called from a handshake_listener callback */
-int ossl_quic_fault_resize_handshake(OSSL_QUIC_FAULT *fault, size_t newlen)
+int qtest_fault_resize_handshake(QTEST_FAULT *fault, size_t newlen)
 {
     unsigned char *buf;
     size_t oldlen = fault->handbuflen;
@@ -456,10 +752,10 @@ int ossl_quic_fault_resize_handshake(OSSL_QUIC_FAULT *fault, size_t newlen)
 }
 
 /* To be called from message specific listener callbacks */
-int ossl_quic_fault_resize_message(OSSL_QUIC_FAULT *fault, size_t newlen)
+int qtest_fault_resize_message(QTEST_FAULT *fault, size_t newlen)
 {
     /* First resize the underlying message */
-    if (!ossl_quic_fault_resize_handshake(fault, newlen + SSL3_HM_HEADER_LENGTH))
+    if (!qtest_fault_resize_handshake(fault, newlen + SSL3_HM_HEADER_LENGTH))
         return 0;
 
     /* Fixup the handshake message header */
@@ -470,9 +766,9 @@ int ossl_quic_fault_resize_message(OSSL_QUIC_FAULT *fault, size_t newlen)
     return 1;
 }
 
-int ossl_quic_fault_delete_extension(OSSL_QUIC_FAULT *fault,
-                                     unsigned int exttype, unsigned char *ext,
-                                     size_t *extlen)
+int qtest_fault_delete_extension(QTEST_FAULT *fault,
+                                 unsigned int exttype, unsigned char *ext,
+                                 size_t *extlen)
 {
     PACKET pkt, sub, subext;
     unsigned int type;
@@ -527,7 +823,7 @@ int ossl_quic_fault_delete_extension(OSSL_QUIC_FAULT *fault,
     if ((size_t)(end - start) + SSL3_HM_HEADER_LENGTH > msglen)
         return 0; /* Should not happen */
     msglen -= (end - start) + SSL3_HM_HEADER_LENGTH;
-    if (!ossl_quic_fault_resize_message(fault, msglen))
+    if (!qtest_fault_resize_message(fault, msglen))
         return 0;
 
     return 1;
@@ -543,21 +839,20 @@ static int pcipher_sendmmsg(BIO *b, BIO_MSG *msg, size_t stride,
                             size_t num_msg, uint64_t flags,
                             size_t *num_processed)
 {
-    OSSL_QUIC_FAULT *fault;
+    QTEST_FAULT *fault;
     BIO *next = BIO_next(b);
     ossl_ssize_t ret = 0;
-    BIO_MSG m;
     size_t i = 0, tmpnump;
     QUIC_PKT_HDR hdr;
     PACKET pkt;
-
-    m.data = NULL;
+    unsigned char *tmpdata;
 
     if (next == NULL)
         return 0;
 
     fault = BIO_get_data(b);
-    if (fault == NULL || fault->pciphercb == NULL)
+    if (fault == NULL
+            || (fault->pciphercb == NULL && fault->datagramcb == NULL))
         return BIO_sendmmsg(next, msg, stride, num_msg, flags, num_processed);
 
     if (num_msg == 0) {
@@ -566,45 +861,70 @@ static int pcipher_sendmmsg(BIO *b, BIO_MSG *msg, size_t stride,
     }
 
     for (i = 0; i < num_msg; ++i) {
-        m = BIO_MSG_N(msg, stride, i);
+        fault->msg = BIO_MSG_N(msg, stride, i);
 
         /* Take a copy of the data so that callbacks can modify it */
-        m.data = OPENSSL_memdup(m.data, m.data_len);
-        if (m.data == NULL)
-            return 0;
-
-        if (!PACKET_buf_init(&pkt, m.data, m.data_len))
+        tmpdata = OPENSSL_malloc(fault->msg.data_len + GROWTH_ALLOWANCE);
+        if (tmpdata == NULL)
             return 0;
+        memcpy(tmpdata, fault->msg.data, fault->msg.data_len);
+        fault->msg.data = tmpdata;
+        fault->msgalloc = fault->msg.data_len + GROWTH_ALLOWANCE;
+
+        if (fault->pciphercb != NULL) {
+            if (!PACKET_buf_init(&pkt, fault->msg.data, fault->msg.data_len))
+                return 0;
+
+            do {
+                if (!ossl_quic_wire_decode_pkt_hdr(&pkt,
+                        /*
+                         * TODO(QUIC SERVER):
+                         * Needs to be set to the actual short header CID length
+                         * when testing the server implementation.
+                         */
+                        0,
+                        1,
+                        0, &hdr, NULL))
+                    goto out;
+
+                /*
+                 * hdr.data is const - but its our buffer so casting away the
+                 * const is safe
+                 */
+                if (!fault->pciphercb(fault, &hdr, (unsigned char *)hdr.data,
+                                    hdr.len, fault->pciphercbarg))
+                    goto out;
+
+                /*
+                 * At the moment modifications to hdr by the callback
+                 * are ignored. We might need to rewrite the QUIC header to
+                 * enable tests to change this. We also don't yet have a
+                 * mechanism for the callback to change the encrypted data
+                 * length. It's not clear if that's needed or not.
+                 */
+            } while (PACKET_remaining(&pkt) > 0);
+        }
 
-        do {
-            if (!ossl_quic_wire_decode_pkt_hdr(&pkt,
-                    0/* TODO(QUIC): Not sure how this should be set*/, 1, &hdr,
-                    NULL))
-                goto out;
-
-            /* TODO(QUIC): Resolve const issue here */
-            if (!fault->pciphercb(fault, &hdr, (unsigned char *)hdr.data,
-                                  hdr.len, fault->pciphercbarg))
-                goto out;
-        } while (PACKET_remaining(&pkt) > 0);
+        if (fault->datagramcb != NULL
+                && !fault->datagramcb(fault, &fault->msg, stride,
+                                      fault->datagramcbarg))
+            goto out;
 
-        if (!BIO_sendmmsg(next, &m, stride, 1, flags, &tmpnump)) {
+        if (!BIO_sendmmsg(next, &fault->msg, stride, 1, flags, &tmpnump)) {
             *num_processed = i;
             goto out;
         }
 
-        OPENSSL_free(m.data);
-        m.data = NULL;
+        OPENSSL_free(fault->msg.data);
+        fault->msg.data = NULL;
+        fault->msgalloc = 0;
     }
 
     *num_processed = i;
-    ret = 1;
 out:
-    if (i > 0)
-        ret = 1;
-    else
-        ret = 0;
-    OPENSSL_free(m.data);
+    ret = i > 0;
+    OPENSSL_free(fault->msg.data);
+    fault->msg.data = NULL;
     return ret;
 }
 
@@ -618,7 +938,7 @@ static long pcipher_ctrl(BIO *b, int cmd, long larg, void *parg)
     return BIO_ctrl(next, cmd, larg, parg);
 }
 
-static BIO_METHOD *get_bio_method(void)
+BIO_METHOD *qtest_get_bio_method(void)
 {
     BIO_METHOD *tmp;
 
@@ -641,12 +961,37 @@ static BIO_METHOD *get_bio_method(void)
     return pcipherbiometh;
 }
 
-int ossl_quic_fault_set_packet_cipher_listener(OSSL_QUIC_FAULT *fault,
-                                               ossl_quic_fault_on_packet_cipher_cb pciphercb,
-                                               void *pciphercbarg)
+int qtest_fault_set_packet_cipher_listener(QTEST_FAULT *fault,
+                                           qtest_fault_on_packet_cipher_cb pciphercb,
+                                           void *pciphercbarg)
 {
     fault->pciphercb = pciphercb;
     fault->pciphercbarg = pciphercbarg;
 
     return 1;
-}
\ No newline at end of file
+}
+
+int qtest_fault_set_datagram_listener(QTEST_FAULT *fault,
+                                      qtest_fault_on_datagram_cb datagramcb,
+                                      void *datagramcbarg)
+{
+    fault->datagramcb = datagramcb;
+    fault->datagramcbarg = datagramcbarg;
+
+    return 1;
+}
+
+/* To be called from a datagram_listener callback */
+int qtest_fault_resize_datagram(QTEST_FAULT *fault, size_t newlen)
+{
+    if (newlen > fault->msgalloc)
+            return 0;
+
+    if (newlen > fault->msg.data_len)
+        memset((unsigned char *)fault->msg.data + fault->msg.data_len, 0,
+                newlen - fault->msg.data_len);
+
+    fault->msg.data_len = newlen;
+
+    return 1;
+}