From: Y7n05h Date: Wed, 17 Aug 2022 14:18:11 +0000 (+0800) Subject: dnsdist: add AF_XDP support for udp X-Git-Tag: dnsdist-1.9.0-rc1^2~50 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a61dd3f3b6fed8cc1233b101cdb538593f850aeb;p=thirdparty%2Fpdns.git dnsdist: add AF_XDP support for udp Signed-off-by: Y7n05h --- diff --git a/contrib/xdp-filter.ebpf.src b/contrib/xdp-filter.ebpf.src index 3ead6e151b..8577b08c08 100644 --- a/contrib/xdp-filter.ebpf.src +++ b/contrib/xdp-filter.ebpf.src @@ -17,6 +17,24 @@ BPF_TABLE_PINNED("prog", int, int, progsarray, 2, "/sys/fs/bpf/dnsdist/progs"); BPF_TABLE_PINNED7("lpm_trie", struct CIDR4, struct map_value, cidr4filter, 1024, "/sys/fs/bpf/dnsdist/cidr4", BPF_F_NO_PREALLOC); BPF_TABLE_PINNED7("lpm_trie", struct CIDR6, struct map_value, cidr6filter, 1024, "/sys/fs/bpf/dnsdist/cidr6", BPF_F_NO_PREALLOC); +#ifdef UseXsk +#define BPF_XSKMAP_PIN(_name, _max_entries, _pinned) \ + struct _name##_table_t \ + { \ + u32 key; \ + int leaf; \ + int* (*lookup)(int*); \ + /* xdp_act = map.redirect_map(index, flag) */ \ + u64 (*redirect_map)(int, int); \ + u32 max_entries; \ + }; \ + __attribute__((section("maps/xskmap:" _pinned))) struct _name##_table_t _name = {.max_entries = (_max_entries)} + +BPF_XSKMAP_PIN(xsk_map, 16, "/sys/fs/bpf/dnsdist/xskmap"); +#endif /* UseXsk */ + +#define COMPARE_PORT(x, p) ((x) == bpf_htons(p)) + /* * Recalculate the checksum * Copyright 2020, NLnet Labs, All rights reserved. @@ -39,7 +57,7 @@ static inline void update_checksum(uint16_t *csum, uint16_t old_val, uint16_t ne * Set the TC bit and swap UDP ports * Copyright 2020, NLnet Labs, All rights reserved. */ -static inline enum dns_action set_tc_bit(struct udphdr *udp, struct dnshdr *dns) +static inline void set_tc_bit(struct udphdr* udp, struct dnshdr* dns) { uint16_t old_val = dns->flags.as_value; @@ -49,14 +67,12 @@ static inline enum dns_action set_tc_bit(struct udphdr *udp, struct dnshdr *dns) dns->flags.as_bits_and_pieces.tc = 1; // change the UDP destination to the source - udp->dest = udp->source; - udp->source = bpf_htons(DNS_PORT); + uint16_t tmp = udp->dest; + udp->dest = udp->source; + udp->source = tmp; // calculate and write the new checksum update_checksum(&udp->check, old_val, dns->flags.as_value); - - // bounce - return TC; } /* @@ -65,41 +81,43 @@ static inline enum dns_action set_tc_bit(struct udphdr *udp, struct dnshdr *dns) * TC if (modified) message needs to be replied * DROP if message needs to be blocke */ -static inline enum dns_action check_qname(struct cursor *c) +static inline struct map_value* check_qname(struct cursor* c) { struct dns_qname qkey = {0}; uint8_t qname_byte; uint16_t qtype; int length = 0; - for(int i = 0; i<255; i++) { - if (bpf_probe_read_kernel(&qname_byte, sizeof(qname_byte), c->pos)) { - return PASS; - } - c->pos += 1; - if (length == 0) { - if (qname_byte == 0 || qname_byte > 63 ) { - break; + for (int i = 0; i < 255; i++) { + if (bpf_probe_read_kernel(&qname_byte, sizeof(qname_byte), c->pos)) { + return NULL; + } + c->pos += 1; + if (length == 0) { + if (qname_byte == 0 || qname_byte > 63) { + break; } length += qname_byte; - } else { + } + else { length--; } - if (qname_byte >= 'A' && qname_byte <= 'Z') { - qkey.qname[i] = qname_byte + ('a' - 'A'); - } else { - qkey.qname[i] = qname_byte; - } + if (qname_byte >= 'A' && qname_byte <= 'Z') { + qkey.qname[i] = qname_byte + ('a' - 'A'); + } + else { + qkey.qname[i] = qname_byte; + } } // if the last read qbyte is not 0 incorrect QName format), return PASS if (qname_byte != 0) { - return PASS; + return NULL; } // get QType - if(bpf_probe_read_kernel(&qtype, sizeof(qtype), c->pos)) { - return PASS; + if (bpf_probe_read_kernel(&qtype, sizeof(qtype), c->pos)) { + return NULL; } struct map_value* value; @@ -108,125 +126,184 @@ static inline enum dns_action check_qname(struct cursor *c) qkey.qtype = bpf_htons(qtype); value = qnamefilter.lookup(&qkey); if (value) { - __sync_fetch_and_add(&value->counter, 1); - return value->action; + return value; } // check with Qtype 255 (*) qkey.qtype = 255; - value = qnamefilter.lookup(&qkey); - if (value) { - __sync_fetch_and_add(&value->counter, 1); - return value->action; - } - - return PASS; + return qnamefilter.lookup(&qkey); } /* * Parse IPv4 DNS mesage. - * Returns PASS if message needs to go through (i.e. pass) - * TC if (modified) message needs to be replied - * DROP if message needs to be blocked + * Returns XDP_PASS if message needs to go through (i.e. pass) + * XDP_REDIRECT if message needs to be redirected (for AF_XDP, which needs to be translated to the caller into XDP_PASS outside of the AF_XDP) + * XDP_TX if (modified) message needs to be replied + * XDP_DROP if message needs to be blocked */ -static inline enum dns_action udp_dns_reply_v4(struct cursor *c, struct CIDR4 *key) +static inline enum xdp_action parseIPV4(struct xdp_md* ctx, struct cursor* c) { - struct udphdr *udp; - struct dnshdr *dns; + struct iphdr* ipv4; + struct udphdr* udp = NULL; + struct dnshdr* dns = NULL; + if (!(ipv4 = parse_iphdr(c))) { + return XDP_PASS; + } + switch (ipv4->protocol) { + case IPPROTO_UDP: { + if (!(udp = parse_udphdr(c))) { + return XDP_PASS; + } + if (!IN_DNS_PORT_SET(udp->dest)) { + return XDP_PASS; + } + if (!(dns = parse_dnshdr(c))) { + return XDP_DROP; + } + break; + } - if (!(udp = parse_udphdr(c)) || udp->dest != bpf_htons(DNS_PORT)) { - return PASS; +#ifdef UseXsk + case IPPROTO_TCP: { + struct tcphdr* tcp; + if (!(tcp = parse_tcphdr(c))) { + return XDP_PASS; + } + if (!IN_DNS_PORT_SET(tcp->dest)) { + return XDP_PASS; + } + } +#endif /* UseXsk */ + + default: + return XDP_PASS; } - // check that we have a DNS packet - if (!(dns = parse_dnshdr(c))) { - return PASS; - } + struct CIDR4 key; + key.addr = bpf_htonl(ipv4->saddr); // if the address is blocked, perform the corresponding action - struct map_value* value = v4filter.lookup(&key->addr); + struct map_value* value = v4filter.lookup(&key.addr); if (value) { - __sync_fetch_and_add(&value->counter, 1); - if (value->action == TC) { - return set_tc_bit(udp, dns); - } else { - return value->action; - } + goto res; + } + + key.cidr = 32; + key.addr = bpf_htonl(key.addr); + value = cidr4filter.lookup(&key); + if (value) { + goto res; } - key->cidr = 32; - key->addr = bpf_htonl(key->addr); - value = cidr4filter.lookup(key); + if (dns) { + value = check_qname(c); + } if (value) { + res: __sync_fetch_and_add(&value->counter, 1); - if (value->action == TC) { - return set_tc_bit(udp, dns); + if (value->action == TC && udp && dns) { + set_tc_bit(udp, dns); + // swap src/dest IP addresses + uint32_t swap_ipv4 = ipv4->daddr; + ipv4->daddr = ipv4->saddr; + ipv4->saddr = swap_ipv4; + + progsarray.call(ctx, 1); + return XDP_TX; } - else { - return value->action; + + if (value->action == DROP) { + progsarray.call(ctx, 0); + return XDP_DROP; } } - enum dns_action action = check_qname(c); - if (action == TC) { - return set_tc_bit(udp, dns); - } - return action; + return XDP_REDIRECT; } /* * Parse IPv6 DNS mesage. - * Returns PASS if message needs to go through (i.e. pass) - * TC if (modified) message needs to be replied - * DROP if message needs to be blocked + * Returns XDP_PASS if message needs to go through (i.e. pass) + * XDP_REDIRECT if message needs to be redirected (for AF_XDP, which needs to be translated to the caller into XDP_PASS outside of the AF_XDP) + * XDP_TX if (modified) message needs to be replied + * XDP_DROP if message needs to be blocked */ -static inline enum dns_action udp_dns_reply_v6(struct cursor *c, struct CIDR6* key) +static inline enum xdp_action parseIPV6(struct xdp_md* ctx, struct cursor* c) { - struct udphdr *udp; - struct dnshdr *dns; + struct ipv6hdr* ipv6; + struct udphdr* udp = NULL; + struct dnshdr* dns = NULL; + if (!(ipv6 = parse_ipv6hdr(c))) { + return XDP_PASS; + } + switch (ipv6->nexthdr) { + case IPPROTO_UDP: { + if (!(udp = parse_udphdr(c))) { + return XDP_PASS; + } + if (!IN_DNS_PORT_SET(udp->dest)) { + return XDP_PASS; + } + if (!(dns = parse_dnshdr(c))) { + return XDP_DROP; + } + break; + } - - if (!(udp = parse_udphdr(c)) || udp->dest != bpf_htons(DNS_PORT)) { - return PASS; +#ifdef UseXsk + case IPPROTO_TCP: { + struct tcphdr* tcp; + if (!(tcp = parse_tcphdr(c))) { + return XDP_PASS; + } + if (!IN_DNS_PORT_SET(tcp->dest)) { + return XDP_PASS; + } } +#endif /* UseXsk */ - // check that we have a DNS packet - ; - if (!(dns = parse_dnshdr(c))) { - return PASS; + default: + return XDP_PASS; } + struct CIDR6 key; + key.addr = ipv6->saddr; + // if the address is blocked, perform the corresponding action - struct map_value* value = v6filter.lookup(&key->addr); + struct map_value* value = v6filter.lookup(&key.addr); + if (value) { + goto res; + } + key.cidr = 128; + value = cidr6filter.lookup(&key); if (value) { - __sync_fetch_and_add(&value->counter, 1); - if (value->action == TC) { - return set_tc_bit(udp, dns); - } else { - return value->action; - } + goto res; } - key->cidr = 128; - value = cidr6filter.lookup(key); + if (dns) { + value = check_qname(c); + } if (value) { + res: __sync_fetch_and_add(&value->counter, 1); - if (value->action == TC) { - return set_tc_bit(udp, dns); + if (value->action == TC && udp && dns) { + set_tc_bit(udp, dns); + // swap src/dest IP addresses + struct in6_addr swap_ipv6 = ipv6->daddr; + ipv6->daddr = ipv6->saddr; + ipv6->saddr = swap_ipv6; + progsarray.call(ctx, 1); + return XDP_TX; } - else { - return value->action; + if (value->action == DROP) { + progsarray.call(ctx, 0); + return XDP_DROP; } } - - enum dns_action action = check_qname(c); - if (action == TC) { - return set_tc_bit(udp, dns); - } - return action; + return XDP_REDIRECT; } int xdp_dns_filter(struct xdp_md* ctx) @@ -235,9 +312,7 @@ int xdp_dns_filter(struct xdp_md* ctx) struct cursor c; struct ethhdr *eth; uint16_t eth_proto; - struct iphdr *ipv4; - struct ipv6hdr *ipv6; - int r = 0; + enum xdp_action r; // initialise the cursor cursor_init(&c, ctx); @@ -245,67 +320,36 @@ int xdp_dns_filter(struct xdp_md* ctx) // pass the packet if it is not an ethernet one if ((eth = parse_eth(&c, ð_proto))) { // IPv4 packets - if (eth_proto == bpf_htons(ETH_P_IP)) - { - if (!(ipv4 = parse_iphdr(&c)) || bpf_htons(ipv4->protocol != IPPROTO_UDP)) { - return XDP_PASS; - } - - struct CIDR4 key; - key.addr = bpf_htonl(ipv4->saddr); - // if TC bit must not be set, apply the action - if ((r = udp_dns_reply_v4(&c, &key)) != TC) { - if (r == DROP) { - progsarray.call(ctx, 0); - return XDP_DROP; - } - return XDP_PASS; - } - - // swap src/dest IP addresses - uint32_t swap_ipv4 = ipv4->daddr; - ipv4->daddr = ipv4->saddr; - ipv4->saddr = swap_ipv4; + if (eth_proto == bpf_htons(ETH_P_IP)) { + r = parseIPV4(ctx, &c); + goto res; } // IPv6 packets else if (eth_proto == bpf_htons(ETH_P_IPV6)) { - if (!(ipv6 = parse_ipv6hdr(&c)) || bpf_htons(ipv6->nexthdr != IPPROTO_UDP)) { - return XDP_PASS; - } - struct CIDR6 key; - key.addr = ipv6->saddr; - - // if TC bit must not be set, apply the action - if ((r = udp_dns_reply_v6(&c, &key)) != TC) { - if (r == DROP) { - progsarray.call(ctx, 0); - return XDP_DROP; - } - return XDP_PASS; - } - - // swap src/dest IP addresses - struct in6_addr swap_ipv6 = ipv6->daddr; - ipv6->daddr = ipv6->saddr; - ipv6->saddr = swap_ipv6; + r = parseIPV6(ctx, &c); + goto res; } // pass all non-IP packets - else { - return XDP_PASS; - } + return XDP_PASS; } - else { + return XDP_PASS; +res: + switch (r) { + case XDP_REDIRECT: +#ifdef UseXsk + return xsk_map.redirect_map(ctx->rx_queue_index, 0); +#else return XDP_PASS; +#endif /* UseXsk */ + case XDP_TX: { // swap MAC addresses + uint8_t swap_eth[ETH_ALEN]; + memcpy(swap_eth, eth->h_dest, ETH_ALEN); + memcpy(eth->h_dest, eth->h_source, ETH_ALEN); + memcpy(eth->h_source, swap_eth, ETH_ALEN); + // bounce the request + return XDP_TX; + } + default: + return r; } - - // swap MAC addresses - uint8_t swap_eth[ETH_ALEN]; - memcpy(swap_eth, eth->h_dest, ETH_ALEN); - memcpy(eth->h_dest, eth->h_source, ETH_ALEN); - memcpy(eth->h_source, swap_eth, ETH_ALEN); - - progsarray.call(ctx, 1); - - // bounce the request - return XDP_TX; } diff --git a/contrib/xdp.py b/contrib/xdp.py index 6384c3f8b8..bd96ddb1ce 100644 --- a/contrib/xdp.py +++ b/contrib/xdp.py @@ -27,7 +27,15 @@ blocked_cidr6 = [("2001:db8::1/128", TC_ACTION)] blocked_qnames = [("localhost", "A", DROP_ACTION), ("test.com", "*", TC_ACTION)] # Main -xdp = BPF(src_file="xdp-filter.ebpf.src") +useXsk = True +Ports = [53] +cflag = [] +if useXsk: + cflag.append("-DUseXsk") +IN_DNS_PORT_SET = "||".join("COMPARE_PORT((x),"+str(i)+")" for i in Ports) +cflag.append(r"-DIN_DNS_PORT_SET(x)=(" + IN_DNS_PORT_SET + r")") + +xdp = BPF(src_file="xdp-filter.ebpf.src", cflags=cflag) fn = xdp.load_func("xdp_dns_filter", BPF.XDP) xdp.attach_xdp(DEV, fn, 0) diff --git a/ext/libbpf/libbpf.h b/ext/libbpf/libbpf.h index 2fc7281909..f429545a0b 100644 --- a/ext/libbpf/libbpf.h +++ b/ext/libbpf/libbpf.h @@ -8,19 +8,6 @@ extern "C" { struct bpf_insn; -int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size, - int max_entries, int map_flags); -int bpf_update_elem(int fd, void *key, void *value, unsigned long long flags); -int bpf_lookup_elem(int fd, void *key, void *value); -int bpf_delete_elem(int fd, void *key); -int bpf_get_next_key(int fd, void *key, void *next_key); - -int bpf_prog_load(enum bpf_prog_type prog_type, - const struct bpf_insn *insns, int insn_len, - const char *license, int kern_version); - -int bpf_obj_pin(int fd, const char *pathname); -int bpf_obj_get(const char *pathname); #define LOG_BUF_SIZE 65536 extern char bpf_log_buf[LOG_BUF_SIZE]; diff --git a/pdns/bpf-filter.cc b/pdns/bpf-filter.cc index ec6bd05c55..19343955c8 100644 --- a/pdns/bpf-filter.cc +++ b/pdns/bpf-filter.cc @@ -33,6 +33,20 @@ #include "misc.hh" +int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size, + int max_entries, int map_flags); +int bpf_update_elem(int fd, void *key, void *value, unsigned long long flags); +int bpf_lookup_elem(int fd, void *key, void *value); +int bpf_delete_elem(int fd, void *key); +int bpf_get_next_key(int fd, void *key, void *next_key); + +int bpf_prog_load(enum bpf_prog_type prog_type, + const struct bpf_insn *insns, int insn_len, + const char *license, int kern_version); + +int bpf_obj_pin(int fd, const char *pathname); +int bpf_obj_get(const char *pathname); + static __u64 ptr_to_u64(void *ptr) { return (__u64) (unsigned long) ptr; diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index 73d5f6e5e3..49248a08c7 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -134,6 +134,7 @@ struct InternalQueryState std::unique_ptr d_packet{nullptr}; // Initial packet, so we can restart the query from the response path if needed // 8 std::unique_ptr d_protoBufData{nullptr}; std::unique_ptr d_extendedError{nullptr}; + std::unique_ptr xskPacketHeader; // 8 boost::optional tempFailureTTL{boost::none}; // 8 ClientState* cs{nullptr}; // 8 std::unique_ptr du; // 8 diff --git a/pdns/dnsdist-lua-bindings.cc b/pdns/dnsdist-lua-bindings.cc index 79ba4ec57b..130a71153d 100644 --- a/pdns/dnsdist-lua-bindings.cc +++ b/pdns/dnsdist-lua-bindings.cc @@ -28,6 +28,7 @@ #include "dnsdist-svc.hh" #include "dolog.hh" +#include "xsk.hh" // NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck) @@ -715,7 +716,53 @@ void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck) } }); #endif /* HAVE_EBPF */ - +#ifdef HAVE_XSK + using xskopt_t = LuaAssociativeTable>; + luaCtx.writeFunction("newXsk", [client](xskopt_t opts) { + if (g_configurationDone) { + throw std::runtime_error("newXsk() only can be used at configuration time!"); + } + if (client) { + return std::shared_ptr(nullptr); + } + uint32_t queue_id; + uint32_t frameNums; + std::string ifName; + std::string path; + std::string poolName; + if (opts.count("NIC_queue_id") == 1) { + queue_id = boost::get(opts.at("NIC_queue_id")); + } + else { + throw std::runtime_error("NIC_queue_id field is required!"); + } + if (opts.count("frameNums") == 1) { + frameNums = boost::get(opts.at("frameNums")); + } + else { + throw std::runtime_error("frameNums field is required!"); + } + if (opts.count("ifName") == 1) { + ifName = boost::get(opts.at("ifName")); + } + else { + throw std::runtime_error("ifName field is required!"); + } + if (opts.count("xskMapPath") == 1) { + path = boost::get(opts.at("xskMapPath")); + } + else { + throw std::runtime_error("xskMapPath field is required!"); + } + if (opts.count("pool") == 1) { + poolName = boost::get(opts.at("pool")); + } + extern std::vector> g_xsk; + auto socket = std::make_shared(frameNums, ifName, queue_id, path, poolName); + g_xsk.push_back(socket); + return socket; + }); +#endif /* HAVE_XSK */ /* EDNSOptionView */ luaCtx.registerFunction("count", [](const EDNSOptionView& option) { return option.values.size(); diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 54b7109b19..ac6e6e84e8 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -21,6 +21,7 @@ */ #include +#include #include #include #include @@ -45,6 +46,7 @@ #include "dnsdist-ecs.hh" #include "dnsdist-healthchecks.hh" #include "dnsdist-lua.hh" +#include "xsk.hh" #ifdef LUAJIT_VERSION #include "dnsdist-lua-ffi.hh" #endif /* LUAJIT_VERSION */ @@ -110,7 +112,7 @@ void resetLuaSideEffect() g_noLuaSideEffect = boost::logic::indeterminate; } -using localbind_t = LuaAssociativeTable, LuaArray, LuaAssociativeTable>>; +using localbind_t = LuaAssociativeTable, LuaArray, LuaAssociativeTable, std::shared_ptr>>; static void parseLocalBindVars(boost::optional& vars, bool& reusePort, int& tcpFastOpenQueueSize, std::string& interface, std::set& cpus, int& tcpListenQueueSize, uint64_t& maxInFlightQueriesPerConnection, uint64_t& tcpMaxConcurrentConnections, bool& enableProxyProtocol) { @@ -131,6 +133,16 @@ static void parseLocalBindVars(boost::optional& vars, bool& reusePo } } } +#ifdef HAVE_XSK +static void parseXskVars(boost::optional& vars, std::shared_ptr& socket) +{ + if (!vars) { + return; + } + + getOptionalValue>(vars, "xskSocket", socket); +} +#endif /* HAVE_XSK */ #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_QUIC) static bool loadTLSCertificateAndKeys(const std::string& context, std::vector& pairs, const boost::variant, LuaArray, LuaArray>>& certFiles, const LuaTypeOrArrayOf& keyFiles) @@ -298,7 +310,7 @@ static bool checkConfigurationTime(const std::string& name) // NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) { - typedef LuaAssociativeTable, DownstreamState::checkfunc_t>> newserver_t; + typedef LuaAssociativeTable, std::shared_ptr, DownstreamState::checkfunc_t>> newserver_t; luaCtx.writeFunction("inClientStartup", [client]() { return client && !g_configurationDone; }); @@ -621,7 +633,18 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) if (!(client || configCheck)) { infolog("Added downstream server %s", ret->d_config.remote.toStringWithPort()); } - +#ifdef HAVE_XSK + std::shared_ptr xskSocket; + if (getOptionalValue>(vars, "xskSocket", xskSocket) > 0) { + ret->registerXsk(xskSocket); + std::string mac; + if (getOptionalValue(vars, "MACAddr", mac) != 1) { + throw runtime_error("field MACAddr is required!"); + } + auto* addr = &ret->d_config.destMACAddr[0]; + sscanf(mac.c_str(), "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx", addr, addr + 1, addr + 2, addr + 3, addr + 4, addr + 5); + } +#endif /* HAVE_XSK */ if (autoUpgrade && ret->getProtocol() != dnsdist::Protocol::DoT && ret->getProtocol() != dnsdist::Protocol::DoH) { dnsdist::ServiceDiscovery::addUpgradeableServer(ret, upgradeInterval, upgradePool, upgradeDoHKey, keepAfterUpgrade); } @@ -744,7 +767,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) } // only works pre-startup, so no sync necessary - g_frontends.push_back(std::make_unique(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol)); + auto udpCS = std::make_unique(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol); auto tcpCS = std::make_unique(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol); if (tcpListenQueueSize > 0) { tcpCS->tcpListenQueueSize = tcpListenQueueSize; @@ -756,6 +779,18 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) tcpCS->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections; } +#ifdef HAVE_XSK + std::shared_ptr socket; + parseXskVars(vars, socket); + if (socket) { + udpCS->xskInfo = XskWorker::create(); + udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset; + socket->addWorker(udpCS->xskInfo, loc, false); + // tcpCS->xskInfo=XskWorker::create(); + // TODO: socket->addWorker(tcpCS->xskInfo, loc, true); + } +#endif /* HAVE_XSK */ + g_frontends.push_back(std::move(udpCS)); g_frontends.push_back(std::move(tcpCS)); } catch (const std::exception& e) { @@ -786,7 +821,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) try { ComboAddress loc(addr, 53); // only works pre-startup, so no sync necessary - g_frontends.push_back(std::make_unique(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol)); + auto udpCS = std::make_unique(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol); auto tcpCS = std::make_unique(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol); if (tcpListenQueueSize > 0) { tcpCS->tcpListenQueueSize = tcpListenQueueSize; @@ -797,6 +832,18 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) if (tcpMaxConcurrentConnections > 0) { tcpCS->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections; } +#ifdef HAVE_XSK + std::shared_ptr socket; + parseXskVars(vars, socket); + if (socket) { + udpCS->xskInfo = XskWorker::create(); + udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset; + socket->addWorker(udpCS->xskInfo, loc, false); + // TODO tcpCS->xskInfo=XskWorker::create(); + // TODO socket->addWorker(tcpCS->xskInfo, loc, true); + } +#endif /* HAVE_XSK */ + g_frontends.push_back(std::move(udpCS)); g_frontends.push_back(std::move(tcpCS)); } catch (std::exception& e) { diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index c15a14484d..3baf478ea2 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -495,9 +495,8 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe if (!response.isAsync()) { try { auto& ids = response.d_idstate; - unsigned int qnameWireLength{0}; std::shared_ptr backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr); - if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend, qnameWireLength)) { + if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend)) { state->terminateClientConnection(); return; } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 3beef62b12..304cd0f99b 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -29,9 +29,15 @@ #include #include #include +#include #include #include +#ifdef HAVE_XSK +#include +#include +#endif /* HAVE_XSK */ + #ifdef HAVE_LIBEDIT #if defined (__OpenBSD__) || defined(__NetBSD__) // If this is not undeffed, __attribute__ wil be redefined by /usr/include/readline/rlstdc.h @@ -111,6 +117,7 @@ std::vector> g_dohlocals; std::vector> g_doqlocals; std::vector> g_doh3locals; std::vector> g_dnsCryptLocals; +std::vector> g_xsk; shared_ptr g_defaultBPFFilter{nullptr}; std::vector > g_dynBPFFilters; @@ -332,7 +339,7 @@ static void doLatencyStats(dnsdist::Protocol protocol, double udiff) } } -bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote, unsigned int& qnameWireLength) +bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote) { if (response.size() < sizeof(dnsheader)) { return false; @@ -363,7 +370,7 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, uint16_t rqtype, rqclass; DNSName rqname; try { - rqname = DNSName(reinterpret_cast(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass, &qnameWireLength); + rqname = DNSName(reinterpret_cast(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass); } catch (const std::exception& e) { if (remote && response.size() > 0 && static_cast(response.size()) > sizeof(dnsheader)) { @@ -743,9 +750,7 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re } bool muted = true; - if (ids.cs && !ids.cs->muted) { - ComboAddress empty; - empty.sin4.sin_family = 0; + if (ids.cs && !ids.cs->muted && !ids.xskPacketHeader) { sendUDPResponse(ids.cs->udpFD, response, dr.ids.delayMsec, ids.hopLocal, ids.hopRemote); muted = false; } @@ -766,6 +771,61 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re } } +#ifdef HAVE_XSK +static void XskHealthCheck(std::shared_ptr& dss, std::unordered_map>& map, bool initial = false) +{ + auto& xskInfo = dss->xskInfo; + std::shared_ptr data; + auto packet = getHealthCheckPacket(dss, nullptr, data); + data->d_initial = initial; + setHealthCheckTime(dss, data); + auto* frame = xskInfo->getEmptyframe(); + auto *xskPacket = new XskPacket(frame, 0, xskInfo->frameSize); + xskPacket->setAddr(dss->d_config.sourceAddr, dss->d_config.sourceMACAddr, dss->d_config.remote, dss->d_config.destMACAddr); + xskPacket->setPayload(packet); + xskPacket->rewrite(); + xskInfo->sq.push(xskPacket); + const auto queryId = data->d_queryID; + map[queryId] = std::move(data); +} +#endif /* HAVE_XSK */ + +static bool processResponderPacket(std::shared_ptr& dss, PacketBuffer& response, const std::vector& localRespRuleActions, const std::vector& cacheInsertedRespRuleActions, InternalQueryState&& ids) +{ + + const dnsheader_aligned dh(response.data()); + auto queryId = dh->id; + + if (!responseContentMatches(response, ids.qname, ids.qtype, ids.qclass, dss)) { + dss->restoreState(queryId, std::move(ids)); + return false; + } + + auto du = std::move(ids.du); + dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) { + header.id = ids.origID; + return true; + }); + ++dss->responses; + + double udiff = ids.queryRealTime.udiff(); + // do that _before_ the processing, otherwise it's not fair to the backend + dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0; + dss->reportResponse(dh->rcode); + + /* don't call processResponse for DOH */ + if (du) { +#ifdef HAVE_DNS_OVER_HTTPS + // DoH query, we cannot touch du after that + DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(ids), dss); +#endif + return false; + } + + handleResponseForUDPClient(ids, response, localRespRuleActions, cacheInsertedRespRuleActions, dss, false, false); + return true; +} + // listens on a dedicated socket, lobs answers from downstream servers to original requestors void responderThread(std::shared_ptr dss) { @@ -773,6 +833,103 @@ void responderThread(std::shared_ptr dss) setThreadName("dnsdist/respond"); auto localRespRuleActions = g_respruleactions.getLocal(); auto localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); +#ifdef HAVE_XSK + if (dss->xskInfo) { + auto xskInfo = dss->xskInfo; + auto pollfds = getPollFdsForWorker(*xskInfo); + std::unordered_map> healthCheckMap; + XskHealthCheck(dss, healthCheckMap, true); + itimerspec tm; + tm.it_value.tv_sec = dss->d_config.checkTimeout / 1000; + tm.it_value.tv_nsec = (dss->d_config.checkTimeout % 1000) * 1000000; + tm.it_interval = tm.it_value; + auto res = timerfd_settime(pollfds[1].fd, 0, &tm, nullptr); + if (res) { + throw std::runtime_error("timerfd_settime failed:" + stringerror(errno)); + } + const auto xskFd = xskInfo->workerWaker.getHandle(); + while (!dss->isStopped()) { + poll(pollfds.data(), pollfds.size(), -1); + bool needNotify = false; + if (pollfds[0].revents & POLLIN) { + needNotify = true; + xskInfo->cq.consume_all([&](XskPacket* packet) { + if (packet->dataLen() < sizeof(dnsheader)) { + xskInfo->sq.push(packet); + return; + } + const auto* dh = reinterpret_cast(packet->payloadData()); + const auto queryId = dh->id; + auto ids = dss->getState(queryId); + if (ids) { + if (xskFd != ids->backendFD || !ids->xskPacketHeader) { + dss->restoreState(queryId, std::move(*ids)); + ids = std::nullopt; + } + } + if (!ids) { + // this has to go before we can refactor the duplicated response handling code + auto iter = healthCheckMap.find(queryId); + if (iter != healthCheckMap.end()) { + auto data = std::move(iter->second); + healthCheckMap.erase(iter); + packet->cloneIntoPacketBuffer(data->d_buffer); + data->d_ds->submitHealthCheckResult(data->d_initial, handleResponse(data)); + } + xskInfo->sq.push(packet); + return; + } + auto response = packet->clonePacketBuffer(); + if (!processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids))) { + xskInfo->sq.push(packet); + return; + } + packet->setHeader(*ids->xskPacketHeader); + packet->setPayload(response); + if (ids->delayMsec > 0) { + packet->addDelay(ids->delayMsec); + } + packet->updatePacket(); + xskInfo->sq.push(packet); + }); + xskInfo->cleanSocketNotification(); + } + if (pollfds[1].revents & POLLIN) { + timeval now; + gettimeofday(&now, nullptr); + for (auto i = healthCheckMap.begin(); i != healthCheckMap.end();) { + auto& ttd = i->second->d_ttd; + if (ttd < now) { + dss->submitHealthCheckResult(i->second->d_initial, false); + i = healthCheckMap.erase(i); + } + else { + ++i; + } + } + needNotify = true; + dss->updateStatisticsInfo(); + dss->handleUDPTimeouts(); + if (dss->d_nextCheck <= 1) { + dss->d_nextCheck = dss->d_config.checkInterval; + if (dss->d_config.availability == DownstreamState::Availability::Auto) { + XskHealthCheck(dss, healthCheckMap); + } + } + else { + --dss->d_nextCheck; + } + + uint64_t tmp; + res = read(pollfds[1].fd, &tmp, sizeof(tmp)); + } + if (needNotify) { + xskInfo->notifyXskSocket(); + } + } + } + else { +#endif /* HAVE_XSK */ const size_t initialBufferSize = getInitialUDPPacketBufferSize(false); /* allocate one more byte so we can detect truncation */ PacketBuffer response(initialBufferSize + 1); @@ -805,7 +962,7 @@ void responderThread(std::shared_ptr dss) for (const auto& fd : sockets) { /* allocate one more byte so we can detect truncation */ - // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it + // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it response.resize(initialBufferSize + 1); ssize_t got = recv(fd, response.data(), response.size(), 0); @@ -826,41 +983,32 @@ void responderThread(std::shared_ptr dss) continue; } - unsigned int qnameWireLength = 0; - if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) { + if (fd != ids->backendFD) { dss->restoreState(queryId, std::move(*ids)); continue; } - auto du = std::move(ids->du); - - dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) { - header.id = ids->origID; - return true; - }); - ++dss->responses; - - double udiff = ids->queryRealTime.udiff(); - // do that _before_ the processing, otherwise it's not fair to the backend - dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0; - dss->reportResponse(dh->rcode); - - /* don't call processResponse for DOH */ - if (du) { -#ifdef HAVE_DNS_OVER_HTTPS - // DoH query, we cannot touch du after that - DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(*ids), dss); -#endif - continue; + if (processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids)) && ids->xskPacketHeader && ids->cs->xskInfo) { +#ifdef HAVE_XSK + auto& xskInfo = ids->cs->xskInfo; + auto* frame = xskInfo->getEmptyframe(); + auto xskPacket = std::make_unique(frame, 0, xskInfo->frameSize); + xskPacket->setHeader(*ids->xskPacketHeader); + xskPacket->setPayload(response); + xskPacket->updatePacket(); + xskInfo->sq.push(xskPacket.release()); + xskInfo->notifyXskSocket(); +#endif /* HAVE_XSK */ } - - handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, false); } } catch (const std::exception& e) { vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->d_config.remote.toStringWithPort(), queryId, e.what()); } } +#ifdef HAVE_XSK + } +#endif /* HAVE_XSK */ } catch (const std::exception& e) { errlog("UDP responder thread died because of exception: %s", e.what()); @@ -1280,6 +1428,23 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s return true; } +#ifdef HAVE_XSK +static bool isXskQueryAcceptable(const XskPacket& packet, ClientState& cs, LocalHolders& holders, bool& expectProxyProtocol) noexcept +{ + const auto& from = packet.getFromAddr(); + expectProxyProtocol = expectProxyProtocolFrom(from); + if (!holders.acl->match(from) && !expectProxyProtocol) { + vinfolog("Query from %s dropped because of ACL", from.toStringWithPort()); + ++dnsdist::metrics::g_stats.aclDrops; + return false; + } + cs.queries++; + ++dnsdist::metrics::g_stats.queries; + + return true; +} +#endif /* HAVE_XSK */ + bool checkDNSCryptQuery(const ClientState& cs, PacketBuffer& query, std::unique_ptr& dnsCryptQuery, time_t now, bool tcp) { if (cs.dnscryptCtx) { @@ -1408,7 +1573,11 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders ++dq.ids.cs->responses; return ProcessQueryResult::SendAnswer; } - +#ifdef HAVE_XSK + if (dq.ids.cs->xskInfo) { + dq.ids.poolName = dq.ids.cs->xskInfo->poolName; + } +#endif /* HAVE_XSK */ std::shared_ptr serverPool = getPool(*holders.pools, dq.ids.poolName); std::shared_ptr poolPolicy = serverPool->policy; dq.ids.packetCache = serverPool->packetCache; @@ -1649,7 +1818,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::sha return ProcessQueryResult::Drop; } -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query) +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool actuallySend) { bool doh = dq.ids.du != nullptr; @@ -1670,7 +1839,9 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint1 try { int fd = ds->pickSocketForSending(); - dq.ids.backendFD = fd; + if (actuallySend) { + dq.ids.backendFD = fd; + } dq.ids.origID = queryID; dq.ids.forwardedOverUDP = true; @@ -1682,6 +1853,10 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint1 /* set the correct ID */ memcpy(&query.at(proxyProtocolPayloadSize), &idOffset, sizeof(idOffset)); + if (!actuallySend) { + return true; + } + /* you can't touch ids or du after this line, unless the call returned a non-negative value, because it might already have been freed */ ssize_t ret = udpClientSendRequestToBackend(ds, fd, query); @@ -1839,6 +2014,127 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } } +#ifdef HAVE_XSK +static void ProcessXskQuery(ClientState& cs, LocalHolders& holders, XskPacket& packet) +{ + uint16_t queryId = 0; + const auto& remote = packet.getFromAddr(); + const auto& dest = packet.getToAddr(); + InternalQueryState ids; + ids.cs = &cs; + ids.origRemote = remote; + ids.hopRemote = remote; + ids.origDest = dest; + ids.hopLocal = dest; + ids.protocol = dnsdist::Protocol::DoUDP; + ids.xskPacketHeader = packet.cloneHeadertoPacketBuffer(); + + try { + bool expectProxyProtocol = false; + if (!isXskQueryAcceptable(packet, cs, holders, expectProxyProtocol)) { + return; + } + + auto query = packet.clonePacketBuffer(); + std::vector proxyProtocolValues; + if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, ids.origRemote, ids.origDest, proxyProtocolValues)) { + return; + } + + ids.queryRealTime.start(); + + auto dnsCryptResponse = checkDNSCryptQuery(cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, false); + if (dnsCryptResponse) { + packet.setPayload(query); + return; + } + + { + /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ + dnsheader_aligned dnsHeader(query.data()); + queryId = ntohs(dnsHeader->id); + + if (!checkQueryHeaders(dnsHeader.get(), cs)) { + return; + } + + if (dnsHeader->qdcount == 0) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); + packet.setPayload(query); + return; + } + } + + ids.qname = DNSName(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + if (ids.origDest.sin4.sin_family == 0) { + ids.origDest = cs.local; + } + if (ids.dnsCryptQuery) { + ids.protocol = dnsdist::Protocol::DNSCryptUDP; + } + DNSQuestion dq(ids, query); + if (!proxyProtocolValues.empty()) { + dq.proxyProtocolValues = make_unique>(std::move(proxyProtocolValues)); + } + std::shared_ptr ss{nullptr}; + auto result = processQuery(dq, holders, ss); + + if (result == ProcessQueryResult::Drop) { + return; + } + + if (result == ProcessQueryResult::SendAnswer) { + packet.setPayload(query); + if (dq.ids.delayMsec > 0) { + packet.addDelay(dq.ids.delayMsec); + } + return; + } + + if (result != ProcessQueryResult::PassToBackend || ss == nullptr) { + return; + } + + // the buffer might have been invalidated by now (resized) + const auto dh = dq.getHeader(); + if (ss->isTCPOnly()) { + std::string proxyProtocolPayload; + /* we need to do this _before_ creating the cross protocol query because + after that the buffer will have been moved */ + if (ss->d_config.useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dq); + } + + ids.origID = dh->id; + auto cpq = std::make_unique(std::move(query), std::move(ids), ss); + cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); + + ss->passCrossProtocolQuery(std::move(cpq)); + return; + } + + if (!ss->xskInfo) { + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, true); + } + else { + int fd = ss->xskInfo->workerWaker; + ids.backendFD = fd; + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, false); + packet.setAddr(ss->d_config.sourceAddr,ss->d_config.sourceMACAddr, ss->d_config.remote,ss->d_config.destMACAddr); + packet.setPayload(query); + packet.rewrite(); + } + } + catch (const std::exception& e) { + vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); + } +} +#endif /* HAVE_XSK */ + #ifndef DISABLE_RECVMMSG #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders) @@ -1931,6 +2227,27 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ #endif /* DISABLE_RECVMMSG */ +#ifdef HAVE_XSK +static void xskClientThread(ClientState* cs) +{ + setThreadName("dnsdist/xskClient"); + auto xskInfo = cs->xskInfo; + LocalHolders holders; + + for (;;) { + while (!xskInfo->cq.read_available()) { + xskInfo->waitForXskSocket(); + } + xskInfo->cq.consume_all([&](XskPacket* packet) { + ProcessXskQuery(*cs, holders, *packet); + packet->updatePacket(); + xskInfo->sq.push(packet); + }); + xskInfo->notifyXskSocket(); + } +} +#endif /* HAVE_XSK */ + // listens to incoming queries, sends out to downstream servers, noting the intended return path static void udpClientThread(std::vector states) { @@ -2177,11 +2494,12 @@ static void healthChecksThread() std::unique_ptr mplexer{nullptr}; for (auto& dss : *states) { - auto delta = dss->sw.udiffAndSet()/1000000.0; - dss->queryLoad.store(1.0*(dss->queries.load() - dss->prev.queries.load())/delta); - dss->dropRate.store(1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta); - dss->prev.queries.store(dss->queries.load()); - dss->prev.reuseds.store(dss->reuseds.load()); +#ifdef HAVE_XSK + if (dss->xskInfo) { + continue; + } +#endif /* HAVE_XSK */ + dss->updateStatisticsInfo(); dss->handleUDPTimeouts(); @@ -2909,13 +3227,35 @@ static void initFrontends() } } +#ifdef HAVE_XSK +void XskRouter(std::shared_ptr xsk); +#endif /* HAVE_XSK */ + namespace dnsdist { static void startFrontends() { +#ifdef HAVE_XSK + for (auto& xskContext : g_xsk) { + std::thread xskThread(XskRouter, std::move(xskContext)); + xskThread.detach(); + } +#endif /* HAVE_XSK */ + std::vector tcpStates; std::vector udpStates; for (auto& clientState : g_frontends) { +#ifdef HAVE_XSK + if (clientState->xskInfo) { + std::thread xskCT(xskClientThread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(xskCT.native_handle(), clientState->cpus); + } + xskCT.detach(); + continue; + } +#endif /* HAVE_XSK */ + if (clientState->dohFrontend != nullptr && clientState->dohFrontend->d_library == "h2o") { #ifdef HAVE_DNS_OVER_HTTPS #ifdef HAVE_LIBH2OEVLOOP @@ -3175,6 +3515,12 @@ int main(int argc, char** argv) auto states = g_dstates.getCopy(); // it is a copy, but the internal shared_ptrs are the real deal auto mplexer = std::unique_ptr(FDMultiplexer::getMultiplexerSilent(states.size())); for (auto& dss : states) { +#ifdef HAVE_XSK + if (dss->xskInfo) { + continue; + } +#endif /* HAVE_XSK */ + if (dss->d_config.availability == DownstreamState::Availability::Auto || dss->d_config.availability == DownstreamState::Availability::Lazy) { if (dss->d_config.availability == DownstreamState::Availability::Auto) { dss->d_nextCheck = dss->d_config.checkInterval; @@ -3270,3 +3616,72 @@ int main(int argc, char** argv) #endif } } + +#ifdef HAVE_XSK +void XskRouter(std::shared_ptr xsk) +{ + setThreadName("dnsdist/XskRouter"); + uint32_t failed; + // packets to be submitted for sending + vector fillInTx; + const auto size = xsk->fds.size(); + // list of workers that need to be notified + std::set needNotify; + const auto& xskWakerIdx = xsk->workers.get<0>(); + const auto& destIdx = xsk->workers.get<1>(); + while (true) { + auto ready = xsk->wait(-1); + // descriptor 0 gets incoming AF_XDP packets + if (xsk->fds[0].revents & POLLIN) { + auto packets = xsk->recv(64, &failed); + dnsdist::metrics::g_stats.nonCompliantQueries += failed; + for (auto &packet : packets) { + const auto dest = packet->getToAddr(); + auto res = destIdx.find(dest); + if (res == destIdx.end()) { + xsk->uniqueEmptyFrameOffset.push_back(xsk->frameOffset(*packet)); + continue; + } + res->worker->cq.push(packet.release()); + needNotify.insert(res->workerWaker); + } + for (auto i : needNotify) { + uint64_t x = 1; + auto written = write(i, &x, sizeof(x)); + if (written != sizeof(x)) { + // oh, well, the worker is clearly overloaded + // but there is nothing we can do about it, + // and hopefully the queue will be processed eventually + } + } + needNotify.clear(); + ready--; + } + const auto backup = ready; + for (size_t i = 1; i < size && ready > 0; i++) { + if (xsk->fds[i].revents & POLLIN) { + ready--; + auto& info = xskWakerIdx.find(xsk->fds[i].fd)->worker; + info->sq.consume_all([&](XskPacket* x) { + if (!(x->getFlags() & XskPacket::UPDATE)) { + xsk->uniqueEmptyFrameOffset.push_back(xsk->frameOffset(*x)); + return; + } + auto ptr = std::unique_ptr(x); + if (x->getFlags() & XskPacket::DELAY) { + xsk->waitForDelay.push(std::move(ptr)); + return; + } + fillInTx.push_back(std::move(ptr)); + }); + info->cleanWorkerNotification(); + } + } + xsk->pickUpReadyPacket(fillInTx); + xsk->recycle(64); + xsk->fillFq(); + xsk->send(fillInTx); + ready = backup; + } +} +#endif /* HAVE_XSK */ diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 9d5d06f592..34d4600dac 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -56,6 +56,7 @@ #include "uuid-utils.hh" #include "proxy-protocol.hh" #include "stat_t.hh" +#include "xsk.hh" uint64_t uptimeOfProcess(const std::string& str); @@ -511,6 +512,7 @@ struct ClientState std::shared_ptr doqFrontend{nullptr}; std::shared_ptr doh3Frontend{nullptr}; std::shared_ptr d_filter{nullptr}; + std::shared_ptr xskInfo{nullptr}; size_t d_maxInFlightQueriesPerConn{1}; size_t d_tcpConcurrentConnectionsLimit{0}; int udpFD{-1}; @@ -702,6 +704,8 @@ struct DownstreamState: public std::enable_shared_from_this std::string d_dohPath; std::string name; std::string nameWithAddr; + MACAddr sourceMACAddr; + MACAddr destMACAddr; size_t d_numberOfSockets{1}; size_t d_maxInFlightQueriesPerConn{1}; size_t d_tcpConcurrentConnectionsLimit{0}; @@ -815,6 +819,7 @@ public: std::vector sockets; StopWatch sw; QPSLimiter qps; + std::shared_ptr xskInfo{nullptr}; std::atomic idOffset{0}; size_t socketsOffset{0}; double latencyUsec{0.0}; @@ -837,7 +842,14 @@ private: uint8_t consecutiveSuccessfulChecks{0}; bool d_stopped{false}; public: - + void updateStatisticsInfo() + { + auto delta = sw.udiffAndSet() / 1000000.0; + queryLoad.store(1.0 * (queries.load() - prev.queries.load()) / delta); + dropRate.store(1.0 * (reuseds.load() - prev.reuseds.load()) / delta); + prev.queries.store(queries.load()); + prev.reuseds.store(reuseds.load()); + } void start(); bool isUp() const @@ -966,6 +978,19 @@ public: void restoreState(uint16_t id, InternalQueryState&&); std::optional getState(uint16_t id); +#ifdef HAVE_XSK + void registerXsk(std::shared_ptr& xsk) + { + xskInfo = XskWorker::create(); + if (d_config.sourceAddr.sin4.sin_family == 0) { + throw runtime_error("invalid source addr"); + } + xsk->addWorker(xskInfo, d_config.sourceAddr, getProtocol() != dnsdist::Protocol::DoUDP); + memcpy(d_config.sourceMACAddr, xsk->source, sizeof(MACAddr)); + xskInfo->sharedEmptyFrameOffset = xsk->sharedEmptyFrameOffset; + } +#endif /* HAVE_XSK */ + dnsdist::Protocol getProtocol() const { if (isDoH()) { @@ -1138,7 +1163,7 @@ void setLuaSideEffect(); // set to report a side effect, cancelling all _no_ s bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect void resetLuaSideEffect(); // reset to indeterminate state -bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote, unsigned int& qnameWireLength); +bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote); bool checkQueryHeaders(const struct dnsheader* dh, ClientState& cs); @@ -1163,7 +1188,7 @@ bool processResponse(PacketBuffer& response, const std::vector& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted); -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query); +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool actuallySend = true); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 9ed171375a..c95629daac 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -254,7 +254,8 @@ dnsdist_SOURCES = \ tcpiohandler.cc tcpiohandler.hh \ threadname.hh threadname.cc \ uuid-utils.hh uuid-utils.cc \ - xpf.cc xpf.hh + xpf.cc xpf.hh \ + xsk.cc xsk.hh testrunner_SOURCES = \ base64.hh \ @@ -361,7 +362,8 @@ testrunner_SOURCES = \ testrunner.cc \ threadname.hh threadname.cc \ uuid-utils.hh uuid-utils.cc \ - xpf.cc xpf.hh + xpf.cc xpf.hh \ + xsk.cc xsk.hh dnsdist_LDFLAGS = \ $(AM_LDFLAGS) \ @@ -411,6 +413,13 @@ if HAVE_LIBSSL dnsdist_LDADD += $(LIBSSL_LIBS) endif +if HAVE_XSK +dnsdist_LDADD += -lbpf +dnsdist_LDADD += -lxdp +testrunner_LDADD += -lbpf +testrunner_LDADD += -lxdp +endif + if HAVE_LIBCRYPTO dnsdist_LDADD += $(LIBCRYPTO_LDFLAGS) $(LIBCRYPTO_LIBS) testrunner_LDADD += $(LIBCRYPTO_LDFLAGS) $(LIBCRYPTO_LIBS) diff --git a/pdns/dnsdistdist/configure.ac b/pdns/dnsdistdist/configure.ac index d9429c93f6..d9f6c719dd 100644 --- a/pdns/dnsdistdist/configure.ac +++ b/pdns/dnsdistdist/configure.ac @@ -40,6 +40,7 @@ PDNS_ENABLE_FUZZ_TARGETS PDNS_WITH_RE2 DNSDIST_ENABLE_DNSCRYPT PDNS_WITH_EBPF +PDNS_WITH_XSK PDNS_WITH_NET_SNMP PDNS_WITH_LIBCAP diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc index 8c3eefc239..bd7592545a 100644 --- a/pdns/dnsdistdist/dnsdist-backend.cc +++ b/pdns/dnsdistdist/dnsdist-backend.cc @@ -714,9 +714,11 @@ void DownstreamState::submitHealthCheckResult(bool initial, bool newResult) setUpStatus(newResult); if (newResult == false) { currentCheckFailures++; - auto stats = d_lazyHealthCheckStats.lock(); - stats->d_status = LazyHealthCheckStats::LazyStatus::Failed; - updateNextLazyHealthCheck(*stats, false); + if (d_config.availability == DownstreamState::Availability::Lazy) { + auto stats = d_lazyHealthCheckStats.lock(); + stats->d_status = LazyHealthCheckStats::LazyStatus::Failed; + updateNextLazyHealthCheck(*stats, false); + } } return; } diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index 36805573e7..d60c9dc584 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -31,40 +31,7 @@ bool g_verboseHealthChecks{false}; -struct HealthCheckData -{ - enum class TCPState : uint8_t - { - WritingQuery, - ReadingResponseSize, - ReadingResponse - }; - - HealthCheckData(FDMultiplexer& mplexer, std::shared_ptr downstream, DNSName&& checkName, uint16_t checkType, uint16_t checkClass, uint16_t queryID) : - d_ds(std::move(downstream)), d_mplexer(mplexer), d_udpSocket(-1), d_checkName(std::move(checkName)), d_checkType(checkType), d_checkClass(checkClass), d_queryID(queryID) - { - } - - const std::shared_ptr d_ds; - FDMultiplexer& d_mplexer; - std::unique_ptr d_tcpHandler{nullptr}; - std::unique_ptr d_ioState{nullptr}; - PacketBuffer d_buffer; - Socket d_udpSocket; - DNSName d_checkName; - struct timeval d_ttd - { - 0, 0 - }; - size_t d_bufferPos{0}; - uint16_t d_checkType; - uint16_t d_checkClass; - uint16_t d_queryID; - TCPState d_tcpState{TCPState::WritingQuery}; - bool d_initial{false}; -}; - -static bool handleResponse(std::shared_ptr& data) +bool handleResponse(std::shared_ptr& data) { const auto& downstream = data->d_ds; try { @@ -207,7 +174,7 @@ static void healthCheckUDPCallback(int descriptor, FDMultiplexer::funcparam_t& p } ++data->d_ds->d_healthCheckMetrics.d_networkErrors; data->d_ds->submitHealthCheckResult(data->d_initial, false); - data->d_mplexer.removeReadFD(descriptor); + data->d_mplexer->removeReadFD(descriptor); return; } } while (got < 0); @@ -224,7 +191,7 @@ static void healthCheckUDPCallback(int descriptor, FDMultiplexer::funcparam_t& p return; } - data->d_mplexer.removeReadFD(descriptor); + data->d_mplexer->removeReadFD(descriptor); data->d_ds->submitHealthCheckResult(data->d_initial, handleResponse(data)); } @@ -300,36 +267,51 @@ static void healthCheckTCPCallback(int descriptor, FDMultiplexer::funcparam_t& p } } -bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared_ptr& downstream, bool initialCheck) +PacketBuffer getHealthCheckPacket(const std::shared_ptr& downstream, FDMultiplexer* mplexer, std::shared_ptr& data) { - try { - uint16_t queryID = dnsdist::getRandomDNSID(); - DNSName checkName = downstream->d_config.checkName; - uint16_t checkType = downstream->d_config.checkType.getCode(); - uint16_t checkClass = downstream->d_config.checkClass; - dnsheader checkHeader{}; - memset(&checkHeader, 0, sizeof(checkHeader)); - - checkHeader.qdcount = htons(1); - checkHeader.id = queryID; - - checkHeader.rd = true; - if (downstream->d_config.setCD) { - checkHeader.cd = true; - } + uint16_t queryID = dnsdist::getRandomDNSID(); + DNSName checkName = downstream->d_config.checkName; + uint16_t checkType = downstream->d_config.checkType.getCode(); + uint16_t checkClass = downstream->d_config.checkClass; + dnsheader checkHeader{}; + memset(&checkHeader, 0, sizeof(checkHeader)); + + checkHeader.qdcount = htons(1); + checkHeader.id = queryID; + + checkHeader.rd = true; + if (downstream->d_config.setCD) { + checkHeader.cd = true; + } - if (downstream->d_config.checkFunction) { - auto lock = g_lua.lock(); - auto ret = downstream->d_config.checkFunction(checkName, checkType, checkClass, &checkHeader); - checkName = std::get<0>(ret); - checkType = std::get<1>(ret); - checkClass = std::get<2>(ret); - } + if (downstream->d_config.checkFunction) { + auto lock = g_lua.lock(); + auto ret = downstream->d_config.checkFunction(checkName, checkType, checkClass, &checkHeader); + checkName = std::get<0>(ret); + checkType = std::get<1>(ret); + checkClass = std::get<2>(ret); + } + PacketBuffer packet; + GenericDNSPacketWriter dpw(packet, checkName, checkType, checkClass); + dnsheader* requestHeader = dpw.getHeader(); + *requestHeader = checkHeader; + data = std::make_shared(mplexer, downstream, std::move(checkName), checkType, checkClass, queryID); + return packet; +} - PacketBuffer packet; - GenericDNSPacketWriter dpw(packet, checkName, checkType, checkClass); - dnsheader* requestHeader = dpw.getHeader(); - *requestHeader = checkHeader; +void setHealthCheckTime(const std::shared_ptr& downstream, const std::shared_ptr& data) +{ + gettimeofday(&data->d_ttd, nullptr); + data->d_ttd.tv_sec += static_castd_ttd.tv_sec)>(downstream->d_config.checkTimeout / 1000); /* ms to seconds */ + data->d_ttd.tv_usec += static_castd_ttd.tv_usec)>((downstream->d_config.checkTimeout % 1000) * 1000); /* remaining ms to us */ + normalizeTV(data->d_ttd); +} + +bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared_ptr& downstream, bool initialCheck) +{ + try { + std::shared_ptr data; + PacketBuffer packet = getHealthCheckPacket(downstream, mplexer.get(), data); /* we need to compute that _before_ adding the proxy protocol payload */ uint16_t packetSize = packet.size(); @@ -368,13 +350,9 @@ bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared sock.bind(downstream->d_config.sourceAddr, false); } - auto data = std::make_shared(*mplexer, downstream, std::move(checkName), checkType, checkClass, queryID); data->d_initial = initialCheck; - gettimeofday(&data->d_ttd, nullptr); - data->d_ttd.tv_sec += static_castd_ttd.tv_sec)>(downstream->d_config.checkTimeout / 1000); /* ms to seconds */ - data->d_ttd.tv_usec += static_castd_ttd.tv_usec)>((downstream->d_config.checkTimeout % 1000) * 1000); /* remaining ms to us */ - normalizeTV(data->d_ttd); + setHealthCheckTime(downstream, data); if (!downstream->doHealthcheckOverTCP()) { sock.connect(downstream->d_config.remote); @@ -383,7 +361,7 @@ bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared if (sent < 0) { int ret = errno; if (g_verboseHealthChecks) { - infolog("Error while sending a health check query (ID %d) to backend %s: %d", queryID, downstream->getNameWithAddr(), ret); + infolog("Error while sending a health check query (ID %d) to backend %s: %d", data->d_queryID, downstream->getNameWithAddr(), ret); } return false; } diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.hh b/pdns/dnsdistdist/dnsdist-healthchecks.hh index e9da6c66de..4f1940643e 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.hh +++ b/pdns/dnsdistdist/dnsdist-healthchecks.hh @@ -24,8 +24,47 @@ #include "dnsdist.hh" #include "mplexer.hh" #include "sstuff.hh" +#include "tcpiohandler-mplexer.hh" extern bool g_verboseHealthChecks; bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared_ptr& downstream, bool initial = false); void handleQueuedHealthChecks(FDMultiplexer& mplexer, bool initial = false); + +struct HealthCheckData +{ + enum class TCPState : uint8_t + { + WritingQuery, + ReadingResponseSize, + ReadingResponse + }; + + HealthCheckData(FDMultiplexer* mplexer, std::shared_ptr downstream, DNSName&& checkName, uint16_t checkType, uint16_t checkClass, uint16_t queryID) : + d_ds(std::move(downstream)), d_mplexer(mplexer), d_udpSocket(-1), d_checkName(std::move(checkName)), d_checkType(checkType), d_checkClass(checkClass), d_queryID(queryID) + { + } + + const std::shared_ptr d_ds; + FDMultiplexer* d_mplexer{nullptr}; + std::unique_ptr d_tcpHandler{nullptr}; + std::unique_ptr d_ioState{nullptr}; + PacketBuffer d_buffer; + Socket d_udpSocket; + DNSName d_checkName; + struct timeval d_ttd + { + 0, 0 + }; + size_t d_bufferPos{0}; + uint16_t d_checkType; + uint16_t d_checkClass; + uint16_t d_queryID; + TCPState d_tcpState{TCPState::WritingQuery}; + bool d_initial{false}; +}; + +PacketBuffer getHealthCheckPacket(const std::shared_ptr& ds, FDMultiplexer* mplexer, std::shared_ptr& data); +void setHealthCheckTime(const std::shared_ptr& ds, const std::shared_ptr& data); +bool handleResponse(std::shared_ptr& data); + diff --git a/pdns/dnsdistdist/m4/pdns_with_xsk.m4 b/pdns/dnsdistdist/m4/pdns_with_xsk.m4 new file mode 100644 index 0000000000..b45c9f30f1 --- /dev/null +++ b/pdns/dnsdistdist/m4/pdns_with_xsk.m4 @@ -0,0 +1,22 @@ +AC_DEFUN([PDNS_WITH_XSK],[ + AC_MSG_CHECKING([if we have xsk support]) + AC_ARG_WITH([xsk], + AS_HELP_STRING([--with-xsk],[enable xsk support @<:@default=auto@:>@]), + [with_xsk=$withval], + [with_xsk=auto], + ) + AC_MSG_RESULT([$with_xsk]) + + AS_IF([test "x$with_xsk" != "xno"], [ + AS_IF([test "x$with_xsk" = "xyes" -o "x$with_xsk" = "xauto"], [ + AC_CHECK_HEADERS([xdp/xsk.h], xsk_headers=yes, xsk_headers=no) + ]) + ]) + AS_IF([test "x$with_xsk" = "xyes"], [ + AS_IF([test x"$xsk_headers" = "no"], [ + AC_MSG_ERROR([XSK support requested but required libxdp were not found]) + ]) + ]) + AS_IF([test x"$xsk_headers" = "xyes" ], [ AC_DEFINE([HAVE_XSK], [1], [Define if using eBPF.]) ]) + AM_CONDITIONAL([HAVE_XSK], [test x"$xsk_headers" = "xyes" ]) +]) diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index a4c887aeef..fd37dda319 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -76,7 +76,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::sha return ProcessQueryResult::Drop; } -bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote, unsigned int& qnameWireLength) +bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote) { return true; } diff --git a/pdns/dnsdistdist/xsk.cc b/pdns/dnsdistdist/xsk.cc new file mode 120000 index 0000000000..3258a93cb1 --- /dev/null +++ b/pdns/dnsdistdist/xsk.cc @@ -0,0 +1 @@ +../xsk.cc \ No newline at end of file diff --git a/pdns/dnsdistdist/xsk.hh b/pdns/dnsdistdist/xsk.hh new file mode 120000 index 0000000000..4b1bba7374 --- /dev/null +++ b/pdns/dnsdistdist/xsk.hh @@ -0,0 +1 @@ +../xsk.hh \ No newline at end of file diff --git a/pdns/iputils.hh b/pdns/iputils.hh index 4ef0b8f764..7d8b2e4c2d 100644 --- a/pdns/iputils.hh +++ b/pdns/iputils.hh @@ -83,6 +83,7 @@ #undef IP_PKTINFO #endif +using MACAddr = uint8_t[6]; union ComboAddress { struct sockaddr_in sin4; struct sockaddr_in6 sin6; @@ -123,6 +124,24 @@ union ComboAddress { return rhs.operator<(*this); } + struct addressPortOnlyHash + { + uint32_t operator()(const ComboAddress& ca) const + { + const unsigned char* start = nullptr; + if (ca.sin4.sin_family == AF_INET) { + start = reinterpret_cast(&ca.sin4.sin_addr.s_addr); + auto tmp = burtle(start, 4, 0); + return burtle(reinterpret_cast(&ca.sin4.sin_port), 2, tmp); + } + { + start = reinterpret_cast(&ca.sin6.sin6_addr.s6_addr); + auto tmp = burtle(start, 16, 0); + return burtle(reinterpret_cast(&ca.sin6.sin6_port), 2, tmp); + } + } + }; + struct addressOnlyHash { uint32_t operator()(const ComboAddress& ca) const @@ -347,11 +366,14 @@ union ComboAddress { void truncate(unsigned int bits) noexcept; - uint16_t getPort() const + uint16_t getNetworkOrderPort() const noexcept { - return ntohs(sin4.sin_port); + return sin4.sin_port; + } + uint16_t getPort() const noexcept + { + return ntohs(getNetworkOrderPort()); } - void setPort(uint16_t port) { sin4.sin_port = htons(port); diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index 56f31d2348..c29fdabba5 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -57,7 +57,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs return false; } -bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query) +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool) { return true; } @@ -74,6 +74,8 @@ bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&) return false; } +std::vector> g_xsk; + BOOST_AUTO_TEST_SUITE(test_dnsdist_cc) static const uint16_t ECSSourcePrefixV4 = 24; diff --git a/pdns/xsk.cc b/pdns/xsk.cc new file mode 100644 index 0000000000..b7659317e5 --- /dev/null +++ b/pdns/xsk.cc @@ -0,0 +1,828 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "gettime.hh" +#include "xsk.hh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef HAVE_XSK +#include +#include +extern "C" +{ +#include +} + +constexpr bool XskSocket::isPowOfTwo(uint32_t value) noexcept +{ + return value != 0 && (value & (value - 1)) == 0; +} +int XskSocket::firstTimeout() +{ + if (waitForDelay.empty()) { + return -1; + } + timespec now; + gettime(&now); + const auto& firstTime = waitForDelay.top()->sendTime; + const auto res = timeDifference(now, firstTime); + if (res <= 0) { + return 0; + } + return res; +} +XskSocket::XskSocket(size_t frameNum_, const std::string& ifName_, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_) : + frameNum(frameNum_), queueId(queue_id), ifName(ifName_), poolName(poolName_), socket(nullptr, xsk_socket__delete), sharedEmptyFrameOffset(std::make_shared>>()) +{ + if (!isPowOfTwo(frameNum_) || !isPowOfTwo(frameSize) + || !isPowOfTwo(fqCapacity) || !isPowOfTwo(cqCapacity) || !isPowOfTwo(rxCapacity) || !isPowOfTwo(txCapacity)) { + throw std::runtime_error("The number of frame , the size of frame and the capacity of rings must is a pow of 2"); + } + getMACFromIfName(); + + memset(&cq, 0, sizeof(cq)); + memset(&fq, 0, sizeof(fq)); + memset(&tx, 0, sizeof(tx)); + memset(&rx, 0, sizeof(rx)); + xsk_umem_config umemCfg; + umemCfg.fill_size = fqCapacity; + umemCfg.comp_size = cqCapacity; + umemCfg.frame_size = frameSize; + umemCfg.frame_headroom = XSK_UMEM__DEFAULT_FRAME_HEADROOM; + umemCfg.flags = 0; + umem.umemInit(frameNum_ * frameSize, &cq, &fq, &umemCfg); + { + xsk_socket_config socketCfg; + socketCfg.rx_size = rxCapacity; + socketCfg.tx_size = txCapacity; + socketCfg.bind_flags = XDP_USE_NEED_WAKEUP; + socketCfg.xdp_flags = XDP_FLAGS_SKB_MODE; + socketCfg.libxdp_flags = XSK_LIBBPF_FLAGS__INHIBIT_PROG_LOAD; + xsk_socket* tmp = nullptr; + auto ret = xsk_socket__create(&tmp, ifName.c_str(), queue_id, umem.umem, &rx, &tx, &socketCfg); + if (ret != 0) { + throw std::runtime_error("Error creating a xsk socket of if_name" + ifName + stringerror(ret)); + } + socket = std::unique_ptr(tmp, xsk_socket__delete); + } + for (uint64_t i = 0; i < frameNum; i++) { + uniqueEmptyFrameOffset.push_back(i * frameSize + XDP_PACKET_HEADROOM); + } + fillFq(fqCapacity); + const auto xskfd = xskFd(); + fds.push_back(pollfd{ + .fd = xskfd, + .events = POLLIN, + .revents = 0}); + const auto xskMapFd = FDWrapper(bpf_obj_get(xskMapPath.c_str())); + if (xskMapFd.getHandle() < 0) { + throw std::runtime_error("Error get BPF map from path"); + } + auto ret = bpf_map_update_elem(xskMapFd.getHandle(), &queue_id, &xskfd, 0); + if (ret) { + throw std::runtime_error("Error insert into xsk_map"); + } +} +void XskSocket::fillFq(uint32_t fillSize) noexcept +{ + { + auto frames = sharedEmptyFrameOffset->lock(); + if (frames->size() < holdThreshold) { + const auto moveSize = std::min(holdThreshold - frames->size(), uniqueEmptyFrameOffset.size()); + if (moveSize > 0) { + frames->insert(frames->end(), std::make_move_iterator(uniqueEmptyFrameOffset.end() - moveSize), std::make_move_iterator(uniqueEmptyFrameOffset.end())); + } + } + } + if (uniqueEmptyFrameOffset.size() < fillSize) { + return; + } + uint32_t idx; + if (xsk_ring_prod__reserve(&fq, fillSize, &idx) != fillSize) { + return; + } + for (uint32_t i = 0; i < fillSize; i++) { + *xsk_ring_prod__fill_addr(&fq, idx++) = uniqueEmptyFrameOffset.back(); + uniqueEmptyFrameOffset.pop_back(); + } + xsk_ring_prod__submit(&fq, idx); +} +int XskSocket::wait(int timeout) +{ + return poll(fds.data(), fds.size(), static_cast(std::min(static_cast(timeout), static_cast(firstTimeout())))); +} +[[nodiscard]] uint64_t XskSocket::frameOffset(const XskPacket& packet) const noexcept +{ + return reinterpret_cast(packet.frame) - reinterpret_cast(umem.bufBase); +} + +int XskSocket::xskFd() const noexcept { return xsk_socket__fd(socket.get()); } + +void XskSocket::send(std::vector& packets) +{ + const auto packetSize = packets.size(); + if (packetSize == 0) { + return; + } + uint32_t idx; + if (xsk_ring_prod__reserve(&tx, packetSize, &idx) != packets.size()) { + return; + } + + for (const auto& i : packets) { + *xsk_ring_prod__tx_desc(&tx, idx++) = { + .addr = frameOffset(*i), + .len = i->FrameLen(), + .options = 0}; + } + xsk_ring_prod__submit(&tx, packetSize); + packets.clear(); +} +std::vector XskSocket::recv(uint32_t recvSizeMax, uint32_t* failedCount) +{ + uint32_t idx; + std::vector res; + const auto recvSize = xsk_ring_cons__peek(&rx, recvSizeMax, &idx); + if (recvSize <= 0) { + return res; + } + const auto baseAddr = reinterpret_cast(umem.bufBase); + uint32_t count = 0; + for (uint32_t i = 0; i < recvSize; i++) { + const auto* desc = xsk_ring_cons__rx_desc(&rx, idx++); + auto ptr = std::make_unique(reinterpret_cast(desc->addr + baseAddr), desc->len, frameSize); + if (!ptr->parse()) { + ++count; + uniqueEmptyFrameOffset.push_back(frameOffset(*ptr)); + } + else { + res.push_back(std::move(ptr)); + } + } + xsk_ring_cons__release(&rx, recvSize); + if (failedCount) { + *failedCount = count; + } + return res; +} +void XskSocket::pickUpReadyPacket(std::vector& packets) +{ + timespec now; + gettime(&now); + while (!waitForDelay.empty() && timeDifference(now, waitForDelay.top()->sendTime) <= 0) { + auto& top = const_cast(waitForDelay.top()); + packets.push_back(std::move(top)); + waitForDelay.pop(); + } +} +void XskSocket::recycle(size_t size) noexcept +{ + uint32_t idx; + const auto completeSize = xsk_ring_cons__peek(&cq, size, &idx); + if (completeSize <= 0) { + return; + } + for (uint32_t i = 0; i < completeSize; ++i) { + uniqueEmptyFrameOffset.push_back(*xsk_ring_cons__comp_addr(&cq, idx++)); + } + xsk_ring_cons__release(&cq, completeSize); +} + +void XskSocket::XskUmem::umemInit(size_t memSize, xsk_ring_cons* cq, xsk_ring_prod* fq, xsk_umem_config* config) +{ + size = memSize; + bufBase = static_cast(mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + if (bufBase == MAP_FAILED) { + throw std::runtime_error("mmap failed"); + } + auto ret = xsk_umem__create(&umem, bufBase, size, fq, cq, config); + if (ret != 0) { + munmap(bufBase, size); + throw std::runtime_error("Error creating a umem of size" + std::to_string(size) + stringerror(ret)); + } +} + +XskSocket::XskUmem::~XskUmem() +{ + if (umem) { + xsk_umem__delete(umem); + } + if (bufBase) { + munmap(bufBase, size); + } +} + +bool XskPacket::parse() +{ + // payloadEnd must bigger than payload + sizeof(ethhdr) + sizoef(iphdr) + sizeof(udphdr) + auto* eth = reinterpret_cast(frame); + uint8_t l4Protocol; + if (eth->h_proto == htons(ETH_P_IP)) { + auto* ip = reinterpret_cast(eth + 1); + if (ip->ihl != static_cast(sizeof(iphdr) >> 2)) { + // ip->ihl*4 != sizeof(iphdr) + // ip options is not supported now! + return false; + } + // check ip.check == ipv4Checksum() is not needed! + // We check it in BPF program + from = makeComboAddressFromRaw(4, reinterpret_cast(&ip->saddr), sizeof(ip->saddr)); + to = makeComboAddressFromRaw(4, reinterpret_cast(&ip->daddr), sizeof(ip->daddr)); + l4Protocol = ip->protocol; + l4Header = reinterpret_cast(ip + 1); + payloadEnd = std::min(reinterpret_cast(ip) + ntohs(ip->tot_len), payloadEnd); + } + else if (eth->h_proto == htons(ETH_P_IPV6)) { + auto* ipv6 = reinterpret_cast(eth + 1); + l4Header = reinterpret_cast(ipv6 + 1); + if (l4Header >= payloadEnd) { + return false; + } + from = makeComboAddressFromRaw(6, reinterpret_cast(&ipv6->saddr), sizeof(ipv6->saddr)); + to = makeComboAddressFromRaw(6, reinterpret_cast(&ipv6->daddr), sizeof(ipv6->daddr)); + l4Protocol = ipv6->nexthdr; + payloadEnd = std::min(l4Header + ntohs(ipv6->payload_len), payloadEnd); + } + else { + return false; + } + if (l4Protocol == IPPROTO_UDP) { + // check udp.check == ipv4Checksum() is not needed! + // We check it in BPF program + auto* udp = reinterpret_cast(l4Header); + payload = l4Header + sizeof(udphdr); + // Because of XskPacket::setHeader + // payload = payloadEnd should be allow + if (payload > payloadEnd) { + return false; + } + payloadEnd = std::min(l4Header + ntohs(udp->len), payloadEnd); + from.setPort(ntohs(udp->source)); + to.setPort(ntohs(udp->dest)); + return true; + } + if (l4Protocol == IPPROTO_TCP) { + // check tcp.check == ipv4Checksum() is not needed! + // We check it in BPF program + auto* tcp = reinterpret_cast(l4Header); + if (tcp->doff != static_cast(sizeof(tcphdr) >> 2)) { + // tcp is not supported now! + return false; + } + payload = l4Header + sizeof(tcphdr); + // + if (payload > payloadEnd) { + return false; + } + from.setPort(ntohs(tcp->source)); + to.setPort(ntohs(tcp->dest)); + flags |= TCP; + return true; + } + // ipv6 extension header is not supported now! + return false; +} + +uint32_t XskPacket::dataLen() const noexcept +{ + return payloadEnd - payload; +} +uint32_t XskPacket::FrameLen() const noexcept +{ + return payloadEnd - frame; +} +size_t XskPacket::capacity() const noexcept +{ + return frameEnd - payloadEnd; +} + +void XskPacket::changeDirectAndUpdateChecksum() noexcept +{ + auto* eth = reinterpret_cast(frame); + { + uint8_t tmp[ETH_ALEN]; + static_assert(sizeof(tmp) == sizeof(eth->h_dest), "Size Error"); + static_assert(sizeof(tmp) == sizeof(eth->h_source), "Size Error"); + memcpy(tmp, eth->h_dest, sizeof(tmp)); + memcpy(eth->h_dest, eth->h_source, sizeof(tmp)); + memcpy(eth->h_source, tmp, sizeof(tmp)); + } + if (eth->h_proto == htons(ETH_P_IPV6)) { + // IPV6 + auto* ipv6 = reinterpret_cast(eth + 1); + std::swap(ipv6->daddr, ipv6->saddr); + if (ipv6->nexthdr == IPPROTO_UDP) { + // UDP + auto* udp = reinterpret_cast(ipv6 + 1); + std::swap(udp->dest, udp->source); + udp->len = htons(payloadEnd - reinterpret_cast(udp)); + udp->check = 0; + udp->check = tcp_udp_v6_checksum(); + } + else { + // TCP + auto* tcp = reinterpret_cast(ipv6 + 1); + std::swap(tcp->dest, tcp->source); + // TODO + } + rewriteIpv6Header(ipv6); + } + else { + // IPV4 + auto* ipv4 = reinterpret_cast(eth + 1); + std::swap(ipv4->daddr, ipv4->saddr); + if (ipv4->protocol == IPPROTO_UDP) { + // UDP + auto* udp = reinterpret_cast(ipv4 + 1); + std::swap(udp->dest, udp->source); + udp->len = htons(payloadEnd - reinterpret_cast(udp)); + udp->check = 0; + udp->check = tcp_udp_v4_checksum(); + } + else { + // TCP + auto* tcp = reinterpret_cast(ipv4 + 1); + std::swap(tcp->dest, tcp->source); + // TODO + } + rewriteIpv4Header(ipv4); + } +} +void XskPacket::rewriteIpv4Header(void* ipv4header) noexcept +{ + auto* ipv4 = static_cast(ipv4header); + ipv4->version = 4; + ipv4->ihl = sizeof(iphdr) / 4; + ipv4->tos = 0; + ipv4->tot_len = htons(payloadEnd - reinterpret_cast(ipv4)); + ipv4->id = 0; + ipv4->frag_off = 0; + ipv4->ttl = DefaultTTL; + ipv4->check = 0; + ipv4->check = ipv4Checksum(); +} +void XskPacket::rewriteIpv6Header(void* ipv6header) noexcept +{ + auto* ipv6 = static_cast(ipv6header); + ipv6->version = 6; + ipv6->priority = 0; + ipv6->payload_len = htons(payloadEnd - reinterpret_cast(ipv6 + 1)); + ipv6->hop_limit = DefaultTTL; + memset(&ipv6->flow_lbl, 0, sizeof(ipv6->flow_lbl)); +} + +bool XskPacket::isIPV6() const noexcept +{ + const auto* eth = reinterpret_cast(frame); + return eth->h_proto == htons(ETH_P_IPV6); +} +XskPacket::XskPacket(void* frame_, size_t dataSize, size_t frameSize) : + frame(static_cast(frame_)), payloadEnd(static_cast(frame) + dataSize), frameEnd(static_cast(frame) + frameSize - XDP_PACKET_HEADROOM) +{ +} +PacketBuffer XskPacket::clonePacketBuffer() const +{ + const auto size = dataLen(); + PacketBuffer tmp(size); + memcpy(tmp.data(), payload, size); + return tmp; +} +void XskPacket::cloneIntoPacketBuffer(PacketBuffer& buffer) const +{ + const auto size = dataLen(); + buffer.resize(size); + memcpy(buffer.data(), payload, size); +} +bool XskPacket::setPayload(const PacketBuffer& buf) +{ + const auto bufSize = buf.size(); + if (bufSize == 0 || bufSize > capacity()) { + return false; + } + flags |= UPDATE; + memcpy(payload, buf.data(), bufSize); + payloadEnd = payload + bufSize; + return true; +} +void XskPacket::addDelay(const int relativeMilliseconds) noexcept +{ + gettime(&sendTime); + sendTime.tv_nsec += static_cast(relativeMilliseconds) * 1000000L; + sendTime.tv_sec += sendTime.tv_nsec / 1000000000L; + sendTime.tv_nsec %= 1000000000L; +} +bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept +{ + return s1->sendTime < s2->sendTime; +} +const ComboAddress& XskPacket::getFromAddr() const noexcept +{ + return from; +} +const ComboAddress& XskPacket::getToAddr() const noexcept +{ + return to; +} +void XskWorker::notify(int fd) +{ + uint64_t value = 1; + ssize_t res = 0; + while ((res = write(fd, &value, sizeof(value))) == EINTR) { + } + if (res != sizeof(value)) { + throw runtime_error("Unable Wake Up XskSocket Failed"); + } +} +XskWorker::XskWorker() : + workerWaker(createEventfd()), xskSocketWaker(createEventfd()) +{ +} +void* XskPacket::payloadData() +{ + return reinterpret_cast(payload); +} +const void* XskPacket::payloadData() const +{ + return reinterpret_cast(payload); +} +void XskPacket::setAddr(const ComboAddress& from_, MACAddr fromMAC, const ComboAddress& to_, MACAddr toMAC, bool tcp) noexcept +{ + auto* eth = reinterpret_cast(frame); + memcpy(eth->h_dest, &toMAC[0], sizeof(MACAddr)); + memcpy(eth->h_source, &fromMAC[0], sizeof(MACAddr)); + to = to_; + from = from_; + l4Header = frame + sizeof(ethhdr) + (to.isIPv4() ? sizeof(iphdr) : sizeof(ipv6hdr)); + if (tcp) { + flags = TCP; + payload = l4Header + sizeof(tcphdr); + } + else { + flags = 0; + payload = l4Header + sizeof(udphdr); + } +} +void XskPacket::rewrite() noexcept +{ + flags |= REWRITE; + auto* eth = reinterpret_cast(frame); + if (to.isIPv4()) { + eth->h_proto = htons(ETH_P_IP); + auto* ipv4 = reinterpret_cast(eth + 1); + + ipv4->daddr = to.sin4.sin_addr.s_addr; + ipv4->saddr = from.sin4.sin_addr.s_addr; + if (flags & XskPacket::TCP) { + auto* tcp = reinterpret_cast(ipv4 + 1); + ipv4->protocol = IPPROTO_TCP; + tcp->source = from.sin4.sin_port; + tcp->dest = to.sin4.sin_port; + // TODO + } + else { + auto* udp = reinterpret_cast(ipv4 + 1); + ipv4->protocol = IPPROTO_UDP; + udp->source = from.sin4.sin_port; + udp->dest = to.sin4.sin_port; + udp->len = htons(payloadEnd - reinterpret_cast(udp)); + udp->check = 0; + udp->check = tcp_udp_v4_checksum(); + } + rewriteIpv4Header(ipv4); + } + else { + auto* ipv6 = reinterpret_cast(eth + 1); + memcpy(&ipv6->daddr, &to.sin6.sin6_addr, sizeof(ipv6->daddr)); + memcpy(&ipv6->saddr, &from.sin6.sin6_addr, sizeof(ipv6->saddr)); + if (flags & XskPacket::TCP) { + auto* tcp = reinterpret_cast(ipv6 + 1); + ipv6->nexthdr = IPPROTO_TCP; + tcp->source = from.sin6.sin6_port; + tcp->dest = to.sin6.sin6_port; + // TODO + } + else { + auto* udp = reinterpret_cast(ipv6 + 1); + ipv6->nexthdr = IPPROTO_UDP; + udp->source = from.sin6.sin6_port; + udp->dest = to.sin6.sin6_port; + udp->len = htons(payloadEnd - reinterpret_cast(udp)); + udp->check = 0; + udp->check = tcp_udp_v6_checksum(); + } + } +} + +[[nodiscard]] __be16 XskPacket::ipv4Checksum() const noexcept +{ + auto* ip = reinterpret_cast(frame + sizeof(ethhdr)); + return ip_checksum_fold(ip_checksum_partial(ip, sizeof(iphdr), 0)); +} +[[nodiscard]] __be16 XskPacket::tcp_udp_v4_checksum() const noexcept +{ + const auto* ip = reinterpret_cast(frame + sizeof(ethhdr)); + // ip options is not supported !!! + const auto l4Length = static_cast(payloadEnd - l4Header); + auto sum = tcp_udp_v4_header_checksum_partial(ip->saddr, ip->daddr, ip->protocol, l4Length); + sum = ip_checksum_partial(l4Header, l4Length, sum); + return ip_checksum_fold(sum); +} +[[nodiscard]] __be16 XskPacket::tcp_udp_v6_checksum() const noexcept +{ + const auto* ipv6 = reinterpret_cast(frame + sizeof(ethhdr)); + const auto l4Length = static_cast(payloadEnd - l4Header); + uint64_t sum = tcp_udp_v6_header_checksum_partial(&ipv6->saddr, &ipv6->daddr, ipv6->nexthdr, l4Length); + sum = ip_checksum_partial(l4Header, l4Length, sum); + return ip_checksum_fold(sum); +} + +#ifndef __packed +#define __packed __attribute__((packed)) +#endif +[[nodiscard]] uint64_t XskPacket::ip_checksum_partial(const void* p, size_t len, uint64_t sum) noexcept +{ + /* Main loop: 32 bits at a time. + * We take advantage of intel's ability to do unaligned memory + * accesses with minimal additional cost. Other architectures + * probably want to be more careful here. + */ + const uint32_t* p32 = (const uint32_t*)(p); + for (; len >= sizeof(*p32); len -= sizeof(*p32)) + sum += *p32++; + + /* Handle un-32bit-aligned trailing bytes */ + const uint16_t* p16 = (const uint16_t*)(p32); + if (len >= 2) { + sum += *p16++; + len -= sizeof(*p16); + } + if (len > 0) { + const uint8_t* p8 = (const uint8_t*)(p16); + sum += ntohs(*p8 << 8); /* RFC says pad last byte */ + } + + return sum; +} +[[nodiscard]] __be16 XskPacket::ip_checksum_fold(uint64_t sum) noexcept +{ + while (sum & ~0xffffffffULL) + sum = (sum >> 32) + (sum & 0xffffffffULL); + while (sum & 0xffff0000ULL) + sum = (sum >> 16) + (sum & 0xffffULL); + + return ~sum; +} +[[nodiscard]] uint64_t XskPacket::tcp_udp_v4_header_checksum_partial(__be32 src_ip, __be32 dst_ip, uint8_t protocol, uint16_t len) noexcept +{ + struct header + { + __be32 src_ip; + __be32 dst_ip; + __uint8_t mbz; + __uint8_t protocol; + __be16 length; + }; + /* The IPv4 pseudo-header is defined in RFC 793, Section 3.1. */ + struct ipv4_pseudo_header_t + { + /* We use a union here to avoid aliasing issues with gcc -O2 */ + union + { + header __packed fields; + uint32_t words[3]; + }; + }; + struct ipv4_pseudo_header_t pseudo_header; + assert(sizeof(pseudo_header) == 12); + + /* Fill in the pseudo-header. */ + pseudo_header.fields.src_ip = src_ip; + pseudo_header.fields.dst_ip = dst_ip; + pseudo_header.fields.mbz = 0; + pseudo_header.fields.protocol = protocol; + pseudo_header.fields.length = htons(len); + return ip_checksum_partial(&pseudo_header, sizeof(pseudo_header), 0); +} +[[nodiscard]] uint64_t XskPacket::tcp_udp_v6_header_checksum_partial(const struct in6_addr* src_ip, const struct in6_addr* dst_ip, uint8_t protocol, uint32_t len) noexcept +{ + struct header + { + struct in6_addr src_ip; + struct in6_addr dst_ip; + __be32 length; + __uint8_t mbz[3]; + __uint8_t next_header; + }; + /* The IPv6 pseudo-header is defined in RFC 2460, Section 8.1. */ + struct ipv6_pseudo_header_t + { + /* We use a union here to avoid aliasing issues with gcc -O2 */ + union + { + header __packed fields; + uint32_t words[10]; + }; + }; + struct ipv6_pseudo_header_t pseudo_header; + assert(sizeof(pseudo_header) == 40); + + /* Fill in the pseudo-header. */ + pseudo_header.fields.src_ip = *src_ip; + pseudo_header.fields.dst_ip = *dst_ip; + pseudo_header.fields.length = htonl(len); + memset(pseudo_header.fields.mbz, 0, sizeof(pseudo_header.fields.mbz)); + pseudo_header.fields.next_header = protocol; + return ip_checksum_partial(&pseudo_header, sizeof(pseudo_header), 0); +} +void XskPacket::setHeader(const PacketBuffer& buf) noexcept +{ + memcpy(frame, buf.data(), buf.size()); + payloadEnd = frame + buf.size(); + flags = 0; + const auto res = parse(); + assert(res); +} +std::unique_ptr XskPacket::cloneHeadertoPacketBuffer() const +{ + const auto size = payload - frame; + auto tmp = std::make_unique(size); + memcpy(tmp->data(), frame, size); + return tmp; +} +int XskWorker::createEventfd() +{ + auto fd = ::eventfd(0, EFD_CLOEXEC); + if (fd < 0) { + throw runtime_error("Unable create eventfd"); + } + return fd; +} +void XskWorker::waitForXskSocket() noexcept +{ + uint64_t x = read(workerWaker, &x, sizeof(x)); +} +void XskWorker::notifyXskSocket() noexcept +{ + notify(xskSocketWaker); +} + +std::shared_ptr XskWorker::create() +{ + return std::make_shared(); +} +void XskSocket::addWorker(std::shared_ptr s, const ComboAddress& dest, bool isTCP) +{ + extern std::atomic g_configurationDone; + if (g_configurationDone) { + throw runtime_error("Adding a server with xsk at runtime is not supported"); + } + s->poolName = poolName; + const auto socketWaker = s->xskSocketWaker.getHandle(); + const auto workerWaker = s->workerWaker.getHandle(); + const auto& socketWakerIdx = workers.get<0>(); + if (socketWakerIdx.contains(socketWaker)) { + throw runtime_error("Server already exist"); + } + s->umemBufBase = umem.bufBase; + workers.insert(XskRouteInfo{ + .worker = std::move(s), + .dest = dest, + .xskSocketWaker = socketWaker, + .workerWaker = workerWaker, + }); + fds.push_back(pollfd{ + .fd = socketWaker, + .events = POLLIN, + .revents = 0}); +}; +uint64_t XskWorker::frameOffset(const XskPacket& s) const noexcept +{ + return s.frame - umemBufBase; +} +void XskWorker::notifyWorker() noexcept +{ + notify(workerWaker); +} +void XskSocket::getMACFromIfName() +{ + ifreq ifr; + auto fd = ::socket(AF_INET, SOCK_DGRAM, 0); + strncpy(ifr.ifr_name, ifName.c_str(), ifName.length() + 1); + if (ioctl(fd, SIOCGIFHWADDR, &ifr) < 0) { + throw runtime_error("Error get MAC addr"); + } + memcpy(source, ifr.ifr_hwaddr.sa_data, sizeof(source)); + close(fd); +} +[[nodiscard]] int XskSocket::timeDifference(const timespec& t1, const timespec& t2) noexcept +{ + const auto res = t1.tv_sec * 1000 + t1.tv_nsec / 1000000L - (t2.tv_sec * 1000 + t2.tv_nsec / 1000000L); + return static_cast(res); +} +void XskWorker::cleanWorkerNotification() noexcept +{ + uint64_t x = read(xskSocketWaker, &x, sizeof(x)); +} +void XskWorker::cleanSocketNotification() noexcept +{ + uint64_t x = read(workerWaker, &x, sizeof(x)); +} +std::vector getPollFdsForWorker(XskWorker& info) +{ + std::vector fds; + int timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC); + if (timerfd < 0) { + throw std::runtime_error("create_timerfd failed"); + } + fds.push_back(pollfd{ + .fd = info.workerWaker, + .events = POLLIN, + .revents = 0, + }); + fds.push_back(pollfd{ + .fd = timerfd, + .events = POLLIN, + .revents = 0, + }); + return fds; +} +void XskWorker::fillUniqueEmptyOffset() +{ + auto frames = sharedEmptyFrameOffset->lock(); + const auto moveSize = std::min(static_cast(32), frames->size()); + if (moveSize > 0) { + uniqueEmptyFrameOffset.insert(uniqueEmptyFrameOffset.end(), std::make_move_iterator(frames->end() - moveSize), std::make_move_iterator(frames->end())); + } +} +void* XskWorker::getEmptyframe() +{ + if (!uniqueEmptyFrameOffset.empty()) { + auto offset = uniqueEmptyFrameOffset.back(); + uniqueEmptyFrameOffset.pop_back(); + return offset + umemBufBase; + } + fillUniqueEmptyOffset(); + if (!uniqueEmptyFrameOffset.empty()) { + auto offset = uniqueEmptyFrameOffset.back(); + uniqueEmptyFrameOffset.pop_back(); + return offset + umemBufBase; + } + return nullptr; +} +uint32_t XskPacket::getFlags() const noexcept +{ + return flags; +} +void XskPacket::updatePacket() noexcept +{ + if (!(flags & UPDATE)) { + return; + } + if (!(flags & REWRITE)) { + changeDirectAndUpdateChecksum(); + } +} +#endif /* HAVE_XSK */ diff --git a/pdns/xsk.hh b/pdns/xsk.hh new file mode 100644 index 0000000000..83c957dc32 --- /dev/null +++ b/pdns/xsk.hh @@ -0,0 +1,243 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#pragma once +#include "iputils.hh" +#include "misc.hh" +#include "noinitvector.hh" +#include "lock.hh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef HAVE_XSK +#include +#endif /* HAVE_XSK */ + +class XskPacket; +class XskWorker; +class XskSocket; + +#ifdef HAVE_XSK +using XskPacketPtr = std::unique_ptr; + +// We use an XskSocket to manage an AF_XDP Socket corresponding to a NIC queue. +// The XDP program running in the kernel redirects the data to the XskSocket in userspace. +// XskSocket routes packets to multiple worker threads registered on XskSocket via XskSocket::addWorker based on the destination port number of the packet. +// The kernel and the worker thread holding XskWorker will wake up the XskSocket through XskFd and the Eventfd corresponding to each worker thread, respectively. +class XskSocket +{ + struct XskRouteInfo + { + std::shared_ptr worker; + ComboAddress dest; + int xskSocketWaker; + int workerWaker; + }; + struct XskUmem + { + xsk_umem* umem{nullptr}; + uint8_t* bufBase{nullptr}; + size_t size; + void umemInit(size_t memSize, xsk_ring_cons* cq, xsk_ring_prod* fq, xsk_umem_config* config); + ~XskUmem(); + XskUmem() = default; + }; + boost::multi_index_container< + XskRouteInfo, + boost::multi_index::indexed_by< + boost::multi_index::hashed_unique>, + boost::multi_index::hashed_unique, ComboAddress::addressPortOnlyHash>>> + workers; + static constexpr size_t holdThreshold = 256; + static constexpr size_t fillThreshold = 128; + static constexpr size_t frameSize = 2048; + const size_t frameNum; + const uint32_t queueId; + std::priority_queue waitForDelay; + const std::string ifName; + const std::string poolName; + vector fds; + vector uniqueEmptyFrameOffset; + xsk_ring_cons cq; + xsk_ring_cons rx; + xsk_ring_prod fq; + xsk_ring_prod tx; + std::unique_ptr socket; + XskUmem umem; + bpf_object* prog; + + static constexpr uint32_t fqCapacity = XSK_RING_PROD__DEFAULT_NUM_DESCS * 4; + static constexpr uint32_t cqCapacity = XSK_RING_CONS__DEFAULT_NUM_DESCS * 4; + static constexpr uint32_t rxCapacity = XSK_RING_CONS__DEFAULT_NUM_DESCS * 2; + static constexpr uint32_t txCapacity = XSK_RING_PROD__DEFAULT_NUM_DESCS * 2; + + constexpr static bool isPowOfTwo(uint32_t value) noexcept; + [[nodiscard]] static int timeDifference(const timespec& t1, const timespec& t2) noexcept; + friend void XskRouter(std::shared_ptr xsk); + + [[nodiscard]] uint64_t frameOffset(const XskPacket& packet) const noexcept; + int firstTimeout(); + void fillFq(uint32_t fillSize = fillThreshold) noexcept; + void recycle(size_t size) noexcept; + void getMACFromIfName(); + void pickUpReadyPacket(std::vector& packets); + +public: + std::shared_ptr>> sharedEmptyFrameOffset; + XskSocket(size_t frameNum, const std::string& ifName, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_); + MACAddr source; + [[nodiscard]] int xskFd() const noexcept; + int wait(int timeout); + void send(std::vector& packets); + std::vector recv(uint32_t recvSizeMax, uint32_t* failedCount); + void addWorker(std::shared_ptr s, const ComboAddress& dest, bool isTCP); +}; +class XskPacket +{ +public: + enum Flags : uint32_t + { + TCP = 1 << 0, + UPDATE = 1 << 1, + DELAY = 1 << 3, + REWRITE = 1 << 4 + }; + +private: + ComboAddress from; + ComboAddress to; + timespec sendTime; + uint8_t* frame; + uint8_t* l4Header; + uint8_t* payload; + uint8_t* payloadEnd; + uint8_t* frameEnd; + uint32_t flags{0}; + + friend XskSocket; + friend XskWorker; + friend bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept; + + constexpr static uint8_t DefaultTTL = 64; + bool parse(); + void changeDirectAndUpdateChecksum() noexcept; + + // You must set ipHeader.check = 0 before call this method + [[nodiscard]] __be16 ipv4Checksum() const noexcept; + // You must set l4Header.check = 0 before call this method + // ip options is not supported + [[nodiscard]] __be16 tcp_udp_v4_checksum() const noexcept; + // You must set l4Header.check = 0 before call this method + [[nodiscard]] __be16 tcp_udp_v6_checksum() const noexcept; + [[nodiscard]] static uint64_t ip_checksum_partial(const void* p, size_t len, uint64_t sum) noexcept; + [[nodiscard]] static __be16 ip_checksum_fold(uint64_t sum) noexcept; + [[nodiscard]] static uint64_t tcp_udp_v4_header_checksum_partial(__be32 src_ip, __be32 dst_ip, uint8_t protocol, uint16_t len) noexcept; + [[nodiscard]] static uint64_t tcp_udp_v6_header_checksum_partial(const struct in6_addr* src_ip, const struct in6_addr* dst_ip, uint8_t protocol, uint32_t len) noexcept; + void rewriteIpv4Header(void* ipv4header) noexcept; + void rewriteIpv6Header(void* ipv6header) noexcept; + +public: + [[nodiscard]] const ComboAddress& getFromAddr() const noexcept; + [[nodiscard]] const ComboAddress& getToAddr() const noexcept; + [[nodiscard]] const void* payloadData() const; + [[nodiscard]] bool isIPV6() const noexcept; + [[nodiscard]] size_t capacity() const noexcept; + [[nodiscard]] uint32_t dataLen() const noexcept; + [[nodiscard]] uint32_t FrameLen() const noexcept; + [[nodiscard]] PacketBuffer clonePacketBuffer() const; + void cloneIntoPacketBuffer(PacketBuffer& buffer) const; + [[nodiscard]] std::unique_ptr cloneHeadertoPacketBuffer() const; + [[nodiscard]] void* payloadData(); + void setAddr(const ComboAddress& from_, MACAddr fromMAC, const ComboAddress& to_, MACAddr toMAC, bool tcp = false) noexcept; + bool setPayload(const PacketBuffer& buf); + void rewrite() noexcept; + void setHeader(const PacketBuffer& buf) noexcept; + XskPacket() = default; + XskPacket(void* frame, size_t dataSize, size_t frameSize); + void addDelay(int relativeMilliseconds) noexcept; + void updatePacket() noexcept; + [[nodiscard]] uint32_t getFlags() const noexcept; +}; +bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept; + +// XskWorker obtains XskPackets of specific ports in the NIC from XskSocket through cq. +// After finishing processing the packet, XskWorker puts the packet into sq so that XskSocket decides whether to send it through the network card according to XskPacket::flags. +// XskWorker wakes up XskSocket via xskSocketWaker after putting the packets in sq. +class XskWorker +{ + using XskPacketRing = boost::lockfree::spsc_queue>; + +public: + uint8_t* umemBufBase; + std::shared_ptr>> sharedEmptyFrameOffset; + vector uniqueEmptyFrameOffset; + XskPacketRing cq; + XskPacketRing sq; + std::string poolName; + size_t frameSize; + FDWrapper workerWaker; + FDWrapper xskSocketWaker; + + XskWorker(); + static int createEventfd(); + static void notify(int fd); + static std::shared_ptr create(); + void notifyWorker() noexcept; + void notifyXskSocket() noexcept; + void waitForXskSocket() noexcept; + void cleanWorkerNotification() noexcept; + void cleanSocketNotification() noexcept; + [[nodiscard]] uint64_t frameOffset(const XskPacket& s) const noexcept; + void fillUniqueEmptyOffset(); + void* getEmptyframe(); +}; +std::vector getPollFdsForWorker(XskWorker& info); +#else +class XskSocket +{ +}; +class XskPacket +{ +}; +class XskWorker +{ +}; + +#endif /* HAVE_XSK */