BPF_TABLE_PINNED7("lpm_trie", struct CIDR4, struct map_value, cidr4filter, 1024, "/sys/fs/bpf/dnsdist/cidr4", BPF_F_NO_PREALLOC);
BPF_TABLE_PINNED7("lpm_trie", struct CIDR6, struct map_value, cidr6filter, 1024, "/sys/fs/bpf/dnsdist/cidr6", BPF_F_NO_PREALLOC);
+#ifdef UseXsk
+#define BPF_XSKMAP_PIN(_name, _max_entries, _pinned) \
+ struct _name##_table_t \
+ { \
+ u32 key; \
+ int leaf; \
+ int* (*lookup)(int*); \
+ /* xdp_act = map.redirect_map(index, flag) */ \
+ u64 (*redirect_map)(int, int); \
+ u32 max_entries; \
+ }; \
+ __attribute__((section("maps/xskmap:" _pinned))) struct _name##_table_t _name = {.max_entries = (_max_entries)}
+
+BPF_XSKMAP_PIN(xsk_map, 16, "/sys/fs/bpf/dnsdist/xskmap");
+#endif /* UseXsk */
+
+#define COMPARE_PORT(x, p) ((x) == bpf_htons(p))
+
/*
* Recalculate the checksum
* Copyright 2020, NLnet Labs, All rights reserved.
* Set the TC bit and swap UDP ports
* Copyright 2020, NLnet Labs, All rights reserved.
*/
-static inline enum dns_action set_tc_bit(struct udphdr *udp, struct dnshdr *dns)
+static inline void set_tc_bit(struct udphdr* udp, struct dnshdr* dns)
{
uint16_t old_val = dns->flags.as_value;
dns->flags.as_bits_and_pieces.tc = 1;
// change the UDP destination to the source
- udp->dest = udp->source;
- udp->source = bpf_htons(DNS_PORT);
+ uint16_t tmp = udp->dest;
+ udp->dest = udp->source;
+ udp->source = tmp;
// calculate and write the new checksum
update_checksum(&udp->check, old_val, dns->flags.as_value);
-
- // bounce
- return TC;
}
/*
* TC if (modified) message needs to be replied
* DROP if message needs to be blocke
*/
-static inline enum dns_action check_qname(struct cursor *c)
+static inline struct map_value* check_qname(struct cursor* c)
{
struct dns_qname qkey = {0};
uint8_t qname_byte;
uint16_t qtype;
int length = 0;
- for(int i = 0; i<255; i++) {
- if (bpf_probe_read_kernel(&qname_byte, sizeof(qname_byte), c->pos)) {
- return PASS;
- }
- c->pos += 1;
- if (length == 0) {
- if (qname_byte == 0 || qname_byte > 63 ) {
- break;
+ for (int i = 0; i < 255; i++) {
+ if (bpf_probe_read_kernel(&qname_byte, sizeof(qname_byte), c->pos)) {
+ return NULL;
+ }
+ c->pos += 1;
+ if (length == 0) {
+ if (qname_byte == 0 || qname_byte > 63) {
+ break;
}
length += qname_byte;
- } else {
+ }
+ else {
length--;
}
- if (qname_byte >= 'A' && qname_byte <= 'Z') {
- qkey.qname[i] = qname_byte + ('a' - 'A');
- } else {
- qkey.qname[i] = qname_byte;
- }
+ if (qname_byte >= 'A' && qname_byte <= 'Z') {
+ qkey.qname[i] = qname_byte + ('a' - 'A');
+ }
+ else {
+ qkey.qname[i] = qname_byte;
+ }
}
// if the last read qbyte is not 0 incorrect QName format), return PASS
if (qname_byte != 0) {
- return PASS;
+ return NULL;
}
// get QType
- if(bpf_probe_read_kernel(&qtype, sizeof(qtype), c->pos)) {
- return PASS;
+ if (bpf_probe_read_kernel(&qtype, sizeof(qtype), c->pos)) {
+ return NULL;
}
struct map_value* value;
qkey.qtype = bpf_htons(qtype);
value = qnamefilter.lookup(&qkey);
if (value) {
- __sync_fetch_and_add(&value->counter, 1);
- return value->action;
+ return value;
}
// check with Qtype 255 (*)
qkey.qtype = 255;
- value = qnamefilter.lookup(&qkey);
- if (value) {
- __sync_fetch_and_add(&value->counter, 1);
- return value->action;
- }
-
- return PASS;
+ return qnamefilter.lookup(&qkey);
}
/*
* Parse IPv4 DNS mesage.
- * Returns PASS if message needs to go through (i.e. pass)
- * TC if (modified) message needs to be replied
- * DROP if message needs to be blocked
+ * Returns XDP_PASS if message needs to go through (i.e. pass)
+ * XDP_REDIRECT if message needs to be redirected (for AF_XDP, which needs to be translated to the caller into XDP_PASS outside of the AF_XDP)
+ * XDP_TX if (modified) message needs to be replied
+ * XDP_DROP if message needs to be blocked
*/
-static inline enum dns_action udp_dns_reply_v4(struct cursor *c, struct CIDR4 *key)
+static inline enum xdp_action parseIPV4(struct xdp_md* ctx, struct cursor* c)
{
- struct udphdr *udp;
- struct dnshdr *dns;
+ struct iphdr* ipv4;
+ struct udphdr* udp = NULL;
+ struct dnshdr* dns = NULL;
+ if (!(ipv4 = parse_iphdr(c))) {
+ return XDP_PASS;
+ }
+ switch (ipv4->protocol) {
+ case IPPROTO_UDP: {
+ if (!(udp = parse_udphdr(c))) {
+ return XDP_PASS;
+ }
+ if (!IN_DNS_PORT_SET(udp->dest)) {
+ return XDP_PASS;
+ }
+ if (!(dns = parse_dnshdr(c))) {
+ return XDP_DROP;
+ }
+ break;
+ }
- if (!(udp = parse_udphdr(c)) || udp->dest != bpf_htons(DNS_PORT)) {
- return PASS;
+#ifdef UseXsk
+ case IPPROTO_TCP: {
+ struct tcphdr* tcp;
+ if (!(tcp = parse_tcphdr(c))) {
+ return XDP_PASS;
+ }
+ if (!IN_DNS_PORT_SET(tcp->dest)) {
+ return XDP_PASS;
+ }
+ }
+#endif /* UseXsk */
+
+ default:
+ return XDP_PASS;
}
- // check that we have a DNS packet
- if (!(dns = parse_dnshdr(c))) {
- return PASS;
- }
+ struct CIDR4 key;
+ key.addr = bpf_htonl(ipv4->saddr);
// if the address is blocked, perform the corresponding action
- struct map_value* value = v4filter.lookup(&key->addr);
+ struct map_value* value = v4filter.lookup(&key.addr);
if (value) {
- __sync_fetch_and_add(&value->counter, 1);
- if (value->action == TC) {
- return set_tc_bit(udp, dns);
- } else {
- return value->action;
- }
+ goto res;
+ }
+
+ key.cidr = 32;
+ key.addr = bpf_htonl(key.addr);
+ value = cidr4filter.lookup(&key);
+ if (value) {
+ goto res;
}
- key->cidr = 32;
- key->addr = bpf_htonl(key->addr);
- value = cidr4filter.lookup(key);
+ if (dns) {
+ value = check_qname(c);
+ }
if (value) {
+ res:
__sync_fetch_and_add(&value->counter, 1);
- if (value->action == TC) {
- return set_tc_bit(udp, dns);
+ if (value->action == TC && udp && dns) {
+ set_tc_bit(udp, dns);
+ // swap src/dest IP addresses
+ uint32_t swap_ipv4 = ipv4->daddr;
+ ipv4->daddr = ipv4->saddr;
+ ipv4->saddr = swap_ipv4;
+
+ progsarray.call(ctx, 1);
+ return XDP_TX;
}
- else {
- return value->action;
+
+ if (value->action == DROP) {
+ progsarray.call(ctx, 0);
+ return XDP_DROP;
}
}
- enum dns_action action = check_qname(c);
- if (action == TC) {
- return set_tc_bit(udp, dns);
- }
- return action;
+ return XDP_REDIRECT;
}
/*
* Parse IPv6 DNS mesage.
- * Returns PASS if message needs to go through (i.e. pass)
- * TC if (modified) message needs to be replied
- * DROP if message needs to be blocked
+ * Returns XDP_PASS if message needs to go through (i.e. pass)
+ * XDP_REDIRECT if message needs to be redirected (for AF_XDP, which needs to be translated to the caller into XDP_PASS outside of the AF_XDP)
+ * XDP_TX if (modified) message needs to be replied
+ * XDP_DROP if message needs to be blocked
*/
-static inline enum dns_action udp_dns_reply_v6(struct cursor *c, struct CIDR6* key)
+static inline enum xdp_action parseIPV6(struct xdp_md* ctx, struct cursor* c)
{
- struct udphdr *udp;
- struct dnshdr *dns;
+ struct ipv6hdr* ipv6;
+ struct udphdr* udp = NULL;
+ struct dnshdr* dns = NULL;
+ if (!(ipv6 = parse_ipv6hdr(c))) {
+ return XDP_PASS;
+ }
+ switch (ipv6->nexthdr) {
+ case IPPROTO_UDP: {
+ if (!(udp = parse_udphdr(c))) {
+ return XDP_PASS;
+ }
+ if (!IN_DNS_PORT_SET(udp->dest)) {
+ return XDP_PASS;
+ }
+ if (!(dns = parse_dnshdr(c))) {
+ return XDP_DROP;
+ }
+ break;
+ }
-
- if (!(udp = parse_udphdr(c)) || udp->dest != bpf_htons(DNS_PORT)) {
- return PASS;
+#ifdef UseXsk
+ case IPPROTO_TCP: {
+ struct tcphdr* tcp;
+ if (!(tcp = parse_tcphdr(c))) {
+ return XDP_PASS;
+ }
+ if (!IN_DNS_PORT_SET(tcp->dest)) {
+ return XDP_PASS;
+ }
}
+#endif /* UseXsk */
- // check that we have a DNS packet
- ;
- if (!(dns = parse_dnshdr(c))) {
- return PASS;
+ default:
+ return XDP_PASS;
}
+ struct CIDR6 key;
+ key.addr = ipv6->saddr;
+
// if the address is blocked, perform the corresponding action
- struct map_value* value = v6filter.lookup(&key->addr);
+ struct map_value* value = v6filter.lookup(&key.addr);
+ if (value) {
+ goto res;
+ }
+ key.cidr = 128;
+ value = cidr6filter.lookup(&key);
if (value) {
- __sync_fetch_and_add(&value->counter, 1);
- if (value->action == TC) {
- return set_tc_bit(udp, dns);
- } else {
- return value->action;
- }
+ goto res;
}
- key->cidr = 128;
- value = cidr6filter.lookup(key);
+ if (dns) {
+ value = check_qname(c);
+ }
if (value) {
+ res:
__sync_fetch_and_add(&value->counter, 1);
- if (value->action == TC) {
- return set_tc_bit(udp, dns);
+ if (value->action == TC && udp && dns) {
+ set_tc_bit(udp, dns);
+ // swap src/dest IP addresses
+ struct in6_addr swap_ipv6 = ipv6->daddr;
+ ipv6->daddr = ipv6->saddr;
+ ipv6->saddr = swap_ipv6;
+ progsarray.call(ctx, 1);
+ return XDP_TX;
}
- else {
- return value->action;
+ if (value->action == DROP) {
+ progsarray.call(ctx, 0);
+ return XDP_DROP;
}
}
-
- enum dns_action action = check_qname(c);
- if (action == TC) {
- return set_tc_bit(udp, dns);
- }
- return action;
+ return XDP_REDIRECT;
}
int xdp_dns_filter(struct xdp_md* ctx)
struct cursor c;
struct ethhdr *eth;
uint16_t eth_proto;
- struct iphdr *ipv4;
- struct ipv6hdr *ipv6;
- int r = 0;
+ enum xdp_action r;
// initialise the cursor
cursor_init(&c, ctx);
// pass the packet if it is not an ethernet one
if ((eth = parse_eth(&c, ð_proto))) {
// IPv4 packets
- if (eth_proto == bpf_htons(ETH_P_IP))
- {
- if (!(ipv4 = parse_iphdr(&c)) || bpf_htons(ipv4->protocol != IPPROTO_UDP)) {
- return XDP_PASS;
- }
-
- struct CIDR4 key;
- key.addr = bpf_htonl(ipv4->saddr);
- // if TC bit must not be set, apply the action
- if ((r = udp_dns_reply_v4(&c, &key)) != TC) {
- if (r == DROP) {
- progsarray.call(ctx, 0);
- return XDP_DROP;
- }
- return XDP_PASS;
- }
-
- // swap src/dest IP addresses
- uint32_t swap_ipv4 = ipv4->daddr;
- ipv4->daddr = ipv4->saddr;
- ipv4->saddr = swap_ipv4;
+ if (eth_proto == bpf_htons(ETH_P_IP)) {
+ r = parseIPV4(ctx, &c);
+ goto res;
}
// IPv6 packets
else if (eth_proto == bpf_htons(ETH_P_IPV6)) {
- if (!(ipv6 = parse_ipv6hdr(&c)) || bpf_htons(ipv6->nexthdr != IPPROTO_UDP)) {
- return XDP_PASS;
- }
- struct CIDR6 key;
- key.addr = ipv6->saddr;
-
- // if TC bit must not be set, apply the action
- if ((r = udp_dns_reply_v6(&c, &key)) != TC) {
- if (r == DROP) {
- progsarray.call(ctx, 0);
- return XDP_DROP;
- }
- return XDP_PASS;
- }
-
- // swap src/dest IP addresses
- struct in6_addr swap_ipv6 = ipv6->daddr;
- ipv6->daddr = ipv6->saddr;
- ipv6->saddr = swap_ipv6;
+ r = parseIPV6(ctx, &c);
+ goto res;
}
// pass all non-IP packets
- else {
- return XDP_PASS;
- }
+ return XDP_PASS;
}
- else {
+ return XDP_PASS;
+res:
+ switch (r) {
+ case XDP_REDIRECT:
+#ifdef UseXsk
+ return xsk_map.redirect_map(ctx->rx_queue_index, 0);
+#else
return XDP_PASS;
+#endif /* UseXsk */
+ case XDP_TX: { // swap MAC addresses
+ uint8_t swap_eth[ETH_ALEN];
+ memcpy(swap_eth, eth->h_dest, ETH_ALEN);
+ memcpy(eth->h_dest, eth->h_source, ETH_ALEN);
+ memcpy(eth->h_source, swap_eth, ETH_ALEN);
+ // bounce the request
+ return XDP_TX;
+ }
+ default:
+ return r;
}
-
- // swap MAC addresses
- uint8_t swap_eth[ETH_ALEN];
- memcpy(swap_eth, eth->h_dest, ETH_ALEN);
- memcpy(eth->h_dest, eth->h_source, ETH_ALEN);
- memcpy(eth->h_source, swap_eth, ETH_ALEN);
-
- progsarray.call(ctx, 1);
-
- // bounce the request
- return XDP_TX;
}
blocked_qnames = [("localhost", "A", DROP_ACTION), ("test.com", "*", TC_ACTION)]
# Main
-xdp = BPF(src_file="xdp-filter.ebpf.src")
+useXsk = True
+Ports = [53]
+cflag = []
+if useXsk:
+ cflag.append("-DUseXsk")
+IN_DNS_PORT_SET = "||".join("COMPARE_PORT((x),"+str(i)+")" for i in Ports)
+cflag.append(r"-DIN_DNS_PORT_SET(x)=(" + IN_DNS_PORT_SET + r")")
+
+xdp = BPF(src_file="xdp-filter.ebpf.src", cflags=cflag)
fn = xdp.load_func("xdp_dns_filter", BPF.XDP)
xdp.attach_xdp(DEV, fn, 0)
struct bpf_insn;
-int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size,
- int max_entries, int map_flags);
-int bpf_update_elem(int fd, void *key, void *value, unsigned long long flags);
-int bpf_lookup_elem(int fd, void *key, void *value);
-int bpf_delete_elem(int fd, void *key);
-int bpf_get_next_key(int fd, void *key, void *next_key);
-
-int bpf_prog_load(enum bpf_prog_type prog_type,
- const struct bpf_insn *insns, int insn_len,
- const char *license, int kern_version);
-
-int bpf_obj_pin(int fd, const char *pathname);
-int bpf_obj_get(const char *pathname);
#define LOG_BUF_SIZE 65536
extern char bpf_log_buf[LOG_BUF_SIZE];
#include "misc.hh"
+int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size,
+ int max_entries, int map_flags);
+int bpf_update_elem(int fd, void *key, void *value, unsigned long long flags);
+int bpf_lookup_elem(int fd, void *key, void *value);
+int bpf_delete_elem(int fd, void *key);
+int bpf_get_next_key(int fd, void *key, void *next_key);
+
+int bpf_prog_load(enum bpf_prog_type prog_type,
+ const struct bpf_insn *insns, int insn_len,
+ const char *license, int kern_version);
+
+int bpf_obj_pin(int fd, const char *pathname);
+int bpf_obj_get(const char *pathname);
+
static __u64 ptr_to_u64(void *ptr)
{
return (__u64) (unsigned long) ptr;
std::unique_ptr<PacketBuffer> d_packet{nullptr}; // Initial packet, so we can restart the query from the response path if needed // 8
std::unique_ptr<ProtoBufData> d_protoBufData{nullptr};
std::unique_ptr<EDNSExtendedError> d_extendedError{nullptr};
+ std::unique_ptr<PacketBuffer> xskPacketHeader; // 8
boost::optional<uint32_t> tempFailureTTL{boost::none}; // 8
ClientState* cs{nullptr}; // 8
std::unique_ptr<DOHUnitInterface> du; // 8
#include "dnsdist-svc.hh"
#include "dolog.hh"
+#include "xsk.hh"
// NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck)
}
});
#endif /* HAVE_EBPF */
-
+#ifdef HAVE_XSK
+ using xskopt_t = LuaAssociativeTable<boost::variant<uint32_t, std::string>>;
+ luaCtx.writeFunction("newXsk", [client](xskopt_t opts) {
+ if (g_configurationDone) {
+ throw std::runtime_error("newXsk() only can be used at configuration time!");
+ }
+ if (client) {
+ return std::shared_ptr<XskSocket>(nullptr);
+ }
+ uint32_t queue_id;
+ uint32_t frameNums;
+ std::string ifName;
+ std::string path;
+ std::string poolName;
+ if (opts.count("NIC_queue_id") == 1) {
+ queue_id = boost::get<uint32_t>(opts.at("NIC_queue_id"));
+ }
+ else {
+ throw std::runtime_error("NIC_queue_id field is required!");
+ }
+ if (opts.count("frameNums") == 1) {
+ frameNums = boost::get<uint32_t>(opts.at("frameNums"));
+ }
+ else {
+ throw std::runtime_error("frameNums field is required!");
+ }
+ if (opts.count("ifName") == 1) {
+ ifName = boost::get<std::string>(opts.at("ifName"));
+ }
+ else {
+ throw std::runtime_error("ifName field is required!");
+ }
+ if (opts.count("xskMapPath") == 1) {
+ path = boost::get<std::string>(opts.at("xskMapPath"));
+ }
+ else {
+ throw std::runtime_error("xskMapPath field is required!");
+ }
+ if (opts.count("pool") == 1) {
+ poolName = boost::get<std::string>(opts.at("pool"));
+ }
+ extern std::vector<std::shared_ptr<XskSocket>> g_xsk;
+ auto socket = std::make_shared<XskSocket>(frameNums, ifName, queue_id, path, poolName);
+ g_xsk.push_back(socket);
+ return socket;
+ });
+#endif /* HAVE_XSK */
/* EDNSOptionView */
luaCtx.registerFunction<size_t(EDNSOptionView::*)()const>("count", [](const EDNSOptionView& option) {
return option.values.size();
*/
#include <cstdint>
+#include <cstdio>
#include <dirent.h>
#include <fstream>
#include <cinttypes>
#include "dnsdist-ecs.hh"
#include "dnsdist-healthchecks.hh"
#include "dnsdist-lua.hh"
+#include "xsk.hh"
#ifdef LUAJIT_VERSION
#include "dnsdist-lua-ffi.hh"
#endif /* LUAJIT_VERSION */
g_noLuaSideEffect = boost::logic::indeterminate;
}
-using localbind_t = LuaAssociativeTable<boost::variant<bool, int, std::string, LuaArray<int>, LuaArray<std::string>, LuaAssociativeTable<std::string>>>;
+using localbind_t = LuaAssociativeTable<boost::variant<bool, int, std::string, LuaArray<int>, LuaArray<std::string>, LuaAssociativeTable<std::string>, std::shared_ptr<XskSocket>>>;
static void parseLocalBindVars(boost::optional<localbind_t>& vars, bool& reusePort, int& tcpFastOpenQueueSize, std::string& interface, std::set<int>& cpus, int& tcpListenQueueSize, uint64_t& maxInFlightQueriesPerConnection, uint64_t& tcpMaxConcurrentConnections, bool& enableProxyProtocol)
{
}
}
}
+#ifdef HAVE_XSK
+static void parseXskVars(boost::optional<localbind_t>& vars, std::shared_ptr<XskSocket>& socket)
+{
+ if (!vars) {
+ return;
+ }
+
+ getOptionalValue<std::shared_ptr<XskSocket>>(vars, "xskSocket", socket);
+}
+#endif /* HAVE_XSK */
#if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_QUIC)
static bool loadTLSCertificateAndKeys(const std::string& context, std::vector<TLSCertKeyPair>& pairs, const boost::variant<std::string, std::shared_ptr<TLSCertKeyPair>, LuaArray<std::string>, LuaArray<std::shared_ptr<TLSCertKeyPair>>>& certFiles, const LuaTypeOrArrayOf<std::string>& keyFiles)
// NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
{
- typedef LuaAssociativeTable<boost::variant<bool, std::string, LuaArray<std::string>, DownstreamState::checkfunc_t>> newserver_t;
+ typedef LuaAssociativeTable<boost::variant<bool, std::string, LuaArray<std::string>, std::shared_ptr<XskSocket>, DownstreamState::checkfunc_t>> newserver_t;
luaCtx.writeFunction("inClientStartup", [client]() {
return client && !g_configurationDone;
});
if (!(client || configCheck)) {
infolog("Added downstream server %s", ret->d_config.remote.toStringWithPort());
}
-
+#ifdef HAVE_XSK
+ std::shared_ptr<XskSocket> xskSocket;
+ if (getOptionalValue<std::shared_ptr<XskSocket>>(vars, "xskSocket", xskSocket) > 0) {
+ ret->registerXsk(xskSocket);
+ std::string mac;
+ if (getOptionalValue<std::string>(vars, "MACAddr", mac) != 1) {
+ throw runtime_error("field MACAddr is required!");
+ }
+ auto* addr = &ret->d_config.destMACAddr[0];
+ sscanf(mac.c_str(), "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx", addr, addr + 1, addr + 2, addr + 3, addr + 4, addr + 5);
+ }
+#endif /* HAVE_XSK */
if (autoUpgrade && ret->getProtocol() != dnsdist::Protocol::DoT && ret->getProtocol() != dnsdist::Protocol::DoH) {
dnsdist::ServiceDiscovery::addUpgradeableServer(ret, upgradeInterval, upgradePool, upgradeDoHKey, keepAfterUpgrade);
}
}
// only works pre-startup, so no sync necessary
- g_frontends.push_back(std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol));
+ auto udpCS = std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
auto tcpCS = std::make_unique<ClientState>(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
if (tcpListenQueueSize > 0) {
tcpCS->tcpListenQueueSize = tcpListenQueueSize;
tcpCS->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections;
}
+#ifdef HAVE_XSK
+ std::shared_ptr<XskSocket> socket;
+ parseXskVars(vars, socket);
+ if (socket) {
+ udpCS->xskInfo = XskWorker::create();
+ udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset;
+ socket->addWorker(udpCS->xskInfo, loc, false);
+ // tcpCS->xskInfo=XskWorker::create();
+ // TODO: socket->addWorker(tcpCS->xskInfo, loc, true);
+ }
+#endif /* HAVE_XSK */
+ g_frontends.push_back(std::move(udpCS));
g_frontends.push_back(std::move(tcpCS));
}
catch (const std::exception& e) {
try {
ComboAddress loc(addr, 53);
// only works pre-startup, so no sync necessary
- g_frontends.push_back(std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol));
+ auto udpCS = std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
auto tcpCS = std::make_unique<ClientState>(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
if (tcpListenQueueSize > 0) {
tcpCS->tcpListenQueueSize = tcpListenQueueSize;
if (tcpMaxConcurrentConnections > 0) {
tcpCS->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections;
}
+#ifdef HAVE_XSK
+ std::shared_ptr<XskSocket> socket;
+ parseXskVars(vars, socket);
+ if (socket) {
+ udpCS->xskInfo = XskWorker::create();
+ udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset;
+ socket->addWorker(udpCS->xskInfo, loc, false);
+ // TODO tcpCS->xskInfo=XskWorker::create();
+ // TODO socket->addWorker(tcpCS->xskInfo, loc, true);
+ }
+#endif /* HAVE_XSK */
+ g_frontends.push_back(std::move(udpCS));
g_frontends.push_back(std::move(tcpCS));
}
catch (std::exception& e) {
if (!response.isAsync()) {
try {
auto& ids = response.d_idstate;
- unsigned int qnameWireLength{0};
std::shared_ptr<DownstreamState> backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
- if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend, qnameWireLength)) {
+ if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend)) {
state->terminateClientConnection();
return;
}
#include <limits>
#include <netinet/tcp.h>
#include <pwd.h>
+#include <set>
#include <sys/resource.h>
#include <unistd.h>
+#ifdef HAVE_XSK
+#include <sys/poll.h>
+#include <sys/timerfd.h>
+#endif /* HAVE_XSK */
+
#ifdef HAVE_LIBEDIT
#if defined (__OpenBSD__) || defined(__NetBSD__)
// If this is not undeffed, __attribute__ wil be redefined by /usr/include/readline/rlstdc.h
std::vector<std::shared_ptr<DOQFrontend>> g_doqlocals;
std::vector<std::shared_ptr<DOH3Frontend>> g_doh3locals;
std::vector<std::shared_ptr<DNSCryptContext>> g_dnsCryptLocals;
+std::vector<std::shared_ptr<XskSocket>> g_xsk;
shared_ptr<BPFFilter> g_defaultBPFFilter{nullptr};
std::vector<std::shared_ptr<DynBPFFilter> > g_dynBPFFilters;
}
}
-bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength)
+bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote)
{
if (response.size() < sizeof(dnsheader)) {
return false;
uint16_t rqtype, rqclass;
DNSName rqname;
try {
- rqname = DNSName(reinterpret_cast<const char*>(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass, &qnameWireLength);
+ rqname = DNSName(reinterpret_cast<const char*>(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass);
}
catch (const std::exception& e) {
if (remote && response.size() > 0 && static_cast<size_t>(response.size()) > sizeof(dnsheader)) {
}
bool muted = true;
- if (ids.cs && !ids.cs->muted) {
- ComboAddress empty;
- empty.sin4.sin_family = 0;
+ if (ids.cs && !ids.cs->muted && !ids.xskPacketHeader) {
sendUDPResponse(ids.cs->udpFD, response, dr.ids.delayMsec, ids.hopLocal, ids.hopRemote);
muted = false;
}
}
}
+#ifdef HAVE_XSK
+static void XskHealthCheck(std::shared_ptr<DownstreamState>& dss, std::unordered_map<uint16_t, std::shared_ptr<HealthCheckData>>& map, bool initial = false)
+{
+ auto& xskInfo = dss->xskInfo;
+ std::shared_ptr<HealthCheckData> data;
+ auto packet = getHealthCheckPacket(dss, nullptr, data);
+ data->d_initial = initial;
+ setHealthCheckTime(dss, data);
+ auto* frame = xskInfo->getEmptyframe();
+ auto *xskPacket = new XskPacket(frame, 0, xskInfo->frameSize);
+ xskPacket->setAddr(dss->d_config.sourceAddr, dss->d_config.sourceMACAddr, dss->d_config.remote, dss->d_config.destMACAddr);
+ xskPacket->setPayload(packet);
+ xskPacket->rewrite();
+ xskInfo->sq.push(xskPacket);
+ const auto queryId = data->d_queryID;
+ map[queryId] = std::move(data);
+}
+#endif /* HAVE_XSK */
+
+static bool processResponderPacket(std::shared_ptr<DownstreamState>& dss, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& localRespRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, InternalQueryState&& ids)
+{
+
+ const dnsheader_aligned dh(response.data());
+ auto queryId = dh->id;
+
+ if (!responseContentMatches(response, ids.qname, ids.qtype, ids.qclass, dss)) {
+ dss->restoreState(queryId, std::move(ids));
+ return false;
+ }
+
+ auto du = std::move(ids.du);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) {
+ header.id = ids.origID;
+ return true;
+ });
+ ++dss->responses;
+
+ double udiff = ids.queryRealTime.udiff();
+ // do that _before_ the processing, otherwise it's not fair to the backend
+ dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0;
+ dss->reportResponse(dh->rcode);
+
+ /* don't call processResponse for DOH */
+ if (du) {
+#ifdef HAVE_DNS_OVER_HTTPS
+ // DoH query, we cannot touch du after that
+ DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(ids), dss);
+#endif
+ return false;
+ }
+
+ handleResponseForUDPClient(ids, response, localRespRuleActions, cacheInsertedRespRuleActions, dss, false, false);
+ return true;
+}
+
// listens on a dedicated socket, lobs answers from downstream servers to original requestors
void responderThread(std::shared_ptr<DownstreamState> dss)
{
setThreadName("dnsdist/respond");
auto localRespRuleActions = g_respruleactions.getLocal();
auto localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
+#ifdef HAVE_XSK
+ if (dss->xskInfo) {
+ auto xskInfo = dss->xskInfo;
+ auto pollfds = getPollFdsForWorker(*xskInfo);
+ std::unordered_map<uint16_t, std::shared_ptr<HealthCheckData>> healthCheckMap;
+ XskHealthCheck(dss, healthCheckMap, true);
+ itimerspec tm;
+ tm.it_value.tv_sec = dss->d_config.checkTimeout / 1000;
+ tm.it_value.tv_nsec = (dss->d_config.checkTimeout % 1000) * 1000000;
+ tm.it_interval = tm.it_value;
+ auto res = timerfd_settime(pollfds[1].fd, 0, &tm, nullptr);
+ if (res) {
+ throw std::runtime_error("timerfd_settime failed:" + stringerror(errno));
+ }
+ const auto xskFd = xskInfo->workerWaker.getHandle();
+ while (!dss->isStopped()) {
+ poll(pollfds.data(), pollfds.size(), -1);
+ bool needNotify = false;
+ if (pollfds[0].revents & POLLIN) {
+ needNotify = true;
+ xskInfo->cq.consume_all([&](XskPacket* packet) {
+ if (packet->dataLen() < sizeof(dnsheader)) {
+ xskInfo->sq.push(packet);
+ return;
+ }
+ const auto* dh = reinterpret_cast<const struct dnsheader*>(packet->payloadData());
+ const auto queryId = dh->id;
+ auto ids = dss->getState(queryId);
+ if (ids) {
+ if (xskFd != ids->backendFD || !ids->xskPacketHeader) {
+ dss->restoreState(queryId, std::move(*ids));
+ ids = std::nullopt;
+ }
+ }
+ if (!ids) {
+ // this has to go before we can refactor the duplicated response handling code
+ auto iter = healthCheckMap.find(queryId);
+ if (iter != healthCheckMap.end()) {
+ auto data = std::move(iter->second);
+ healthCheckMap.erase(iter);
+ packet->cloneIntoPacketBuffer(data->d_buffer);
+ data->d_ds->submitHealthCheckResult(data->d_initial, handleResponse(data));
+ }
+ xskInfo->sq.push(packet);
+ return;
+ }
+ auto response = packet->clonePacketBuffer();
+ if (!processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids))) {
+ xskInfo->sq.push(packet);
+ return;
+ }
+ packet->setHeader(*ids->xskPacketHeader);
+ packet->setPayload(response);
+ if (ids->delayMsec > 0) {
+ packet->addDelay(ids->delayMsec);
+ }
+ packet->updatePacket();
+ xskInfo->sq.push(packet);
+ });
+ xskInfo->cleanSocketNotification();
+ }
+ if (pollfds[1].revents & POLLIN) {
+ timeval now;
+ gettimeofday(&now, nullptr);
+ for (auto i = healthCheckMap.begin(); i != healthCheckMap.end();) {
+ auto& ttd = i->second->d_ttd;
+ if (ttd < now) {
+ dss->submitHealthCheckResult(i->second->d_initial, false);
+ i = healthCheckMap.erase(i);
+ }
+ else {
+ ++i;
+ }
+ }
+ needNotify = true;
+ dss->updateStatisticsInfo();
+ dss->handleUDPTimeouts();
+ if (dss->d_nextCheck <= 1) {
+ dss->d_nextCheck = dss->d_config.checkInterval;
+ if (dss->d_config.availability == DownstreamState::Availability::Auto) {
+ XskHealthCheck(dss, healthCheckMap);
+ }
+ }
+ else {
+ --dss->d_nextCheck;
+ }
+
+ uint64_t tmp;
+ res = read(pollfds[1].fd, &tmp, sizeof(tmp));
+ }
+ if (needNotify) {
+ xskInfo->notifyXskSocket();
+ }
+ }
+ }
+ else {
+#endif /* HAVE_XSK */
const size_t initialBufferSize = getInitialUDPPacketBufferSize(false);
/* allocate one more byte so we can detect truncation */
PacketBuffer response(initialBufferSize + 1);
for (const auto& fd : sockets) {
/* allocate one more byte so we can detect truncation */
- // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it
+ // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it
response.resize(initialBufferSize + 1);
ssize_t got = recv(fd, response.data(), response.size(), 0);
continue;
}
- unsigned int qnameWireLength = 0;
- if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) {
+ if (fd != ids->backendFD) {
dss->restoreState(queryId, std::move(*ids));
continue;
}
- auto du = std::move(ids->du);
-
- dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) {
- header.id = ids->origID;
- return true;
- });
- ++dss->responses;
-
- double udiff = ids->queryRealTime.udiff();
- // do that _before_ the processing, otherwise it's not fair to the backend
- dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0;
- dss->reportResponse(dh->rcode);
-
- /* don't call processResponse for DOH */
- if (du) {
-#ifdef HAVE_DNS_OVER_HTTPS
- // DoH query, we cannot touch du after that
- DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(*ids), dss);
-#endif
- continue;
+ if (processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids)) && ids->xskPacketHeader && ids->cs->xskInfo) {
+#ifdef HAVE_XSK
+ auto& xskInfo = ids->cs->xskInfo;
+ auto* frame = xskInfo->getEmptyframe();
+ auto xskPacket = std::make_unique<XskPacket>(frame, 0, xskInfo->frameSize);
+ xskPacket->setHeader(*ids->xskPacketHeader);
+ xskPacket->setPayload(response);
+ xskPacket->updatePacket();
+ xskInfo->sq.push(xskPacket.release());
+ xskInfo->notifyXskSocket();
+#endif /* HAVE_XSK */
}
-
- handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, false);
}
}
catch (const std::exception& e) {
vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->d_config.remote.toStringWithPort(), queryId, e.what());
}
}
+#ifdef HAVE_XSK
+ }
+#endif /* HAVE_XSK */
}
catch (const std::exception& e) {
errlog("UDP responder thread died because of exception: %s", e.what());
return true;
}
+#ifdef HAVE_XSK
+static bool isXskQueryAcceptable(const XskPacket& packet, ClientState& cs, LocalHolders& holders, bool& expectProxyProtocol) noexcept
+{
+ const auto& from = packet.getFromAddr();
+ expectProxyProtocol = expectProxyProtocolFrom(from);
+ if (!holders.acl->match(from) && !expectProxyProtocol) {
+ vinfolog("Query from %s dropped because of ACL", from.toStringWithPort());
+ ++dnsdist::metrics::g_stats.aclDrops;
+ return false;
+ }
+ cs.queries++;
+ ++dnsdist::metrics::g_stats.queries;
+
+ return true;
+}
+#endif /* HAVE_XSK */
+
bool checkDNSCryptQuery(const ClientState& cs, PacketBuffer& query, std::unique_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
{
if (cs.dnscryptCtx) {
++dq.ids.cs->responses;
return ProcessQueryResult::SendAnswer;
}
-
+#ifdef HAVE_XSK
+ if (dq.ids.cs->xskInfo) {
+ dq.ids.poolName = dq.ids.cs->xskInfo->poolName;
+ }
+#endif /* HAVE_XSK */
std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, dq.ids.poolName);
std::shared_ptr<ServerPolicy> poolPolicy = serverPool->policy;
dq.ids.packetCache = serverPool->packetCache;
return ProcessQueryResult::Drop;
}
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool actuallySend)
{
bool doh = dq.ids.du != nullptr;
try {
int fd = ds->pickSocketForSending();
- dq.ids.backendFD = fd;
+ if (actuallySend) {
+ dq.ids.backendFD = fd;
+ }
dq.ids.origID = queryID;
dq.ids.forwardedOverUDP = true;
/* set the correct ID */
memcpy(&query.at(proxyProtocolPayloadSize), &idOffset, sizeof(idOffset));
+ if (!actuallySend) {
+ return true;
+ }
+
/* you can't touch ids or du after this line, unless the call returned a non-negative value,
because it might already have been freed */
ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);
}
}
+#ifdef HAVE_XSK
+static void ProcessXskQuery(ClientState& cs, LocalHolders& holders, XskPacket& packet)
+{
+ uint16_t queryId = 0;
+ const auto& remote = packet.getFromAddr();
+ const auto& dest = packet.getToAddr();
+ InternalQueryState ids;
+ ids.cs = &cs;
+ ids.origRemote = remote;
+ ids.hopRemote = remote;
+ ids.origDest = dest;
+ ids.hopLocal = dest;
+ ids.protocol = dnsdist::Protocol::DoUDP;
+ ids.xskPacketHeader = packet.cloneHeadertoPacketBuffer();
+
+ try {
+ bool expectProxyProtocol = false;
+ if (!isXskQueryAcceptable(packet, cs, holders, expectProxyProtocol)) {
+ return;
+ }
+
+ auto query = packet.clonePacketBuffer();
+ std::vector<ProxyProtocolValue> proxyProtocolValues;
+ if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, ids.origRemote, ids.origDest, proxyProtocolValues)) {
+ return;
+ }
+
+ ids.queryRealTime.start();
+
+ auto dnsCryptResponse = checkDNSCryptQuery(cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, false);
+ if (dnsCryptResponse) {
+ packet.setPayload(query);
+ return;
+ }
+
+ {
+ /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
+ dnsheader_aligned dnsHeader(query.data());
+ queryId = ntohs(dnsHeader->id);
+
+ if (!checkQueryHeaders(dnsHeader.get(), cs)) {
+ return;
+ }
+
+ if (dnsHeader->qdcount == 0) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
+ header.rcode = RCode::NotImp;
+ header.qr = true;
+ return true;
+ });
+ packet.setPayload(query);
+ return;
+ }
+ }
+
+ ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
+ if (ids.origDest.sin4.sin_family == 0) {
+ ids.origDest = cs.local;
+ }
+ if (ids.dnsCryptQuery) {
+ ids.protocol = dnsdist::Protocol::DNSCryptUDP;
+ }
+ DNSQuestion dq(ids, query);
+ if (!proxyProtocolValues.empty()) {
+ dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
+ }
+ std::shared_ptr<DownstreamState> ss{nullptr};
+ auto result = processQuery(dq, holders, ss);
+
+ if (result == ProcessQueryResult::Drop) {
+ return;
+ }
+
+ if (result == ProcessQueryResult::SendAnswer) {
+ packet.setPayload(query);
+ if (dq.ids.delayMsec > 0) {
+ packet.addDelay(dq.ids.delayMsec);
+ }
+ return;
+ }
+
+ if (result != ProcessQueryResult::PassToBackend || ss == nullptr) {
+ return;
+ }
+
+ // the buffer might have been invalidated by now (resized)
+ const auto dh = dq.getHeader();
+ if (ss->isTCPOnly()) {
+ std::string proxyProtocolPayload;
+ /* we need to do this _before_ creating the cross protocol query because
+ after that the buffer will have been moved */
+ if (ss->d_config.useProxyProtocol) {
+ proxyProtocolPayload = getProxyProtocolPayload(dq);
+ }
+
+ ids.origID = dh->id;
+ auto cpq = std::make_unique<UDPCrossProtocolQuery>(std::move(query), std::move(ids), ss);
+ cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
+
+ ss->passCrossProtocolQuery(std::move(cpq));
+ return;
+ }
+
+ if (!ss->xskInfo) {
+ assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, true);
+ }
+ else {
+ int fd = ss->xskInfo->workerWaker;
+ ids.backendFD = fd;
+ assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, false);
+ packet.setAddr(ss->d_config.sourceAddr,ss->d_config.sourceMACAddr, ss->d_config.remote,ss->d_config.destMACAddr);
+ packet.setPayload(query);
+ packet.rewrite();
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
+ }
+}
+#endif /* HAVE_XSK */
+
#ifndef DISABLE_RECVMMSG
#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders)
#endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
#endif /* DISABLE_RECVMMSG */
+#ifdef HAVE_XSK
+static void xskClientThread(ClientState* cs)
+{
+ setThreadName("dnsdist/xskClient");
+ auto xskInfo = cs->xskInfo;
+ LocalHolders holders;
+
+ for (;;) {
+ while (!xskInfo->cq.read_available()) {
+ xskInfo->waitForXskSocket();
+ }
+ xskInfo->cq.consume_all([&](XskPacket* packet) {
+ ProcessXskQuery(*cs, holders, *packet);
+ packet->updatePacket();
+ xskInfo->sq.push(packet);
+ });
+ xskInfo->notifyXskSocket();
+ }
+}
+#endif /* HAVE_XSK */
+
// listens to incoming queries, sends out to downstream servers, noting the intended return path
static void udpClientThread(std::vector<ClientState*> states)
{
std::unique_ptr<FDMultiplexer> mplexer{nullptr};
for (auto& dss : *states) {
- auto delta = dss->sw.udiffAndSet()/1000000.0;
- dss->queryLoad.store(1.0*(dss->queries.load() - dss->prev.queries.load())/delta);
- dss->dropRate.store(1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta);
- dss->prev.queries.store(dss->queries.load());
- dss->prev.reuseds.store(dss->reuseds.load());
+#ifdef HAVE_XSK
+ if (dss->xskInfo) {
+ continue;
+ }
+#endif /* HAVE_XSK */
+ dss->updateStatisticsInfo();
dss->handleUDPTimeouts();
}
}
+#ifdef HAVE_XSK
+void XskRouter(std::shared_ptr<XskSocket> xsk);
+#endif /* HAVE_XSK */
+
namespace dnsdist
{
static void startFrontends()
{
+#ifdef HAVE_XSK
+ for (auto& xskContext : g_xsk) {
+ std::thread xskThread(XskRouter, std::move(xskContext));
+ xskThread.detach();
+ }
+#endif /* HAVE_XSK */
+
std::vector<ClientState*> tcpStates;
std::vector<ClientState*> udpStates;
for (auto& clientState : g_frontends) {
+#ifdef HAVE_XSK
+ if (clientState->xskInfo) {
+ std::thread xskCT(xskClientThread, clientState.get());
+ if (!clientState->cpus.empty()) {
+ mapThreadToCPUList(xskCT.native_handle(), clientState->cpus);
+ }
+ xskCT.detach();
+ continue;
+ }
+#endif /* HAVE_XSK */
+
if (clientState->dohFrontend != nullptr && clientState->dohFrontend->d_library == "h2o") {
#ifdef HAVE_DNS_OVER_HTTPS
#ifdef HAVE_LIBH2OEVLOOP
auto states = g_dstates.getCopy(); // it is a copy, but the internal shared_ptrs are the real deal
auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(states.size()));
for (auto& dss : states) {
+#ifdef HAVE_XSK
+ if (dss->xskInfo) {
+ continue;
+ }
+#endif /* HAVE_XSK */
+
if (dss->d_config.availability == DownstreamState::Availability::Auto || dss->d_config.availability == DownstreamState::Availability::Lazy) {
if (dss->d_config.availability == DownstreamState::Availability::Auto) {
dss->d_nextCheck = dss->d_config.checkInterval;
#endif
}
}
+
+#ifdef HAVE_XSK
+void XskRouter(std::shared_ptr<XskSocket> xsk)
+{
+ setThreadName("dnsdist/XskRouter");
+ uint32_t failed;
+ // packets to be submitted for sending
+ vector<XskPacketPtr> fillInTx;
+ const auto size = xsk->fds.size();
+ // list of workers that need to be notified
+ std::set<int> needNotify;
+ const auto& xskWakerIdx = xsk->workers.get<0>();
+ const auto& destIdx = xsk->workers.get<1>();
+ while (true) {
+ auto ready = xsk->wait(-1);
+ // descriptor 0 gets incoming AF_XDP packets
+ if (xsk->fds[0].revents & POLLIN) {
+ auto packets = xsk->recv(64, &failed);
+ dnsdist::metrics::g_stats.nonCompliantQueries += failed;
+ for (auto &packet : packets) {
+ const auto dest = packet->getToAddr();
+ auto res = destIdx.find(dest);
+ if (res == destIdx.end()) {
+ xsk->uniqueEmptyFrameOffset.push_back(xsk->frameOffset(*packet));
+ continue;
+ }
+ res->worker->cq.push(packet.release());
+ needNotify.insert(res->workerWaker);
+ }
+ for (auto i : needNotify) {
+ uint64_t x = 1;
+ auto written = write(i, &x, sizeof(x));
+ if (written != sizeof(x)) {
+ // oh, well, the worker is clearly overloaded
+ // but there is nothing we can do about it,
+ // and hopefully the queue will be processed eventually
+ }
+ }
+ needNotify.clear();
+ ready--;
+ }
+ const auto backup = ready;
+ for (size_t i = 1; i < size && ready > 0; i++) {
+ if (xsk->fds[i].revents & POLLIN) {
+ ready--;
+ auto& info = xskWakerIdx.find(xsk->fds[i].fd)->worker;
+ info->sq.consume_all([&](XskPacket* x) {
+ if (!(x->getFlags() & XskPacket::UPDATE)) {
+ xsk->uniqueEmptyFrameOffset.push_back(xsk->frameOffset(*x));
+ return;
+ }
+ auto ptr = std::unique_ptr<XskPacket>(x);
+ if (x->getFlags() & XskPacket::DELAY) {
+ xsk->waitForDelay.push(std::move(ptr));
+ return;
+ }
+ fillInTx.push_back(std::move(ptr));
+ });
+ info->cleanWorkerNotification();
+ }
+ }
+ xsk->pickUpReadyPacket(fillInTx);
+ xsk->recycle(64);
+ xsk->fillFq();
+ xsk->send(fillInTx);
+ ready = backup;
+ }
+}
+#endif /* HAVE_XSK */
#include "uuid-utils.hh"
#include "proxy-protocol.hh"
#include "stat_t.hh"
+#include "xsk.hh"
uint64_t uptimeOfProcess(const std::string& str);
std::shared_ptr<DOQFrontend> doqFrontend{nullptr};
std::shared_ptr<DOH3Frontend> doh3Frontend{nullptr};
std::shared_ptr<BPFFilter> d_filter{nullptr};
+ std::shared_ptr<XskWorker> xskInfo{nullptr};
size_t d_maxInFlightQueriesPerConn{1};
size_t d_tcpConcurrentConnectionsLimit{0};
int udpFD{-1};
std::string d_dohPath;
std::string name;
std::string nameWithAddr;
+ MACAddr sourceMACAddr;
+ MACAddr destMACAddr;
size_t d_numberOfSockets{1};
size_t d_maxInFlightQueriesPerConn{1};
size_t d_tcpConcurrentConnectionsLimit{0};
std::vector<int> sockets;
StopWatch sw;
QPSLimiter qps;
+ std::shared_ptr<XskWorker> xskInfo{nullptr};
std::atomic<uint64_t> idOffset{0};
size_t socketsOffset{0};
double latencyUsec{0.0};
uint8_t consecutiveSuccessfulChecks{0};
bool d_stopped{false};
public:
-
+ void updateStatisticsInfo()
+ {
+ auto delta = sw.udiffAndSet() / 1000000.0;
+ queryLoad.store(1.0 * (queries.load() - prev.queries.load()) / delta);
+ dropRate.store(1.0 * (reuseds.load() - prev.reuseds.load()) / delta);
+ prev.queries.store(queries.load());
+ prev.reuseds.store(reuseds.load());
+ }
void start();
bool isUp() const
void restoreState(uint16_t id, InternalQueryState&&);
std::optional<InternalQueryState> getState(uint16_t id);
+#ifdef HAVE_XSK
+ void registerXsk(std::shared_ptr<XskSocket>& xsk)
+ {
+ xskInfo = XskWorker::create();
+ if (d_config.sourceAddr.sin4.sin_family == 0) {
+ throw runtime_error("invalid source addr");
+ }
+ xsk->addWorker(xskInfo, d_config.sourceAddr, getProtocol() != dnsdist::Protocol::DoUDP);
+ memcpy(d_config.sourceMACAddr, xsk->source, sizeof(MACAddr));
+ xskInfo->sharedEmptyFrameOffset = xsk->sharedEmptyFrameOffset;
+ }
+#endif /* HAVE_XSK */
+
dnsdist::Protocol getProtocol() const
{
if (isDoH()) {
bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect
void resetLuaSideEffect(); // reset to indeterminate state
-bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength);
+bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote);
bool checkQueryHeaders(const struct dnsheader* dh, ClientState& cs);
bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop);
bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted);
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query);
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool actuallySend = true);
ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote);
tcpiohandler.cc tcpiohandler.hh \
threadname.hh threadname.cc \
uuid-utils.hh uuid-utils.cc \
- xpf.cc xpf.hh
+ xpf.cc xpf.hh \
+ xsk.cc xsk.hh
testrunner_SOURCES = \
base64.hh \
testrunner.cc \
threadname.hh threadname.cc \
uuid-utils.hh uuid-utils.cc \
- xpf.cc xpf.hh
+ xpf.cc xpf.hh \
+ xsk.cc xsk.hh
dnsdist_LDFLAGS = \
$(AM_LDFLAGS) \
dnsdist_LDADD += $(LIBSSL_LIBS)
endif
+if HAVE_XSK
+dnsdist_LDADD += -lbpf
+dnsdist_LDADD += -lxdp
+testrunner_LDADD += -lbpf
+testrunner_LDADD += -lxdp
+endif
+
if HAVE_LIBCRYPTO
dnsdist_LDADD += $(LIBCRYPTO_LDFLAGS) $(LIBCRYPTO_LIBS)
testrunner_LDADD += $(LIBCRYPTO_LDFLAGS) $(LIBCRYPTO_LIBS)
PDNS_WITH_RE2
DNSDIST_ENABLE_DNSCRYPT
PDNS_WITH_EBPF
+PDNS_WITH_XSK
PDNS_WITH_NET_SNMP
PDNS_WITH_LIBCAP
setUpStatus(newResult);
if (newResult == false) {
currentCheckFailures++;
- auto stats = d_lazyHealthCheckStats.lock();
- stats->d_status = LazyHealthCheckStats::LazyStatus::Failed;
- updateNextLazyHealthCheck(*stats, false);
+ if (d_config.availability == DownstreamState::Availability::Lazy) {
+ auto stats = d_lazyHealthCheckStats.lock();
+ stats->d_status = LazyHealthCheckStats::LazyStatus::Failed;
+ updateNextLazyHealthCheck(*stats, false);
+ }
}
return;
}
bool g_verboseHealthChecks{false};
-struct HealthCheckData
-{
- enum class TCPState : uint8_t
- {
- WritingQuery,
- ReadingResponseSize,
- ReadingResponse
- };
-
- HealthCheckData(FDMultiplexer& mplexer, std::shared_ptr<DownstreamState> downstream, DNSName&& checkName, uint16_t checkType, uint16_t checkClass, uint16_t queryID) :
- d_ds(std::move(downstream)), d_mplexer(mplexer), d_udpSocket(-1), d_checkName(std::move(checkName)), d_checkType(checkType), d_checkClass(checkClass), d_queryID(queryID)
- {
- }
-
- const std::shared_ptr<DownstreamState> d_ds;
- FDMultiplexer& d_mplexer;
- std::unique_ptr<TCPIOHandler> d_tcpHandler{nullptr};
- std::unique_ptr<IOStateHandler> d_ioState{nullptr};
- PacketBuffer d_buffer;
- Socket d_udpSocket;
- DNSName d_checkName;
- struct timeval d_ttd
- {
- 0, 0
- };
- size_t d_bufferPos{0};
- uint16_t d_checkType;
- uint16_t d_checkClass;
- uint16_t d_queryID;
- TCPState d_tcpState{TCPState::WritingQuery};
- bool d_initial{false};
-};
-
-static bool handleResponse(std::shared_ptr<HealthCheckData>& data)
+bool handleResponse(std::shared_ptr<HealthCheckData>& data)
{
const auto& downstream = data->d_ds;
try {
}
++data->d_ds->d_healthCheckMetrics.d_networkErrors;
data->d_ds->submitHealthCheckResult(data->d_initial, false);
- data->d_mplexer.removeReadFD(descriptor);
+ data->d_mplexer->removeReadFD(descriptor);
return;
}
} while (got < 0);
return;
}
- data->d_mplexer.removeReadFD(descriptor);
+ data->d_mplexer->removeReadFD(descriptor);
data->d_ds->submitHealthCheckResult(data->d_initial, handleResponse(data));
}
}
}
-bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& downstream, bool initialCheck)
+PacketBuffer getHealthCheckPacket(const std::shared_ptr<DownstreamState>& downstream, FDMultiplexer* mplexer, std::shared_ptr<HealthCheckData>& data)
{
- try {
- uint16_t queryID = dnsdist::getRandomDNSID();
- DNSName checkName = downstream->d_config.checkName;
- uint16_t checkType = downstream->d_config.checkType.getCode();
- uint16_t checkClass = downstream->d_config.checkClass;
- dnsheader checkHeader{};
- memset(&checkHeader, 0, sizeof(checkHeader));
-
- checkHeader.qdcount = htons(1);
- checkHeader.id = queryID;
-
- checkHeader.rd = true;
- if (downstream->d_config.setCD) {
- checkHeader.cd = true;
- }
+ uint16_t queryID = dnsdist::getRandomDNSID();
+ DNSName checkName = downstream->d_config.checkName;
+ uint16_t checkType = downstream->d_config.checkType.getCode();
+ uint16_t checkClass = downstream->d_config.checkClass;
+ dnsheader checkHeader{};
+ memset(&checkHeader, 0, sizeof(checkHeader));
+
+ checkHeader.qdcount = htons(1);
+ checkHeader.id = queryID;
+
+ checkHeader.rd = true;
+ if (downstream->d_config.setCD) {
+ checkHeader.cd = true;
+ }
- if (downstream->d_config.checkFunction) {
- auto lock = g_lua.lock();
- auto ret = downstream->d_config.checkFunction(checkName, checkType, checkClass, &checkHeader);
- checkName = std::get<0>(ret);
- checkType = std::get<1>(ret);
- checkClass = std::get<2>(ret);
- }
+ if (downstream->d_config.checkFunction) {
+ auto lock = g_lua.lock();
+ auto ret = downstream->d_config.checkFunction(checkName, checkType, checkClass, &checkHeader);
+ checkName = std::get<0>(ret);
+ checkType = std::get<1>(ret);
+ checkClass = std::get<2>(ret);
+ }
+ PacketBuffer packet;
+ GenericDNSPacketWriter<PacketBuffer> dpw(packet, checkName, checkType, checkClass);
+ dnsheader* requestHeader = dpw.getHeader();
+ *requestHeader = checkHeader;
+ data = std::make_shared<HealthCheckData>(mplexer, downstream, std::move(checkName), checkType, checkClass, queryID);
+ return packet;
+}
- PacketBuffer packet;
- GenericDNSPacketWriter<PacketBuffer> dpw(packet, checkName, checkType, checkClass);
- dnsheader* requestHeader = dpw.getHeader();
- *requestHeader = checkHeader;
+void setHealthCheckTime(const std::shared_ptr<DownstreamState>& downstream, const std::shared_ptr<HealthCheckData>& data)
+{
+ gettimeofday(&data->d_ttd, nullptr);
+ data->d_ttd.tv_sec += static_cast<decltype(data->d_ttd.tv_sec)>(downstream->d_config.checkTimeout / 1000); /* ms to seconds */
+ data->d_ttd.tv_usec += static_cast<decltype(data->d_ttd.tv_usec)>((downstream->d_config.checkTimeout % 1000) * 1000); /* remaining ms to us */
+ normalizeTV(data->d_ttd);
+}
+
+bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& downstream, bool initialCheck)
+{
+ try {
+ std::shared_ptr<HealthCheckData> data;
+ PacketBuffer packet = getHealthCheckPacket(downstream, mplexer.get(), data);
/* we need to compute that _before_ adding the proxy protocol payload */
uint16_t packetSize = packet.size();
sock.bind(downstream->d_config.sourceAddr, false);
}
- auto data = std::make_shared<HealthCheckData>(*mplexer, downstream, std::move(checkName), checkType, checkClass, queryID);
data->d_initial = initialCheck;
- gettimeofday(&data->d_ttd, nullptr);
- data->d_ttd.tv_sec += static_cast<decltype(data->d_ttd.tv_sec)>(downstream->d_config.checkTimeout / 1000); /* ms to seconds */
- data->d_ttd.tv_usec += static_cast<decltype(data->d_ttd.tv_usec)>((downstream->d_config.checkTimeout % 1000) * 1000); /* remaining ms to us */
- normalizeTV(data->d_ttd);
+ setHealthCheckTime(downstream, data);
if (!downstream->doHealthcheckOverTCP()) {
sock.connect(downstream->d_config.remote);
if (sent < 0) {
int ret = errno;
if (g_verboseHealthChecks) {
- infolog("Error while sending a health check query (ID %d) to backend %s: %d", queryID, downstream->getNameWithAddr(), ret);
+ infolog("Error while sending a health check query (ID %d) to backend %s: %d", data->d_queryID, downstream->getNameWithAddr(), ret);
}
return false;
}
#include "dnsdist.hh"
#include "mplexer.hh"
#include "sstuff.hh"
+#include "tcpiohandler-mplexer.hh"
extern bool g_verboseHealthChecks;
bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& downstream, bool initial = false);
void handleQueuedHealthChecks(FDMultiplexer& mplexer, bool initial = false);
+
+struct HealthCheckData
+{
+ enum class TCPState : uint8_t
+ {
+ WritingQuery,
+ ReadingResponseSize,
+ ReadingResponse
+ };
+
+ HealthCheckData(FDMultiplexer* mplexer, std::shared_ptr<DownstreamState> downstream, DNSName&& checkName, uint16_t checkType, uint16_t checkClass, uint16_t queryID) :
+ d_ds(std::move(downstream)), d_mplexer(mplexer), d_udpSocket(-1), d_checkName(std::move(checkName)), d_checkType(checkType), d_checkClass(checkClass), d_queryID(queryID)
+ {
+ }
+
+ const std::shared_ptr<DownstreamState> d_ds;
+ FDMultiplexer* d_mplexer{nullptr};
+ std::unique_ptr<TCPIOHandler> d_tcpHandler{nullptr};
+ std::unique_ptr<IOStateHandler> d_ioState{nullptr};
+ PacketBuffer d_buffer;
+ Socket d_udpSocket;
+ DNSName d_checkName;
+ struct timeval d_ttd
+ {
+ 0, 0
+ };
+ size_t d_bufferPos{0};
+ uint16_t d_checkType;
+ uint16_t d_checkClass;
+ uint16_t d_queryID;
+ TCPState d_tcpState{TCPState::WritingQuery};
+ bool d_initial{false};
+};
+
+PacketBuffer getHealthCheckPacket(const std::shared_ptr<DownstreamState>& ds, FDMultiplexer* mplexer, std::shared_ptr<HealthCheckData>& data);
+void setHealthCheckTime(const std::shared_ptr<DownstreamState>& ds, const std::shared_ptr<HealthCheckData>& data);
+bool handleResponse(std::shared_ptr<HealthCheckData>& data);
+
--- /dev/null
+AC_DEFUN([PDNS_WITH_XSK],[
+ AC_MSG_CHECKING([if we have xsk support])
+ AC_ARG_WITH([xsk],
+ AS_HELP_STRING([--with-xsk],[enable xsk support @<:@default=auto@:>@]),
+ [with_xsk=$withval],
+ [with_xsk=auto],
+ )
+ AC_MSG_RESULT([$with_xsk])
+
+ AS_IF([test "x$with_xsk" != "xno"], [
+ AS_IF([test "x$with_xsk" = "xyes" -o "x$with_xsk" = "xauto"], [
+ AC_CHECK_HEADERS([xdp/xsk.h], xsk_headers=yes, xsk_headers=no)
+ ])
+ ])
+ AS_IF([test "x$with_xsk" = "xyes"], [
+ AS_IF([test x"$xsk_headers" = "no"], [
+ AC_MSG_ERROR([XSK support requested but required libxdp were not found])
+ ])
+ ])
+ AS_IF([test x"$xsk_headers" = "xyes" ], [ AC_DEFINE([HAVE_XSK], [1], [Define if using eBPF.]) ])
+ AM_CONDITIONAL([HAVE_XSK], [test x"$xsk_headers" = "xyes" ])
+])
return ProcessQueryResult::Drop;
}
-bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength)
+bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote)
{
return true;
}
--- /dev/null
+../xsk.cc
\ No newline at end of file
--- /dev/null
+../xsk.hh
\ No newline at end of file
#undef IP_PKTINFO
#endif
+using MACAddr = uint8_t[6];
union ComboAddress {
struct sockaddr_in sin4;
struct sockaddr_in6 sin6;
return rhs.operator<(*this);
}
+ struct addressPortOnlyHash
+ {
+ uint32_t operator()(const ComboAddress& ca) const
+ {
+ const unsigned char* start = nullptr;
+ if (ca.sin4.sin_family == AF_INET) {
+ start = reinterpret_cast<const unsigned char*>(&ca.sin4.sin_addr.s_addr);
+ auto tmp = burtle(start, 4, 0);
+ return burtle(reinterpret_cast<const uint8_t*>(&ca.sin4.sin_port), 2, tmp);
+ }
+ {
+ start = reinterpret_cast<const unsigned char*>(&ca.sin6.sin6_addr.s6_addr);
+ auto tmp = burtle(start, 16, 0);
+ return burtle(reinterpret_cast<const unsigned char*>(&ca.sin6.sin6_port), 2, tmp);
+ }
+ }
+ };
+
struct addressOnlyHash
{
uint32_t operator()(const ComboAddress& ca) const
void truncate(unsigned int bits) noexcept;
- uint16_t getPort() const
+ uint16_t getNetworkOrderPort() const noexcept
{
- return ntohs(sin4.sin_port);
+ return sin4.sin_port;
+ }
+ uint16_t getPort() const noexcept
+ {
+ return ntohs(getNetworkOrderPort());
}
-
void setPort(uint16_t port)
{
sin4.sin_port = htons(port);
return false;
}
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool)
{
return true;
}
return false;
}
+std::vector<std::shared_ptr<XskSocket>> g_xsk;
+
BOOST_AUTO_TEST_SUITE(test_dnsdist_cc)
static const uint16_t ECSSourcePrefixV4 = 24;
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include "gettime.hh"
+#include "xsk.hh"
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <fcntl.h>
+#include <iterator>
+#include <linux/bpf.h>
+#include <linux/if_ether.h>
+#include <linux/if_link.h>
+#include <linux/if_xdp.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <linux/tcp.h>
+#include <linux/udp.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <poll.h>
+#include <stdexcept>
+#include <sys/eventfd.h>
+#include <sys/ioctl.h>
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <sys/timerfd.h>
+#include <unistd.h>
+#include <vector>
+
+#ifdef HAVE_XSK
+#include <bpf/bpf.h>
+#include <bpf/libbpf.h>
+extern "C"
+{
+#include <xdp/libxdp.h>
+}
+
+constexpr bool XskSocket::isPowOfTwo(uint32_t value) noexcept
+{
+ return value != 0 && (value & (value - 1)) == 0;
+}
+int XskSocket::firstTimeout()
+{
+ if (waitForDelay.empty()) {
+ return -1;
+ }
+ timespec now;
+ gettime(&now);
+ const auto& firstTime = waitForDelay.top()->sendTime;
+ const auto res = timeDifference(now, firstTime);
+ if (res <= 0) {
+ return 0;
+ }
+ return res;
+}
+XskSocket::XskSocket(size_t frameNum_, const std::string& ifName_, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_) :
+ frameNum(frameNum_), queueId(queue_id), ifName(ifName_), poolName(poolName_), socket(nullptr, xsk_socket__delete), sharedEmptyFrameOffset(std::make_shared<LockGuarded<vector<uint64_t>>>())
+{
+ if (!isPowOfTwo(frameNum_) || !isPowOfTwo(frameSize)
+ || !isPowOfTwo(fqCapacity) || !isPowOfTwo(cqCapacity) || !isPowOfTwo(rxCapacity) || !isPowOfTwo(txCapacity)) {
+ throw std::runtime_error("The number of frame , the size of frame and the capacity of rings must is a pow of 2");
+ }
+ getMACFromIfName();
+
+ memset(&cq, 0, sizeof(cq));
+ memset(&fq, 0, sizeof(fq));
+ memset(&tx, 0, sizeof(tx));
+ memset(&rx, 0, sizeof(rx));
+ xsk_umem_config umemCfg;
+ umemCfg.fill_size = fqCapacity;
+ umemCfg.comp_size = cqCapacity;
+ umemCfg.frame_size = frameSize;
+ umemCfg.frame_headroom = XSK_UMEM__DEFAULT_FRAME_HEADROOM;
+ umemCfg.flags = 0;
+ umem.umemInit(frameNum_ * frameSize, &cq, &fq, &umemCfg);
+ {
+ xsk_socket_config socketCfg;
+ socketCfg.rx_size = rxCapacity;
+ socketCfg.tx_size = txCapacity;
+ socketCfg.bind_flags = XDP_USE_NEED_WAKEUP;
+ socketCfg.xdp_flags = XDP_FLAGS_SKB_MODE;
+ socketCfg.libxdp_flags = XSK_LIBBPF_FLAGS__INHIBIT_PROG_LOAD;
+ xsk_socket* tmp = nullptr;
+ auto ret = xsk_socket__create(&tmp, ifName.c_str(), queue_id, umem.umem, &rx, &tx, &socketCfg);
+ if (ret != 0) {
+ throw std::runtime_error("Error creating a xsk socket of if_name" + ifName + stringerror(ret));
+ }
+ socket = std::unique_ptr<xsk_socket, void (*)(xsk_socket*)>(tmp, xsk_socket__delete);
+ }
+ for (uint64_t i = 0; i < frameNum; i++) {
+ uniqueEmptyFrameOffset.push_back(i * frameSize + XDP_PACKET_HEADROOM);
+ }
+ fillFq(fqCapacity);
+ const auto xskfd = xskFd();
+ fds.push_back(pollfd{
+ .fd = xskfd,
+ .events = POLLIN,
+ .revents = 0});
+ const auto xskMapFd = FDWrapper(bpf_obj_get(xskMapPath.c_str()));
+ if (xskMapFd.getHandle() < 0) {
+ throw std::runtime_error("Error get BPF map from path");
+ }
+ auto ret = bpf_map_update_elem(xskMapFd.getHandle(), &queue_id, &xskfd, 0);
+ if (ret) {
+ throw std::runtime_error("Error insert into xsk_map");
+ }
+}
+void XskSocket::fillFq(uint32_t fillSize) noexcept
+{
+ {
+ auto frames = sharedEmptyFrameOffset->lock();
+ if (frames->size() < holdThreshold) {
+ const auto moveSize = std::min(holdThreshold - frames->size(), uniqueEmptyFrameOffset.size());
+ if (moveSize > 0) {
+ frames->insert(frames->end(), std::make_move_iterator(uniqueEmptyFrameOffset.end() - moveSize), std::make_move_iterator(uniqueEmptyFrameOffset.end()));
+ }
+ }
+ }
+ if (uniqueEmptyFrameOffset.size() < fillSize) {
+ return;
+ }
+ uint32_t idx;
+ if (xsk_ring_prod__reserve(&fq, fillSize, &idx) != fillSize) {
+ return;
+ }
+ for (uint32_t i = 0; i < fillSize; i++) {
+ *xsk_ring_prod__fill_addr(&fq, idx++) = uniqueEmptyFrameOffset.back();
+ uniqueEmptyFrameOffset.pop_back();
+ }
+ xsk_ring_prod__submit(&fq, idx);
+}
+int XskSocket::wait(int timeout)
+{
+ return poll(fds.data(), fds.size(), static_cast<int>(std::min(static_cast<uint32_t>(timeout), static_cast<uint32_t>(firstTimeout()))));
+}
+[[nodiscard]] uint64_t XskSocket::frameOffset(const XskPacket& packet) const noexcept
+{
+ return reinterpret_cast<uint64_t>(packet.frame) - reinterpret_cast<uint64_t>(umem.bufBase);
+}
+
+int XskSocket::xskFd() const noexcept { return xsk_socket__fd(socket.get()); }
+
+void XskSocket::send(std::vector<XskPacketPtr>& packets)
+{
+ const auto packetSize = packets.size();
+ if (packetSize == 0) {
+ return;
+ }
+ uint32_t idx;
+ if (xsk_ring_prod__reserve(&tx, packetSize, &idx) != packets.size()) {
+ return;
+ }
+
+ for (const auto& i : packets) {
+ *xsk_ring_prod__tx_desc(&tx, idx++) = {
+ .addr = frameOffset(*i),
+ .len = i->FrameLen(),
+ .options = 0};
+ }
+ xsk_ring_prod__submit(&tx, packetSize);
+ packets.clear();
+}
+std::vector<XskPacketPtr> XskSocket::recv(uint32_t recvSizeMax, uint32_t* failedCount)
+{
+ uint32_t idx;
+ std::vector<XskPacketPtr> res;
+ const auto recvSize = xsk_ring_cons__peek(&rx, recvSizeMax, &idx);
+ if (recvSize <= 0) {
+ return res;
+ }
+ const auto baseAddr = reinterpret_cast<uint64_t>(umem.bufBase);
+ uint32_t count = 0;
+ for (uint32_t i = 0; i < recvSize; i++) {
+ const auto* desc = xsk_ring_cons__rx_desc(&rx, idx++);
+ auto ptr = std::make_unique<XskPacket>(reinterpret_cast<void*>(desc->addr + baseAddr), desc->len, frameSize);
+ if (!ptr->parse()) {
+ ++count;
+ uniqueEmptyFrameOffset.push_back(frameOffset(*ptr));
+ }
+ else {
+ res.push_back(std::move(ptr));
+ }
+ }
+ xsk_ring_cons__release(&rx, recvSize);
+ if (failedCount) {
+ *failedCount = count;
+ }
+ return res;
+}
+void XskSocket::pickUpReadyPacket(std::vector<XskPacketPtr>& packets)
+{
+ timespec now;
+ gettime(&now);
+ while (!waitForDelay.empty() && timeDifference(now, waitForDelay.top()->sendTime) <= 0) {
+ auto& top = const_cast<XskPacketPtr&>(waitForDelay.top());
+ packets.push_back(std::move(top));
+ waitForDelay.pop();
+ }
+}
+void XskSocket::recycle(size_t size) noexcept
+{
+ uint32_t idx;
+ const auto completeSize = xsk_ring_cons__peek(&cq, size, &idx);
+ if (completeSize <= 0) {
+ return;
+ }
+ for (uint32_t i = 0; i < completeSize; ++i) {
+ uniqueEmptyFrameOffset.push_back(*xsk_ring_cons__comp_addr(&cq, idx++));
+ }
+ xsk_ring_cons__release(&cq, completeSize);
+}
+
+void XskSocket::XskUmem::umemInit(size_t memSize, xsk_ring_cons* cq, xsk_ring_prod* fq, xsk_umem_config* config)
+{
+ size = memSize;
+ bufBase = static_cast<uint8_t*>(mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
+ if (bufBase == MAP_FAILED) {
+ throw std::runtime_error("mmap failed");
+ }
+ auto ret = xsk_umem__create(&umem, bufBase, size, fq, cq, config);
+ if (ret != 0) {
+ munmap(bufBase, size);
+ throw std::runtime_error("Error creating a umem of size" + std::to_string(size) + stringerror(ret));
+ }
+}
+
+XskSocket::XskUmem::~XskUmem()
+{
+ if (umem) {
+ xsk_umem__delete(umem);
+ }
+ if (bufBase) {
+ munmap(bufBase, size);
+ }
+}
+
+bool XskPacket::parse()
+{
+ // payloadEnd must bigger than payload + sizeof(ethhdr) + sizoef(iphdr) + sizeof(udphdr)
+ auto* eth = reinterpret_cast<ethhdr*>(frame);
+ uint8_t l4Protocol;
+ if (eth->h_proto == htons(ETH_P_IP)) {
+ auto* ip = reinterpret_cast<iphdr*>(eth + 1);
+ if (ip->ihl != static_cast<uint8_t>(sizeof(iphdr) >> 2)) {
+ // ip->ihl*4 != sizeof(iphdr)
+ // ip options is not supported now!
+ return false;
+ }
+ // check ip.check == ipv4Checksum() is not needed!
+ // We check it in BPF program
+ from = makeComboAddressFromRaw(4, reinterpret_cast<const char*>(&ip->saddr), sizeof(ip->saddr));
+ to = makeComboAddressFromRaw(4, reinterpret_cast<const char*>(&ip->daddr), sizeof(ip->daddr));
+ l4Protocol = ip->protocol;
+ l4Header = reinterpret_cast<uint8_t*>(ip + 1);
+ payloadEnd = std::min(reinterpret_cast<uint8_t*>(ip) + ntohs(ip->tot_len), payloadEnd);
+ }
+ else if (eth->h_proto == htons(ETH_P_IPV6)) {
+ auto* ipv6 = reinterpret_cast<ipv6hdr*>(eth + 1);
+ l4Header = reinterpret_cast<uint8_t*>(ipv6 + 1);
+ if (l4Header >= payloadEnd) {
+ return false;
+ }
+ from = makeComboAddressFromRaw(6, reinterpret_cast<const char*>(&ipv6->saddr), sizeof(ipv6->saddr));
+ to = makeComboAddressFromRaw(6, reinterpret_cast<const char*>(&ipv6->daddr), sizeof(ipv6->daddr));
+ l4Protocol = ipv6->nexthdr;
+ payloadEnd = std::min(l4Header + ntohs(ipv6->payload_len), payloadEnd);
+ }
+ else {
+ return false;
+ }
+ if (l4Protocol == IPPROTO_UDP) {
+ // check udp.check == ipv4Checksum() is not needed!
+ // We check it in BPF program
+ auto* udp = reinterpret_cast<udphdr*>(l4Header);
+ payload = l4Header + sizeof(udphdr);
+ // Because of XskPacket::setHeader
+ // payload = payloadEnd should be allow
+ if (payload > payloadEnd) {
+ return false;
+ }
+ payloadEnd = std::min(l4Header + ntohs(udp->len), payloadEnd);
+ from.setPort(ntohs(udp->source));
+ to.setPort(ntohs(udp->dest));
+ return true;
+ }
+ if (l4Protocol == IPPROTO_TCP) {
+ // check tcp.check == ipv4Checksum() is not needed!
+ // We check it in BPF program
+ auto* tcp = reinterpret_cast<tcphdr*>(l4Header);
+ if (tcp->doff != static_cast<uint32_t>(sizeof(tcphdr) >> 2)) {
+ // tcp is not supported now!
+ return false;
+ }
+ payload = l4Header + sizeof(tcphdr);
+ //
+ if (payload > payloadEnd) {
+ return false;
+ }
+ from.setPort(ntohs(tcp->source));
+ to.setPort(ntohs(tcp->dest));
+ flags |= TCP;
+ return true;
+ }
+ // ipv6 extension header is not supported now!
+ return false;
+}
+
+uint32_t XskPacket::dataLen() const noexcept
+{
+ return payloadEnd - payload;
+}
+uint32_t XskPacket::FrameLen() const noexcept
+{
+ return payloadEnd - frame;
+}
+size_t XskPacket::capacity() const noexcept
+{
+ return frameEnd - payloadEnd;
+}
+
+void XskPacket::changeDirectAndUpdateChecksum() noexcept
+{
+ auto* eth = reinterpret_cast<ethhdr*>(frame);
+ {
+ uint8_t tmp[ETH_ALEN];
+ static_assert(sizeof(tmp) == sizeof(eth->h_dest), "Size Error");
+ static_assert(sizeof(tmp) == sizeof(eth->h_source), "Size Error");
+ memcpy(tmp, eth->h_dest, sizeof(tmp));
+ memcpy(eth->h_dest, eth->h_source, sizeof(tmp));
+ memcpy(eth->h_source, tmp, sizeof(tmp));
+ }
+ if (eth->h_proto == htons(ETH_P_IPV6)) {
+ // IPV6
+ auto* ipv6 = reinterpret_cast<ipv6hdr*>(eth + 1);
+ std::swap(ipv6->daddr, ipv6->saddr);
+ if (ipv6->nexthdr == IPPROTO_UDP) {
+ // UDP
+ auto* udp = reinterpret_cast<udphdr*>(ipv6 + 1);
+ std::swap(udp->dest, udp->source);
+ udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+ udp->check = 0;
+ udp->check = tcp_udp_v6_checksum();
+ }
+ else {
+ // TCP
+ auto* tcp = reinterpret_cast<tcphdr*>(ipv6 + 1);
+ std::swap(tcp->dest, tcp->source);
+ // TODO
+ }
+ rewriteIpv6Header(ipv6);
+ }
+ else {
+ // IPV4
+ auto* ipv4 = reinterpret_cast<iphdr*>(eth + 1);
+ std::swap(ipv4->daddr, ipv4->saddr);
+ if (ipv4->protocol == IPPROTO_UDP) {
+ // UDP
+ auto* udp = reinterpret_cast<udphdr*>(ipv4 + 1);
+ std::swap(udp->dest, udp->source);
+ udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+ udp->check = 0;
+ udp->check = tcp_udp_v4_checksum();
+ }
+ else {
+ // TCP
+ auto* tcp = reinterpret_cast<tcphdr*>(ipv4 + 1);
+ std::swap(tcp->dest, tcp->source);
+ // TODO
+ }
+ rewriteIpv4Header(ipv4);
+ }
+}
+void XskPacket::rewriteIpv4Header(void* ipv4header) noexcept
+{
+ auto* ipv4 = static_cast<iphdr*>(ipv4header);
+ ipv4->version = 4;
+ ipv4->ihl = sizeof(iphdr) / 4;
+ ipv4->tos = 0;
+ ipv4->tot_len = htons(payloadEnd - reinterpret_cast<uint8_t*>(ipv4));
+ ipv4->id = 0;
+ ipv4->frag_off = 0;
+ ipv4->ttl = DefaultTTL;
+ ipv4->check = 0;
+ ipv4->check = ipv4Checksum();
+}
+void XskPacket::rewriteIpv6Header(void* ipv6header) noexcept
+{
+ auto* ipv6 = static_cast<ipv6hdr*>(ipv6header);
+ ipv6->version = 6;
+ ipv6->priority = 0;
+ ipv6->payload_len = htons(payloadEnd - reinterpret_cast<uint8_t*>(ipv6 + 1));
+ ipv6->hop_limit = DefaultTTL;
+ memset(&ipv6->flow_lbl, 0, sizeof(ipv6->flow_lbl));
+}
+
+bool XskPacket::isIPV6() const noexcept
+{
+ const auto* eth = reinterpret_cast<ethhdr*>(frame);
+ return eth->h_proto == htons(ETH_P_IPV6);
+}
+XskPacket::XskPacket(void* frame_, size_t dataSize, size_t frameSize) :
+ frame(static_cast<uint8_t*>(frame_)), payloadEnd(static_cast<uint8_t*>(frame) + dataSize), frameEnd(static_cast<uint8_t*>(frame) + frameSize - XDP_PACKET_HEADROOM)
+{
+}
+PacketBuffer XskPacket::clonePacketBuffer() const
+{
+ const auto size = dataLen();
+ PacketBuffer tmp(size);
+ memcpy(tmp.data(), payload, size);
+ return tmp;
+}
+void XskPacket::cloneIntoPacketBuffer(PacketBuffer& buffer) const
+{
+ const auto size = dataLen();
+ buffer.resize(size);
+ memcpy(buffer.data(), payload, size);
+}
+bool XskPacket::setPayload(const PacketBuffer& buf)
+{
+ const auto bufSize = buf.size();
+ if (bufSize == 0 || bufSize > capacity()) {
+ return false;
+ }
+ flags |= UPDATE;
+ memcpy(payload, buf.data(), bufSize);
+ payloadEnd = payload + bufSize;
+ return true;
+}
+void XskPacket::addDelay(const int relativeMilliseconds) noexcept
+{
+ gettime(&sendTime);
+ sendTime.tv_nsec += static_cast<uint64_t>(relativeMilliseconds) * 1000000L;
+ sendTime.tv_sec += sendTime.tv_nsec / 1000000000L;
+ sendTime.tv_nsec %= 1000000000L;
+}
+bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept
+{
+ return s1->sendTime < s2->sendTime;
+}
+const ComboAddress& XskPacket::getFromAddr() const noexcept
+{
+ return from;
+}
+const ComboAddress& XskPacket::getToAddr() const noexcept
+{
+ return to;
+}
+void XskWorker::notify(int fd)
+{
+ uint64_t value = 1;
+ ssize_t res = 0;
+ while ((res = write(fd, &value, sizeof(value))) == EINTR) {
+ }
+ if (res != sizeof(value)) {
+ throw runtime_error("Unable Wake Up XskSocket Failed");
+ }
+}
+XskWorker::XskWorker() :
+ workerWaker(createEventfd()), xskSocketWaker(createEventfd())
+{
+}
+void* XskPacket::payloadData()
+{
+ return reinterpret_cast<void*>(payload);
+}
+const void* XskPacket::payloadData() const
+{
+ return reinterpret_cast<const void*>(payload);
+}
+void XskPacket::setAddr(const ComboAddress& from_, MACAddr fromMAC, const ComboAddress& to_, MACAddr toMAC, bool tcp) noexcept
+{
+ auto* eth = reinterpret_cast<ethhdr*>(frame);
+ memcpy(eth->h_dest, &toMAC[0], sizeof(MACAddr));
+ memcpy(eth->h_source, &fromMAC[0], sizeof(MACAddr));
+ to = to_;
+ from = from_;
+ l4Header = frame + sizeof(ethhdr) + (to.isIPv4() ? sizeof(iphdr) : sizeof(ipv6hdr));
+ if (tcp) {
+ flags = TCP;
+ payload = l4Header + sizeof(tcphdr);
+ }
+ else {
+ flags = 0;
+ payload = l4Header + sizeof(udphdr);
+ }
+}
+void XskPacket::rewrite() noexcept
+{
+ flags |= REWRITE;
+ auto* eth = reinterpret_cast<ethhdr*>(frame);
+ if (to.isIPv4()) {
+ eth->h_proto = htons(ETH_P_IP);
+ auto* ipv4 = reinterpret_cast<iphdr*>(eth + 1);
+
+ ipv4->daddr = to.sin4.sin_addr.s_addr;
+ ipv4->saddr = from.sin4.sin_addr.s_addr;
+ if (flags & XskPacket::TCP) {
+ auto* tcp = reinterpret_cast<tcphdr*>(ipv4 + 1);
+ ipv4->protocol = IPPROTO_TCP;
+ tcp->source = from.sin4.sin_port;
+ tcp->dest = to.sin4.sin_port;
+ // TODO
+ }
+ else {
+ auto* udp = reinterpret_cast<udphdr*>(ipv4 + 1);
+ ipv4->protocol = IPPROTO_UDP;
+ udp->source = from.sin4.sin_port;
+ udp->dest = to.sin4.sin_port;
+ udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+ udp->check = 0;
+ udp->check = tcp_udp_v4_checksum();
+ }
+ rewriteIpv4Header(ipv4);
+ }
+ else {
+ auto* ipv6 = reinterpret_cast<ipv6hdr*>(eth + 1);
+ memcpy(&ipv6->daddr, &to.sin6.sin6_addr, sizeof(ipv6->daddr));
+ memcpy(&ipv6->saddr, &from.sin6.sin6_addr, sizeof(ipv6->saddr));
+ if (flags & XskPacket::TCP) {
+ auto* tcp = reinterpret_cast<tcphdr*>(ipv6 + 1);
+ ipv6->nexthdr = IPPROTO_TCP;
+ tcp->source = from.sin6.sin6_port;
+ tcp->dest = to.sin6.sin6_port;
+ // TODO
+ }
+ else {
+ auto* udp = reinterpret_cast<udphdr*>(ipv6 + 1);
+ ipv6->nexthdr = IPPROTO_UDP;
+ udp->source = from.sin6.sin6_port;
+ udp->dest = to.sin6.sin6_port;
+ udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+ udp->check = 0;
+ udp->check = tcp_udp_v6_checksum();
+ }
+ }
+}
+
+[[nodiscard]] __be16 XskPacket::ipv4Checksum() const noexcept
+{
+ auto* ip = reinterpret_cast<iphdr*>(frame + sizeof(ethhdr));
+ return ip_checksum_fold(ip_checksum_partial(ip, sizeof(iphdr), 0));
+}
+[[nodiscard]] __be16 XskPacket::tcp_udp_v4_checksum() const noexcept
+{
+ const auto* ip = reinterpret_cast<iphdr*>(frame + sizeof(ethhdr));
+ // ip options is not supported !!!
+ const auto l4Length = static_cast<uint16_t>(payloadEnd - l4Header);
+ auto sum = tcp_udp_v4_header_checksum_partial(ip->saddr, ip->daddr, ip->protocol, l4Length);
+ sum = ip_checksum_partial(l4Header, l4Length, sum);
+ return ip_checksum_fold(sum);
+}
+[[nodiscard]] __be16 XskPacket::tcp_udp_v6_checksum() const noexcept
+{
+ const auto* ipv6 = reinterpret_cast<ipv6hdr*>(frame + sizeof(ethhdr));
+ const auto l4Length = static_cast<uint16_t>(payloadEnd - l4Header);
+ uint64_t sum = tcp_udp_v6_header_checksum_partial(&ipv6->saddr, &ipv6->daddr, ipv6->nexthdr, l4Length);
+ sum = ip_checksum_partial(l4Header, l4Length, sum);
+ return ip_checksum_fold(sum);
+}
+
+#ifndef __packed
+#define __packed __attribute__((packed))
+#endif
+[[nodiscard]] uint64_t XskPacket::ip_checksum_partial(const void* p, size_t len, uint64_t sum) noexcept
+{
+ /* Main loop: 32 bits at a time.
+ * We take advantage of intel's ability to do unaligned memory
+ * accesses with minimal additional cost. Other architectures
+ * probably want to be more careful here.
+ */
+ const uint32_t* p32 = (const uint32_t*)(p);
+ for (; len >= sizeof(*p32); len -= sizeof(*p32))
+ sum += *p32++;
+
+ /* Handle un-32bit-aligned trailing bytes */
+ const uint16_t* p16 = (const uint16_t*)(p32);
+ if (len >= 2) {
+ sum += *p16++;
+ len -= sizeof(*p16);
+ }
+ if (len > 0) {
+ const uint8_t* p8 = (const uint8_t*)(p16);
+ sum += ntohs(*p8 << 8); /* RFC says pad last byte */
+ }
+
+ return sum;
+}
+[[nodiscard]] __be16 XskPacket::ip_checksum_fold(uint64_t sum) noexcept
+{
+ while (sum & ~0xffffffffULL)
+ sum = (sum >> 32) + (sum & 0xffffffffULL);
+ while (sum & 0xffff0000ULL)
+ sum = (sum >> 16) + (sum & 0xffffULL);
+
+ return ~sum;
+}
+[[nodiscard]] uint64_t XskPacket::tcp_udp_v4_header_checksum_partial(__be32 src_ip, __be32 dst_ip, uint8_t protocol, uint16_t len) noexcept
+{
+ struct header
+ {
+ __be32 src_ip;
+ __be32 dst_ip;
+ __uint8_t mbz;
+ __uint8_t protocol;
+ __be16 length;
+ };
+ /* The IPv4 pseudo-header is defined in RFC 793, Section 3.1. */
+ struct ipv4_pseudo_header_t
+ {
+ /* We use a union here to avoid aliasing issues with gcc -O2 */
+ union
+ {
+ header __packed fields;
+ uint32_t words[3];
+ };
+ };
+ struct ipv4_pseudo_header_t pseudo_header;
+ assert(sizeof(pseudo_header) == 12);
+
+ /* Fill in the pseudo-header. */
+ pseudo_header.fields.src_ip = src_ip;
+ pseudo_header.fields.dst_ip = dst_ip;
+ pseudo_header.fields.mbz = 0;
+ pseudo_header.fields.protocol = protocol;
+ pseudo_header.fields.length = htons(len);
+ return ip_checksum_partial(&pseudo_header, sizeof(pseudo_header), 0);
+}
+[[nodiscard]] uint64_t XskPacket::tcp_udp_v6_header_checksum_partial(const struct in6_addr* src_ip, const struct in6_addr* dst_ip, uint8_t protocol, uint32_t len) noexcept
+{
+ struct header
+ {
+ struct in6_addr src_ip;
+ struct in6_addr dst_ip;
+ __be32 length;
+ __uint8_t mbz[3];
+ __uint8_t next_header;
+ };
+ /* The IPv6 pseudo-header is defined in RFC 2460, Section 8.1. */
+ struct ipv6_pseudo_header_t
+ {
+ /* We use a union here to avoid aliasing issues with gcc -O2 */
+ union
+ {
+ header __packed fields;
+ uint32_t words[10];
+ };
+ };
+ struct ipv6_pseudo_header_t pseudo_header;
+ assert(sizeof(pseudo_header) == 40);
+
+ /* Fill in the pseudo-header. */
+ pseudo_header.fields.src_ip = *src_ip;
+ pseudo_header.fields.dst_ip = *dst_ip;
+ pseudo_header.fields.length = htonl(len);
+ memset(pseudo_header.fields.mbz, 0, sizeof(pseudo_header.fields.mbz));
+ pseudo_header.fields.next_header = protocol;
+ return ip_checksum_partial(&pseudo_header, sizeof(pseudo_header), 0);
+}
+void XskPacket::setHeader(const PacketBuffer& buf) noexcept
+{
+ memcpy(frame, buf.data(), buf.size());
+ payloadEnd = frame + buf.size();
+ flags = 0;
+ const auto res = parse();
+ assert(res);
+}
+std::unique_ptr<PacketBuffer> XskPacket::cloneHeadertoPacketBuffer() const
+{
+ const auto size = payload - frame;
+ auto tmp = std::make_unique<PacketBuffer>(size);
+ memcpy(tmp->data(), frame, size);
+ return tmp;
+}
+int XskWorker::createEventfd()
+{
+ auto fd = ::eventfd(0, EFD_CLOEXEC);
+ if (fd < 0) {
+ throw runtime_error("Unable create eventfd");
+ }
+ return fd;
+}
+void XskWorker::waitForXskSocket() noexcept
+{
+ uint64_t x = read(workerWaker, &x, sizeof(x));
+}
+void XskWorker::notifyXskSocket() noexcept
+{
+ notify(xskSocketWaker);
+}
+
+std::shared_ptr<XskWorker> XskWorker::create()
+{
+ return std::make_shared<XskWorker>();
+}
+void XskSocket::addWorker(std::shared_ptr<XskWorker> s, const ComboAddress& dest, bool isTCP)
+{
+ extern std::atomic<bool> g_configurationDone;
+ if (g_configurationDone) {
+ throw runtime_error("Adding a server with xsk at runtime is not supported");
+ }
+ s->poolName = poolName;
+ const auto socketWaker = s->xskSocketWaker.getHandle();
+ const auto workerWaker = s->workerWaker.getHandle();
+ const auto& socketWakerIdx = workers.get<0>();
+ if (socketWakerIdx.contains(socketWaker)) {
+ throw runtime_error("Server already exist");
+ }
+ s->umemBufBase = umem.bufBase;
+ workers.insert(XskRouteInfo{
+ .worker = std::move(s),
+ .dest = dest,
+ .xskSocketWaker = socketWaker,
+ .workerWaker = workerWaker,
+ });
+ fds.push_back(pollfd{
+ .fd = socketWaker,
+ .events = POLLIN,
+ .revents = 0});
+};
+uint64_t XskWorker::frameOffset(const XskPacket& s) const noexcept
+{
+ return s.frame - umemBufBase;
+}
+void XskWorker::notifyWorker() noexcept
+{
+ notify(workerWaker);
+}
+void XskSocket::getMACFromIfName()
+{
+ ifreq ifr;
+ auto fd = ::socket(AF_INET, SOCK_DGRAM, 0);
+ strncpy(ifr.ifr_name, ifName.c_str(), ifName.length() + 1);
+ if (ioctl(fd, SIOCGIFHWADDR, &ifr) < 0) {
+ throw runtime_error("Error get MAC addr");
+ }
+ memcpy(source, ifr.ifr_hwaddr.sa_data, sizeof(source));
+ close(fd);
+}
+[[nodiscard]] int XskSocket::timeDifference(const timespec& t1, const timespec& t2) noexcept
+{
+ const auto res = t1.tv_sec * 1000 + t1.tv_nsec / 1000000L - (t2.tv_sec * 1000 + t2.tv_nsec / 1000000L);
+ return static_cast<int>(res);
+}
+void XskWorker::cleanWorkerNotification() noexcept
+{
+ uint64_t x = read(xskSocketWaker, &x, sizeof(x));
+}
+void XskWorker::cleanSocketNotification() noexcept
+{
+ uint64_t x = read(workerWaker, &x, sizeof(x));
+}
+std::vector<pollfd> getPollFdsForWorker(XskWorker& info)
+{
+ std::vector<pollfd> fds;
+ int timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC);
+ if (timerfd < 0) {
+ throw std::runtime_error("create_timerfd failed");
+ }
+ fds.push_back(pollfd{
+ .fd = info.workerWaker,
+ .events = POLLIN,
+ .revents = 0,
+ });
+ fds.push_back(pollfd{
+ .fd = timerfd,
+ .events = POLLIN,
+ .revents = 0,
+ });
+ return fds;
+}
+void XskWorker::fillUniqueEmptyOffset()
+{
+ auto frames = sharedEmptyFrameOffset->lock();
+ const auto moveSize = std::min(static_cast<size_t>(32), frames->size());
+ if (moveSize > 0) {
+ uniqueEmptyFrameOffset.insert(uniqueEmptyFrameOffset.end(), std::make_move_iterator(frames->end() - moveSize), std::make_move_iterator(frames->end()));
+ }
+}
+void* XskWorker::getEmptyframe()
+{
+ if (!uniqueEmptyFrameOffset.empty()) {
+ auto offset = uniqueEmptyFrameOffset.back();
+ uniqueEmptyFrameOffset.pop_back();
+ return offset + umemBufBase;
+ }
+ fillUniqueEmptyOffset();
+ if (!uniqueEmptyFrameOffset.empty()) {
+ auto offset = uniqueEmptyFrameOffset.back();
+ uniqueEmptyFrameOffset.pop_back();
+ return offset + umemBufBase;
+ }
+ return nullptr;
+}
+uint32_t XskPacket::getFlags() const noexcept
+{
+ return flags;
+}
+void XskPacket::updatePacket() noexcept
+{
+ if (!(flags & UPDATE)) {
+ return;
+ }
+ if (!(flags & REWRITE)) {
+ changeDirectAndUpdateChecksum();
+ }
+}
+#endif /* HAVE_XSK */
--- /dev/null
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+#include "iputils.hh"
+#include "misc.hh"
+#include "noinitvector.hh"
+#include "lock.hh"
+
+#include <array>
+#include <bits/types/struct_timespec.h>
+#include <boost/lockfree/spsc_queue.hpp>
+#include <boost/multi_index/hashed_index.hpp>
+#include <boost/multi_index/indexed_by.hpp>
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/member.hpp>
+#include <cstdint>
+#include <cstring>
+#include <linux/types.h>
+#include <memory>
+#include <queue>
+#include <stdexcept>
+#include <string>
+#include <sys/poll.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <unordered_map>
+#include <vector>
+
+#ifdef HAVE_XSK
+#include <xdp/xsk.h>
+#endif /* HAVE_XSK */
+
+class XskPacket;
+class XskWorker;
+class XskSocket;
+
+#ifdef HAVE_XSK
+using XskPacketPtr = std::unique_ptr<XskPacket>;
+
+// We use an XskSocket to manage an AF_XDP Socket corresponding to a NIC queue.
+// The XDP program running in the kernel redirects the data to the XskSocket in userspace.
+// XskSocket routes packets to multiple worker threads registered on XskSocket via XskSocket::addWorker based on the destination port number of the packet.
+// The kernel and the worker thread holding XskWorker will wake up the XskSocket through XskFd and the Eventfd corresponding to each worker thread, respectively.
+class XskSocket
+{
+ struct XskRouteInfo
+ {
+ std::shared_ptr<XskWorker> worker;
+ ComboAddress dest;
+ int xskSocketWaker;
+ int workerWaker;
+ };
+ struct XskUmem
+ {
+ xsk_umem* umem{nullptr};
+ uint8_t* bufBase{nullptr};
+ size_t size;
+ void umemInit(size_t memSize, xsk_ring_cons* cq, xsk_ring_prod* fq, xsk_umem_config* config);
+ ~XskUmem();
+ XskUmem() = default;
+ };
+ boost::multi_index_container<
+ XskRouteInfo,
+ boost::multi_index::indexed_by<
+ boost::multi_index::hashed_unique<boost::multi_index::member<XskRouteInfo, int, &XskRouteInfo::xskSocketWaker>>,
+ boost::multi_index::hashed_unique<boost::multi_index::member<XskRouteInfo, ComboAddress, &XskRouteInfo::dest>, ComboAddress::addressPortOnlyHash>>>
+ workers;
+ static constexpr size_t holdThreshold = 256;
+ static constexpr size_t fillThreshold = 128;
+ static constexpr size_t frameSize = 2048;
+ const size_t frameNum;
+ const uint32_t queueId;
+ std::priority_queue<XskPacketPtr> waitForDelay;
+ const std::string ifName;
+ const std::string poolName;
+ vector<pollfd> fds;
+ vector<uint64_t> uniqueEmptyFrameOffset;
+ xsk_ring_cons cq;
+ xsk_ring_cons rx;
+ xsk_ring_prod fq;
+ xsk_ring_prod tx;
+ std::unique_ptr<xsk_socket, void (*)(xsk_socket*)> socket;
+ XskUmem umem;
+ bpf_object* prog;
+
+ static constexpr uint32_t fqCapacity = XSK_RING_PROD__DEFAULT_NUM_DESCS * 4;
+ static constexpr uint32_t cqCapacity = XSK_RING_CONS__DEFAULT_NUM_DESCS * 4;
+ static constexpr uint32_t rxCapacity = XSK_RING_CONS__DEFAULT_NUM_DESCS * 2;
+ static constexpr uint32_t txCapacity = XSK_RING_PROD__DEFAULT_NUM_DESCS * 2;
+
+ constexpr static bool isPowOfTwo(uint32_t value) noexcept;
+ [[nodiscard]] static int timeDifference(const timespec& t1, const timespec& t2) noexcept;
+ friend void XskRouter(std::shared_ptr<XskSocket> xsk);
+
+ [[nodiscard]] uint64_t frameOffset(const XskPacket& packet) const noexcept;
+ int firstTimeout();
+ void fillFq(uint32_t fillSize = fillThreshold) noexcept;
+ void recycle(size_t size) noexcept;
+ void getMACFromIfName();
+ void pickUpReadyPacket(std::vector<XskPacketPtr>& packets);
+
+public:
+ std::shared_ptr<LockGuarded<vector<uint64_t>>> sharedEmptyFrameOffset;
+ XskSocket(size_t frameNum, const std::string& ifName, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_);
+ MACAddr source;
+ [[nodiscard]] int xskFd() const noexcept;
+ int wait(int timeout);
+ void send(std::vector<XskPacketPtr>& packets);
+ std::vector<XskPacketPtr> recv(uint32_t recvSizeMax, uint32_t* failedCount);
+ void addWorker(std::shared_ptr<XskWorker> s, const ComboAddress& dest, bool isTCP);
+};
+class XskPacket
+{
+public:
+ enum Flags : uint32_t
+ {
+ TCP = 1 << 0,
+ UPDATE = 1 << 1,
+ DELAY = 1 << 3,
+ REWRITE = 1 << 4
+ };
+
+private:
+ ComboAddress from;
+ ComboAddress to;
+ timespec sendTime;
+ uint8_t* frame;
+ uint8_t* l4Header;
+ uint8_t* payload;
+ uint8_t* payloadEnd;
+ uint8_t* frameEnd;
+ uint32_t flags{0};
+
+ friend XskSocket;
+ friend XskWorker;
+ friend bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept;
+
+ constexpr static uint8_t DefaultTTL = 64;
+ bool parse();
+ void changeDirectAndUpdateChecksum() noexcept;
+
+ // You must set ipHeader.check = 0 before call this method
+ [[nodiscard]] __be16 ipv4Checksum() const noexcept;
+ // You must set l4Header.check = 0 before call this method
+ // ip options is not supported
+ [[nodiscard]] __be16 tcp_udp_v4_checksum() const noexcept;
+ // You must set l4Header.check = 0 before call this method
+ [[nodiscard]] __be16 tcp_udp_v6_checksum() const noexcept;
+ [[nodiscard]] static uint64_t ip_checksum_partial(const void* p, size_t len, uint64_t sum) noexcept;
+ [[nodiscard]] static __be16 ip_checksum_fold(uint64_t sum) noexcept;
+ [[nodiscard]] static uint64_t tcp_udp_v4_header_checksum_partial(__be32 src_ip, __be32 dst_ip, uint8_t protocol, uint16_t len) noexcept;
+ [[nodiscard]] static uint64_t tcp_udp_v6_header_checksum_partial(const struct in6_addr* src_ip, const struct in6_addr* dst_ip, uint8_t protocol, uint32_t len) noexcept;
+ void rewriteIpv4Header(void* ipv4header) noexcept;
+ void rewriteIpv6Header(void* ipv6header) noexcept;
+
+public:
+ [[nodiscard]] const ComboAddress& getFromAddr() const noexcept;
+ [[nodiscard]] const ComboAddress& getToAddr() const noexcept;
+ [[nodiscard]] const void* payloadData() const;
+ [[nodiscard]] bool isIPV6() const noexcept;
+ [[nodiscard]] size_t capacity() const noexcept;
+ [[nodiscard]] uint32_t dataLen() const noexcept;
+ [[nodiscard]] uint32_t FrameLen() const noexcept;
+ [[nodiscard]] PacketBuffer clonePacketBuffer() const;
+ void cloneIntoPacketBuffer(PacketBuffer& buffer) const;
+ [[nodiscard]] std::unique_ptr<PacketBuffer> cloneHeadertoPacketBuffer() const;
+ [[nodiscard]] void* payloadData();
+ void setAddr(const ComboAddress& from_, MACAddr fromMAC, const ComboAddress& to_, MACAddr toMAC, bool tcp = false) noexcept;
+ bool setPayload(const PacketBuffer& buf);
+ void rewrite() noexcept;
+ void setHeader(const PacketBuffer& buf) noexcept;
+ XskPacket() = default;
+ XskPacket(void* frame, size_t dataSize, size_t frameSize);
+ void addDelay(int relativeMilliseconds) noexcept;
+ void updatePacket() noexcept;
+ [[nodiscard]] uint32_t getFlags() const noexcept;
+};
+bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept;
+
+// XskWorker obtains XskPackets of specific ports in the NIC from XskSocket through cq.
+// After finishing processing the packet, XskWorker puts the packet into sq so that XskSocket decides whether to send it through the network card according to XskPacket::flags.
+// XskWorker wakes up XskSocket via xskSocketWaker after putting the packets in sq.
+class XskWorker
+{
+ using XskPacketRing = boost::lockfree::spsc_queue<XskPacket*, boost::lockfree::capacity<512>>;
+
+public:
+ uint8_t* umemBufBase;
+ std::shared_ptr<LockGuarded<vector<uint64_t>>> sharedEmptyFrameOffset;
+ vector<uint64_t> uniqueEmptyFrameOffset;
+ XskPacketRing cq;
+ XskPacketRing sq;
+ std::string poolName;
+ size_t frameSize;
+ FDWrapper workerWaker;
+ FDWrapper xskSocketWaker;
+
+ XskWorker();
+ static int createEventfd();
+ static void notify(int fd);
+ static std::shared_ptr<XskWorker> create();
+ void notifyWorker() noexcept;
+ void notifyXskSocket() noexcept;
+ void waitForXskSocket() noexcept;
+ void cleanWorkerNotification() noexcept;
+ void cleanSocketNotification() noexcept;
+ [[nodiscard]] uint64_t frameOffset(const XskPacket& s) const noexcept;
+ void fillUniqueEmptyOffset();
+ void* getEmptyframe();
+};
+std::vector<pollfd> getPollFdsForWorker(XskWorker& info);
+#else
+class XskSocket
+{
+};
+class XskPacket
+{
+};
+class XskWorker
+{
+};
+
+#endif /* HAVE_XSK */