]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
auth: lua-records, add support for pickchashed function
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Tue, 30 Jan 2024 15:15:20 +0000 (16:15 +0100)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Fri, 9 Feb 2024 10:33:32 +0000 (11:33 +0100)
docs/lua-records/functions.rst
pdns/lua-record.cc
regression-tests.auth-py/test_LuaRecords.py

index e9528a58e4c1681ccab6b778fc1627b97d54501b..e63e9c7a112cca62515b33c6cc4cd72c341a48d5 100644 (file)
@@ -207,6 +207,27 @@ Record creation functions
 
   This function also works for CNAME or TXT records.
 
+.. function:: pickchashed(values)
+
+  Based on the hash of ``bestwho``, returns a string from the list
+  supplied, as weighted by the various ``weight`` parameters and distributed consistently.
+  Performs no uptime checking.
+
+  :param values: table of weight, string (such as IPv4 or IPv6 address).
+
+  This function works almost like :func:`pickwhashed` while bringing the following properties:
+  - reordering the list of entries won't affect the distribution
+  - updating the weight of an entry will only affect a part of the distribution
+  - because of the previous properties, the CPU and memory cost is a bit higher than :func:`pickwhashed`
+
+  An example::
+
+    mydomain.example.com    IN    LUA    A ("pickchashed({                             "
+                                            "        {15,  "192.0.2.1"},               "
+                                            "        {100, "198.51.100.5"}             "
+                                            "})                                        ")
+
+
 .. function:: pickwhashed(values)
 
   Based on the hash of ``bestwho``, returns a string from the list
@@ -271,12 +292,12 @@ Reverse DNS functions
 
 .. function:: createReverse(format, [exceptions])
 
-  Used for generating default hostnames from IPv4 wildcard reverse DNS records, e.g. ``*.0.0.127.in-addr.arpa`` 
-  
+  Used for generating default hostnames from IPv4 wildcard reverse DNS records, e.g. ``*.0.0.127.in-addr.arpa``
+
   See :func:`createReverse6` for IPv6 records (ip6.arpa)
 
   See :func:`createForward` for creating the A records on a wildcard record such as ``*.static.example.com``
-  
+
   Returns a formatted hostname based on the format string passed.
 
   :param format: A hostname string to format, for example ``%1%.%2%.%3%.%4%.static.example.com``.
@@ -297,13 +318,13 @@ Reverse DNS functions
       - ``%6`` would be ``7f00000f`` (127 is 7f, and 15 is 0f in hexadecimal)
 
   Example records::
-  
+
     *.0.0.127.in-addr.arpa IN    LUA    PTR "createReverse('%1%.%2%.%3%.%4%.static.example.com')"
     *.1.0.127.in-addr.arpa IN    LUA    PTR "createReverse('%5%.static.example.com')"
     *.2.0.127.in-addr.arpa IN    LUA    PTR "createReverse('%6%.static.example.com')"
+
   When queried::
-  
+
     # -x is syntactic sugar to request the PTR record for an IPv4/v6 address such as 127.0.0.5
     # Equivalent to dig PTR 5.0.0.127.in-addr.arpa
     $ dig +short -x 127.0.0.5 @ns1.example.com
@@ -314,44 +335,44 @@ Reverse DNS functions
     7f000205.static.example.com.
 
 .. function:: createForward()
-  
+
   Used to generate the reverse DNS domains made from :func:`createReverse`
-  
+
   Generates an A record for a dotted or hexadecimal IPv4 domain (e.g. 127.0.0.1.static.example.com)
-  
+
   It does not take any parameters, it simply interprets the zone record to find the IP address.
-  
+
   An example record for zone ``static.example.com``::
-    
+
     *.static.example.com    IN    LUA    A "createForward()"
-  
+
   This function supports the forward dotted format (``127.0.0.1.static.example.com``), and the hex format, when prefixed by two ignored characters (``ip40414243.static.example.com``)
-  
+
   When queried::
-  
+
     $ dig +short A 127.0.0.5.static.example.com @ns1.example.com
     127.0.0.5
-  
+
   Since 4.8.0: the hex format can be prefixed by any number of characters (within DNS label length limits), including zero characters (so no prefix).
 
 .. function:: createReverse6(format[, exceptions])
 
   Used for generating default hostnames from IPv6 wildcard reverse DNS records, e.g. ``*.1.0.0.2.ip6.arpa``
-  
+
   **For simplicity purposes, only small sections of IPv6 rDNS domains are used in most parts of this guide,**
   **as a full ip6.arpa record is around 80 characters long**
-  
+
   See :func:`createReverse` for IPv4 records (in-addr.arpa)
 
   See :func:`createForward6` for creating the AAAA records on a wildcard record such as ``*.static.example.com``
-  
+
   Returns a formatted hostname based on the format string passed.
 
   :param format: A hostname string to format, for example ``%33%.static6.example.com``.
   :param exceptions: An optional table of overrides. For example ``{['2001:db8::1'] = 'example.example.com.'}`` would, when generating a name for IP ``2001:db8::1``, return ``example.example.com`` instead of something like ``2001--db8.example.com``.
 
   Formatting options:
-   
+
   - ``%1%`` to ``%32%`` are individual characters (nibbles)
       - **Example PTR record query:** ``a.0.0.0.1.0.0.2.ip6.arpa``
       - ``%1%`` = 2
@@ -364,40 +385,40 @@ Reverse DNS functions
       - ``%34%`` - returns ``2001`` (chunk 1)
       - ``%35%`` - returns ``000a`` (chunk 2)
       - ``%41%`` - returns ``0123`` (chunk 8)
-  
+
   Example records::
-  
+
     *.1.0.0.2.ip6.arpa IN    LUA    PTR "createReverse6('%33%.static6.example.com')"
     *.2.0.0.2.ip6.arpa IN    LUA    PTR "createReverse6('%34%.%35%.static6.example.com')"
+
   When queried::
-  
+
     # -x is syntactic sugar to request the PTR record for an IPv4/v6 address such as 2001::1
     # Equivalent to dig PTR 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.b.0.0.0.a.0.0.0.1.0.0.2.ip6.arpa
     # readable version:     1.0.0.0 .0.0.0.0 .0.0.0.0 .0.0.0.0 .0.0.0.0 .b.0.0.0 .a.0.0.0 .1.0.0.2 .ip6.arpa
-    
+
     $ dig +short -x 2001:a:b::1 @ns1.example.com
     2001-a-b--1.static6.example.com.
-    
+
     $ dig +short -x 2002:a:b::1 @ns1.example.com
     2002.000a.static6.example.com
 
 .. function:: createForward6()
-  
+
   Used to generate the reverse DNS domains made from :func:`createReverse6`
-  
+
   Generates an AAAA record for a dashed compressed IPv6 domain (e.g. ``2001-a-b--1.static6.example.com``)
-  
+
   It does not take any parameters, it simply interprets the zone record to find the IP address.
-  
+
   An example record for zone ``static.example.com``::
-    
+
     *.static6.example.com    IN    LUA    AAAA "createForward6()"
-  
+
   This function supports the dashed compressed format (i.e. ``2001-a-b--1.static6.example.com``), and the dot-split uncompressed format (``2001.db8.6.5.4.3.2.1.static6.example.com``)
-  
+
   When queried::
-  
+
     $ dig +short AAAA 2001-a-b--1.static6.example.com @ns1.example.com
     2001:a:b::1
 
index 90b862c8a1fe048680db936a049cbca25cdf9e1c..106e90df301212311baee59bce10ee3714b42c73 100644 (file)
@@ -1,10 +1,12 @@
 #include <thread>
 #include <future>
 #include <boost/format.hpp>
+#include <boost/uuid/string_generator.hpp>
 #include <utility>
 #include <algorithm>
 #include <random>
 #include "qtype.hh"
+#include <tuple>
 #include "version.hh"
 #include "ext/luawrapper/include/LuaContext.hpp"
 #include "lock.hh"
@@ -359,7 +361,7 @@ static T pickWeightedRandom(const vector< pair<int, T> >& items)
 }
 
 template <typename T>
-static T pickWeightedHashed(const ComboAddress& bestwho, vector< pair<int, T> >& items)
+static T pickWeightedHashed(const ComboAddress& bestwho, const vector< pair<int, T> >& items)
 {
   if (items.empty()) {
     throw std::invalid_argument("The items list cannot be empty");
@@ -651,6 +653,103 @@ typedef struct AuthLuaRecordContext
 
 static thread_local unique_ptr<lua_record_ctx_t> s_lua_record_ctx;
 
+/*
+ *  Holds computed hashes for a given entry
+ */
+struct EntryHashesHolder {
+  std::atomic<size_t> weight;
+  std::string entry;
+  SharedLockGuarded<std::vector<unsigned int>> hashes;
+
+  EntryHashesHolder(size_t weight_, std::string entry_): weight(weight_), entry(std::move(entry_)) {
+  }
+
+  bool hashesComputed() {
+    return weight == hashes.read_lock()->size();
+  }
+  void hash() {
+    auto locked = hashes.write_lock();
+    locked->clear();
+    locked->reserve(weight);
+    size_t count = 0;
+    while (count < weight) {
+      auto value = boost::str(boost::format("%s-%d") % entry % count);
+      auto whash = burtle(reinterpret_cast<const unsigned char*>(value.c_str()), value.size(), 0);
+      locked->push_back(whash);
+      ++count;
+    }
+    std::sort(locked->begin(), locked->end());
+  }
+};
+
+static std::map<
+  std::tuple<int, std::string, std::string>, // zoneid qname entry
+  std::shared_ptr<EntryHashesHolder> // entry w/ corresponding hashes
+  >
+s_zone_hashes;
+
+static std::vector<std::shared_ptr<EntryHashesHolder>> getCHashedEntries(const int zoneId, const std::string& queryName, const std::vector<std::pair<int, std::string>>& items)
+{
+  std::vector<std::shared_ptr<EntryHashesHolder>> result{};
+
+  for (const auto& [weight, entry]: items) {
+    auto key = std::make_tuple(zoneId, queryName, entry);
+    if (s_zone_hashes.count(key) == 0) {
+      s_zone_hashes[key] = std::make_shared<EntryHashesHolder>(weight, entry);
+    } else {
+      s_zone_hashes.at(key)->weight = weight;
+    }
+    result.push_back(s_zone_hashes.at(key));
+  }
+
+  return result;
+}
+
+static std::string pickConsistentWeightedHashed(const ComboAddress& bestwho, const std::vector<std::pair<int, std::string>>& items)
+{
+  const auto& zoneId = s_lua_record_ctx->zoneid;
+  const auto queryName = s_lua_record_ctx->qname.toString();
+  unsigned int sel = std::numeric_limits<unsigned int>::max();
+  unsigned int min = std::numeric_limits<unsigned int>::max();
+
+  boost::optional<std::string> ret;
+  boost::optional<std::string> first;
+
+  auto entries = getCHashedEntries(zoneId, queryName, items);
+
+  ComboAddress::addressOnlyHash addrOnlyHash;
+  auto qhash = addrOnlyHash(bestwho);
+  for (const auto& entry : entries) {
+    if (!entry->hashesComputed()) {
+      entry->hash();
+    }
+    {
+      const auto hashes = entry->hashes.read_lock();
+      if (hashes->size() > 0) {
+        if (min > *(hashes->begin())) {
+          min = *(hashes->begin());
+          first = entry->entry;
+        }
+
+        auto hash_it = std::lower_bound(hashes->begin(), hashes->end(), qhash);
+        if (hash_it != hashes->end()) {
+          if (*hash_it < sel) {
+            sel = *hash_it;
+            ret = entry->entry;
+          }
+        }
+      }
+    }
+  }
+  if (ret != boost::none) {
+    return *ret;
+  }
+  if (first != boost::none) {
+    return *first;
+  }
+  return std::string();
+}
+
 static vector<string> genericIfUp(const boost::variant<iplist_t, ipunitlist_t>& ips, boost::optional<opts_t> options, const std::function<bool(const ComboAddress&, const opts_t&)>& upcheckf, uint16_t port = 0)
 {
   vector<vector<ComboAddress> > candidates;
@@ -1025,8 +1124,9 @@ static void setupLuaRecords(LuaContext& lua) // NOLINT(readability-function-cogn
       vector< pair<int, string> > items;
 
       items.reserve(ips.size());
-      for(auto& i : ips)
+      for (auto& i : ips) {
         items.emplace_back(atoi(i.second[1].c_str()), i.second[2]);
+      }
 
       return pickWeightedHashed<string>(s_lua_record_ctx->bestwho, items);
     });
@@ -1047,6 +1147,21 @@ static void setupLuaRecords(LuaContext& lua) // NOLINT(readability-function-cogn
 
       return pickWeightedNameHashed<string>(s_lua_record_ctx->qname, items);
     });
+  /*
+   * Based on the hash of `bestwho`, returns an IP address from the list
+   * supplied, as weighted by the various `weight` parameters and distributed consistently
+   * @example pickchashed({ {15, '1.2.3.4'}, {50, '5.4.3.2'} })
+   */
+  lua.writeFunction("pickchashed", [](std::unordered_map<int, wiplist_t > ips) {
+    vector< pair<int, string> > items;
+
+    items.reserve(ips.size());
+    for (auto& i : ips) {
+      items.emplace_back(atoi(i.second[1].c_str()), i.second[2]);
+    }
+
+    return pickConsistentWeightedHashed(s_lua_record_ctx->bestwho, items);
+  });
 
   lua.writeFunction("pickclosest", [](const iplist_t& ips) {
       vector<ComboAddress> conv = convComboAddressList(ips);
index d41a29aee9cff84f1512786db145e08b6a5876fd..75fe646a5b8998708a229a233719944a973d1818 100644 (file)
@@ -70,6 +70,8 @@ hashed-txt.example.org.      3600 IN LUA  TXT   "pickhashed({{ 'bob', 'alice' }}
 whashed.example.org.         3600 IN LUA  A     "pickwhashed({{ {{15, '1.2.3.4'}}, {{42, '4.3.2.1'}} }})"
 *.namehashed.example.org.    3600 IN LUA  A     "picknamehashed({{ {{15, '1.2.3.4'}}, {{42, '4.3.2.1'}} }})"
 whashed-txt.example.org.     3600 IN LUA  TXT   "pickwhashed({{ {{15, 'bob'}}, {{42, 'alice'}} }})"
+chashed.example.org.         3600 IN LUA  A     "pickchashed({{ {{15, '1.2.3.4'}}, {{42, '4.3.2.1'}} }})"
+chashed-txt.example.org.     3600 IN LUA  TXT   "pickchashed({{ {{15, 'joh'}}, {{42, 'do'}} }})"
 rand.example.org.            3600 IN LUA  A     "pickrandom({{'{prefix}.101', '{prefix}.102'}})"
 rand-txt.example.org.        3600 IN LUA  TXT   "pickrandom({{ 'bob', 'alice' }})"
 randn-txt.example.org.       3600 IN LUA  TXT   "pickrandomsample( 2, {{ 'bob', 'alice', 'john' }} )"
@@ -777,24 +779,23 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
         self.assertRcodeEqual(res, dns.rcode.SERVFAIL)
         self.assertAnswerEmpty(res)
 
-
-    def testWHashed(self):
+    def testCWHashed(self):
         """
-        Basic pickwhashed() test with a set of A records
+        Basic pickwhashed() and pickchashed() test with a set of A records
         As the `bestwho` is hashed, we should always get the same answer
         """
-        expected = [dns.rrset.from_text('whashed.example.org.', 0, dns.rdataclass.IN, 'A', '1.2.3.4'),
-                    dns.rrset.from_text('whashed.example.org.', 0, dns.rdataclass.IN, 'A', '4.3.2.1')]
-        query = dns.message.make_query('whashed.example.org', 'A')
-
-        first = self.sendUDPQuery(query)
-        self.assertRcodeEqual(first, dns.rcode.NOERROR)
-        self.assertAnyRRsetInAnswer(first, expected)
-        for _ in range(5):
-            res = self.sendUDPQuery(query)
-            self.assertRcodeEqual(res, dns.rcode.NOERROR)
-            self.assertRRsetInAnswer(res, first.answer[0])
+        for qname in ['whashed.example.org.', 'chashed.example.org.']:
+            expected = [dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '1.2.3.4'),
+                        dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'A', '4.3.2.1')]
+            query = dns.message.make_query(qname, 'A')
 
+            first = self.sendUDPQuery(query)
+            self.assertRcodeEqual(first, dns.rcode.NOERROR)
+            self.assertAnyRRsetInAnswer(first, expected)
+            for _ in range(5):
+                res = self.sendUDPQuery(query)
+                self.assertRcodeEqual(res, dns.rcode.NOERROR)
+                self.assertRRsetInAnswer(res, first.answer[0])
 
     def testNamehashed(self):
         """
@@ -821,23 +822,23 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertRRsetInAnswer(res, query['expected'])
 
-
-    def testWHashedTxt(self):
+    def testCWHashedTxt(self):
         """
         Basic pickwhashed() test with a set of TXT records
         As the `bestwho` is hashed, we should always get the same answer
         """
-        expected = [dns.rrset.from_text('whashed-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'bob'),
-                    dns.rrset.from_text('whashed-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'alice')]
-        query = dns.message.make_query('whashed-txt.example.org', 'TXT')
+        for qname in ['whashed-txt.example.org.', 'chashed-txt.example.org.']:
+            expected = [dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'TXT', 'bob'),
+                        dns.rrset.from_text(qname, 0, dns.rdataclass.IN, 'TXT', 'alice')]
+            query = dns.message.make_query(qname,'TXT')
 
-        first = self.sendUDPQuery(query)
-        self.assertRcodeEqual(first, dns.rcode.NOERROR)
-        self.assertAnyRRsetInAnswer(first, expected)
-        for _ in range(5):
-            res = self.sendUDPQuery(query)
-            self.assertRcodeEqual(res, dns.rcode.NOERROR)
-            self.assertRRsetInAnswer(res, first.answer[0])
+            first = self.sendUDPQuery(query)
+            self.assertRcodeEqual(first, dns.rcode.NOERROR)
+            self.assertAnyRRsetInAnswer(first, expected)
+            for _ in range(5):
+                res = self.sendUDPQuery(query)
+                self.assertRcodeEqual(res, dns.rcode.NOERROR)
+                self.assertRRsetInAnswer(res, first.answer[0])
 
     def testHashed(self):
         """