]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add port range support for dynamic blocks
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 4 Oct 2021 16:00:23 +0000 (18:00 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 21 Oct 2021 08:08:41 +0000 (10:08 +0200)
.github/actions/spell-check/expect.txt
pdns/dnsdist-dynblocks.hh
pdns/dnsdist-lua-inspection.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-lua.hh
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-dynblocks.cc
pdns/dnsdistdist/docs/reference/config.rst
pdns/dnsdistdist/test-dnsdistdynblocks_hh.cc
pdns/iputils.hh

index 9110acdcde421f0e60cb9e2838724af2cce8ba81..451c21768e9736f41f565410ef668eca2bdc45bb 100644 (file)
@@ -217,6 +217,7 @@ certusage
 cfea
 CFLAGS
 cgi
+CGNAT
 changelog
 changeme
 changeset
index c8a7aaa8c60ddd31b0c6ff120578d389d89265ef..2a817e13bfe799854a695db178e806825beea5c8 100644 (file)
@@ -212,7 +212,7 @@ private:
     double d_warningRatio{0.0};
   };
 
-  typedef std::unordered_map<Netmask, Counts, Netmask::hash> counts_t;
+  typedef std::unordered_map<AddressAndPortRange, Counts, AddressAndPortRange::hash> counts_t;
 
 public:
   DynBlockRulesGroup()
@@ -262,10 +262,11 @@ public:
     d_smtVisitorFFI = visitor;
   }
 
-  void setMasks(uint8_t v4, uint8_t v6)
+  void setMasks(uint8_t v4, uint8_t v6, uint8_t port)
   {
     d_v4Mask = v4;
     d_v6Mask = v6;
+    d_portMask = port;
   }
 
   void apply()
@@ -336,15 +337,15 @@ private:
 
   bool checkIfQueryTypeMatches(const Rings::Query& query);
   bool checkIfResponseCodeMatches(const Rings::Response& response);
-  void addOrRefreshBlock(boost::optional<NetmaskTree<DynBlock> >& blocks, const struct timespec& now, const Netmask& requestor, const DynBlockRule& rule, bool& updated, bool warning);
+  void addOrRefreshBlock(boost::optional<NetmaskTree<DynBlock, AddressAndPortRange> >& blocks, const struct timespec& now, const AddressAndPortRange& requestor, const DynBlockRule& rule, bool& updated, bool warning);
   void addOrRefreshBlockSMT(SuffixMatchTree<DynBlock>& blocks, const struct timespec& now, const DNSName& name, const DynBlockRule& rule, bool& updated);
 
-  void addBlock(boost::optional<NetmaskTree<DynBlock> >& blocks, const struct timespec& now, const Netmask& requestor, const DynBlockRule& rule, bool& updated)
+  void addBlock(boost::optional<NetmaskTree<DynBlock, AddressAndPortRange> >& blocks, const struct timespec& now, const AddressAndPortRange& requestor, const DynBlockRule& rule, bool& updated)
   {
     addOrRefreshBlock(blocks, now, requestor, rule, updated, false);
   }
 
-  void handleWarning(boost::optional<NetmaskTree<DynBlock> >& blocks, const struct timespec& now, const Netmask& requestor, const DynBlockRule& rule, bool& updated)
+  void handleWarning(boost::optional<NetmaskTree<DynBlock, AddressAndPortRange> >& blocks, const struct timespec& now, const AddressAndPortRange& requestor, const DynBlockRule& rule, bool& updated)
   {
     addOrRefreshBlock(blocks, now, requestor, rule, updated, true);
   }
@@ -384,6 +385,7 @@ private:
   dnsdist_ffi_stat_node_visitor_t d_smtVisitorFFI;
   uint8_t d_v6Mask{128};
   uint8_t d_v4Mask{32};
+  uint8_t d_portMask{0};
   bool d_beQuiet{false};
 };
 
@@ -393,11 +395,11 @@ public:
   static void run();
 
   /* return the (cached) number of hits per second for the top offenders, averaged over 60s */
-  static std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> getHitsForTopNetmasks();
+  static std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> getHitsForTopNetmasks();
   static std::map<std::string, std::list<std::pair<DNSName, unsigned int>>> getHitsForTopSuffixes();
 
   /* get the the top offenders based on the current value of the counters */
-  static std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> getTopNetmasks(size_t topN);
+  static std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> getTopNetmasks(size_t topN);
   static std::map<std::string, std::list<std::pair<DNSName, unsigned int>>> getTopSuffixes(size_t topN);
   static void purgeExpired(const struct timespec& now);
 
@@ -409,13 +411,13 @@ private:
 
   struct MetricsSnapshot
   {
-    std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> nmgData;
+    std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> nmgData;
     std::map<std::string, std::list<std::pair<DNSName, unsigned int>>> smtData;
   };
 
   struct Tops
   {
-    std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> topNMGsByReason;
+    std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> topNMGsByReason;
     std::map<std::string, std::list<std::pair<DNSName, unsigned int>>> topSMTsByReason;
   };
 
@@ -425,110 +427,3 @@ private:
   static std::list<MetricsSnapshot> s_metricsData;
   static size_t s_topN;
 };
-
-class AddressAndPortRange
-{
-public:
-  AddressAndPortRange(): d_addrMask(0), d_portMask(0)
-  {
-    d_addr.sin4.sin_family = 0; // disable this doing anything useful
-    d_addr.sin4.sin_port = 0; // this guarantees d_network compares identical
-  }
-
-  AddressAndPortRange(ComboAddress ca, uint8_t addrMask, uint8_t portMask): d_addr(std::move(ca)), d_addrMask(addrMask), d_portMask(portMask)
-  {
-    cerr<<"creating a address and port range "<<ca.toStringWithPort()<<"/"<<std::to_string(addrMask)<<" "<<std::to_string(portMask)<<endl;
-    if (d_addrMask < d_addr.getBits()) {
-      uint16_t port = d_addr.getPort();
-      d_addr = Netmask(d_addr, d_addrMask).getMaskedNetwork();
-      d_addr.setPort(port);
-    }
-  }
-
-  uint8_t getFullBits() const
-  {
-    cerr<<"in getFullbits returning "<<(d_addr.getBits() + 16)<<endl;
-    return d_addr.getBits() + 16;
-  }
-
-  uint8_t getBits() const
-  {
-    if (d_addrMask < d_addr.getBits()) {
-      cerr<<"in getBits returning "<<std::to_string(d_addrMask)<<endl;
-      return d_addrMask;
-    }
-
-    cerr<<"in getBits returning "<<std::to_string(d_addr.getBits() + d_portMask)<<endl;
-    return d_addr.getBits() + d_portMask;
-  }
-
-  /** Get the value of the bit at the provided bit index. When the index >= 0,
-      the index is relative to the LSB starting at index zero. When the index < 0,
-      the index is relative to the MSB starting at index -1 and counting down.
-  */
-  bool getBit(int index) const
-  {
-    cerr<<"in getBit "<<index<<" for "<<d_addr.toStringWithPort()<<endl;
-    if (index >= getFullBits()) {
-      cerr<<"flse 1"<<endl;
-      return false;
-    }
-    if (index < 0) {
-      index = getFullBits() + index;
-      cerr<<"normalized index to "<<index<<endl;
-    }
-
-    if (index < 16) {
-      /* we are into the port bits */
-      uint16_t port = d_addr.getPort();
-      cerr<<"return (2) "<<((port & (1U<<index)) != 0x0000)<<endl;
-      return ((port & (1U<<index)) != 0x0000);
-    }
-
-    index -= 16;
-
-    cerr<<"return (d_addr) "<<d_addr.getBit(index)<<endl;
-    return d_addr.getBit(index);
-  }
-
-  bool isIPv4() const
-  {
-    return d_addr.isIPv4();
-  }
-
-  bool isIPv6() const
-  {
-    return d_addr.isIPv6();
-  }
-
-  AddressAndPortRange getNormalized() const
-  {
-    cerr<<"in getNormalized"<<endl;
-    return AddressAndPortRange(d_addr, d_addrMask, d_portMask);
-  }
-
-  AddressAndPortRange getSuper(uint8_t bits) const
-  {
-    cerr<<"in getSuper("<<std::to_string(bits)<<")"<<endl;
-    if (bits <= d_addrMask) {
-      return AddressAndPortRange(d_addr, bits, 0);
-    }
-    if (bits <= d_addrMask + d_portMask) {
-      return AddressAndPortRange(d_addr, d_addrMask, d_portMask - (bits - d_addrMask));
-    }
-
-    return AddressAndPortRange(d_addr, d_addrMask, d_portMask);
-  }
-
-  const ComboAddress& getNetwork() const
-  {
-    cerr<<"in getNetwork"<<endl;
-    return d_addr;
-  }
-
-private:
-  ComboAddress d_addr;
-  uint8_t d_addrMask;
-  uint8_t d_portMask;
-};
-
index dc210c9e229245c48aac2ca21298f5b82b489bcc..0677bf684d9944372b340ea9c2f2576db4c2bb92 100644 (file)
@@ -796,9 +796,12 @@ void setupLuaInspection(LuaContext& luaCtx)
         group->setQTypeRate(qtype, rate, warningRate ? *warningRate : 0, seconds, reason, blockDuration, action ? *action : DNSAction::Action::None);
       }
     });
-  luaCtx.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(uint8_t, uint8_t)>("setMasks", [](std::shared_ptr<DynBlockRulesGroup>& group, uint8_t v4, uint8_t v6) {
+  luaCtx.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(uint8_t, uint8_t, uint8_t)>("setMasks", [](std::shared_ptr<DynBlockRulesGroup>& group, uint8_t v4, uint8_t v6, uint8_t port) {
       if (group) {
-        group->setMasks(v4, v6);
+        if (port > 0 && v4 != 32) {
+          throw std::runtime_error("Setting a non-zero port mask for Dynamic Blocks while only considering parts of IPv4 addresses does not make sense");
+        }
+        group->setMasks(v4, v6, port);
       }
     });
   luaCtx.registerFunction<void(std::shared_ptr<DynBlockRulesGroup>::*)(boost::variant<std::string, std::vector<std::pair<int, std::string>>, NetmaskGroup>)>("excludeRange", [](std::shared_ptr<DynBlockRulesGroup>& group, boost::variant<std::string, std::vector<std::pair<int, std::string>>, NetmaskGroup> ranges) {
index 74cf79ae5ba9bdf4cd7e0e9a2f46fbea76867340..7e4b0079ca6952ee0485bc26bc106cae838f427a 100644 (file)
@@ -1392,26 +1392,33 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck)
                           auto slow = g_dynblockNMG.getCopy();
                           struct timespec until, now;
                           gettime(&now);
-                          until=now;
+                          until = now;
                            int actualSeconds = seconds ? *seconds : 10;
                           until.tv_sec += actualSeconds;
-                          for(const auto& capair : m) {
+                          for (const auto& capair : m) {
                             unsigned int count = 0;
-                             auto got = slow.lookup(Netmask(capair.first));
-                             bool expired=false;
-                            if(got) {
-                              if(until < got->second.until) // had a longer policy
+                             AddressAndPortRange requestor(capair.first, capair.first.isIPv4() ? 32 : 128, 0);
+                             auto got = slow.lookup(requestor);
+                             bool expired = false;
+                            if (got) {
+                              if (until < got->second.until) {
+                                 // had a longer policy
                                 continue;
-                              if(now < got->second.until) // only inherit count on fresh query we are extending
-                                count=got->second.blocks;
-                               else
-                                 expired=true;
+                               }
+                              if (now < got->second.until) {
+                                 // only inherit count on fresh query we are extending
+                                count = got->second.blocks;
+                               }
+                               else {
+                                 expired = true;
+                               }
                             }
-                            DynBlock db{msg,until,DNSName(),(action ? *action : DNSAction::Action::None)};
-                            db.blocks=count;
-                             if(!got || expired)
+                            DynBlock db{msg, until, DNSName(), (action ? *action : DNSAction::Action::None)};
+                            db.blocks = count;
+                             if (!got || expired) {
                                warnlog("Inserting dynamic block for %s for %d seconds: %s", capair.first.toString(), actualSeconds, msg);
-                            slow.insert(Netmask(capair.first)).second=db;
+                             }
+                            slow.insert(requestor).second = db;
                           }
                           g_dynblockNMG.setState(slow);
                         });
index 28cd00da73969290a3724db26408ccb1b7a9a30f..b8230ebad60b0693256277a9e409a3b2ad37f2ce 100644 (file)
@@ -92,7 +92,7 @@ std::shared_ptr<DNSRule> makeRule(const luadnsrule_t& var);
 typedef std::unordered_map<std::string, boost::variant<std::string> > luaruleparams_t;
 void parseRuleParams(boost::optional<luaruleparams_t> params, boost::uuids::uuid& uuid, std::string& name, uint64_t& creationOrder);
 
-typedef NetmaskTree<DynBlock> nmts_t;
+typedef NetmaskTree<DynBlock, AddressAndPortRange> nmts_t;
 
 vector<std::function<void(void)>> setupLua(LuaContext& luaCtx, bool client, bool configCheck, const std::string& config);
 void setupLuaActions(LuaContext& luaCtx);
index a82247f4c254410594a5018b0fa2f3f4fc8c5801..10064e9bc6097ee18d1e130021cba11bdca09538 100644 (file)
@@ -132,7 +132,7 @@ Rings g_rings;
 QueryCount g_qcount;
 
 GlobalStateHolder<servers_t> g_dstates;
-GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
+GlobalStateHolder<NetmaskTree<DynBlock, AddressAndPortRange>> g_dynblockNMG;
 GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
 DNSAction::Action g_dynBlockAction = DNSAction::Action::Drop;
 int g_udpTimeout{2};
@@ -861,13 +861,13 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
     }
   }
 
-  if(auto got = holders.dynNMGBlock->lookup(*dq.remote)) {
+  if (auto got = holders.dynNMGBlock->lookup(AddressAndPortRange(*dq.remote, dq.remote->isIPv4() ? 32 : 128, 16))) {
     auto updateBlockStats = [&got]() {
       ++g_stats.dynBlocked;
       got->second.blocks++;
     };
 
-    if(now < got->second.until) {
+    if (now < got->second.until) {
       DNSAction::Action action = got->second.action;
       if (action == DNSAction::Action::None) {
         action = g_dynBlockAction;
@@ -921,13 +921,13 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru
     }
   }
 
-  if(auto got = holders.dynSMTBlock->lookup(*dq.qname)) {
+  if (auto got = holders.dynSMTBlock->lookup(*dq.qname)) {
     auto updateBlockStats = [&got]() {
       ++g_stats.dynBlocked;
       got->blocks++;
     };
 
-    if(now < got->until) {
+    if (now < got->until) {
       DNSAction::Action action = got->action;
       if (action == DNSAction::Action::None) {
         action = g_dynBlockAction;
index 08d44ee8496776cab478b3fa02d5309260131d5b..6d54b94943f2b383bdf22d25cf57dae9a265bc0b 100644 (file)
@@ -323,7 +323,7 @@ struct DynBlock
   bool bpf{false};
 };
 
-extern GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
+extern GlobalStateHolder<NetmaskTree<DynBlock, AddressAndPortRange>> g_dynblockNMG;
 
 extern vector<pair<struct timeval, std::string> > g_confDelta;
 
@@ -1025,7 +1025,7 @@ struct LocalHolders
   LocalStateHolder<vector<DNSDistResponseRuleAction> > cacheHitRespRuleactions;
   LocalStateHolder<vector<DNSDistResponseRuleAction> > selfAnsweredRespRuleactions;
   LocalStateHolder<servers_t> servers;
-  LocalStateHolder<NetmaskTree<DynBlock> > dynNMGBlock;
+  LocalStateHolder<NetmaskTree<DynBlock, AddressAndPortRange> > dynNMGBlock;
   LocalStateHolder<SuffixMatchTree<DynBlock> > dynSMTBlock;
   LocalStateHolder<pools_t> pools;
 };
index 2b73743b9ddb631d01e004cac33cd2e6008663e0..a1230f1017b2b3dfcc2ef81199597febb6b4a649 100644 (file)
@@ -23,7 +23,7 @@ void DynBlockRulesGroup::apply(const struct timespec& now)
     return;
   }
 
-  boost::optional<NetmaskTree<DynBlock> > blocks;
+  boost::optional<NetmaskTree<DynBlock, AddressAndPortRange> > blocks;
   bool updated = false;
 
   for (const auto& entry : counts) {
@@ -174,9 +174,9 @@ bool DynBlockRulesGroup::checkIfResponseCodeMatches(const Rings::Response& respo
   return false;
 }
 
-void DynBlockRulesGroup::addOrRefreshBlock(boost::optional<NetmaskTree<DynBlock> >& blocks, const struct timespec& now, const Netmask& requestor, const DynBlockRule& rule, bool& updated, bool warning)
+void DynBlockRulesGroup::addOrRefreshBlock(boost::optional<NetmaskTree<DynBlock, AddressAndPortRange> >& blocks, const struct timespec& now, const AddressAndPortRange& requestor, const DynBlockRule& rule, bool& updated, bool warning)
 {
-  if (d_excludedSubnets.match(requestor.getMaskedNetwork())) {
+  if (d_excludedSubnets.match(requestor.getNetwork())) {
     /* do not add a block for excluded subnets */
     return;
   }
@@ -187,7 +187,7 @@ void DynBlockRulesGroup::addOrRefreshBlock(boost::optional<NetmaskTree<DynBlock>
   struct timespec until = now;
   until.tv_sec += rule.d_blockDuration;
   unsigned int count = 0;
-  const auto& got = blocks->lookup(requestor.getMaskedNetwork());
+  const auto& got = blocks->lookup(requestor.getNetwork());
   bool expired = false;
   bool wasWarning = false;
   bool bpf = false;
@@ -226,7 +226,7 @@ void DynBlockRulesGroup::addOrRefreshBlock(boost::optional<NetmaskTree<DynBlock>
     if (db.action == DNSAction::Action::Drop && g_defaultBPFFilter &&
         ((requestor.isIPv4() && requestor.getBits() == 32) || (requestor.isIPv6() && requestor.getBits() == 128))) {
       try {
-        g_defaultBPFFilter->block(requestor.getMaskedNetwork());
+        g_defaultBPFFilter->block(requestor.getNetwork());
         bpf = true;
       }
       catch (const std::exception& e) {
@@ -315,7 +315,7 @@ void DynBlockRulesGroup::processQueryRules(counts_t& counts, const struct timesp
       bool typeRuleMatches = checkIfQueryTypeMatches(c);
 
       if (qRateMatches || typeRuleMatches) {
-        auto& entry = counts[Netmask(c.requestor, c.requestor.isIPv4() ? d_v4Mask : d_v6Mask)];
+        auto& entry = counts[AddressAndPortRange(c.requestor, c.requestor.isIPv4() ? d_v4Mask : d_v6Mask, d_portMask)];
         if (qRateMatches) {
           ++entry.queries;
         }
@@ -374,7 +374,7 @@ void DynBlockRulesGroup::processResponseRules(counts_t& counts, StatNode& root,
         continue;
       }
 
-      auto& entry = counts[Netmask(c.requestor, c.requestor.isIPv4() ? d_v4Mask : d_v6Mask)];
+      auto& entry = counts[AddressAndPortRange(c.requestor, c.requestor.isIPv4() ? d_v4Mask : d_v6Mask, d_portMask)];
       ++entry.responses;
 
       bool respRateMatches = d_respRateRule.matches(c.when);
@@ -401,7 +401,7 @@ void DynBlockMaintenance::purgeExpired(const struct timespec& now)
 {
   {
     auto blocks = g_dynblockNMG.getLocal();
-    std::vector<Netmask> toRemove;
+    std::vector<AddressAndPortRange> toRemove;
     for (const auto& entry : *blocks) {
       if (!(now < entry.second.until)) {
         toRemove.push_back(entry.first);
@@ -442,9 +442,9 @@ void DynBlockMaintenance::purgeExpired(const struct timespec& now)
   }
 }
 
-std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> DynBlockMaintenance::getTopNetmasks(size_t topN)
+std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> DynBlockMaintenance::getTopNetmasks(size_t topN)
 {
-  std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> results;
+  std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> results;
   if (topN == 0) {
     return results;
   }
@@ -465,7 +465,7 @@ std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> DynBlockMaint
         topsForReason.pop_front();
       }
 
-      topsForReason.insert(std::lower_bound(topsForReason.begin(), topsForReason.end(), newEntry, [](const std::pair<Netmask, unsigned int>& a, const std::pair<Netmask, unsigned int>& b) {
+      topsForReason.insert(std::lower_bound(topsForReason.begin(), topsForReason.end(), newEntry, [](const std::pair<AddressAndPortRange, unsigned int>& a, const std::pair<AddressAndPortRange, unsigned int>& b) {
         return a.second < b.second;
       }),
         newEntry);
@@ -535,7 +535,7 @@ void DynBlockMaintenance::generateMetrics()
   }
 
   /* do NMG */
-  std::map<std::string, std::map<Netmask, DynBlockEntryStat>> nm;
+  std::map<std::string, std::map<AddressAndPortRange, DynBlockEntryStat>> nm;
   for (const auto& reason : s_metricsData.front().nmgData) {
     auto& reasonStat = nm[reason.first];
 
@@ -573,19 +573,19 @@ void DynBlockMaintenance::generateMetrics()
   }
 
   /* now we need to get the top N entries (for each "reason") based on our counters (sum of the last N entries) */
-  std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> topNMGs;
+  std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> topNMGs;
   {
     for (const auto& reason : nm) {
       auto& topsForReason = topNMGs[reason.first];
       for (const auto& entry : reason.second) {
         if (topsForReason.size() < s_topN || topsForReason.front().second < entry.second.sum) {
           /* Note that this is a gauge, so we need to divide by the number of elapsed seconds */
-          auto newEntry = std::pair<Netmask, unsigned int>(entry.first, std::round(entry.second.sum / 60.0));
+          auto newEntry = std::pair<AddressAndPortRange, unsigned int>(entry.first, std::round(entry.second.sum / 60.0));
           if (topsForReason.size() >= s_topN) {
             topsForReason.pop_front();
           }
 
-          topsForReason.insert(std::lower_bound(topsForReason.begin(), topsForReason.end(), newEntry, [](const std::pair<Netmask, unsigned int>& a, const std::pair<Netmask, unsigned int>& b) {
+          topsForReason.insert(std::lower_bound(topsForReason.begin(), topsForReason.end(), newEntry, [](const std::pair<AddressAndPortRange, unsigned int>& a, const std::pair<AddressAndPortRange, unsigned int>& b) {
             return a.second < b.second;
           }),
             newEntry);
@@ -722,7 +722,7 @@ void DynBlockMaintenance::run()
   }
 }
 
-std::map<std::string, std::list<std::pair<Netmask, unsigned int>>> DynBlockMaintenance::getHitsForTopNetmasks()
+std::map<std::string, std::list<std::pair<AddressAndPortRange, unsigned int>>> DynBlockMaintenance::getHitsForTopNetmasks()
 {
   return s_tops.lock()->topNMGsByReason;
 }
index 73557c48913548960a279d844af7ebfbe74c02d4..27152c0d8c447b6a26dbb4cb2934b44d3e56d11d 100644 (file)
@@ -1258,16 +1258,20 @@ faster than the existing rules.
 
   Represents a group of dynamic block rules.
 
-  .. method:: DynBlockRulesGroup:setMasks(v4, v6)
+  .. method:: DynBlockRulesGroup:setMasks(v4, v6, port)
 
     .. versionadded:: 1.7.0
 
     Set the number of bits to keep in the IP address when inserting a block. The default is 32 for IPv4 and 128 for IPv6, meaning
     that only the exact address is blocked, but in some scenarios it might make sense to block a whole /64 IPv6 range instead of a
     single address, for example.
+    It is also possible to take the IPv4 UDP and TCP ports into account, for CGNAT deployments, by setting the number of bits of the port
+    to consider. For example passing 2 as the last parameter, which only makes sense if the previous parameters are respectively 32
+    and 128, will split a given IP address into four port ranges: 0-16383, 16384-32767, 32768-49151 and 49152-65535.
 
-    :param int v4: Number of bits of to keep for IPv4 addresses. Default is 32
-    :param int v6: Number of bits of to keep for IPv6 addresses. Default is 128
+    :param int v4: Number of bits to keep for IPv4 addresses. Default is 32
+    :param int v6: Number of bits to keep for IPv6 addresses. Default is 128
+    :param int port: Number of bits of port to consider over IPv4. Default is 0 meaning that the port is not taken into account
 
   .. method:: DynBlockRulesGroup:setQueryRate(rate, seconds, reason, blockingTime [, action [, warningRate]])
 
index ed8227fc6918f792f43658a3c2c17d1a6b9767e3..3193e6692b8558023058b8b995aeb10e34e5e95d 100644 (file)
@@ -9,7 +9,7 @@
 #include "dnsdist-rings.hh"
 
 Rings g_rings;
-GlobalStateHolder<NetmaskTree<DynBlock>> g_dynblockNMG;
+GlobalStateHolder<NetmaskTree<DynBlock, AddressAndPortRange>> g_dynblockNMG;
 GlobalStateHolder<SuffixMatchTree<DynBlock>> g_dynblockSMT;
 shared_ptr<BPFFilter> g_defaultBPFFilter{nullptr};
 
@@ -29,7 +29,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate) {
   unsigned int responseTime = 0;
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   size_t numberOfSeconds = 10;
   size_t blockDuration = 60;
@@ -151,6 +151,204 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate) {
   }
 }
 
+BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate_RangeV6) {
+  /* Check that we correctly group IPv6 addresses from the same /64 subnet into the same
+     dynamic block entry, if instructed to do so */
+  dnsheader dh;
+  memset(&dh, 0, sizeof(dh));
+  DNSName qname("rings.powerdns.com.");
+  ComboAddress requestor1("2001:db8::1");
+  ComboAddress backend("2001:0db8:ffff:ffff:ffff:ffff:ffff:ffff");
+  uint16_t qtype = QType::AAAA;
+  uint16_t size = 42;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
+  unsigned int responseTime = 0;
+  struct timespec now;
+  gettime(&now);
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
+
+  size_t numberOfSeconds = 10;
+  size_t blockDuration = 60;
+  const auto action = DNSAction::Action::Drop;
+  const std::string reason = "Exceeded query rate";
+
+  DynBlockRulesGroup dbrg;
+  dbrg.setQuiet(true);
+  dbrg.setMasks(32, 64, 0);
+
+  /* block above 50 qps for numberOfSeconds seconds, no warning */
+  dbrg.setQueryRate(50, 0, numberOfSeconds, reason, blockDuration, action);
+
+  {
+    /* insert 45 qps from a given client in the last 10s
+       this should not trigger the rule */
+    size_t numberOfQueries = 45 * numberOfSeconds;
+    g_rings.clear();
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), 0U);
+    g_dynblockNMG.setState(emptyNMG);
+
+    for (size_t idx = 0; idx < numberOfQueries; idx++) {
+      g_rings.insertQuery(now, requestor1, qname, qtype, size, dh, protocol);
+      /* we do not care about the response during that test, but we want to make sure
+         these do not interfere with the computation */
+      g_rings.insertResponse(now, requestor1, qname, qtype, responseTime, size, dh, backend, outgoingProtocol);
+    }
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfResponseEntries(), numberOfQueries);
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), numberOfQueries);
+
+    dbrg.apply(now);
+    BOOST_CHECK_EQUAL(g_dynblockNMG.getLocal()->size(), 0U);
+    BOOST_CHECK(g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(requestor1, 128, 16)) == nullptr);
+  }
+
+  {
+    /* insert just above 50 qps from several clients in the same /64 IPv6 range in the last 10s,
+       this should trigger the rule this time */
+    size_t numberOfQueries = (50 * numberOfSeconds) + 1;
+    g_rings.clear();
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), 0U);
+    g_dynblockNMG.setState(emptyNMG);
+
+    for (size_t idx = 0; idx < numberOfQueries; idx++) {
+      ComboAddress requestor("2001:db8::" + std::to_string(idx));
+      g_rings.insertQuery(now, requestor, qname, qtype, size, dh, protocol);
+      g_rings.insertResponse(now, requestor, qname, qtype, responseTime, size, dh, backend, outgoingProtocol);
+    }
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), numberOfQueries);
+
+    dbrg.apply(now);
+    BOOST_CHECK_EQUAL(g_dynblockNMG.getLocal()->size(), 1U);
+
+    {
+      /* beginning of the range should be blocked */
+      const auto& block = g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(requestor1, 128, 16))->second;
+      BOOST_CHECK_EQUAL(block.reason, reason);
+      BOOST_CHECK_EQUAL(static_cast<size_t>(block.until.tv_sec), now.tv_sec + blockDuration);
+      BOOST_CHECK(block.domain.empty());
+      BOOST_CHECK(block.action == action);
+      BOOST_CHECK_EQUAL(block.blocks, 0U);
+      BOOST_CHECK_EQUAL(block.warning, false);
+    }
+
+    {
+      /* end of the range should be blocked as well */
+      ComboAddress end("2001:0db8:0000:0000:ffff:ffff:ffff:ffff");
+      const auto& block = g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(end, 128, 16))->second;
+      BOOST_CHECK_EQUAL(block.reason, reason);
+      BOOST_CHECK_EQUAL(static_cast<size_t>(block.until.tv_sec), now.tv_sec + blockDuration);
+      BOOST_CHECK(block.domain.empty());
+      BOOST_CHECK(block.action == action);
+      BOOST_CHECK_EQUAL(block.blocks, 0U);
+      BOOST_CHECK_EQUAL(block.warning, false);
+    }
+
+    {
+      /* outside of the range should NOT */
+      ComboAddress out("2001:0db8:0000:0001::0");
+      BOOST_CHECK(g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(out, 128, 16)) == nullptr);
+    }
+  }
+}
+
+BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate_V4Ports) {
+  /* Check that we correctly split IPv4 addresses based on port ranges, when instructed to do so */
+  dnsheader dh;
+  memset(&dh, 0, sizeof(dh));
+  DNSName qname("rings.powerdns.com.");
+  ComboAddress requestor1("192.0.2.1:42");
+  ComboAddress backend("192.0.2.254");
+  uint16_t qtype = QType::AAAA;
+  uint16_t size = 42;
+  unsigned int responseTime = 0;
+  dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
+  dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
+  struct timespec now;
+  gettime(&now);
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
+
+  size_t numberOfSeconds = 10;
+  size_t blockDuration = 60;
+  const auto action = DNSAction::Action::Drop;
+  const std::string reason = "Exceeded query rate";
+
+  DynBlockRulesGroup dbrg;
+  dbrg.setQuiet(true);
+  /* split v4 by ports using a  /2 (0 - 16383, 16384 - 32767, 32768 - 49151, 49152 - 65535) */
+  dbrg.setMasks(32, 128, 2);
+
+  /* block above 50 qps for numberOfSeconds seconds, no warning */
+  dbrg.setQueryRate(50, 0, numberOfSeconds, reason, blockDuration, action);
+
+  {
+    /* insert 45 qps from a given client in the last 10s
+       this should not trigger the rule */
+    size_t numberOfQueries = 45 * numberOfSeconds;
+    g_rings.clear();
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), 0U);
+    g_dynblockNMG.setState(emptyNMG);
+
+    for (size_t idx = 0; idx < numberOfQueries; idx++) {
+      g_rings.insertQuery(now, requestor1, qname, qtype, size, dh, protocol);
+      /* we do not care about the response during that test, but we want to make sure
+         these do not interfere with the computation */
+      g_rings.insertResponse(now, requestor1, qname, qtype, responseTime, size, dh, backend, outgoingProtocol);
+    }
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfResponseEntries(), numberOfQueries);
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), numberOfQueries);
+
+    dbrg.apply(now);
+    BOOST_CHECK_EQUAL(g_dynblockNMG.getLocal()->size(), 0U);
+    BOOST_CHECK(g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(requestor1, 128, 16)) == nullptr);
+  }
+
+  {
+    /* insert just above 50 qps from several clients in the same IPv4 port range in the last 10s,
+       this should trigger the rule this time */
+    size_t numberOfQueries = (50 * numberOfSeconds) + 1;
+    g_rings.clear();
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), 0U);
+    g_dynblockNMG.setState(emptyNMG);
+
+    for (size_t idx = 0; idx < numberOfQueries; idx++) {
+      ComboAddress requestor("192.0.2.1:" + std::to_string(idx));
+      g_rings.insertQuery(now, requestor, qname, qtype, size, dh, protocol);
+      g_rings.insertResponse(now, requestor, qname, qtype, responseTime, size, dh, backend, outgoingProtocol);
+    }
+    BOOST_CHECK_EQUAL(g_rings.getNumberOfQueryEntries(), numberOfQueries);
+
+    dbrg.apply(now);
+    BOOST_CHECK_EQUAL(g_dynblockNMG.getLocal()->size(), 1U);
+
+    {
+      /* beginning of the port range should be blocked */
+      const auto& block = g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(ComboAddress("192.0.2.1:0"), 32, 16))->second;
+      BOOST_CHECK_EQUAL(block.reason, reason);
+      BOOST_CHECK_EQUAL(static_cast<size_t>(block.until.tv_sec), now.tv_sec + blockDuration);
+      BOOST_CHECK(block.domain.empty());
+      BOOST_CHECK(block.action == action);
+      BOOST_CHECK_EQUAL(block.blocks, 0U);
+      BOOST_CHECK_EQUAL(block.warning, false);
+    }
+
+    {
+      /* end of the range should be blocked as well */
+      const auto& block = g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(ComboAddress("192.0.2.1:16383"), 32, 16))->second;
+      BOOST_CHECK_EQUAL(block.reason, reason);
+      BOOST_CHECK_EQUAL(static_cast<size_t>(block.until.tv_sec), now.tv_sec + blockDuration);
+      BOOST_CHECK(block.domain.empty());
+      BOOST_CHECK(block.action == action);
+      BOOST_CHECK_EQUAL(block.blocks, 0U);
+      BOOST_CHECK_EQUAL(block.warning, false);
+    }
+
+    {
+      /* outside of the range should not */
+      BOOST_CHECK(g_dynblockNMG.getLocal()->lookup(AddressAndPortRange(ComboAddress("192.0.2.1:16384"), 32, 16)) == nullptr);
+    }
+  }
+}
+
 BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate_responses) {
   /* check that the responses are not accounted as queries when a
      rcode rate rule is defined (sounds very specific but actually happened) */
@@ -167,7 +365,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QueryRate_responses) {
   unsigned int responseTime = 0;
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   /* 100k entries, one shard */
   g_rings.setCapacity(1000000, 1);
@@ -224,7 +422,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_QTypeRate) {
   dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   size_t numberOfSeconds = 10;
   size_t blockDuration = 60;
@@ -313,7 +511,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_RCodeRate) {
   unsigned int responseTime = 100 * 1000; /* 100ms */
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   size_t numberOfSeconds = 10;
   size_t blockDuration = 60;
@@ -405,7 +603,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_RCodeRatio) {
   unsigned int responseTime = 100 * 1000; /* 100ms */
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   time_t numberOfSeconds = 10;
   unsigned int blockDuration = 60;
@@ -523,7 +721,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_ResponseByteRate) {
   unsigned int responseTime = 100 * 1000; /* 100ms */
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   size_t numberOfSeconds = 10;
   size_t blockDuration = 60;
@@ -594,7 +792,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_Warning) {
   dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   size_t numberOfSeconds = 10;
   size_t blockDuration = 60;
@@ -753,7 +951,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesGroup_Ranges) {
   dnsdist::Protocol protocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
 
   size_t numberOfSeconds = 10;
   size_t blockDuration = 60;
@@ -809,7 +1007,7 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesMetricsCache_GetTopN) {
   dnsdist::Protocol outgoingProtocol = dnsdist::Protocol::DoUDP;
   struct timespec now;
   gettime(&now);
-  NetmaskTree<DynBlock> emptyNMG;
+  NetmaskTree<DynBlock, AddressAndPortRange> emptyNMG;
   SuffixMatchTree<DynBlock> emptySMT;
 
   size_t numberOfSeconds = 10;
@@ -1080,4 +1278,168 @@ BOOST_AUTO_TEST_CASE(test_DynBlockRulesMetricsCache_GetTopN) {
 #endif
 }
 
+BOOST_AUTO_TEST_CASE(test_NetmaskTree) {
+  NetmaskTree<int, AddressAndPortRange> nmt;
+  BOOST_CHECK_EQUAL(nmt.empty(), true);
+  BOOST_CHECK_EQUAL(nmt.size(), 0U);
+  nmt.insert(AddressAndPortRange(ComboAddress("130.161.252.0"), 24, 0)).second = 0;
+  BOOST_CHECK_EQUAL(nmt.empty(), false);
+  BOOST_CHECK_EQUAL(nmt.size(), 1U);
+  nmt.insert(AddressAndPortRange(ComboAddress("130.161.0.0"), 16, 0)).second = 1;
+  BOOST_CHECK_EQUAL(nmt.size(), 2U);
+  nmt.insert(AddressAndPortRange(ComboAddress("130.0.0.0"), 8, 0)).second = 2;
+  BOOST_CHECK_EQUAL(nmt.size(), 3U);
+
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("213.244.168.210")), nullptr);
+  auto found = nmt.lookup(ComboAddress("130.161.252.29"));
+  BOOST_REQUIRE(found);
+  BOOST_CHECK_EQUAL(found->second, 0);
+  found = nmt.lookup(ComboAddress("130.161.180.1"));
+  BOOST_CHECK(found);
+  BOOST_CHECK_EQUAL(found->second, 1);
+
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("130.255.255.255"))->second, 2);
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("130.161.252.255"))->second, 0);
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("130.161.253.255"))->second, 1);
+  BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.255.255.255"), 32, 16))->second, 2);
+  BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.255"), 32, 16))->second, 0);
+  BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.161.253.255"), 32, 16))->second, 1);
+
+  found = nmt.lookup(ComboAddress("130.145.180.1"));
+  BOOST_CHECK(found);
+  BOOST_CHECK_EQUAL(found->second, 2);
+
+  nmt.insert(AddressAndPortRange(ComboAddress("0.0.0.0"), 0, 0)).second = 3;
+  BOOST_CHECK_EQUAL(nmt.size(), 4U);
+  nmt.insert(AddressAndPortRange(ComboAddress("0.0.0.0"), 7, 0)).second = 4;
+  BOOST_CHECK_EQUAL(nmt.size(), 5U);
+  nmt.insert(AddressAndPortRange(ComboAddress("0.0.0.0"), 15, 0)).second = 5;
+  BOOST_CHECK_EQUAL(nmt.size(), 6U);
+  BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("0.0.0.0"), 0, 0))->second, 3);
+  BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("0.0.0.0"), 7, 0))->second, 4);
+  BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("0.0.0.0"), 15, 0))->second, 5);
+  BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("0.0.0.0"), 32, 0))->second, 5);
+
+  nmt.clear();
+  BOOST_CHECK_EQUAL(nmt.empty(), true);
+  BOOST_CHECK_EQUAL(nmt.size(), 0U);
+  BOOST_CHECK(!nmt.lookup(ComboAddress("130.161.180.1")));
+
+  nmt.insert(AddressAndPortRange(ComboAddress("::1"), 128, 0)).second = 1;
+  BOOST_CHECK_EQUAL(nmt.empty(), false);
+  BOOST_CHECK_EQUAL(nmt.size(), 1U);
+  nmt.insert(AddressAndPortRange(ComboAddress("::"), 0, 0)).second = 0;
+  BOOST_CHECK_EQUAL(nmt.size(), 2U);
+  nmt.insert(AddressAndPortRange(ComboAddress("fe80::"), 16, 0)).second = 2;
+  BOOST_CHECK_EQUAL(nmt.size(), 3U);
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("130.161.253.255")), nullptr);
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("::2"))->second, 0);
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("::ffff"))->second, 0);
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("::1"))->second, 1);
+  BOOST_CHECK_EQUAL(nmt.lookup(ComboAddress("fe80::1"))->second, 2);
+}
+
+BOOST_AUTO_TEST_CASE(test_NetmaskTreePort) {
+  {
+    /* exact port matching */
+    NetmaskTree<int, AddressAndPortRange> nmt;
+    BOOST_CHECK_EQUAL(nmt.empty(), true);
+    BOOST_CHECK_EQUAL(nmt.size(), 0U);
+    nmt.insert(AddressAndPortRange(ComboAddress("130.161.252.42:65534"), 32, 16)).second = 0;
+    BOOST_CHECK_EQUAL(nmt.empty(), false);
+    BOOST_CHECK_EQUAL(nmt.size(), 1U);
+
+    BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("213.244.168.210"), 32, 16)), nullptr);
+
+    auto found = nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:65534"), 32, 16));
+    BOOST_CHECK(found != nullptr);
+    BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:65533"), 32, 16)), nullptr);
+    BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:65535"), 32, 16)), nullptr);
+  }
+
+  {
+    /* /15 port matching */
+    NetmaskTree<int, AddressAndPortRange> nmt;
+    BOOST_CHECK_EQUAL(nmt.empty(), true);
+    BOOST_CHECK_EQUAL(nmt.size(), 0U);
+    nmt.insert(AddressAndPortRange(ComboAddress("130.161.252.42:0"), 32, 15)).second = 0;
+    BOOST_CHECK_EQUAL(nmt.empty(), false);
+    BOOST_CHECK_EQUAL(nmt.size(), 1U);
+
+    BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("213.244.168.210"), 32, 16)), nullptr);
+
+    auto found = nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:0"), 32, 16));
+    BOOST_CHECK(found != nullptr);
+
+    found = nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:1"), 32, 16));
+    BOOST_CHECK(found != nullptr);
+
+    /* everything else should be a miss */
+    for (size_t idx = 2; idx <= 65535; idx++) {
+      BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:" + std::to_string(idx)), 32, 16)), nullptr);
+    }
+
+    nmt.clear();
+    BOOST_CHECK_EQUAL(nmt.empty(), true);
+    BOOST_CHECK_EQUAL(nmt.size(), 0U);
+    nmt.insert(AddressAndPortRange(ComboAddress("130.161.252.42:65535"), 32, 15)).second = 0;
+    BOOST_CHECK_EQUAL(nmt.empty(), false);
+    BOOST_CHECK_EQUAL(nmt.size(), 1U);
+
+    BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("213.244.168.210"), 32, 16)), nullptr);
+
+    /* everything else should be a miss */
+    for (size_t idx = 0; idx <= 65533; idx++) {
+      BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:" + std::to_string(idx)), 32, 16)), nullptr);
+    }
+    found = nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:65534"), 32, 16));
+    BOOST_CHECK(found != nullptr);
+    found = nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:65535"), 32, 16));
+    BOOST_CHECK(found != nullptr);
+  }
+
+  {
+    /* /1 port matching */
+    NetmaskTree<int, AddressAndPortRange> nmt;
+    BOOST_CHECK_EQUAL(nmt.empty(), true);
+    BOOST_CHECK_EQUAL(nmt.size(), 0U);
+    nmt.insert(AddressAndPortRange(ComboAddress("130.161.252.42:0"), 32, 1)).second = 0;
+    BOOST_CHECK_EQUAL(nmt.empty(), false);
+    BOOST_CHECK_EQUAL(nmt.size(), 1U);
+
+    BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("213.244.168.210"), 32, 16)), nullptr);
+
+    for (size_t idx = 0; idx <= 32767; idx++) {
+      auto found = nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:" + std::to_string(idx)), 32, 16));
+      BOOST_CHECK(found != nullptr);
+    }
+
+    /* everything else should be a miss */
+    for (size_t idx = 32768; idx <= 65535; idx++) {
+      BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("130.161.252.42:" + std::to_string(idx)), 32, 16)), nullptr);
+    }
+  }
+
+  {
+    /* Check that the port matching does not apply to IPv6, where it does not make sense */
+
+    /* /1 port matching */
+    NetmaskTree<int, AddressAndPortRange> nmt;
+    BOOST_CHECK_EQUAL(nmt.empty(), true);
+    BOOST_CHECK_EQUAL(nmt.size(), 0U);
+    nmt.insert(AddressAndPortRange(ComboAddress("[2001:db8::1]:0"), 128, 1)).second = 0;
+    BOOST_CHECK_EQUAL(nmt.empty(), false);
+    BOOST_CHECK_EQUAL(nmt.size(), 1U);
+
+    /* different IP, no match */
+    BOOST_CHECK_EQUAL(nmt.lookup(AddressAndPortRange(ComboAddress("[2001:db8::2]:0"), 128, 16)), nullptr);
+
+    /* all ports should match */
+    for (size_t idx = 1; idx <= 65535; idx++) {
+      auto found = nmt.lookup(AddressAndPortRange(ComboAddress("[2001:db8::1]:" + std::to_string(idx)), 128, 16));
+      BOOST_CHECK(found != nullptr);
+    }
+  }
+}
+
 BOOST_AUTO_TEST_SUITE_END()
index 2ec43220f3eb2cfdc6e7074a9e5cf418cee6690f..9da0ea07ddcf419aaa8cf710a4c34eedcbbc0b99 100644 (file)
@@ -129,15 +129,15 @@ union ComboAddress {
   {
     uint32_t operator()(const ComboAddress& ca) const
     {
-      const unsigned char* start;
-      int len;
-      if(ca.sin4.sin_family == AF_INET) {
-        start =(const unsigned char*)&ca.sin4.sin_addr.s_addr;
-        len=4;
+      const unsigned char* start = nullptr;
+      uint32_t len = 0;
+      if (ca.sin4.sin_family == AF_INET) {
+        start = reinterpret_cast<const unsigned char*>(&ca.sin4.sin_addr.s_addr);
+        len = 4;
       }
       else {
-        start =(const unsigned char*)&ca.sin6.sin6_addr.s6_addr;
-        len=16;
+        start = reinterpret_cast<const unsigned char*>(&ca.sin6.sin6_addr.s6_addr);
+        len = 16;
       }
       return burtle(start, len, 0);
     }
@@ -632,14 +632,9 @@ public:
   }
 
   //! Get the total number of address bits for this netmask (either 32 or 128 depending on IP version)
-  uint8_t getAddressBits() const
-  {
-    return d_network.getBits();
-  }
-
   uint8_t getFullBits() const
   {
-    return getAddressBits();
+    return d_network.getBits();
   }
 
   /** Get the value of the bit at the provided bit index. When the index >= 0,
@@ -664,14 +659,6 @@ public:
     return d_network.getBit(bit);
   }
 
-  struct hash
-  {
-    uint32_t operator()(const Netmask& nm) const
-    {
-      ComboAddress::addressOnlyHash hashOp;
-      return hashOp(nm.d_network);
-    }
-  };
 private:
   ComboAddress d_network;
   uint32_t d_mask;
@@ -1140,7 +1127,57 @@ public:
 
   //<! Returns "best match" for key_type, which might not be value
   const node_type* lookup(const key_type& value) const {
-    return lookup(value.getNetwork(), value.getBits());
+    TreeNode *node = nullptr;
+
+    uint8_t max_bits = value.getBits();
+
+    if (value.isIPv4())
+      node = d_root->left.get();
+    else if (value.isIPv6())
+      node = d_root->right.get();
+    else
+      throw NetmaskException("invalid address family");
+    if (node == nullptr) return nullptr;
+
+    node_type *ret = nullptr;
+
+    int bits = 0;
+    for(; bits < max_bits; bits++) {
+      bool vall = value.getBit(-1-bits);
+      if (bits >= node->d_bits) {
+        // the end of the current node is reached; continue with the next
+        // (we keep track of last assigned node)
+        if (node->assigned && bits == node->node.first.getBits())
+          ret = &node->node;
+        if (vall) {
+          if (!node->right)
+            break;
+          node = node->right.get();
+        } else {
+          if (!node->left)
+            break;
+          node = node->left.get();
+        }
+        continue;
+      }
+      if (bits >= node->node.first.getBits()) {
+        // the matching branch ends here
+        break;
+      }
+      bool valr = node->node.first.getBit(-1-bits);
+      if (vall != valr) {
+        // the branch matches just upto this point, yet continues in a different
+        // direction
+        break;
+      }
+    }
+    // needed if we did not find one in loop
+    if (node->assigned && bits == node->node.first.getBits())
+      ret = &node->node;
+
+    // this can be nullptr.
+    return ret;
+
   }
 
   //<! Perform best match lookup for value, using at most max_bits
@@ -1443,6 +1480,185 @@ public:
   {}
 };
 
+class AddressAndPortRange
+{
+public:
+  AddressAndPortRange(): d_addrMask(0), d_portMask(0)
+  {
+    d_addr.sin4.sin_family = 0; // disable this doing anything useful
+    d_addr.sin4.sin_port = 0; // this guarantees d_network compares identical
+  }
+
+  AddressAndPortRange(ComboAddress ca, uint8_t addrMask, uint8_t portMask): d_addr(std::move(ca)), d_addrMask(addrMask), d_portMask(portMask)
+  {
+    if (!d_addr.isIPv4()) {
+      d_portMask = 0;
+    }
+
+    uint16_t port = d_addr.getPort();
+    if (d_portMask < 16) {
+      uint16_t mask = ~(0xFFFF >> d_portMask);
+      port = port & mask;
+    }
+
+    if (d_addrMask < d_addr.getBits()) {
+      if (d_portMask > 0) {
+        throw std::runtime_error("Trying to create a AddressAndPortRange with a reduced address mask (" + std::to_string(d_addrMask) + ") and a port range (" + std::to_string(d_portMask) + ")");
+      }
+      d_addr = Netmask(d_addr, d_addrMask).getMaskedNetwork();
+    }
+    d_addr.setPort(port);
+  }
+
+  uint8_t getFullBits() const
+  {
+    return d_addr.getBits() + 16;
+  }
+
+  uint8_t getBits() const
+  {
+    if (d_addrMask < d_addr.getBits()) {
+      return d_addrMask;
+    }
+
+    return d_addr.getBits() + d_portMask;
+  }
+
+  /** Get the value of the bit at the provided bit index. When the index >= 0,
+      the index is relative to the LSB starting at index zero. When the index < 0,
+      the index is relative to the MSB starting at index -1 and counting down.
+  */
+  bool getBit(int index) const
+  {
+    if (index >= getFullBits()) {
+      return false;
+    }
+    if (index < 0) {
+      index = getFullBits() + index;
+    }
+
+    if (index < 16) {
+      /* we are into the port bits */
+      uint16_t port = d_addr.getPort();
+      return ((port & (1U<<index)) != 0x0000);
+    }
+
+    index -= 16;
+
+    return d_addr.getBit(index);
+  }
+
+  bool isIPv4() const
+  {
+    return d_addr.isIPv4();
+  }
+
+  bool isIPv6() const
+  {
+    return d_addr.isIPv6();
+  }
+
+  AddressAndPortRange getNormalized() const
+  {
+    return AddressAndPortRange(d_addr, d_addrMask, d_portMask);
+  }
+
+  AddressAndPortRange getSuper(uint8_t bits) const
+  {
+    if (bits <= d_addrMask) {
+      return AddressAndPortRange(d_addr, bits, 0);
+    }
+    if (bits <= d_addrMask + d_portMask) {
+      return AddressAndPortRange(d_addr, d_addrMask, d_portMask - (bits - d_addrMask));
+    }
+
+    return AddressAndPortRange(d_addr, d_addrMask, d_portMask);
+  }
+
+  const ComboAddress& getNetwork() const
+  {
+    return d_addr;
+  }
+
+  string toString() const
+  {
+    if (d_addrMask < d_addr.getBits() || d_portMask == 0) {
+      return d_addr.toStringNoInterface() + "/" + std::to_string(d_addrMask);
+    }
+    return d_addr.toStringNoInterface() + ":" + std::to_string(d_addr.getPort()) + "/" + std::to_string(d_portMask);
+  }
+
+  bool empty() const
+  {
+    return d_addr.sin4.sin_family == 0;
+  }
+
+  bool operator==(const AddressAndPortRange& rhs) const
+  {
+    return tie(d_addr, d_addrMask, d_portMask) == tie(rhs.d_addr, rhs.d_addrMask, rhs.d_portMask);
+  }
+
+  bool operator<(const AddressAndPortRange& rhs) const
+  {
+    if (empty() && !rhs.empty()) {
+      return false;
+    }
+
+    if (!empty() && rhs.empty()) {
+      return true;
+    }
+
+    if (d_addrMask > rhs.d_addrMask) {
+      return true;
+    }
+
+    if (d_addrMask < rhs.d_addrMask) {
+      return false;
+    }
+
+    if (d_addr < rhs.d_addr) {
+      return true;
+    }
+
+    if (d_addr > rhs.d_addr) {
+      return false;
+    }
+
+    if (d_portMask > rhs.d_portMask) {
+      return true;
+    }
+
+    if (d_portMask < rhs.d_portMask) {
+      return false;
+    }
+
+    return d_addr.getPort() < rhs.d_addr.getPort();
+  }
+
+  bool operator>(const AddressAndPortRange& rhs) const
+  {
+    return rhs.operator<(*this);
+  }
+
+  struct hash
+  {
+    uint32_t operator()(const AddressAndPortRange& apr) const
+    {
+      ComboAddress::addressOnlyHash hashOp;
+      uint16_t port = apr.d_addr.getPort();
+      /* it's fine to hash the whole address and port because the non-relevant parts have
+         been masked to 0 */
+      return burtle(reinterpret_cast<const unsigned char*>(&port), sizeof(port), hashOp(apr.d_addr));
+    }
+  };
+
+private:
+  ComboAddress d_addr;
+  uint8_t d_addrMask;
+  /* only used for v4 addresses */
+  uint8_t d_portMask;
+};
+
 int SSocket(int family, int type, int flags);
 int SConnect(int sockfd, const ComboAddress& remote);
 /* tries to connect to remote for a maximum of timeout seconds.
@@ -1476,4 +1692,3 @@ bool isTCPSocketUsable(int sock);
 
 extern template class NetmaskTree<bool>;
 ComboAddress parseIPAndPort(const std::string& input, uint16_t port);
-