]> git.ipfire.org Git - thirdparty/openssl.git/blob - test/helpers/quictestlib.c
Implement the QUIC Fault injector support for TLS handshake messages
[thirdparty/openssl.git] / test / helpers / quictestlib.c
1 /*
2 * Copyright 2022 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10 #include <assert.h>
11 #include "quictestlib.h"
12 #include "../testutil.h"
13 #include "internal/quic_wire_pkt.h"
14 #include "internal/quic_record_tx.h"
15 #include "internal/packet.h"
16
17 #define GROWTH_ALLOWANCE 1024
18
19 struct ossl_quic_fault {
20 QUIC_TSERVER *qtserv;
21
22 /* Plain packet mutations */
23 /* Header for the plaintext packet */
24 QUIC_PKT_HDR pplainhdr;
25 /* iovec for the plaintext packet data buffer */
26 OSSL_QTX_IOVEC pplainio;
27 /* Allocted size of the plaintext packet data buffer */
28 size_t pplainbuf_alloc;
29 ossl_quic_fault_on_packet_plain_cb pplaincb;
30 void *pplaincbarg;
31
32 /* Handshake message mutations */
33 /* Handshake message buffer */
34 unsigned char *handbuf;
35 /* Allocated size of the handshake message buffer */
36 size_t handbufalloc;
37 /* Actual length of the handshake message */
38 size_t handbuflen;
39 ossl_quic_fault_on_handshake_cb handshakecb;
40 void *handshakecbarg;
41 ossl_quic_fault_on_enc_ext_cb encextcb;
42 void *encextcbarg;
43 };
44
45 static void packet_plain_finish(void *arg);
46 static void handshake_finish(void *arg);
47
48 int qtest_create_quic_objects(SSL_CTX *clientctx, char *certfile, char *keyfile,
49 QUIC_TSERVER **qtserv, SSL **cssl,
50 OSSL_QUIC_FAULT **fault)
51 {
52 /* ALPN value as recognised by QUIC_TSERVER */
53 unsigned char alpn[] = { 8, 'o', 's', 's', 'l', 't', 'e', 's', 't' };
54 QUIC_TSERVER_ARGS tserver_args = {0};
55 BIO *bio1 = NULL, *bio2 = NULL;
56 BIO_ADDR *peeraddr = NULL;
57 struct in_addr ina = {0};
58
59 *qtserv = NULL;
60 if (fault != NULL)
61 *fault = NULL;
62 *cssl = SSL_new(clientctx);
63 if (!TEST_ptr(*cssl))
64 return 0;
65
66 if (!TEST_true(SSL_set_blocking_mode(*cssl, 0)))
67 goto err;
68
69 /* SSL_set_alpn_protos returns 0 for success! */
70 if (!TEST_false(SSL_set_alpn_protos(*cssl, alpn, sizeof(alpn))))
71 goto err;
72
73 if (!TEST_true(BIO_new_bio_dgram_pair(&bio1, 0, &bio2, 0)))
74 goto err;
75
76 if (!TEST_true(BIO_dgram_set_caps(bio1, BIO_DGRAM_CAP_HANDLES_DST_ADDR))
77 || !TEST_true(BIO_dgram_set_caps(bio2, BIO_DGRAM_CAP_HANDLES_DST_ADDR)))
78 goto err;
79
80 SSL_set_bio(*cssl, bio1, bio1);
81
82 if (!TEST_ptr(peeraddr = BIO_ADDR_new()))
83 goto err;
84
85 /* Dummy server address */
86 if (!TEST_true(BIO_ADDR_rawmake(peeraddr, AF_INET, &ina, sizeof(ina),
87 htons(0))))
88 goto err;
89
90 if (!TEST_true(SSL_set_initial_peer_addr(*cssl, peeraddr)))
91 goto err;
92
93 /* 2 refs are passed for bio2 */
94 if (!BIO_up_ref(bio2))
95 goto err;
96 tserver_args.net_rbio = bio2;
97 tserver_args.net_wbio = bio2;
98
99 if (!TEST_ptr(*qtserv = ossl_quic_tserver_new(&tserver_args, certfile,
100 keyfile))) {
101 /* We hold 2 refs to bio2 at the moment */
102 BIO_free(bio2);
103 goto err;
104 }
105 /* Ownership of bio2 is now held by *qtserv */
106 bio2 = NULL;
107
108 if (fault != NULL) {
109 *fault = OPENSSL_zalloc(sizeof(**fault));
110 if (*fault == NULL)
111 goto err;
112
113 (*fault)->qtserv = *qtserv;
114 }
115
116 BIO_ADDR_free(peeraddr);
117
118 return 1;
119 err:
120 BIO_ADDR_free(peeraddr);
121 BIO_free(bio1);
122 BIO_free(bio2);
123 SSL_free(*cssl);
124 ossl_quic_tserver_free(*qtserv);
125 if (fault != NULL)
126 OPENSSL_free(*fault);
127
128 return 0;
129 }
130
131 #define MAXLOOPS 1000
132
133 int qtest_create_quic_connection(QUIC_TSERVER *qtserv, SSL *clientssl)
134 {
135 int retc = -1, rets = 0, err, abortctr = 0, ret = 0;
136 int clienterr = 0, servererr = 0;
137
138 do {
139 err = SSL_ERROR_WANT_WRITE;
140 while (!clienterr && retc <= 0 && err == SSL_ERROR_WANT_WRITE) {
141 retc = SSL_connect(clientssl);
142 if (retc <= 0)
143 err = SSL_get_error(clientssl, retc);
144 }
145
146 if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
147 TEST_info("SSL_connect() failed %d, %d", retc, err);
148 TEST_openssl_errors();
149 clienterr = 1;
150 }
151
152 /*
153 * We're cheating. We don't take any notice of SSL_get_tick_timeout()
154 * and tick everytime around the loop anyway. This is inefficient. We
155 * can get away with it in test code because we control both ends of
156 * the communications and don't expect network delays. This shouldn't
157 * be done in a real application.
158 */
159 if (!clienterr)
160 SSL_tick(clientssl);
161 if (!servererr) {
162 ossl_quic_tserver_tick(qtserv);
163 servererr = ossl_quic_tserver_is_term_any(qtserv, NULL);
164 if (!servererr && !rets)
165 rets = ossl_quic_tserver_is_connected(qtserv);
166 }
167
168 if (clienterr && servererr)
169 goto err;
170
171 if (++abortctr == MAXLOOPS) {
172 TEST_info("No progress made");
173 goto err;
174 }
175 } while (retc <=0 || rets <= 0);
176
177 ret = 1;
178 err:
179 return ret;
180 }
181
182 void ossl_quic_fault_free(OSSL_QUIC_FAULT *fault)
183 {
184 if (fault == NULL)
185 return;
186
187 packet_plain_finish(fault);
188 handshake_finish(fault);
189
190 OPENSSL_free(fault);
191 }
192
193 static int packet_plain_mutate(const QUIC_PKT_HDR *hdrin,
194 const OSSL_QTX_IOVEC *iovecin, size_t numin,
195 QUIC_PKT_HDR **hdrout,
196 const OSSL_QTX_IOVEC **iovecout,
197 size_t *numout,
198 void *arg)
199 {
200 OSSL_QUIC_FAULT *fault = arg;
201 size_t i, bufsz = 0;
202 unsigned char *cur;
203
204 /* Coalesce our data into a single buffer */
205
206 /* First calculate required buffer size */
207 for (i = 0; i < numin; i++)
208 bufsz += iovecin[i].buf_len;
209
210 fault->pplainio.buf_len = bufsz;
211
212 /* Add an allowance for possible growth */
213 bufsz += GROWTH_ALLOWANCE;
214
215 fault->pplainio.buf = cur = OPENSSL_malloc(bufsz);
216 if (cur == NULL) {
217 fault->pplainio.buf_len = 0;
218 return 0;
219 }
220
221 fault->pplainbuf_alloc = bufsz;
222
223 /* Copy in the data from the input buffers */
224 for (i = 0; i < numin; i++) {
225 memcpy(cur, iovecin[i].buf, iovecin[i].buf_len);
226 cur += iovecin[i].buf_len;
227 }
228
229 fault->pplainhdr = *hdrin;
230
231 /* Cast below is safe because we allocated the buffer */
232 if (fault->pplaincb != NULL
233 && !fault->pplaincb(fault, &fault->pplainhdr,
234 (unsigned char *)fault->pplainio.buf,
235 fault->pplainio.buf_len, fault->pplaincbarg))
236 return 0;
237
238 *hdrout = &fault->pplainhdr;
239 *iovecout = &fault->pplainio;
240 *numout = 1;
241
242 return 1;
243 }
244
245 static void packet_plain_finish(void *arg)
246 {
247 OSSL_QUIC_FAULT *fault = arg;
248
249 /* Cast below is safe because we allocated the buffer */
250 OPENSSL_free((unsigned char *)fault->pplainio.buf);
251 fault->pplainio.buf_len = 0;
252 fault->pplainbuf_alloc = 0;
253 fault->pplainio.buf = NULL;
254 }
255
256 int ossl_quic_fault_set_packet_plain_listener(OSSL_QUIC_FAULT *fault,
257 ossl_quic_fault_on_packet_plain_cb pplaincb,
258 void *pplaincbarg)
259 {
260 fault->pplaincb = pplaincb;
261 fault->pplaincbarg = pplaincbarg;
262
263 return ossl_quic_tserver_set_plain_packet_mutator(fault->qtserv,
264 packet_plain_mutate,
265 packet_plain_finish,
266 fault);
267 }
268
269 /* To be called from a packet_plain_listener callback */
270 int ossl_quic_fault_resize_plain_packet(OSSL_QUIC_FAULT *fault, size_t newlen)
271 {
272 unsigned char *buf;
273 size_t oldlen = fault->pplainio.buf_len;
274
275 /*
276 * Alloc'd size should always be non-zero, so if this fails we've been
277 * incorrectly called
278 */
279 if (fault->pplainbuf_alloc == 0)
280 return 0;
281
282 if (newlen > fault->pplainbuf_alloc) {
283 /* This exceeds our growth allowance. Fail */
284 return 0;
285 }
286
287 /* Cast below is safe because we allocated the buffer */
288 buf = (unsigned char *)fault->pplainio.buf;
289
290 if (newlen > oldlen) {
291 /* Extend packet with 0 bytes */
292 memset(buf + oldlen, 0, newlen - oldlen);
293 } /* else we're truncating or staying the same */
294
295 fault->pplainio.buf_len = newlen;
296 fault->pplainhdr.len = newlen;
297
298 return 1;
299 }
300
301 static int handshake_mutate(const unsigned char *msgin, size_t msginlen,
302 unsigned char **msgout, size_t *msgoutlen,
303 void *arg)
304 {
305 OSSL_QUIC_FAULT *fault = arg;
306 unsigned char *buf;
307 unsigned long payloadlen;
308 unsigned int msgtype;
309 PACKET pkt;
310
311 buf = OPENSSL_malloc(msginlen + GROWTH_ALLOWANCE);
312 if (buf == NULL)
313 return 0;
314
315 fault->handbuf = buf;
316 fault->handbuflen = msginlen;
317 fault->handbufalloc = msginlen + GROWTH_ALLOWANCE;
318 memcpy(buf, msgin, msginlen);
319
320 if (!PACKET_buf_init(&pkt, buf, msginlen)
321 || !PACKET_get_1(&pkt, &msgtype)
322 || !PACKET_get_net_3(&pkt, &payloadlen)
323 || PACKET_remaining(&pkt) != payloadlen)
324 return 0;
325
326 /* Parse specific message types */
327 switch (msgtype) {
328 case SSL3_MT_ENCRYPTED_EXTENSIONS:
329 {
330 OSSL_QF_ENCRYPTED_EXTENSIONS ee;
331
332 if (fault->encextcb == NULL)
333 break;
334
335 /*
336 * The EncryptedExtensions message is very simple. It just has an
337 * extensions block in it and nothing else.
338 */
339 ee.extensions = (unsigned char *)PACKET_data(&pkt);
340 ee.extensionslen = payloadlen;
341 if (!fault->encextcb(fault, &ee, payloadlen, fault->encextcbarg))
342 return 0;
343 }
344
345 default:
346 /* No specific handlers for these message types yet */
347 break;
348 }
349
350 if (fault->handshakecb != NULL
351 && !fault->handshakecb(fault, buf, fault->handbuflen,
352 fault->handshakecbarg))
353 return 0;
354
355 *msgout = buf;
356 *msgoutlen = fault->handbuflen;
357
358 return 1;
359 }
360
361 static void handshake_finish(void *arg)
362 {
363 OSSL_QUIC_FAULT *fault = arg;
364
365 OPENSSL_free(fault->handbuf);
366 fault->handbuf = NULL;
367 }
368
369 int ossl_quic_fault_set_handshake_listener(OSSL_QUIC_FAULT *fault,
370 ossl_quic_fault_on_handshake_cb handshakecb,
371 void *handshakecbarg)
372 {
373 fault->handshakecb = handshakecb;
374 fault->handshakecbarg = handshakecbarg;
375
376 return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
377 handshake_mutate,
378 handshake_finish,
379 fault);
380 }
381
382 int ossl_quic_fault_set_hand_enc_ext_listener(OSSL_QUIC_FAULT *fault,
383 ossl_quic_fault_on_enc_ext_cb encextcb,
384 void *encextcbarg)
385 {
386 fault->encextcb = encextcb;
387 fault->encextcbarg = encextcbarg;
388
389 return ossl_quic_tserver_set_handshake_mutator(fault->qtserv,
390 handshake_mutate,
391 handshake_finish,
392 fault);
393 }
394
395 /* To be called from a handshake_listener callback */
396 int ossl_quic_fault_resize_handshake(OSSL_QUIC_FAULT *fault, size_t newlen)
397 {
398 unsigned char *buf;
399 size_t oldlen = fault->handbuflen;
400
401 /*
402 * Alloc'd size should always be non-zero, so if this fails we've been
403 * incorrectly called
404 */
405 if (fault->handbufalloc == 0)
406 return 0;
407
408 if (newlen > fault->handbufalloc) {
409 /* This exceeds our growth allowance. Fail */
410 return 0;
411 }
412
413 buf = (unsigned char *)fault->handbuf;
414
415 if (newlen > oldlen) {
416 /* Extend packet with 0 bytes */
417 memset(buf + oldlen, 0, newlen - oldlen);
418 } /* else we're truncating or staying the same */
419
420 fault->handbuflen = newlen;
421 return 1;
422 }
423
424 /* To be called from message specific listener callbacks */
425 int ossl_quic_fault_resize_message(OSSL_QUIC_FAULT *fault, size_t newlen)
426 {
427 /* First resize the underlying message */
428 if (!ossl_quic_fault_resize_handshake(fault, newlen + SSL3_HM_HEADER_LENGTH))
429 return 0;
430
431 /* Fixup the handshake message header */
432 fault->handbuf[1] = (unsigned char)((newlen >> 16) & 0xff);
433 fault->handbuf[2] = (unsigned char)((newlen >> 8) & 0xff);
434 fault->handbuf[3] = (unsigned char)((newlen ) & 0xff);
435
436 return 1;
437 }
438
439 int ossl_quic_fault_delete_extension(OSSL_QUIC_FAULT *fault,
440 unsigned int exttype, unsigned char *ext,
441 size_t *extlen, size_t *msglen)
442 {
443 PACKET pkt, sub, subext;
444 unsigned int type;
445 const unsigned char *start, *end;
446 size_t newlen;
447
448 if (!PACKET_buf_init(&pkt, ext, *extlen))
449 return 0;
450
451 /* Extension block starts with 2 bytes for extension block length */
452 if (!PACKET_as_length_prefixed_2(&pkt, &sub))
453 return 0;
454
455 do {
456 start = PACKET_data(&sub);
457 if (!PACKET_get_net_2(&sub, &type)
458 || !PACKET_as_length_prefixed_2(&sub, &subext))
459 return 0;
460 } while (type != exttype);
461
462 /* Found it */
463 end = PACKET_data(&sub);
464
465 /*
466 * If we're not the last extension we need to move the rest earlier. The
467 * cast below is safe because we own the underlying buffer and we're no
468 * longer making PACKET calls.
469 */
470 if (end < ext + *extlen)
471 memmove((unsigned char *)start, end, end - start);
472
473 /*
474 * Calculate new extensions payload length =
475 * Original length
476 * - 2 extension block length bytes
477 * - length of removed extension
478 */
479 newlen = *extlen - 2 - (end - start);
480
481 /* Fixup the length bytes for the extension block */
482 ext[0] = (unsigned char)((newlen >> 8) & 0xff);
483 ext[1] = (unsigned char)((newlen ) & 0xff);
484
485 /*
486 * Length of the whole extension block is the new payload length plus the
487 * 2 bytes for the length
488 */
489 *extlen = newlen + 2;
490
491 /* We can now resize the message */
492 *msglen -= (end - start);
493 if (!ossl_quic_fault_resize_message(fault, *msglen))
494 return 0;
495
496 return 1;
497 }