]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
fix(auth): check LUA record weights are > 0 16948/head
authorPieter Lexis <pieter.lexis@powerdns.com>
Thu, 5 Mar 2026 10:39:40 +0000 (11:39 +0100)
committerPieter Lexis <pieter.lexis@powerdns.com>
Thu, 5 Mar 2026 11:26:26 +0000 (12:26 +0100)
pdns/lua-record.cc
regression-tests.auth-py/test_LuaRecords.py

index 0acb2c30f7c56c60275e84d78a3adf41e6343a4a..bb7d1ed539f4a4179e3bb81bb788b244fcb5d297 100644 (file)
@@ -422,13 +422,13 @@ static T pickHashed(const ComboAddress& who, const vector<T>& items)
 }
 
 template <typename T>
-static T pickWeightedRandom(const vector< pair<int, T> >& items)
+static T pickWeightedRandom(const vector< pair<unsigned int, T> >& items)
 {
   if (items.empty()) {
     throw std::invalid_argument("The items list cannot be empty");
   }
-  int sum=0;
-  vector< pair<int, T> > pick;
+  unsigned int sum=0;
+  vector< pair<unsigned int, T> > pick;
   pick.reserve(items.size());
 
   for(auto& i : items) {
@@ -441,18 +441,18 @@ static T pickWeightedRandom(const vector< pair<int, T> >& items)
   }
 
   int r = dns_random(sum);
-  auto p = upper_bound(pick.begin(), pick.end(), r, [](int rarg, const typename decltype(pick)::value_type& a) { return rarg < a.first; });
+  auto p = upper_bound(pick.begin(), pick.end(), r, [](unsigned int rarg, const typename decltype(pick)::value_type& a) { return rarg < a.first; });
   return p->second;
 }
 
 template <typename T>
-static T pickWeightedHashed(const ComboAddress& bestwho, const vector< pair<int, T> >& items)
+static T pickWeightedHashed(const ComboAddress& bestwho, const vector< pair<unsigned int, T> >& items)
 {
   if (items.empty()) {
     throw std::invalid_argument("The items list cannot be empty");
   }
-  int sum=0;
-  vector< pair<int, T> > pick;
+  unsigned int sum=0;
+  vector< pair<unsigned int, T> > pick;
   pick.reserve(items.size());
 
   for(auto& i : items) {
@@ -466,18 +466,18 @@ static T pickWeightedHashed(const ComboAddress& bestwho, const vector< pair<int,
 
   ComboAddress::addressOnlyHash aoh;
   int r = aoh(bestwho) % sum;
-  auto p = upper_bound(pick.begin(), pick.end(), r, [](int rarg, const typename decltype(pick)::value_type& a) { return rarg < a.first; });
+  auto p = upper_bound(pick.begin(), pick.end(), r, [](unsigned int rarg, const typename decltype(pick)::value_type& a) { return rarg < a.first; });
   return p->second;
 }
 
 template <typename T>
-static T pickWeightedNameHashed(const DNSName& dnsname, vector< pair<int, T> >& items)
+static T pickWeightedNameHashed(const DNSName& dnsname, vector< pair<unsigned int, T> >& items)
 {
   if (items.empty()) {
     throw std::invalid_argument("The items list cannot be empty");
   }
   size_t sum=0;
-  vector< pair<int, T> > pick;
+  vector< pair<unsigned int, T> > pick;
   pick.reserve(items.size());
 
   for(auto& i : items) {
@@ -490,7 +490,7 @@ static T pickWeightedNameHashed(const DNSName& dnsname, vector< pair<int, T> >&
   }
 
   size_t r = dnsname.hash() % sum;
-  auto p = upper_bound(pick.begin(), pick.end(), r, [](int rarg, const typename decltype(pick)::value_type& a) { return rarg < a.first; });
+  auto p = upper_bound(pick.begin(), pick.end(), r, [](unsigned int rarg, const typename decltype(pick)::value_type& a) { return rarg < a.first; });
   return p->second;
 }
 
@@ -716,13 +716,13 @@ static vector<string> convStringList(const iplist_t& items)
   return result;
 }
 
-static vector< pair<int, string> > convIntStringPairList(const std::unordered_map<int, wiplist_t >& items)
+static vector< pair<unsigned int, string> > convIntStringPairList(const std::unordered_map<int, wiplist_t >& items)
 {
-  vector<pair<int,string> > result;
+  vector<pair<unsigned int,string> > result;
   result.reserve(items.size());
 
   for(const auto& item : items) {
-    result.emplace_back(atoi(item.second.at(1).c_str()), item.second.at(2));
+    result.emplace_back(pdns::checked_stoi_nonzero<unsigned int>(item.second.at(1).c_str()), item.second.at(2));
   }
 
   return result;
@@ -812,7 +812,7 @@ static void cleanZoneHashes()
   }
 }
 
-static std::vector<std::shared_ptr<EntryHashesHolder>> getCHashedEntries(const domainid_t zoneId, const std::string& queryName, const std::vector<std::pair<int, std::string>>& items)
+static std::vector<std::shared_ptr<EntryHashesHolder>> getCHashedEntries(const domainid_t zoneId, const std::string& queryName, const std::vector<std::pair<unsigned int, std::string>>& items)
 {
   std::vector<std::shared_ptr<EntryHashesHolder>> result{};
   std::map<zone_hashes_key_t, std::shared_ptr<EntryHashesHolder>> newEntries{};
@@ -844,7 +844,7 @@ static std::vector<std::shared_ptr<EntryHashesHolder>> getCHashedEntries(const d
   return result;
 }
 
-static std::string pickConsistentWeightedHashed(const ComboAddress& bestwho, const std::vector<std::pair<int, std::string>>& items)
+static std::string pickConsistentWeightedHashed(const ComboAddress& bestwho, const std::vector<std::pair<unsigned int, std::string>>& items)
 {
   const auto& zoneId = s_lua_record_ctx->zone_record.domain_id;
   const auto queryName = s_lua_record_ctx->qname.toString();
@@ -1319,7 +1319,7 @@ static string lua_pickrandom(const iplist_t& ips)
  */
 static string lua_pickselfweighted(const std::string& url, const iplist_t& ips, boost::optional<opts_t> options)
 {
-  vector< pair<int, ComboAddress> > items;
+  vector< pair<unsigned int, ComboAddress> > items;
   opts_t opts;
   if(options) {
     opts = *options;
@@ -1364,7 +1364,7 @@ static string lua_pickhashed(const iplist_t& ips)
  */
 static string lua_pickwrandom(const std::unordered_map<int, wiplist_t>& ips)
 {
-  vector< pair<int, string> > items = convIntStringPairList(ips);
+  vector< pair<unsigned int, string> > items = convIntStringPairList(ips);
   return pickWeightedRandom<string>(items);
 }
 
@@ -1375,11 +1375,11 @@ static string lua_pickwrandom(const std::unordered_map<int, wiplist_t>& ips)
  */
 static string lua_pickwhashed(std::unordered_map<int, wiplist_t> ips)
 {
-  vector< pair<int, string> > items;
+  vector< pair<unsigned int, string> > items;
 
   items.reserve(ips.size());
   for (auto& entry : ips) {
-    items.emplace_back(atoi(entry.second[1].c_str()), entry.second[2]);
+    items.emplace_back(pdns::checked_stoi_nonzero<unsigned int>(entry.second[1].c_str()), entry.second[2]);
   }
 
   return pickWeightedHashed<string>(s_lua_record_ctx->bestwho, items);
@@ -1392,11 +1392,11 @@ static string lua_pickwhashed(std::unordered_map<int, wiplist_t> ips)
  */
 static string lua_picknamehashed(std::unordered_map<int, wiplist_t> ips)
 {
-  vector< pair<int, string> > items;
+  vector< pair<unsigned int, string> > items;
 
   items.reserve(ips.size());
   for (auto& address : ips) {
-    items.emplace_back(atoi(address.second[1].c_str()), address.second[2]);
+    items.emplace_back(pdns::checked_stoi_nonzero<unsigned int>(address.second[1].c_str()), address.second[2]);
   }
 
   return pickWeightedNameHashed<string>(s_lua_record_ctx->qname, items);
@@ -1409,11 +1409,11 @@ static string lua_picknamehashed(std::unordered_map<int, wiplist_t> ips)
  */
 static string lua_pickchashed(const std::unordered_map<int, wiplist_t>& ips)
 {
-  std::vector<std::pair<int, std::string>> items;
+  std::vector<std::pair<unsigned int, std::string>> items;
 
   items.reserve(ips.size());
   for (const auto& entry : ips) {
-    items.emplace_back(atoi(entry.second.at(1).c_str()), entry.second.at(2));
+    items.emplace_back(pdns::checked_stoi_nonzero<unsigned int>(entry.second.at(1).c_str()), entry.second.at(2));
   }
 
   return pickConsistentWeightedHashed(s_lua_record_ctx->bestwho, items);
index f0c7053d41175b50dd0ec761eb6d9f5857972cd9..9088ff4e17e1faacbf1293a8a321b38b5b91ef8b 100644 (file)
@@ -71,6 +71,8 @@ hashed.example.org.          3600 IN LUA  A     "pickhashed({{ '1.2.3.4', '4.3.2
 hashed-v6.example.org.       3600 IN LUA  AAAA  "pickhashed({{ '2001:db8:a0b:12f0::1', 'fe80::2a1:9bff:fe9b:f268' }})"
 hashed-txt.example.org.      3600 IN LUA  TXT   "pickhashed({{ 'bob', 'alice' }})"
 whashed.example.org.         3600 IN LUA  A     "pickwhashed({{ {{15, '1.2.3.4'}}, {{42, '4.3.2.1'}} }})"
+whashedzero.example.org.     3600 IN LUA  A     "pickwhashed({{ {{15, '1.2.3.4'}}, {{0, '4.3.2.1'}} }})"
+whashednegative.example.org. 3600 IN LUA  A     "pickwhashed({{ {{15, '1.2.3.4'}}, {{-3, '4.3.2.1'}} }})"
 *.namehashed.example.org.    3600 IN LUA  A     "picknamehashed({{ {{15, '1.2.3.4'}}, {{42, '4.3.2.1'}} }})"
 whashed-txt.example.org.     3600 IN LUA  TXT   "pickwhashed({{ {{15, 'bob'}}, {{42, 'alice'}} }})"
 chashed.example.org.         3600 IN LUA  A     "pickchashed({{ {{15, '1.2.3.4'}}, {{42, '4.3.2.1'}} }})"
@@ -978,6 +980,26 @@ class TestLuaRecords(BaseLuaTest):
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertRRsetInAnswer(res, first.answer[0])
 
+    def testWHashedZero(self):
+        """
+        Test that pickwhashed() does not accept zero weights
+        """
+
+        query = dns.message.make_query('whashedzero.example.org', 'A')
+
+        response = self.sendUDPQuery(query)
+        self.assertRcodeEqual(response, dns.rcode.SERVFAIL)
+
+    def testWHashedNegative(self):
+        """
+        Test that pickwhashed() does not accept negative weights
+        """
+
+        query = dns.message.make_query('whashednegative.example.org', 'A')
+
+        response = self.sendUDPQuery(query)
+        self.assertRcodeEqual(response, dns.rcode.SERVFAIL)
+
     def testTimeout(self):
         """
         Test if LUA scripts are aborted if script execution takes too long