]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix XSK between dnsdist and its backends
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 11 Jan 2024 15:24:38 +0000 (16:24 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 23 Jan 2024 11:54:21 +0000 (12:54 +0100)
contrib/xdp-filter.ebpf.src
contrib/xdp.h
contrib/xdp.py
pdns/dnsdist-lua-bindings.cc
pdns/dnsdist-lua.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-backend.cc
pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc
pdns/xsk.cc
pdns/xsk.hh

index 786d9d75a355f1a7fbf9a32e0063a72bbada11e3..ec0d07145d262d0e1081c2cf9d3344b13bae023e 100644 (file)
@@ -35,6 +35,8 @@ BPF_TABLE_PINNED7("lpm_trie", struct CIDR6, struct map_value, cidr6filter, 1024,
   __attribute__((section("maps/xskmap:" _pinned))) struct _name##_table_t _name = {.max_entries = (_max_entries)}
 
 BPF_XSKMAP_PIN(xsk_map, 16, "/sys/fs/bpf/dnsdist/xskmap");
+BPF_TABLE_PINNED("hash", struct IPv4AndPort, bool, xskDestinationsV4, 1024, "/sys/fs/bpf/dnsdist/xsk-destinations-v4");
+BPF_TABLE_PINNED("hash", struct IPv6AndPort, bool, xskDestinationsV6, 1024, "/sys/fs/bpf/dnsdist/xsk-destinations-v6");
 #endif /* UseXsk */
 
 #define COMPARE_PORT(x, p) ((x) == bpf_htons(p))
@@ -159,9 +161,19 @@ static inline enum xdp_action parseIPV4(struct xdp_md* ctx, struct cursor* c)
     if (!(udp = parse_udphdr(c))) {
       return XDP_PASS;
     }
+#ifdef UseXsk
+    struct IPv4AndPort v4Dest;
+    memset(&v4Dest, 0, sizeof(v4Dest));
+    v4Dest.port = udp->dest;
+    v4Dest.addr = ipv4->daddr;
+    if (!xskDestinationsV4.lookup(&v4Dest)) {
+      return XDP_PASS;
+    }
+#else /* UseXsk */
     if (!IN_DNS_PORT_SET(udp->dest)) {
       return XDP_PASS;
     }
+#endif /* UseXsk */
     if (!(dns = parse_dnshdr(c))) {
       return XDP_DROP;
     }
@@ -253,10 +265,20 @@ static inline enum xdp_action parseIPV6(struct xdp_md* ctx, struct cursor* c)
     if (!(udp = parse_udphdr(c))) {
       return XDP_PASS;
     }
+#ifdef UseXsk
+    struct IPv6AndPort v6Dest;
+    memset(&v6Dest, 0, sizeof(v6Dest));
+    v6Dest.port = udp->dest;
+    memcpy(&v6Dest.addr, &ipv6->daddr, sizeof(v6Dest.addr));
+    if (!xskDestinationsV6.lookup(&v6Dest)) {
+      return XDP_PASS;
+    }
+#else /* UseXsk */
     if (!IN_DNS_PORT_SET(udp->dest)) {
       return XDP_PASS;
     }
-    if (!(dns = parse_dnshdr(c))) {
+#endif /* UseXsk */
+  if (!(dns = parse_dnshdr(c))) {
       return XDP_DROP;
     }
     break;
index 87fef3a776cd7ac95d631f953b15c08fd8b210e9..0d63fcfd963dd4840583b551f6b3475c35542cb9 100644 (file)
@@ -108,6 +108,18 @@ struct CIDR6
   struct in6_addr addr;
 };
 
+struct IPv4AndPort
+{
+  uint32_t addr;
+  uint16_t port;
+};
+
+struct IPv6AndPort
+{
+  struct in6_addr addr;
+  uint16_t port;
+};
+
 /*
  * Store the matching counter and the associated action for a blocked element
  */
@@ -128,7 +140,7 @@ static inline void cursor_init(struct cursor *c, struct xdp_md *ctx)
   c->pos = (void *)(long)ctx->data;
 }
 
-/* 
+/*
  * Header parser functions
  * Copyright 2020, NLnet Labs, All rights reserved.
  */
@@ -180,4 +192,4 @@ static inline struct ethhdr *parse_eth(struct cursor *c, uint16_t *eth_proto)
   return eth;
 }
 
-#endif 
+#endif
index 1b9187007ff22db4ea1c4efed5eec154eae6066c..67ad96b917f7c03304671a119670a71fcf51ce7b 100644 (file)
@@ -46,6 +46,9 @@ cidr4filter = xdp.get_table("cidr4filter")
 cidr6filter = xdp.get_table("cidr6filter")
 qnamefilter = xdp.get_table("qnamefilter")
 
+if useXsk:
+  xskDestinations = xdp.get_table("xskDestinationsV4")
+
 for ip in blocked_ipv4:
   print(f"Blocking {ip}")
   key = v4filter.Key(int(netaddr.IPAddress(ip[0]).value))
@@ -106,7 +109,7 @@ for qname in blocked_qnames:
 
 print("Filter is ready")
 try:
-  xdp.trace_print() 
+  xdp.trace_print()
 except KeyboardInterrupt:
   pass
 
index 5a6d8f4e95d87fa18fcbecdae0712ca714288b86..45ee564cf64afd580e5db95e256a02debe4985e6 100644 (file)
@@ -754,11 +754,8 @@ void setupLuaBindings(LuaContext& luaCtx, bool client, bool configCheck)
     else {
       throw std::runtime_error("xskMapPath field is required!");
     }
-    if (opts.count("pool") == 1) {
-      poolName = boost::get<std::string>(opts.at("pool"));
-    }
     extern std::vector<std::shared_ptr<XskSocket>> g_xsk;
-    auto socket = std::make_shared<XskSocket>(frameNums, ifName, queue_id, path, poolName);
+    auto socket = std::make_shared<XskSocket>(frameNums, ifName, queue_id, path);
     g_xsk.push_back(socket);
     return socket;
   });
index c4ee51962812c14d6a753f10a8ac8bc766076140..793f226df697d8cdc6a5915042adf8502b161ee1 100644 (file)
@@ -636,13 +636,22 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
 #ifdef HAVE_XSK
                          std::shared_ptr<XskSocket> xskSocket;
                          if (getOptionalValue<std::shared_ptr<XskSocket>>(vars, "xskSocket", xskSocket) > 0) {
+                           if (g_configurationDone) {
+                             throw std::runtime_error("Adding a server with xsk at runtime is not supported");
+                           }
                            ret->registerXsk(xskSocket);
                            std::string mac;
-                           if (getOptionalValue<std::string>(vars, "MACAddr", mac) != 1) {
-                             throw runtime_error("field MACAddr is required!");
+                           if (getOptionalValue<std::string>(vars, "MACAddr", mac) > 0) {
+                             auto* addr = &ret->d_config.destMACAddr[0];
+                             sscanf(mac.c_str(), "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx", addr, addr + 1, addr + 2, addr + 3, addr + 4, addr + 5);
+                           }
+                           else {
+                             mac = getMACAddress(ret->d_config.remote);
+                             if (mac.size() != ret->d_config.destMACAddr.size()) {
+                               throw runtime_error("Field 'MACAddr' is not set on 'newServer' directive for '" + ret->d_config.remote.toStringWithPort() + "' and cannot be retriever from the system either!");
+                             }
+                             memcpy(ret->d_config.destMACAddr.data(), mac.data(), ret->d_config.destMACAddr.size());
                            }
-                           auto* addr = &ret->d_config.destMACAddr[0];
-                           sscanf(mac.c_str(), "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx", addr, addr + 1, addr + 2, addr + 3, addr + 4, addr + 5);
                          }
 #endif /* HAVE_XSK */
                          if (autoUpgrade && ret->getProtocol() != dnsdist::Protocol::DoT && ret->getProtocol() != dnsdist::Protocol::DoH) {
@@ -783,7 +792,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       if (socket) {
         udpCS->xskInfo = XskWorker::create();
         udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset;
-        socket->addWorker(udpCS->xskInfo, loc);
+        socket->addWorker(udpCS->xskInfo);
+        socket->addWorkerRoute(udpCS->xskInfo, loc);
       }
 #endif /* HAVE_XSK */
       g_frontends.push_back(std::move(udpCS));
@@ -835,7 +845,8 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
       if (socket) {
         udpCS->xskInfo = XskWorker::create();
         udpCS->xskInfo->sharedEmptyFrameOffset = socket->sharedEmptyFrameOffset;
-        socket->addWorker(udpCS->xskInfo, loc);
+        socket->addWorker(udpCS->xskInfo);
+        socket->addWorkerRoute(udpCS->xskInfo, loc);
       }
 #endif /* HAVE_XSK */
       g_frontends.push_back(std::move(udpCS));
index a4dcd3f1bd189a4de1dc1a4b3a516d6cb726f269..cd71ac6b5c3d76ff8bc09d74790b0289896820b8 100644 (file)
@@ -35,7 +35,6 @@
 
 #ifdef HAVE_XSK
 #include <sys/poll.h>
-#include <sys/timerfd.h>
 #endif /* HAVE_XSK */
 
 #ifdef HAVE_LIBEDIT
@@ -815,25 +814,6 @@ static bool processResponderPacket(std::shared_ptr<DownstreamState>& dss, Packet
 #ifdef HAVE_XSK
 namespace dnsdist::xsk
 {
-static void doHealthCheck(std::shared_ptr<DownstreamState>& dss, std::unordered_map<uint16_t, std::shared_ptr<HealthCheckData>>& map, bool initial = false)
-{
-  auto& xskInfo = dss->xskInfo;
-  std::shared_ptr<HealthCheckData> data;
-  auto packet = getHealthCheckPacket(dss, nullptr, data);
-  data->d_initial = initial;
-  setHealthCheckTime(dss, data);
-  auto xskPacket = xskInfo->getEmptyFrame();
-  if (!xskPacket) {
-    return;
-  }
-  xskPacket->setAddr(dss->d_config.sourceAddr, dss->d_config.sourceMACAddr, dss->d_config.remote, dss->d_config.destMACAddr);
-  xskPacket->setPayload(packet);
-  xskPacket->rewrite();
-  xskInfo->pushToSendQueue(std::move(*xskPacket));
-  const auto queryId = data->d_queryID;
-  map[queryId] = std::move(data);
-}
-
 void responderThread(std::shared_ptr<DownstreamState> dss)
 {
   if (dss->xskInfo == nullptr) {
@@ -846,16 +826,6 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
     auto localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
     auto xskInfo = dss->xskInfo;
     auto pollfds = getPollFdsForWorker(*xskInfo);
-    std::unordered_map<uint16_t, std::shared_ptr<HealthCheckData>> healthCheckMap;
-    dnsdist::xsk::doHealthCheck(dss, healthCheckMap, true);
-    itimerspec tm;
-    tm.it_value.tv_sec = dss->d_config.checkTimeout / 1000;
-    tm.it_value.tv_nsec = (dss->d_config.checkTimeout % 1000) * 1000000;
-    tm.it_interval = tm.it_value;
-    auto res = timerfd_settime(pollfds[1].fd, 0, &tm, nullptr);
-    if (res) {
-      throw std::runtime_error("timerfd_settime failed:" + stringerror(errno));
-    }
     const auto xskFd = xskInfo->workerWaker.getHandle();
     while (!dss->isStopped()) {
       poll(pollfds.data(), pollfds.size(), -1);
@@ -881,14 +851,6 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
             }
           }
           if (!ids) {
-            // this has to go before we can refactor the duplicated response handling code
-            auto iter = healthCheckMap.find(queryId);
-            if (iter != healthCheckMap.end()) {
-              auto data = std::move(iter->second);
-              healthCheckMap.erase(iter);
-              packet.cloneIntoPacketBuffer(data->d_buffer);
-              data->d_ds->submitHealthCheckResult(data->d_initial, handleResponse(data));
-            }
             xskInfo->markAsFree(std::move(packet));
             return;
           }
@@ -902,7 +864,6 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
             vinfolog("XSK packet pushed to queue because processResponderPacket failed");
             return;
           }
-          vinfolog("XSK packet - processResponderPacket OK");
           if (response.size() > packet.getCapacity()) {
             /* fallback to sending the packet via normal socket */
             sendUDPResponse(ids->cs->udpFD, response, ids->delayMsec, ids->hopLocal, ids->hopRemote);
@@ -910,9 +871,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
             xskInfo->markAsFree(std::move(packet));
             return;
           }
-          //vinfolog("XSK packet - set header");
           packet.setHeader(ids->xskPacketHeader);
-          //vinfolog("XSK packet - set payload");
           if (!packet.setPayload(response)) {
             vinfolog("Unable to set payload !");
           }
@@ -920,42 +879,11 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
             vinfolog("XSK packet - adding delay");
             packet.addDelay(ids->delayMsec);
           }
-          //vinfolog("XSK packet - update packet");
           packet.updatePacket();
-          //vinfolog("XSK packet pushed to send queue");
           xskInfo->pushToSendQueue(std::move(packet));
         });
         xskInfo->cleanSocketNotification();
       }
-      if (pollfds[1].revents & POLLIN) {
-        timeval now;
-        gettimeofday(&now, nullptr);
-        for (auto i = healthCheckMap.begin(); i != healthCheckMap.end();) {
-          auto& ttd = i->second->d_ttd;
-          if (ttd < now) {
-            dss->submitHealthCheckResult(i->second->d_initial, false);
-            i = healthCheckMap.erase(i);
-          }
-          else {
-            ++i;
-          }
-        }
-        needNotify = true;
-        dss->updateStatisticsInfo();
-        dss->handleUDPTimeouts();
-        if (dss->d_nextCheck <= 1) {
-          dss->d_nextCheck = dss->d_config.checkInterval;
-          if (dss->d_config.availability == DownstreamState::Availability::Auto) {
-            doHealthCheck(dss, healthCheckMap);
-          }
-        }
-        else {
-          --dss->d_nextCheck;
-        }
-
-        uint64_t tmp;
-        res = read(pollfds[1].fd, &tmp, sizeof(tmp));
-      }
       if (needNotify) {
         xskInfo->notifyXskSocket();
       }
@@ -987,7 +915,7 @@ static bool isXskQueryAcceptable(const XskPacket& packet, ClientState& cs, Local
   return true;
 }
 
-void XskRouter(std::shared_ptr<XskSocket> xsk)
+static void XskRouter(std::shared_ptr<XskSocket> xsk)
 {
   setThreadName("dnsdist/XskRouter");
   uint32_t failed;
@@ -996,8 +924,6 @@ void XskRouter(std::shared_ptr<XskSocket> xsk)
   const auto& fds = xsk->getDescriptors();
   // list of workers that need to be notified
   std::set<int> needNotify;
-  const auto& xskWakerIdx = xsk->getWorkers().get<0>();
-  const auto& destIdx = xsk->getWorkers().get<1>();
   while (true) {
     try {
       auto ready = xsk->wait(-1);
@@ -1007,13 +933,13 @@ void XskRouter(std::shared_ptr<XskSocket> xsk)
         dnsdist::metrics::g_stats.nonCompliantQueries += failed;
         for (auto &packet : packets) {
           const auto dest = packet.getToAddr();
-          auto res = destIdx.find(dest);
-          if (res == destIdx.end()) {
+          auto worker = xsk->getWorkerByDestination(dest);
+          if (!worker) {
             xsk->markAsFree(std::move(packet));
             continue;
           }
-          res->worker->pushToProcessingQueue(std::move(packet));
-          needNotify.insert(res->workerWaker);
+          worker->pushToProcessingQueue(std::move(packet));
+          needNotify.insert(worker->workerWaker.getHandle());
         }
         for (auto i : needNotify) {
           uint64_t x = 1;
@@ -1031,7 +957,7 @@ void XskRouter(std::shared_ptr<XskSocket> xsk)
       for (size_t fdIndex = 1; fdIndex < fds.size() && ready > 0; fdIndex++) {
         if (fds.at(fdIndex).revents & POLLIN) {
           ready--;
-          auto& info = xskWakerIdx.find(fds.at(fdIndex).fd)->worker;
+          auto& info = xsk->getWorkerByDescriptor(fds.at(fdIndex).fd);
 #if defined(__SANITIZE_THREAD__)
           info->outgoingPacketsQueue.lock()->consume_all([&](XskPacket& packet) {
 #else
@@ -1131,19 +1057,14 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 
         if (processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids)) && ids->isXSK() && ids->cs->xskInfo) {
 #ifdef HAVE_XSK
-          //vinfolog("processResponderPacket OK");
           auto& xskInfo = ids->cs->xskInfo;
           auto xskPacket = xskInfo->getEmptyFrame();
           if (!xskPacket) {
             continue;
           }
-          //vinfolog("XSK setHeader");
           xskPacket->setHeader(ids->xskPacketHeader);
-          //vinfolog("XSK payload");
           xskPacket->setPayload(response);
-          //vinfolog("XSK update packet");
           xskPacket->updatePacket();
-          //vinfolog("XSK pushed to send queue");
           xskInfo->pushToSendQueue(std::move(*xskPacket));
           xskInfo->notifyXskSocket();
 #endif /* HAVE_XSK */
@@ -1701,11 +1622,6 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders
       ++dq.ids.cs->responses;
       return ProcessQueryResult::SendAnswer;
     }
-#ifdef HAVE_XSK
-    if (dq.ids.cs->xskInfo) {
-      dq.ids.poolName = dq.ids.cs->xskInfo->poolName;
-    }
-#endif /* HAVE_XSK */
     std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, dq.ids.poolName);
     std::shared_ptr<ServerPolicy> poolPolicy = serverPool->policy;
     dq.ids.packetCache = serverPool->packetCache;
@@ -2222,6 +2138,8 @@ static bool ProcessXskQuery(ClientState& cs, LocalHolders& holders, XskPacket& p
       if (dq.ids.delayMsec > 0) {
         packet.addDelay(dq.ids.delayMsec);
       }
+      const auto dh = dq.getHeader();
+      handleResponseSent(ids.qname, ids.qtype, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP, dnsdist::Protocol::DoUDP, false);
       return true;
     }
 
@@ -2247,6 +2165,7 @@ static bool ProcessXskQuery(ClientState& cs, LocalHolders& holders, XskPacket& p
       return false;
     }
 
+#ifdef HAVE_XSK
     if (!ss->xskInfo) {
       assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, true);
       return false;
@@ -2255,11 +2174,16 @@ static bool ProcessXskQuery(ClientState& cs, LocalHolders& holders, XskPacket& p
       int fd = ss->xskInfo->workerWaker;
       ids.backendFD = fd;
       assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, false);
-      packet.setAddr(ss->d_config.sourceAddr,ss->d_config.sourceMACAddr, ss->d_config.remote,ss->d_config.destMACAddr);
+      auto sourceAddr = ss->pickSourceAddressForSending();
+      packet.setAddr(sourceAddr, ss->d_config.sourceMACAddr, ss->d_config.remote, ss->d_config.destMACAddr);
       packet.setPayload(query);
       packet.rewrite();
       return true;
     }
+#else /* HAVE_XSK */
+    assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, true);
+    return false;
+#endif /* HAVE_XSK */
   }
   catch (const std::exception& e) {
     vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
@@ -2638,11 +2562,6 @@ static void healthChecksThread()
 
     std::unique_ptr<FDMultiplexer> mplexer{nullptr};
     for (auto& dss : *states) {
-#ifdef HAVE_XSK
-      if (dss->xskInfo) {
-        continue;
-      }
-#endif /* HAVE_XSK */
       dss->updateStatisticsInfo();
 
       dss->handleUDPTimeouts();
@@ -3387,6 +3306,8 @@ static void startFrontends()
   for (auto& clientState : g_frontends) {
 #ifdef HAVE_XSK
     if (clientState->xskInfo) {
+      XskSocket::addDestinationAddress(clientState->local);
+
       std::thread xskCT(dnsdist::xsk::xskClientThread, clientState.get());
       if (!clientState->cpus.empty()) {
         mapThreadToCPUList(xskCT.native_handle(), clientState->cpus);
@@ -3493,6 +3414,10 @@ int main(int argc, char** argv)
     dnsdist::initRandom();
     g_hashperturb = dnsdist::getRandomValue(0xffffffff);
 
+#ifdef HAVE_XSK
+    XskSocket::clearDestinationAddresses();
+#endif /* HAVE_XSK */
+
     ComboAddress clientAddress = ComboAddress();
     g_cmdLine.config=SYSCONFDIR "/dnsdist.conf";
 
@@ -3655,11 +3580,6 @@ int main(int argc, char** argv)
       auto states = g_dstates.getCopy(); // it is a copy, but the internal shared_ptrs are the real deal
       auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(states.size()));
       for (auto& dss : states) {
-#ifdef HAVE_XSK
-        if (dss->xskInfo) {
-          continue;
-        }
-#endif /* HAVE_XSK */
 
         if (dss->d_config.availability == DownstreamState::Availability::Auto || dss->d_config.availability == DownstreamState::Availability::Lazy) {
           if (dss->d_config.availability == DownstreamState::Availability::Auto) {
index 0081c492a625066c872c96aac3adafbe3a484211..4cec0d03043cda85abc2922aa9aac9f85b99f008 100644 (file)
@@ -819,7 +819,10 @@ public:
   std::vector<int> sockets;
   StopWatch sw;
   QPSLimiter qps;
+#ifdef HAVE_XSK
   std::shared_ptr<XskWorker> xskInfo{nullptr};
+  std::shared_ptr<XskSocket> d_xskSocket{nullptr};
+#endif
   std::atomic<uint64_t> idOffset{0};
   size_t socketsOffset{0};
   double latencyUsec{0.0};
@@ -834,10 +837,17 @@ private:
   void handleUDPTimeout(IDState& ids);
   void updateNextLazyHealthCheck(LazyHealthCheckStats& stats, bool checkScheduled, std::optional<time_t> currentTime = std::nullopt);
   void connectUDPSockets();
+#ifdef HAVE_XSK
+  void addXSKDestination(int fd);
+  void removeXSKDestination(int fd);
+#endif /* HAVE_XSK */
 
   std::thread tid;
   std::mutex connectLock;
   std::condition_variable d_connectedWait;
+#ifdef HAVE_XSK
+  SharedLockGuarded<std::vector<ComboAddress>> d_socketSourceAddresses;
+#endif
   std::atomic_flag threadStarted;
   uint8_t consecutiveSuccessfulChecks{0};
   bool d_stopped{false};
@@ -979,16 +989,8 @@ public:
   std::optional<InternalQueryState> getState(uint16_t id);
 
 #ifdef HAVE_XSK
-  void registerXsk(std::shared_ptr<XskSocket>& xsk)
-  {
-    xskInfo = XskWorker::create();
-    if (d_config.sourceAddr.sin4.sin_family == 0) {
-      throw runtime_error("invalid source addr");
-    }
-    xsk->addWorker(xskInfo, d_config.sourceAddr);
-    d_config.sourceMACAddr = xsk->source;
-    xskInfo->sharedEmptyFrameOffset = xsk->sharedEmptyFrameOffset;
-  }
+  void registerXsk(std::shared_ptr<XskSocket>& xsk);
+  [[nodiscard]] ComboAddress pickSourceAddressForSending();
 #endif /* HAVE_XSK */
 
   dnsdist::Protocol getProtocol() const
@@ -1194,3 +1196,10 @@ ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss
 bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote);
 void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol, bool fromBackend);
 void handleResponseSent(const InternalQueryState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, bool fromBackend);
+
+#ifdef HAVE_XSK
+namespace dnsdist::xsk
+{
+void responderThread(std::shared_ptr<DownstreamState> dss);
+}
+#endif /* HAVE_XSK */
index bd7592545a49324d2e681f89bddbcd042ea3cc35..02cfea5c7d3036e8851c19bd5c8b6e509a64ec04 100644 (file)
@@ -19,7 +19,7 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
-
+#include "config.h"
 #include "dnsdist.hh"
 #include "dnsdist-backoff.hh"
 #include "dnsdist-metrics.hh"
@@ -28,6 +28,7 @@
 #include "dnsdist-rings.hh"
 #include "dnsdist-tcp.hh"
 #include "dolog.hh"
+#include "xsk.hh"
 
 bool DownstreamState::passCrossProtocolQuery(std::unique_ptr<CrossProtocolQuery>&& cpq)
 {
@@ -39,6 +40,36 @@ bool DownstreamState::passCrossProtocolQuery(std::unique_ptr<CrossProtocolQuery>
   return g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq));
 }
 
+#ifdef HAVE_XSK
+void DownstreamState::addXSKDestination(int fd)
+{
+  auto socklen = d_config.remote.getSocklen();
+  ComboAddress local;
+  if (getsockname(fd, reinterpret_cast<sockaddr*>(&local), &socklen)) {
+    return;
+  }
+
+  {
+    auto addresses = d_socketSourceAddresses.write_lock();
+    addresses->push_back(local);
+  }
+  XskSocket::addDestinationAddress(local);
+  d_xskSocket->addWorkerRoute(xskInfo, local);
+}
+
+void DownstreamState::removeXSKDestination(int fd)
+{
+  auto socklen = d_config.remote.getSocklen();
+  ComboAddress local;
+  if (getsockname(fd, reinterpret_cast<sockaddr*>(&local), &socklen)) {
+    return;
+  }
+
+  XskSocket::removeDestinationAddress(local);
+  d_xskSocket->removeWorkerRoute(local);
+}
+#endif /* HAVE_XSK */
+
 bool DownstreamState::reconnect(bool initialAttempt)
 {
   std::unique_lock<std::mutex> tl(connectLock, std::try_to_lock);
@@ -52,11 +83,23 @@ bool DownstreamState::reconnect(bool initialAttempt)
   }
 
   connected = false;
+#ifdef HAVE_XSK
+  if (xskInfo != nullptr) {
+    auto addresses = d_socketSourceAddresses.write_lock();
+    addresses->clear();
+  }
+#endif /* HAVE_XSK */
+
   for (auto& fd : sockets) {
     if (fd != -1) {
       if (sockets.size() > 1) {
         (*mplexer.lock())->removeReadFD(fd);
       }
+#ifdef HAVE_XSK
+      if (xskInfo != nullptr) {
+        removeXSKDestination(fd);
+      }
+#endif /* HAVE_XSK */
       /* shutdown() is needed to wake up recv() in the responderThread */
       shutdown(fd, SHUT_RDWR);
       close(fd);
@@ -87,6 +130,11 @@ bool DownstreamState::reconnect(bool initialAttempt)
       if (sockets.size() > 1) {
         (*mplexer.lock())->addReadFD(fd, [](int, boost::any) {});
       }
+#ifdef HAVE_XSK
+      if (xskInfo != nullptr) {
+        addXSKDestination(fd);
+      }
+#endif /* HAVE_XSK */
       connected = true;
     }
     catch (const std::runtime_error& error) {
@@ -100,8 +148,19 @@ bool DownstreamState::reconnect(bool initialAttempt)
 
   /* if at least one (re-)connection failed, close all sockets */
   if (!connected) {
+#ifdef HAVE_XSK
+    if (xskInfo != nullptr) {
+      auto addresses = d_socketSourceAddresses.write_lock();
+      addresses->clear();
+    }
+#endif /* HAVE_XSK */
     for (auto& fd : sockets) {
       if (fd != -1) {
+#ifdef HAVE_XSK
+        if (xskInfo != nullptr) {
+          removeXSKDestination(fd);
+        }
+#endif /* HAVE_XSK */
         if (sockets.size() > 1) {
           try {
             (*mplexer.lock())->removeReadFD(fd);
@@ -268,7 +327,16 @@ DownstreamState::DownstreamState(DownstreamState::Config&& config, std::shared_p
 void DownstreamState::start()
 {
   if (connected && !threadStarted.test_and_set()) {
+#ifdef HAVE_XSK
+    if (xskInfo != nullptr) {
+      tid = std::thread(dnsdist::xsk::responderThread, shared_from_this());
+    }
+    else {
+      tid = std::thread(responderThread, shared_from_this());
+    }
+#else
     tid = std::thread(responderThread, shared_from_this());
+#endif /* HAVE_XSK */
 
     if (!d_config.d_cpus.empty()) {
       mapThreadToCPUList(tid.native_handle(), d_config.d_cpus);
@@ -797,6 +865,46 @@ void DownstreamState::submitHealthCheckResult(bool initial, bool newResult)
   }
 }
 
+#ifdef HAVE_XSK
+[[nodiscard]] ComboAddress DownstreamState::pickSourceAddressForSending()
+{
+  if (!connected) {
+    waitUntilConnected();
+  }
+
+  auto addresses = d_socketSourceAddresses.read_lock();
+  auto numberOfAddresses = addresses->size();
+  if (numberOfAddresses == 0) {
+    throw std::runtime_error("No source address available for sending XSK data to backend " + getNameWithAddr());
+  }
+  size_t idx = dnsdist::getRandomValue(numberOfAddresses);
+  return (*addresses)[idx % numberOfAddresses];
+}
+
+void DownstreamState::registerXsk(std::shared_ptr<XskSocket>& xsk)
+{
+  d_xskSocket = xsk;
+
+  if (d_config.sourceAddr.sin4.sin_family == 0 || (IsAnyAddress(d_config.sourceAddr))) {
+    const auto& ifName = xsk->getInterfaceName();
+    auto addresses = getListOfAddressesOfNetworkInterface(ifName);
+    if (addresses.empty()) {
+      throw std::runtime_error("Unable to get source address from interface " + ifName);
+    }
+
+    if (addresses.size() > 1) {
+      warnlog("More than one address configured on interface %s, picking the first one (%s) for XSK. Set the 'source' parameter on 'newServer' if you want to use a different address.", ifName, addresses.at(0).toString());
+    }
+    d_config.sourceAddr = addresses.at(0);
+  }
+  xskInfo = XskWorker::create();
+  xsk->addWorker(xskInfo);
+  reconnect(false);
+  d_config.sourceMACAddr = xsk->getSourceMACAddress();
+  xskInfo->sharedEmptyFrameOffset = xsk->sharedEmptyFrameOffset;
+}
+#endif /* HAVE_XSK */
+
 size_t ServerPool::countServers(bool upOnly)
 {
   std::shared_ptr<const ServerPolicy::NumberedServerVector> servers = nullptr;
index 813eeac44247ff246e3d888a77c811f549cf5b54..e05c0b56b4af626e38c926317ff6e0681cbb13a8 100644 (file)
@@ -74,6 +74,13 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
 {
 }
 
+namespace dnsdist::xsk
+{
+void responderThread(std::shared_ptr<DownstreamState> dss)
+{
+}
+}
+
 string g_outputBuffer;
 std::atomic<bool> g_configurationDone{false};
 
@@ -181,7 +188,7 @@ BOOST_AUTO_TEST_CASE(test_firstAvailableWithOrderAndQPS) {
   servers.push_back({ 2, std::make_shared<DownstreamState>(ComboAddress("192.0.2.2:53")) });
   /* Second server has a higher order, so most queries should be routed to the first (remember that
      we need to keep them ordered!).
-     However the first server has a QPS limit at 10 qps, so any query above that should be routed 
+     However the first server has a QPS limit at 10 qps, so any query above that should be routed
      to the second server. */
   servers.at(0).second->d_config.order = 1;
   servers.at(1).second->d_config.order = 2;
index 371587da5254a6629d2240eb85ca17c56a699182..15d9f14ec5f40a0f6a6948d9a3166b9ff33b85cb 100644 (file)
@@ -104,8 +104,8 @@ int XskSocket::firstTimeout()
   return res;
 }
 
-XskSocket::XskSocket(size_t frameNum_, const std::string& ifName_, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_) :
-  frameNum(frameNum_), ifName(ifName_), poolName(poolName_), socket(nullptr, xsk_socket__delete), sharedEmptyFrameOffset(std::make_shared<LockGuarded<vector<uint64_t>>>())
+XskSocket::XskSocket(size_t frameNum_, const std::string& ifName_, uint32_t queue_id, const std::string& xskMapPath) :
+  frameNum(frameNum_), ifName(ifName_), socket(nullptr, xsk_socket__delete), sharedEmptyFrameOffset(std::make_shared<LockGuarded<vector<uint64_t>>>())
 {
   if (!isPowOfTwo(frameNum_) || !isPowOfTwo(frameSize)
       || !isPowOfTwo(fqCapacity) || !isPowOfTwo(cqCapacity) || !isPowOfTwo(rxCapacity) || !isPowOfTwo(txCapacity)) {
@@ -175,6 +175,113 @@ XskSocket::XskSocket(size_t frameNum_, const std::string& ifName_, uint32_t queu
   }
 }
 
+// see xdp.h in contrib/
+
+struct IPv4AndPort
+{
+  uint32_t addr;
+  uint16_t port;
+};
+struct IPv6AndPort
+{
+  struct in6_addr addr;
+  uint16_t port;
+};
+
+static void clearDestinationMap(bool v6)
+{
+  const std::string mapPath = !v6 ? "/sys/fs/bpf/dnsdist/xsk-destinations-v4" : "/sys/fs/bpf/dnsdist/xsk-destinations-v6";
+
+  const auto destMapFd = FDWrapper(bpf_obj_get(mapPath.c_str()));
+  if (destMapFd.getHandle() < 0) {
+    throw std::runtime_error("Error getting the XSK destination addresses map path '" + mapPath + "'");
+  }
+
+  if (!v6) {
+    IPv4AndPort prevKey{};
+    IPv4AndPort key{};
+    while (bpf_map_get_next_key(destMapFd.getHandle(), &prevKey, &key) == 0) {
+      bpf_map_delete_elem(destMapFd.getHandle(), &key);
+      prevKey = key;
+    }
+  }
+  else {
+    IPv6AndPort prevKey{};
+    IPv6AndPort key{};
+    while (bpf_map_get_next_key(destMapFd.getHandle(), &prevKey, &key) == 0) {
+      bpf_map_delete_elem(destMapFd.getHandle(), &key);
+      prevKey = key;
+    }
+  }
+}
+
+void XskSocket::clearDestinationAddresses()
+{
+  clearDestinationMap(false);
+  clearDestinationMap(true);
+}
+
+void XskSocket::addDestinationAddress(const ComboAddress& destination)
+{
+  // see xdp.h in contrib/
+
+  const std::string mapPath = destination.isIPv4() ? "/sys/fs/bpf/dnsdist/xsk-destinations-v4" : "/sys/fs/bpf/dnsdist/xsk-destinations-v6";
+  //if (!s_destinationAddressesMap) {
+  //  throw std::runtime_error("The path of the XSK (AF_XDP) destination addresses map has not been set! Please consider using setXSKDestinationAddressesMapPath().");
+  //}
+
+  const auto destMapFd = FDWrapper(bpf_obj_get(mapPath.c_str()));
+  if (destMapFd.getHandle() < 0) {
+    throw std::runtime_error("Error getting the XSK destination addresses map path '" + mapPath + "'");
+  }
+
+  bool value = true;
+  if (destination.isIPv4()) {
+    IPv4AndPort key{};
+    key.addr = destination.sin4.sin_addr.s_addr;
+    key.port = destination.sin4.sin_port;
+    auto ret = bpf_map_update_elem(destMapFd.getHandle(), &key, &value, 0);
+    if (ret) {
+      throw std::runtime_error("Error inserting into xsk_map '" + mapPath + "': " + std::to_string(ret));
+    }
+  }
+  else {
+    IPv6AndPort key{};
+    key.addr = destination.sin6.sin6_addr;
+    key.port = destination.sin6.sin6_port;
+    auto ret = bpf_map_update_elem(destMapFd.getHandle(), &key, &value, 0);
+    if (ret) {
+      throw std::runtime_error("Error inserting into XSK destination addresses map '" + mapPath + "': " + std::to_string(ret));
+    }
+  }
+}
+
+void XskSocket::removeDestinationAddress(const ComboAddress& destination)
+{
+  const std::string mapPath = destination.isIPv4() ? "/sys/fs/bpf/dnsdist/xsk-destinations-v4" : "/sys/fs/bpf/dnsdist/xsk-destinations-v6";
+  //if (!s_destinationAddressesMap) {
+  //  throw std::runtime_error("The path of the XSK (AF_XDP) destination addresses map has not been set! Please consider using setXSKDestinationAddressesMapPath().");
+  //}
+
+  const auto destMapFd = FDWrapper(bpf_obj_get(mapPath.c_str()));
+  if (destMapFd.getHandle() < 0) {
+    throw std::runtime_error("Error getting the XSK destination addresses map path '" + mapPath + "'");
+  }
+
+  if (destination.isIPv4()) {
+    IPv4AndPort key{};
+    key.addr = destination.sin4.sin_addr.s_addr;
+    key.port = destination.sin4.sin_port;
+    bpf_map_delete_elem(destMapFd.getHandle(), &key);
+  }
+  else {
+    IPv6AndPort key{};
+    key.addr = destination.sin6.sin6_addr;
+    key.port = destination.sin6.sin6_port;
+    bpf_map_delete_elem(destMapFd.getHandle(), &key);
+  }
+}
+
 void XskSocket::fillFq(uint32_t fillSize) noexcept
 {
   {
@@ -412,14 +519,17 @@ XskSocket::XskUmem::~XskUmem()
 [[nodiscard]] ethhdr XskPacket::getEthernetHeader() const noexcept
 {
   ethhdr ethHeader{};
-  assert(frameLength >= sizeof(ethHeader));
-  memcpy(&ethHeader, frame, sizeof(ethHeader));
+  if (frameLength >= sizeof(ethHeader)) {
+    memcpy(&ethHeader, frame, sizeof(ethHeader));
+  }
   return ethHeader;
 }
 
 void XskPacket::setEthernetHeader(const ethhdr& ethHeader) noexcept
 {
-  assert(frameLength >= sizeof(ethHeader));
+  if (frameLength < sizeof(ethHeader)) {
+    frameLength = sizeof(ethHeader);
+  }
   memcpy(frame, &ethHeader, sizeof(ethHeader));
 }
 
@@ -631,8 +741,8 @@ bool XskPacket::isIPV6() const noexcept
   return v6;
 }
 
-XskPacket::XskPacket(uint8_t* frame_, size_t dataSize, size_t frameSize) :
-  frame(frame_), frameLength(dataSize), frameSize(frameSize - XDP_PACKET_HEADROOM)
+XskPacket::XskPacket(uint8_t* frame_, size_t dataSize, size_t frameSize_) :
+  frame(frame_), frameLength(dataSize), frameSize(frameSize_ - XDP_PACKET_HEADROOM)
 {
 }
 
@@ -757,7 +867,7 @@ void XskPacket::rewrite() noexcept
     ipHeader.protocol = IPPROTO_UDP;
     udpHeader.source = from.sin4.sin_port;
     udpHeader.dest = to.sin4.sin_port;
-    udpHeader.len = htons(getDataSize());
+    udpHeader.len = htons(getDataSize() + sizeof(udpHeader));
     udpHeader.check = 0;
     /* needed to get the correct checksum */
     setIPv4Header(ipHeader);
@@ -963,32 +1073,27 @@ std::shared_ptr<XskWorker> XskWorker::create()
   return std::make_shared<XskWorker>();
 }
 
-void XskSocket::addWorker(std::shared_ptr<XskWorker> s, const ComboAddress& dest)
+void XskSocket::addWorker(std::shared_ptr<XskWorker> worker)
 {
-  extern std::atomic<bool> g_configurationDone;
-  if (g_configurationDone) {
-    throw runtime_error("Adding a server with xsk at runtime is not supported");
-  }
-  s->poolName = poolName;
-  const auto socketWaker = s->xskSocketWaker.getHandle();
-  const auto workerWaker = s->workerWaker.getHandle();
-  const auto& socketWakerIdx = workers.get<0>();
-  if (socketWakerIdx.contains(socketWaker)) {
-    throw runtime_error("Server already exist");
-  }
-  s->umemBufBase = umem.bufBase;
-  workers.insert(XskRouteInfo{
-    .worker = std::move(s),
-    .dest = dest,
-    .xskSocketWaker = socketWaker,
-    .workerWaker = workerWaker,
-  });
+  const auto socketWaker = worker->xskSocketWaker.getHandle();
+  worker->umemBufBase = umem.bufBase;
+  d_workers.insert({socketWaker, std::move(worker)});
   fds.push_back(pollfd{
     .fd = socketWaker,
     .events = POLLIN,
     .revents = 0});
 };
 
+void XskSocket::addWorkerRoute(const std::shared_ptr<XskWorker>& worker, const ComboAddress& dest)
+{
+  d_workerRoutes.lock()->insert({dest, worker});
+}
+
+void XskSocket::removeWorkerRoute(const ComboAddress& dest)
+{
+  d_workerRoutes.lock()->erase(dest);
+}
+
 uint64_t XskWorker::frameOffset(const XskPacket& packet) const noexcept
 {
   return packet.getFrameOffsetFrom(umemBufBase);
index dc9f285751b2a09b129f30992f220dd32856f148..702855a4739f32e8b38829a012a6e5102f1db339 100644 (file)
@@ -71,13 +71,6 @@ using XskPacketPtr = std::unique_ptr<XskPacket>;
 
 class XskSocket
 {
-  struct XskRouteInfo
-  {
-    std::shared_ptr<XskWorker> worker;
-    ComboAddress dest;
-    int xskSocketWaker;
-    int workerWaker;
-  };
   struct XskUmem
   {
     xsk_umem* umem{nullptr};
@@ -87,12 +80,11 @@ class XskSocket
     ~XskUmem();
     XskUmem() = default;
   };
-  using WorkerContainer = boost::multi_index_container<
-    XskRouteInfo,
-    boost::multi_index::indexed_by<
-      boost::multi_index::hashed_unique<boost::multi_index::member<XskRouteInfo, int, &XskRouteInfo::xskSocketWaker>>,
-      boost::multi_index::hashed_unique<boost::multi_index::member<XskRouteInfo, ComboAddress, &XskRouteInfo::dest>, ComboAddress::addressPortOnlyHash>>>;
-  WorkerContainer workers;
+  using WorkerContainer = std::unordered_map<int, std::shared_ptr<XskWorker>>;
+  WorkerContainer d_workers;
+  using WorkerRoutesMap = std::unordered_map<ComboAddress, std::shared_ptr<XskWorker>, ComboAddress::addressPortOnlyHash>;
+  // it might be better to move to a StateHolder for performance
+  LockGuarded<WorkerRoutesMap> d_workerRoutes;
   // number of frames to keep in sharedEmptyFrameOffset
   static constexpr size_t holdThreshold = 256;
   // number of frames to insert into the fill queue
@@ -102,8 +94,8 @@ class XskSocket
   const size_t frameNum;
   // responses that have been delayed
   std::priority_queue<XskPacket> waitForDelay;
+  MACAddr source;
   const std::string ifName;
-  const std::string poolName;
   // AF_XDP socket then worker waker sockets
   vector<pollfd> fds;
   // list of frames, aka (indexes of) umem entries that can be reused to fill fq,
@@ -135,14 +127,16 @@ class XskSocket
   void getMACFromIfName();
 
 public:
+  static void clearDestinationAddresses();
+  static void addDestinationAddress(const ComboAddress& destination);
+  static void removeDestinationAddress(const ComboAddress& destination);
   static constexpr size_t getFrameSize()
   {
     return frameSize;
   }
   // list of free umem entries that can be reused
   std::shared_ptr<LockGuarded<vector<uint64_t>>> sharedEmptyFrameOffset;
-  XskSocket(size_t frameNum, const std::string& ifName, uint32_t queue_id, const std::string& xskMapPath, const std::string& poolName_);
-  MACAddr source;
+  XskSocket(size_t frameNum, const std::string& ifName, uint32_t queue_id, const std::string& xskMapPath);
   [[nodiscard]] int xskFd() const noexcept;
   // wait until one event has occurred
   [[nodiscard]] int wait(int timeout);
@@ -150,17 +144,36 @@ public:
   void send(std::vector<XskPacket>& packets);
   // look at incoming packets in rx, return them if parsing succeeeded
   [[nodiscard]] std::vector<XskPacket> recv(uint32_t recvSizeMax, uint32_t* failedCount);
-  void addWorker(std::shared_ptr<XskWorker> s, const ComboAddress& dest);
+  void addWorker(std::shared_ptr<XskWorker> s);
+  void addWorkerRoute(const std::shared_ptr<XskWorker>& worker, const ComboAddress& dest);
+  void removeWorkerRoute(const ComboAddress& dest);
   [[nodiscard]] std::string getMetrics() const;
   void markAsFree(XskPacket&& packet);
-  [[nodiscard]] WorkerContainer& getWorkers()
+  [[nodiscard]] const std::shared_ptr<XskWorker>& getWorkerByDescriptor(int desc) const
+  {
+    return d_workers.at(desc);
+  }
+  [[nodiscard]] std::shared_ptr<XskWorker> getWorkerByDestination(const ComboAddress& destination)
   {
-    return workers;
+    auto routes = d_workerRoutes.lock();
+    auto workerIt = routes->find(destination);
+    if (workerIt == routes->end()) {
+      return nullptr;
+    }
+    return workerIt->second;
   }
   [[nodiscard]] const std::vector<pollfd>& getDescriptors() const
   {
     return fds;
   }
+  [[nodiscard]] MACAddr getSourceMACAddress() const
+  {
+    return source;
+  }
+  [[nodiscard]] const std::string& getInterfaceName() const
+  {
+    return ifName;
+  }
   // pick ups available frames from uniqueEmptyFrameOffset
   // insert entries from uniqueEmptyFrameOffset into fq
   void fillFq(uint32_t fillSize = fillThreshold) noexcept;
@@ -291,7 +304,6 @@ public:
   std::shared_ptr<LockGuarded<vector<uint64_t>>> sharedEmptyFrameOffset;
   // list of frames that we own, used to generate new packets (health-check)
   vector<uint64_t> uniqueEmptyFrameOffset;
-  std::string poolName;
   const size_t frameSize{XskSocket::getFrameSize()};
   FDWrapper workerWaker;
   FDWrapper xskSocketWaker;