]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
daemon: PROXYv2 header processing
authorOto Šťáva <oto.stava@nic.cz>
Mon, 13 Dec 2021 14:34:36 +0000 (15:34 +0100)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Tue, 22 Feb 2022 10:52:11 +0000 (10:52 +0000)
17 files changed:
daemon/io.c
daemon/lua/kres-gen-30.lua
daemon/lua/kres-gen-31.lua
daemon/meson.build
daemon/network.c
daemon/network.h
daemon/proxyv2.c [new file with mode: 0644]
daemon/proxyv2.h [new file with mode: 0644]
daemon/session.c
daemon/session.h
daemon/udp_queue.c
daemon/worker.c
daemon/worker.h
lib/resolve.h
lib/test_utils.c
lib/utils.c
lib/utils.h

index 7a4d7eb28b507e1b927ecfc869089c863bb97180..e1e8e8af48b0011fff55b25db0bba29b2340538a 100644 (file)
@@ -17,6 +17,7 @@
 #endif
 
 #include "daemon/network.h"
+#include "daemon/proxyv2.h"
 #include "daemon/worker.h"
 #include "daemon/tls.h"
 #include "daemon/http.h"
@@ -68,26 +69,71 @@ static void handle_getbuf(uv_handle_t* handle, size_t suggested_size, uv_buf_t*
 }
 
 void udp_recv(uv_udp_t *handle, ssize_t nread, const uv_buf_t *buf,
-       const struct sockaddr *addr, unsigned flags)
+       const struct sockaddr *comm_addr, unsigned flags)
 {
        struct session *s = handle->data;
-       if (session_flags(s)->closing || nread <= 0 || addr->sa_family == AF_UNSPEC)
+       if (session_flags(s)->closing || nread <= 0 || comm_addr->sa_family == AF_UNSPEC)
                return;
 
        if (session_flags(s)->outgoing) {
                const struct sockaddr *peer = session_get_peer(s);
                if (kr_fails_assert(peer->sa_family != AF_UNSPEC))
                        return;
-               if (kr_sockaddr_cmp(peer, addr) != 0) {
+               if (kr_sockaddr_cmp(peer, comm_addr) != 0) {
                        kr_log_debug(IO, "<= ignoring UDP from unexpected address '%s'\n",
-                                       kr_straddr(addr));
+                                       kr_straddr(comm_addr));
                        return;
                }
        }
-       ssize_t consumed = session_wirebuf_consume(s, (const uint8_t *)buf->base,
-                                                  nread);
-       kr_assert(consumed == nread);
-       session_wirebuf_process(s, addr);
+
+       const uint8_t *data = (const uint8_t *)buf->base;
+       ssize_t data_len = nread;
+       const struct sockaddr *src_addr = comm_addr;
+       const struct sockaddr *dst_addr = NULL;
+       if (!session_flags(s)->outgoing && proxy_header_present(data, data_len)) {
+               if (!proxy_allowed(&the_worker->engine->net, comm_addr)) {
+                       kr_log_debug(IO, "<= ignoring PROXYv2 UDP from disallowed address '%s'\n",
+                                       kr_straddr(comm_addr));
+                       return;
+               }
+
+               struct proxy_result proxy;
+               ssize_t trimmed = proxy_process_header(&proxy, s, data, data_len);
+               if (trimmed == KNOT_EMALF) {
+                       if (kr_log_is_debug(IO, NULL)) {
+                               kr_log_debug(IO, "<= ignoring malformed PROXYv2 UDP "
+                                               "from address '%s'\n",
+                                               kr_straddr(comm_addr));
+                       }
+                       return;
+               } else if (trimmed < 0) {
+                       if (kr_log_is_debug(IO, NULL)) {
+                               kr_log_debug(IO, "<= error processing PROXYv2 UDP "
+                                               "from address '%s', ignoring\n",
+                                               kr_straddr(comm_addr));
+                       }
+                       return;
+               }
+
+               if (proxy.command == PROXY2_CMD_PROXY && proxy.family != AF_UNSPEC) {
+                       src_addr = &proxy.src_addr.ip;
+                       dst_addr = &proxy.dst_addr.ip;
+
+                       if (kr_log_is_debug(IO, NULL)) {
+                               kr_log_debug(IO, "<= UDP query from '%s'\n",
+                                               kr_straddr(src_addr));
+                               kr_log_debug(IO, "<= proxied through '%s'\n",
+                                               kr_straddr(comm_addr));
+                       }
+               }
+               data = session_wirebuf_get_free_start(s);
+               data_len = nread - trimmed;
+       }
+
+       ssize_t consumed = session_wirebuf_consume(s, data, data_len);
+       kr_assert(consumed == data_len);
+
+       session_wirebuf_process(s, src_addr, comm_addr, dst_addr);
        session_wirebuf_discard(s);
        mp_flush(the_worker->pkt_pool.ctx);
 }
@@ -292,17 +338,68 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
                return;
        }
 
-       ssize_t consumed = 0;
        const uint8_t *data = (const uint8_t *)buf->base;
        ssize_t data_len = nread;
+       const struct sockaddr *src_addr = session_get_peer(s);
+       const struct sockaddr *dst_addr = NULL;
+       if (!session_flags(s)->outgoing && !session_flags(s)->no_proxy &&
+                       proxy_header_present(data, data_len)) {
+               if (!proxy_allowed(&the_worker->engine->net, src_addr)) {
+                       if (kr_log_is_debug(IO, NULL)) {
+                               kr_log_debug(IO, "<= connection to '%s': PROXYv2 not allowed "
+                                               "for this peer, close\n",
+                                               kr_straddr(src_addr));
+                       }
+                       worker_end_tcp(s);
+                       return;
+               }
+
+               struct proxy_result proxy;
+               ssize_t trimmed = proxy_process_header(&proxy, s, data, data_len);
+               if (trimmed < 0) {
+                       if (kr_log_is_debug(IO, NULL)) {
+                               if (trimmed == KNOT_EMALF) {
+                                       kr_log_debug(IO, "<= connection to '%s': "
+                                                       "malformed PROXYv2 header, close\n",
+                                                       kr_straddr(src_addr));
+                               } else {
+                                       kr_log_debug(IO, "<= connection to '%s': "
+                                                       "error processing PROXYv2 header, close\n",
+                                                       kr_straddr(src_addr));
+                               }
+                       }
+                       worker_end_tcp(s);
+                       return;
+               } else if (trimmed == 0) {
+                       return;
+               }
+
+               if (proxy.command != PROXY2_CMD_LOCAL && proxy.family != AF_UNSPEC) {
+                       src_addr = &proxy.src_addr.ip;
+                       dst_addr = &proxy.dst_addr.ip;
+
+                       if (kr_log_is_debug(IO, NULL)) {
+                               kr_log_debug(IO, "<= TCP stream from '%s'\n",
+                                               kr_straddr(src_addr));
+                               kr_log_debug(IO, "<= proxied through '%s'\n",
+                                               kr_straddr(session_get_peer(s)));
+                       }
+               }
+
+               data = session_wirebuf_get_free_start(s);
+               data_len = nread - trimmed;
+       }
+
+       session_flags(s)->no_proxy = true;
+
+       ssize_t consumed = 0;
        if (session_flags(s)->has_tls) {
                /* buf->base points to start of the tls receive buffer.
                   Decode data free space in session wire buffer. */
-               consumed = tls_process_input_data(s, (const uint8_t *)buf->base, nread);
+               consumed = tls_process_input_data(s, data, data_len);
                if (consumed < 0) {
                        if (kr_log_is_debug(IO, NULL)) {
-                               struct sockaddr *peer = session_get_peer(s);
-                               char *peer_str = kr_straddr(peer);
+                               char *peer_str = kr_straddr(src_addr);
                                kr_log_debug(IO, "=> connection to '%s': "
                                               "error processing TLS data, close\n",
                                               peer_str ? peer_str : "");
@@ -320,8 +417,7 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
                consumed = http_process_input_data(s, data, data_len);
                if (consumed < 0) {
                        if (kr_log_is_debug(IO, NULL)) {
-                               struct sockaddr *peer = session_get_peer(s);
-                               char *peer_str = kr_straddr(peer);
+                               char *peer_str = kr_straddr(src_addr);
                                kr_log_debug(IO, "=> connection to '%s': "
                                       "error processing HTTP data, close\n",
                                       peer_str ? peer_str : "");
@@ -341,7 +437,7 @@ static void tcp_recv(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
        consumed = session_wirebuf_consume(s, data, data_len);
        kr_assert(consumed == data_len);
 
-       int ret = session_wirebuf_process(s, session_get_peer(s));
+       int ret = session_wirebuf_process(s, src_addr, session_get_peer(s), dst_addr);
        if (ret < 0) {
                /* An error has occurred, close the session. */
                worker_end_tcp(s);
@@ -828,6 +924,7 @@ static void xdp_rx(uv_poll_t* handle, int status, int events)
                        ret = kr_error(ENOMEM);
                } else {
                        ret = worker_submit(xhd->session,
+                                       (const struct sockaddr *)&msg->ip_from,
                                        (const struct sockaddr *)&msg->ip_from,
                                        (const struct sockaddr *)&msg->ip_to,
                                        msg->eth_from, msg->eth_to, kpkt);
index e9ff550b5e98ecf927031ad8b387c97878899dda..a34102e8c5adba6b01399174800383397eff1cea 100644 (file)
@@ -205,6 +205,7 @@ struct kr_request {
        struct kr_query *current_query;
        struct {
                const struct sockaddr *addr;
+               const struct sockaddr *comm_addr;
                const struct sockaddr *dst_addr;
                const knot_pkt_t *packet;
                struct kr_request_qsource_flags flags;
index 7ebd5337559b324d88a141095086a8bac5c1493b..0a87bcc7fc5ff90092146c90362de505bd03ccb3 100644 (file)
@@ -205,6 +205,7 @@ struct kr_request {
        struct kr_query *current_query;
        struct {
                const struct sockaddr *addr;
+               const struct sockaddr *comm_addr;
                const struct sockaddr *dst_addr;
                const knot_pkt_t *packet;
                struct kr_request_qsource_flags flags;
index 2bf3db01ace895571b23fe599eb0546aa623266e..ba94d95a31c7b1fd109c6aa207a4b9eb130ecabe 100644 (file)
@@ -13,6 +13,7 @@ kresd_src = files([
   'io.c',
   'main.c',
   'network.c',
+  'proxyv2.c',
   'session.c',
   'tls.c',
   'tls_ephemeral_credentials.c',
index e018a02e510a60338d72f122853d179126a66b90..d6e2712b37d955ee7e858e71d4c9c79f6ccfb773 100644 (file)
@@ -24,6 +24,8 @@ void network_init(struct network *net, uv_loop_t *loop, int tcp_backlog)
                net->loop = loop;
                net->endpoints = map_make(NULL);
                net->endpoint_kinds = trie_create(NULL);
+               net->proxy_addrs4 = trie_create(NULL);
+               net->proxy_addrs6 = trie_create(NULL);
                net->tls_client_params = NULL;
                net->tls_session_ticket_ctx = /* unsync. random, by default */
                tls_session_ticket_ctx_create(loop, NULL, 0);
@@ -203,12 +205,27 @@ void network_close_force(struct network *net)
        }
 }
 
+/** Frees all the `struct net_proxy_data` in the specified trie. */
+void network_proxy_free_addr_data(trie_t* trie)
+{
+       trie_it_t *it;
+       for (it = trie_it_begin(trie); !trie_it_finished(it); trie_it_next(it)) {
+               struct net_proxy_data *data = *trie_it_val(it);
+               free(data);
+       }
+       trie_it_free(it);
+}
+
 void network_deinit(struct network *net)
 {
        if (net != NULL) {
                network_close_force(net);
                trie_apply(net->endpoint_kinds, kind_unregister, the_worker->engine->L);
                trie_free(net->endpoint_kinds);
+               network_proxy_free_addr_data(net->proxy_addrs4);
+               trie_free(net->proxy_addrs4);
+               network_proxy_free_addr_data(net->proxy_addrs6);
+               trie_free(net->proxy_addrs6);
 
                tls_credentials_free(net->tls_credentials);
                tls_client_params_free(net->tls_client_params);
@@ -506,6 +523,79 @@ int network_listen(struct network *net, const char *addr, uint16_t port,
        return ret;
 }
 
+int network_proxy_allow(struct network *net, const char* addr)
+{
+       if (kr_fails_assert(net != NULL && addr != NULL))
+               return kr_error(EINVAL);
+
+       int family = kr_straddr_family(addr);
+       if (family < 0) {
+               kr_log_error(NETWORK, "Wrong address format for proxy_allowed: %s\n",
+                               addr);
+               return kr_error(EINVAL);
+       } else if (family == AF_UNIX) {
+               kr_log_error(NETWORK, "Unix sockets not supported for proxy_allowed: %s\n",
+                               addr);
+               return kr_error(EINVAL);
+       }
+
+       union kr_in_addr ia;
+       int netmask = kr_straddr_subnet(&ia, addr);
+       if (netmask < 0) {
+               kr_log_error(NETWORK, "Wrong netmask format for proxy_allowed: %s\n", addr);
+               return kr_error(EINVAL);
+       } else if (netmask == 0) {
+               kr_log_error(NETWORK, "Zero netmask not allowed proxy_allowed: %s\n", addr);
+               return kr_error(EINVAL);
+       }
+
+       size_t addr_length;
+       trie_t *trie;
+       switch (family) {
+       case AF_INET:
+               addr_length = sizeof(ia.ip4);
+               trie = net->proxy_addrs4;
+               break;
+       case AF_INET6:
+               addr_length = sizeof(ia.ip6);
+               trie = net->proxy_addrs6;
+               break;
+       default:
+               kr_assert(false);
+               return kr_error(EINVAL);
+       }
+
+       kr_bitmask((unsigned char *) &ia, addr_length, netmask);
+       trie_val_t *val = trie_get_ins(trie, (char *) &ia, addr_length);
+       if (!val)
+               return kr_error(ENOMEM);
+
+       struct net_proxy_data *data = *val;
+       if (!data) { /* Allocate data if the entry is new in the trie */
+               *val = malloc(sizeof(struct net_proxy_data));
+               data = *val;
+               data->netmask = 0;
+       }
+
+       if (data->netmask == 0) {
+               memcpy(&data->addr, &ia, addr_length);
+               data->netmask = netmask;
+       } else if (data->netmask > netmask) {
+               /* A more relaxed netmask configured - replace it */
+               data->netmask = netmask;
+       }
+
+       return kr_ok();
+}
+
+void network_proxy_reset(struct network *net)
+{
+       network_proxy_free_addr_data(net->proxy_addrs4);
+       trie_clear(net->proxy_addrs4);
+       network_proxy_free_addr_data(net->proxy_addrs6);
+       trie_clear(net->proxy_addrs6);
+}
+
 int network_close(struct network *net, const char *addr, int port)
 {
        endpoint_array_t *ep_array = map_get(&net->endpoints, addr);
index e9e6799700403a3dea95eff41a07767254d35b52..cbba8a8d61f4571dff4ec60d7b755090a124d7c6 100644 (file)
@@ -68,6 +68,12 @@ struct net_tcp_param {
        uint64_t tls_handshake_timeout;
 };
 
+/** Information about an address that is allowed to use PROXYv2. */
+struct net_proxy_data {
+       union kr_in_addr addr;
+       uint8_t netmask;   /**< Number of bits to be matched */
+};
+
 struct network {
        uv_loop_t *loop;
 
@@ -83,6 +89,11 @@ struct network {
        /** See network_engage_endpoints() */
        bool missing_kind_is_error;
 
+       /** IPv4 addresses and networks allowed to use the PROXYv2 protocol */
+       trie_t *proxy_addrs4;
+       /** IPv6 addresses and networks allowed to use the PROXYv2 protocol */
+       trie_t *proxy_addrs6;
+
        struct tls_credentials *tls_credentials;
        tls_client_params_t *tls_client_params; /**< Use tls_client_params_*() functions. */
        struct tls_session_ticket_ctx *tls_session_ticket_ctx;
@@ -105,6 +116,17 @@ void network_deinit(struct network *net);
 int network_listen(struct network *net, const char *addr, uint16_t port,
                   int16_t nic_queue, endpoint_flags_t flags);
 
+/** Allow the specified address to send the PROXYv2 header.
+ * \note the address may be specified with a netmask
+ */
+int network_proxy_allow(struct network *net, const char* addr);
+
+/** Reset all addresses allowed to send the PROXYv2 header. No addresses will
+ * be allowed to send PROXYv2 headers from the point of calling this function
+ * until re-allowed via network_proxy_allow again.
+ */
+void network_proxy_reset(struct network *net);
+
 /** Start listening on an open file-descriptor.
  * \note flags.sock_type isn't meaningful here.
  * \note ownership of flags.* is taken on success.  TODO: non-success?
diff --git a/daemon/proxyv2.c b/daemon/proxyv2.c
new file mode 100644 (file)
index 0000000..061f795
--- /dev/null
@@ -0,0 +1,184 @@
+/*  Copyright (C) 2014-2020 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ *  SPDX-License-Identifier: GPL-3.0-or-later
+ */
+
+#include "daemon/proxyv2.h"
+
+#include "lib/generic/trie.h"
+
+const char PROXY2_SIGNATURE[12] = {
+       0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A
+};
+
+
+/** Gets protocol version from the specified PROXYv2 header. */
+static inline unsigned char proxy2_header_version(const struct proxy2_header* h)
+{
+       return (h->version_command & 0xF0) >> 4;
+}
+
+/** Gets command from the specified PROXYv2 header. */
+static inline enum proxy2_command proxy2_header_command(const struct proxy2_header *h)
+{
+       return h->version_command & 0x0F;
+}
+
+/** Gets address family from the specified PROXYv2 header. */
+static inline enum proxy2_family proxy2_header_family(const struct proxy2_header *h)
+{
+       return (h->family_protocol & 0xF0) >> 4;
+}
+
+/** Gets transport protocol from the specified PROXYv2 header. */
+static inline enum proxy2_family proxy2_header_protocol(const struct proxy2_header *h)
+{
+       return h->family_protocol & 0x0F;
+}
+
+static inline union proxy2_address *proxy2_get_address(const struct proxy2_header *h)
+{
+       return (union proxy2_address *) ((uint8_t *) h + sizeof(struct proxy2_header));
+}
+
+
+bool proxy_allowed(const struct network *net, const struct sockaddr *saddr)
+{
+       union kr_in_addr addr;
+       trie_t *trie;
+       size_t addr_size;
+       switch (saddr->sa_family) {
+       case AF_INET:
+               trie = net->proxy_addrs4;
+               addr_size = sizeof(addr.ip4);
+               addr.ip4 = ((struct sockaddr_in *) saddr)->sin_addr;
+               break;
+       case AF_INET6:
+               trie = net->proxy_addrs6;
+               addr_size = sizeof(addr.ip6);
+               addr.ip6 = ((struct sockaddr_in6 *) saddr)->sin6_addr;
+               break;
+       default:
+               kr_assert(false); // Only IPv4 and IPv6 proxy addresses supported
+               return false;
+       }
+
+       trie_val_t *val;
+       int ret = trie_get_leq(trie, (char *) &addr, addr_size, &val);
+       if (ret != kr_ok() && ret != 1)
+               return false;
+
+       kr_assert(val);
+       const struct net_proxy_data *found = *val;
+       kr_assert(found);
+       return kr_bitcmp((char *) &addr, (char *) &found->addr, found->netmask) == 0;
+}
+
+ssize_t proxy_process_header(struct proxy_result *out, struct session *s,
+               const void *buf, const ssize_t nread)
+{
+       if (!buf)
+               return kr_error(EINVAL);
+
+       const struct proxy2_header *hdr = (struct proxy2_header *) buf;
+
+       uint64_t addr_length = ntohs(hdr->length);
+       ssize_t hdr_len = sizeof(struct proxy2_header) + addr_length;
+
+       /* PROXYv2 requires the header to be received all at once */
+       if (nread < hdr_len) {
+               return kr_error(KNOT_EMALF);
+       }
+
+       unsigned char version = proxy2_header_version(hdr);
+       if (version != 2) {
+               /* Version MUST be 2 for PROXYv2 protocol */
+               return kr_error(KNOT_EMALF);
+       }
+
+       enum proxy2_command command = proxy2_header_command(hdr);
+       if (command == PROXY2_CMD_LOCAL) {
+               /* Addresses for LOCAL are to be discarded */
+               *out = (struct proxy_result) { .command = PROXY2_CMD_LOCAL };
+               goto fill_wirebuf;
+       }
+
+       if (command != PROXY2_CMD_PROXY) {
+               /* PROXYv2 prohibits values other than LOCAL and PROXY */
+               return kr_error(KNOT_EMALF);
+       }
+
+       *out = (struct proxy_result) { .command = PROXY2_CMD_PROXY };
+
+       /* Parse flags */
+       enum proxy2_family family = proxy2_header_family(hdr);
+       switch(family) {
+       case PROXY2_AF_UNSPEC:
+       case PROXY2_AF_UNIX: /* UNIX is unsupported, fall back to UNSPEC */
+               out->family = AF_UNSPEC;
+               break;
+       case PROXY2_AF_INET:
+               out->family = AF_INET;
+               break;
+       case PROXY2_AF_INET6:
+               out->family = AF_INET6;
+               break;
+       default: /* PROXYv2 prohibits other values */
+               return kr_error(KNOT_EMALF);
+       }
+
+       enum proxy2_family protocol = proxy2_header_protocol(hdr);
+       switch (protocol) {
+       case PROXY2_PROTOCOL_DGRAM:
+               out->protocol = SOCK_DGRAM;
+               break;
+       case PROXY2_PROTOCOL_STREAM:
+               out->protocol = SOCK_STREAM;
+               break;
+       default: /* PROXYv2 prohibits other values */
+               return kr_error(KNOT_EMALF);
+       }
+
+       /* Parse addresses */
+       union proxy2_address* addr = proxy2_get_address(hdr);
+       switch(out->family) {
+       case AF_INET:
+               if (addr_length < sizeof(addr->ipv4_addr))
+                       return kr_error(KNOT_EMALF);
+
+               out->src_addr.ip4 = (struct sockaddr_in) {
+                       .sin_family = AF_INET,
+                       .sin_addr = { .s_addr = addr->ipv4_addr.src_addr },
+                       .sin_port = addr->ipv4_addr.src_port,
+               };
+               out->dst_addr.ip4 = (struct sockaddr_in) {
+                       .sin_family = AF_INET,
+                       .sin_addr = { .s_addr = addr->ipv4_addr.dst_addr },
+                       .sin_port = addr->ipv4_addr.dst_port,
+               };
+               break;
+       case AF_INET6:
+               if (addr_length < sizeof(addr->ipv6_addr))
+                       return kr_error(KNOT_EMALF);
+
+               out->src_addr.ip6 = (struct sockaddr_in6) {
+                       .sin6_family = AF_INET6,
+                       .sin6_port = addr->ipv6_addr.src_port
+               };
+               memcpy(
+                               &out->src_addr.ip6.sin6_addr.s6_addr,
+                               &addr->ipv6_addr.src_addr,
+                               sizeof(out->src_addr.ip6.sin6_addr.s6_addr));
+               out->dst_addr.ip6 = (struct sockaddr_in6) {
+                       .sin6_family = AF_INET6,
+                       .sin6_port = addr->ipv6_addr.dst_port
+               };
+               memcpy(
+                               &out->dst_addr.ip6.sin6_addr.s6_addr,
+                               &addr->ipv6_addr.dst_addr,
+                               sizeof(out->dst_addr.ip6.sin6_addr.s6_addr));
+               break;
+       }
+
+fill_wirebuf:
+       return session_wirebuf_trim(s, hdr_len);
+}
diff --git a/daemon/proxyv2.h b/daemon/proxyv2.h
new file mode 100644 (file)
index 0000000..f608ecb
--- /dev/null
@@ -0,0 +1,90 @@
+/*  Copyright (C) 2014-2020 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+ *  SPDX-License-Identifier: GPL-3.0-or-later
+ */
+
+#pragma once
+
+#include <stdint.h>
+
+#include "daemon/session.h"
+#include "daemon/network.h"
+#include "lib/utils.h"
+
+extern const char PROXY2_SIGNATURE[12];
+
+#define PROXY2_MIN_SIZE 16
+#define PROXY2_IP6_ADDR_SIZE 16
+#define PROXY2_UNIX_ADDR_SIZE 108
+
+enum proxy2_command {
+       PROXY2_CMD_LOCAL = 0x0,
+       PROXY2_CMD_PROXY = 0x1
+};
+
+enum proxy2_family {
+       PROXY2_AF_UNSPEC = 0x0,
+       PROXY2_AF_INET   = 0x1,
+       PROXY2_AF_INET6  = 0x2,
+       PROXY2_AF_UNIX   = 0x3
+};
+
+enum proxy2_protocol {
+       PROXY2_PROTOCOL_UNSPEC = 0x0,
+       PROXY2_PROTOCOL_STREAM = 0x1,
+       PROXY2_PROTOCOL_DGRAM  = 0x2
+};
+
+/** PROXYv2 protocol header section */
+struct proxy2_header {
+       uint8_t signature[sizeof(PROXY2_SIGNATURE)];
+       uint8_t version_command;
+       uint8_t family_protocol;
+       uint16_t length; /**< Length of the address section */
+};
+
+/** PROXYv2 protocol address section */
+union proxy2_address {
+       struct {
+               uint32_t src_addr;
+               uint32_t dst_addr;
+               uint16_t src_port;
+               uint16_t dst_port;
+       } ipv4_addr;
+       struct {
+               uint8_t src_addr[PROXY2_IP6_ADDR_SIZE];
+               uint8_t dst_addr[PROXY2_IP6_ADDR_SIZE];
+               uint16_t src_port;
+               uint16_t dst_port;
+       } ipv6_addr;
+       struct {
+               uint8_t src_addr[PROXY2_UNIX_ADDR_SIZE];
+               uint8_t dst_addr[PROXY2_UNIX_ADDR_SIZE];
+       } unix_addr;
+};
+
+/** Parsed result of the PROXY protocol */
+struct proxy_result {
+       enum proxy2_command command;  /**< Proxy command - PROXY or LOCAL. */
+       int family;                   /**< Address family from netinet library (e.g. AF_INET6). */
+       int protocol;                 /**< Protocol type from socket library (e.g. SOCK_STREAM). */
+       union kr_sockaddr src_addr;   /**< Parsed source address and port. */
+       union kr_sockaddr dst_addr;   /**< Parsed destination address and port. */
+};
+
+/** Checks for a PROXY protocol version 2 signature in the specified buffer. */
+static inline bool proxy_header_present(const void* buf, const ssize_t nread)
+{
+       return nread >= PROXY2_MIN_SIZE &&
+               memcmp(buf, PROXY2_SIGNATURE, sizeof(PROXY2_SIGNATURE)) == 0;
+}
+
+/** Checks whether the use of PROXYv2 protocol is allowed for the specified
+ * address. */
+bool proxy_allowed(const struct network *net, const struct sockaddr *saddr);
+
+/** Parses the PROXYv2 header from buf of size nread and writes the result into
+ * out. The rest of the buffer is moved to free bytes of the specified session's
+ * wire buffer. The function assumes that the PROXYv2 signature is present
+ * and has been already checked by the caller (like `udp_recv` or `tcp_recv`). */
+ssize_t proxy_process_header(struct proxy_result *out, struct session *s,
+               const void *buf, ssize_t nread);
index 781db47046ec4a69c62f5184977657e28d87686f..0c1bbac2f70a0c8121756e3fb30cfb33425de74d 100644 (file)
@@ -534,25 +534,30 @@ int session_timer_stop(struct session *session)
 
 ssize_t session_wirebuf_consume(struct session *session, const uint8_t *data, ssize_t len)
 {
-       if (data != &session->wire_buf[session->wire_buf_end_idx]) {
-               /* shouldn't happen */
+       if (kr_fails_assert(data == &session->wire_buf[session->wire_buf_end_idx]))
                return kr_error(EINVAL);
-       }
-
-       if (len < 0) {
-               /* shouldn't happen */
+       if (kr_fails_assert(len >= 0))
                return kr_error(EINVAL);
-       }
-
-       if (session->wire_buf_end_idx + len > session->wire_buf_size) {
-               /* shouldn't happen */
+       if (kr_fails_assert(session->wire_buf_end_idx + len <= session->wire_buf_size))
                return kr_error(EINVAL);
-       }
 
        session->wire_buf_end_idx += len;
        return len;
 }
 
+ssize_t session_wirebuf_trim(struct session *session, ssize_t len)
+{
+       if (kr_fails_assert(len >= 0))
+               return kr_error(EINVAL);
+       if (kr_fails_assert(session->wire_buf_start_idx + len <= session->wire_buf_size))
+               return kr_error(EINVAL);
+
+       session->wire_buf_start_idx += len;
+       if (session->wire_buf_start_idx > session->wire_buf_end_idx)
+               session->wire_buf_end_idx = session->wire_buf_start_idx;
+       return len;
+}
+
 knot_pkt_t *session_produce_packet(struct session *session, knot_mm_t *mm)
 {
        session->sflags.wirebuf_error = false;
@@ -744,7 +749,9 @@ void session_unpoison(struct session *session)
        kr_asan_unpoison(session, sizeof(*session));
 }
 
-int session_wirebuf_process(struct session *session, const struct sockaddr *peer)
+int session_wirebuf_process(
+               struct session *session, const struct sockaddr *src_addr,
+               const struct sockaddr *comm_addr, const struct sockaddr *dst_addr)
 {
        int ret = 0;
        if (session->wire_buf_start_idx == session->wire_buf_end_idx)
@@ -759,7 +766,7 @@ int session_wirebuf_process(struct session *session, const struct sockaddr *peer
               (ret < max_iterations)) {
                if (kr_fails_assert(!session_wirebuf_error(session)))
                        return -1;
-               int res = worker_submit(session, peer, NULL, NULL, NULL, pkt);
+               int res = worker_submit(session, src_addr, comm_addr, dst_addr, NULL, NULL, pkt);
                /* Errors from worker_submit() are intentionally *not* handled in order to
                 * ensure the entire wire buffer is processed. */
                if (res == kr_ok())
index b7e93b2245caa7220e01bfa05c3c0d7204014212..abe1ddd1f47c8818d837f5a084a3944cad3ac5c6 100644 (file)
@@ -20,6 +20,9 @@ struct session_flags {
        bool has_tls : 1;       /**< True: given session uses TLS. */
        bool has_http : 1;      /**< True: given session uses HTTP. */
        bool connected : 1;     /**< True: TCP connection is established. */
+       bool no_proxy : 1;      /**< True: TCP has gotten some data - PROXYv2 header
+                                * disallowed. Proxy headers are only expected at
+                                * the very start of a stream. */
        bool closing : 1;       /**< True: session close sequence is in progress. */
        bool wirebuf_error : 1; /**< True: last operation with wirebuf ended up with an error. */
 };
@@ -128,10 +131,15 @@ size_t session_wirebuf_get_free_size(struct session *session);
 void session_wirebuf_discard(struct session *session);
 /** Move all data to the beginning of the buffer. */
 void session_wirebuf_compress(struct session *session);
-int session_wirebuf_process(struct session *session, const struct sockaddr *peer);
+int session_wirebuf_process(
+               struct session *session, const struct sockaddr *src_addr,
+               const struct sockaddr *comm_addr, const struct sockaddr *dst_addr);
 ssize_t session_wirebuf_consume(struct session *session,
                                const uint8_t *data, ssize_t len);
-
+/** Trims `len` bytes from the start of the session's wire buffer.
+ * If this operation makes the buffer's end appear before the start, it gets
+ * nudged to the same position as the start. */
+ssize_t session_wirebuf_trim(struct session *session, ssize_t len);
 /** poison session structure with ASAN. */
 void session_poison(struct session *session);
 /** unpoison session structure with ASAN. */
index 04225bcfb73e78e59898d81a52631af7394e09b2..b466f2f3263be925a5c7ecfcbf1fa3ce7edc46ce 100644 (file)
@@ -120,7 +120,7 @@ void udp_queue_push(int fd, struct kr_request *req, struct qr_task *task)
        udp_queue_t *const q = state.udp_queues[fd];
 
        /* Append to the queue */
-       struct sockaddr *sa = (struct sockaddr *)/*const-cast*/req->qsource.addr;
+       struct sockaddr *sa = (struct sockaddr *)/*const-cast*/req->qsource.comm_addr;
        q->msgvec[q->len].msg_hdr.msg_name = sa;
        q->msgvec[q->len].msg_hdr.msg_namelen = kr_sockaddr_len(sa);
        q->items[q->len].task = task;
index 5e74c3b2ed568e32e55f0f51c3fccab457613690..0f9983d6064a096b04ca2f8d92663f7d5996ef79 100644 (file)
@@ -65,6 +65,8 @@ struct request_ctx
                struct session *session;
                /** Requestor's address; separate because of UDP session "sharing". */
                union kr_sockaddr addr;
+               /** Request communication address; if not from a proxy, same as addr. */
+               union kr_sockaddr comm_addr;
                /** Local address.  For AF_XDP we couldn't use session's,
                 * as the address might be different every time. */
                union kr_sockaddr dst_addr;
@@ -289,7 +291,7 @@ static uint8_t *alloc_wire_cb(struct kr_request *req, uint16_t *maxlen)
                return NULL;
        xdp_handle_data_t *xhd = handle->data;
        knot_xdp_msg_t out;
-       bool ipv6 = ctx->source.addr.ip.sa_family == AF_INET6;
+       bool ipv6 = ctx->source.comm_addr.ip.sa_family == AF_INET6;
        int ret = knot_xdp_send_alloc(xhd->socket,
                        #if KNOT_VERSION_HEX >= 0x030100
                                        ipv6 ? KNOT_XDP_MSG_IPV6 : 0, &out);
@@ -353,6 +355,7 @@ static inline bool is_tcp_waiting(struct sockaddr *address) {
 static struct request_ctx *request_create(struct worker_ctx *worker,
                                          struct session *session,
                                          const struct sockaddr *addr,
+                                         const struct sockaddr *comm_addr,
                                          const struct sockaddr *dst_addr,
                                          const uint8_t *eth_from,
                                          const uint8_t *eth_to,
@@ -425,6 +428,10 @@ static struct request_ctx *request_create(struct worker_ctx *worker,
                /* We need to store a copy of peer address. */
                memcpy(&ctx->source.addr.ip, addr, kr_sockaddr_len(addr));
                req->qsource.addr = &ctx->source.addr.ip;
+               if (!comm_addr)
+                       comm_addr = addr;
+               memcpy(&ctx->source.comm_addr.ip, comm_addr, kr_sockaddr_len(comm_addr));
+               req->qsource.comm_addr = &ctx->source.comm_addr.ip;
                if (!dst_addr) /* We wouldn't have to copy in this case, but for consistency. */
                        dst_addr = session_get_sockname(session);
                memcpy(&ctx->source.dst_addr.ip, dst_addr, kr_sockaddr_len(dst_addr));
@@ -1377,7 +1384,7 @@ static int xdp_push(struct qr_task *task, const uv_handle_t *src_handle)
 
        knot_xdp_msg_t msg;
        const struct sockaddr *ip_from = &ctx->source.dst_addr.ip;
-       const struct sockaddr *ip_to   = &ctx->source.addr.ip;
+       const struct sockaddr *ip_to   = &ctx->source.comm_addr.ip;
        memcpy(&msg.ip_from, ip_from, kr_sockaddr_len(ip_from));
        memcpy(&msg.ip_to,   ip_to,   kr_sockaddr_len(ip_to));
        msg.payload.iov_base = ctx->req.answer->wire;
@@ -1443,7 +1450,7 @@ static int qr_task_finalize(struct qr_task *task, int state)
                else
                        kr_assert(false);
        } else {
-               ret = qr_task_send(task, source_session, &ctx->source.addr.ip, ctx->req.answer);
+               ret = qr_task_send(task, source_session, &ctx->source.comm_addr.ip, ctx->req.answer);
        }
 
        if (ret != kr_ok()) {
@@ -1796,7 +1803,8 @@ static int parse_packet(knot_pkt_t *query)
 }
 
 int worker_submit(struct session *session,
-                 const struct sockaddr *peer, const struct sockaddr *dst_addr,
+                 const struct sockaddr *src_addr, const struct sockaddr *comm_addr,
+                 const struct sockaddr *dst_addr,
                  const uint8_t *eth_from, const uint8_t *eth_to, knot_pkt_t *pkt)
 {
        if (!session || !pkt)
@@ -1841,7 +1849,7 @@ int worker_submit(struct session *session,
        const struct sockaddr *addr = NULL;
        if (!is_outgoing) { /* request from a client */
                struct request_ctx *ctx =
-                       request_create(the_worker, session, peer, dst_addr,
+                       request_create(the_worker, session, src_addr, comm_addr, dst_addr,
                                        eth_from, eth_to, knot_wire_get_id(pkt->wire));
                if (http_ctx)
                        queue_pop(http_ctx->streams);
@@ -1873,7 +1881,7 @@ int worker_submit(struct session *session,
                }
                if (kr_fails_assert(!session_flags(session)->closing))
                        return kr_error(EINVAL);
-               addr = peer;
+               addr = src_addr;
                /* Note receive time for RTT calculation */
                task->recv_time = kr_now();
        }
@@ -2081,7 +2089,7 @@ struct qr_task *worker_resolve_start(knot_pkt_t *query, struct kr_qflags options
                return NULL;
 
 
-       struct request_ctx *ctx = request_create(worker, NULL, NULL, NULL, NULL, NULL,
+       struct request_ctx *ctx = request_create(worker, NULL, NULL, NULL, NULL, NULL, NULL,
                                                 worker->next_request_uid);
        if (!ctx)
                return NULL;
index 2469d5ae1b3f0260fe35fb3d21d7ef5a1aaf1863..5543ab70586613876b28a759da1cbc993a757b8b 100644 (file)
@@ -31,14 +31,17 @@ void worker_deinit(void);
 /**
  * Process an incoming packet (query from a client or answer from upstream).
  *
- * @param session  session the packet came from, or NULL (not from network)
- * @param peer     address the packet came from, or NULL (not from network)
- * @param eth_*    MAC addresses or NULL (they're useful for XDP)
- * @param pkt      the packet, or NULL (an error from the transport layer)
+ * @param session     session the packet came from, or NULL (not from network)
+ * @param src_addr    original address the packet came from, or NULL (not from network)
+ * @param comm_addr   actual address the packet came from, or NULL (then the same as src_addr).
+ *     May be different from peer if the packet went through a proxy with PROXYv2 enabled.
+ * @param eth_*       MAC addresses or NULL (they're useful for XDP)
+ * @param pkt         the packet, or NULL (an error from the transport layer)
  * @return 0 or an error code
  */
 int worker_submit(struct session *session,
-                 const struct sockaddr *peer, const struct sockaddr *dst_addr,
+                 const struct sockaddr *src_addr, const struct sockaddr *comm_addr,
+                 const struct sockaddr *dst_addr,
                  const uint8_t *eth_from, const uint8_t *eth_to, knot_pkt_t *pkt);
 
 /**
index 7d6cdd19bfb6fb64809903f73bddca4bdd609304..437812ef5ec83f377a561737010e1d57f555ec9f 100644 (file)
@@ -212,6 +212,9 @@ struct kr_request {
        struct {
                /** Address that originated the request. NULL for internal origin. */
                const struct sockaddr *addr;
+               /** Address that communicated the request (e.g. a proxy). Same as
+                * addr if no proxy is used. */
+               const struct sockaddr *comm_addr;
                /** Address that accepted the request.  NULL for internal origin.
                 * Beware: in case of UDP on wildcard address it will be wildcard;
                 * closely related: issue #173. */
index 14843a8c1e08b3342b25c95de351f0f64157557b..a4b846faa5ae49c92e7da3de5b42019afa1c15e9 100644 (file)
@@ -67,6 +67,41 @@ static void test_straddr(void **state)
        assert_int_not_equal(test_bitcmp(ip6_sub, ip6_out, 4), 0);
 }
 
+static inline int assert_bitmask(const char *addr, const char *exp_masked)
+{
+       unsigned char addr_buf[16];
+       unsigned char exp_masked_buf[16];
+
+       int bits = kr_straddr_subnet(addr_buf, addr);
+       size_t addr_len = (kr_straddr_family(addr) == AF_INET6) ? 16 : 4;
+       int exp_masked_bits = kr_straddr_subnet(exp_masked_buf, exp_masked);
+       size_t exp_masked_len = (kr_straddr_family(exp_masked) == AF_INET6) ? 16 : 4;
+
+       /* sanity checks */
+       assert_true(bits >= 0);
+       assert_int_equal(addr_len, exp_masked_len);
+       assert_int_equal(exp_masked_bits, exp_masked_len * 8);
+
+       kr_bitmask(addr_buf, addr_len, bits);
+       return memcmp(addr_buf, exp_masked_buf, addr_len);
+}
+
+static void test_bitmask(void **state)
+{
+       assert_int_equal(assert_bitmask("10.0.1.5/32", "10.0.1.5"), 0);
+       assert_int_equal(assert_bitmask("10.0.1.5", "10.0.1.5"), 0);
+       assert_int_equal(assert_bitmask("10.0.1.5/24", "10.0.1.0"), 0);
+       assert_int_equal(assert_bitmask("128.30.1.16/16", "128.30.0.0"), 0);
+       assert_int_equal(assert_bitmask("255.255.255.255/20", "255.255.240.0"), 0);
+       assert_int_equal(assert_bitmask("255.255.255.255/22", "255.255.252.0"), 0);
+       assert_int_equal(assert_bitmask("192.168.0.1/0", "0.0.0.0"), 0);
+       assert_int_equal(assert_bitmask("7caa::/4", "7000::"), 0);
+       assert_int_equal(assert_bitmask("dead:beef::/16", "dead::"), 0);
+       assert_int_equal(assert_bitmask("dead:beef::/20", "dead:b000::"), 0);
+       assert_int_equal(assert_bitmask("dead:beef::/0", "::"), 0);
+       assert_int_equal(assert_bitmask("64aa:22fa:1378:aaaa:bbbb::/36", "64aa:22fa:1000::"), 0);
+}
+
 static void test_strptime_diff(void **state)
 {
        char *format = "%Y-%m-%dT%H:%M:%S";
@@ -104,7 +139,8 @@ int main(void)
        const UnitTest tests[] = {
                unit_test(test_strcatdup),
                unit_test(test_straddr),
-               unit_test(test_strptime_diff),
+               unit_test(test_bitmask),
+               unit_test(test_strptime_diff)
        };
 
        return run_tests(tests);
index 4b6ff55894a180ed8a330ebc781efa72bfb16a62..9ba845d2ffd44963b9e49e74f86373b5f28272d3 100644 (file)
@@ -552,6 +552,22 @@ int kr_bitcmp(const char *a, const char *b, int bits)
        return ret;
 }
 
+void kr_bitmask(unsigned char *a, size_t a_len, int bits)
+{
+       if (bits < 0 || !a || !a_len) {
+               return;
+       }
+
+       size_t i = bits / 8;
+       const size_t mid_bits = 8 - (bits % 8);
+       const unsigned char mask = 0xFF << mid_bits;
+       if (i < a_len)
+               a[i] &= mask;
+
+       for (++i; i < a_len; ++i)
+               a[i] = 0;
+}
+
 int kr_rrkey(char *key, uint16_t class, const knot_dname_t *owner,
             uint16_t type, uint16_t additional)
 {
index 89000ff063c41a77f85d1949d8e383e54b964379..47e13807cee3cfdd5676e8eb7719b8b375ea7edb 100644 (file)
@@ -240,6 +240,12 @@ union kr_sockaddr {
        struct sockaddr_in6 ip6;
 };
 
+/** Simple storage for IPx addresses. */
+union kr_in_addr {
+       struct in_addr ip4;
+       struct in6_addr ip6;
+};
+
 /** Address bytes for given family. */
 KR_EXPORT KR_PURE
 const char *kr_inaddr(const struct sockaddr *addr);
@@ -337,6 +343,12 @@ int kr_straddr_join(const char *addr, uint16_t port, char *buf, size_t *buflen);
 KR_EXPORT KR_PURE
 int kr_bitcmp(const char *a, const char *b, int bits);
 
+/** Masks bits. The specified number of bits in `a` from the left (network order)
+ * will remain their original value, while the rest will be set to zero.
+ * This is useful for storing network addresses in a trie. */
+KR_EXPORT
+void kr_bitmask(unsigned char *a, size_t a_len, int bits);
+
 /** @internal RR map flags. */
 static const uint8_t KEY_FLAG_RRSIG = 0x02;
 static inline uint8_t KEY_FLAG_RANK(const char *key)