]> git.ipfire.org Git - thirdparty/strongswan.git/commitdiff
Use a more POSIXy tls_socket interface with more flexibility.
authorMartin Willi <martin@revosec.ch>
Thu, 10 Jan 2013 15:20:06 +0000 (16:20 +0100)
committerMartin Willi <martin@revosec.ch>
Tue, 15 Jan 2013 16:43:05 +0000 (17:43 +0100)
If an unsufficient read buffer is provided, application data gets cached
for subsequent read() calls.

src/libtls/tls_socket.c
src/libtls/tls_socket.h

index 75b714e308469f3b9d3032b0a01a2428761bbaee..131bada9641d7335a78111f7850d183087f346b3 100644 (file)
@@ -42,14 +42,34 @@ struct private_tls_application_t {
        tls_application_t application;
 
        /**
-        * Chunk of data to send
+        * Output buffer to write to
         */
        chunk_t out;
 
        /**
-        * Chunk of data received
+        * Number of bytes written to out
+        */
+       size_t out_done;
+
+       /**
+        * Input buffer to read to
         */
        chunk_t in;
+
+       /**
+        * Number of bytes read to in
+        */
+       size_t in_done;
+
+       /**
+        * Cached input data
+        */
+       chunk_t cache;
+
+       /**
+        * Bytes cosnumed in cache
+        */
+       size_t cache_done;
 };
 
 /**
@@ -82,22 +102,37 @@ METHOD(tls_application_t, process, status_t,
        private_tls_application_t *this, bio_reader_t *reader)
 {
        chunk_t data;
+       size_t len;
 
-       if (!reader->read_data(reader, reader->remaining(reader), &data))
-       {
-               return FAILED;
+       len = min(reader->remaining(reader), this->in.len - this->in_done);
+       if (len)
+       {       /* copy to read buffer as much as fits in */
+               if (!reader->read_data(reader, len, &data))
+               {
+                       return FAILED;
+               }
+
+               memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
+               this->in_done += data.len;
+       }
+       else
+       {       /* read buffer is full, cache for next read */
+               if (!reader->read_data(reader, reader->remaining(reader), &data))
+               {
+                       return FAILED;
+               }
+               this->cache = chunk_cat("mc", this->cache, data);
        }
-       this->in = chunk_cat("mc", this->in, data);
        return NEED_MORE;
 }
 
 METHOD(tls_application_t, build, status_t,
        private_tls_application_t *this, bio_writer_t *writer)
 {
-       if (this->out.len)
+       if (this->out.len > this->out_done)
        {
                writer->write_data(writer, this->out);
-               this->out = chunk_empty;
+               this->out_done = this->out.len;
                return NEED_MORE;
        }
        return INVALID_STATE;
@@ -106,7 +141,7 @@ METHOD(tls_application_t, build, status_t,
 /**
  * TLS data exchange loop
  */
-static bool exchange(private_tls_socket_t *this, bool wr)
+static bool exchange(private_tls_socket_t *this, bool wr, bool block)
 {
        char buf[CRYPTO_BUF_SIZE], *pos;
        ssize_t len, out;
@@ -144,27 +179,38 @@ static bool exchange(private_tls_socket_t *this, bool wr)
                }
                if (wr)
                {
-                       if (this->app.out.len == 0)
+                       if (this->app.out_done == this->app.out.len)
                        {       /* all data written */
                                return TRUE;
                        }
                }
                else
                {
-                       if (this->app.in.len)
-                       {       /* some data received */
-                               return TRUE;
-                       }
-                       if (round > 0)
-                       {       /* did some handshaking, return empty chunk to not block */
+                       if (this->app.in_done == this->app.in.len)
+                       {       /* buffer fully received */
                                return TRUE;
                        }
                }
-               len = read(this->fd, buf, sizeof(buf));
-               if (len <= 0)
+               len = recv(this->fd, buf, sizeof(buf),
+                                  !block || this->app.in_done || round ? MSG_DONTWAIT : 0);
+               if (len < 0)
                {
+                       if (errno == EAGAIN || errno == EWOULDBLOCK)
+                       {
+                               if (this->app.in_done == 0)
+                               {
+                                       /* reading, nothing got yet, and call would block */
+                                       errno = EWOULDBLOCK;
+                                       this->app.in_done = -1;
+                               }
+                               return TRUE;
+                       }
                        return FALSE;
                }
+               if (len == 0)
+               {       /* EOF */
+                       return TRUE;
+               }
                if (this->tls->process(this->tls, buf, len) != NEED_MORE)
                {
                        return FALSE;
@@ -172,27 +218,45 @@ static bool exchange(private_tls_socket_t *this, bool wr)
        }
 }
 
-METHOD(tls_socket_t, read_, bool,
-       private_tls_socket_t *this, chunk_t *buf)
+METHOD(tls_socket_t, read_, ssize_t,
+       private_tls_socket_t *this, void *buf, size_t len, bool block)
 {
-       if (exchange(this, FALSE))
+       if (this->app.cache.len)
        {
-               *buf = this->app.in;
-               this->app.in = chunk_empty;
-               return TRUE;
+               size_t cache;
+
+               cache = min(len, this->app.cache.len - this->app.cache_done);
+               memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);
+
+               this->app.cache_done += cache;
+               if (this->app.cache_done == this->app.cache.len)
+               {
+                       chunk_free(&this->app.cache);
+                       this->app.cache_done = 0;
+               }
+               return cache;
        }
-       return FALSE;
+       this->app.in.ptr = buf;
+       this->app.in.len = len;
+       this->app.in_done = 0;
+       if (exchange(this, FALSE, block))
+       {
+               return this->app.in_done;
+       }
+       return -1;
 }
 
-METHOD(tls_socket_t, write_, bool,
-       private_tls_socket_t *this, chunk_t buf)
+METHOD(tls_socket_t, write_, ssize_t,
+       private_tls_socket_t *this, void *buf, size_t len)
 {
-       this->app.out = buf;
-       if (exchange(this, TRUE))
+       this->app.out.ptr = buf;
+       this->app.out.len = len;
+       this->app.out_done = 0;
+       if (exchange(this, TRUE, FALSE))
        {
-               return TRUE;
+               return this->app.out_done;
        }
-       return FALSE;
+       return -1;
 }
 
 METHOD(tls_socket_t, splice, bool,
@@ -200,68 +264,85 @@ METHOD(tls_socket_t, splice, bool,
 {
        char buf[PLAIN_BUF_SIZE], *pos;
        fd_set set;
-       chunk_t data;
-       ssize_t len;
-       bool old;
+       ssize_t in, out;
+       bool old, plain_eof = FALSE, crypto_eof = FALSE;
 
-       while (TRUE)
+       while (!plain_eof && !crypto_eof)
        {
                FD_ZERO(&set);
                FD_SET(rfd, &set);
                FD_SET(this->fd, &set);
 
                old = thread_cancelability(TRUE);
-               len = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL);
+               in = select(max(rfd, this->fd) + 1, &set, NULL, NULL, NULL);
                thread_cancelability(old);
-               if (len == -1)
+               if (in == -1)
                {
                        DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
                        return FALSE;
                }
-               if (FD_ISSET(this->fd, &set))
+               while (!plain_eof && FD_ISSET(this->fd, &set))
                {
-                       if (!read_(this, &data))
+                       in = read_(this, buf, sizeof(buf), FALSE);
+                       switch (in)
                        {
-                               DBG2(DBG_TLS, "TLS read error/disconnect");
-                               return TRUE;
-                       }
-                       pos = data.ptr;
-                       while (data.len)
-                       {
-                               len = write(wfd, pos, data.len);
-                               if (len == -1)
-                               {
-                                       free(data.ptr);
-                                       DBG1(DBG_TLS, "TLS plain write error: %s", strerror(errno));
-                                       return FALSE;
-                               }
-                               data.len -= len;
-                               pos += len;
+                               case 0:
+                                       plain_eof = TRUE;
+                                       break;
+                               case -1:
+                                       if (errno != EWOULDBLOCK)
+                                       {
+                                               DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
+                                               return FALSE;
+                                       }
+                                       break;
+                               default:
+                                       pos = buf;
+                                       while (in)
+                                       {
+                                               out = write(wfd, pos, in);
+                                               if (out == -1)
+                                               {
+                                                       DBG1(DBG_TLS, "TLS plain write error: %s",
+                                                                strerror(errno));
+                                                       return FALSE;
+                                               }
+                                               in -= out;
+                                               pos += out;
+                                       }
+                                       continue;
                        }
-                       free(data.ptr);
+                       break;
                }
-               if (FD_ISSET(rfd, &set))
+               if (!crypto_eof && FD_ISSET(rfd, &set))
                {
-                       len = read(rfd, buf, sizeof(buf));
-                       if (len > 0)
+                       in = read(rfd, buf, sizeof(buf));
+                       switch (in)
                        {
-                               if (!write_(this, chunk_create(buf, len)))
-                               {
-                                       DBG1(DBG_TLS, "TLS write error");
-                                       return FALSE;
-                               }
-                       }
-                       else
-                       {
-                               if (len < 0)
-                               {
+                               case 0:
+                                       crypto_eof = TRUE;
+                                       break;
+                               case -1:
                                        DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
                                        return FALSE;
-                               }
-                               return TRUE;
+                               default:
+                                       pos = buf;
+                                       while (in)
+                                       {
+                                               out = write_(this, pos, in);
+                                               if (out == -1)
+                                               {
+                                                       DBG1(DBG_TLS, "TLS write error");
+                                                       return FALSE;
+                                               }
+                                               in -= out;
+                                               pos += out;
+                                       }
+                                       break;
                        }
                }
        }
+       return TRUE;
 }
 
 METHOD(tls_socket_t, get_fd, int,
@@ -273,8 +354,8 @@ METHOD(tls_socket_t, get_fd, int,
 METHOD(tls_socket_t, destroy, void,
        private_tls_socket_t *this)
 {
+       free(this->app.cache.ptr);
        this->tls->destroy(this->tls);
-       free(this->app.in.ptr);
        free(this);
 }
 
index edd05fd29c2015c6c19caf31bd8dcee44093d0a6..4ddddc19e5ec1fe8903f4150123bc48aba8ab870 100644 (file)
@@ -35,24 +35,27 @@ typedef struct tls_socket_t tls_socket_t;
 struct tls_socket_t {
 
        /**
-        * Read data from secured socket, return allocated chunk.
+        * Read data from secured socket.
         *
         * This call is blocking, you may use select() on the underlying socket to
-        * wait for data. If the there was non-application data available, the
-        * read function can return an empty chunk.
+        * wait for data. If "block" is FALSE and no application data is available,
+        * the function returns -1 and sets errno to EWOULDBLOCK.
         *
-        * @param data          pointer to allocate received data
-        * @return                      TRUE if data received successfully
+        * @param buf           buffer to write received data to
+        * @param len           size of buffer
+        * @param block         TRUE to block this call, FALSE to fail if it would block
+        * @return                      number of bytes read, 0 on EOF, -1 on error
         */
-       bool (*read)(tls_socket_t *this, chunk_t *data);
+       ssize_t (*read)(tls_socket_t *this, void *buf, size_t len, bool block);
 
        /**
-        * Write a chunk of data over the secured socket.
+        * Write data over the secured socket.
         *
-        * @param data          data to send
-        * @return                      TRUE if data sent successfully
+        * @param buf           data to send
+        * @param len           number of bytes to write from buf
+        * @return                      number of bytes written, -1 on error
         */
-       bool (*write)(tls_socket_t *this, chunk_t data);
+       ssize_t (*write)(tls_socket_t *this, void *buf, size_t len);
 
        /**
         * Read/write plain data from file descriptor.