]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
pytests: add rehandshake test
authorTomas Krizek <tomas.krizek@nic.cz>
Tue, 27 Nov 2018 14:54:12 +0000 (15:54 +0100)
committerTomas Krizek <tomas.krizek@nic.cz>
Tue, 4 Dec 2018 16:13:42 +0000 (17:13 +0100)
12 files changed:
.gitlab-ci.yml
tests/pytests/kresd.py
tests/pytests/rehandshake/Makefile [new file with mode: 0644]
tests/pytests/rehandshake/array.h [new file with mode: 0644]
tests/pytests/rehandshake/tcp-proxy.c [new file with mode: 0644]
tests/pytests/rehandshake/tcp-proxy.h [new file with mode: 0644]
tests/pytests/rehandshake/tcproxy.c [new file with mode: 0644]
tests/pytests/rehandshake/tls-proxy.c [new file with mode: 0644]
tests/pytests/rehandshake/tls-proxy.h [new file with mode: 0644]
tests/pytests/rehandshake/tlsproxy.c [new file with mode: 0644]
tests/pytests/templates/kresd.conf.j2
tests/pytests/test_rehandshake.py [new file with mode: 0644]

index 39005be18cf7e58b5c9cdedab5bb31a850112a54..58e972931ee5839e5187ce5b6d5a9eb0190a0fd3 100644 (file)
@@ -287,6 +287,9 @@ pytests:run:
   except:
     - master
   script:
+    - pushd tests/pytests/rehandshake
+    - make all
+    - popd
     - PATH="$PREFIX/sbin:$PATH" ./ci/pytests/run.sh &> pytests.log.txt
   after_script:
     - tail -1 pytests.log.txt
index c1414e86e390120f066c7c9c6b30d5acae757099..717383cecd228879c3b85723bc9cfc20103928d1 100644 (file)
@@ -30,7 +30,7 @@ def create_file_from_template(template_path, dest, data):
         fh.write(rendered_template)
 
 
-Forward = namedtuple('Forward', ['proto', 'ip', 'port'])
+Forward = namedtuple('Forward', ['proto', 'ip', 'port', 'hostname', 'ca_file'])
 
 
 class Kresd(ContextDecorator):
@@ -223,9 +223,11 @@ KRESD_LOG_IO_CLOSE = re.compile(r'^\[io\].*closed by peer.*')
 
 
 @contextmanager
-def make_kresd(workdir, certname=None, ip='127.0.0.1', ip6='::1', forward=None, hints=None):
-    port = make_port(ip, ip6)
-    tls_port = make_port(ip, ip6)
+def make_kresd(
+        workdir, certname=None, ip='127.0.0.1', ip6='::1', forward=None, hints=None,
+        port=None, tls_port=None):
+    port = make_port(ip, ip6) if port is None else port
+    tls_port = make_port(ip, ip6) if tls_port is None else tls_port
     with Kresd(workdir, port, tls_port, ip, ip6, certname, forward=forward, hints=hints) as kresd:
         yield kresd
         with open(kresd.logfile_path) as log:  # display partial log for debugging
diff --git a/tests/pytests/rehandshake/Makefile b/tests/pytests/rehandshake/Makefile
new file mode 100644 (file)
index 0000000..170b89e
--- /dev/null
@@ -0,0 +1,28 @@
+CC=gcc
+CFLAGS_TLS=-DDEBUG -ggdb3 -O0 -lgnutls -luv
+CFLAGS_TCP=-DDEBUG -ggdb3 -O0 -luv
+
+all: tcproxy tlsproxy
+
+tlsproxy: tls-proxy.o tlsproxy.o
+       $(CC) tls-proxy.o tlsproxy.o -o tlsproxy $(CFLAGS_TLS)
+
+tls-proxy.o: tls-proxy.c tls-proxy.h array.h
+       $(CC) -c -o $@ $< $(CFLAGS_TLS)
+
+tlsproxy.o: tlsproxy.c tls-proxy.h
+       $(CC) -c -o $@ $< $(CFLAGS_TLS)
+
+tcproxy: tcp-proxy.o tcproxy.o
+       $(CC) tcp-proxy.o tcproxy.o -o tcproxy $(CFLAGS_TCP)
+
+tcp-proxy.o: tcp-proxy.c tcp-proxy.h array.h
+       $(CC) -c -o $@ $< $(CFLAGS_TCP)
+
+tcproxy.o: tcproxy.c tcp-proxy.h
+       $(CC) -c -o $@ $< $(CFLAGS_TCP)
+
+clean:
+       rm -f tcp-proxy.o tcproxy.o tcproxy tls-proxy.o tlsproxy.o tlsproxy
+
+.PHONY: all clean
diff --git a/tests/pytests/rehandshake/array.h b/tests/pytests/rehandshake/array.h
new file mode 100644 (file)
index 0000000..ece4dd1
--- /dev/null
@@ -0,0 +1,166 @@
+/*  Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+    This program is free software: you can redistribute it and/or modify
+    it under the terms of the GNU General Public License as published by
+    the Free Software Foundation, either version 3 of the License, or
+    (at your option) any later version.
+
+    This program is distributed in the hope that it will be useful,
+    but WITHOUT ANY WARRANTY; without even the implied warranty of
+    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+    GNU General Public License for more details.
+
+    You should have received a copy of the GNU General Public License
+    along with this program.  If not, see <https://www.gnu.org/licenses/>.
+ */
+
+/**
+ *
+ * @file array.h
+ * @brief A set of simple macros to make working with dynamic arrays easier.
+ *
+ * @note The C has no generics, so it is implemented mostly using macros.
+ * Be aware of that, as direct usage of the macros in the evaluating macros
+ * may lead to different expectations:
+ *
+ * @code{.c}
+ *     MIN(array_push(arr, val), other)
+ * @endcode
+ *
+ * May evaluate the code twice, leading to unexpected behaviour.
+ * This is a price to pay for the absence of proper generics.
+ *
+ * # Example usage:
+ *
+ * @code{.c}
+ *      array_t(const char*) arr;
+ *      array_init(arr);
+ *
+ *      // Reserve memory in advance
+ *      if (array_reserve(arr, 2) < 0) {
+ *          return ENOMEM;
+ *      }
+ *
+ *      // Already reserved, cannot fail
+ *      array_push(arr, "princess");
+ *      array_push(arr, "leia");
+ *
+ *      // Not reserved, may fail
+ *      if (array_push(arr, "han") < 0) {
+ *          return ENOMEM;
+ *      }
+ *
+ *      // It does not hide what it really is
+ *      for (size_t i = 0; i < arr.len; ++i) {
+ *          printf("%s\n", arr.at[i]);
+ *      }
+ *
+ *      // Random delete
+ *      array_del(arr, 0);
+ * @endcode
+ * \addtogroup generics
+ * @{
+ */
+
+#pragma once
+#include <stdlib.h>
+
+/** Simplified Qt containers growth strategy. */
+static inline size_t array_next_count(size_t want)
+{
+       if (want < 2048) {
+               return (want < 20) ? want + 4 : want * 2;
+       } else {
+               return want + 2048;
+       }
+}
+
+/** @internal Incremental memory reservation */
+static inline int array_std_reserve(void *baton, char **mem, size_t elm_size, size_t want, size_t *have)
+{
+       if (*have >= want) {
+               return 0;
+       }
+       /* Simplified Qt containers growth strategy */
+       size_t next_size = array_next_count(want);
+       void *mem_new = realloc(*mem, next_size * elm_size);
+       if (mem_new != NULL) {
+               *mem = mem_new;
+               *have = next_size;
+               return 0;
+       }
+       return -1;
+}
+
+/** @internal Wrapper for stdlib free. */
+static inline void array_std_free(void *baton, void *p)
+{
+       free(p);
+}
+
+/** Declare an array structure. */
+#define array_t(type) struct {type * at; size_t len; size_t cap; }
+
+/** Zero-initialize the array. */
+#define array_init(array) ((array).at = NULL, (array).len = (array).cap = 0)
+
+/** Free and zero-initialize the array (plain malloc/free). */
+#define array_clear(array) \
+       array_clear_mm(array, array_std_free, NULL)
+
+/** Make the array empty and free pointed-to memory.
+ * Mempool usage: pass mm_free and a knot_mm_t* . */
+#define array_clear_mm(array, free, baton) \
+       (free)((baton), (array).at), array_init(array)
+
+/** Reserve capacity for at least n elements.
+ * @return 0 if success, <0 on failure */
+#define array_reserve(array, n) \
+       array_reserve_mm(array, n, array_std_reserve, NULL)
+
+/** Reserve capacity for at least n elements.
+ * Mempool usage: pass kr_memreserve and a knot_mm_t* .
+ * @return 0 if success, <0 on failure */
+#define array_reserve_mm(array, n, reserve, baton) \
+       (reserve)((baton), (char **) &(array).at, sizeof((array).at[0]), (n), &(array).cap)
+
+/**
+ * Push value at the end of the array, resize it if necessary.
+ * Mempool usage: pass kr_memreserve and a knot_mm_t* .
+ * @note May fail if the capacity is not reserved.
+ * @return element index on success, <0 on failure
+ */
+#define array_push_mm(array, val, reserve, baton) \
+       (int)((array).len < (array).cap ? ((array).at[(array).len] = val, (array).len++) \
+               : (array_reserve_mm(array, ((array).cap + 1), reserve, baton) < 0 ? -1 \
+                       : ((array).at[(array).len] = val, (array).len++)))
+
+/**
+ * Push value at the end of the array, resize it if necessary (plain malloc/free).
+ * @note May fail if the capacity is not reserved.
+ * @return element index on success, <0 on failure
+ */
+#define array_push(array, val) \
+       array_push_mm(array, val, array_std_reserve, NULL)
+
+/**
+ * Pop value from the end of the array.
+ */
+#define array_pop(array) \
+       (array).len -= 1
+
+/**
+ * Remove value at given index.
+ * @return 0 on success, <0 on failure
+ */
+#define array_del(array, i) \
+       (int)((i) < (array).len ? ((array).len -= 1,(array).at[i] = (array).at[(array).len], 0) : -1)
+
+/**
+ * Return last element of the array.
+ * @warning Undefined if the array is empty.
+ */
+#define array_tail(array) \
+    (array).at[(array).len - 1]
+
+/** @} */
diff --git a/tests/pytests/rehandshake/tcp-proxy.c b/tests/pytests/rehandshake/tcp-proxy.c
new file mode 100644 (file)
index 0000000..ba7198b
--- /dev/null
@@ -0,0 +1,336 @@
+#include <assert.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <string.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <uv.h>
+#include "array.h"
+
+struct buf {
+       char buf[16 * 1024];
+       size_t size;
+};
+
+enum peer_state {
+       STATE_NOT_CONNECTED,
+       STATE_LISTENING,
+       STATE_CONNECTED,
+       STATE_CONNECT_IN_PROGRESS,
+       STATE_CLOSING_IN_PROGRESS
+};
+
+struct proxy_ctx {
+       uv_loop_t *loop;
+       uv_tcp_t server;
+       uv_tcp_t client;
+       uv_tcp_t upstream;
+       struct sockaddr_storage server_addr;
+       struct sockaddr_storage upstream_addr;
+       
+       int server_state;
+       int client_state;
+       int upstream_state;
+
+       array_t(struct buf *) buffer_pool;
+       array_t(struct buf *) upstream_pending;
+};
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf);
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf);
+
+static struct buf *borrow_io_buffer(struct proxy_ctx *proxy)
+{
+       struct buf *buf = NULL;
+       if (proxy->buffer_pool.len > 0) {
+               buf = array_tail(proxy->buffer_pool);
+               array_pop(proxy->buffer_pool);
+       } else {
+               buf = calloc(1, sizeof (struct buf));
+       }
+       return buf;
+}
+
+static void release_io_buffer(struct proxy_ctx *proxy, struct buf *buf)
+{
+       if (!buf) {
+               return;
+       }
+
+       if (proxy->buffer_pool.len < 1000) {
+               buf->size = 0;
+               array_push(proxy->buffer_pool, buf);
+       } else {
+               free(buf);
+       }
+}
+
+static void push_to_upstream_pending(struct proxy_ctx *proxy, const char *buf, size_t size)
+{
+       while (size > 0) {
+               struct buf *b = borrow_io_buffer(proxy);
+               b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf);
+               memcpy(b->buf, buf, b->size);
+               array_push(proxy->upstream_pending, b);
+               size -= b->size;
+       }
+}
+
+static struct buf *get_first_upstream_pending(struct proxy_ctx *proxy)
+{
+       struct buf *buf = NULL;
+       if (proxy->upstream_pending.len > 0) {
+               buf = proxy->upstream_pending.at[0];
+       }
+       return buf;
+}
+
+static void remove_first_upstream_pending(struct proxy_ctx *proxy)
+{
+       for (int i = 1; i < proxy->upstream_pending.len; ++i) {
+               proxy->upstream_pending.at[i - 1] = proxy->upstream_pending.at[i];
+       }
+       if (proxy->upstream_pending.len > 0) {
+               proxy->upstream_pending.len -= 1;
+       }
+}
+
+static void clear_upstream_pending(struct proxy_ctx *proxy)
+{
+       for (int i = 1; i < proxy->upstream_pending.len; ++i) {
+               struct buf *b = proxy->upstream_pending.at[i];
+               release_io_buffer(proxy, b);
+       }
+       proxy->upstream_pending.len = 0;
+}
+
+static void clear_buffer_pool(struct proxy_ctx *proxy)
+{
+       for (int i = 1; i < proxy->buffer_pool.len; ++i) {
+               struct buf *b = proxy->buffer_pool.at[i];
+               free(b);
+       }
+       proxy->buffer_pool.len = 0;
+}
+
+static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf)
+{
+       buf->base = (char*)malloc(suggested_size);
+       buf->len = suggested_size;
+}
+
+static void on_client_close(uv_handle_t *handle)
+{
+       struct proxy_ctx *proxy = (struct proxy_ctx *)handle->loop->data;
+       proxy->client_state = STATE_NOT_CONNECTED;
+}
+
+static void on_upstream_close(uv_handle_t *handle)
+{
+       struct proxy_ctx *proxy = (struct proxy_ctx *)handle->loop->data;
+       proxy->upstream_state = STATE_NOT_CONNECTED;
+}
+
+static void write_to_client_cb(uv_write_t *req, int status)
+{
+       struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data;
+       free(req);
+       if (status) {
+               fprintf(stderr, "error writing to client: %s\n", uv_strerror(status));
+               clear_upstream_pending(proxy);
+               proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+               uv_close((uv_handle_t*)&proxy->client, on_client_close);
+       }
+}
+
+static void write_to_upstream_cb(uv_write_t *req, int status)
+{
+       struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data;
+       free(req);
+       if (status) {
+               fprintf(stderr, "error writing to upstream: %s\n", uv_strerror(status));
+               clear_upstream_pending(proxy);
+               proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+               uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+               return;
+       }
+       if (proxy->upstream_pending.len > 0) {
+               struct buf *buf = get_first_upstream_pending(proxy);
+               remove_first_upstream_pending(proxy);
+               release_io_buffer(proxy, buf);
+               if (proxy->upstream_state == STATE_CONNECTED &&
+                   proxy->upstream_pending.len > 0) {
+                       buf = get_first_upstream_pending(proxy);
+                       /* TODO avoid allocation */
+                       uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+                       uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+                       uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+               }
+       }
+}
+
+static void on_client_connection(uv_stream_t *server, int status)
+{
+       if (status < 0) {
+               fprintf(stderr, "incoming connection error: %s\n", uv_strerror(status));
+               return;
+       }
+
+       fprintf(stdout, "incoming connection\n");
+
+       struct proxy_ctx *proxy = (struct proxy_ctx *)server->loop->data;
+       if (proxy->client_state != STATE_NOT_CONNECTED) {
+               fprintf(stderr, "client already connected, ignoring\n");
+               return;
+       }
+
+       uv_tcp_init(proxy->loop, &proxy->client);
+       proxy->client_state = STATE_CONNECTED;
+       if (uv_accept(server, (uv_stream_t*)&proxy->client) == 0) {
+               uv_read_start((uv_stream_t*)&proxy->client, alloc_uv_buffer, read_from_client_cb);
+       } else {
+               proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+               uv_close((uv_handle_t*)&proxy->client, on_client_close);
+       }
+}
+
+static void on_connect_to_upstream(uv_connect_t *req, int status)
+{
+       struct proxy_ctx *proxy = (struct proxy_ctx *)req->handle->loop->data;
+       free(req);
+       if (status < 0) {
+               fprintf(stderr, "error connecting to upstream: %s\n", uv_strerror(status));
+               clear_upstream_pending(proxy);
+               proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+               uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+               return;
+       }
+
+       proxy->upstream_state = STATE_CONNECTED;
+       uv_read_start((uv_stream_t*)&proxy->upstream, alloc_uv_buffer, read_from_upstream_cb);
+       if (proxy->upstream_pending.len > 0) {
+               struct buf *buf = get_first_upstream_pending(proxy);
+               /* TODO avoid allocation */
+               uv_write_t *wreq = (uv_write_t *) malloc(sizeof(uv_write_t));
+               uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+               uv_write(wreq, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+       }
+}
+
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf)
+{
+       if (nread == 0) {
+               return;
+       }
+       struct proxy_ctx *proxy = (struct proxy_ctx *)client->loop->data;
+       if (nread < 0) {
+               if (nread != UV_EOF) {
+                       fprintf(stderr, "error reading from client: %s\n", uv_err_name(nread));
+               }
+               if (proxy->client_state == STATE_CONNECTED) {
+                       proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+                       uv_close((uv_handle_t*) client, on_client_close);
+               }
+               return;
+       }
+       if (proxy->upstream_state == STATE_CONNECTED) {
+               if (proxy->upstream_pending.len > 0) {
+                       push_to_upstream_pending(proxy, buf->base, nread);
+               } else {
+                       /* TODO avoid allocation */
+                       uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+                       uv_buf_t wrbuf = uv_buf_init(buf->base, nread);
+                       uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+               }
+       } else if (proxy->upstream_state == STATE_NOT_CONNECTED) {
+               /* TODO avoid allocation */
+               uv_tcp_init(proxy->loop, &proxy->upstream);     
+               uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t));
+               proxy->upstream_state = STATE_CONNECT_IN_PROGRESS;
+               uv_tcp_connect(conn, &proxy->upstream, (struct sockaddr *)&proxy->upstream_addr,
+                              on_connect_to_upstream);
+               push_to_upstream_pending(proxy, buf->base, nread);
+       } else if (proxy->upstream_state == STATE_CONNECT_IN_PROGRESS) {
+               push_to_upstream_pending(proxy, buf->base, nread);
+       }
+}
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf)
+{
+       if (nread == 0) {
+               return;
+       }
+       struct proxy_ctx *proxy = (struct proxy_ctx *)upstream->loop->data;
+       if (nread < 0) {
+               if (nread != UV_EOF) {
+                       fprintf(stderr, "error reading from upstream: %s\n", uv_err_name(nread));
+               }
+               clear_upstream_pending(proxy);
+               if (proxy->upstream_state == STATE_CONNECTED) {
+                       proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+                       uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+               }
+               return;
+       }
+       if (proxy->client_state == STATE_CONNECTED) {
+               /* TODO Avoid allocation */
+               uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+               uv_buf_t wrbuf = uv_buf_init(buf->base, nread);
+               uv_write(req, (uv_stream_t *)&proxy->client, &wrbuf, 1, write_to_client_cb);
+       }
+}
+
+struct proxy_ctx *proxy_allocate()
+{
+       return malloc(sizeof(struct proxy_ctx));
+}
+
+int proxy_init(struct proxy_ctx *proxy,
+              const char *server_addr, int server_port,
+              const char *upstream_addr, int upstream_port)
+{
+       proxy->loop = uv_default_loop();
+       uv_tcp_init(proxy->loop, &proxy->server);
+       int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server_addr);
+       if (res != 0) {
+               return res;
+       }
+       res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr);
+       if (res != 0) {
+               return res;
+       }
+       array_init(proxy->buffer_pool);
+       array_init(proxy->upstream_pending);
+       proxy->server_state = STATE_NOT_CONNECTED;
+       proxy->client_state = STATE_NOT_CONNECTED;
+       proxy->upstream_state = STATE_NOT_CONNECTED;
+
+       proxy->loop->data = proxy;
+       return 0;
+}
+
+void proxy_free(struct proxy_ctx *proxy)
+{
+       if (!proxy) {
+               return;
+       }
+       clear_upstream_pending(proxy);
+       clear_buffer_pool(proxy);
+       /* TODO correctly close all the uv_tcp_t */
+       free(proxy);
+}
+
+int proxy_start_listen(struct proxy_ctx *proxy)
+{      
+       uv_tcp_bind(&proxy->server, (const struct sockaddr*)&proxy->server_addr, 0);
+       int ret = uv_listen((uv_stream_t*)&proxy->server, 128, on_client_connection);
+       if (ret == 0) {
+               proxy->server_state = STATE_LISTENING;
+       }
+       return ret;
+}
+
+int proxy_run(struct proxy_ctx *proxy)
+{
+       return uv_run(proxy->loop, UV_RUN_DEFAULT);
+}
diff --git a/tests/pytests/rehandshake/tcp-proxy.h b/tests/pytests/rehandshake/tcp-proxy.h
new file mode 100644 (file)
index 0000000..668a65f
--- /dev/null
@@ -0,0 +1,12 @@
+#pragma once
+
+struct proxy_ctx;
+
+struct proxy_ctx *proxy_allocate();
+void proxy_free(struct proxy_ctx *proxy);
+int proxy_init(struct proxy_ctx *proxy,
+              const char *server_addr, int server_port,
+              const char *upstream_addr, int upstream_port);
+int proxy_start_listen(struct proxy_ctx *proxy);
+int proxy_run(struct proxy_ctx *proxy);
+
diff --git a/tests/pytests/rehandshake/tcproxy.c b/tests/pytests/rehandshake/tcproxy.c
new file mode 100644 (file)
index 0000000..87a6b4c
--- /dev/null
@@ -0,0 +1,25 @@
+#include <stdio.h>
+#include "tcp-proxy.h"
+
+int main()
+{
+       struct proxy_ctx *proxy = proxy_allocate();
+       if (!proxy) {
+               fprintf(stderr, "can't allocate proxy structure\n");
+               return 1;
+       }
+       int res = proxy_init(proxy, "127.0.0.1", 54000, "127.0.0.1", 53001);
+       if (res) {
+               fprintf(stderr, "can't initialize proxy by given addresses\n");
+               return res;
+       }
+       res = proxy_start_listen(proxy);
+       if (res) {
+               fprintf(stderr, "error starting listen, error code: %i\n", res);
+               return res;
+       }
+       res = proxy_run(proxy);
+       proxy_free(proxy);
+       return res;
+}
+
diff --git a/tests/pytests/rehandshake/tls-proxy.c b/tests/pytests/rehandshake/tls-proxy.c
new file mode 100644 (file)
index 0000000..5f731fd
--- /dev/null
@@ -0,0 +1,830 @@
+#include <assert.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <string.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <gnutls/gnutls.h>
+#include <uv.h>
+#include "array.h"
+
+#define TLS_MAX_SEND_RETRIES 100
+
+struct buf {
+       char buf[16 * 1024];
+       size_t size;
+};
+
+enum peer_state {
+       STATE_NOT_CONNECTED,
+       STATE_LISTENING,
+       STATE_CONNECTED,
+       STATE_CONNECT_IN_PROGRESS,
+       STATE_CLOSING_IN_PROGRESS
+};
+
+enum handshake_state {
+       TLS_HS_NOT_STARTED = 0,
+       TLS_HS_IN_PROGRESS,
+       TLS_HS_DONE,
+       TLS_HS_CLOSING,
+       TLS_HS_LAST
+};
+
+struct tls_ctx {
+       gnutls_session_t session;
+       int handshake_state;
+       gnutls_certificate_credentials_t credentials;
+        gnutls_priority_t priority_cache;
+       /* for reading from the network */
+       const uint8_t *buf;
+       ssize_t nread;
+       ssize_t consumed;
+       uint8_t recv_buf[4096];
+};
+
+struct tls_proxy_ctx {
+       uv_loop_t *loop;
+       uv_tcp_t server;
+       uv_tcp_t client;
+       uv_tcp_t upstream;
+       struct sockaddr_storage server_addr;
+       struct sockaddr_storage upstream_addr;
+       struct sockaddr_storage client_addr;
+       
+       int server_state;
+       int client_state;
+       int upstream_state;
+
+       array_t(struct buf *) buffer_pool;
+       array_t(struct buf *) upstream_pending;
+       array_t(struct buf *) client_pending;
+       
+       char io_buf[0xFFFF];
+       struct tls_ctx tls;
+};
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf);
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf);
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len);
+static ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len);
+static int tls_process_from_upstream(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread);
+static int tls_process_from_client(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread);
+static int write_to_upstream_pending(struct tls_proxy_ctx *proxy);
+static int write_to_client_pending(struct tls_proxy_ctx *proxy);
+
+
+static int gnutls_references = 0;
+
+const void *ip_addr(const struct sockaddr *addr)
+{
+       if (!addr) {
+               return NULL;
+       }
+       switch (addr->sa_family) {
+       case AF_INET:  return (const void *)&(((const struct sockaddr_in *)addr)->sin_addr);
+       case AF_INET6: return (const void *)&(((const struct sockaddr_in6 *)addr)->sin6_addr);
+       default:       return NULL;
+       }
+}
+
+uint16_t ip_addr_port(const struct sockaddr *addr)
+{
+       if (!addr) {
+               return 0;
+       }
+       switch (addr->sa_family) {
+       case AF_INET:  return ntohs(((const struct sockaddr_in *)addr)->sin_port);
+       case AF_INET6: return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port);
+       default:       return 0;
+       }
+}
+
+static int ip_addr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
+{
+       int ret = 0;
+       if (!addr || !buf || !buflen) {
+               return EINVAL;
+       }
+
+       char str[INET6_ADDRSTRLEN + 6];
+       if (!inet_ntop(addr->sa_family, ip_addr(addr), str, sizeof(str))) {
+               return errno;
+       }
+       int len = strlen(str);
+       str[len] = '#';
+       snprintf(&str[len + 1], 6, "%uh", ip_addr_port(addr));
+       len += 6;
+       str[len] = 0;
+       if (len >= *buflen) {
+               ret = ENOSPC;
+       } else {
+               memcpy(buf, str, len + 1);
+       }
+       *buflen = len;
+       return ret;
+}
+
+static inline char *ip_straddr(const struct sockaddr *addr)
+{
+       assert(addr != NULL);
+       /* We are the sinle-threaded application */
+       static char str[INET6_ADDRSTRLEN + 6];
+       size_t len = sizeof(str);
+       int ret = ip_addr_str(addr, str, &len);
+       return ret != 0 || len == 0 ? NULL : str;
+}
+
+static struct buf *borrow_io_buffer(struct tls_proxy_ctx *proxy)
+{
+       struct buf *buf = NULL;
+       if (proxy->buffer_pool.len > 0) {
+               buf = array_tail(proxy->buffer_pool);
+               array_pop(proxy->buffer_pool);
+       } else {
+               buf = calloc(1, sizeof (struct buf));
+       }
+       return buf;
+}
+
+static void release_io_buffer(struct tls_proxy_ctx *proxy, struct buf *buf)
+{
+       if (!buf) {
+               return;
+       }
+
+       if (proxy->buffer_pool.len < 1000) {
+               buf->size = 0;
+               array_push(proxy->buffer_pool, buf);
+       } else {
+               free(buf);
+       }
+}
+
+static struct buf *get_first_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+       struct buf *buf = NULL;
+       if (proxy->upstream_pending.len > 0) {
+               buf = proxy->upstream_pending.at[0];
+       }
+       return buf;
+}
+
+static struct buf *get_first_client_pending(struct tls_proxy_ctx *proxy)
+{
+       struct buf *buf = NULL;
+       if (proxy->client_pending.len > 0) {
+               buf = proxy->client_pending.at[0];
+       }
+       return buf;
+}
+
+static void remove_first_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+       for (int i = 1; i < proxy->upstream_pending.len; ++i) {
+               proxy->upstream_pending.at[i - 1] = proxy->upstream_pending.at[i];
+       }
+       if (proxy->upstream_pending.len > 0) {
+               proxy->upstream_pending.len -= 1;
+       }
+}
+
+static void remove_first_client_pending(struct tls_proxy_ctx *proxy)
+{
+       for (int i = 1; i < proxy->client_pending.len; ++i) {
+               proxy->client_pending.at[i - 1] = proxy->client_pending.at[i];
+       }
+       if (proxy->client_pending.len > 0) {
+               proxy->client_pending.len -= 1;
+       }
+}
+
+static void clear_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+       for (int i = 0; i < proxy->upstream_pending.len; ++i) {
+               struct buf *b = proxy->upstream_pending.at[i];
+               release_io_buffer(proxy, b);
+       }
+       proxy->upstream_pending.len = 0;
+}
+
+static void clear_client_pending(struct tls_proxy_ctx *proxy)
+{
+       for (int i = 0; i < proxy->client_pending.len; ++i) {
+               struct buf *b = proxy->client_pending.at[i];
+               release_io_buffer(proxy, b);
+       }
+       proxy->client_pending.len = 0;
+}
+
+static void clear_buffer_pool(struct tls_proxy_ctx *proxy)
+{
+       for (int i = 0; i < proxy->buffer_pool.len; ++i) {
+               struct buf *b = proxy->buffer_pool.at[i];
+               free(b);
+       }
+       proxy->buffer_pool.len = 0;
+}
+
+static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+       buf->base = proxy->io_buf;
+       buf->len = sizeof(proxy->io_buf);
+}
+
+static void on_client_close(uv_handle_t *handle)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+       gnutls_deinit(proxy->tls.session);
+       proxy->tls.handshake_state = TLS_HS_NOT_STARTED;
+       proxy->client_state = STATE_NOT_CONNECTED;
+}
+
+static void on_dummmy_client_close(uv_handle_t *handle)
+{
+       free(handle);
+}
+
+static void on_upstream_close(uv_handle_t *handle)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+       proxy->upstream_state = STATE_NOT_CONNECTED;
+}
+
+static void write_to_client_cb(uv_write_t *req, int status)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data;
+       free(req);
+       if (status) {
+               fprintf(stderr, "error writing to client: %s\n", uv_strerror(status));
+               clear_client_pending(proxy);
+               clear_upstream_pending(proxy);
+               if (proxy->client_state == STATE_CONNECTED) {
+                       proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+                       uv_close((uv_handle_t*)&proxy->client, on_client_close);
+                       return;
+               }
+       }
+       if (proxy->client_pending.len > 0) {
+               struct buf *buf = get_first_client_pending(proxy);
+               fprintf(stdout, "successfully wrote %zi bytes to client, pending len is %zd\n",
+                       buf->size, proxy->client_pending.len);
+               remove_first_client_pending(proxy);
+               release_io_buffer(proxy, buf);
+               if (proxy->client_state == STATE_CONNECTED &&
+                   proxy->tls.handshake_state == TLS_HS_DONE) {
+                       write_to_client_pending(proxy);
+               }
+       }
+}
+
+static void write_to_upstream_cb(uv_write_t *req, int status)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data;
+       free(req);
+       if (status) {
+               fprintf(stderr, "error writing to upstream: %s\n", uv_strerror(status));
+               clear_upstream_pending(proxy);
+               proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+               uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+               return;
+       }
+       fprintf(stdout, "successfully wrote to upstream, pending len is %zd\n", proxy->upstream_pending.len);
+       if (proxy->upstream_pending.len > 0) {
+               struct buf *buf = get_first_upstream_pending(proxy);
+               remove_first_upstream_pending(proxy);
+               release_io_buffer(proxy, buf);
+               if (proxy->upstream_state == STATE_CONNECTED &&
+                   proxy->upstream_pending.len > 0) {
+                       write_to_upstream_pending(proxy);
+               }
+       }
+}
+
+static void on_client_connection(uv_stream_t *server, int status)
+{
+       if (status < 0) {
+               fprintf(stderr, "incoming connection error: %s\n", uv_strerror(status));
+               return;
+       }
+
+       int err = 0;
+       int ret = 0;
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data;
+       if (proxy->client_state != STATE_NOT_CONNECTED) {
+               fprintf(stderr, "incoming connection");
+               uv_tcp_t *dummy_client = malloc(sizeof(uv_tcp_t));
+               uv_tcp_init(proxy->loop, dummy_client);
+               err = uv_accept(server, (uv_stream_t*)dummy_client);
+               if (err == 0) {
+                       struct sockaddr dummy_addr;
+                       int dummy_addr_len = sizeof(dummy_addr);
+                       ret = uv_tcp_getpeername(dummy_client,
+                                                &dummy_addr,
+                                                &dummy_addr_len);
+                       if (ret == 0) {
+                               fprintf(stderr, " from %s", ip_straddr(&dummy_addr));
+                       }
+                       uv_close((uv_handle_t *)dummy_client, on_dummmy_client_close);
+               } else {
+                       on_dummmy_client_close((uv_handle_t *)dummy_client);
+               }
+               fprintf(stderr, " - client already connected, rejecting\n");
+               return;
+       }
+
+       uv_tcp_init(proxy->loop, &proxy->client);
+       uv_tcp_nodelay((uv_tcp_t *)&proxy->client, 1);
+       proxy->client_state = STATE_CONNECTED;
+       err = uv_accept(server, (uv_stream_t*)&proxy->client);
+       if (err != 0) {
+               fprintf(stderr, "incoming connection - uv_accept() failed: (%d) %s\n",
+                            err, uv_strerror(err));
+               return;
+       }
+
+       struct sockaddr *addr = (struct sockaddr *)&(proxy->client_addr);
+       int addr_len = sizeof(proxy->client_addr);
+       ret = uv_tcp_getpeername(&proxy->client, addr, &addr_len);
+       if (ret || addr->sa_family == AF_UNSPEC) {
+               proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+               uv_close((uv_handle_t*)&proxy->client, on_client_close);
+               fprintf(stderr, "incoming connection - uv_tcp_getpeername() failed: (%d) %s\n",
+                            err, uv_strerror(err));
+               return;
+       }
+       
+       fprintf(stdout, "incoming connection from %s\n", ip_straddr(addr));
+       
+       uv_read_start((uv_stream_t*)&proxy->client, alloc_uv_buffer, read_from_client_cb);
+
+       const char *errpos = NULL;
+       struct tls_ctx *tls = &proxy->tls;
+       assert (tls->handshake_state == TLS_HS_NOT_STARTED);
+       err = gnutls_init(&tls->session, GNUTLS_SERVER | GNUTLS_NONBLOCK);
+       if (err != GNUTLS_E_SUCCESS) {
+               fprintf(stderr, "gnutls_init() failed: (%d) %s\n",
+                            err, gnutls_strerror_name(err));
+       }
+        err = gnutls_priority_set(tls->session, tls->priority_cache);
+       if (err != GNUTLS_E_SUCCESS) {
+               fprintf(stderr, "gnutls_priority_set() failed: (%d) %s\n",
+                            err, gnutls_strerror_name(err));
+       }
+       err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->credentials);
+       if (err != GNUTLS_E_SUCCESS) {
+               fprintf(stderr, "gnutls_credentials_set() failed: (%d) %s\n",
+                            err, gnutls_strerror_name(err));
+       }
+       gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_IGNORE);
+       gnutls_handshake_set_timeout(tls->session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
+
+       gnutls_transport_set_pull_function(tls->session, proxy_gnutls_pull);
+       gnutls_transport_set_push_function(tls->session, proxy_gnutls_push);
+       gnutls_transport_set_ptr(tls->session, proxy);
+
+       tls->handshake_state = TLS_HS_IN_PROGRESS;
+}
+
+static void on_connect_to_upstream(uv_connect_t *req, int status)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)req->handle->loop->data;
+       free(req);
+       if (status < 0) {
+               fprintf(stderr, "error connecting to upstream (%s): %s\n",
+                       ip_straddr((struct sockaddr *)&proxy->upstream_addr),
+                       uv_strerror(status));
+               clear_upstream_pending(proxy);
+               proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+               uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+               return;
+       }
+       fprintf(stdout, "connected to %s\n", ip_straddr((struct sockaddr *)&proxy->upstream_addr));
+
+       proxy->upstream_state = STATE_CONNECTED;
+       uv_read_start((uv_stream_t*)&proxy->upstream, alloc_uv_buffer, read_from_upstream_cb);
+       if (proxy->upstream_pending.len > 0) {
+               write_to_upstream_pending(proxy);
+       }
+}
+
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf)
+{
+       fprintf(stdout, "reading %zd bytes from client\n", nread);
+       if (nread == 0) {
+               return;
+       }
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)client->loop->data;
+       if (nread < 0) {
+               if (nread != UV_EOF) {
+                       fprintf(stderr, "error reading from client: %s\n", uv_err_name(nread));
+               } else {
+                       fprintf(stdout, "client has closed the connection\n");
+               }
+               if (proxy->client_state == STATE_CONNECTED) {
+                       proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+                       uv_close((uv_handle_t*) client, on_client_close);
+               }
+               return;
+       }
+
+       int res = tls_process_from_client(proxy, buf->base, nread);
+       if (res < 0) {
+               if (proxy->client_state == STATE_CONNECTED) {
+                       proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+                       uv_close((uv_handle_t*) client, on_client_close);
+               }
+       }
+}
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf)
+{
+       fprintf(stdout, "reading %zd bytes from upstream\n", nread);
+       if (nread == 0) {
+               return;
+       }
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)upstream->loop->data;
+       if (nread < 0) {
+               if (nread != UV_EOF) {
+                       fprintf(stderr, "error reading from upstream: %s\n", uv_err_name(nread));
+               } else {
+                       fprintf(stdout, "upstream has closed the connection\n");
+               }
+               clear_upstream_pending(proxy);
+               if (proxy->upstream_state == STATE_CONNECTED) {
+                       proxy->upstream_state = STATE_CLOSING_IN_PROGRESS;
+                       uv_close((uv_handle_t*)&proxy->upstream, on_upstream_close);
+               }
+               return;
+       }
+       int res = tls_process_from_upstream(proxy, buf->base, nread);
+       if (res < 0) {
+               fprintf(stderr, "error sending tls data to client\n");
+               if (proxy->client_state == STATE_CONNECTED) {
+                       proxy->client_state = STATE_CLOSING_IN_PROGRESS;
+                       uv_close((uv_handle_t*)&proxy->client, on_client_close);
+               }
+       }
+}
+
+static void push_to_upstream_pending(struct tls_proxy_ctx *proxy, const char *buf, size_t size)
+{
+       while (size > 0) {
+               struct buf *b = borrow_io_buffer(proxy);
+               b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf);
+               memcpy(b->buf, buf, b->size);
+               array_push(proxy->upstream_pending, b);
+               size -= b->size;
+       }
+}
+
+static void push_to_client_pending(struct tls_proxy_ctx *proxy, const char *buf, size_t size)
+{
+       while (size > 0) {
+               struct buf *b = borrow_io_buffer(proxy);
+               b->size = size <= sizeof(b->buf) ? size : sizeof(b->buf);
+               memcpy(b->buf, buf, b->size);
+               array_push(proxy->client_pending, b);
+               size -= b->size;
+       }
+}
+
+static int write_to_upstream_pending(struct tls_proxy_ctx *proxy)
+{
+       struct buf *buf = get_first_upstream_pending(proxy);
+       /* TODO avoid allocation */
+       uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+       uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+       fprintf(stdout, "writing %zd bytes to upstream\n", buf->size);
+       return uv_write(req, (uv_stream_t *)&proxy->upstream, &wrbuf, 1, write_to_upstream_cb);
+}
+
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)h;
+       struct tls_ctx *t = &proxy->tls;
+
+       fprintf(stdout, "\t gnutls: pulling %zd bytes from client\n", len);
+
+       if (t->nread <= t->consumed) {
+               errno = EAGAIN;
+               fprintf(stdout, "\t gnutls: return EAGAIN\n");
+               return -1;
+       }
+
+       ssize_t avail = t->nread - t->consumed;
+       ssize_t transfer = (avail <= len ? avail : len);
+       memcpy(buf, t->buf + t->consumed, transfer);
+       t->consumed += transfer;
+       return transfer;
+}
+
+ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len)
+{
+       struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)h;
+       struct tls_ctx *t = &proxy->tls;
+       fprintf(stdout, "\t gnutls: writing %zd bytes to client\n", len);
+
+       ssize_t ret = -1;
+       const size_t req_size_aligned = ((sizeof(uv_write_t) / 16) + 1) * 16;
+       char *common_buf = malloc(req_size_aligned + len);
+       uv_write_t *req = (uv_write_t *) common_buf;
+       char *data = common_buf + req_size_aligned;
+       const uv_buf_t uv_buf[1] = {
+               { data, len }
+       };
+       memcpy(data, buf, len);
+       int res = uv_write(req, (uv_stream_t *)&proxy->client, uv_buf, 1, write_to_client_cb);
+       if (res == 0) {
+               ret = len;
+       } else {
+               free(common_buf);
+               errno = EIO;
+       }
+       return ret;
+}
+
+static int write_to_client_pending(struct tls_proxy_ctx *proxy)
+{
+       if (proxy->client_pending.len == 0) {
+               return 0;
+       }
+
+       struct buf *buf = get_first_client_pending(proxy);
+       uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+       fprintf(stdout, "writing %zd bytes to client\n", buf->size);
+
+       gnutls_session_t tls_session = proxy->tls.session;
+       assert(proxy->tls.handshake_state != TLS_HS_IN_PROGRESS);
+       assert(gnutls_record_check_corked(tls_session) == 0);
+
+       char *data = buf->buf;
+       size_t len = buf->size;
+
+       ssize_t count = 0;
+       ssize_t submitted = len;
+       ssize_t retries = 0;
+       do {
+               count = gnutls_record_send(tls_session, data, len);
+               if (count < 0) {
+                       if (gnutls_error_is_fatal(count)) {
+                               fprintf(stderr, "gnutls_record_send failed: %s (%zd)\n",
+                                       gnutls_strerror_name(count), count);
+                               return -1;
+                       }
+                       if (++retries > TLS_MAX_SEND_RETRIES) {
+                               fprintf(stderr, "gnutls_record_send: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n",
+                                       retries, gnutls_strerror_name(count), count);
+                               return -1;
+                       }
+               } else if (count != 0) {
+                       data += count;
+                       len -= count;
+                       retries = 0;
+               } else {
+                       if (++retries < TLS_MAX_SEND_RETRIES) {
+                               continue;
+                       }
+                       fprintf(stderr, "gnutls_record_send: too many retries (%zd)\n",
+                               retries);
+                       fprintf(stderr, "tls_push_to_client didn't send all data(%zd of %zd)\n",
+                               len, submitted);
+                       return -1;
+               }
+       } while (len > 0);
+
+       fprintf(stdout, "submitted %zd bytes to client\n", submitted);
+       assert (gnutls_safe_renegotiation_status(tls_session) != 0);
+       assert (gnutls_rehandshake(tls_session) == GNUTLS_E_SUCCESS);
+       fprintf(stdout, "rehandshake started\n");
+       return submitted;
+}
+
+static int tls_process_from_upstream(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t len)
+{
+       gnutls_session_t tls_session = proxy->tls.session;
+
+       fprintf(stdout, "pushing %zd bytes to client\n", len);
+
+       assert(gnutls_record_check_corked(tls_session) == 0);
+       ssize_t submitted = 0;
+       if (proxy->client_state != STATE_CONNECTED) {
+               return submitted;
+       }
+
+       push_to_client_pending(proxy, buf, len);
+       submitted = len;
+       if (proxy->tls.handshake_state == TLS_HS_DONE) {
+               if (proxy->client_pending.len == 1) {
+                       int ret = write_to_client_pending(proxy);
+                       if (ret < 0) {
+                               submitted = -1;
+                       }
+               }
+       }
+
+       return submitted;
+}
+
+int tls_process_handshake(struct tls_proxy_ctx *proxy)
+{
+       struct tls_ctx *tls = &proxy->tls;
+       int ret = 1;
+       while (tls->handshake_state == TLS_HS_IN_PROGRESS) {
+               fprintf(stdout, "TLS handshake in progress...\n");
+               int err = gnutls_handshake(tls->session);
+               if (err == GNUTLS_E_SUCCESS) {
+                       tls->handshake_state = TLS_HS_DONE;
+                       fprintf(stdout, "TLS handshake has completed\n");
+                       ret = 1;
+                       if (proxy->client_pending.len != 0) {
+                               write_to_client_pending(proxy);
+                       }
+               } else if (gnutls_error_is_fatal(err)) {
+                       fprintf(stderr, "gnutls_handshake failed: %s (%d)\n",
+                               gnutls_strerror_name(err), err);
+                       ret = -1;
+                       break;
+               } else {
+                       fprintf(stderr, "gnutls_handshake nonfatal error: %s (%d)\n",
+                               gnutls_strerror_name(err), err);
+                       ret = 0;
+                       break;
+               }
+       }
+       return ret;
+}
+
+int tls_process_from_client(struct tls_proxy_ctx *proxy, const uint8_t *buf, ssize_t nread)
+{
+       struct tls_ctx *tls = &proxy->tls;
+
+       tls->buf = buf;
+       tls->nread = nread >= 0 ? nread : 0;
+       tls->consumed = 0;
+
+       fprintf(stdout, "tls_process: reading %zd bytes from client\n", nread);
+
+       int ret = tls_process_handshake(proxy);
+       if (ret <= 0) {
+               return ret;
+       }
+
+       int submitted = 0;
+       while (true) {
+               ssize_t count = 0;
+               count = gnutls_record_recv(tls->session, tls->recv_buf, sizeof(tls->recv_buf));
+               if (count == GNUTLS_E_AGAIN) {
+                       break;    /* No data available */
+               } else if (count == GNUTLS_E_INTERRUPTED) {
+                       continue; /* Try reading again */
+               } else if (count == GNUTLS_E_REHANDSHAKE) {
+                       tls->handshake_state = TLS_HS_IN_PROGRESS;
+                       ret = tls_process_handshake(proxy);
+                       if (ret <= 0) {
+                               return ret;
+                       }
+                       continue;
+               } else if (count < 0) {
+                       fprintf(stderr, "gnutls_record_recv failed: %s (%zd)\n",
+                               gnutls_strerror_name(count), count);
+                       return -1;
+               } else if (count == 0) {
+                       break;
+               }
+               if (proxy->upstream_state == STATE_CONNECTED) {
+                       bool upstream_pending_is_empty = (proxy->upstream_pending.len == 0);
+                       push_to_upstream_pending(proxy, tls->recv_buf, count);
+                       if (upstream_pending_is_empty) {
+                               write_to_upstream_pending(proxy);
+                       }
+               } else if (proxy->upstream_state == STATE_NOT_CONNECTED) {
+                       /* TODO avoid allocation */
+                       uv_tcp_init(proxy->loop, &proxy->upstream);     
+                       uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t));
+                       proxy->upstream_state = STATE_CONNECT_IN_PROGRESS;
+                       fprintf(stdout, "connecting to %s\n",
+                               ip_straddr((struct sockaddr *)&proxy->upstream_addr));
+                       uv_tcp_connect(conn, &proxy->upstream, (struct sockaddr *)&proxy->upstream_addr,
+                                      on_connect_to_upstream);
+                       push_to_upstream_pending(proxy, tls->recv_buf, count);
+               } else if (proxy->upstream_state == STATE_CONNECT_IN_PROGRESS) {
+                       push_to_upstream_pending(proxy, tls->recv_buf, count);
+               }
+               submitted += count;
+       }
+       return submitted;
+}
+
+struct tls_proxy_ctx *tls_proxy_allocate()
+{
+       return malloc(sizeof(struct tls_proxy_ctx));
+}
+
+int tls_proxy_init(struct tls_proxy_ctx *proxy,
+                  const char *server_addr, int server_port,
+                  const char *upstream_addr, int upstream_port,
+                  const char *cert_file, const char *key_file)
+{      
+       proxy->loop = uv_default_loop();
+       uv_tcp_init(proxy->loop, &proxy->server);
+       int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server_addr);
+       if (res != 0) {
+               fprintf(stderr, "uv_ip4_addr failed with string '%s'\n", server_addr);
+               return -1;
+       }
+       res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr);
+       if (res != 0) {
+               fprintf(stderr, "uv_ip4_addr failed with string '%s'\n", upstream_addr);
+               return -1;
+       }
+       array_init(proxy->buffer_pool);
+       array_init(proxy->upstream_pending);
+       array_init(proxy->client_pending);
+       proxy->server_state = STATE_NOT_CONNECTED;
+       proxy->client_state = STATE_NOT_CONNECTED;
+       proxy->upstream_state = STATE_NOT_CONNECTED;
+
+       proxy->loop->data = proxy;
+       
+       int err = 0;
+       if (gnutls_references == 0) {
+               err = gnutls_global_init();
+               if (err != GNUTLS_E_SUCCESS) {
+                       fprintf(stderr, "gnutls_global_init() failed: (%d) %s\n",
+                                    err, gnutls_strerror_name(err));
+                       return -1;
+               }
+       }
+       gnutls_references += 1;
+       
+       err = gnutls_certificate_allocate_credentials(&proxy->tls.credentials);
+       if (err != GNUTLS_E_SUCCESS) {
+               fprintf(stderr, "gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
+                            err, gnutls_strerror_name(err));
+               return -1;
+       }
+
+       err = gnutls_certificate_set_x509_system_trust(proxy->tls.credentials);
+       if (err <= 0) {
+               fprintf(stderr, "gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
+                            err, gnutls_strerror_name(err));
+               return -1;
+       }
+
+       if (cert_file && key_file) {
+               err = gnutls_certificate_set_x509_key_file(proxy->tls.credentials,
+                                                    cert_file, key_file, GNUTLS_X509_FMT_PEM);
+               if (err != GNUTLS_E_SUCCESS) {
+                       fprintf(stderr, "gnutls_certificate_set_x509_key_file() failed: (%d) %s\n",
+                                    err, gnutls_strerror_name(err));
+                       return -1;
+               }
+       }
+
+       err = gnutls_priority_init(&proxy->tls.priority_cache, NULL, NULL);
+       if (err != GNUTLS_E_SUCCESS) {
+               fprintf(stderr, "gnutls_priority_init() failed: (%d) %s\n",
+                            err, gnutls_strerror_name(err));
+               return -1;
+       }
+
+       
+       proxy->tls.handshake_state = TLS_HS_NOT_STARTED;
+       return 0;
+}
+
+void tls_proxy_free(struct tls_proxy_ctx *proxy)
+{
+       if (!proxy) {
+               return;
+       }
+       clear_upstream_pending(proxy);
+       clear_client_pending(proxy);
+       clear_buffer_pool(proxy);
+       gnutls_certificate_free_credentials(proxy->tls.credentials);
+        gnutls_priority_deinit(proxy->tls.priority_cache);
+       /* TODO correctly close all the uv_tcp_t */
+       free(proxy);
+       
+       gnutls_references -= 1;
+       if (gnutls_references == 0) {
+               gnutls_global_deinit();
+       }
+}
+
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy)
+{      
+       uv_tcp_bind(&proxy->server, (const struct sockaddr*)&proxy->server_addr, 0);
+       int ret = uv_listen((uv_stream_t*)&proxy->server, 128, on_client_connection);
+       if (ret == 0) {
+               proxy->server_state = STATE_LISTENING;
+       }
+       return ret;
+}
+
+int tls_proxy_run(struct tls_proxy_ctx *proxy)
+{
+       return uv_run(proxy->loop, UV_RUN_DEFAULT);
+}
diff --git a/tests/pytests/rehandshake/tls-proxy.h b/tests/pytests/rehandshake/tls-proxy.h
new file mode 100644 (file)
index 0000000..1204eda
--- /dev/null
@@ -0,0 +1,14 @@
+#pragma once
+
+struct tls_proxy_ctx;
+
+struct tls_proxy_ctx *tls_proxy_allocate();
+void tls_proxy_free(struct tls_proxy_ctx *proxy);
+int tls_proxy_init(struct tls_proxy_ctx *proxy,
+                  const char *server_addr, int server_port,
+                  const char *upstream_addr, int upstream_port,
+                  const char *cert_file, const char *key_file);
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy);
+int tls_proxy_run(struct tls_proxy_ctx *proxy);
+
+
diff --git a/tests/pytests/rehandshake/tlsproxy.c b/tests/pytests/rehandshake/tlsproxy.c
new file mode 100644 (file)
index 0000000..0c074f1
--- /dev/null
@@ -0,0 +1,31 @@
+#include <stdio.h>
+#include "tls-proxy.h"
+#include <gnutls/gnutls.h>
+
+int main()
+{
+       struct tls_proxy_ctx *proxy = tls_proxy_allocate();
+       if (!proxy) {
+               fprintf(stderr, "can't allocate tls_proxy structure\n");
+               return 1;
+       }
+       int res = tls_proxy_init(proxy,
+                                "127.0.0.1", 53921, /* Address to listen */
+                                "127.0.0.1", 53910, /* Upstream address */
+                                "../certs/tt.cert.pem",
+                                "../certs/tt.key.pem");
+       if (res) {
+               fprintf(stderr, "can't initialize tls_proxy structure\n");
+               return res;
+       }
+       res = tls_proxy_start_listen(proxy);
+       if (res) {
+               fprintf(stderr, "error starting listen, error code: %i\n", res);
+               return res;
+       }
+       fprintf(stdout, "started...\n");
+       res = tls_proxy_run(proxy);
+       tls_proxy_free(proxy);
+       return res;
+}
+
index 22461479055d7bdd8ec34458e34f55b826f26f66..5475aebca8d485ccba2f310c28c37772f464efe4 100644 (file)
@@ -32,7 +32,7 @@ hints['{{ name }}'] = '{{ ip }}'
 policy.add(policy.all(
     {% if kresd.forward.proto == 'tls' %}
     policy.TLS_FORWARD({
-        {"{{ kresd.forward.ip }}@{{ kresd.forward.port }}", insecure=true}})
+        {"{{ kresd.forward.ip }}@{{ kresd.forward.port }}", hostname='{{ kresd.forward.hostname}}', ca_file='{{ kresd.forward.ca_file }}'}})
     {% endif %}
 ))
 {% endif %}
diff --git a/tests/pytests/test_rehandshake.py b/tests/pytests/test_rehandshake.py
new file mode 100644 (file)
index 0000000..d6faf48
--- /dev/null
@@ -0,0 +1,69 @@
+"""TLS rehandshake test
+
+Test utilizes rehandshake/tls-proxy, which forwards queries to configured
+resolver, but when it sends the response back to the query source, it
+performs a rehandshake after every byte sent.
+
+It is expected the answer will be received by the source kresd instance
+and sent back to the client (this test).
+
+Make sure to run `make all` in `rehandshake/` to compile the proxy.
+"""
+
+import os
+import subprocess
+import time
+
+import pytest
+
+from kresd import CERTS_DIR, Forward, make_kresd, PYTESTS_DIR
+import utils
+
+
+REHANDSHAKE_PROXY = os.path.join(PYTESTS_DIR, 'rehandshake', 'tlsproxy')
+
+
+@pytest.mark.skipif(not os.path.exists(REHANDSHAKE_PROXY),
+                    reason="tlsproxy not found (did you compile it?)")
+def test_rehandshake(tmpdir):
+    def resolve_hint(sock, qname):
+        buff, msgid = utils.get_msgbuff(qname)
+        sock.sendall(buff)
+        answer = utils.receive_parse_answer(sock)
+        assert answer.id == msgid
+        assert answer.answer[0][0].address == '127.0.0.1'
+
+    hints = {
+        '0.foo.': '127.0.0.1',
+        '1.foo.': '127.0.0.1',
+        '2.foo.': '127.0.0.1',
+        '3.foo.': '127.0.0.1',
+    }
+    # run forward target instance
+    workdir = os.path.join(str(tmpdir), 'kresd_fwd_target')
+    os.makedirs(workdir)
+
+    with make_kresd(workdir, hints=hints, port=53910) as kresd_fwd_target:
+        sock = kresd_fwd_target.ip_tls_socket()
+        resolve_hint(sock, '0.foo.')
+
+        # run proxy
+        cwd, cmd = os.path.split(REHANDSHAKE_PROXY)
+        cmd = './' + cmd
+        ca_file = os.path.join(CERTS_DIR, 'tt.cert.pem')
+        try:
+            proxy = subprocess.Popen(
+                [cmd], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+            # run test kresd instance
+            workdir2 = os.path.join(str(tmpdir), 'kresd')
+            os.makedirs(workdir2)
+            forward = Forward(proto='tls', ip='127.0.0.1', port=53921,
+                              hostname='transport-test-server.com', ca_file=ca_file)
+            with make_kresd(workdir2, forward=forward) as kresd:
+                sock2 = kresd.ip_tcp_socket()
+                for hint in hints:
+                    resolve_hint(sock2, hint)
+                    time.sleep(1)
+        finally:
+            proxy.terminate()