]> git.ipfire.org Git - people/ms/strongswan.git/blob - src/libtls/tls_socket.c
tls-socket: Handle sending fatal errors better
[people/ms/strongswan.git] / src / libtls / tls_socket.c
1 /*
2 * Copyright (C) 2010 Martin Willi
3 * Copyright (C) 2010 revosec AG
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include "tls_socket.h"
17
18 #include <unistd.h>
19 #include <errno.h>
20
21 #include <utils/debug.h>
22 #include <threading/thread.h>
23
24 /**
25 * Buffer size for plain side I/O
26 */
27 #define PLAIN_BUF_SIZE TLS_MAX_FRAGMENT_LEN
28
29 /**
30 * Buffer size for encrypted side I/O
31 */
32 #define CRYPTO_BUF_SIZE TLS_MAX_FRAGMENT_LEN + 2048
33
34 typedef struct private_tls_socket_t private_tls_socket_t;
35 typedef struct private_tls_application_t private_tls_application_t;
36
37 struct private_tls_application_t {
38
39 /**
40 * Implements tls_application layer.
41 */
42 tls_application_t application;
43
44 /**
45 * Output buffer to write to
46 */
47 chunk_t out;
48
49 /**
50 * Number of bytes written to out
51 */
52 size_t out_done;
53
54 /**
55 * Input buffer to read to
56 */
57 chunk_t in;
58
59 /**
60 * Number of bytes read to in
61 */
62 size_t in_done;
63
64 /**
65 * Cached input data
66 */
67 chunk_t cache;
68
69 /**
70 * Bytes consumed in cache
71 */
72 size_t cache_done;
73
74 /**
75 * Close TLS connection?
76 */
77 bool close;
78 };
79
80 /**
81 * Private data of an tls_socket_t object.
82 */
83 struct private_tls_socket_t {
84
85 /**
86 * Public tls_socket_t interface.
87 */
88 tls_socket_t public;
89
90 /**
91 * TLS application implementation
92 */
93 private_tls_application_t app;
94
95 /**
96 * TLS stack
97 */
98 tls_t *tls;
99
100 /**
101 * Underlying OS socket
102 */
103 int fd;
104
105 /**
106 * Whether the socket returned EOF
107 */
108 bool eof;
109 };
110
111 METHOD(tls_application_t, process, status_t,
112 private_tls_application_t *this, bio_reader_t *reader)
113 {
114 chunk_t data;
115 size_t len;
116
117 if (this->close)
118 {
119 return SUCCESS;
120 }
121 len = min(reader->remaining(reader), this->in.len - this->in_done);
122 if (len)
123 { /* copy to read buffer as much as fits in */
124 if (!reader->read_data(reader, len, &data))
125 {
126 return FAILED;
127 }
128 memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
129 this->in_done += data.len;
130 }
131 else
132 { /* read buffer is full, cache for next read */
133 if (!reader->read_data(reader, reader->remaining(reader), &data))
134 {
135 return FAILED;
136 }
137 this->cache = chunk_cat("mc", this->cache, data);
138 }
139 return NEED_MORE;
140 }
141
142 METHOD(tls_application_t, build, status_t,
143 private_tls_application_t *this, bio_writer_t *writer)
144 {
145 if (this->close)
146 {
147 return SUCCESS;
148 }
149 if (this->out.len > this->out_done)
150 {
151 writer->write_data(writer, this->out);
152 this->out_done = this->out.len;
153 return NEED_MORE;
154 }
155 return INVALID_STATE;
156 }
157
158 /**
159 * TLS data exchange loop
160 */
161 static bool exchange(private_tls_socket_t *this, bool wr, bool block)
162 {
163 char buf[CRYPTO_BUF_SIZE], *pos;
164 ssize_t in, out;
165 size_t len;
166 int flags;
167
168 while (TRUE)
169 {
170 while (TRUE)
171 {
172 len = sizeof(buf);
173 switch (this->tls->build(this->tls, buf, &len, NULL))
174 {
175 case NEED_MORE:
176 case ALREADY_DONE:
177 pos = buf;
178 while (len)
179 {
180 out = write(this->fd, pos, len);
181 if (out == -1)
182 {
183 DBG1(DBG_TLS, "TLS crypto write error: %s",
184 strerror(errno));
185 return FALSE;
186 }
187 len -= out;
188 pos += out;
189 }
190 continue;
191 case INVALID_STATE:
192 break;
193 case SUCCESS:
194 return TRUE;
195 default:
196 if (!wr && this->app.in_done > 0)
197 { /* return data after proper termination via fatal close
198 * notify to which we responded with one */
199 this->eof = TRUE;
200 return TRUE;
201 }
202 return FALSE;
203 }
204 break;
205 }
206 if (wr)
207 {
208 if (this->app.out_done == this->app.out.len)
209 { /* all data written */
210 return TRUE;
211 }
212 }
213 else
214 {
215 if (this->app.in_done == this->app.in.len)
216 { /* buffer fully received */
217 return TRUE;
218 }
219 }
220
221 flags = 0;
222 if (this->app.out_done == this->app.out.len)
223 {
224 if (!block || this->app.in_done)
225 {
226 flags |= MSG_DONTWAIT;
227 }
228 }
229 in = recv(this->fd, buf, sizeof(buf), flags);
230 if (in < 0)
231 {
232 if (errno == EAGAIN || errno == EWOULDBLOCK)
233 {
234 if (this->app.in_done == 0)
235 {
236 /* reading, nothing got yet, and call would block */
237 errno = EWOULDBLOCK;
238 this->app.in_done = -1;
239 }
240 return TRUE;
241 }
242 return FALSE;
243 }
244 if (in == 0)
245 { /* EOF */
246 this->eof = TRUE;
247 return TRUE;
248 }
249 switch (this->tls->process(this->tls, buf, in))
250 {
251 case NEED_MORE:
252 break;
253 case SUCCESS:
254 return TRUE;
255 default:
256 return FALSE;
257 }
258 }
259 }
260
261 METHOD(tls_socket_t, read_, ssize_t,
262 private_tls_socket_t *this, void *buf, size_t len, bool block)
263 {
264 if (this->app.cache.len)
265 {
266 size_t cache;
267
268 cache = min(len, this->app.cache.len - this->app.cache_done);
269 memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);
270
271 this->app.cache_done += cache;
272 if (this->app.cache_done == this->app.cache.len)
273 {
274 chunk_free(&this->app.cache);
275 this->app.cache_done = 0;
276 }
277 return cache;
278 }
279 if (this->eof)
280 {
281 return 0;
282 }
283 this->app.in.ptr = buf;
284 this->app.in.len = len;
285 this->app.in_done = 0;
286 if (exchange(this, FALSE, block))
287 {
288 if (!this->app.in_done && !this->eof)
289 {
290 errno = EWOULDBLOCK;
291 return -1;
292 }
293 return this->app.in_done;
294 }
295 return -1;
296 }
297
298 METHOD(tls_socket_t, write_, ssize_t,
299 private_tls_socket_t *this, void *buf, size_t len)
300 {
301 this->app.out.ptr = buf;
302 this->app.out.len = len;
303 this->app.out_done = 0;
304 if (exchange(this, TRUE, FALSE))
305 {
306 return this->app.out_done;
307 }
308 return -1;
309 }
310
311 METHOD(tls_socket_t, splice, bool,
312 private_tls_socket_t *this, int rfd, int wfd)
313 {
314 char buf[PLAIN_BUF_SIZE], *pos;
315 ssize_t in, out;
316 bool old, crypto_eof = FALSE;
317 struct pollfd pfd[] = {
318 { .fd = this->fd, .events = POLLIN, },
319 { .fd = rfd, .events = POLLIN, },
320 };
321
322 while (!this->eof && !crypto_eof)
323 {
324 old = thread_cancelability(TRUE);
325 in = poll(pfd, countof(pfd), -1);
326 thread_cancelability(old);
327 if (in == -1)
328 {
329 DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
330 return FALSE;
331 }
332 while (!this->eof && pfd[0].revents & (POLLIN | POLLHUP | POLLNVAL))
333 {
334 in = read_(this, buf, sizeof(buf), FALSE);
335 switch (in)
336 {
337 case -1:
338 if (errno != EWOULDBLOCK)
339 {
340 DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
341 return FALSE;
342 }
343 break;
344 default:
345 pos = buf;
346 while (in)
347 {
348 out = write(wfd, pos, in);
349 if (out == -1)
350 {
351 DBG1(DBG_TLS, "TLS plain write error: %s",
352 strerror(errno));
353 return FALSE;
354 }
355 in -= out;
356 pos += out;
357 }
358 continue;
359 }
360 break;
361 }
362 if (!crypto_eof && pfd[1].revents & (POLLIN | POLLHUP | POLLNVAL))
363 {
364 in = read(rfd, buf, sizeof(buf));
365 switch (in)
366 {
367 case 0:
368 crypto_eof = TRUE;
369 break;
370 case -1:
371 DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
372 return FALSE;
373 default:
374 pos = buf;
375 while (in)
376 {
377 out = write_(this, pos, in);
378 if (out == -1)
379 {
380 DBG1(DBG_TLS, "TLS write error");
381 return FALSE;
382 }
383 in -= out;
384 pos += out;
385 }
386 break;
387 }
388 }
389 }
390 return TRUE;
391 }
392
393 METHOD(tls_socket_t, get_fd, int,
394 private_tls_socket_t *this)
395 {
396 return this->fd;
397 }
398
399 METHOD(tls_socket_t, get_server_id, identification_t*,
400 private_tls_socket_t *this)
401 {
402 return this->tls->get_server_id(this->tls);
403 }
404
405 METHOD(tls_socket_t, get_peer_id, identification_t*,
406 private_tls_socket_t *this)
407 {
408 return this->tls->get_peer_id(this->tls);
409 }
410
411 METHOD(tls_socket_t, destroy, void,
412 private_tls_socket_t *this)
413 {
414 /* send a TLS close notify if not done yet */
415 this->app.close = TRUE;
416 write_(this, NULL, 0);
417 free(this->app.cache.ptr);
418 this->tls->destroy(this->tls);
419 free(this);
420 }
421
422 /**
423 * See header
424 */
425 tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
426 identification_t *peer, int fd,
427 tls_cache_t *cache, tls_version_t min_version,
428 tls_version_t max_version, tls_flag_t flags)
429 {
430 private_tls_socket_t *this;
431
432 INIT(this,
433 .public = {
434 .read = _read_,
435 .write = _write_,
436 .splice = _splice,
437 .get_fd = _get_fd,
438 .get_server_id = _get_server_id,
439 .get_peer_id = _get_peer_id,
440 .destroy = _destroy,
441 },
442 .app = {
443 .application = {
444 .build = _build,
445 .process = _process,
446 .destroy = (void*)nop,
447 },
448 },
449 .fd = fd,
450 );
451
452 this->tls = tls_create(is_server, server, peer, TLS_PURPOSE_GENERIC,
453 &this->app.application, cache, flags);
454 if (!this->tls ||
455 !this->tls->set_version(this->tls, min_version, max_version))
456 {
457 free(this);
458 return NULL;
459 }
460 return &this->public;
461 }