]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon: TLS over outbound TCP connection - use hostname authentication
authorgrid <grigorii.demidov@nic.cz>
Tue, 5 Dec 2017 14:15:44 +0000 (15:15 +0100)
committerPetr Špaček <petr.spacek@nic.cz>
Mon, 8 Jan 2018 11:01:00 +0000 (12:01 +0100)
daemon/bindings.c
daemon/io.c
daemon/network.c
daemon/tls.c
daemon/tls.h
daemon/worker.c
daemon/worker.h
modules/policy/policy.lua

index 7fcc9a39d848cc4776e33f55ecc6bbcb935516c2..88b4e74e1af410096378a2a1315b768a9b5b741b 100644 (file)
@@ -425,6 +425,13 @@ static int print_tls_param(const char *key, void *val, void *data)
                lua_settable(L, -3);
        }
        lua_setfield(L, -2, "ca files");
+       lua_newtable(L);
+       for (size_t i = 0; i < entry->hostnames.len; ++i) {
+               lua_pushnumber(L, i + 1);
+               lua_pushstring(L, entry->hostnames.at[i]);
+               lua_settable(L, -3);
+       }
+       lua_setfield(L, -2, "hostnames");
        lua_setfield(L, -2, key);
 
        return 0;
@@ -465,16 +472,24 @@ static int net_tls_client(lua_State *L)
        }
 
        const char *full_addr = NULL;
-       const char *ca_file = NULL;
-       const char *pin = NULL;
+       bool pin_exists = false;
+       bool ca_file_exists = false;
+       printf("%i\n", lua_gettop(L));
        if ((lua_gettop(L) == 1) && lua_isstring(L, 1)) {
                full_addr = lua_tostring(L, 1);
-       } else if ((lua_gettop(L) == 3) && lua_isstring(L, 1) && lua_isstring(L, 2) && lua_isstring(L, 3)) {
+       } else if ((lua_gettop(L) == 2) && lua_isstring(L, 1) && lua_istable(L, 2)) {
+               full_addr = lua_tostring(L, 1);
+               pin_exists = true;
+       } else if ((lua_gettop(L) == 3) && lua_isstring(L, 1) && lua_istable(L, 2)) {
                full_addr = lua_tostring(L, 1);
-               ca_file = lua_tostring(L, 2);
-               pin = lua_tostring(L, 3);
+               ca_file_exists = true;
+       } else if ((lua_gettop(L) == 4) && lua_isstring(L, 1) &&
+                   lua_istable(L, 2) && lua_istable(L, 3)) {
+               full_addr = lua_tostring(L, 1);
+               pin_exists = true;
+               ca_file_exists = true;
        } else {
-               format_error(L, "net.tls_client either takes one parameter (\"address\") either takes three ones: (\"address\", \"ca_file\", \"pin\")");
+               format_error(L, "net.tls_client takes one parameter (\"address\"), two parameters (\"address\",\"pin\"), three parameters (\"address\", \"ca_file\", \"hostname\") or four ones: (\"address\", \"pin\", \"ca_file\", \"hostname\")");
                lua_error(L);
        }
 
@@ -486,13 +501,76 @@ static int net_tls_client(lua_State *L)
        }
 
        if (port == 0) {
-               port = 53;
+               port = 853;
        }
 
-       int r = tls_client_params_set(&net->tls_client_params, addr, port, ca_file, pin);
-       if (r != 0) {
-               lua_pushstring(L, strerror(ENOMEM));
-               lua_error(L);
+       if (!pin_exists && !ca_file_exists) {
+               int r = tls_client_params_set(&net->tls_client_params,
+                                             addr, port, NULL, NULL, NULL);
+               if (r != 0) {
+                       lua_pushstring(L, strerror(ENOMEM));
+                       lua_error(L);
+               }
+
+               lua_pushboolean(L, true);
+               return 1;
+       }
+
+       if (pin_exists) {
+               /* iterate over table with pins
+                * http://www.lua.org/manual/5.1/manual.html#lua_next */
+               lua_pushnil(L); /* first key */
+               while (lua_next(L, 2)) {  /* pin table is in stack at index 2 */
+                       /* pin now at index -1, key at index -2*/
+                       const char *pin = lua_tostring(L, -1);
+                       int r = tls_client_params_set(&net->tls_client_params,
+                                                     addr, port, NULL, NULL, pin);
+                       if (r != 0) {
+                               lua_pushstring(L, strerror(ENOMEM));
+                               lua_error(L);
+                       }
+                       lua_pop(L, 1);
+               }
+       }
+
+       int ca_table_index = 2;
+       int hostname_table_index = 3;
+       if (ca_file_exists) {
+               if (pin_exists) {
+                       ca_table_index = 3;
+                       hostname_table_index = 4;
+               }
+       } else {
+               lua_pushboolean(L, true);
+               return 1;
+       }
+
+       /* iterate over ca filenames */
+       lua_pushnil(L);
+       while (lua_next(L, ca_table_index)) {
+               const char *ca_file = lua_tostring(L, -1);
+               int r = tls_client_params_set(&net->tls_client_params,
+                                             addr, port, ca_file, NULL, NULL);
+               if (r != 0) {
+                       lua_pushstring(L, strerror(ENOMEM));
+                       lua_error(L);
+               }
+               /* removes 'value'; keeps 'key' for next iteration */
+               lua_pop(L, 1);
+       }
+
+       /* iterate over hostnames */
+       lua_pushnil(L);
+       while (lua_next(L, hostname_table_index)) {
+               const char *hostname = lua_tostring(L, -1);
+               int r = tls_client_params_set(&net->tls_client_params,
+                                             addr, port, NULL, hostname, NULL);
+               if (r != 0) {
+                       lua_pushstring(L, strerror(ENOMEM));
+                       lua_error(L);
+               }
+               /* removes 'value'; keeps 'key' for next iteration */
+               lua_pop(L, 1);
        }
 
        lua_pushboolean(L, true);
index 22185ea3adf87d3354b2ddcaabc5173a4f634d99..a69c187a16cdbe7cd20cd24a836850629f5ce625 100644 (file)
@@ -110,7 +110,7 @@ static void session_release(struct worker_ctx *worker, uv_handle_t *handle)
 
 static uv_stream_t *handle_alloc(uv_loop_t *loop)
 {
-       uv_stream_t *handle = calloc(1, sizeof(union uv_handles));
+       uv_stream_t *handle = calloc(1, sizeof(uv_handles_t));
        if (!handle) {
                return NULL;
        }
index 040d4ed9c6cb94783b4293f0ea173061319853df..ff8004baea13ba1a967b8686b13fc7a982af35a2 100644 (file)
@@ -140,7 +140,7 @@ static int open_endpoint(struct network *net, struct endpoint *ep, struct sockad
 {
        int ret = 0;
        if (flags & NET_UDP) {
-               ep->udp = malloc(sizeof(union uv_handles));
+               ep->udp = malloc(sizeof(uv_handles_t));
                if (!ep->udp) {
                        return kr_error(ENOMEM);
                }
@@ -153,7 +153,7 @@ static int open_endpoint(struct network *net, struct endpoint *ep, struct sockad
                ep->flags |= NET_UDP;
        }
        if (flags & NET_TCP) {
-               ep->tcp = malloc(sizeof(union uv_handles));
+               ep->tcp = malloc(sizeof(uv_handles_t));
                if (!ep->tcp) {
                        return kr_error(ENOMEM);
                }
@@ -185,7 +185,7 @@ static int open_endpoint_fd(struct network *net, struct endpoint *ep, int fd, in
                if (ep->udp) {
                        return kr_error(EEXIST);
                }
-               ep->udp = malloc(sizeof(union uv_handles));// malloc(sizeof(*ep->udp));
+               ep->udp = malloc(sizeof(uv_handles_t));// malloc(sizeof(*ep->udp));
                if (!ep->udp) {
                        return kr_error(ENOMEM);
                }
@@ -201,7 +201,7 @@ static int open_endpoint_fd(struct network *net, struct endpoint *ep, int fd, in
                if (ep->tcp) {
                        return kr_error(EEXIST);
                }
-               ep->tcp = malloc(sizeof(union uv_handles));
+               ep->tcp = malloc(sizeof(uv_handles_t));
                if (!ep->tcp) {
                        return kr_error(ENOMEM);
                }
index d58e66f023d2b268818eff3e7474904bf1403c97..fab79d325e4ba7dc5c89dcfc408872d362b8d7c3 100644 (file)
@@ -551,7 +551,7 @@ static int client_paramlist_entry_clear(const char *k, void *v, void *baton)
 
 int tls_client_params_set(map_t *tls_client_paramlist,
                          const char *addr, uint16_t port,
-                         const char *ca_file, const char *pin)
+                         const char *ca_file, const char *hostname, const char *pin)
 {
        if (!tls_client_paramlist || !addr) {
                return kr_error(EINVAL);
@@ -587,7 +587,7 @@ int tls_client_params_set(map_t *tls_client_paramlist,
                bool already_exists = false;
                for (size_t i = 0; i < entry->ca_files.len; ++i) {
                        if (strcmp(entry->ca_files.at[i], ca_file) == 0) {
-                               kr_log_error("[tls client] error: ca file for address %s already was set, ignoring\n", key);
+                               kr_log_error("[tls client] error: ca file '%s'for address '%s' already was set, ignoring\n", ca_file, key);
                                already_exists = true;
                                break;
                        }
@@ -612,10 +612,30 @@ int tls_client_params_set(map_t *tls_client_paramlist,
                }
        }
 
+       if ((ret == kr_ok()) && hostname && hostname[0] != 0) {
+               bool already_exists = false;
+               for (size_t i = 0; i < entry->hostnames.len; ++i) {
+                       if (strcmp(entry->hostnames.at[i], hostname) == 0) {
+                               kr_log_error("[tls client] error: hostname '%s' for address '%s' already was set, ignoring\n", hostname, key);
+                               already_exists = true;
+                               break;
+                       }
+               }
+               if (!already_exists) {
+                       const char *value = strdup(hostname);
+                       if (!value) {
+                               ret = kr_error(ENOMEM);
+                       } else if (array_push(entry->hostnames, value) < 0) {
+                               free ((void *)value);
+                               ret = kr_error(ENOMEM);
+                       }
+               }
+       }
+
        if ((ret == kr_ok()) && pin && pin[0] != 0) {
                for (size_t i = 0; i < entry->pins.len; ++i) {
                        if (strcmp(entry->pins.at[i], pin) == 0) {
-                               kr_log_error("[tls client] warning: pin for address %s already was set, ignoring\n", key);
+                               kr_log_error("[tls client] warning: pin '%s' for address '%s' already was set, ignoring\n", pin, key);
                                return kr_ok();
                        }
                }
@@ -659,7 +679,7 @@ static int client_verify_certificate(gnutls_session_t tls_session)
        struct tls_client_ctx_t *ctx = gnutls_session_get_ptr(tls_session);
        assert(ctx->params != NULL);
 
-       if (ctx->params->pins.len == 0 && ctx->params->ca_files.len) {
+       if (ctx->params->pins.len == 0 && ctx->params->ca_files.len == 0) {
                return GNUTLS_E_SUCCESS;
        }
 
@@ -716,26 +736,38 @@ static int client_verify_certificate(gnutls_session_t tls_session)
                }
        }
 
+       /* pins were set, but no one was not matched */
+       kr_log_error("[tls_client] certificate PIN check failed\n");
+
 skip_pins:
 
        if (ctx->params->ca_files.len == 0) {
-               DEBUG_MSG("[tls_client] skipping certificate verification\n");
-               return GNUTLS_E_SUCCESS;
+               DEBUG_MSG("[tls_client] empty CA files list\n");
+               return GNUTLS_E_CERTIFICATE_ERROR;
        }
 
-       gnutls_typed_vdata_st data[2] = {
-               { .type = GNUTLS_DT_KEY_PURPOSE_OID,
-                 .data = (void *)GNUTLS_KP_TLS_WWW_SERVER }
-       };
-       size_t data_count = 1;
-       unsigned int status;
-       int ret = gnutls_certificate_verify_peers(ctx->tls_session, data, data_count, &status);
-       if (ret != GNUTLS_E_SUCCESS) {
-               kr_log_error("[tls_client] failed to verify peer certificate\n");
+       if (ctx->params->hostnames.len == 0) {
+               DEBUG_MSG("[tls_client] empty hostname list\n");
                return GNUTLS_E_CERTIFICATE_ERROR;
        }
 
-       return GNUTLS_E_SUCCESS;
+       for (size_t i = 0; i < ctx->params->hostnames.len; ++i) {
+               gnutls_typed_vdata_st data[2] = {
+                       { .type = GNUTLS_DT_KEY_PURPOSE_OID,
+                         .data = (void *)GNUTLS_KP_TLS_WWW_SERVER },
+                       { .type = GNUTLS_DT_DNS_HOSTNAME,
+                         .data = (void *)ctx->params->hostnames.at[i] }
+               };
+               size_t data_count = 2;
+               unsigned int status;
+               int ret = gnutls_certificate_verify_peers(ctx->tls_session, data, data_count, &status);
+               if ((ret == GNUTLS_E_SUCCESS) && (status == 0)) {
+                       return GNUTLS_E_SUCCESS;
+               }
+       }
+
+       kr_log_error("[tls_client] failed to verify peer certificate\n");
+       return GNUTLS_E_CERTIFICATE_ERROR;
 }
 
 static ssize_t kres_gnutls_client_push(gnutls_transport_ptr_t h, const void *buf, size_t len)
index 9b46c80214082c30b0a6581aeb73a4e3a85c6d2b..20d1efc521e1036d99bf4bc6690c4477906a6608 100644 (file)
@@ -38,6 +38,7 @@ struct tls_credentials {
 
 struct tls_client_paramlist_entry {
        array_t(const char *) ca_files;
+       array_t(const char *) hostnames;
        array_t(const char *) pins;
        gnutls_certificate_credentials_t credentials;
 };
@@ -87,7 +88,7 @@ struct tls_credentials * tls_get_ephemeral_credentials(struct engine *engine);
 /*! Set TLS authentication parameters for given address. */
 int tls_client_params_set(map_t *tls_client_paramlist,
                          const char *addr, uint16_t port,
-                         const char *ca_file, const char *pin);
+                         const char *ca_file, const char *hostname, const char *pin);
 
 /*! Free TLS authentication parameters. */
 int tls_client_params_free(map_t *tls_client_paramlist);
@@ -113,4 +114,4 @@ tls_client_hs_state_t tls_client_get_hs_state(const struct tls_client_ctx_t *ctx
 int tls_client_set_hs_state(struct tls_client_ctx_t *ctx, tls_client_hs_state_t state);
 
 int tls_client_ctx_set_params(struct tls_client_ctx_t *ctx,
-                             const struct tls_client_paramlist_entry *entry);
\ No newline at end of file
+                             const struct tls_client_paramlist_entry *entry);
index 614ad006a48b4de05ed21cc3798163e2f44efb26..9d7abdda0f91c6b087d087268a7f7c8a3310c9cb 100644 (file)
@@ -133,7 +133,7 @@ static inline void *iohandle_borrow(struct worker_ctx *worker)
 {
        void *h = NULL;
 
-       const size_t size = sizeof(union uv_handles);
+       const size_t size = sizeof(uv_handles_t);
        if (worker->pool_iohandles.len > 0) {
                h = array_tail(worker->pool_iohandles);
                array_pop(worker->pool_iohandles);
@@ -149,7 +149,7 @@ static inline void iohandle_release(struct worker_ctx *worker, void *h)
 {
        assert(h);
 
-       const size_t size = sizeof(union uv_handles);
+       const size_t size = sizeof(uv_handles_t);
        if (worker->pool_iohandles.len < MP_FREELIST_SIZE) {
                array_push(worker->pool_iohandles, h);
                kr_asan_poison(h, size);
@@ -172,7 +172,7 @@ static inline void *iorequest_borrow(struct worker_ctx *worker)
 {
        void *r = NULL;
 
-       const size_t size = sizeof(union uv_reqs);
+       const size_t size = sizeof(uv_reqs_t);
        if (worker->pool_ioreqs.len > 0) {
                r = array_tail(worker->pool_ioreqs);
                array_pop(worker->pool_ioreqs);
@@ -188,7 +188,7 @@ static inline void iorequest_release(struct worker_ctx *worker, void *r)
 {
        assert(r);
 
-       const size_t size = sizeof(union uv_reqs);
+       const size_t size = sizeof(uv_reqs_t);
        if (worker->pool_ioreqs.len < MP_FREELIST_SIZE) {
                array_push(worker->pool_ioreqs, r);
                kr_asan_poison(r, size);
@@ -2231,8 +2231,8 @@ static int worker_reserve(struct worker_ctx *worker, size_t ring_maxlen)
 void worker_reclaim(struct worker_ctx *worker)
 {
        reclaim_freelist(worker->pool_mp, struct mempool, mp_delete);
-       reclaim_freelist(worker->pool_ioreqs, union uv_reqs, free);
-       reclaim_freelist(worker->pool_iohandles, union uv_handles, free);
+       reclaim_freelist(worker->pool_ioreqs, uv_reqs_t, free);
+       reclaim_freelist(worker->pool_iohandles, uv_handles_t, free);
        reclaim_freelist(worker->pool_sessions, struct session, session_free);
        mp_delete(worker->pkt_pool.ctx);
        worker->pkt_pool.ctx = NULL;
index c19ad477a6d6886db3bc3061fa172e10663fb766..231b239759acd3edeccf21a13c8f595dd249b278 100644 (file)
@@ -147,8 +147,9 @@ struct worker_ctx {
        knot_mm_t pkt_pool;
 };
 
-/* @internal Union of derivatives from libuv uv_handle_t for freelist.
- * These have session as their `handle->data` and own it. */
+/* @internal Union of some libuv handles for freelist.
+ * These have session as their `handle->data` and own it.
+ * Subset of uv_any_handle. */
 union uv_handles {
        uv_handle_t   handle;
        uv_stream_t   stream;
@@ -156,9 +157,11 @@ union uv_handles {
        uv_tcp_t      tcp;
        uv_timer_t    timer;
 };
+typedef union uv_any_handle uv_handles_t;
 
 /* @internal Union of derivatives from uv_req_t libuv request handles for freelist.
- * These have only a reference to the task they're operating on. */
+ * These have only a reference to the task they're operating on.
+ * Subset of uv_any_req. */
 union uv_reqs {
        uv_req_t      req;
        uv_shutdown_t sdown;
@@ -166,6 +169,7 @@ union uv_reqs {
        uv_connect_t  connect;
        uv_udp_send_t send;
 };
+typedef union uv_reqs uv_reqs_t;
 
 /** @endcond */
 
index 9fa135e32463a749eb6c493a90fdbe5d6d31b98b..631c037c7e241508bef45766b7172a7a01d9fbca 100644 (file)
@@ -121,15 +121,65 @@ local function forward(target)
 end
 
 -- Forward request and all subrequests to upstream over TCP; validate answers
-local function tcp_forward(target)
-       local list = {}
-       if type(target) == 'table' then
-               for _, v in pairs(target) do
-                       table.insert(list, addr2sock(v))
-                       assert(#list <= 4, 'at most 4 TCP_FORWARD targets are supported')
+local function tls_forward(target)
+       local sockaddr_list = {}
+       local addr_list = {}
+       local ca_files = {}
+       local hostnames = {}
+       local pins = {}
+       if type(target) ~= 'table' then
+               assert(false, 'wrong TLS_FORWARD target')
+       end
+       for _, upstream_list_entry in pairs(target) do
+               upstream_addr = upstream_list_entry[1]
+               if type(upstream_addr) ~= 'string' then
+                       assert(false, 'bad IP address in TLS_FORWARD target')
+               end
+               table.insert(sockaddr_list, addr2sock(upstream_addr))
+               table.insert(addr_list, upstream_addr)
+               ca_file = upstream_list_entry['ca_file']
+               if ca_file ~= nil then
+                       hostname = upstream_list_entry['hostname']
+                       if hostname == nil then
+                               assert(false, 'hostname(s) is absent in TLS_FORWARD target')
+                       end
+                       ca_files_local = {}
+                       if type(ca_file) == 'table' then
+                               for _, v in pairs(ca_file) do
+                                       table.insert(ca_files_local, v)
+                               end
+                       else
+                               table.insert(ca_files_local, ca_file)
+                       end
+                       hostnames_local = {}
+                       if type(hostname) == 'table' then
+                               for _, v in pairs(hostname) do
+                                       table.insert(hostnames_local, v)
+                               end
+                       else
+                               table.insert(hostnames_local, hostname)
+                       end
+                       if next(ca_files_local) then
+                               ca_files[upstream_addr] = ca_files_local
+                       end
+                       if next(hostnames_local) then
+                               hostnames[upstream_addr] = hostnames_local
+                       end
+               end
+               pin = upstream_list_entry['pin']
+               pins_local = {}
+               if pin ~= nil then
+                       if type(pin) == 'table' then
+                               for _, v in pairs(pin) do
+                                       table.insert(pins_local, v)
+                               end
+                       else
+                               table.insert(pins_local, pin)
+                       end
+               end
+               if next(pins_local) then
+                       pins[upstream_addr] = pins_local
                end
-       else
-               table.insert(list, addr2sock(target))
        end
        return function(state, req)
                local qry = req:current()
@@ -141,7 +191,18 @@ local function tcp_forward(target)
                qry.flags.AWAIT_CUT = true
                req.options.TCP = true
                qry.flags.TCP = true
-               set_nslist(qry, list)
+               set_nslist(qry, sockaddr_list)
+               for _, v in pairs(addr_list) do
+                       if (pins[v] == nil and ca_files[v] == nil) then
+                               net.tls_client(v)
+                       elseif (pins[v] ~= nil and ca_files[v] == nil) then
+                               net.tls_client(v, pins[v])
+                       elseif (pins[v] == nil and ca_files[v] ~= nil) then
+                               net.tls_client(v, ca_files[v], hostnames[v])
+                       else
+                               net.tls_client(v, pins[v], ca_files[v], hostnames[v])
+                       end
+               end
                return state
        end
 end
@@ -262,7 +323,7 @@ end
 local policy = {
        -- Policies
        PASS = 1, DENY = 2, DROP = 3, TC = 4, QTRACE = 5,
-       FORWARD = forward, TCP_FORWARD = tcp_forward,
+       FORWARD = forward, TLS_FORWARD = tls_forward,
        STUB = stub, REROUTE = reroute, MIRROR = mirror, FLAGS = flags,
        -- Special values
        ANY = 0,