]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: add AF_XDP support for udp
authorY7n05h <Y7n05h@protonmail.com>
Wed, 17 Aug 2022 14:18:11 +0000 (22:18 +0800)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 23 Jan 2024 11:54:09 +0000 (12:54 +0100)
Signed-off-by: Y7n05h <Y7n05h@protonmail.com>
23 files changed:
contrib/xdp-filter.ebpf.src
contrib/xdp.py
ext/libbpf/libbpf.h
pdns/bpf-filter.cc
pdns/dnsdist-idstate.hh
pdns/dnsdist-lua-bindings.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/configure.ac
pdns/dnsdistdist/dnsdist-backend.cc
pdns/dnsdistdist/dnsdist-healthchecks.cc
pdns/dnsdistdist/dnsdist-healthchecks.hh
pdns/dnsdistdist/m4/pdns_with_xsk.m4 [new file with mode: 0644]
pdns/dnsdistdist/test-dnsdisttcp_cc.cc
pdns/dnsdistdist/xsk.cc [new symlink]
pdns/dnsdistdist/xsk.hh [new symlink]
pdns/iputils.hh
pdns/test-dnsdist_cc.cc
pdns/xsk.cc [new file with mode: 0644]
pdns/xsk.hh [new file with mode: 0644]

index 3ead6e151b96d1bc8109f5bfb8ff096eb579bce7..8577b08c08e14fe61f58c0eacf20c9fdcc273b3f 100644 (file)
@@ -17,6 +17,24 @@ BPF_TABLE_PINNED("prog", int, int, progsarray, 2, "/sys/fs/bpf/dnsdist/progs");
 BPF_TABLE_PINNED7("lpm_trie", struct CIDR4, struct map_value, cidr4filter, 1024, "/sys/fs/bpf/dnsdist/cidr4", BPF_F_NO_PREALLOC);
 BPF_TABLE_PINNED7("lpm_trie", struct CIDR6, struct map_value, cidr6filter, 1024, "/sys/fs/bpf/dnsdist/cidr6", BPF_F_NO_PREALLOC);
 
+#ifdef UseXsk
+#define BPF_XSKMAP_PIN(_name, _max_entries, _pinned) \
+  struct _name##_table_t                             \
+  {                                                  \
+    u32 key;                                         \
+    int leaf;                                        \
+    int* (*lookup)(int*);                            \
+    /* xdp_act = map.redirect_map(index, flag) */    \
+    u64 (*redirect_map)(int, int);                   \
+    u32 max_entries;                                 \
+  };                                                 \
+  __attribute__((section("maps/xskmap:" _pinned))) struct _name##_table_t _name = {.max_entries = (_max_entries)}
+
+BPF_XSKMAP_PIN(xsk_map, 16, "/sys/fs/bpf/dnsdist/xskmap");
+#endif /* UseXsk */
+
+#define COMPARE_PORT(x, p) ((x) == bpf_htons(p))
+
 /*
  * Recalculate the checksum
  * Copyright 2020, NLnet Labs, All rights reserved.
@@ -39,7 +57,7 @@ static inline void update_checksum(uint16_t *csum, uint16_t old_val, uint16_t ne
  * Set the TC bit and swap UDP ports
  * Copyright 2020, NLnet Labs, All rights reserved.
  */
-static inline enum dns_action set_tc_bit(struct udphdr *udp, struct dnshdr *dns)
+static inline void set_tc_bit(struct udphdr* udp, struct dnshdr* dns)
 {
   uint16_t old_val = dns->flags.as_value;
 
@@ -49,14 +67,12 @@ static inline enum dns_action set_tc_bit(struct udphdr *udp, struct dnshdr *dns)
   dns->flags.as_bits_and_pieces.tc = 1;
 
   // change the UDP destination to the source
-  udp->dest   = udp->source;
-  udp->source = bpf_htons(DNS_PORT);
+  uint16_t tmp = udp->dest;
+  udp->dest = udp->source;
+  udp->source = tmp;
 
   // calculate and write the new checksum
   update_checksum(&udp->check, old_val, dns->flags.as_value);
-
-  // bounce
-  return TC;
 }
 
 /*
@@ -65,41 +81,43 @@ static inline enum dns_action set_tc_bit(struct udphdr *udp, struct dnshdr *dns)
  *         TC if (modified) message needs to be replied
  *         DROP if message needs to be blocke
  */
-static inline enum dns_action check_qname(struct cursor *c)
+static inline struct map_value* check_qname(struct cursor* c)
 {
   struct dns_qname qkey = {0};
   uint8_t qname_byte;
   uint16_t qtype;
   int length = 0;
 
-  for(int i = 0; i<255; i++) {
-       if (bpf_probe_read_kernel(&qname_byte, sizeof(qname_byte), c->pos)) {
-               return PASS;
-       }
-       c->pos += 1;
-       if (length == 0) {
-      if (qname_byte == 0 || qname_byte > 63 ) {
-                 break;
+  for (int i = 0; i < 255; i++) {
+    if (bpf_probe_read_kernel(&qname_byte, sizeof(qname_byte), c->pos)) {
+      return NULL;
+    }
+    c->pos += 1;
+    if (length == 0) {
+      if (qname_byte == 0 || qname_byte > 63) {
+        break;
       }
       length += qname_byte;
-       } else {
+    }
+    else {
       length--;
     }
-       if (qname_byte >= 'A' && qname_byte <= 'Z') {
-               qkey.qname[i] = qname_byte + ('a' - 'A');
-       } else {
-               qkey.qname[i] = qname_byte;
-       }
+    if (qname_byte >= 'A' && qname_byte <= 'Z') {
+      qkey.qname[i] = qname_byte + ('a' - 'A');
+    }
+    else {
+      qkey.qname[i] = qname_byte;
+    }
   }
 
   // if the last read qbyte is not 0 incorrect QName format), return PASS
   if (qname_byte != 0) {
-       return PASS;
+    return NULL;
   }
 
   // get QType
-  if(bpf_probe_read_kernel(&qtype, sizeof(qtype), c->pos)) {
-       return PASS;
+  if (bpf_probe_read_kernel(&qtype, sizeof(qtype), c->pos)) {
+    return NULL;
   }
 
   struct map_value* value;
@@ -108,125 +126,184 @@ static inline enum dns_action check_qname(struct cursor *c)
   qkey.qtype = bpf_htons(qtype);
   value = qnamefilter.lookup(&qkey);
   if (value) {
-    __sync_fetch_and_add(&value->counter, 1);
-       return value->action;
+    return value;
   }
 
   // check with Qtype 255 (*)
   qkey.qtype = 255;
 
-  value = qnamefilter.lookup(&qkey);
-  if (value) {
-    __sync_fetch_and_add(&value->counter, 1);
-       return value->action;
-  }
-
-  return PASS;
+  return qnamefilter.lookup(&qkey);
 }
 
 /*
  * Parse IPv4 DNS mesage.
- * Returns PASS if message needs to go through (i.e. pass)
- *         TC if (modified) message needs to be replied
- *         DROP if message needs to be blocked
+ * Returns XDP_PASS if message needs to go through (i.e. pass)
+ *         XDP_REDIRECT if message needs to be redirected (for AF_XDP, which needs to be translated to the caller into XDP_PASS outside of the AF_XDP)
+ *         XDP_TX if (modified) message needs to be replied
+ *         XDP_DROP if message needs to be blocked
  */
-static inline enum dns_action udp_dns_reply_v4(struct cursor *c, struct CIDR4 *key)
+static inline enum xdp_action parseIPV4(struct xdp_md* ctx, struct cursor* c)
 {
-  struct udphdr  *udp;
-  struct dnshdr  *dns;
+  struct iphdr* ipv4;
+  struct udphdr* udp = NULL;
+  struct dnshdr* dns = NULL;
+  if (!(ipv4 = parse_iphdr(c))) {
+    return XDP_PASS;
+  }
+  switch (ipv4->protocol) {
+  case IPPROTO_UDP: {
+    if (!(udp = parse_udphdr(c))) {
+      return XDP_PASS;
+    }
+    if (!IN_DNS_PORT_SET(udp->dest)) {
+      return XDP_PASS;
+    }
+    if (!(dns = parse_dnshdr(c))) {
+      return XDP_DROP;
+    }
+    break;
+  }
 
-  if (!(udp = parse_udphdr(c)) || udp->dest != bpf_htons(DNS_PORT)) {
-       return PASS;
+#ifdef UseXsk
+  case IPPROTO_TCP: {
+    struct tcphdr* tcp;
+    if (!(tcp = parse_tcphdr(c))) {
+      return XDP_PASS;
+    }
+    if (!IN_DNS_PORT_SET(tcp->dest)) {
+      return XDP_PASS;
+    }
+  }
+#endif /* UseXsk */
+
+  default:
+    return XDP_PASS;
   }
 
-  // check that we have a DNS packet
-  if (!(dns = parse_dnshdr(c))) {
-       return PASS;
-  }    
+  struct CIDR4 key;
+  key.addr = bpf_htonl(ipv4->saddr);
 
   // if the address is blocked, perform the corresponding action
-  struct map_value* value = v4filter.lookup(&key->addr);
+  struct map_value* value = v4filter.lookup(&key.addr);
 
   if (value) {
-    __sync_fetch_and_add(&value->counter, 1);
-    if (value->action == TC) {
-         return set_tc_bit(udp, dns);
-    } else {
-      return value->action;
-    }
+    goto res;
+  }
+
+  key.cidr = 32;
+  key.addr = bpf_htonl(key.addr);
+  value = cidr4filter.lookup(&key);
+  if (value) {
+    goto res;
   }
 
-  key->cidr = 32;
-  key->addr = bpf_htonl(key->addr);
-  value = cidr4filter.lookup(key);
+  if (dns) {
+    value = check_qname(c);
+  }
   if (value) {
+  res:
     __sync_fetch_and_add(&value->counter, 1);
-    if (value->action == TC) {
-      return set_tc_bit(udp, dns);
+    if (value->action == TC && udp && dns) {
+      set_tc_bit(udp, dns);
+      // swap src/dest IP addresses
+      uint32_t swap_ipv4 = ipv4->daddr;
+      ipv4->daddr = ipv4->saddr;
+      ipv4->saddr = swap_ipv4;
+
+      progsarray.call(ctx, 1);
+      return XDP_TX;
     }
-    else {
-      return value->action;
+
+    if (value->action == DROP) {
+      progsarray.call(ctx, 0);
+      return XDP_DROP;
     }
   }
 
-  enum dns_action action = check_qname(c);
-  if (action == TC) {
-    return set_tc_bit(udp, dns);
-  }
-  return action;
+  return XDP_REDIRECT;
 }
 
 /*
  * Parse IPv6 DNS mesage.
- * Returns PASS if message needs to go through (i.e. pass)
- *         TC if (modified) message needs to be replied
- *         DROP if message needs to be blocked
+ * Returns XDP_PASS if message needs to go through (i.e. pass)
+ *         XDP_REDIRECT if message needs to be redirected (for AF_XDP, which needs to be translated to the caller into XDP_PASS outside of the AF_XDP)
+ *         XDP_TX if (modified) message needs to be replied
+ *         XDP_DROP if message needs to be blocked
  */
-static inline enum dns_action udp_dns_reply_v6(struct cursor *c, struct CIDR6* key)
+static inline enum xdp_action parseIPV6(struct xdp_md* ctx, struct cursor* c)
 {
-   struct udphdr  *udp;
-   struct dnshdr  *dns;
+  struct ipv6hdr* ipv6;
+  struct udphdr* udp = NULL;
+  struct dnshdr* dns = NULL;
+  if (!(ipv6 = parse_ipv6hdr(c))) {
+    return XDP_PASS;
+  }
+  switch (ipv6->nexthdr) {
+  case IPPROTO_UDP: {
+    if (!(udp = parse_udphdr(c))) {
+      return XDP_PASS;
+    }
+    if (!IN_DNS_PORT_SET(udp->dest)) {
+      return XDP_PASS;
+    }
+    if (!(dns = parse_dnshdr(c))) {
+      return XDP_DROP;
+    }
+    break;
+  }
 
-  
-  if (!(udp = parse_udphdr(c)) || udp->dest != bpf_htons(DNS_PORT)) {
-       return PASS;
+#ifdef UseXsk
+  case IPPROTO_TCP: {
+    struct tcphdr* tcp;
+    if (!(tcp = parse_tcphdr(c))) {
+      return XDP_PASS;
+    }
+    if (!IN_DNS_PORT_SET(tcp->dest)) {
+      return XDP_PASS;
+    }
   }
+#endif /* UseXsk */
 
-  // check that we have a DNS packet
-  ;
-  if (!(dns = parse_dnshdr(c))) {
-       return PASS;
+  default:
+    return XDP_PASS;
   }
 
+  struct CIDR6 key;
+  key.addr = ipv6->saddr;
+
   // if the address is blocked, perform the corresponding action
-  struct map_value* value = v6filter.lookup(&key->addr);
+  struct map_value* value = v6filter.lookup(&key.addr);
+  if (value) {
+    goto res;
+  }
 
+  key.cidr = 128;
+  value = cidr6filter.lookup(&key);
   if (value) {
-    __sync_fetch_and_add(&value->counter, 1);
-    if (value->action == TC) {
-         return set_tc_bit(udp, dns);
-    } else {
-      return value->action;
-    }
+    goto res;
   }
 
-  key->cidr = 128;
-  value = cidr6filter.lookup(key);
+  if (dns) {
+    value = check_qname(c);
+  }
   if (value) {
+  res:
     __sync_fetch_and_add(&value->counter, 1);
-    if (value->action == TC) {
-      return set_tc_bit(udp, dns);
+    if (value->action == TC && udp && dns) {
+      set_tc_bit(udp, dns);
+      // swap src/dest IP addresses
+      struct in6_addr swap_ipv6 = ipv6->daddr;
+      ipv6->daddr = ipv6->saddr;
+      ipv6->saddr = swap_ipv6;
+      progsarray.call(ctx, 1);
+      return XDP_TX;
     }
-    else {
-      return value->action;
+    if (value->action == DROP) {
+      progsarray.call(ctx, 0);
+      return XDP_DROP;
     }
   }
-
-  enum dns_action action = check_qname(c);
-  if (action == TC) {
-    return set_tc_bit(udp, dns);
-  }
-  return action;
+  return XDP_REDIRECT;
 }
 
 int xdp_dns_filter(struct xdp_md* ctx)
@@ -235,9 +312,7 @@ int xdp_dns_filter(struct xdp_md* ctx)
   struct cursor   c;
   struct ethhdr  *eth;
   uint16_t        eth_proto;
-  struct iphdr   *ipv4;
-  struct ipv6hdr *ipv6;
-  int            r = 0;
+  enum xdp_action r;
 
   // initialise the cursor
   cursor_init(&c, ctx);
@@ -245,67 +320,36 @@ int xdp_dns_filter(struct xdp_md* ctx)
   // pass the packet if it is not an ethernet one
   if ((eth = parse_eth(&c, &eth_proto))) {
     // IPv4 packets
-    if (eth_proto == bpf_htons(ETH_P_IP))
-    {
-      if (!(ipv4 = parse_iphdr(&c)) || bpf_htons(ipv4->protocol != IPPROTO_UDP)) {
-        return XDP_PASS;
-      }
-
-      struct CIDR4 key;
-      key.addr = bpf_htonl(ipv4->saddr);
-      // if TC bit must not be set, apply the action
-      if ((r = udp_dns_reply_v4(&c, &key)) != TC) {
-        if (r == DROP) {
-          progsarray.call(ctx, 0);
-          return XDP_DROP;
-        }
-        return XDP_PASS;
-      }
-
-      // swap src/dest IP addresses
-      uint32_t swap_ipv4 = ipv4->daddr;
-      ipv4->daddr = ipv4->saddr;
-      ipv4->saddr = swap_ipv4;
+    if (eth_proto == bpf_htons(ETH_P_IP)) {
+      r = parseIPV4(ctx, &c);
+      goto res;
     }
     // IPv6 packets
     else if (eth_proto == bpf_htons(ETH_P_IPV6)) {
-      if (!(ipv6 = parse_ipv6hdr(&c)) || bpf_htons(ipv6->nexthdr != IPPROTO_UDP)) {
-        return XDP_PASS;
-      }
-      struct CIDR6 key;
-      key.addr = ipv6->saddr;
-
-      // if TC bit must not be set, apply the action
-      if ((r = udp_dns_reply_v6(&c, &key)) != TC) {
-        if (r == DROP) {
-          progsarray.call(ctx, 0);
-          return XDP_DROP;
-        }
-        return XDP_PASS;
-      }
-
-      // swap src/dest IP addresses
-      struct in6_addr swap_ipv6 = ipv6->daddr;
-      ipv6->daddr = ipv6->saddr;
-      ipv6->saddr = swap_ipv6;
+      r = parseIPV6(ctx, &c);
+      goto res;
     }
     // pass all non-IP packets
-    else {
-      return XDP_PASS;
-    }
+    return XDP_PASS;
   }
-  else {
+  return XDP_PASS;
+res:
+  switch (r) {
+  case XDP_REDIRECT:
+#ifdef UseXsk
+    return xsk_map.redirect_map(ctx->rx_queue_index, 0);
+#else
     return XDP_PASS;
+#endif /* UseXsk */
+  case XDP_TX: { // swap MAC addresses
+    uint8_t swap_eth[ETH_ALEN];
+    memcpy(swap_eth, eth->h_dest, ETH_ALEN);
+    memcpy(eth->h_dest, eth->h_source, ETH_ALEN);
+    memcpy(eth->h_source, swap_eth, ETH_ALEN);
+    // bounce the request
+    return XDP_TX;
+  }
+  default:
+    return r;
   }
-
-  // swap MAC addresses
-  uint8_t swap_eth[ETH_ALEN];
-  memcpy(swap_eth, eth->h_dest, ETH_ALEN);
-  memcpy(eth->h_dest, eth->h_source, ETH_ALEN);
-  memcpy(eth->h_source, swap_eth, ETH_ALEN);
-
-  progsarray.call(ctx, 1);
-
-  // bounce the request
-  return XDP_TX;
 }
index 6384c3f8b85da367957396951689734f3a5f3efb..bd96ddb1ce58756fb552551a5bfa1c458ec012fe 100644 (file)
@@ -27,7 +27,15 @@ blocked_cidr6 = [("2001:db8::1/128", TC_ACTION)]
 blocked_qnames = [("localhost", "A", DROP_ACTION), ("test.com", "*", TC_ACTION)]
 
 # Main
-xdp = BPF(src_file="xdp-filter.ebpf.src")
+useXsk = True
+Ports = [53]
+cflag = []
+if useXsk:
+  cflag.append("-DUseXsk")
+IN_DNS_PORT_SET = "||".join("COMPARE_PORT((x),"+str(i)+")" for i in Ports)
+cflag.append(r"-DIN_DNS_PORT_SET(x)=(" + IN_DNS_PORT_SET + r")")
+
+xdp = BPF(src_file="xdp-filter.ebpf.src", cflags=cflag)
 
 fn = xdp.load_func("xdp_dns_filter", BPF.XDP)
 xdp.attach_xdp(DEV, fn, 0)
index 2fc728190966e1277ea3328305fb2cd2b10d2e0d..f429545a0beb60391fe963bdbed761cecca81231 100644 (file)
@@ -8,19 +8,6 @@ extern "C" {
   
 struct bpf_insn;
 
-int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size,
-                   int max_entries, int map_flags);
-int bpf_update_elem(int fd, void *key, void *value, unsigned long long flags);
-int bpf_lookup_elem(int fd, void *key, void *value);
-int bpf_delete_elem(int fd, void *key);
-int bpf_get_next_key(int fd, void *key, void *next_key);
-
-int bpf_prog_load(enum bpf_prog_type prog_type,
-                 const struct bpf_insn *insns, int insn_len,
-                 const char *license, int kern_version);
-
-int bpf_obj_pin(int fd, const char *pathname);
-int bpf_obj_get(const char *pathname);
 
 #define LOG_BUF_SIZE 65536
 extern char bpf_log_buf[LOG_BUF_SIZE];
index ec6bd05c5528cbd2517350fce1cdf324bf092de5..19343955c81fb0677243418d7c3b0859e93e7ea1 100644 (file)
 
 #include "misc.hh"
 
+int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size,
+                   int max_entries, int map_flags);
+int bpf_update_elem(int fd, void *key, void *value, unsigned long long flags);
+int bpf_lookup_elem(int fd, void *key, void *value);
+int bpf_delete_elem(int fd, void *key);
+int bpf_get_next_key(int fd, void *key, void *next_key);
+
+int bpf_prog_load(enum bpf_prog_type prog_type,
+                 const struct bpf_insn *insns, int insn_len,
+                 const char *license, int kern_version);
+
+int bpf_obj_pin(int fd, const char *pathname);
+int bpf_obj_get(const char *pathname);
+
 static __u64 ptr_to_u64(void *ptr)
 {
   return (__u64) (unsigned long) ptr;
index 73d5f6e5e32feea16e808c42e9b54b3b8343960c..49248a08c7865abf11a7f088d6bf60b9e5ddad1f 100644 (file)
@@ -134,6 +134,7 @@ struct InternalQueryState
   std::unique_ptr<PacketBuffer> d_packet{nullptr}; // Initial packet, so we can restart the query from the response path if needed // 8
   std::unique_ptr<ProtoBufData> d_protoBufData{nullptr};
   std::unique_ptr<EDNSExtendedError> d_extendedError{nullptr};
+  std::unique_ptr<PacketBuffer> xskPacketHeader; // 8
   boost::optional<uint32_t> tempFailureTTL{boost::none}; // 8
   ClientState* cs{nullptr}; // 8
   std::unique_ptr<DOHUnitInterface> du; // 8
index 79ba4ec57b33e20b278610fa147476917cb2a8eb..130a71153ddf9ef0e608ba2d7bde5b19ca4a67b4 100644 (file)
@@ -28,6 +28,7 @@
 #include "dnsdist-svc.hh"
 
 #include "dolog.hh"
+#include "xsk.hh"
 
 // NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
 void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck)
@@ -715,7 +716,53 @@ void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck)
       }
     });
 #endif /* HAVE_EBPF */
-
+#ifdef HAVE_XSK
+  using xskopt_t = LuaAssociativeTable<boost::variant<uint32_t, std::string>>;
+  luaCtx.writeFunction("newXsk", [client](xskopt_t opts) {
+    if (g_configurationDone) {
+      throw std::runtime_error("newXsk() only can be used at configuration time!");
+    }
+    if (client) {
+      return std::shared_ptr<XskSocket>(nullptr);
+    }
+    uint32_t queue_id;
+    uint32_t frameNums;
+    std::string ifName;
+    std::string path;
+    std::string poolName;
+    if (opts.count("NIC_queue_id") == 1) {
+      queue_id = boost::get<uint32_t>(opts.at("NIC_queue_id"));
+    }
+    else {
+      throw std::runtime_error("NIC_queue_id field is required!");
+    }
+    if (opts.count("frameNums") == 1) {
+      frameNums = boost::get<uint32_t>(opts.at("frameNums"));
+    }
+    else {
+      throw std::runtime_error("frameNums field is required!");
+    }
+    if (opts.count("ifName") == 1) {
+      ifName = boost::get<std::string>(opts.at("ifName"));
+    }
+    else {
+      throw std::runtime_error("ifName field is required!");
+    }
+    if (opts.count("xskMapPath") == 1) {
+      path = boost::get<std::string>(opts.at("xskMapPath"));
+    }
+    else {
+      throw std::runtime_error("xskMapPath field is required!");
+    }
+    if (opts.count("pool") == 1) {
+      poolName = boost::get<std::string>(opts.at("pool"));
+    }
+    extern std::vector<std::shared_ptr<XskSocket>> g_xsk;
+    auto socket = std::make_shared<XskSocket>(frameNums, ifName, queue_id, path, poolName);
+    g_xsk.push_back(socket); 
+    return socket;
+  });
+#endif /* HAVE_XSK */
   /* EDNSOptionView */
   luaCtx.registerFunction<size_t(EDNSOptionView::*)()const>("count", [](const EDNSOptionView& option) {
       return option.values.size();
index 54b7109b1968413a0df507e7862c933b61ec66f3..ac6e6e84e8d5ce4f62dc7cbcd55281c9bc7d5c99 100644 (file)
@@ -21,6 +21,7 @@
  */
 
 #include <cstdint>
+#include <cstdio>
 #include <dirent.h>
 #include <fstream>
 #include <cinttypes>
@@ -45,6 +46,7 @@
 #include "dnsdist-ecs.hh"
 #include "dnsdist-healthchecks.hh"
 #include "dnsdist-lua.hh"
+#include "xsk.hh"
 #ifdef LUAJIT_VERSION
 #include "dnsdist-lua-ffi.hh"
 #endif /* LUAJIT_VERSION */
@@ -110,7 +112,7 @@ void resetLuaSideEffect()
   g_noLuaSideEffect = boost::logic::indeterminate;
 }
 
-using localbind_t = LuaAssociativeTable<boost::variant<bool, int, std::string, LuaArray<int>, LuaArray<std::string>, LuaAssociativeTable<std::string>>>;
+using localbind_t = LuaAssociativeTable<boost::variant<bool, int, std::string, LuaArray<int>, LuaArray<std::string>, LuaAssociativeTable<std::string>, std::shared_ptr<XskSocket>>>;
 
 static void parseLocalBindVars(boost::optional<localbind_t>& vars, bool& reusePort, int& tcpFastOpenQueueSize, std::string& interface, std::set<int>& cpus, int& tcpListenQueueSize, uint64_t& maxInFlightQueriesPerConnection, uint64_t& tcpMaxConcurrentConnections, bool& enableProxyProtocol)
 {
@@ -131,6 +133,16 @@ static void parseLocalBindVars(boost::optional<localbind_t>& vars, bool& reusePo
     }
   }
 }
+#ifdef HAVE_XSK
+static void parseXskVars(boost::optional<localbind_t>& vars, std::shared_ptr<XskSocket>& socket)
+{
+  if (!vars) {
+    return;
+  }
+
+  getOptionalValue<std::shared_ptr<XskSocket>>(vars, "xskSocket", socket);
+}
+#endif /* HAVE_XSK */
 
 #if defined(HAVE_DNS_OVER_TLS) || defined(HAVE_DNS_OVER_HTTPS) || defined(HAVE_DNS_OVER_QUIC)
 static bool loadTLSCertificateAndKeys(const std::string& context, std::vector<TLSCertKeyPair>& pairs, const boost::variant<std::string, std::shared_ptr<TLSCertKeyPair>, LuaArray<std::string>, LuaArray<std::shared_ptr<TLSCertKeyPair>>>& certFiles, const LuaTypeOrArrayOf<std::string>& keyFiles)
@@ -298,7 +310,7 @@ static bool checkConfigurationTime(const std::string& name)
 // NOLINTNEXTLINE(readability-function-cognitive-complexity): this function declares Lua bindings, even with a good refactoring it will likely blow up the threshold
 static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 {
-  typedef LuaAssociativeTable<boost::variant<bool, std::string, LuaArray<std::string>, DownstreamState::checkfunc_t>> newserver_t;
+  typedef LuaAssociativeTable<boost::variant<bool, std::string, LuaArray<std::string>, std::shared_ptr<XskSocket>, DownstreamState::checkfunc_t>> newserver_t;
   luaCtx.writeFunction("inClientStartup", [client]() {
     return client && !g_configurationDone;
   });
@@ -621,7 +633,18 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                          if (!(client || configCheck)) {
                            infolog("Added downstream server %s", ret->d_config.remote.toStringWithPort());
                          }
-
+#ifdef HAVE_XSK
+                         std::shared_ptr<XskSocket> xskSocket;
+                         if (getOptionalValue<std::shared_ptr<XskSocket>>(vars, "xskSocket", xskSocket) > 0) {
+                           ret->registerXsk(xskSocket);
+                           std::string mac;
+                           if (getOptionalValue<std::string>(vars, "MACAddr", mac) != 1) {
+                             throw runtime_error("field MACAddr is required!");
+                           }
+                           auto* addr = &ret->d_config.destMACAddr[0];
+                           sscanf(mac.c_str(), "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx", addr, addr + 1, addr + 2, addr + 3, addr + 4, addr + 5);
+                         }
+#endif /* HAVE_XSK */
                          if (autoUpgrade && ret->getProtocol() != dnsdist::Protocol::DoT && ret->getProtocol() != dnsdist::Protocol::DoH) {
                            dnsdist::ServiceDiscovery::addUpgradeableServer(ret, upgradeInterval, upgradePool, upgradeDoHKey, keepAfterUpgrade);
                          }
@@ -744,7 +767,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       }
 
       // only works pre-startup, so no sync necessary
-      g_frontends.push_back(std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol));
+      auto udpCS = std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
       auto tcpCS = std::make_unique<ClientState>(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
       if (tcpListenQueueSize > 0) {
         tcpCS->tcpListenQueueSize = tcpListenQueueSize;
@@ -756,6 +779,18 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
         tcpCS->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections;
       }
 
+#ifdef HAVE_XSK
+      std::shared_ptr<XskSocket> socket;
+      parseXskVars(vars, socket);
+      if (socket) {
+        udpCS->xskInfo = XskWorker::create();
+        udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset;
+        socket->addWorker(udpCS->xskInfo, loc, false);
+        // tcpCS->xskInfo=XskWorker::create();
+        // TODO: socket->addWorker(tcpCS->xskInfo, loc, true);
+      }
+#endif /* HAVE_XSK */
+      g_frontends.push_back(std::move(udpCS));
       g_frontends.push_back(std::move(tcpCS));
     }
     catch (const std::exception& e) {
@@ -786,7 +821,7 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
     try {
       ComboAddress loc(addr, 53);
       // only works pre-startup, so no sync necessary
-      g_frontends.push_back(std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol));
+      auto udpCS = std::make_unique<ClientState>(loc, false, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
       auto tcpCS = std::make_unique<ClientState>(loc, true, reusePort, tcpFastOpenQueueSize, interface, cpus, enableProxyProtocol);
       if (tcpListenQueueSize > 0) {
         tcpCS->tcpListenQueueSize = tcpListenQueueSize;
@@ -797,6 +832,18 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       if (tcpMaxConcurrentConnections > 0) {
         tcpCS->d_tcpConcurrentConnectionsLimit = tcpMaxConcurrentConnections;
       }
+#ifdef HAVE_XSK
+      std::shared_ptr<XskSocket> socket;
+      parseXskVars(vars, socket);
+      if (socket) {
+        udpCS->xskInfo = XskWorker::create();
+        udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset;
+        socket->addWorker(udpCS->xskInfo, loc, false);
+        // TODO tcpCS->xskInfo=XskWorker::create();
+        // TODO socket->addWorker(tcpCS->xskInfo, loc, true);
+      }
+#endif /* HAVE_XSK */
+      g_frontends.push_back(std::move(udpCS));
       g_frontends.push_back(std::move(tcpCS));
     }
     catch (std::exception& e) {
index c15a14484db56b1e5afcd022dc24af4cf531a27b..3baf478ea22b124d02813d5348e6045e6037313e 100644 (file)
@@ -495,9 +495,8 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe
   if (!response.isAsync()) {
     try {
       auto& ids = response.d_idstate;
-      unsigned int qnameWireLength{0};
       std::shared_ptr<DownstreamState> backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr);
-      if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend, qnameWireLength)) {
+      if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend)) {
         state->terminateClientConnection();
         return;
       }
index 3beef62b1234ff67c635a9ccd34b1fa1ed81b330..304cd0f99b695e42d07450ed4e072cf35f584fd7 100644 (file)
 #include <limits>
 #include <netinet/tcp.h>
 #include <pwd.h>
+#include <set>
 #include <sys/resource.h>
 #include <unistd.h>
 
+#ifdef HAVE_XSK
+#include <sys/poll.h>
+#include <sys/timerfd.h>
+#endif /* HAVE_XSK */
+
 #ifdef HAVE_LIBEDIT
 #if defined (__OpenBSD__) || defined(__NetBSD__)
 // If this is not undeffed, __attribute__ wil be redefined by /usr/include/readline/rlstdc.h
@@ -111,6 +117,7 @@ std::vector<std::shared_ptr<DOHFrontend>> g_dohlocals;
 std::vector<std::shared_ptr<DOQFrontend>> g_doqlocals;
 std::vector<std::shared_ptr<DOH3Frontend>> g_doh3locals;
 std::vector<std::shared_ptr<DNSCryptContext>> g_dnsCryptLocals;
+std::vector<std::shared_ptr<XskSocket>> g_xsk;
 
 shared_ptr<BPFFilter> g_defaultBPFFilter{nullptr};
 std::vector<std::shared_ptr<DynBPFFilter> > g_dynBPFFilters;
@@ -332,7 +339,7 @@ static void doLatencyStats(dnsdist::Protocol protocol, double udiff)
   }
 }
 
-bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength)
+bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote)
 {
   if (response.size() < sizeof(dnsheader)) {
     return false;
@@ -363,7 +370,7 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname,
   uint16_t rqtype, rqclass;
   DNSName rqname;
   try {
-    rqname = DNSName(reinterpret_cast<const char*>(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass, &qnameWireLength);
+    rqname = DNSName(reinterpret_cast<const char*>(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass);
   }
   catch (const std::exception& e) {
     if (remote && response.size() > 0 && static_cast<size_t>(response.size()) > sizeof(dnsheader)) {
@@ -743,9 +750,7 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
   }
 
   bool muted = true;
-  if (ids.cs && !ids.cs->muted) {
-    ComboAddress empty;
-    empty.sin4.sin_family = 0;
+  if (ids.cs && !ids.cs->muted && !ids.xskPacketHeader) {
     sendUDPResponse(ids.cs->udpFD, response, dr.ids.delayMsec, ids.hopLocal, ids.hopRemote);
     muted = false;
   }
@@ -766,6 +771,61 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re
   }
 }
 
+#ifdef HAVE_XSK
+static void XskHealthCheck(std::shared_ptr<DownstreamState>& dss, std::unordered_map<uint16_t, std::shared_ptr<HealthCheckData>>& map, bool initial = false)
+{
+  auto& xskInfo = dss->xskInfo;
+  std::shared_ptr<HealthCheckData> data;
+  auto packet = getHealthCheckPacket(dss, nullptr, data);
+  data->d_initial = initial;
+  setHealthCheckTime(dss, data);
+  auto* frame = xskInfo->getEmptyframe();
+  auto *xskPacket = new XskPacket(frame, 0, xskInfo->frameSize);
+  xskPacket->setAddr(dss->d_config.sourceAddr, dss->d_config.sourceMACAddr, dss->d_config.remote, dss->d_config.destMACAddr);
+  xskPacket->setPayload(packet);
+  xskPacket->rewrite();
+  xskInfo->sq.push(xskPacket);
+  const auto queryId = data->d_queryID;
+  map[queryId] = std::move(data);
+}
+#endif /* HAVE_XSK */
+
+static bool processResponderPacket(std::shared_ptr<DownstreamState>& dss, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& localRespRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, InternalQueryState&& ids)
+{
+
+  const dnsheader_aligned dh(response.data());
+  auto queryId = dh->id;
+
+  if (!responseContentMatches(response, ids.qname, ids.qtype, ids.qclass, dss)) {
+    dss->restoreState(queryId, std::move(ids));
+    return false;
+  }
+
+  auto du = std::move(ids.du);
+  dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) {
+    header.id = ids.origID;
+    return true;
+  });
+  ++dss->responses;
+
+  double udiff = ids.queryRealTime.udiff();
+  // do that _before_ the processing, otherwise it's not fair to the backend
+  dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0;
+  dss->reportResponse(dh->rcode);
+
+  /* don't call processResponse for DOH */
+  if (du) {
+#ifdef HAVE_DNS_OVER_HTTPS
+    // DoH query, we cannot touch du after that
+    DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(ids), dss);
+#endif
+    return false;
+  }
+
+  handleResponseForUDPClient(ids, response, localRespRuleActions, cacheInsertedRespRuleActions, dss, false, false);
+  return true;
+}
+
 // listens on a dedicated socket, lobs answers from downstream servers to original requestors
 void responderThread(std::shared_ptr<DownstreamState> dss)
 {
@@ -773,6 +833,103 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
   setThreadName("dnsdist/respond");
   auto localRespRuleActions = g_respruleactions.getLocal();
   auto localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
+#ifdef HAVE_XSK
+    if (dss->xskInfo) {
+      auto xskInfo = dss->xskInfo;
+      auto pollfds = getPollFdsForWorker(*xskInfo);
+      std::unordered_map<uint16_t, std::shared_ptr<HealthCheckData>> healthCheckMap;
+      XskHealthCheck(dss, healthCheckMap, true);
+      itimerspec tm;
+      tm.it_value.tv_sec = dss->d_config.checkTimeout / 1000;
+      tm.it_value.tv_nsec = (dss->d_config.checkTimeout % 1000) * 1000000;
+      tm.it_interval = tm.it_value;
+      auto res = timerfd_settime(pollfds[1].fd, 0, &tm, nullptr);
+      if (res) {
+        throw std::runtime_error("timerfd_settime failed:" + stringerror(errno));
+      }
+      const auto xskFd = xskInfo->workerWaker.getHandle();
+      while (!dss->isStopped()) {
+        poll(pollfds.data(), pollfds.size(), -1);
+        bool needNotify = false;
+        if (pollfds[0].revents & POLLIN) {
+          needNotify = true;
+          xskInfo->cq.consume_all([&](XskPacket* packet) {
+            if (packet->dataLen() < sizeof(dnsheader)) {
+              xskInfo->sq.push(packet);
+              return;
+            }
+            const auto* dh = reinterpret_cast<const struct dnsheader*>(packet->payloadData());
+            const auto queryId = dh->id;
+            auto ids = dss->getState(queryId);
+            if (ids) {
+              if (xskFd != ids->backendFD || !ids->xskPacketHeader) {
+                dss->restoreState(queryId, std::move(*ids));
+                ids = std::nullopt;
+              }
+            }
+            if (!ids) {
+              // this has to go before we can refactor the duplicated response handling code
+              auto iter = healthCheckMap.find(queryId);
+              if (iter != healthCheckMap.end()) {
+                auto data = std::move(iter->second);
+                healthCheckMap.erase(iter);
+                packet->cloneIntoPacketBuffer(data->d_buffer);
+                data->d_ds->submitHealthCheckResult(data->d_initial, handleResponse(data));
+              }
+              xskInfo->sq.push(packet);
+              return;
+            }
+            auto response = packet->clonePacketBuffer();
+            if (!processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids))) {
+              xskInfo->sq.push(packet);
+              return;
+            }
+            packet->setHeader(*ids->xskPacketHeader);
+            packet->setPayload(response);
+            if (ids->delayMsec > 0) {
+              packet->addDelay(ids->delayMsec);
+            }
+            packet->updatePacket();
+            xskInfo->sq.push(packet);
+          });
+          xskInfo->cleanSocketNotification();
+        }
+        if (pollfds[1].revents & POLLIN) {
+          timeval now;
+          gettimeofday(&now, nullptr);
+          for (auto i = healthCheckMap.begin(); i != healthCheckMap.end();) {
+            auto& ttd = i->second->d_ttd;
+            if (ttd < now) {
+              dss->submitHealthCheckResult(i->second->d_initial, false);
+              i = healthCheckMap.erase(i);
+            }
+            else {
+              ++i;
+            }
+          }
+          needNotify = true;
+          dss->updateStatisticsInfo();
+          dss->handleUDPTimeouts();
+          if (dss->d_nextCheck <= 1) {
+            dss->d_nextCheck = dss->d_config.checkInterval;
+            if (dss->d_config.availability == DownstreamState::Availability::Auto) {
+              XskHealthCheck(dss, healthCheckMap);
+            }
+          }
+          else {
+            --dss->d_nextCheck;
+          }
+
+          uint64_t tmp;
+          res = read(pollfds[1].fd, &tmp, sizeof(tmp));
+        }
+        if (needNotify) {
+          xskInfo->notifyXskSocket();
+        }
+      }
+    }
+    else {
+#endif /* HAVE_XSK */
   const size_t initialBufferSize = getInitialUDPPacketBufferSize(false);
   /* allocate one more byte so we can detect truncation */
   PacketBuffer response(initialBufferSize + 1);
@@ -805,7 +962,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 
       for (const auto& fd : sockets) {
         /* allocate one more byte so we can detect truncation */
-        // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it 
+        // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it
         response.resize(initialBufferSize + 1);
         ssize_t got = recv(fd, response.data(), response.size(), 0);
 
@@ -826,41 +983,32 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
           continue;
         }
 
-        unsigned int qnameWireLength = 0;
-        if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) {
+        if (fd != ids->backendFD) {
           dss->restoreState(queryId, std::move(*ids));
           continue;
         }
 
-        auto du = std::move(ids->du);
-
-        dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) {
-          header.id = ids->origID;
-          return true;
-        });
-        ++dss->responses;
-
-        double udiff = ids->queryRealTime.udiff();
-        // do that _before_ the processing, otherwise it's not fair to the backend
-        dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0;
-        dss->reportResponse(dh->rcode);
-
-        /* don't call processResponse for DOH */
-        if (du) {
-#ifdef HAVE_DNS_OVER_HTTPS
-          // DoH query, we cannot touch du after that
-          DOHUnitInterface::handleUDPResponse(std::move(du), std::move(response), std::move(*ids), dss);
-#endif
-          continue;
+        if (processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids)) && ids->xskPacketHeader && ids->cs->xskInfo) {
+#ifdef HAVE_XSK
+          auto& xskInfo = ids->cs->xskInfo;
+          auto* frame = xskInfo->getEmptyframe();
+          auto xskPacket = std::make_unique<XskPacket>(frame, 0, xskInfo->frameSize);
+          xskPacket->setHeader(*ids->xskPacketHeader);
+          xskPacket->setPayload(response);
+          xskPacket->updatePacket();
+          xskInfo->sq.push(xskPacket.release());
+          xskInfo->notifyXskSocket();
+#endif /* HAVE_XSK */
         }
-
-        handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, false);
       }
     }
     catch (const std::exception& e) {
       vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->d_config.remote.toStringWithPort(), queryId, e.what());
     }
   }
+#ifdef HAVE_XSK
+  }
+#endif /* HAVE_XSK */
 }
 catch (const std::exception& e) {
   errlog("UDP responder thread died because of exception: %s", e.what());
@@ -1280,6 +1428,23 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
   return true;
 }
 
+#ifdef HAVE_XSK
+static bool isXskQueryAcceptable(const XskPacket& packet, ClientState& cs, LocalHolders& holders, bool& expectProxyProtocol) noexcept
+{
+  const auto& from = packet.getFromAddr();
+  expectProxyProtocol = expectProxyProtocolFrom(from);
+  if (!holders.acl->match(from) && !expectProxyProtocol) {
+    vinfolog("Query from %s dropped because of ACL", from.toStringWithPort());
+    ++dnsdist::metrics::g_stats.aclDrops;
+    return false;
+  }
+  cs.queries++;
+  ++dnsdist::metrics::g_stats.queries;
+
+  return true;
+}
+#endif /* HAVE_XSK */
+
 bool checkDNSCryptQuery(const ClientState& cs, PacketBuffer& query, std::unique_ptr<DNSCryptQuery>& dnsCryptQuery, time_t now, bool tcp)
 {
   if (cs.dnscryptCtx) {
@@ -1408,7 +1573,11 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders
       ++dq.ids.cs->responses;
       return ProcessQueryResult::SendAnswer;
     }
-
+#ifdef HAVE_XSK
+    if (dq.ids.cs->xskInfo) {
+      dq.ids.poolName = dq.ids.cs->xskInfo->poolName;
+    }
+#endif /* HAVE_XSK */
     std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, dq.ids.poolName);
     std::shared_ptr<ServerPolicy> poolPolicy = serverPool->policy;
     dq.ids.packetCache = serverPool->packetCache;
@@ -1649,7 +1818,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::sha
   return ProcessQueryResult::Drop;
 }
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool actuallySend)
 {
   bool doh = dq.ids.du != nullptr;
 
@@ -1670,7 +1839,9 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint1
 
   try {
     int fd = ds->pickSocketForSending();
-    dq.ids.backendFD = fd;
+    if (actuallySend) {
+      dq.ids.backendFD = fd;
+    }
     dq.ids.origID = queryID;
     dq.ids.forwardedOverUDP = true;
 
@@ -1682,6 +1853,10 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint1
     /* set the correct ID */
     memcpy(&query.at(proxyProtocolPayloadSize), &idOffset, sizeof(idOffset));
 
+    if (!actuallySend) {
+      return true;
+    }
+
     /* you can't touch ids or du after this line, unless the call returned a non-negative value,
        because it might already have been freed */
     ssize_t ret = udpClientSendRequestToBackend(ds, fd, query);
@@ -1839,6 +2014,127 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
   }
 }
 
+#ifdef HAVE_XSK
+static void ProcessXskQuery(ClientState& cs, LocalHolders& holders, XskPacket& packet)
+{
+  uint16_t queryId = 0;
+  const auto& remote = packet.getFromAddr();
+  const auto& dest = packet.getToAddr();
+  InternalQueryState ids;
+  ids.cs = &cs;
+  ids.origRemote = remote;
+  ids.hopRemote = remote;
+  ids.origDest = dest;
+  ids.hopLocal = dest;
+  ids.protocol = dnsdist::Protocol::DoUDP;
+  ids.xskPacketHeader = packet.cloneHeadertoPacketBuffer();
+
+  try {
+    bool expectProxyProtocol = false;
+    if (!isXskQueryAcceptable(packet, cs, holders, expectProxyProtocol)) {
+      return;
+    }
+
+    auto query = packet.clonePacketBuffer();
+    std::vector<ProxyProtocolValue> proxyProtocolValues;
+    if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, ids.origRemote, ids.origDest, proxyProtocolValues)) {
+      return;
+    }
+
+    ids.queryRealTime.start();
+
+    auto dnsCryptResponse = checkDNSCryptQuery(cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, false);
+    if (dnsCryptResponse) {
+      packet.setPayload(query);
+      return;
+    }
+
+    {
+      /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
+      dnsheader_aligned dnsHeader(query.data());
+      queryId = ntohs(dnsHeader->id);
+
+      if (!checkQueryHeaders(dnsHeader.get(), cs)) {
+        return;
+      }
+
+      if (dnsHeader->qdcount == 0) {
+        dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
+          header.rcode = RCode::NotImp;
+          header.qr = true;
+          return true;
+        });
+        packet.setPayload(query);
+        return;
+      }
+    }
+
+    ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
+    if (ids.origDest.sin4.sin_family == 0) {
+      ids.origDest = cs.local;
+    }
+    if (ids.dnsCryptQuery) {
+      ids.protocol = dnsdist::Protocol::DNSCryptUDP;
+    }
+    DNSQuestion dq(ids, query);
+    if (!proxyProtocolValues.empty()) {
+      dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues));
+    }
+    std::shared_ptr<DownstreamState> ss{nullptr};
+    auto result = processQuery(dq, holders, ss);
+
+    if (result == ProcessQueryResult::Drop) {
+      return;
+    }
+
+    if (result == ProcessQueryResult::SendAnswer) {
+      packet.setPayload(query);
+      if (dq.ids.delayMsec > 0) {
+        packet.addDelay(dq.ids.delayMsec);
+      }
+      return;
+    }
+
+    if (result != ProcessQueryResult::PassToBackend || ss == nullptr) {
+      return;
+    }
+
+    // the buffer might have been invalidated by now (resized)
+    const auto dh = dq.getHeader();
+    if (ss->isTCPOnly()) {
+      std::string proxyProtocolPayload;
+      /* we need to do this _before_ creating the cross protocol query because
+         after that the buffer will have been moved */
+      if (ss->d_config.useProxyProtocol) {
+        proxyProtocolPayload = getProxyProtocolPayload(dq);
+      }
+
+      ids.origID = dh->id;
+      auto cpq = std::make_unique<UDPCrossProtocolQuery>(std::move(query), std::move(ids), ss);
+      cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
+
+      ss->passCrossProtocolQuery(std::move(cpq));
+      return;
+    }
+
+    if (!ss->xskInfo) {
+      assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, true);
+    }
+    else {
+      int fd = ss->xskInfo->workerWaker;
+      ids.backendFD = fd;
+      assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, false);
+      packet.setAddr(ss->d_config.sourceAddr,ss->d_config.sourceMACAddr, ss->d_config.remote,ss->d_config.destMACAddr);
+      packet.setPayload(query);
+      packet.rewrite();
+    }
+  }
+  catch (const std::exception& e) {
+    vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
+  }
+}
+#endif /* HAVE_XSK */
+
 #ifndef DISABLE_RECVMMSG
 #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
 static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders)
@@ -1931,6 +2227,27 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde
 #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */
 #endif /* DISABLE_RECVMMSG */
 
+#ifdef HAVE_XSK
+static void xskClientThread(ClientState* cs)
+{
+  setThreadName("dnsdist/xskClient");
+  auto xskInfo = cs->xskInfo;
+  LocalHolders holders;
+
+  for (;;) {
+    while (!xskInfo->cq.read_available()) {
+      xskInfo->waitForXskSocket();
+    }
+    xskInfo->cq.consume_all([&](XskPacket* packet) {
+      ProcessXskQuery(*cs, holders, *packet);
+      packet->updatePacket();
+      xskInfo->sq.push(packet);
+    });
+    xskInfo->notifyXskSocket();
+  }
+}
+#endif /* HAVE_XSK */
+
 // listens to incoming queries, sends out to downstream servers, noting the intended return path
 static void udpClientThread(std::vector<ClientState*> states)
 {
@@ -2177,11 +2494,12 @@ static void healthChecksThread()
 
     std::unique_ptr<FDMultiplexer> mplexer{nullptr};
     for (auto& dss : *states) {
-      auto delta = dss->sw.udiffAndSet()/1000000.0;
-      dss->queryLoad.store(1.0*(dss->queries.load() - dss->prev.queries.load())/delta);
-      dss->dropRate.store(1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta);
-      dss->prev.queries.store(dss->queries.load());
-      dss->prev.reuseds.store(dss->reuseds.load());
+#ifdef HAVE_XSK
+      if (dss->xskInfo) {
+        continue;
+      }
+#endif /* HAVE_XSK */
+      dss->updateStatisticsInfo();
 
       dss->handleUDPTimeouts();
 
@@ -2909,13 +3227,35 @@ static void initFrontends()
   }
 }
 
+#ifdef HAVE_XSK
+void XskRouter(std::shared_ptr<XskSocket> xsk);
+#endif /* HAVE_XSK */
+
 namespace dnsdist
 {
 static void startFrontends()
 {
+#ifdef HAVE_XSK
+  for (auto& xskContext : g_xsk) {
+    std::thread xskThread(XskRouter, std::move(xskContext));
+    xskThread.detach();
+  }
+#endif /* HAVE_XSK */
+
   std::vector<ClientState*> tcpStates;
   std::vector<ClientState*> udpStates;
   for (auto& clientState : g_frontends) {
+#ifdef HAVE_XSK
+    if (clientState->xskInfo) {
+      std::thread xskCT(xskClientThread, clientState.get());
+      if (!clientState->cpus.empty()) {
+        mapThreadToCPUList(xskCT.native_handle(), clientState->cpus);
+      }
+      xskCT.detach();
+      continue;
+    }
+#endif /* HAVE_XSK */
+
     if (clientState->dohFrontend != nullptr && clientState->dohFrontend->d_library == "h2o") {
 #ifdef HAVE_DNS_OVER_HTTPS
 #ifdef HAVE_LIBH2OEVLOOP
@@ -3175,6 +3515,12 @@ int main(int argc, char** argv)
       auto states = g_dstates.getCopy(); // it is a copy, but the internal shared_ptrs are the real deal
       auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(states.size()));
       for (auto& dss : states) {
+#ifdef HAVE_XSK
+        if (dss->xskInfo) {
+          continue;
+        }
+#endif /* HAVE_XSK */
+
         if (dss->d_config.availability == DownstreamState::Availability::Auto || dss->d_config.availability == DownstreamState::Availability::Lazy) {
           if (dss->d_config.availability == DownstreamState::Availability::Auto) {
             dss->d_nextCheck = dss->d_config.checkInterval;
@@ -3270,3 +3616,72 @@ int main(int argc, char** argv)
 #endif
   }
 }
+
+#ifdef HAVE_XSK
+void XskRouter(std::shared_ptr<XskSocket> xsk)
+{
+  setThreadName("dnsdist/XskRouter");
+  uint32_t failed;
+  // packets to be submitted for sending
+  vector<XskPacketPtr> fillInTx;
+  const auto size = xsk->fds.size();
+  // list of workers that need to be notified
+  std::set<int> needNotify;
+  const auto& xskWakerIdx = xsk->workers.get<0>();
+  const auto& destIdx = xsk->workers.get<1>();
+  while (true) {
+    auto ready = xsk->wait(-1);
+    // descriptor 0 gets incoming AF_XDP packets
+    if (xsk->fds[0].revents & POLLIN) {
+      auto packets = xsk->recv(64, &failed);
+      dnsdist::metrics::g_stats.nonCompliantQueries += failed;
+      for (auto &packet : packets) {
+        const auto dest = packet->getToAddr();
+        auto res = destIdx.find(dest);
+        if (res == destIdx.end()) {
+          xsk->uniqueEmptyFrameOffset.push_back(xsk->frameOffset(*packet));
+          continue;
+        }
+        res->worker->cq.push(packet.release());
+        needNotify.insert(res->workerWaker);
+      }
+      for (auto i : needNotify) {
+        uint64_t x = 1;
+        auto written = write(i, &x, sizeof(x));
+        if (written != sizeof(x)) {
+          // oh, well, the worker is clearly overloaded
+          // but there is nothing we can do about it,
+          // and hopefully the queue will be processed eventually
+        }
+      }
+      needNotify.clear();
+      ready--;
+    }
+    const auto backup = ready;
+    for (size_t i = 1; i < size && ready > 0; i++) {
+      if (xsk->fds[i].revents & POLLIN) {
+        ready--;
+        auto& info = xskWakerIdx.find(xsk->fds[i].fd)->worker;
+        info->sq.consume_all([&](XskPacket* x) {
+          if (!(x->getFlags() & XskPacket::UPDATE)) {
+            xsk->uniqueEmptyFrameOffset.push_back(xsk->frameOffset(*x));
+            return;
+          }
+          auto ptr = std::unique_ptr<XskPacket>(x);
+          if (x->getFlags() & XskPacket::DELAY) {
+            xsk->waitForDelay.push(std::move(ptr));
+            return;
+          }
+          fillInTx.push_back(std::move(ptr));
+        });
+        info->cleanWorkerNotification();
+      }
+    }
+    xsk->pickUpReadyPacket(fillInTx);
+    xsk->recycle(64);
+    xsk->fillFq();
+    xsk->send(fillInTx);
+    ready = backup;
+  }
+}
+#endif /* HAVE_XSK */
index 9d5d06f59296c69f280fb0af28e641e5bcfd5857..34d4600dac1d0d91dcc66939ccdb8ad9147d4802 100644 (file)
@@ -56,6 +56,7 @@
 #include "uuid-utils.hh"
 #include "proxy-protocol.hh"
 #include "stat_t.hh"
+#include "xsk.hh"
 
 uint64_t uptimeOfProcess(const std::string& str);
 
@@ -511,6 +512,7 @@ struct ClientState
   std::shared_ptr<DOQFrontend> doqFrontend{nullptr};
   std::shared_ptr<DOH3Frontend> doh3Frontend{nullptr};
   std::shared_ptr<BPFFilter> d_filter{nullptr};
+  std::shared_ptr<XskWorker> xskInfo{nullptr};
   size_t d_maxInFlightQueriesPerConn{1};
   size_t d_tcpConcurrentConnectionsLimit{0};
   int udpFD{-1};
@@ -702,6 +704,8 @@ struct DownstreamState: public std::enable_shared_from_this<DownstreamState>
     std::string d_dohPath;
     std::string name;
     std::string nameWithAddr;
+    MACAddr sourceMACAddr;
+    MACAddr destMACAddr;
     size_t d_numberOfSockets{1};
     size_t d_maxInFlightQueriesPerConn{1};
     size_t d_tcpConcurrentConnectionsLimit{0};
@@ -815,6 +819,7 @@ public:
   std::vector<int> sockets;
   StopWatch sw;
   QPSLimiter qps;
+  std::shared_ptr<XskWorker> xskInfo{nullptr};
   std::atomic<uint64_t> idOffset{0};
   size_t socketsOffset{0};
   double latencyUsec{0.0};
@@ -837,7 +842,14 @@ private:
   uint8_t consecutiveSuccessfulChecks{0};
   bool d_stopped{false};
 public:
-
+  void updateStatisticsInfo()
+  {
+    auto delta = sw.udiffAndSet() / 1000000.0;
+    queryLoad.store(1.0 * (queries.load() - prev.queries.load()) / delta);
+    dropRate.store(1.0 * (reuseds.load() - prev.reuseds.load()) / delta);
+    prev.queries.store(queries.load());
+    prev.reuseds.store(reuseds.load());
+  }
   void start();
 
   bool isUp() const
@@ -966,6 +978,19 @@ public:
   void restoreState(uint16_t id, InternalQueryState&&);
   std::optional<InternalQueryState> getState(uint16_t id);
 
+#ifdef HAVE_XSK
+  void registerXsk(std::shared_ptr<XskSocket>& xsk)
+  {
+    xskInfo = XskWorker::create();
+    if (d_config.sourceAddr.sin4.sin_family == 0) {
+      throw runtime_error("invalid source addr");
+    }
+    xsk->addWorker(xskInfo, d_config.sourceAddr, getProtocol() != dnsdist::Protocol::DoUDP);
+    memcpy(d_config.sourceMACAddr, xsk->source, sizeof(MACAddr));
+    xskInfo->sharedEmptyFrameOffset = xsk->sharedEmptyFrameOffset;
+  }
+#endif /* HAVE_XSK */
+
   dnsdist::Protocol getProtocol() const
   {
     if (isDoH()) {
@@ -1138,7 +1163,7 @@ void setLuaSideEffect();   // set to report a side effect, cancelling all _no_ s
 bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect
 void resetLuaSideEffect(); // reset to indeterminate state
 
-bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength);
+bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote);
 
 bool checkQueryHeaders(const struct dnsheader* dh, ClientState& cs);
 
@@ -1163,7 +1188,7 @@ bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRu
 bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop);
 bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, DNSResponse& dr, bool muted);
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query);
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool actuallySend = true);
 
 ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck = false);
 bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote);
index 9ed171375a9ee1eb0d3d16910afba6815bdc6e62..c95629daac9797c3727630ca1be79c29600620ed 100644 (file)
@@ -254,7 +254,8 @@ dnsdist_SOURCES = \
        tcpiohandler.cc tcpiohandler.hh \
        threadname.hh threadname.cc \
        uuid-utils.hh uuid-utils.cc \
-       xpf.cc xpf.hh
+       xpf.cc xpf.hh \
+       xsk.cc xsk.hh
 
 testrunner_SOURCES = \
        base64.hh \
@@ -361,7 +362,8 @@ testrunner_SOURCES = \
        testrunner.cc \
        threadname.hh threadname.cc \
        uuid-utils.hh uuid-utils.cc \
-       xpf.cc xpf.hh
+       xpf.cc xpf.hh \
+       xsk.cc xsk.hh
 
 dnsdist_LDFLAGS = \
        $(AM_LDFLAGS) \
@@ -411,6 +413,13 @@ if HAVE_LIBSSL
 dnsdist_LDADD += $(LIBSSL_LIBS)
 endif
 
+if HAVE_XSK
+dnsdist_LDADD += -lbpf
+dnsdist_LDADD += -lxdp
+testrunner_LDADD += -lbpf
+testrunner_LDADD += -lxdp
+endif
+
 if HAVE_LIBCRYPTO
 dnsdist_LDADD += $(LIBCRYPTO_LDFLAGS) $(LIBCRYPTO_LIBS)
 testrunner_LDADD += $(LIBCRYPTO_LDFLAGS) $(LIBCRYPTO_LIBS)
index d9429c93f61ff2f393650f3dc7f2a4e6a9b2d7a4..d9f6c719ddc392b33d5567a051cab371d8d3c7e7 100644 (file)
@@ -40,6 +40,7 @@ PDNS_ENABLE_FUZZ_TARGETS
 PDNS_WITH_RE2
 DNSDIST_ENABLE_DNSCRYPT
 PDNS_WITH_EBPF
+PDNS_WITH_XSK
 PDNS_WITH_NET_SNMP
 PDNS_WITH_LIBCAP
 
index 8c3eefc2399e1a25f2aa582ec5957ea9d82e6e3d..bd7592545a49324d2e681f89bddbcd042ea3cc35 100644 (file)
@@ -714,9 +714,11 @@ void DownstreamState::submitHealthCheckResult(bool initial, bool newResult)
     setUpStatus(newResult);
     if (newResult == false) {
       currentCheckFailures++;
-      auto stats = d_lazyHealthCheckStats.lock();
-      stats->d_status = LazyHealthCheckStats::LazyStatus::Failed;
-      updateNextLazyHealthCheck(*stats, false);
+      if (d_config.availability == DownstreamState::Availability::Lazy) {
+        auto stats = d_lazyHealthCheckStats.lock();
+        stats->d_status = LazyHealthCheckStats::LazyStatus::Failed;
+        updateNextLazyHealthCheck(*stats, false);
+      }
     }
     return;
   }
index 36805573e7842624ecefe7ecc223e7270f92250a..d60c9dc5841020c42c3f954ec43b8d722d3c77c3 100644 (file)
 
 bool g_verboseHealthChecks{false};
 
-struct HealthCheckData
-{
-  enum class TCPState : uint8_t
-  {
-    WritingQuery,
-    ReadingResponseSize,
-    ReadingResponse
-  };
-
-  HealthCheckData(FDMultiplexer& mplexer, std::shared_ptr<DownstreamState> downstream, DNSName&& checkName, uint16_t checkType, uint16_t checkClass, uint16_t queryID) :
-    d_ds(std::move(downstream)), d_mplexer(mplexer), d_udpSocket(-1), d_checkName(std::move(checkName)), d_checkType(checkType), d_checkClass(checkClass), d_queryID(queryID)
-  {
-  }
-
-  const std::shared_ptr<DownstreamState> d_ds;
-  FDMultiplexer& d_mplexer;
-  std::unique_ptr<TCPIOHandler> d_tcpHandler{nullptr};
-  std::unique_ptr<IOStateHandler> d_ioState{nullptr};
-  PacketBuffer d_buffer;
-  Socket d_udpSocket;
-  DNSName d_checkName;
-  struct timeval d_ttd
-  {
-    0, 0
-  };
-  size_t d_bufferPos{0};
-  uint16_t d_checkType;
-  uint16_t d_checkClass;
-  uint16_t d_queryID;
-  TCPState d_tcpState{TCPState::WritingQuery};
-  bool d_initial{false};
-};
-
-static bool handleResponse(std::shared_ptr<HealthCheckData>& data)
+bool handleResponse(std::shared_ptr<HealthCheckData>& data)
 {
   const auto& downstream = data->d_ds;
   try {
@@ -207,7 +174,7 @@ static void healthCheckUDPCallback(int descriptor, FDMultiplexer::funcparam_t& p
       }
       ++data->d_ds->d_healthCheckMetrics.d_networkErrors;
       data->d_ds->submitHealthCheckResult(data->d_initial, false);
-      data->d_mplexer.removeReadFD(descriptor);
+      data->d_mplexer->removeReadFD(descriptor);
       return;
     }
   } while (got < 0);
@@ -224,7 +191,7 @@ static void healthCheckUDPCallback(int descriptor, FDMultiplexer::funcparam_t& p
     return;
   }
 
-  data->d_mplexer.removeReadFD(descriptor);
+  data->d_mplexer->removeReadFD(descriptor);
   data->d_ds->submitHealthCheckResult(data->d_initial, handleResponse(data));
 }
 
@@ -300,36 +267,51 @@ static void healthCheckTCPCallback(int descriptor, FDMultiplexer::funcparam_t& p
   }
 }
 
-bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& downstream, bool initialCheck)
+PacketBuffer getHealthCheckPacket(const std::shared_ptr<DownstreamState>& downstream, FDMultiplexer* mplexer, std::shared_ptr<HealthCheckData>& data)
 {
-  try {
-    uint16_t queryID = dnsdist::getRandomDNSID();
-    DNSName checkName = downstream->d_config.checkName;
-    uint16_t checkType = downstream->d_config.checkType.getCode();
-    uint16_t checkClass = downstream->d_config.checkClass;
-    dnsheader checkHeader{};
-    memset(&checkHeader, 0, sizeof(checkHeader));
-
-    checkHeader.qdcount = htons(1);
-    checkHeader.id = queryID;
-
-    checkHeader.rd = true;
-    if (downstream->d_config.setCD) {
-      checkHeader.cd = true;
-    }
+  uint16_t queryID = dnsdist::getRandomDNSID();
+  DNSName checkName = downstream->d_config.checkName;
+  uint16_t checkType = downstream->d_config.checkType.getCode();
+  uint16_t checkClass = downstream->d_config.checkClass;
+  dnsheader checkHeader{};
+  memset(&checkHeader, 0, sizeof(checkHeader));
+
+  checkHeader.qdcount = htons(1);
+  checkHeader.id = queryID;
+
+  checkHeader.rd = true;
+  if (downstream->d_config.setCD) {
+    checkHeader.cd = true;
+  }
 
-    if (downstream->d_config.checkFunction) {
-      auto lock = g_lua.lock();
-      auto ret = downstream->d_config.checkFunction(checkName, checkType, checkClass, &checkHeader);
-      checkName = std::get<0>(ret);
-      checkType = std::get<1>(ret);
-      checkClass = std::get<2>(ret);
-    }
+  if (downstream->d_config.checkFunction) {
+    auto lock = g_lua.lock();
+    auto ret = downstream->d_config.checkFunction(checkName, checkType, checkClass, &checkHeader);
+    checkName = std::get<0>(ret);
+    checkType = std::get<1>(ret);
+    checkClass = std::get<2>(ret);
+  }
+  PacketBuffer packet;
+  GenericDNSPacketWriter<PacketBuffer> dpw(packet, checkName, checkType, checkClass);
+  dnsheader* requestHeader = dpw.getHeader();
+  *requestHeader = checkHeader;
+  data = std::make_shared<HealthCheckData>(mplexer, downstream, std::move(checkName), checkType, checkClass, queryID);
+  return packet;
+}
 
-    PacketBuffer packet;
-    GenericDNSPacketWriter<PacketBuffer> dpw(packet, checkName, checkType, checkClass);
-    dnsheader* requestHeader = dpw.getHeader();
-    *requestHeader = checkHeader;
+void setHealthCheckTime(const std::shared_ptr<DownstreamState>& downstream, const std::shared_ptr<HealthCheckData>& data)
+{
+  gettimeofday(&data->d_ttd, nullptr);
+  data->d_ttd.tv_sec += static_cast<decltype(data->d_ttd.tv_sec)>(downstream->d_config.checkTimeout / 1000); /* ms to seconds */
+  data->d_ttd.tv_usec += static_cast<decltype(data->d_ttd.tv_usec)>((downstream->d_config.checkTimeout % 1000) * 1000); /* remaining ms to us */
+  normalizeTV(data->d_ttd);
+}
+
+bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& downstream, bool initialCheck)
+{
+  try {
+    std::shared_ptr<HealthCheckData> data;
+    PacketBuffer packet = getHealthCheckPacket(downstream, mplexer.get(), data);
 
     /* we need to compute that _before_ adding the proxy protocol payload */
     uint16_t packetSize = packet.size();
@@ -368,13 +350,9 @@ bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared
       sock.bind(downstream->d_config.sourceAddr, false);
     }
 
-    auto data = std::make_shared<HealthCheckData>(*mplexer, downstream, std::move(checkName), checkType, checkClass, queryID);
     data->d_initial = initialCheck;
 
-    gettimeofday(&data->d_ttd, nullptr);
-    data->d_ttd.tv_sec += static_cast<decltype(data->d_ttd.tv_sec)>(downstream->d_config.checkTimeout / 1000); /* ms to seconds */
-    data->d_ttd.tv_usec += static_cast<decltype(data->d_ttd.tv_usec)>((downstream->d_config.checkTimeout % 1000) * 1000); /* remaining ms to us */
-    normalizeTV(data->d_ttd);
+    setHealthCheckTime(downstream, data);
 
     if (!downstream->doHealthcheckOverTCP()) {
       sock.connect(downstream->d_config.remote);
@@ -383,7 +361,7 @@ bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared
       if (sent < 0) {
         int ret = errno;
         if (g_verboseHealthChecks) {
-          infolog("Error while sending a health check query (ID %d) to backend %s: %d", queryID, downstream->getNameWithAddr(), ret);
+          infolog("Error while sending a health check query (ID %d) to backend %s: %d", data->d_queryID, downstream->getNameWithAddr(), ret);
         }
         return false;
       }
index e9da6c66de8b0eb56baadb0c0fc57e78461822d0..4f1940643e33da74c5a209e775fbc5952bf9a9b7 100644 (file)
 #include "dnsdist.hh"
 #include "mplexer.hh"
 #include "sstuff.hh"
+#include "tcpiohandler-mplexer.hh"
 
 extern bool g_verboseHealthChecks;
 
 bool queueHealthCheck(std::unique_ptr<FDMultiplexer>& mplexer, const std::shared_ptr<DownstreamState>& downstream, bool initial = false);
 void handleQueuedHealthChecks(FDMultiplexer& mplexer, bool initial = false);
+
+struct HealthCheckData
+{
+  enum class TCPState : uint8_t
+  {
+    WritingQuery,
+    ReadingResponseSize,
+    ReadingResponse
+  };
+
+  HealthCheckData(FDMultiplexer* mplexer, std::shared_ptr<DownstreamState> downstream, DNSName&& checkName, uint16_t checkType, uint16_t checkClass, uint16_t queryID) :
+    d_ds(std::move(downstream)), d_mplexer(mplexer), d_udpSocket(-1), d_checkName(std::move(checkName)), d_checkType(checkType), d_checkClass(checkClass), d_queryID(queryID)
+  {
+  }
+
+  const std::shared_ptr<DownstreamState> d_ds;
+  FDMultiplexer* d_mplexer{nullptr};
+  std::unique_ptr<TCPIOHandler> d_tcpHandler{nullptr};
+  std::unique_ptr<IOStateHandler> d_ioState{nullptr};
+  PacketBuffer d_buffer;
+  Socket d_udpSocket;
+  DNSName d_checkName;
+  struct timeval d_ttd
+  {
+    0, 0
+  };
+  size_t d_bufferPos{0};
+  uint16_t d_checkType;
+  uint16_t d_checkClass;
+  uint16_t d_queryID;
+  TCPState d_tcpState{TCPState::WritingQuery};
+  bool d_initial{false};
+};
+
+PacketBuffer getHealthCheckPacket(const std::shared_ptr<DownstreamState>& ds, FDMultiplexer* mplexer, std::shared_ptr<HealthCheckData>& data);
+void setHealthCheckTime(const std::shared_ptr<DownstreamState>& ds, const std::shared_ptr<HealthCheckData>& data);
+bool handleResponse(std::shared_ptr<HealthCheckData>& data);
+
diff --git a/pdns/dnsdistdist/m4/pdns_with_xsk.m4 b/pdns/dnsdistdist/m4/pdns_with_xsk.m4
new file mode 100644 (file)
index 0000000..b45c9f3
--- /dev/null
@@ -0,0 +1,22 @@
+AC_DEFUN([PDNS_WITH_XSK],[
+  AC_MSG_CHECKING([if we have xsk support])
+  AC_ARG_WITH([xsk],
+    AS_HELP_STRING([--with-xsk],[enable xsk support @<:@default=auto@:>@]),
+    [with_xsk=$withval],
+    [with_xsk=auto],
+  )
+  AC_MSG_RESULT([$with_xsk])
+
+  AS_IF([test "x$with_xsk" != "xno"], [
+    AS_IF([test "x$with_xsk" = "xyes" -o "x$with_xsk" = "xauto"], [
+      AC_CHECK_HEADERS([xdp/xsk.h], xsk_headers=yes, xsk_headers=no)
+    ])
+  ])
+  AS_IF([test "x$with_xsk" = "xyes"], [
+    AS_IF([test x"$xsk_headers" = "no"], [
+      AC_MSG_ERROR([XSK support requested but required libxdp were not found])
+    ])
+  ])
+  AS_IF([test x"$xsk_headers" = "xyes" ], [ AC_DEFINE([HAVE_XSK], [1], [Define if using eBPF.]) ])
+  AM_CONDITIONAL([HAVE_XSK], [test x"$xsk_headers" = "xyes" ])
+])
index a4c887aeef4db5718032a5fc709f76fc40efbea2..fd37dda3196d035b350fbeee8dbd904b76235c72 100644 (file)
@@ -76,7 +76,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::sha
   return ProcessQueryResult::Drop;
 }
 
-bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength)
+bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote)
 {
   return true;
 }
diff --git a/pdns/dnsdistdist/xsk.cc b/pdns/dnsdistdist/xsk.cc
new file mode 120000 (symlink)
index 0000000..3258a93
--- /dev/null
@@ -0,0 +1 @@
+../xsk.cc
\ No newline at end of file
diff --git a/pdns/dnsdistdist/xsk.hh b/pdns/dnsdistdist/xsk.hh
new file mode 120000 (symlink)
index 0000000..4b1bba7
--- /dev/null
@@ -0,0 +1 @@
+../xsk.hh
\ No newline at end of file
index 4ef0b8f764a3924b74e8d2f2c2da94c3c8f3bd3a..7d8b2e4c2d6ab27b152cef74604f45da93da7b12 100644 (file)
@@ -83,6 +83,7 @@
 #undef IP_PKTINFO
 #endif
 
+using MACAddr = uint8_t[6];
 union ComboAddress {
   struct sockaddr_in sin4;
   struct sockaddr_in6 sin6;
@@ -123,6 +124,24 @@ union ComboAddress {
     return rhs.operator<(*this);
   }
 
+  struct addressPortOnlyHash
+  {
+    uint32_t operator()(const ComboAddress& ca) const
+    {
+      const unsigned char* start = nullptr;
+      if (ca.sin4.sin_family == AF_INET) {
+        start = reinterpret_cast<const unsigned char*>(&ca.sin4.sin_addr.s_addr);
+        auto tmp = burtle(start, 4, 0);
+        return burtle(reinterpret_cast<const uint8_t*>(&ca.sin4.sin_port), 2, tmp);
+      }
+      {
+        start = reinterpret_cast<const unsigned char*>(&ca.sin6.sin6_addr.s6_addr);
+        auto tmp = burtle(start, 16, 0);
+        return burtle(reinterpret_cast<const unsigned char*>(&ca.sin6.sin6_port), 2, tmp);
+      }
+    }
+  };
+
   struct addressOnlyHash
   {
     uint32_t operator()(const ComboAddress& ca) const
@@ -347,11 +366,14 @@ union ComboAddress {
 
   void truncate(unsigned int bits) noexcept;
 
-  uint16_t getPort() const
+  uint16_t getNetworkOrderPort() const noexcept
   {
-    return ntohs(sin4.sin_port);
+    return sin4.sin_port;
+  }
+  uint16_t getPort() const noexcept
+  {
+    return ntohs(getNetworkOrderPort());
   }
-
   void setPort(uint16_t port)
   {
     sin4.sin_port = htons(port);
index 56f31d2348b1ba07936ccbaee1e153a452e216ca..c29fdabba5e9c61618e8ee0fcea6b2f330587f4e 100644 (file)
@@ -57,7 +57,7 @@ bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMs
   return false;
 }
 
-bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query)
+bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, bool)
 {
   return true;
 }
@@ -74,6 +74,8 @@ bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&)
   return false;
 }
 
+std::vector<std::shared_ptr<XskSocket>> g_xsk;
+
 BOOST_AUTO_TEST_SUITE(test_dnsdist_cc)
 
 static const uint16_t ECSSourcePrefixV4 = 24;
diff --git a/pdns/xsk.cc b/pdns/xsk.cc
new file mode 100644 (file)
index 0000000..b765931
--- /dev/null
@@ -0,0 +1,828 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include "gettime.hh"
+#include "xsk.hh"
+
+#include <algorithm>
+#include <cassert>
+#include <cstdint>
+#include <cstring>
+#include <fcntl.h>
+#include <iterator>
+#include <linux/bpf.h>
+#include <linux/if_ether.h>
+#include <linux/if_link.h>
+#include <linux/if_xdp.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <linux/tcp.h>
+#include <linux/udp.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <poll.h>
+#include <stdexcept>
+#include <sys/eventfd.h>
+#include <sys/ioctl.h>
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <sys/timerfd.h>
+#include <unistd.h>
+#include <vector>
+
+#ifdef HAVE_XSK
+#include <bpf/bpf.h>
+#include <bpf/libbpf.h>
+extern "C"
+{
+#include <xdp/libxdp.h>
+}
+
+constexpr bool XskSocket::isPowOfTwo(uint32_t value) noexcept
+{
+  return value != 0 && (value & (value - 1)) == 0;
+}
+int XskSocket::firstTimeout()
+{
+  if (waitForDelay.empty()) {
+    return -1;
+  }
+  timespec now;
+  gettime(&now);
+  const auto& firstTime = waitForDelay.top()->sendTime;
+  const auto res = timeDifference(now, firstTime);
+  if (res <= 0) {
+    return 0;
+  }
+  return res;
+}
+XskSocket::XskSocket(size_t frameNum_, const std::string& ifName_, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_) :
+  frameNum(frameNum_), queueId(queue_id), ifName(ifName_), poolName(poolName_), socket(nullptr, xsk_socket__delete), sharedEmptyFrameOffset(std::make_shared<LockGuarded<vector<uint64_t>>>())
+{
+  if (!isPowOfTwo(frameNum_) || !isPowOfTwo(frameSize)
+      || !isPowOfTwo(fqCapacity) || !isPowOfTwo(cqCapacity) || !isPowOfTwo(rxCapacity) || !isPowOfTwo(txCapacity)) {
+    throw std::runtime_error("The number of frame , the size of frame and the capacity of rings must is a pow of 2");
+  }
+  getMACFromIfName();
+
+  memset(&cq, 0, sizeof(cq));
+  memset(&fq, 0, sizeof(fq));
+  memset(&tx, 0, sizeof(tx));
+  memset(&rx, 0, sizeof(rx));
+  xsk_umem_config umemCfg;
+  umemCfg.fill_size = fqCapacity;
+  umemCfg.comp_size = cqCapacity;
+  umemCfg.frame_size = frameSize;
+  umemCfg.frame_headroom = XSK_UMEM__DEFAULT_FRAME_HEADROOM;
+  umemCfg.flags = 0;
+  umem.umemInit(frameNum_ * frameSize, &cq, &fq, &umemCfg);
+  {
+    xsk_socket_config socketCfg;
+    socketCfg.rx_size = rxCapacity;
+    socketCfg.tx_size = txCapacity;
+    socketCfg.bind_flags = XDP_USE_NEED_WAKEUP;
+    socketCfg.xdp_flags = XDP_FLAGS_SKB_MODE;
+    socketCfg.libxdp_flags = XSK_LIBBPF_FLAGS__INHIBIT_PROG_LOAD;
+    xsk_socket* tmp = nullptr;
+    auto ret = xsk_socket__create(&tmp, ifName.c_str(), queue_id, umem.umem, &rx, &tx, &socketCfg);
+    if (ret != 0) {
+      throw std::runtime_error("Error creating a xsk socket of if_name" + ifName + stringerror(ret));
+    }
+    socket = std::unique_ptr<xsk_socket, void (*)(xsk_socket*)>(tmp, xsk_socket__delete);
+  }
+  for (uint64_t i = 0; i < frameNum; i++) {
+    uniqueEmptyFrameOffset.push_back(i * frameSize + XDP_PACKET_HEADROOM);
+  }
+  fillFq(fqCapacity);
+  const auto xskfd = xskFd();
+  fds.push_back(pollfd{
+    .fd = xskfd,
+    .events = POLLIN,
+    .revents = 0});
+  const auto xskMapFd = FDWrapper(bpf_obj_get(xskMapPath.c_str()));
+  if (xskMapFd.getHandle() < 0) {
+    throw std::runtime_error("Error get BPF map from path");
+  }
+  auto ret = bpf_map_update_elem(xskMapFd.getHandle(), &queue_id, &xskfd, 0);
+  if (ret) {
+    throw std::runtime_error("Error insert into xsk_map");
+  }
+}
+void XskSocket::fillFq(uint32_t fillSize) noexcept
+{
+  {
+    auto frames = sharedEmptyFrameOffset->lock();
+    if (frames->size() < holdThreshold) {
+      const auto moveSize = std::min(holdThreshold - frames->size(), uniqueEmptyFrameOffset.size());
+      if (moveSize > 0) {
+        frames->insert(frames->end(), std::make_move_iterator(uniqueEmptyFrameOffset.end() - moveSize), std::make_move_iterator(uniqueEmptyFrameOffset.end()));
+      }
+    }
+  }
+  if (uniqueEmptyFrameOffset.size() < fillSize) {
+    return;
+  }
+  uint32_t idx;
+  if (xsk_ring_prod__reserve(&fq, fillSize, &idx) != fillSize) {
+    return;
+  }
+  for (uint32_t i = 0; i < fillSize; i++) {
+    *xsk_ring_prod__fill_addr(&fq, idx++) = uniqueEmptyFrameOffset.back();
+    uniqueEmptyFrameOffset.pop_back();
+  }
+  xsk_ring_prod__submit(&fq, idx);
+}
+int XskSocket::wait(int timeout)
+{
+  return poll(fds.data(), fds.size(), static_cast<int>(std::min(static_cast<uint32_t>(timeout), static_cast<uint32_t>(firstTimeout()))));
+}
+[[nodiscard]] uint64_t XskSocket::frameOffset(const XskPacket& packet) const noexcept
+{
+  return reinterpret_cast<uint64_t>(packet.frame) - reinterpret_cast<uint64_t>(umem.bufBase);
+}
+
+int XskSocket::xskFd() const noexcept { return xsk_socket__fd(socket.get()); }
+
+void XskSocket::send(std::vector<XskPacketPtr>& packets)
+{
+  const auto packetSize = packets.size();
+  if (packetSize == 0) {
+    return;
+  }
+  uint32_t idx;
+  if (xsk_ring_prod__reserve(&tx, packetSize, &idx) != packets.size()) {
+    return;
+  }
+
+  for (const auto& i : packets) {
+    *xsk_ring_prod__tx_desc(&tx, idx++) = {
+      .addr = frameOffset(*i),
+      .len = i->FrameLen(),
+      .options = 0};
+  }
+  xsk_ring_prod__submit(&tx, packetSize);
+  packets.clear();
+}
+std::vector<XskPacketPtr> XskSocket::recv(uint32_t recvSizeMax, uint32_t* failedCount)
+{
+  uint32_t idx;
+  std::vector<XskPacketPtr> res;
+  const auto recvSize = xsk_ring_cons__peek(&rx, recvSizeMax, &idx);
+  if (recvSize <= 0) {
+    return res;
+  }
+  const auto baseAddr = reinterpret_cast<uint64_t>(umem.bufBase);
+  uint32_t count = 0;
+  for (uint32_t i = 0; i < recvSize; i++) {
+    const auto* desc = xsk_ring_cons__rx_desc(&rx, idx++);
+    auto ptr = std::make_unique<XskPacket>(reinterpret_cast<void*>(desc->addr + baseAddr), desc->len, frameSize);
+    if (!ptr->parse()) {
+      ++count;
+      uniqueEmptyFrameOffset.push_back(frameOffset(*ptr));
+    }
+    else {
+      res.push_back(std::move(ptr));
+    }
+  }
+  xsk_ring_cons__release(&rx, recvSize);
+  if (failedCount) {
+    *failedCount = count;
+  }
+  return res;
+}
+void XskSocket::pickUpReadyPacket(std::vector<XskPacketPtr>& packets)
+{
+  timespec now;
+  gettime(&now);
+  while (!waitForDelay.empty() && timeDifference(now, waitForDelay.top()->sendTime) <= 0) {
+    auto& top = const_cast<XskPacketPtr&>(waitForDelay.top());
+    packets.push_back(std::move(top));
+    waitForDelay.pop();
+  }
+}
+void XskSocket::recycle(size_t size) noexcept
+{
+  uint32_t idx;
+  const auto completeSize = xsk_ring_cons__peek(&cq, size, &idx);
+  if (completeSize <= 0) {
+    return;
+  }
+  for (uint32_t i = 0; i < completeSize; ++i) {
+    uniqueEmptyFrameOffset.push_back(*xsk_ring_cons__comp_addr(&cq, idx++));
+  }
+  xsk_ring_cons__release(&cq, completeSize);
+}
+
+void XskSocket::XskUmem::umemInit(size_t memSize, xsk_ring_cons* cq, xsk_ring_prod* fq, xsk_umem_config* config)
+{
+  size = memSize;
+  bufBase = static_cast<uint8_t*>(mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0));
+  if (bufBase == MAP_FAILED) {
+    throw std::runtime_error("mmap failed");
+  }
+  auto ret = xsk_umem__create(&umem, bufBase, size, fq, cq, config);
+  if (ret != 0) {
+    munmap(bufBase, size);
+    throw std::runtime_error("Error creating a umem of size" + std::to_string(size) + stringerror(ret));
+  }
+}
+
+XskSocket::XskUmem::~XskUmem()
+{
+  if (umem) {
+    xsk_umem__delete(umem);
+  }
+  if (bufBase) {
+    munmap(bufBase, size);
+  }
+}
+
+bool XskPacket::parse()
+{
+  // payloadEnd must bigger than payload + sizeof(ethhdr) + sizoef(iphdr) + sizeof(udphdr)
+  auto* eth = reinterpret_cast<ethhdr*>(frame);
+  uint8_t l4Protocol;
+  if (eth->h_proto == htons(ETH_P_IP)) {
+    auto* ip = reinterpret_cast<iphdr*>(eth + 1);
+    if (ip->ihl != static_cast<uint8_t>(sizeof(iphdr) >> 2)) {
+      // ip->ihl*4 != sizeof(iphdr)
+      // ip options is not supported now!
+      return false;
+    }
+    // check ip.check == ipv4Checksum() is not needed!
+    // We check it in BPF program
+    from = makeComboAddressFromRaw(4, reinterpret_cast<const char*>(&ip->saddr), sizeof(ip->saddr));
+    to = makeComboAddressFromRaw(4, reinterpret_cast<const char*>(&ip->daddr), sizeof(ip->daddr));
+    l4Protocol = ip->protocol;
+    l4Header = reinterpret_cast<uint8_t*>(ip + 1);
+    payloadEnd = std::min(reinterpret_cast<uint8_t*>(ip) + ntohs(ip->tot_len), payloadEnd);
+  }
+  else if (eth->h_proto == htons(ETH_P_IPV6)) {
+    auto* ipv6 = reinterpret_cast<ipv6hdr*>(eth + 1);
+    l4Header = reinterpret_cast<uint8_t*>(ipv6 + 1);
+    if (l4Header >= payloadEnd) {
+      return false;
+    }
+    from = makeComboAddressFromRaw(6, reinterpret_cast<const char*>(&ipv6->saddr), sizeof(ipv6->saddr));
+    to = makeComboAddressFromRaw(6, reinterpret_cast<const char*>(&ipv6->daddr), sizeof(ipv6->daddr));
+    l4Protocol = ipv6->nexthdr;
+    payloadEnd = std::min(l4Header + ntohs(ipv6->payload_len), payloadEnd);
+  }
+  else {
+    return false;
+  }
+  if (l4Protocol == IPPROTO_UDP) {
+    // check udp.check == ipv4Checksum() is not needed!
+    // We check it in BPF program
+    auto* udp = reinterpret_cast<udphdr*>(l4Header);
+    payload = l4Header + sizeof(udphdr);
+    // Because of XskPacket::setHeader
+    // payload = payloadEnd should be allow
+    if (payload > payloadEnd) {
+      return false;
+    }
+    payloadEnd = std::min(l4Header + ntohs(udp->len), payloadEnd);
+    from.setPort(ntohs(udp->source));
+    to.setPort(ntohs(udp->dest));
+    return true;
+  }
+  if (l4Protocol == IPPROTO_TCP) {
+    // check tcp.check == ipv4Checksum() is not needed!
+    // We check it in BPF program
+    auto* tcp = reinterpret_cast<tcphdr*>(l4Header);
+    if (tcp->doff != static_cast<uint32_t>(sizeof(tcphdr) >> 2)) {
+      // tcp is not supported now!
+      return false;
+    }
+    payload = l4Header + sizeof(tcphdr);
+    //
+    if (payload > payloadEnd) {
+      return false;
+    }
+    from.setPort(ntohs(tcp->source));
+    to.setPort(ntohs(tcp->dest));
+    flags |= TCP;
+    return true;
+  }
+  // ipv6 extension header is not supported now!
+  return false;
+}
+
+uint32_t XskPacket::dataLen() const noexcept
+{
+  return payloadEnd - payload;
+}
+uint32_t XskPacket::FrameLen() const noexcept
+{
+  return payloadEnd - frame;
+}
+size_t XskPacket::capacity() const noexcept
+{
+  return frameEnd - payloadEnd;
+}
+
+void XskPacket::changeDirectAndUpdateChecksum() noexcept
+{
+  auto* eth = reinterpret_cast<ethhdr*>(frame);
+  {
+    uint8_t tmp[ETH_ALEN];
+    static_assert(sizeof(tmp) == sizeof(eth->h_dest), "Size Error");
+    static_assert(sizeof(tmp) == sizeof(eth->h_source), "Size Error");
+    memcpy(tmp, eth->h_dest, sizeof(tmp));
+    memcpy(eth->h_dest, eth->h_source, sizeof(tmp));
+    memcpy(eth->h_source, tmp, sizeof(tmp));
+  }
+  if (eth->h_proto == htons(ETH_P_IPV6)) {
+    // IPV6
+    auto* ipv6 = reinterpret_cast<ipv6hdr*>(eth + 1);
+    std::swap(ipv6->daddr, ipv6->saddr);
+    if (ipv6->nexthdr == IPPROTO_UDP) {
+      // UDP
+      auto* udp = reinterpret_cast<udphdr*>(ipv6 + 1);
+      std::swap(udp->dest, udp->source);
+      udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+      udp->check = 0;
+      udp->check = tcp_udp_v6_checksum();
+    }
+    else {
+      // TCP
+      auto* tcp = reinterpret_cast<tcphdr*>(ipv6 + 1);
+      std::swap(tcp->dest, tcp->source);
+      // TODO
+    }
+    rewriteIpv6Header(ipv6);
+  }
+  else {
+    // IPV4
+    auto* ipv4 = reinterpret_cast<iphdr*>(eth + 1);
+    std::swap(ipv4->daddr, ipv4->saddr);
+    if (ipv4->protocol == IPPROTO_UDP) {
+      // UDP
+      auto* udp = reinterpret_cast<udphdr*>(ipv4 + 1);
+      std::swap(udp->dest, udp->source);
+      udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+      udp->check = 0;
+      udp->check = tcp_udp_v4_checksum();
+    }
+    else {
+      // TCP
+      auto* tcp = reinterpret_cast<tcphdr*>(ipv4 + 1);
+      std::swap(tcp->dest, tcp->source);
+      // TODO
+    }
+    rewriteIpv4Header(ipv4);
+  }
+}
+void XskPacket::rewriteIpv4Header(void* ipv4header) noexcept
+{
+  auto* ipv4 = static_cast<iphdr*>(ipv4header);
+  ipv4->version = 4;
+  ipv4->ihl = sizeof(iphdr) / 4;
+  ipv4->tos = 0;
+  ipv4->tot_len = htons(payloadEnd - reinterpret_cast<uint8_t*>(ipv4));
+  ipv4->id = 0;
+  ipv4->frag_off = 0;
+  ipv4->ttl = DefaultTTL;
+  ipv4->check = 0;
+  ipv4->check = ipv4Checksum();
+}
+void XskPacket::rewriteIpv6Header(void* ipv6header) noexcept
+{
+  auto* ipv6 = static_cast<ipv6hdr*>(ipv6header);
+  ipv6->version = 6;
+  ipv6->priority = 0;
+  ipv6->payload_len = htons(payloadEnd - reinterpret_cast<uint8_t*>(ipv6 + 1));
+  ipv6->hop_limit = DefaultTTL;
+  memset(&ipv6->flow_lbl, 0, sizeof(ipv6->flow_lbl));
+}
+
+bool XskPacket::isIPV6() const noexcept
+{
+  const auto* eth = reinterpret_cast<ethhdr*>(frame);
+  return eth->h_proto == htons(ETH_P_IPV6);
+}
+XskPacket::XskPacket(void* frame_, size_t dataSize, size_t frameSize) :
+  frame(static_cast<uint8_t*>(frame_)), payloadEnd(static_cast<uint8_t*>(frame) + dataSize), frameEnd(static_cast<uint8_t*>(frame) + frameSize - XDP_PACKET_HEADROOM)
+{
+}
+PacketBuffer XskPacket::clonePacketBuffer() const
+{
+  const auto size = dataLen();
+  PacketBuffer tmp(size);
+  memcpy(tmp.data(), payload, size);
+  return tmp;
+}
+void XskPacket::cloneIntoPacketBuffer(PacketBuffer& buffer) const
+{
+  const auto size = dataLen();
+  buffer.resize(size);
+  memcpy(buffer.data(), payload, size);
+}
+bool XskPacket::setPayload(const PacketBuffer& buf)
+{
+  const auto bufSize = buf.size();
+  if (bufSize == 0 || bufSize > capacity()) {
+    return false;
+  }
+  flags |= UPDATE;
+  memcpy(payload, buf.data(), bufSize);
+  payloadEnd = payload + bufSize;
+  return true;
+}
+void XskPacket::addDelay(const int relativeMilliseconds) noexcept
+{
+  gettime(&sendTime);
+  sendTime.tv_nsec += static_cast<uint64_t>(relativeMilliseconds) * 1000000L;
+  sendTime.tv_sec += sendTime.tv_nsec / 1000000000L;
+  sendTime.tv_nsec %= 1000000000L;
+}
+bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept
+{
+  return s1->sendTime < s2->sendTime;
+}
+const ComboAddress& XskPacket::getFromAddr() const noexcept
+{
+  return from;
+}
+const ComboAddress& XskPacket::getToAddr() const noexcept
+{
+  return to;
+}
+void XskWorker::notify(int fd)
+{
+  uint64_t value = 1;
+  ssize_t res = 0;
+  while ((res = write(fd, &value, sizeof(value))) == EINTR) {
+  }
+  if (res != sizeof(value)) {
+    throw runtime_error("Unable Wake Up XskSocket Failed");
+  }
+}
+XskWorker::XskWorker() :
+  workerWaker(createEventfd()), xskSocketWaker(createEventfd())
+{
+}
+void* XskPacket::payloadData()
+{
+  return reinterpret_cast<void*>(payload);
+}
+const void* XskPacket::payloadData() const
+{
+  return reinterpret_cast<const void*>(payload);
+}
+void XskPacket::setAddr(const ComboAddress& from_, MACAddr fromMAC, const ComboAddress& to_, MACAddr toMAC, bool tcp) noexcept
+{
+  auto* eth = reinterpret_cast<ethhdr*>(frame);
+  memcpy(eth->h_dest, &toMAC[0], sizeof(MACAddr));
+  memcpy(eth->h_source, &fromMAC[0], sizeof(MACAddr));
+  to = to_;
+  from = from_;
+  l4Header = frame + sizeof(ethhdr) + (to.isIPv4() ? sizeof(iphdr) : sizeof(ipv6hdr));
+  if (tcp) {
+    flags = TCP;
+    payload = l4Header + sizeof(tcphdr);
+  }
+  else {
+    flags = 0;
+    payload = l4Header + sizeof(udphdr);
+  }
+}
+void XskPacket::rewrite() noexcept
+{
+  flags |= REWRITE;
+  auto* eth = reinterpret_cast<ethhdr*>(frame);
+  if (to.isIPv4()) {
+    eth->h_proto = htons(ETH_P_IP);
+    auto* ipv4 = reinterpret_cast<iphdr*>(eth + 1);
+
+    ipv4->daddr = to.sin4.sin_addr.s_addr;
+    ipv4->saddr = from.sin4.sin_addr.s_addr;
+    if (flags & XskPacket::TCP) {
+      auto* tcp = reinterpret_cast<tcphdr*>(ipv4 + 1);
+      ipv4->protocol = IPPROTO_TCP;
+      tcp->source = from.sin4.sin_port;
+      tcp->dest = to.sin4.sin_port;
+      // TODO
+    }
+    else {
+      auto* udp = reinterpret_cast<udphdr*>(ipv4 + 1);
+      ipv4->protocol = IPPROTO_UDP;
+      udp->source = from.sin4.sin_port;
+      udp->dest = to.sin4.sin_port;
+      udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+      udp->check = 0;
+      udp->check = tcp_udp_v4_checksum();
+    }
+    rewriteIpv4Header(ipv4);
+  }
+  else {
+    auto* ipv6 = reinterpret_cast<ipv6hdr*>(eth + 1);
+    memcpy(&ipv6->daddr, &to.sin6.sin6_addr, sizeof(ipv6->daddr));
+    memcpy(&ipv6->saddr, &from.sin6.sin6_addr, sizeof(ipv6->saddr));
+    if (flags & XskPacket::TCP) {
+      auto* tcp = reinterpret_cast<tcphdr*>(ipv6 + 1);
+      ipv6->nexthdr = IPPROTO_TCP;
+      tcp->source = from.sin6.sin6_port;
+      tcp->dest = to.sin6.sin6_port;
+      // TODO
+    }
+    else {
+      auto* udp = reinterpret_cast<udphdr*>(ipv6 + 1);
+      ipv6->nexthdr = IPPROTO_UDP;
+      udp->source = from.sin6.sin6_port;
+      udp->dest = to.sin6.sin6_port;
+      udp->len = htons(payloadEnd - reinterpret_cast<uint8_t*>(udp));
+      udp->check = 0;
+      udp->check = tcp_udp_v6_checksum();
+    }
+  }
+}
+
+[[nodiscard]] __be16 XskPacket::ipv4Checksum() const noexcept
+{
+  auto* ip = reinterpret_cast<iphdr*>(frame + sizeof(ethhdr));
+  return ip_checksum_fold(ip_checksum_partial(ip, sizeof(iphdr), 0));
+}
+[[nodiscard]] __be16 XskPacket::tcp_udp_v4_checksum() const noexcept
+{
+  const auto* ip = reinterpret_cast<iphdr*>(frame + sizeof(ethhdr));
+  // ip options is not supported !!!
+  const auto l4Length = static_cast<uint16_t>(payloadEnd - l4Header);
+  auto sum = tcp_udp_v4_header_checksum_partial(ip->saddr, ip->daddr, ip->protocol, l4Length);
+  sum = ip_checksum_partial(l4Header, l4Length, sum);
+  return ip_checksum_fold(sum);
+}
+[[nodiscard]] __be16 XskPacket::tcp_udp_v6_checksum() const noexcept
+{
+  const auto* ipv6 = reinterpret_cast<ipv6hdr*>(frame + sizeof(ethhdr));
+  const auto l4Length = static_cast<uint16_t>(payloadEnd - l4Header);
+  uint64_t sum = tcp_udp_v6_header_checksum_partial(&ipv6->saddr, &ipv6->daddr, ipv6->nexthdr, l4Length);
+  sum = ip_checksum_partial(l4Header, l4Length, sum);
+  return ip_checksum_fold(sum);
+}
+
+#ifndef __packed
+#define __packed __attribute__((packed))
+#endif
+[[nodiscard]] uint64_t XskPacket::ip_checksum_partial(const void* p, size_t len, uint64_t sum) noexcept
+{
+  /* Main loop: 32 bits at a time.
+   * We take advantage of intel's ability to do unaligned memory
+   * accesses with minimal additional cost. Other architectures
+   * probably want to be more careful here.
+   */
+  const uint32_t* p32 = (const uint32_t*)(p);
+  for (; len >= sizeof(*p32); len -= sizeof(*p32))
+    sum += *p32++;
+
+  /* Handle un-32bit-aligned trailing bytes */
+  const uint16_t* p16 = (const uint16_t*)(p32);
+  if (len >= 2) {
+    sum += *p16++;
+    len -= sizeof(*p16);
+  }
+  if (len > 0) {
+    const uint8_t* p8 = (const uint8_t*)(p16);
+    sum += ntohs(*p8 << 8); /* RFC says pad last byte */
+  }
+
+  return sum;
+}
+[[nodiscard]] __be16 XskPacket::ip_checksum_fold(uint64_t sum) noexcept
+{
+  while (sum & ~0xffffffffULL)
+    sum = (sum >> 32) + (sum & 0xffffffffULL);
+  while (sum & 0xffff0000ULL)
+    sum = (sum >> 16) + (sum & 0xffffULL);
+
+  return ~sum;
+}
+[[nodiscard]] uint64_t XskPacket::tcp_udp_v4_header_checksum_partial(__be32 src_ip, __be32 dst_ip, uint8_t protocol, uint16_t len) noexcept
+{
+  struct header
+  {
+    __be32 src_ip;
+    __be32 dst_ip;
+    __uint8_t mbz;
+    __uint8_t protocol;
+    __be16 length;
+  };
+  /* The IPv4 pseudo-header is defined in RFC 793, Section 3.1. */
+  struct ipv4_pseudo_header_t
+  {
+    /* We use a union here to avoid aliasing issues with gcc -O2 */
+    union
+    {
+      header __packed fields;
+      uint32_t words[3];
+    };
+  };
+  struct ipv4_pseudo_header_t pseudo_header;
+  assert(sizeof(pseudo_header) == 12);
+
+  /* Fill in the pseudo-header. */
+  pseudo_header.fields.src_ip = src_ip;
+  pseudo_header.fields.dst_ip = dst_ip;
+  pseudo_header.fields.mbz = 0;
+  pseudo_header.fields.protocol = protocol;
+  pseudo_header.fields.length = htons(len);
+  return ip_checksum_partial(&pseudo_header, sizeof(pseudo_header), 0);
+}
+[[nodiscard]] uint64_t XskPacket::tcp_udp_v6_header_checksum_partial(const struct in6_addr* src_ip, const struct in6_addr* dst_ip, uint8_t protocol, uint32_t len) noexcept
+{
+  struct header
+  {
+    struct in6_addr src_ip;
+    struct in6_addr dst_ip;
+    __be32 length;
+    __uint8_t mbz[3];
+    __uint8_t next_header;
+  };
+  /* The IPv6 pseudo-header is defined in RFC 2460, Section 8.1. */
+  struct ipv6_pseudo_header_t
+  {
+    /* We use a union here to avoid aliasing issues with gcc -O2 */
+    union
+    {
+      header __packed fields;
+      uint32_t words[10];
+    };
+  };
+  struct ipv6_pseudo_header_t pseudo_header;
+  assert(sizeof(pseudo_header) == 40);
+
+  /* Fill in the pseudo-header. */
+  pseudo_header.fields.src_ip = *src_ip;
+  pseudo_header.fields.dst_ip = *dst_ip;
+  pseudo_header.fields.length = htonl(len);
+  memset(pseudo_header.fields.mbz, 0, sizeof(pseudo_header.fields.mbz));
+  pseudo_header.fields.next_header = protocol;
+  return ip_checksum_partial(&pseudo_header, sizeof(pseudo_header), 0);
+}
+void XskPacket::setHeader(const PacketBuffer& buf) noexcept
+{
+  memcpy(frame, buf.data(), buf.size());
+  payloadEnd = frame + buf.size();
+  flags = 0;
+  const auto res = parse();
+  assert(res);
+}
+std::unique_ptr<PacketBuffer> XskPacket::cloneHeadertoPacketBuffer() const
+{
+  const auto size = payload - frame;
+  auto tmp = std::make_unique<PacketBuffer>(size);
+  memcpy(tmp->data(), frame, size);
+  return tmp;
+}
+int XskWorker::createEventfd()
+{
+  auto fd = ::eventfd(0, EFD_CLOEXEC);
+  if (fd < 0) {
+    throw runtime_error("Unable create eventfd");
+  }
+  return fd;
+}
+void XskWorker::waitForXskSocket() noexcept
+{
+  uint64_t x = read(workerWaker, &x, sizeof(x));
+}
+void XskWorker::notifyXskSocket() noexcept
+{
+  notify(xskSocketWaker);
+}
+
+std::shared_ptr<XskWorker> XskWorker::create()
+{
+  return std::make_shared<XskWorker>();
+}
+void XskSocket::addWorker(std::shared_ptr<XskWorker> s, const ComboAddress& dest, bool isTCP)
+{
+  extern std::atomic<bool> g_configurationDone;
+  if (g_configurationDone) {
+    throw runtime_error("Adding a server with xsk at runtime is not supported");
+  }
+  s->poolName = poolName;
+  const auto socketWaker = s->xskSocketWaker.getHandle();
+  const auto workerWaker = s->workerWaker.getHandle();
+  const auto& socketWakerIdx = workers.get<0>();
+  if (socketWakerIdx.contains(socketWaker)) {
+    throw runtime_error("Server already exist");
+  }
+  s->umemBufBase = umem.bufBase;
+  workers.insert(XskRouteInfo{
+    .worker = std::move(s),
+    .dest = dest,
+    .xskSocketWaker = socketWaker,
+    .workerWaker = workerWaker,
+  });
+  fds.push_back(pollfd{
+    .fd = socketWaker,
+    .events = POLLIN,
+    .revents = 0});
+};
+uint64_t XskWorker::frameOffset(const XskPacket& s) const noexcept
+{
+  return s.frame - umemBufBase;
+}
+void XskWorker::notifyWorker() noexcept
+{
+  notify(workerWaker);
+}
+void XskSocket::getMACFromIfName()
+{
+  ifreq ifr;
+  auto fd = ::socket(AF_INET, SOCK_DGRAM, 0);
+  strncpy(ifr.ifr_name, ifName.c_str(), ifName.length() + 1);
+  if (ioctl(fd, SIOCGIFHWADDR, &ifr) < 0) {
+    throw runtime_error("Error get MAC addr");
+  }
+  memcpy(source, ifr.ifr_hwaddr.sa_data, sizeof(source));
+  close(fd);
+}
+[[nodiscard]] int XskSocket::timeDifference(const timespec& t1, const timespec& t2) noexcept
+{
+  const auto res = t1.tv_sec * 1000 + t1.tv_nsec / 1000000L - (t2.tv_sec * 1000 + t2.tv_nsec / 1000000L);
+  return static_cast<int>(res);
+}
+void XskWorker::cleanWorkerNotification() noexcept
+{
+  uint64_t x = read(xskSocketWaker, &x, sizeof(x));
+}
+void XskWorker::cleanSocketNotification() noexcept
+{
+  uint64_t x = read(workerWaker, &x, sizeof(x));
+}
+std::vector<pollfd> getPollFdsForWorker(XskWorker& info)
+{
+  std::vector<pollfd> fds;
+  int timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC);
+  if (timerfd < 0) {
+    throw std::runtime_error("create_timerfd failed");
+  }
+  fds.push_back(pollfd{
+    .fd = info.workerWaker,
+    .events = POLLIN,
+    .revents = 0,
+  });
+  fds.push_back(pollfd{
+    .fd = timerfd,
+    .events = POLLIN,
+    .revents = 0,
+  });
+  return fds;
+}
+void XskWorker::fillUniqueEmptyOffset()
+{
+  auto frames = sharedEmptyFrameOffset->lock();
+  const auto moveSize = std::min(static_cast<size_t>(32), frames->size());
+  if (moveSize > 0) {
+    uniqueEmptyFrameOffset.insert(uniqueEmptyFrameOffset.end(), std::make_move_iterator(frames->end() - moveSize), std::make_move_iterator(frames->end()));
+  }
+}
+void* XskWorker::getEmptyframe()
+{
+  if (!uniqueEmptyFrameOffset.empty()) {
+    auto offset = uniqueEmptyFrameOffset.back();
+    uniqueEmptyFrameOffset.pop_back();
+    return offset + umemBufBase;
+  }
+  fillUniqueEmptyOffset();
+  if (!uniqueEmptyFrameOffset.empty()) {
+    auto offset = uniqueEmptyFrameOffset.back();
+    uniqueEmptyFrameOffset.pop_back();
+    return offset + umemBufBase;
+  }
+  return nullptr;
+}
+uint32_t XskPacket::getFlags() const noexcept
+{
+  return flags;
+}
+void XskPacket::updatePacket() noexcept
+{
+  if (!(flags & UPDATE)) {
+    return;
+  }
+  if (!(flags & REWRITE)) {
+    changeDirectAndUpdateChecksum();
+  }
+}
+#endif /* HAVE_XSK */
diff --git a/pdns/xsk.hh b/pdns/xsk.hh
new file mode 100644 (file)
index 0000000..83c957d
--- /dev/null
@@ -0,0 +1,243 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#pragma once
+#include "iputils.hh"
+#include "misc.hh"
+#include "noinitvector.hh"
+#include "lock.hh"
+
+#include <array>
+#include <bits/types/struct_timespec.h>
+#include <boost/lockfree/spsc_queue.hpp>
+#include <boost/multi_index/hashed_index.hpp>
+#include <boost/multi_index/indexed_by.hpp>
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/member.hpp>
+#include <cstdint>
+#include <cstring>
+#include <linux/types.h>
+#include <memory>
+#include <queue>
+#include <stdexcept>
+#include <string>
+#include <sys/poll.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <unordered_map>
+#include <vector>
+
+#ifdef HAVE_XSK
+#include <xdp/xsk.h>
+#endif /* HAVE_XSK */
+
+class XskPacket;
+class XskWorker;
+class XskSocket;
+
+#ifdef HAVE_XSK
+using XskPacketPtr = std::unique_ptr<XskPacket>;
+
+// We use an XskSocket to manage an AF_XDP Socket corresponding to a NIC queue.
+// The XDP program running in the kernel redirects the data to the XskSocket in userspace.
+// XskSocket routes packets to multiple worker threads registered on XskSocket via XskSocket::addWorker based on the destination port number of the packet.
+// The kernel and the worker thread holding XskWorker will wake up the XskSocket through XskFd and the Eventfd corresponding to each worker thread, respectively.
+class XskSocket
+{
+  struct XskRouteInfo
+  {
+    std::shared_ptr<XskWorker> worker;
+    ComboAddress dest;
+    int xskSocketWaker;
+    int workerWaker;
+  };
+  struct XskUmem
+  {
+    xsk_umem* umem{nullptr};
+    uint8_t* bufBase{nullptr};
+    size_t size;
+    void umemInit(size_t memSize, xsk_ring_cons* cq, xsk_ring_prod* fq, xsk_umem_config* config);
+    ~XskUmem();
+    XskUmem() = default;
+  };
+  boost::multi_index_container<
+    XskRouteInfo,
+    boost::multi_index::indexed_by<
+      boost::multi_index::hashed_unique<boost::multi_index::member<XskRouteInfo, int, &XskRouteInfo::xskSocketWaker>>,
+      boost::multi_index::hashed_unique<boost::multi_index::member<XskRouteInfo, ComboAddress, &XskRouteInfo::dest>, ComboAddress::addressPortOnlyHash>>>
+    workers;
+  static constexpr size_t holdThreshold = 256;
+  static constexpr size_t fillThreshold = 128;
+  static constexpr size_t frameSize = 2048;
+  const size_t frameNum;
+  const uint32_t queueId;
+  std::priority_queue<XskPacketPtr> waitForDelay;
+  const std::string ifName;
+  const std::string poolName;
+  vector<pollfd> fds;
+  vector<uint64_t> uniqueEmptyFrameOffset;
+  xsk_ring_cons cq;
+  xsk_ring_cons rx;
+  xsk_ring_prod fq;
+  xsk_ring_prod tx;
+  std::unique_ptr<xsk_socket, void (*)(xsk_socket*)> socket;
+  XskUmem umem;
+  bpf_object* prog;
+
+  static constexpr uint32_t fqCapacity = XSK_RING_PROD__DEFAULT_NUM_DESCS * 4;
+  static constexpr uint32_t cqCapacity = XSK_RING_CONS__DEFAULT_NUM_DESCS * 4;
+  static constexpr uint32_t rxCapacity = XSK_RING_CONS__DEFAULT_NUM_DESCS * 2;
+  static constexpr uint32_t txCapacity = XSK_RING_PROD__DEFAULT_NUM_DESCS * 2;
+
+  constexpr static bool isPowOfTwo(uint32_t value) noexcept;
+  [[nodiscard]] static int timeDifference(const timespec& t1, const timespec& t2) noexcept;
+  friend void XskRouter(std::shared_ptr<XskSocket> xsk);
+
+  [[nodiscard]] uint64_t frameOffset(const XskPacket& packet) const noexcept;
+  int firstTimeout();
+  void fillFq(uint32_t fillSize = fillThreshold) noexcept;
+  void recycle(size_t size) noexcept;
+  void getMACFromIfName();
+  void pickUpReadyPacket(std::vector<XskPacketPtr>& packets);
+
+public:
+  std::shared_ptr<LockGuarded<vector<uint64_t>>> sharedEmptyFrameOffset;
+  XskSocket(size_t frameNum, const std::string& ifName, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_);
+  MACAddr source;
+  [[nodiscard]] int xskFd() const noexcept;
+  int wait(int timeout);
+  void send(std::vector<XskPacketPtr>& packets);
+  std::vector<XskPacketPtr> recv(uint32_t recvSizeMax, uint32_t* failedCount);
+  void addWorker(std::shared_ptr<XskWorker> s, const ComboAddress& dest, bool isTCP);
+};
+class XskPacket
+{
+public:
+  enum Flags : uint32_t
+  {
+    TCP = 1 << 0,
+    UPDATE = 1 << 1,
+    DELAY = 1 << 3,
+    REWRITE = 1 << 4
+  };
+
+private:
+  ComboAddress from;
+  ComboAddress to;
+  timespec sendTime;
+  uint8_t* frame;
+  uint8_t* l4Header;
+  uint8_t* payload;
+  uint8_t* payloadEnd;
+  uint8_t* frameEnd;
+  uint32_t flags{0};
+
+  friend XskSocket;
+  friend XskWorker;
+  friend bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept;
+
+  constexpr static uint8_t DefaultTTL = 64;
+  bool parse();
+  void changeDirectAndUpdateChecksum() noexcept;
+
+  // You must set ipHeader.check = 0 before call this method
+  [[nodiscard]] __be16 ipv4Checksum() const noexcept;
+  // You must set l4Header.check = 0 before call this method
+  // ip options is not supported
+  [[nodiscard]] __be16 tcp_udp_v4_checksum() const noexcept;
+  // You must set l4Header.check = 0 before call this method
+  [[nodiscard]] __be16 tcp_udp_v6_checksum() const noexcept;
+  [[nodiscard]] static uint64_t ip_checksum_partial(const void* p, size_t len, uint64_t sum) noexcept;
+  [[nodiscard]] static __be16 ip_checksum_fold(uint64_t sum) noexcept;
+  [[nodiscard]] static uint64_t tcp_udp_v4_header_checksum_partial(__be32 src_ip, __be32 dst_ip, uint8_t protocol, uint16_t len) noexcept;
+  [[nodiscard]] static uint64_t tcp_udp_v6_header_checksum_partial(const struct in6_addr* src_ip, const struct in6_addr* dst_ip, uint8_t protocol, uint32_t len) noexcept;
+  void rewriteIpv4Header(void* ipv4header) noexcept;
+  void rewriteIpv6Header(void* ipv6header) noexcept;
+
+public:
+  [[nodiscard]] const ComboAddress& getFromAddr() const noexcept;
+  [[nodiscard]] const ComboAddress& getToAddr() const noexcept;
+  [[nodiscard]] const void* payloadData() const;
+  [[nodiscard]] bool isIPV6() const noexcept;
+  [[nodiscard]] size_t capacity() const noexcept;
+  [[nodiscard]] uint32_t dataLen() const noexcept;
+  [[nodiscard]] uint32_t FrameLen() const noexcept;
+  [[nodiscard]] PacketBuffer clonePacketBuffer() const;
+  void cloneIntoPacketBuffer(PacketBuffer& buffer) const;
+  [[nodiscard]] std::unique_ptr<PacketBuffer> cloneHeadertoPacketBuffer() const;
+  [[nodiscard]] void* payloadData();
+  void setAddr(const ComboAddress& from_, MACAddr fromMAC, const ComboAddress& to_, MACAddr toMAC, bool tcp = false) noexcept;
+  bool setPayload(const PacketBuffer& buf);
+  void rewrite() noexcept;
+  void setHeader(const PacketBuffer& buf) noexcept;
+  XskPacket() = default;
+  XskPacket(void* frame, size_t dataSize, size_t frameSize);
+  void addDelay(int relativeMilliseconds) noexcept;
+  void updatePacket() noexcept;
+  [[nodiscard]] uint32_t getFlags() const noexcept;
+};
+bool operator<(const XskPacketPtr& s1, const XskPacketPtr& s2) noexcept;
+
+// XskWorker obtains XskPackets of specific ports in the NIC from XskSocket through cq.
+// After finishing processing the packet, XskWorker puts the packet into sq so that XskSocket decides whether to send it through the network card according to XskPacket::flags.
+// XskWorker wakes up XskSocket via xskSocketWaker after putting the packets in sq.
+class XskWorker
+{
+  using XskPacketRing = boost::lockfree::spsc_queue<XskPacket*, boost::lockfree::capacity<512>>;
+
+public:
+  uint8_t* umemBufBase;
+  std::shared_ptr<LockGuarded<vector<uint64_t>>> sharedEmptyFrameOffset;
+  vector<uint64_t> uniqueEmptyFrameOffset;
+  XskPacketRing cq;
+  XskPacketRing sq;
+  std::string poolName;
+  size_t frameSize;
+  FDWrapper workerWaker;
+  FDWrapper xskSocketWaker;
+
+  XskWorker();
+  static int createEventfd();
+  static void notify(int fd);
+  static std::shared_ptr<XskWorker> create();
+  void notifyWorker() noexcept;
+  void notifyXskSocket() noexcept;
+  void waitForXskSocket() noexcept;
+  void cleanWorkerNotification() noexcept;
+  void cleanSocketNotification() noexcept;
+  [[nodiscard]] uint64_t frameOffset(const XskPacket& s) const noexcept;
+  void fillUniqueEmptyOffset();
+  void* getEmptyframe();
+};
+std::vector<pollfd> getPollFdsForWorker(XskWorker& info);
+#else
+class XskSocket
+{
+};
+class XskPacket
+{
+};
+class XskWorker
+{
+};
+
+#endif /* HAVE_XSK */