]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon: rework handling of TLS authentication params
authorVladimír Čunát <vladimir.cunat@nic.cz>
Thu, 31 Jan 2019 16:00:22 +0000 (17:00 +0100)
committerPetr Špaček <petr.spacek@nic.cz>
Fri, 22 Feb 2019 14:47:31 +0000 (15:47 +0100)
It's mainly about the way we parse and validate them.

Almost all of the parts of validation that were being done
in modules/policy/policy.lua and daemon/tls.c got moved
to daemon/bindings/net.c, so it's easier to follow that.
Also more checks are being done now, e.g. contents of .pin_sha256
and .hostname strings.

14 files changed:
daemon/bindings/impl.c
daemon/bindings/impl.h
daemon/bindings/net.c
daemon/network.c
daemon/network.h
daemon/tls.c
daemon/tls.h
daemon/worker.c
lib/generic/trie.h
lib/utils.c
lib/utils.h
modules/policy/README.rst
modules/policy/policy.lua
modules/policy/policy.test.lua

index b4671a07fa9b02fa404f277bc7cf8c3240c428ed..2db2895d2ffa3499385941bd40121d4f4f559ff1 100644 (file)
 #include <lauxlib.h>
 #include <string.h>
 
+
+const char * lua_table_checkindices(lua_State *L, const char *keys[])
+{
+       /* Iterate over table at the top of the stack.
+        * http://www.lua.org/manual/5.1/manual.html#lua_next */
+       for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+               lua_pop(L, 1); /* we don't need the value */
+               /* We need to copy the key, as _tostring() confuses _next().
+                * https://www.lua.org/manual/5.1/manual.html#lua_tolstring */
+               lua_pushvalue(L, -1);
+               const char *key = lua_tostring(L, -1);
+               if (!key)
+                       return "<NON-STRING_INDEX>";
+               for (const char **k = keys; ; ++k) {
+                       if (*k == NULL)
+                               return key;
+                       if (strcmp(*k, key) == 0)
+                               break;
+               }
+       }
+       return NULL;
+}
+
+
 /* Each of these just creates the correspondingly named lua table of functions. */
 int kr_bindings_cache   (lua_State *L); /* ./cache.c   */
 int kr_bindings_event   (lua_State *L); /* ./event.c   */
index d5dfff3148f6fd2c36ebcf5319bdcd075047972a..af54ddc0bae573261b5b92f7489ecd2eea9d12fa 100644 (file)
 #define STR(s) STRINGIFY_TOKEN(s)
 #define STRINGIFY_TOKEN(s) #s
 
+
+/** Check lua table at the top of the stack for allowed keys.
+ * \param keys NULL-terminated array of 0-terminated strings
+ * \return NULL if passed or the offending string (pushed on top of lua stack)
+ * \note Future work: if non-NULL is returned, there's extra stuff on the lua stack.
+ * \note Brute-force complexity: table length * summed length of keys.
+ */
+const char * lua_table_checkindices(lua_State *L, const char *keys[]);
+
+/** If the value at the top of the stack isn't a table, make it a single-element list. */
+static inline void lua_listify(lua_State *L)
+{
+       if (lua_istable(L, -1))
+               return;
+       lua_createtable(L, 1, 0);
+       lua_insert(L, lua_gettop(L) - 1); /* swap the top two stack elements */
+       lua_pushinteger(L, 1);
+       lua_insert(L, lua_gettop(L) - 1); /* swap the top two stack elements */
+       lua_settable(L, -3);
+}
+
+
 /** Throw a formatted lua error.
  *
  * The message will get prefixed by "ERROR: " and supplemented by stack trace.
index 28bad79ecfe47745c86f46340edf41a8294a3f5f..f8cbc32019557c38a270319cf9d2352cbba51a6b 100644 (file)
 
 #include "daemon/bindings/impl.h"
 
+#include "contrib/base64.h"
 #include "daemon/network.h"
 #include "daemon/tls.h"
 #include "daemon/worker.h"
 
+#include <stdlib.h>
+
 /** Append 'addr = {port = int, udp = bool, tcp = bool}' */
 static int net_list_add(const char *key, void *val, void *ext)
 {
@@ -198,7 +201,7 @@ static int net_bufsize(lua_State *L)
        struct engine *engine = engine_luaget(L);
        knot_rrset_t *opt_rr = engine->resolver.opt_rr;
        if (!lua_isnumber(L, 1)) {
-               lua_pushnumber(L, knot_edns_get_payload(opt_rr));
+               lua_pushinteger(L, knot_edns_get_payload(opt_rr));
                return 1;
        }
        int bufsize = lua_tointeger(L, 1);
@@ -216,14 +219,14 @@ static int net_pipeline(lua_State *L)
                return 0;
        }
        if (!lua_isnumber(L, 1)) {
-               lua_pushnumber(L, worker->tcp_pipeline_max);
+               lua_pushinteger(L, worker->tcp_pipeline_max);
                return 1;
        }
        int len = lua_tointeger(L, 1);
        if (len < 0 || len > UINT16_MAX)
                lua_error_p(L, "tcp_pipeline must be within <0, " STR(UINT16_MAX) ">");
        worker->tcp_pipeline_max = len;
-       lua_pushnumber(L, len);
+       lua_pushinteger(L, len);
        return 1;
 }
 
@@ -262,203 +265,325 @@ static int net_tls(lua_State *L)
        return 1;
 }
 
-static int print_tls_param(const char *key, void *val, void *data)
+/** Return a lua table with TLS authentication parameters.
+ * The format is the same as passed to policy.TLS_FORWARD();
+ * more precisely, it's in a compatible canonical form. */
+static int tls_params2lua(lua_State *L, trie_t *params)
 {
-       if (!val) {
-               return 0;
-       }
+       lua_newtable(L);
+       if (!params) /* Allowed special case. */
+               return 1;
+       trie_it_t *it;
+       size_t list_index = 0;
+       for (it = trie_it_begin(params); !trie_it_finished(it); trie_it_next(it)) {
+               /* Prepare table for the current address
+                * and its index in the returned list. */
+               lua_pushinteger(L, ++list_index);
+               lua_createtable(L, 0, 2);
+
+               /* Get the "addr#port" string... */
+               size_t ia_len;
+               const char *key = trie_it_key(it, &ia_len);
+               int af = AF_UNSPEC;
+               if (ia_len == 2 + sizeof(struct in_addr)) {
+                       af = AF_INET;
+               } else if (ia_len == 2 + sizeof(struct in6_addr)) {
+                       af = AF_INET6;
+               }
+               if (!key || af == AF_UNSPEC) {
+                       assert(false);
+                       lua_error_p(L, "internal error: bad IP address");
+               }
+               uint16_t port;
+               memcpy(&port, key, sizeof(port));
+               port = ntohs(port);
+               const char *ia = key + sizeof(port);
+               char str[INET6_ADDRSTRLEN + 1 + 5 + 1];
+               size_t len = sizeof(str);
+               if (kr_ntop_str(af, ia, port, str, &len) != kr_ok()) {
+                       assert(false);
+                       lua_error_p(L, "internal error: bad IP address conversion");
+               }
+               /* ...and push it as [1]. */
+               lua_pushinteger(L, 1);
+               lua_pushlstring(L, str, len - 1 /* len includes '\0' */);
+               lua_settable(L, -3);
 
-       struct tls_client_paramlist_entry *entry = (struct tls_client_paramlist_entry *)val;
+               const tls_client_param_t *e = *trie_it_val(it);
+               if (!e)
+                       lua_error_p(L, "internal problem - NULL entry for %s", str);
 
-       lua_State *L = (lua_State *)data;
+               /* .hostname = */
+               if (e->hostname) {
+                       lua_pushstring(L, e->hostname);
+                       lua_setfield(L, -2, "hostname");
+               }
 
-       lua_createtable(L, 0, 3);
+               /* .ca_files = */
+               if (e->ca_files.len) {
+                       lua_createtable(L, e->ca_files.len, 0);
+                       for (size_t i = 0; i < e->ca_files.len; ++i) {
+                               lua_pushinteger(L, i + 1);
+                               lua_pushstring(L, e->ca_files.at[i]);
+                               lua_settable(L, -3);
+                       }
+                       lua_setfield(L, -2, "ca_files");
+               }
 
-       lua_createtable(L, entry->pins.len, 0);
-       for (size_t i = 0; i < entry->pins.len; ++i) {
-               lua_pushnumber(L, i + 1);
-               lua_pushstring(L, entry->pins.at[i]);
-               lua_settable(L, -3);
-       }
-       lua_setfield(L, -2, "pins");
+               /* .pin_sha256 = ... ; keep sane indentation via goto. */
+               if (!e->pins.len) goto no_pins;
+               lua_createtable(L, e->pins.len, 0);
+               for (size_t i = 0; i < e->pins.len; ++i) {
+                       uint8_t pin_base64[TLS_SHA256_BASE64_BUFLEN];
+                       int err = base64_encode(e->pins.at[i], TLS_SHA256_RAW_LEN,
+                                               pin_base64, sizeof(pin_base64));
+                       if (err < 0) {
+                               assert(false);
+                               lua_error_p(L,
+                                       "internal problem when converting pin_sha256: %s",
+                                       kr_strerror(err));
+                       }
+                       lua_pushinteger(L, i + 1);
+                       lua_pushlstring(L, (const char *)pin_base64, err);
+                               /* pin_base64 isn't 0-terminated     ^^^ */
+                       lua_settable(L, -3);
+               }
+               lua_setfield(L, -2, "pin_sha256");
 
-       lua_createtable(L, entry->ca_files.len, 0);
-       for (size_t i = 0; i < entry->ca_files.len; ++i) {
-               lua_pushnumber(L, i + 1);
-               lua_pushstring(L, entry->ca_files.at[i]);
+       no_pins:/* .insecure = */
+               if (e->insecure) {
+                       lua_pushboolean(L, true);
+                       lua_setfield(L, -2, "insecure");
+               }
+               /* Now the whole table is pushed atop the returned list. */
                lua_settable(L, -3);
        }
-       lua_setfield(L, -2, "ca_files");
-
-       if (entry->hostname) {
-               lua_pushstring(L, entry->hostname);
-               lua_setfield(L, -2, "hostname");
-       }
-
-       lua_setfield(L, -2, key);
-
-       return 0;
+       trie_it_free(it);
+       return 1;
 }
 
-static int print_tls_client_params(lua_State *L)
+static inline int cmp_sha256(const void *p1, const void *p2)
 {
-       struct engine *engine = engine_luaget(L);
-       if (!engine) {
-               return 0;
-       }
-       struct network *net = &engine->net;
-       if (!net) {
-               return 0;
-       }
-       lua_newtable(L);
-       map_walk(&net->tls_client_params, print_tls_param, (void *)L);
-       return 1;
+       return memcmp(*(char * const *)p1, *(char * const *)p2, TLS_SHA256_RAW_LEN);
 }
-
-
 static int net_tls_client(lua_State *L)
 {
-       struct engine *engine = engine_luaget(L);
-       if (!engine) {
-               return 0;
-       }
-       struct network *net = &engine->net;
-       if (!net) {
-               return 0;
-       }
-
-       /* Only return current credentials. */
-       if (lua_gettop(L) == 0) {
-               return print_tls_client_params(L);
-       }
-
-       const char *full_addr = NULL;
-       bool pin_exists = false;
-       bool hostname_exists = false;
-       if ((lua_gettop(L) == 1) && lua_isstring(L, 1)) {
-               full_addr = lua_tostring(L, 1);
-       } else if ((lua_gettop(L) == 2) && lua_isstring(L, 1) && lua_istable(L, 2)) {
-               full_addr = lua_tostring(L, 1);
-               pin_exists = true;
-       } else if ((lua_gettop(L) == 3) && lua_isstring(L, 1) && lua_istable(L, 2)) {
-               full_addr = lua_tostring(L, 1);
-               hostname_exists = true;
-       } else if ((lua_gettop(L) == 4) && lua_isstring(L, 1) &&
-                   lua_istable(L, 2) && lua_istable(L, 3)) {
-               full_addr = lua_tostring(L, 1);
-               pin_exists = true;
-               hostname_exists = true;
-       } else {
-               lua_error_p(L,
-                       "net.tls_client takes one parameter (\"address\"),"
-                       " two parameters (\"address\",\"pin\"),"
-                       " three parameters (\"address\", \"ca_file\", \"hostname\")"
-                       " or four ones: (\"address\", \"pin\", \"ca_file\", \"hostname\")");
-       }
-
-       char buf[INET6_ADDRSTRLEN + 1];
-       uint16_t port = 853;
-       const char *addr = kr_straddr_split(full_addr, buf, &port);
-       if (!addr)
-               lua_error_p(L, "invalid IP address");
-
-       if (!pin_exists && !hostname_exists) {
-               int r = tls_client_params_set(&net->tls_client_params,
-                                             addr, port, NULL,
-                                             TLS_CLIENT_PARAM_NONE);
-               lua_error_maybe(L, r);
-               lua_pushboolean(L, true);
-               return 1;
+       /* TODO idea: allow starting the lua table with *multiple* IP targets,
+        * meaning the authentication config should be applied to each.
+        */
+       struct network *net = &engine_luaget(L)->net;
+       if (lua_gettop(L) == 0)
+               return tls_params2lua(L, net->tls_client_params);
+       /* Various basic sanity-checking. */
+       if (lua_gettop(L) != 1 || !lua_istable(L, 1))
+               lua_error_maybe(L, EINVAL);
+       {
+               const char *bad_key = lua_table_checkindices(L, (const char *[])
+                       { "1", "hostname", "ca_file", "pin_sha256", "insecure", NULL });
+               if (bad_key)
+                       lua_error_p(L, "found unexpected key '%s'", bad_key);
+       }
+
+       /**** Phase 1: get the parameter into a C struct, incl. parse of CA files,
+        *       regardless of the address-pair having an entry already. */
+
+       tls_client_param_t *e = tls_client_param_new();
+       if (!e)
+               lua_error_p(L, "out of memory or something like that :-/");
+       /* Shortcut for cleanup actions needed from now on. */
+       #define ERROR(...) do { \
+               free(e); \
+               lua_error_p(L, __VA_ARGS__); \
+       } while (false)
+
+       /* .hostname - always accepted. */
+       lua_getfield(L, 1, "hostname");
+       if (!lua_isnil(L, -1)) {
+               const char *hn_str = lua_tostring(L, -1);
+               /* Convert to lower-case dname and back, for checking etc. */
+               knot_dname_t dname[KNOT_DNAME_MAXLEN];
+               if (!hn_str || !knot_dname_from_str(dname, hn_str, sizeof(dname)))
+                       ERROR("invalid hostname");
+               knot_dname_to_lower(dname);
+               char *h = knot_dname_to_str_alloc(dname);
+               if (!h)
+                       ERROR("%s", kr_strerror(ENOMEM));
+               /* Strip the final dot produced by knot_dname_*() */
+               h[strlen(h) - 1] = '\0';
+               e->hostname = h;
        }
+       lua_pop(L, 1);
 
-       if (pin_exists) {
-               /* iterate over table with pins
+       /* .ca_file - it can be a list of paths, contrary to the name. */
+       bool has_ca_file = false;
+       lua_getfield(L, 1, "ca_file");
+       if (!lua_isnil(L, -1)) {
+               if (!e->hostname)
+                       ERROR("missing hostname but specifying ca_file");
+               lua_listify(L);
+               array_init(e->ca_files); /*< placate apparently confused scan-build */
+               if (array_reserve(e->ca_files, lua_objlen(L, -1)) != 0) /*< optim. */
+                       ERROR("%s", kr_strerror(ENOMEM));
+               /* Iterate over table at the top of the stack.
                 * http://www.lua.org/manual/5.1/manual.html#lua_next */
-               lua_pushnil(L); /* first key */
-               while (lua_next(L, 2)) {  /* pin table is in stack at index 2 */
-                       /* pin now at index -1, key at index -2*/
-                       const char *pin = lua_tostring(L, -1);
-                       int r = tls_client_params_set(&net->tls_client_params,
-                                                     addr, port, pin,
-                                                     TLS_CLIENT_PARAM_PIN);
-                       lua_error_maybe(L, r);
-                       lua_pop(L, 1);
+               for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+                       has_ca_file = true; /* deferred here so that {} -> false */
+                       const char *ca_file = lua_tostring(L, -1);
+                       if (!ca_file)
+                               ERROR("ca_file contains a non-string");
+                       /* Let gnutls process it immediately, so garbage gets detected. */
+                       int ret = gnutls_certificate_set_x509_trust_file(
+                                       e->credentials, ca_file, GNUTLS_X509_FMT_PEM);
+                       if (ret < 0) {
+                               ERROR("failed to import certificate file '%s': %s - %s\n",
+                                       ca_file, gnutls_strerror_name(ret),
+                                       gnutls_strerror(ret));
+                       } else {
+                               kr_log_verbose(
+                                       "[tls_client] imported %d certs from file '%s'\n",
+                                       ret, ca_file);
+                       }
+
+                       ca_file = strdup(ca_file);
+                       if (!ca_file || array_push(e->ca_files, ca_file) < 0)
+                               ERROR("%s", kr_strerror(ENOMEM));
+               }
+               /* Sort the strings for easier comparison later. */
+               if (e->ca_files.len) {
+                       qsort(&e->ca_files.at[0], e->ca_files.len,
+                               sizeof(e->ca_files.at[0]), strcmp_p);
                }
        }
+       lua_pop(L, 1);
 
-       int ca_table_index = 2;
-       int hostname_table_index = 3;
-       if (hostname_exists) {
-               if (pin_exists) {
-                       ca_table_index = 3;
-                       hostname_table_index = 4;
+       /* .pin_sha256 */
+       lua_getfield(L, 1, "pin_sha256");
+       if (!lua_isnil(L, -1)) {
+               if (has_ca_file)
+                       ERROR("mixing pin_sha256 with ca_file is not supported");
+               lua_listify(L);
+               array_init(e->pins); /*< placate apparently confused scan-build */
+               if (array_reserve(e->pins, lua_objlen(L, -1)) != 0) /*< optim. */
+                       ERROR("%s", kr_strerror(ENOMEM));
+               /* Iterate over table at the top of the stack. */
+               for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
+                       const char *pin = lua_tostring(L, -1);
+                       if (!pin)
+                               ERROR("pin_sha256 is not a string");
+                       uint8_t *pin_raw = malloc(TLS_SHA256_RAW_LEN);
+                       /* Push the string early to simplify error processing. */
+                       if (!pin_raw || array_push(e->pins, pin_raw) < 0) {
+                               assert(false);
+                               free(pin_raw);
+                               ERROR("%s", kr_strerror(ENOMEM));
+                       }
+                       int ret = base64_decode((const uint8_t *)pin, strlen(pin),
+                                               pin_raw, TLS_SHA256_RAW_LEN + 8);
+                       if (ret < 0) {
+                               ERROR("not a valid pin_sha256: '%s' (length %d), %s\n",
+                                       pin, (int)strlen(pin), knot_strerror(ret));
+                       } else if (ret != TLS_SHA256_RAW_LEN) {
+                               ERROR("not a valid pin_sha256: '%s', "
+                                               "raw length %d instead of "
+                                               STR(TLS_SHA256_RAW_LEN)"\n",
+                                       pin, ret);
+                       }
+               }
+               /* Sort the raw strings for easier comparison later. */
+               if (e->pins.len) {
+                       qsort(&e->pins.at[0], e->pins.len,
+                               sizeof(e->pins.at[0]), cmp_sha256);
                }
-       } else {
-               lua_pushboolean(L, true);
-               return 1;
        }
+       lua_pop(L, 1);
 
-       /* iterate over hostnames,
-        * it must be done before iterating over ca filenames */
-       lua_pushnil(L);
-       while (lua_next(L, hostname_table_index)) {
-               const char *hostname = lua_tostring(L, -1);
-               int r = tls_client_params_set(&net->tls_client_params,
-                                             addr, port, hostname,
-                                             TLS_CLIENT_PARAM_HOSTNAME);
-               lua_error_maybe(L, r);
-               /* removes 'value'; keeps 'key' for next iteration */
-               lua_pop(L, 1);
+       /* .insecure */
+       lua_getfield(L, 1, "insecure");
+       if (lua_isnil(L, -1)) {
+               if (!e->hostname && !e->pins.len)
+                       ERROR("no way to authenticate and not set as insecure");
+       } else if (lua_isboolean(L, -1) && lua_toboolean(L, -1)) {
+               e->insecure = true;
+               if (has_ca_file || e->pins.len)
+                       ERROR("set as insecure but provided authentication config");
+       } else {
+               ERROR("incorrect value in the 'insecure' field");
        }
+       lua_pop(L, 1);
 
-       /* iterate over ca filenames */
-       lua_pushnil(L);
-       size_t num_of_ca_files = 0;
-       while (lua_next(L, ca_table_index)) {
-               const char *ca_file = lua_tostring(L, -1);
-               int r = tls_client_params_set(&net->tls_client_params,
-                                             addr, port, ca_file,
-                                             TLS_CLIENT_PARAM_CA);
-               lua_error_maybe(L, r);
-               num_of_ca_files += 1;
-               /* removes 'value'; keeps 'key' for next iteration */
-               lua_pop(L, 1);
+       /* Init CAs from system trust store, if needed. */
+       if (!e->insecure && !e->pins.len && !has_ca_file) {
+               int ret = gnutls_certificate_set_x509_system_trust(e->credentials);
+               if (ret <= 0) {
+                       ERROR("failed to use system CA certificate store: %s",
+                               ret ? gnutls_strerror(ret) : kr_strerror(ENOENT));
+               } else {
+                       kr_log_verbose(
+                               "[tls_client] imported %d certs from system store\n",
+                               ret);
+               }
        }
+       #undef ERROR
 
-       if (num_of_ca_files == 0) {
-               /* No ca files were explicitly configured, so use system CA */
-               int r = tls_client_params_set(&net->tls_client_params,
-                                             addr, port, NULL,
-                                             TLS_CLIENT_PARAM_CA);
-               lua_error_maybe(L, r);
-       }
+       /**** Phase 2: deal with the C authentication "table". */
+       /* Parse address and port. */
+       lua_pushinteger(L, 1);
+       lua_gettable(L, 1);
+       const char *addr_str = lua_tostring(L, -1);
+       if (!addr_str)
+               lua_error_p(L, "address is not a string");
+       char buf[INET6_ADDRSTRLEN + 1];
+       uint16_t port = 853;
+       addr_str = kr_straddr_split(addr_str, buf, &port);
+       /* Add e into the C map, saving the original into e0. */
+       const struct sockaddr *addr = kr_straddr_socket(addr_str, port);
+       if (!addr)
+               lua_error_p(L, "address '%s' could not be converted", addr_str);
+       tls_client_param_t **e0p = tls_client_param_getptr(
+                       &net->tls_client_params, addr, true);
+       free_const(addr);
+       if (!e0p)
+               lua_error_p(L, "internal error when extending tls_client_params map");
+       tls_client_param_t *e0 = *e0p;
+       *e0p = e;
+       /* If there was no original entry, it's easy! */
+       if (!e0)
+               return 0;
 
-       lua_pushboolean(L, true);
-       return 1;
+       /* Check for equality (e vs. e0), and print a warning if not equal.*/
+       const bool ok_h = (!e->hostname && !e0->hostname)
+               || (e->hostname && e0->hostname && strcmp(e->hostname, e0->hostname) == 0);
+       bool ok_ca = e->ca_files.len == e0->ca_files.len;
+       for (int i = 0; ok_ca && i < e->ca_files.len; ++i)
+               ok_ca = strcmp(e->ca_files.at[i], e0->ca_files.at[i]) == 0;
+       bool ok_pins = e->pins.len == e0->pins.len;
+       for (int i = 0; ok_pins && i < e->pins.len; ++i)
+               ok_ca = memcmp(e->pins.at[i], e0->pins.at[i], TLS_SHA256_RAW_LEN) == 0;
+       if (!(ok_h && ok_ca && ok_pins && e->insecure == e0->insecure)) {
+               kr_log_info("[tls_client] "
+                       "warning: re-defining TLS authentication parameters for %s\n",
+                       addr_str);
+       }
+       tls_client_param_unref(e0);
+       return 0;
 }
 
-static int net_tls_client_clear(lua_State *L)
+int net_tls_client_clear(lua_State *L)
 {
-       struct engine *engine = engine_luaget(L);
-       if (!engine)
-               return 0;
-
-       struct network *net = &engine->net;
-       if (!net)
-               return 0;
-
+       /* One parameter: address -> convert it to a struct sockaddr. */
        if (lua_gettop(L) != 1 || !lua_isstring(L, 1))
                lua_error_p(L, "net.tls_client_clear() requires one parameter (\"address\")");
-
-       const char *full_addr = lua_tostring(L, 1);
-
+       const char *addr_str = lua_tostring(L, 1);
        char buf[INET6_ADDRSTRLEN + 1];
        uint16_t port = 853;
-       const char *addr = kr_straddr_split(full_addr, buf, &port);
+       addr_str = kr_straddr_split(addr_str, buf, &port);
+       const struct sockaddr *addr = kr_straddr_socket(addr_str, port);
        if (!addr)
                lua_error_p(L, "invalid IP address");
-
-       int r = tls_client_params_clear(&net->tls_client_params, addr, port);
+       /* Do the actual removal. */
+       struct network *net = &engine_luaget(L)->net;
+       int r = tls_client_param_remove(net->tls_client_params, addr);
+       free_const(addr);
        lua_error_maybe(L, r);
        lua_pushboolean(L, true);
        return 1;
@@ -644,7 +769,7 @@ static int net_update_timeout(lua_State *L, uint64_t *timeout, const char *name)
 {
        /* Only return current idle timeout. */
        if (lua_gettop(L) == 0) {
-               lua_pushnumber(L, *timeout);
+               lua_pushinteger(L, *timeout);
                return 1;
        }
 
index 22273f034c6b0cf862cab38122077754439e3528..ed94f7d47de90af1e2949680a189589415dc0f40 100644 (file)
@@ -51,7 +51,7 @@ void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
        if (net != NULL) {
                net->loop = loop;
                net->endpoints = map_make(NULL);
-               net->tls_client_params = map_make(NULL);
+               net->tls_client_params = NULL;
                net->tls_session_ticket_ctx = /* unsync. random, by default */
                tls_session_ticket_ctx_create(loop, NULL, 0);
                net->tcp.in_idle_timeout = 10000;
@@ -112,10 +112,11 @@ void network_deinit(struct network *net)
                map_walk(&net->endpoints, free_key, 0);
                map_clear(&net->endpoints);
                tls_credentials_free(net->tls_credentials);
-               tls_client_params_free(&net->tls_client_params);
-               net->tls_credentials = NULL;
+               tls_client_params_free(net->tls_client_params);
                tls_session_ticket_ctx_destroy(net->tls_session_ticket_ctx);
-               net->tcp.in_idle_timeout = 0;
+               #ifndef NDEBUG
+                       memset(net, 0, sizeof(*net));
+               #endif
        }
 }
 
index ffaaab9ea6dc46be200ed158463e0e057aba4311..1e80d09a5a3a05bcbcd98e9eba08193cc821248b 100644 (file)
 
 #pragma once
 
-#include <uv.h>
-#include <stdbool.h>
+#include "daemon/tls.h"
 
 #include "lib/generic/array.h"
 #include "lib/generic/map.h"
 
+#include <uv.h>
+#include <stdbool.h>
+
+
 struct engine;
 
 enum endpoint_flag {
@@ -47,12 +50,11 @@ struct net_tcp_param {
        uint64_t tls_handshake_timeout;
 };
 
-struct tls_session_ticket_ctx;
 struct network {
        uv_loop_t *loop;
        map_t endpoints;
        struct tls_credentials *tls_credentials;
-       map_t tls_client_params; /**< Use tls_client_params_*() functions. */
+       tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */
        struct tls_session_ticket_ctx *tls_session_ticket_ctx;
        struct net_tcp_param tcp;
        int tcp_backlog;
index 6b66c50cfceb8d0c9af7ffea50e0ceee73618c01..75fcd6d47a94c00a4dc1a7e5208816de35ad9a78 100644 (file)
@@ -537,8 +537,7 @@ ssize_t tls_process_input_data(struct session *s, const uint8_t *buf, ssize_t nr
        return submitted;
 }
 
-#if GNUTLS_VERSION_NUMBER >= GNUTLS_PIN_MIN_VERSION
-
+#if TLS_CAN_USE_PINS
 /*
   DNS-over-TLS Out of band key-pinned authentication profile uses the
   same form of pins as HPKP:
@@ -550,29 +549,45 @@ ssize_t tls_process_input_data(struct session *s, const uint8_t *buf, ssize_t nr
 */
 #define PINLEN  ((((32) * 8 + 4)/6) + 3 + 1)
 
-/* out must be at least PINLEN octets long */
-static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len)
+/* Compute pin_sha256 for the certificate.
+ * It may be in raw format - just TLS_SHA256_RAW_LEN bytes without termination,
+ * or it may be a base64 0-terminated string requiring up to
+ * TLS_SHA256_BASE64_BUFLEN bytes.
+ * \return error code */
+static int get_oob_key_pin(gnutls_x509_crt_t crt, char *outchar, ssize_t outchar_len, bool raw)
 {
-       int err;
+       if (raw && outchar_len < TLS_SHA256_RAW_LEN) {
+               assert(false);
+               return kr_error(ENOSPC);
+               /* With !raw we have check inside base64_encode. */
+       }
        gnutls_pubkey_t key;
-       gnutls_datum_t datum = { .size = 0 };
+       int err = gnutls_pubkey_init(&key);
+       if (err != GNUTLS_E_SUCCESS) return err;
 
-       if ((err = gnutls_pubkey_init(&key)) != GNUTLS_E_SUCCESS) {
-               return err;
-       }
+       gnutls_datum_t datum = { .data = NULL, .size = 0 };
+       err = gnutls_pubkey_import_x509(key, crt, 0);
+       if (err != GNUTLS_E_SUCCESS) goto leave;
 
-       if ((err = gnutls_pubkey_import_x509(key, crt, 0)) != GNUTLS_E_SUCCESS) {
-               goto leave;
-       } else {
-               if ((err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum)) != GNUTLS_E_SUCCESS) {
+       err = gnutls_pubkey_export2(key, GNUTLS_X509_FMT_DER, &datum);
+       if (err != GNUTLS_E_SUCCESS) goto leave;
+
+       {
+               char raw_pin[TLS_SHA256_RAW_LEN]; /* TMP buffer if raw == false */
+               err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size,
+                                       (raw ? outchar : raw_pin));
+               if (err != GNUTLS_E_SUCCESS || raw/*success*/)
                        goto leave;
-               } else {
-                       uint8_t raw_pin[32];
-                       if ((err = gnutls_hash_fast(GNUTLS_DIG_SHA256, datum.data, datum.size, raw_pin)) != GNUTLS_E_SUCCESS) {
-                               goto leave;
-                       } else {
-                               base64_encode(raw_pin, sizeof(raw_pin), (uint8_t *)outchar, outchar_len);
-                       }
+               /* Convert to non-raw. */
+               err = base64_encode((uint8_t *)raw_pin, sizeof(raw_pin),
+                                   (uint8_t *)outchar, outchar_len);
+               if (err >= 0 && err < outchar_len) {
+                       err = GNUTLS_E_SUCCESS;
+                       outchar[err] = '\0'; /* base64_decode() doesn't do it */
+               } else if (err >= 0) {
+                       assert(false);
+                       err = kr_error(ENOSPC); /* base64 fits but '\0' doesn't */
+                       outchar[outchar_len - 1] = '\0';
                }
        }
 leave:
@@ -584,23 +599,27 @@ leave:
 void tls_credentials_log_pins(struct tls_credentials *tls_credentials)
 {
        for (int index = 0;; index++) {
-               int err;
                gnutls_x509_crt_t *certs = NULL;
                unsigned int cert_count = 0;
-
-               if ((err = gnutls_certificate_get_x509_crt(tls_credentials->credentials, index, &certs, &cert_count)) != GNUTLS_E_SUCCESS) {
+               int err = gnutls_certificate_get_x509_crt(tls_credentials->credentials,
+                                                       index, &certs, &cert_count);
+               if (err != GNUTLS_E_SUCCESS) {
                        if (err != GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE) {
-                               kr_log_error("[tls] could not get X.509 certificates (%d) %s\n", err, gnutls_strerror_name(err));
+                               kr_log_error("[tls] could not get X.509 certificates (%d) %s\n",
+                                               err, gnutls_strerror_name(err));
                        }
                        return;
                }
 
                for (int i = 0; i < cert_count; i++) {
-                       char pin[PINLEN] = { 0 };
-                       if ((err = get_oob_key_pin(certs[i], pin, sizeof(pin))) != GNUTLS_E_SUCCESS) {
-                               kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n", i, err, gnutls_strerror_name(err));
+                       char pin[TLS_SHA256_BASE64_BUFLEN] = { 0 };
+                       err = get_oob_key_pin(certs[i], pin, sizeof(pin), false);
+                       if (err != GNUTLS_E_SUCCESS) {
+                               kr_log_error("[tls] could not calculate RFC 7858 OOB key-pin from cert %d (%d) %s\n",
+                                               i, err, gnutls_strerror_name(err));
                        } else {
-                               kr_log_info("[tls] RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n", i, pin);
+                               kr_log_info("[tls] RFC 7858 OOB key-pin (%d): pin-sha256=\"%s\"\n",
+                                               i, pin);
                        }
                        gnutls_x509_crt_deinit(certs[i]);
                }
@@ -757,8 +776,13 @@ void tls_credentials_free(struct tls_credentials *tls_credentials) {
        free(tls_credentials);
 }
 
-static int client_paramlist_entry_free(struct tls_client_paramlist_entry *entry)
+void tls_client_param_unref(tls_client_param_t *entry)
 {
+       if (!entry) return;
+       assert(entry->refs); /* Well, we'd only leak memory. */
+       --(entry->refs);
+       if (entry->refs) return;
+
        DEBUG_MSG("freeing TLS parameters %p\n", (void *)entry);
 
        for (int i = 0; i < entry->ca_files.len; ++i) {
@@ -782,231 +806,96 @@ static int client_paramlist_entry_free(struct tls_client_paramlist_entry *entry)
        }
 
        free(entry);
-
-       return 0;
 }
-
-static void client_paramlist_entry_ref(struct tls_client_paramlist_entry *entry)
-{
-       if (entry != NULL) {
-               entry->refs += 1;
-       }
-}
-
-static void client_paramlist_entry_unref(struct tls_client_paramlist_entry *entry)
+static int param_free(void **param, void *null)
 {
-       if (entry != NULL) {
-               assert(entry->refs > 0);
-               entry->refs -= 1;
-
-               /* Last reference frees the object */
-               if (entry->refs == 0) {
-                       client_paramlist_entry_free(entry);
-               }
-       }
+       assert(param && *param);
+       tls_client_param_unref(*param);
+       return 0;
 }
-
-static int client_paramlist_entry_clear(const char *k, void *v, void *baton)
+void tls_client_params_free(tls_client_params_t *params)
 {
-       struct tls_client_paramlist_entry *entry = (struct tls_client_paramlist_entry *)v;
-       return client_paramlist_entry_free(entry);
+       if (!params) return;
+       trie_apply(params, param_free, NULL);
+       trie_free(params);
 }
 
-struct tls_client_paramlist_entry *tls_client_try_upgrade(map_t *tls_client_paramlist,
-                         const struct sockaddr *addr)
+tls_client_param_t * tls_client_param_new()
 {
-       /* Opportunistic upgrade from port 53 -> 853 */
-       if (kr_inaddr_port(addr) != KR_DNS_PORT) {
+       tls_client_param_t *e = calloc(1, sizeof(*e));
+       if (!e) {
+               assert(!ENOMEM);
                return NULL;
        }
-
-       static char key[INET6_ADDRSTRLEN + 6];
-       size_t keylen = sizeof(key);
-       if (kr_inaddr_str(addr, key, &keylen) != 0) {
+       /* Note: those array_t don't need further initialization. */
+       e->refs = 1;
+       int ret = gnutls_certificate_allocate_credentials(&e->credentials);
+       if (ret != GNUTLS_E_SUCCESS) {
+               kr_log_error("[tls_client] error: gnutls_certificate_allocate_credentials() fails (%s)\n",
+                            gnutls_strerror_name(ret));
+               free(e);
                return NULL;
        }
-
-       /* Rewrite 053 -> 853 */
-       memcpy(key + keylen - 4, "853", 3);
-
-       return map_get(tls_client_paramlist, key);
+       gnutls_certificate_set_verify_function(e->credentials, client_verify_certificate);
+       return e;
 }
 
-int tls_client_params_clear(map_t *tls_client_paramlist, const char *addr, uint16_t port)
+static bool construct_key(const union inaddr *addr, uint32_t *len, char *key)
 {
-       if (!tls_client_paramlist || !addr) {
-               return kr_error(EINVAL);
+       switch (addr->ip.sa_family) {
+       case AF_INET:
+               memcpy(key, &addr->ip4.sin_port, sizeof(addr->ip4.sin_port));
+               memcpy(key + sizeof(addr->ip4.sin_port), &addr->ip4.sin_addr,
+                       sizeof(addr->ip4.sin_addr));
+               *len = sizeof(addr->ip4.sin_port) + sizeof(addr->ip4.sin_addr);
+               return true;
+       case AF_INET6:
+               memcpy(key, &addr->ip6.sin6_port, sizeof(addr->ip6.sin6_port));
+               memcpy(key + sizeof(addr->ip6.sin6_port), &addr->ip6.sin6_addr,
+                       sizeof(addr->ip6.sin6_addr));
+               *len = sizeof(addr->ip6.sin6_port) + sizeof(addr->ip6.sin6_addr);
+               return true;
+       default:
+               assert(!EINVAL);
+               return false;
        }
-
-       /* Parameters are OK */
-
-       char key[INET6_ADDRSTRLEN + 6];
-       size_t keylen = sizeof(key);
-       if (kr_straddr_join(addr, port, key, &keylen) != kr_ok()) {
-               return kr_error(EINVAL);
-       }
-
-       struct tls_client_paramlist_entry *entry = map_get(tls_client_paramlist, key);
-       if (entry != NULL) {
-               client_paramlist_entry_clear(NULL, (void *)entry, NULL);
-               map_del(tls_client_paramlist, key);
-       }
-
-       return kr_ok();
 }
-
-int tls_client_params_set(map_t *tls_client_paramlist,
-                         const char *addr, uint16_t port,
-                         const char *param, tls_client_param_t param_type)
+tls_client_param_t ** tls_client_param_getptr(tls_client_params_t **params,
+                               const struct sockaddr *addr, bool do_insert)
 {
-       if (!tls_client_paramlist || !addr) {
-               return kr_error(EINVAL);
-       }
-
-       /* TLS_CLIENT_PARAM_CA can be empty */
-       if (param_type == TLS_CLIENT_PARAM_HOSTNAME ||
-           param_type == TLS_CLIENT_PARAM_PIN) {
-               if (param == NULL || param[0] == 0) {
-                       return kr_error(EINVAL);
-               }
-       }
-
-       /* Parameters are OK */
-
-       char key[INET6_ADDRSTRLEN + 6];
-       size_t keylen = sizeof(key);
-       if (kr_straddr_join(addr, port, key, &keylen) != kr_ok()) {
-               kr_log_error("[tls_client] warning: '%s' is not a valid ip address\n", addr);
-               return kr_error(EINVAL);
-       }
-
-       bool is_first_entry = false;
-       struct tls_client_paramlist_entry *entry = map_get(tls_client_paramlist, key);
-       if (entry == NULL) {
-               entry = calloc(1, sizeof(struct tls_client_paramlist_entry));
-               if (entry == NULL) {
-                       return kr_error(ENOMEM);
-               }
-               is_first_entry  = true;
-               int ret = gnutls_certificate_allocate_credentials(&entry->credentials);
-               if (ret != GNUTLS_E_SUCCESS) {
-                       free(entry);
-                       kr_log_error("[tls_client] error: gnutls_certificate_allocate_credentials() fails (%s)\n",
-                                    gnutls_strerror_name(ret));
-                       return kr_error(ENOMEM);
-               }
-               gnutls_certificate_set_verify_function(entry->credentials, client_verify_certificate);
-               client_paramlist_entry_ref(entry);
-       }
-
-       int ret = kr_ok();
-
-       if (param_type == TLS_CLIENT_PARAM_HOSTNAME) {
-               if (entry->hostname && strcasecmp(entry->hostname, param)) {
-                       kr_log_error("[tls_client] error: hostname collision for address"
-                                       " '%s': '%s' '%s'\n",
-                                       key, entry->hostname, param);
-                       return kr_error(EINVAL);
-               }
-               if (!entry->hostname) {
-                       entry->hostname = strdup(param);
-                       if (!entry->hostname) {
-                               return kr_error(ENOMEM);
-                       }
-               }
-       } else if (param_type == TLS_CLIENT_PARAM_CA) {
-               /* Import ca files only when hostname is already set */
-               if (!entry->hostname) {
-                       return kr_error(ENOENT);
-               }
-               const char *ca_file = param;
-               bool already_exists = false;
-               for (size_t i = 0; i < entry->ca_files.len; ++i) {
-                       const char *imported_ca = entry->ca_files.at[i];
-                       if (imported_ca[0] == 0 && (ca_file == NULL || ca_file[0] == 0)) {
-                               kr_log_error("[tls_client] error: system ca for address '%s' already was set, ignoring\n", key);
-                               already_exists = true;
-                               break;
-                       } else if (strcmp(imported_ca, ca_file) == 0) {
-                               kr_log_error("[tls_client] error: ca file '%s' for address '%s' already was set, ignoring\n", ca_file, key);
-                               already_exists = true;
-                               break;
-                       }
-               }
-               if (!already_exists) {
-                       const char *value = strdup(ca_file != NULL ? ca_file : "");
-                       if (!value) {
-                               ret = kr_error(ENOMEM);
-                       } else if (array_push(entry->ca_files, value) < 0) {
-                               free ((void *)value);
-                               ret = kr_error(ENOMEM);
-                       } else if (value[0] == 0) {
-                               int res = gnutls_certificate_set_x509_system_trust (entry->credentials);
-                               if (res <= 0) {
-                                       kr_log_error("[tls_client] failed to import certs from system store (%s)\n",
-                                                    gnutls_strerror_name(res));
-                                       /* value will be freed at cleanup */
-                                       ret = kr_error(EINVAL);
-                               } else {
-                                       kr_log_verbose("[tls_client] imported %d certs from system store\n", res);
-                               }
-                       } else {
-                               int res = gnutls_certificate_set_x509_trust_file(entry->credentials, value,
-                                                                                GNUTLS_X509_FMT_PEM);
-                               if (res <= 0) {
-                                       kr_log_error("[tls_client] failed to import certificate file '%s' (%s)\n",
-                                                    value, gnutls_strerror_name(res));
-                                       /* value will be freed at cleanup */
-                                       ret = kr_error(EINVAL);
-                               } else {
-                                       kr_log_verbose("[tls_client] imported %d certs from file '%s'\n",
-                                                       res, value);
-
-                               }
-                       }
-               }
-       } else if (param_type == TLS_CLIENT_PARAM_PIN) {
-               const char *pin = param;
-               for (size_t i = 0; i < entry->pins.len; ++i) {
-                       if (strcmp(entry->pins.at[i], pin) == 0) {
-                               kr_log_error("[tls_client] warning: pin '%s' for address '%s' already was set, ignoring\n", pin, key);
-                               return kr_ok();
-                       }
-               }
-               const void *value = strdup(pin);
-               if (!value) {
-                       ret = kr_error(ENOMEM);
-               } else if (array_push(entry->pins, value) < 0) {
-                       free ((void *)value);
-                       ret = kr_error(ENOMEM);
-               }
-       } else {
-               assert(param_type == TLS_CLIENT_PARAM_NONE);
-       }
-
-       if ((ret == kr_ok()) && is_first_entry) {
-               bool fail = (map_set(tls_client_paramlist, key, entry) != 0);
-               if (fail) {
-                       ret = kr_error(ENOMEM);
+       assert(params && addr);
+       /* We accept NULL for empty map; ensure the map exists if needed. */
+       if (!*params) {
+               if (!do_insert) return NULL;
+               *params = trie_create(NULL);
+               if (!*params) {
+                       assert(!ENOMEM);
+                       return NULL;
                }
        }
-
-       if ((ret != kr_ok()) && is_first_entry) {
-               client_paramlist_entry_unref(entry);
-       }
-
-       return ret;
+       /* Construct the key. */
+       const union inaddr *ia = (const union inaddr *)addr;
+       char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
+       uint32_t len;
+       if (!construct_key(ia, &len, key))
+               return NULL;
+       /* Get the entry. */
+       return (tls_client_param_t **)
+               (do_insert ? trie_get_ins : trie_get_try)(*params, key, len);
 }
 
-int tls_client_params_free(map_t *tls_client_paramlist)
+int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr)
 {
-       if (!tls_client_paramlist) {
+       const union inaddr *ia = (const union inaddr *)addr;
+       char key[sizeof(ia->ip6.sin6_port) + sizeof(ia->ip6.sin6_addr)];
+       uint32_t len;
+       if (!construct_key(ia, &len, key))
                return kr_error(EINVAL);
-       }
-
-       map_walk(tls_client_paramlist, client_paramlist_entry_clear, NULL);
-       map_clear(tls_client_paramlist);
-
+       trie_val_t param_ptr;
+       int ret = trie_del(params, key, len, &param_ptr);
+       if (ret)
+               return kr_error(ret);
+       tls_client_param_unref(param_ptr);
        return kr_ok();
 }
 
@@ -1015,7 +904,7 @@ static int client_verify_certificate(gnutls_session_t tls_session)
        struct tls_client_ctx_t *ctx = gnutls_session_get_ptr(tls_session);
        assert(ctx->params != NULL);
 
-       if (ctx->params->pins.len == 0 && ctx->params->ca_files.len == 0) {
+       if (ctx->params->insecure) {
                return GNUTLS_E_SUCCESS;
        }
 
@@ -1033,9 +922,9 @@ static int client_verify_certificate(gnutls_session_t tls_session)
                return GNUTLS_E_CERTIFICATE_ERROR;
        }
 
-#if GNUTLS_VERSION_NUMBER >= GNUTLS_PIN_MIN_VERSION
+#if TLS_CAN_USE_PINS
        if (ctx->params->pins.len == 0) {
-               DEBUG_MSG("[tls_client] skipping certificate PIN check\n");
+               DEBUG_MSG("[tls_client] configured to authenticate via CA\n");
                goto skip_pins;
        }
 
@@ -1052,32 +941,46 @@ static int client_verify_certificate(gnutls_session_t tls_session)
                        return ret;
                }
 
-               char cert_pin[PINLEN] = { 0 };
-               ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin));
-
+       #ifdef DEBUG
+               if (VERBOSE_STATUS) {
+                       char pin_base64[TLS_SHA256_BASE64_BUFLEN];
+                       /* DEBUG: additionally compute and print the base64 pin.
+                        * Not very efficient, but that's OK for DEBUG. */
+                       ret = get_oob_key_pin(cert, pin_base64, sizeof(pin_base64), false);
+                       if (ret == GNUTLS_E_SUCCESS) {
+                               DEBUG_MSG("[tls_client] received pin: %s\n", pin_base64);
+                       } else {
+                               DEBUG_MSG("[tls_client] failed to convert received pin\n");
+                               /* Now we hope that `ret` below can't differ. */
+                       }
+               }
+       #endif
+               char cert_pin[TLS_SHA256_RAW_LEN];
+               /* Get raw pin and compare. */
+               ret = get_oob_key_pin(cert, cert_pin, sizeof(cert_pin), true);
                gnutls_x509_crt_deinit(cert);
-
                if (ret != GNUTLS_E_SUCCESS) {
                        return ret;
                }
-
-               DEBUG_MSG("[tls_client] received pin  : %s\n", cert_pin);
                for (size_t j = 0; j < ctx->params->pins.len; ++j) {
-                       const char *pin = ctx->params->pins.at[j];
-                       bool match = (strcmp(cert_pin, pin) == 0);
-                       DEBUG_MSG("[tls_client] configured pin: %s matches? %s\n",
-                                 pin, match ? "yes" : "no");
-                       if (match) {
-                               return GNUTLS_E_SUCCESS;
-                       }
+                       const uint8_t *pin = ctx->params->pins.at[j];
+                       if (memcmp(cert_pin, pin, TLS_SHA256_RAW_LEN) != 0)
+                               continue; /* mismatch */
+                       DEBUG_MSG("[tls_client] matched a configured pin no. %zd\n", j);
+                       return GNUTLS_E_SUCCESS;
                }
+               DEBUG_MSG("[tls_client] none of %zd configured pin(s) matched\n",
+                               ctx->params->pins.len);
        }
 
-       /* pins were set, but no one was not matched */
-       kr_log_error("[tls_client] certificate PIN check failed\n");
-#else
+       kr_log_error("[tls_client] no pin matched: %d pins * %d certificates\n",
+                       (int)ctx->params->pins.len, cert_list_size);
+       return GNUTLS_E_CERTIFICATE_ERROR;
+
+#else /* TLS_CAN_USE_PINS */
        if (ctx->params->pins.len != 0) {
-               kr_log_error("[tls_client] newer gnutls is required to use PIN check\n");
+               kr_log_error("[tls_client] internal inconsistency: TLS_CAN_USE_PINS\n");
+               assert(false);
                return GNUTLS_E_CERTIFICATE_ERROR;
        }
        goto skip_pins;
@@ -1085,13 +988,9 @@ static int client_verify_certificate(gnutls_session_t tls_session)
 
 skip_pins:
 
-       if (ctx->params->ca_files.len == 0) {
-               DEBUG_MSG("[tls_client] empty CA files list\n");
-               return GNUTLS_E_CERTIFICATE_ERROR;
-       }
-
        if (!ctx->params->hostname) {
-               DEBUG_MSG("[tls_client] no hostname set\n");
+               kr_log_error("[tls_client] internal config inconsistency: no hostname set\n");
+               assert(false);
                return GNUTLS_E_CERTIFICATE_ERROR;
        }
 
@@ -1123,7 +1022,7 @@ skip_pins:
        return GNUTLS_E_CERTIFICATE_ERROR;
 }
 
-struct tls_client_ctx_t *tls_client_ctx_new(struct tls_client_paramlist_entry *entry,
+struct tls_client_ctx_t *tls_client_ctx_new(tls_client_param_t *entry,
                                            struct worker_ctx *worker)
 {
        struct tls_client_ctx_t *ctx = calloc(1, sizeof (struct tls_client_ctx_t));
@@ -1149,7 +1048,7 @@ struct tls_client_ctx_t *tls_client_ctx_new(struct tls_client_paramlist_entry *e
 
        /* Must take a reference on parameters as the credentials are owned by it
         * and must not be freed while the session is active. */
-       client_paramlist_entry_ref(entry);
+       ++(entry->refs);
        ctx->params = entry;
 
        ret = gnutls_credentials_set(ctx->c.tls_session, GNUTLS_CRD_CERTIFICATE,
@@ -1157,6 +1056,9 @@ struct tls_client_ctx_t *tls_client_ctx_new(struct tls_client_paramlist_entry *e
        if (ret == GNUTLS_E_SUCCESS && entry->hostname) {
                ret = gnutls_server_name_set(ctx->c.tls_session, GNUTLS_NAME_DNS,
                                        entry->hostname, strlen(entry->hostname));
+               kr_log_verbose("[tls_client] set hostname, ret = %d\n", ret);
+       } else if (!entry->hostname) {
+               kr_log_verbose("[tls_client] no hostname\n");
        }
        if (ret != GNUTLS_E_SUCCESS) {
                tls_client_ctx_free(ctx);
@@ -1184,7 +1086,7 @@ void tls_client_ctx_free(struct tls_client_ctx_t *ctx)
        }
 
        /* Must decrease the refcount for referenced parameters */
-       client_paramlist_entry_unref(ctx->params);
+       tls_client_param_unref(ctx->params);
 
        free (ctx);
 }
@@ -1223,7 +1125,7 @@ int tls_client_connect_start(struct tls_client_ctx_t *client_ctx,
        ctx->handshake_state = TLS_HS_IN_PROGRESS;
        ctx->session = session;
 
-       struct tls_client_paramlist_entry *tls_params = client_ctx->params;
+       tls_client_param_t *tls_params = client_ctx->params;
        if (tls_params->session_data.data != NULL) {
                gnutls_session_set_data(ctx->tls_session, tls_params->session_data.data,
                                        tls_params->session_data.size);
index 0b500aa083085b6e3a35ff4d2ce2038666fff7a7..aa37df313b29c26709d260e5b28ad6f646e2fac1 100644 (file)
@@ -21,7 +21,8 @@
 #include <libknot/packet/pkt.h>
 #include "lib/defines.h"
 #include "lib/generic/array.h"
-#include "lib/generic/map.h"
+#include "lib/generic/trie.h"
+#include "lib/utils.h"
 
 #define MAX_TLS_PADDING KR_EDNS_PAYLOAD
 #define TLS_MAX_UNCORK_RETRIES 100
@@ -56,17 +57,59 @@ struct tls_credentials {
        char *ephemeral_servicename;
 };
 
-struct tls_client_paramlist_entry {
-       array_t(const char *) ca_files;
-       const char *hostname; /**< Server name for SNI and certificate check. */
-       array_t(const char *) pins;
-       gnutls_certificate_credentials_t credentials;
-       gnutls_datum_t session_data;
-       uint32_t refs;
-};
+
+#define TLS_SHA256_RAW_LEN 32 /* gnutls_hash_get_len(GNUTLS_DIG_SHA256) */
+/** Required buffer length for pin_sha256, including the zero terminator. */
+#define TLS_SHA256_BASE64_BUFLEN (((TLS_SHA256_RAW_LEN * 8 + 4) / 6) + 3 + 1)
+
+#if GNUTLS_VERSION_NUMBER >= 0x030400
+       #define TLS_CAN_USE_PINS 1
+#else
+       #define TLS_CAN_USE_PINS 0
+#endif
+
+
+/** TLS authentication parameters for a single address-port pair. */
+typedef struct {
+       uint32_t refs; /**< Reference count; consider TLS sessions in progress. */
+       bool insecure; /**< Use no authentication. */
+       const char *hostname; /**< Server name for SNI and certificate check, lowercased.  */
+       array_t(const char *) ca_files; /**< Paths to certificate files; not really used. */
+       array_t(const uint8_t *) pins; /**< Certificate pins as raw unterminated strings.*/
+       gnutls_certificate_credentials_t credentials; /**< CA creds. in gnutls format.  */
+       gnutls_datum_t session_data; /**< Session-resumption data gets stored here.    */
+} tls_client_param_t;
+/** Holds configuration for TLS authentication for all potential servers.
+ * Special case: NULL pointer also means empty. */
+typedef trie_t tls_client_params_t;
+
+/** Get a pointer-to-pointer to TLS auth params.
+ * If it didn't exist, it returns NULL (if !do_insert) or pointer to NULL. */
+tls_client_param_t ** tls_client_param_getptr(tls_client_params_t **params,
+                               const struct sockaddr *addr, bool do_insert);
+
+/** Get a pointer to TLS auth params or NULL. */
+static inline tls_client_param_t *
+       tls_client_param_get(tls_client_params_t *params, const struct sockaddr *addr)
+{
+       tls_client_param_t **pe = tls_client_param_getptr(&params, addr, false);
+       return pe ? *pe : NULL;
+}
+
+/** Allocate and initialize the structure (with ->ref = 1). */
+tls_client_param_t * tls_client_param_new();
+/** Reference-counted free(); any inside data is freed alongside. */
+void tls_client_param_unref(tls_client_param_t *entry);
+
+int tls_client_param_remove(tls_client_params_t *params, const struct sockaddr *addr);
+/** Free TLS authentication parameters. */
+void tls_client_params_free(tls_client_params_t *params);
+
 
 struct worker_ctx;
 struct qr_task;
+struct network;
+struct engine;
 
 typedef enum tls_client_hs_state {
        TLS_HS_NOT_STARTED = 0,
@@ -78,12 +121,6 @@ typedef enum tls_client_hs_state {
 
 typedef int (*tls_handshake_cb) (struct session *session, int status);
 
-typedef enum tls_client_param {
-       TLS_CLIENT_PARAM_NONE = 0,
-       TLS_CLIENT_PARAM_PIN,
-       TLS_CLIENT_PARAM_HOSTNAME,
-       TLS_CLIENT_PARAM_CA,
-} tls_client_param_t;
 
 struct tls_common_ctx {
        bool client_side;
@@ -117,7 +154,7 @@ struct tls_client_ctx_t {
         * this field must be always at first position
         */
        struct tls_common_ctx c;
-       struct tls_client_paramlist_entry *params;
+       tls_client_param_t *params; /**< It's reference-counted. */
 };
 
 /*! Create an empty TLS context in query context */
@@ -162,28 +199,9 @@ tls_hs_state_t tls_get_hs_state(const struct tls_common_ctx *ctx);
 /*! Set TLS handshake state. */
 int tls_set_hs_state(struct tls_common_ctx *ctx, tls_hs_state_t state);
 
-/*! Find TLS parameters for given address. Attempt opportunistic upgrade for port 53 to 853,
- *  if the address is configured with a working DoT on port 853.
- */
-struct tls_client_paramlist_entry *tls_client_try_upgrade(map_t *tls_client_paramlist,
-                         const struct sockaddr *addr);
-
-/*! Clear (remove) TLS parameters for given address. */
-int tls_client_params_clear(map_t *tls_client_paramlist, const char *addr, uint16_t port);
-
-/*! Set TLS authentication parameters for given address.
- * Note: hostname must be set before ca files,
- *       otherwise ca files will not be imported at all.
- */
-int tls_client_params_set(map_t *tls_client_paramlist,
-                         const char *addr, uint16_t port,
-                         const char *param, tls_client_param_t param_type);
-
-/*! Free TLS authentication parameters. */
-int tls_client_params_free(map_t *tls_client_paramlist);
 
 /*! Allocate new client TLS context */
-struct tls_client_ctx_t *tls_client_ctx_new(struct tls_client_paramlist_entry *entry,
+struct tls_client_ctx_t *tls_client_ctx_new(tls_client_param_t *entry,
                                            struct worker_ctx *worker);
 
 /*! Free client TLS context */
index 74ed644c86744272fdac0115b6b51eae0130609c..0c90014346e71e8b17a8afe178a936d0e78a43ab 100644 (file)
@@ -719,7 +719,7 @@ static int session_tls_hs_cb(struct session *session, int status)
 
        /* handshake was completed successfully */
        struct tls_client_ctx_t *tls_client_ctx = session_tls_get_client_ctx(session);
-       struct tls_client_paramlist_entry *tls_params = tls_client_ctx->params;
+       tls_client_param_t *tls_params = tls_client_ctx->params;
        gnutls_session_t tls_session = tls_client_ctx->c.tls_session;
        if (gnutls_session_is_resumed(tls_session) != 0) {
                kr_log_verbose("[tls_client] TLS session has resumed\n");
@@ -1299,11 +1299,9 @@ static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr
        struct worker_ctx *worker = ctx->worker;
 
        /* Check if there must be TLS */
-       struct engine *engine = worker->engine;
-       struct network *net = &engine->net;
-       const char *key = tcpsess_key(addr);
        struct tls_client_ctx_t *tls_ctx = NULL;
-       struct tls_client_paramlist_entry *entry = map_get(&net->tls_client_params, key);
+       struct network *net = &worker->engine->net;
+       tls_client_param_t *entry = tls_client_param_get(net->tls_client_params, addr);
        if (entry) {
                /* Address is configured to be used with TLS.
                 * We need to allocate auxiliary data structure. */
@@ -1334,7 +1332,7 @@ static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr
 
        /* Add address to the waiting list.
         * Now it "is waiting to be connected to." */
-       int ret = worker_add_tcp_waiting(ctx->worker, addr, session);
+       int ret = worker_add_tcp_waiting(worker, addr, session);
        if (ret < 0) {
                free(conn);
                session_close(session);
@@ -1350,7 +1348,7 @@ static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr
        ret = session_timer_start(session, on_tcp_connect_timeout,
                                  KR_CONN_RTT_MAX, 0);
        if (ret != 0) {
-               worker_del_tcp_waiting(ctx->worker, addr);
+               worker_del_tcp_waiting(worker, addr);
                free(conn);
                session_close(session);
                return kr_error(EINVAL);
@@ -1366,7 +1364,7 @@ static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr
        ret = uv_tcp_connect(conn, (uv_tcp_t *)client, addr , on_connect);
        if (ret != 0) {
                session_timer_stop(session);
-               worker_del_tcp_waiting(ctx->worker, addr);
+               worker_del_tcp_waiting(worker, addr);
                free(conn);
                session_close(session);
                unsigned score = qry->flags.FORWARD || qry->flags.STUB ? KR_NS_FWD_DEAD : KR_NS_DEAD;
@@ -1386,7 +1384,7 @@ static int tcp_task_make_connection(struct qr_task *task, const struct sockaddr
        ret = session_waitinglist_push(session, task);
        if (ret < 0) {
                session_timer_stop(session);
-               worker_del_tcp_waiting(ctx->worker, addr);
+               worker_del_tcp_waiting(worker, addr);
                free(conn);
                session_close(session);
                return kr_error(EINVAL);
@@ -1509,16 +1507,18 @@ static int qr_task_step(struct qr_task *task,
        if (task->addrlist_count > 0 && kr_inaddr_port(task->addrlist) == KR_DNS_PORT) {
                /* TODO if there are multiple addresses (task->addrlist_count > 1)
                 * check all of them. */
-               struct engine *engine = worker->engine;
-               struct network *net = &engine->net;
-               struct tls_client_paramlist_entry *tls_entry =
-                       tls_client_try_upgrade(&net->tls_client_params, task->addrlist);
-               if (tls_entry != NULL) {
-                       kr_inaddr_set_port(task->addrlist, KR_DNS_TLS_PORT);
+               struct network *net = &worker->engine->net;
+               kr_inaddr_set_port(task->addrlist, KR_DNS_TLS_PORT);
+               tls_client_param_t *tls_entry =
+                       tls_client_param_get(net->tls_client_params, task->addrlist);
+               if (tls_entry) {
                        packet_source = NULL;
                        sock_type = SOCK_STREAM;
                        /* TODO in this case in tcp_task_make_connection() will be performed
                         * redundant map_get() call. */
+               } else {
+                       /* The function is fairly cheap, so we just change there and back. */
+                       kr_inaddr_set_port(task->addrlist, KR_DNS_PORT);
                }
        }
 
index 0550e95a23addd95da0c9db4fd84c207b16a89a0..72b0c096db0a601695ce01afc3412e02cbfb83b5 100644 (file)
@@ -92,6 +92,7 @@ int trie_get_leq(trie_t *tbl, const char *key, uint32_t len, trie_val_t **val);
  * \param d Parameter passed as the second argument to f().
  * \return First nonzero from f() or zero (i.e. KNOT_EOK).
  */
+KR_EXPORT
 int trie_apply(trie_t *tbl, int (*f)(trie_val_t *, void *), void *d);
 
 /*!
index 2625b6191afb67435f9fb7edac34e187a5d1c0cf..e29e9093eb685cdbad825d40b903ceedc2e92224 100644 (file)
@@ -393,11 +393,20 @@ void kr_inaddr_set_port(struct sockaddr *addr, uint16_t port)
 
 int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
 {
-       if (!addr || !buf || !buflen) {
+       if (!addr) {
+               return kr_error(EINVAL);
+       }
+       return kr_ntop_str(addr->sa_family, kr_inaddr(addr), kr_inaddr_port(addr),
+                          buf, buflen);
+}
+
+int kr_ntop_str(int family, const void *src, uint16_t port, char *buf, size_t *buflen)
+{
+       if (!src || !buf || !buflen) {
                return kr_error(EINVAL);
        }
 
-       if (!inet_ntop(addr->sa_family, kr_inaddr(addr), buf, *buflen)) {
+       if (!inet_ntop(family, src, buf, *buflen)) {
                return kr_error(errno);
        }
        const int len = strlen(buf);
@@ -408,7 +417,7 @@ int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
        }
        *buflen = len_need;
        buf[len] = '#';
-       u16tostr((uint8_t *)&buf[len + 1], kr_inaddr_port(addr));
+       u16tostr((uint8_t *)&buf[len + 1], port);
        buf[len_need - 1] = 0;
        return kr_ok();
 }
index f36b629736d448733864c60e37ff5996f61e92ba..3e0de9f4d2da6404ffec3616be860fa83559cae0 100644 (file)
@@ -145,6 +145,13 @@ static inline void mm_ctx_init(knot_mm_t *mm)
 }
 /* @endcond */
 
+/** A strcmp() variant directly usable for qsort() on an array of strings. */
+static inline int strcmp_p(const void *p1, const void *p2)
+{
+       return strcmp(*(char * const *)p1, *(char * const *)p2);
+}
+
+
 /** Return time difference in miliseconds.
   * @note based on the _BSD_SOURCE timersub() macro */
 static inline long time_diff(struct timeval *begin, struct timeval *end) {
@@ -295,6 +302,26 @@ void kr_inaddr_set_port(struct sockaddr *addr, uint16_t port);
 KR_EXPORT
 int kr_inaddr_str(const struct sockaddr *addr, char *buf, size_t *buflen);
 
+/** Write string representation for given address as "<addr>#<port>".
+ * It's the same as kr_inaddr_str(), but the input address is input in native format
+ * like for inet_ntop() (4 or 16 bytes) and port must be separate parameter.  */
+KR_EXPORT
+int kr_ntop_str(int family, const void *src, uint16_t port, char *buf, size_t *buflen);
+
+/** @internal Create string representation addr#port.
+ *  @return pointer to static string
+ */
+static inline char *kr_straddr(const struct sockaddr *addr)
+{
+       assert(addr != NULL);
+       /* We are the sinle-threaded application */
+       static char str[INET6_ADDRSTRLEN + 1 + 5 + 1];
+       size_t len = sizeof(str);
+       int ret = kr_inaddr_str(addr, str, &len);
+       return ret != kr_ok() || len == 0 ? NULL : str;
+}
+
+
 /** Return address type for string. */
 KR_EXPORT KR_PURE
 int kr_straddr_family(const char *addr);
@@ -418,19 +445,6 @@ static inline uint16_t kr_rrset_type_maysig(const knot_rrset_t *rr)
        return type;
 }
 
-/** @internal Return string representation of addr.
- *  @note return pointer to static string
- */
-static inline char *kr_straddr(const struct sockaddr *addr)
-{
-       assert(addr != NULL);
-       /* We are the sinle-threaded application */
-       static char str[INET6_ADDRSTRLEN + 1 + 5 + 1];
-       size_t len = sizeof(str);
-       int ret = kr_inaddr_str(addr, str, &len);
-       return ret != kr_ok() || len == 0 ? NULL : str;
-}
-
 /** The current time in monotonic milliseconds.
  *
  * \note it may be outdated in case of long callbacks; see uv_now().
index 3b33e03be186c004ea7690a842ce610455063042..b48e1d25d45a2dccdd2a64f5461fbbb1df0e96f9 100644 (file)
@@ -48,13 +48,13 @@ Following actions stop the policy matching on the query, i.e. other rules are no
 * ``DROP`` - terminate query resolution and return SERVFAIL to the requestor
 * ``REFUSE`` - terminate query resolution and return REFUSED to the requestor
 * ``TC`` - set TC=1 if the request came through UDP, forcing client to retry with TCP
-* ``FORWARD(ip)`` - resolve a query via forwarding to an IP while validating and caching locally;
-* ``TLS_FORWARD({{ip, authentication}})`` - resolve a query via TLS connection forwarding to an IP while validating and caching locally;
-  the parameter can be a single IP (string) or a lua list of up to four IPs.
+* ``FORWARD(ip)`` - resolve a query via forwarding to an IP while validating and caching locally
+* ``TLS_FORWARD({{ip, authentication}})`` - resolve a query via TLS connection forwarding to an IP while validating and caching locally
 * ``STUB(ip)`` - similar to ``FORWARD(ip)`` but *without* attempting DNSSEC validation.
   Each request may be either answered from cache or simply sent to one of the IPs with proxying back the answer.
 * ``REROUTE({{subnet,target}, ...})`` - reroute addresses in response matching given subnet to given target, e.g. ``{'192.0.2.0/24', '127.0.0.0'}`` will rewrite '192.0.2.55' to '127.0.0.55', see :ref:`renumber module <mod-renumber>` for more information.
 
+``FORWARD``, ``TLS_FORWARD`` and ``STUB`` support up to four IP addresses "in a single call".
 
 **Chain actions**
 
@@ -90,9 +90,16 @@ Traditional PKI authentication requires server to present certificate with speci
         policy.TLS_FORWARD({
                 {'2001:DB8::d0c', hostname='res.example.com'}})
 
-- `hostname` must exactly match hostname in server's certificate, i.e. in most cases it must not contain trailing dot (`res.example.com`).
-- System CA certificate store will be used if no `ca_file` option is specified.
-- Optional `ca_file` option can specify path to CA certificate (or certificate bundle) in `PEM format`_.
+- ``hostname`` must be a valid domain name matching server's certificate.  It will also be sent to the server as SNI_.
+- ``ca_file`` optionally contains a path to a CA certificate (or certificate bundle) in `PEM format`_.
+  If you omit that, the system CA certificate store will be used instead (usually sufficient).
+  A list of paths is also accepted, but all of them must be valid PEMs.
+
+Key-pinned authentication
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Instead of CAs, you can specify hashes of accepted certificates in ``pin_sha256``.
+They are in the usual format -- base64 from sha256.
+You may still specify ``hostname`` if you want SNI_ to be sent.
 
 TLS Examples
 ~~~~~~~~~~~~
@@ -283,3 +290,4 @@ Most properties (actions, filters) are described above.
 .. _`Transport Layer Security`: https://en.wikipedia.org/wiki/Transport_Layer_Security
 .. _`DNS Privacy Project`: https://dnsprivacy.org/
 .. _`IETF draft dprive-dtls-and-tls-profiles`: https://tools.ietf.org/html/draft-ietf-dprive-dtls-and-tls-profiles
+.. _SNI: https://en.wikipedia.org/wiki/Server_Name_Indication
index 993550fa2254e5dc91af902760e0a231efe0209c..be6a69ba90c58520b5dd88ede46d774a1291b59c 100644 (file)
@@ -144,114 +144,35 @@ function policy.FORWARD(target)
        end
 end
 
--- object must be non-empty string or non-empty table of non-empty strings
-local function is_nonempty_string_or_table(object)
-       if type(object) == 'string' then
-               return #object ~= 0
-       elseif type(object) ~= 'table' or not next(object) then
-               return false
-       end
-       for _, val in pairs(object) do
-               if type(val) ~= 'string' or #val == 0 then
-                       return false
-               end
-       end
-       return true
-end
-
-local function insert_from_string_or_table(source, destination)
-       if type(source) == 'table' then
-               for _, v in pairs(source) do
-                       table.insert(destination, v)
-               end
-       else
-               table.insert(destination, source)
-       end
-end
-
--- Check for allowed authentication types and return type for the current target
-local function tls_forward_target_authtype(idx, target)
-       if (target.pin_sha256 and not (target.ca_file or target.hostname or target.insecure)) then
-               if not is_nonempty_string_or_table(target.pin_sha256) then
-                       error('TLS_FORWARD target authentication is invalid at position '
-                             .. idx .. '; pin_sha256 must be string or list of strings')
-               end
-               return 'pin_sha256'
-       elseif (target.insecure and not (target.ca_file or target.hostname or target.pin_sha256)) then
-               return 'insecure'
-       elseif (target.hostname and not (target.insecure or target.pin_sha256)) then
-               if not (is_nonempty_string_or_table(target.hostname)) then
-                       error('TLS_FORWARD target authentication is invalid at position '
-                             .. idx .. '; hostname must be string or list of strings')
-               end
-               -- if target.ca_file is empty, system CA will be used
-               return 'cert'
-       else
-               error('TLS_FORWARD authentication options at position ' .. idx
-                     .. ' are invalid; specify one of: pin_sha256 / hostname [+ca_file] / insecure')
-       end
-end
-
-local function tls_forward_target_check_syntax(idx, list_entry)
-       if type(list_entry) ~= 'table' then
-               error('TLS_FORWARD target must be a non-empty table (found '
-                     .. type(list_entry) .. ' at position ' .. idx .. ')')
-       end
-       if type(list_entry[1]) ~= 'string' then
-               error('TLS_FORWARD target must start with an IP address (found '
-                     .. type(list_entry[1]) .. ' at the beginning of target position ' .. idx .. ')')
-       end
-end
-
 -- Forward request and all subrequests to upstream over TLS; validate answers
-function policy.TLS_FORWARD(target)
-       local sockaddr_c_list = {}
-       local sockaddr_config = {}  -- items: { string_addr=<addr string>, auth_type=<auth type> }
-       local ca_files = {}
-       local hostnames = {}
-       local pins = {}
-       if type(target) ~= 'table' or #target < 1 then
+function policy.TLS_FORWARD(targets)
+       if type(targets) ~= 'table' or #targets < 1 then
                error('TLS_FORWARD argument must be a non-empty table')
+       elseif #targets > 4 then
+               error('TLS_FORWARD supports at most four targets (in a single call)')
        end
-       for idx, upstream_list_entry in pairs(target) do
-               tls_forward_target_check_syntax(idx, upstream_list_entry)
-               local auth_type = tls_forward_target_authtype(idx, upstream_list_entry)
-               local string_addr = upstream_list_entry[1]
-               local sockaddr_c = addr2sock(string_addr, 853)
-               local sockaddr_lua = ffi.string(sockaddr_c, ffi.C.kr_sockaddr_len(sockaddr_c))
-               if sockaddr_config[sockaddr_lua] then
-                       error('TLS_FORWARD configuration cannot declare two configs for IP address ' .. string_addr)
-               end
-               table.insert(sockaddr_c_list, sockaddr_c)
-               sockaddr_config[sockaddr_lua] = {string_addr=string_addr, auth_type=auth_type}
-               if auth_type == 'cert' then
-                       ca_files[sockaddr_lua] = {}
-                       hostnames[sockaddr_lua] = {}
-                       insert_from_string_or_table(upstream_list_entry.ca_file, ca_files[sockaddr_lua])
-                       insert_from_string_or_table(upstream_list_entry.hostname, hostnames[sockaddr_lua])
-               elseif auth_type == 'pin_sha256' then
-                       pins[sockaddr_lua] = {}
-                       insert_from_string_or_table(upstream_list_entry.pin_sha256, pins[sockaddr_lua])
-               elseif auth_type ~= 'insecure' then
-                       -- insecure does nothing, user does not want authentication
-                       assert(false, 'unsupported auth_type')
+
+       local sockaddr_c_set = {}
+       local nslist = {} -- to persist in closure of the returned function
+       for idx, target in pairs(targets) do
+               if type(target) ~= 'table' or type(target[1]) ~= 'string' then
+                       error('TLS_FORWARD argument number %1 must be a table starting with an address',
+                                       idx)
                end
-       end
+               -- Note: some functions have checks with error() calls inside.
+               local sockaddr_c = addr2sock(target[1], 853)
 
-       -- Update the global table of authentication data only if all checks above passed
-       for sockaddr_lua, config in pairs(sockaddr_config) do
-               assert(#config.string_addr > 0)
-               if config.auth_type == 'insecure' then
-                       net.tls_client(config.string_addr)
-               elseif config.auth_type == 'pin_sha256' then
-                       assert(#pins[sockaddr_lua] > 0)
-                       net.tls_client(config.string_addr, pins[sockaddr_lua])
-               elseif config.auth_type == 'cert' then
-                       assert(#hostnames[sockaddr_lua] > 0) -- but we don't support > 1 anymore
-                       net.tls_client(config.string_addr, ca_files[sockaddr_lua], hostnames[sockaddr_lua])
+               -- Refuse repeated addresses in the same set.
+               local sockaddr_lua = ffi.string(sockaddr_c, ffi.C.kr_sockaddr_len(sockaddr_c))
+               if sockaddr_c_set[sockaddr_lua] then
+                       error('TLS_FORWARD configuration cannot declare two configs for IP address '
+                                       .. target[1])
                else
-                       assert(false, 'unsupported auth_type')
+                       sockaddr_c_set[sockaddr_lua] = true;
                end
+
+               table.insert(nslist, sockaddr_c)
+               net.tls_client(target)
        end
 
        return function(state, req)
@@ -264,7 +185,7 @@ function policy.TLS_FORWARD(target)
                qry.flags.AWAIT_CUT = true
                req.options.TCP = true
                qry.flags.TCP = true
-               set_nslist(qry, sockaddr_c_list)
+               set_nslist(qry, nslist)
                return state
        end
 end
index 8780caa1e1dcbf4352db0feb7c41149337249225..99b46e3813063b92803a236d6013a81eb78f82ae 100644 (file)
@@ -10,40 +10,49 @@ local function test_tls_forward()
        boom(policy.TLS_FORWARD, {{{bleble=''}}}, 'TLS_FORWARD with invalid parameters in table')
 
        boom(policy.TLS_FORWARD, {{'1'}}, 'TLS_FORWARD with invalid IP address')
-       -- boom(policy.TLS_FORWARD, {{{'::1', bleble=''}}}, 'TLS_FORWARD with valid IP and invalid parameters')
+       boom(policy.TLS_FORWARD, {{{'::1', bleble=''}}}, 'TLS_FORWARD with valid IP and invalid parameters')
        boom(policy.TLS_FORWARD, {{{'127.0.0.1'}}}, 'TLS_FORWARD with missing auth parameters')
 
        ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true}}), 'TLS_FORWARD with no authentication')
        boom(policy.TLS_FORWARD, {{{'100:dead::', insecure=true},
                                   {'100:DEAD:0::', insecure=true}
                           }}, 'TLS_FORWARD with duplicate IP addresses is not allowed')
-       ok(policy.TLS_FORWARD({{'100:dead::', insecure=true},
-                              {'100:dead::@443', insecure=true}
+       ok(policy.TLS_FORWARD({{'100:dead::2', insecure=true},
+                              {'100:dead::2@443', insecure=true}
                           }), 'TLS_FORWARD with duplicate IP addresses but different ports is allowed')
-       ok(policy.TLS_FORWARD({{'100:dead::', insecure=true},
-                              {'100:beef::', insecure=true}
+       ok(policy.TLS_FORWARD({{'100:dead::3', insecure=true},
+                              {'100:beef::3', insecure=true}
                           }), 'TLS_FORWARD with different IPv6 addresses is allowed')
        ok(policy.TLS_FORWARD({{'127.0.0.1', insecure=true},
                               {'127.0.0.2', insecure=true}
                           }), 'TLS_FORWARD with different IPv4 addresses is allowed')
 
        boom(policy.TLS_FORWARD, {{{'::1', pin_sha256=''}}}, 'TLS_FORWARD with empty pin_sha256')
-       -- boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='č'}}}, 'TLS_FORWARD with bad pin_sha256')
+       boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='č'}}}, 'TLS_FORWARD with bad pin_sha256')
+       boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='d161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}},
+               'TLS_FORWARD with bad pin_sha256 (short base64)')
+       boom(policy.TLS_FORWARD, {{{'::1', pin_sha256='bbd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg='}}},
+               'TLS_FORWARD with bad pin_sha256 (long base64)')
        ok(policy.TLS_FORWARD({
-                       {'::1', pin_sha256='ZTNiMGM0NDI5OGZjMWMxNDlhZmJmNGM4OTk2ZmI5MjQyN2FlNDFlNDY0OWI5MzRjYTQ5NTk5MWI3ODUyYjg1NQ=='}
+                       {'::1', pin_sha256='g1PpXsxqPchz2tH6w9kcvVXqzQ0QclhInFP2+VWOqic='}
                }), 'TLS_FORWARD with base64 pin_sha256')
        ok(policy.TLS_FORWARD({
                {'::1', pin_sha256={
-                       'ZTNiMGM0NDI5OGZjMWMxNDlhZmJmNGM4OTk2ZmI5MjQyN2FlNDFlNDY0OWI5MzRjYTQ5NTk5MWI3ODUyYjg1NQ==',
-                       'MTcwYWUzMGNjZDlmYmE2MzBhZjhjZGE2ODQxZTAwYzZiNjU3OWNlYzc3NmQ0MTllNzAyZTIwYzY5YzQ4OGZmOA=='
-               }}}), 'TLS_FORWARD with table of pins')
+                       'ev1xcdU++dY9BlcX0QoKeaUftvXQvNIz/PCss1Z/3ek=',
+                       'SgnqTFcvYduWX7+VUnlNFT1gwSNvQdZakH7blChIRbM=',
+                       'bd161VN6aMSSdRN/TSDP6HZOHdaqcIvISlyFB9xLbGg=',
+               }}}), 'TLS_FORWARD with a table of pins')
 
        -- ok(policy.TLS_FORWARD({{'::1', hostname='test.', ca_file='/tmp/ca.crt'}}), 'TLS_FORWARD with hostname + CA cert')
-       ok(policy.TLS_FORWARD({{'::1', hostname='test.'}}), 'TLS_FORWARD with just hostname (use system CA store)')
-       boom(policy.TLS_FORWARD, {{{'::1', ca_file='/tmp/ca.crt'}}}, 'TLS_FORWARD with just CA cert')
-       boom(policy.TLS_FORWARD, {{{'::1', hostname='', ca_file='/tmp/ca.crt'}}}, 'TLS_FORWARD with empty hostname + CA cert')
-       boom(policy.TLS_FORWARD, {{{'::1', hostname='test.', ca_file='/dev/a_file_which_surely_does_NOT_exist!'}}},
-               'TLS_FORWARD with hostname + unreadable CA cert')
+       ok(policy.TLS_FORWARD({{'::1', hostname='test.'}}),
+               'TLS_FORWARD with just hostname (use system CA store)')
+       boom(policy.TLS_FORWARD, {{{'::1', ca_file='/tmp/ca.crt'}}},
+               'TLS_FORWARD with just CA cert')
+       boom(policy.TLS_FORWARD, {{{'::1', hostname='', ca_file='/tmp/ca.crt'}}},
+               'TLS_FORWARD with empty hostname + CA cert')
+       boom(policy.TLS_FORWARD, {
+                       {{'::1', hostname='test.', ca_file='/dev/a_file_which_surely_does_NOT_exist!'}}
+               }, 'TLS_FORWARD with hostname + unreadable CA cert')
 
 end