]> git.ipfire.org Git - thirdparty/strongswan.git/blob - scripts/tls_test.c
Update copyright headers after acquisition by secunet
[thirdparty/strongswan.git] / scripts / tls_test.c
1 /*
2 * Copyright (C) 2020 Pascal Knecht
3 * Copyright (C) 2020 Tobias Brunner
4 * Copyright (C) 2010 Martin Willi
5 *
6 * Copyright (C) secunet Security Networks AG
7 *
8 * This program is free software; you can redistribute it and/or modify it
9 * under the terms of the GNU General Public License as published by the
10 * Free Software Foundation; either version 2 of the License, or (at your
11 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
12 *
13 * This program is distributed in the hope that it will be useful, but
14 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
16 * for more details.
17 */
18
19 #include <unistd.h>
20 #include <stdio.h>
21 #include <sys/types.h>
22 #include <sys/socket.h>
23 #include <getopt.h>
24 #include <errno.h>
25 #include <string.h>
26
27 #include <library.h>
28 #include <utils/debug.h>
29 #include <tls_socket.h>
30 #include <networking/host.h>
31 #include <credentials/sets/mem_cred.h>
32
33 /**
34 * Print usage information
35 */
36 static void usage(FILE *out, char *cmd)
37 {
38 fprintf(out, "usage:\n");
39 fprintf(out, " %s --connect <address> --port <port> [--key <key] [--cert <file>] [--cacert <file>]+ [--times <n>]\n", cmd);
40 fprintf(out, " %s --listen <address> --port <port> --key <key> --cert <file> [--cacert <file>]+ [--auth-optional] [--times <n>]\n", cmd);
41 fprintf(out, "\n");
42 fprintf(out, "options:\n");
43 fprintf(out, " --help print help and exit\n");
44 fprintf(out, " --connect <address> connect to a server on dns name or ip address\n");
45 fprintf(out, " --listen <address> listen on dns name or ip address\n");
46 fprintf(out, " --port <port> specify the port to use\n");
47 fprintf(out, " --cert <file> certificate to authenticate itself\n");
48 fprintf(out, " --key <file> private key to authenticate itself\n");
49 fprintf(out, " --cacert <file> certificate to verify other peer\n");
50 fprintf(out, " --identity <id> optional remote identity to enforce\n");
51 fprintf(out, " --auth-optional don't enforce client authentication\n");
52 fprintf(out, " --times <n> specify the amount of repeated connection establishments\n");
53 fprintf(out, " --ipv4 use IPv4\n");
54 fprintf(out, " --ipv6 use IPv6\n");
55 fprintf(out, " --min-version <version> specify the minimum TLS version, supported versions:\n");
56 fprintf(out, " 1.0 (default), 1.1, 1.2 and 1.3\n");
57 fprintf(out, " --max-version <version> specify the maximum TLS version, supported versions:\n");
58 fprintf(out, " 1.0, 1.1, 1.2 and 1.3 (default)\n");
59 fprintf(out, " --version <version> set one specific TLS version to use, supported versions:\n");
60 fprintf(out, " 1.0, 1.1, 1.2 and 1.3\n");
61 fprintf(out, " --debug <debug level> set debug level, default is 1\n");
62 }
63
64 /**
65 * Check, as client, if we have a client certificate with private key
66 */
67 static identification_t *find_client_id()
68 {
69 identification_t *client = NULL, *keyid;
70 enumerator_t *enumerator;
71 certificate_t *cert;
72 public_key_t *pubkey;
73 private_key_t *privkey;
74 chunk_t chunk;
75
76 enumerator = lib->credmgr->create_cert_enumerator(lib->credmgr,
77 CERT_X509, KEY_ANY, NULL, FALSE);
78 while (enumerator->enumerate(enumerator, &cert))
79 {
80 pubkey = cert->get_public_key(cert);
81 if (pubkey)
82 {
83 if (pubkey->get_fingerprint(pubkey, KEYID_PUBKEY_SHA1, &chunk))
84 {
85 keyid = identification_create_from_encoding(ID_KEY_ID, chunk);
86 privkey = lib->credmgr->get_private(lib->credmgr,
87 pubkey->get_type(pubkey), keyid, NULL);
88 keyid->destroy(keyid);
89 if (privkey)
90 {
91 client = cert->get_subject(cert);
92 client = client->clone(client);
93 privkey->destroy(privkey);
94 }
95 }
96 pubkey->destroy(pubkey);
97 }
98 if (client)
99 {
100 break;
101 }
102 }
103 enumerator->destroy(enumerator);
104
105 return client;
106 }
107
108 /**
109 * Client routine
110 */
111 static int run_client(host_t *host, identification_t *server,
112 identification_t *client, int times, tls_cache_t *cache,
113 tls_version_t min_version, tls_version_t max_version,
114 tls_flag_t flags)
115 {
116 tls_socket_t *tls;
117 int fd, res;
118
119 while (times == -1 || times-- > 0)
120 {
121 DBG2(DBG_TLS, "connecting to %#H", host);
122 fd = socket(host->get_family(host), SOCK_STREAM, 0);
123 if (fd == -1)
124 {
125 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
126 return 1;
127 }
128 if (connect(fd, host->get_sockaddr(host),
129 *host->get_sockaddr_len(host)) == -1)
130 {
131 DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
132 close(fd);
133 return 1;
134 }
135 tls = tls_socket_create(FALSE, server, client, fd, cache, min_version,
136 max_version, flags);
137 if (!tls)
138 {
139 close(fd);
140 return 1;
141 }
142 res = tls->splice(tls, 0, 1) ? 0 : 1;
143 tls->destroy(tls);
144 close(fd);
145 if (res)
146 {
147 break;
148 }
149 }
150 return res;
151 }
152
153 /**
154 * Server routine
155 */
156 static int serve(host_t *host, identification_t *server, identification_t *client,
157 int times, tls_cache_t *cache, tls_version_t min_version,
158 tls_version_t max_version, tls_flag_t flags)
159 {
160 tls_socket_t *tls;
161 int fd, cfd;
162
163 fd = socket(AF_INET, SOCK_STREAM, 0);
164 if (fd == -1)
165 {
166 DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
167 return 1;
168 }
169 if (bind(fd, host->get_sockaddr(host),
170 *host->get_sockaddr_len(host)) == -1)
171 {
172 DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
173 close(fd);
174 return 1;
175 }
176 if (listen(fd, 1) == -1)
177 {
178 DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
179 close(fd);
180 return 1;
181 }
182
183 while (times == -1 || times-- > 0)
184 {
185 cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
186 if (cfd == -1)
187 {
188 DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
189 close(fd);
190 return 1;
191 }
192 DBG1(DBG_TLS, "%#H connected", host);
193
194 tls = tls_socket_create(TRUE, server, client, cfd, cache, min_version,
195 max_version, flags);
196 if (!tls)
197 {
198 close(fd);
199 return 1;
200 }
201 tls->splice(tls, 0, 1);
202 DBG1(DBG_TLS, "%#H disconnected", host);
203 tls->destroy(tls);
204 }
205 close(fd);
206
207 return 0;
208 }
209
210 /**
211 * In-Memory credential set
212 */
213 static mem_cred_t *creds;
214
215 /**
216 * Load certificate from file
217 */
218 static bool load_certificate(char *filename)
219 {
220 certificate_t *cert;
221
222 cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
223 BUILD_FROM_FILE, filename, BUILD_END);
224 if (!cert)
225 {
226 DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
227 return FALSE;
228 }
229 creds->add_cert(creds, TRUE, cert);
230 return TRUE;
231 }
232
233 /**
234 * Load private key from file
235 */
236 static bool load_key(char *filename)
237 {
238 private_key_t *key;
239
240 key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_ANY,
241 BUILD_FROM_FILE, filename, BUILD_END);
242 if (!key)
243 {
244 DBG1(DBG_TLS, "loading key from '%s' failed", filename);
245 return FALSE;
246 }
247 creds->add_key(creds, key);
248 return TRUE;
249 }
250
251 /**
252 * TLS debug level
253 */
254 static level_t tls_level = 1;
255
256 static void dbg_tls(debug_t group, level_t level, char *fmt, ...)
257 {
258 if ((group == DBG_TLS && level <= tls_level) || level <= 1)
259 {
260 va_list args;
261
262 va_start(args, fmt);
263 vfprintf(stderr, fmt, args);
264 fprintf(stderr, "\n");
265 va_end(args);
266 }
267 }
268
269 /**
270 * Cleanup
271 */
272 static void cleanup()
273 {
274 lib->credmgr->remove_set(lib->credmgr, &creds->set);
275 creds->destroy(creds);
276 library_deinit();
277 }
278
279 /**
280 * Initialize library
281 */
282 static void init()
283 {
284 char *plugins;
285
286 library_init(NULL, "tls_test");
287
288 dbg = dbg_tls;
289
290 plugins = getenv("PLUGINS") ?: PLUGINS;
291 lib->plugins->load(lib->plugins, plugins);
292
293 creds = mem_cred_create();
294 lib->credmgr->add_set(lib->credmgr, &creds->set);
295
296 atexit(cleanup);
297 }
298
299 int main(int argc, char *argv[])
300 {
301 char *address = NULL;
302 bool listen = FALSE;
303 int port = 0, times = -1, res, family = AF_UNSPEC;
304 identification_t *server, *client = NULL, *identity = NULL;
305 tls_version_t min_version = TLS_SUPPORTED_MIN, max_version = TLS_SUPPORTED_MAX;
306 tls_flag_t flags = TLS_FLAG_ENCRYPTION_OPTIONAL;
307 tls_cache_t *cache;
308 host_t *host;
309
310 init();
311
312 while (TRUE)
313 {
314 struct option long_opts[] = {
315 {"help", no_argument, NULL, 'h' },
316 {"connect", required_argument, NULL, 'c' },
317 {"listen", required_argument, NULL, 'l' },
318 {"port", required_argument, NULL, 'p' },
319 {"cert", required_argument, NULL, 'x' },
320 {"key", required_argument, NULL, 'k' },
321 {"cacert", required_argument, NULL, 'f' },
322 {"times", required_argument, NULL, 't' },
323 {"ipv4", no_argument, NULL, '4' },
324 {"ipv6", no_argument, NULL, '6' },
325 {"min-version", required_argument, NULL, 'm' },
326 {"max-version", required_argument, NULL, 'M' },
327 {"version", required_argument, NULL, 'v' },
328 {"auth-optional", no_argument, NULL, 'n' },
329 {"identity", required_argument, NULL, 'i' },
330 {"debug", required_argument, NULL, 'd' },
331 {0,0,0,0 }
332 };
333 switch (getopt_long(argc, argv, "", long_opts, NULL))
334 {
335 case EOF:
336 break;
337 case 'h':
338 usage(stdout, argv[0]);
339 return 0;
340 case 'x':
341 if (!load_certificate(optarg))
342 {
343 return 1;
344 }
345 continue;
346 case 'k':
347 if (!load_key(optarg))
348 {
349 return 1;
350 }
351 continue;
352 case 'f':
353 if (!load_certificate(optarg))
354 {
355 return 1;
356 }
357 client = identification_create_from_encoding(ID_ANY, chunk_empty);
358 continue;
359 case 'i':
360 identity = identification_create_from_string(optarg);
361 if (!identity)
362 {
363 return 1;
364 }
365 continue;
366 case 'l':
367 listen = TRUE;
368 /* fall */
369 case 'c':
370 if (address)
371 {
372 usage(stderr, argv[0]);
373 return 1;
374 }
375 address = optarg;
376 continue;
377 case 'p':
378 port = atoi(optarg);
379 continue;
380 case 't':
381 times = atoi(optarg);
382 continue;
383 case 'd':
384 tls_level = atoi(optarg);
385 continue;
386 case '4':
387 family = AF_INET;
388 continue;
389 case '6':
390 family = AF_INET6;
391 continue;
392 case 'm':
393 if (!enum_from_name(tls_numeric_version_names, optarg,
394 &min_version))
395 {
396 fprintf(stderr, "unknown minimum TLS version: %s\n", optarg);
397 return 1;
398 }
399 continue;
400 case 'M':
401 if (!enum_from_name(tls_numeric_version_names, optarg,
402 &max_version))
403 {
404 fprintf(stderr, "unknown maximum TLS version: %s\n", optarg);
405 return 1;
406 }
407 continue;
408 case 'v':
409 if (!enum_from_name(tls_numeric_version_names, optarg,
410 &min_version))
411 {
412 fprintf(stderr, "unknown TLS version: %s\n", optarg);
413 return 1;
414 }
415 max_version = min_version;
416 continue;
417 case 'n':
418 flags |= TLS_FLAG_CLIENT_AUTH_OPTIONAL;
419 continue;
420 default:
421 usage(stderr, argv[0]);
422 return 1;
423 }
424 break;
425 }
426 if (!port || !address)
427 {
428 usage(stderr, argv[0]);
429 return 1;
430 }
431 host = host_create_from_dns(address, family, port);
432 if (!host)
433 {
434 DBG1(DBG_TLS, "resolving hostname %s failed", address);
435 return 1;
436 }
437 server = identification_create_from_string(address);
438 cache = tls_cache_create(100, 30);
439 if (listen)
440 {
441 res = serve(host, server, identity ?: client, times, cache, min_version,
442 max_version, flags);
443 }
444 else
445 {
446 DESTROY_IF(client);
447 client = find_client_id();
448 res = run_client(host, identity ?: server, client, times, cache, min_version,
449 max_version, flags);
450 DESTROY_IF(client);
451 }
452 cache->destroy(cache);
453 host->destroy(host);
454 server->destroy(server);
455 DESTROY_IF(identity);
456 return res;
457 }