]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Extend LUA records
authorrage4 <office@gbshouse.com>
Thu, 28 Apr 2022 18:57:30 +0000 (20:57 +0200)
committerPiotr Ginalski <office@gbshouse.com>
Fri, 17 Jun 2022 08:41:59 +0000 (10:41 +0200)
.github/actions/spell-check/expect.txt
docs/lua-records/functions.rst
docs/lua-records/index.rst
modules/geoipbackend/regression-tests/GeoLiteCity.mmdb
modules/geoipbackend/regression-tests/write-mmdb.pl
pdns/lua-record.cc
regression-tests.auth-py/test_LuaRecords.py

index 36ab04741342a87fc3a84ca13c12e047f1f8ce43..958e0d99b22e1c50ee57a84760f7419a9a87b10b 100644 (file)
@@ -1243,6 +1243,8 @@ phonedph
 php
 pickclosest
 pickrandom
+pickrandomsample
+pickhashed
 pickwhashed
 pickwrandom
 pid
index ce42be063706096d27b89da61a1e6aa3d38165ef..0347b02faa14ff08bbf5e5a892d56ade0817e02c 100644 (file)
@@ -117,11 +117,30 @@ Record creation functions
   The 404s will cause the first group of IPs to get marked as down, after which the URL in the second group is tested.
   The third IP will get marked up assuming ``https://example.net/`` responds with HTTP response code 200.
 
-.. function:: pickrandom(addresses)
+.. function:: pickrandom(values)
 
-  Returns a random IP address from the list supplied.
+  Returns a random value from the list supplied.
 
-  :param addresses: A list of strings with the possible IP addresses.
+  :param values: A list of strings such as IPv4 or IPv6 address.
+
+  This function also works for CNAME or TXT records.
+
+.. function:: pickrandomsample(number, values)
+
+  Returns N random values from the list supplied.
+
+  :param number: Number of values to return
+  :param values: A list of strings such as IPv4 or IPv6 address.
+
+  This function also works for CNAME or TXT records.
+
+.. function:: pickhashed(values)
+
+  Based on the hash of ``bestwho``, returns a random value from the list supplied.
+
+  :param values: A list of strings such as IPv4 or IPv6 address.
+
+  This function also works for CNAME or TXT records.
 
 .. function:: pickclosest(addresses)
 
@@ -159,6 +178,14 @@ Record creation functions
 
   Performs no uptime checking.
 
+.. function:: all(values)
+
+  Returns all values.
+
+  :param values: A list of strings such as IPv4 or IPv6 address.
+
+  This function also works for CNAME or TXT records.
+
 .. function:: view(pairs)
 
   Shorthand function to implement 'views' for all record types.
@@ -177,18 +204,20 @@ Record creation functions
 
   This function also works for CNAME or TXT records.
 
-.. function:: pickwhashed(weightparams)
+.. function:: pickwhashed(values)
 
-  Based on the hash of ``bestwho``, returns an IP address from the list
+  Based on the hash of ``bestwho``, returns a string from the list
   supplied, as weighted by the various ``weight`` parameters.
   Performs no uptime checking.
 
-  :param weightparams: table of weight, IP addresses.
+  :param values: table of weight, string (such as IPv4 or IPv6 address).
 
   Because of the hash, the same client keeps getting the same answer, but
   given sufficient clients, the load is still spread according to the weight
   factors.
 
+  This function also works for CNAME or TXT records.
+
   An example::
 
     mydomain.example.com    IN    LUA    A ("pickwhashed({                             "
@@ -197,15 +226,17 @@ Record creation functions
                                             "})                                        ")
 
 
-.. function:: pickwrandom(weightparams)
+.. function:: pickwrandom(values)
 
-  Returns a random IP address from the list supplied, as weighted by the
+  Returns a random string from the list supplied, as weighted by the
   various ``weight`` parameters. Performs no uptime checking.
 
-  :param weightparams: table of weight, IP addresses.
+  :param values: table of weight, string (such as IPv4 or IPv6 address).
 
   See :func:`pickwhashed` for an example.
 
+  This function also works for CNAME or TXT records.
+
 Reverse DNS functions
 ~~~~~~~~~~~~~~~~~~~~~
 
@@ -378,8 +409,25 @@ Helper functions
   :param string country: A country code like "NL"
   :param [string] countries: A list of country codes
 
-.. function:: continent(continent)
-              continent(continents)
+.. function:: countryCode()
+
+  Returns two letter ISO country code based ``bestwho`` IP address, as described in :doc:`../backends/geoip`.
+  If the two letter ISO country code is unknown "--" will be returned.
+
+.. function:: region()
+
+  Returns true if the ``bestwho`` IP address of the client is within the
+  two letter ISO region code passed, as described in :doc:`../backends/geoip`.
+
+  :param string region: A region code like "CA"
+  :param [string] regions: A list of regions codes
+
+.. function:: regionCode()
+
+  Returns two letter ISO region code based ``bestwho`` IP address, as described in :doc:`../backends/geoip`.
+  If the two letter ISO region code is unknown "--" will be returned.
+
+.. function:: continent()
 
   Returns true if the ``bestwho`` IP address of the client is within the
   continent passed, as described in :doc:`../backends/geoip`.
@@ -387,6 +435,11 @@ Helper functions
   :param string continent: A continent code like "EU"
   :param [string] continents: A list of continent codes
 
+.. function:: continentCode()
+
+  Returns two letter ISO continent code based ``bestwho`` IP address, as described in :doc:`../backends/geoip`.
+  If the two letter ISO continent code is unknown "--" will be returned.
+
 .. function:: netmask(netmasks)
 
   Returns true if ``bestwho`` is within any of the listed subnets.
index dad201f082c6adf8fbaa5c66a1bf4d017b601ac4..37f46e9cc3eaf7e6e083b28b42e3838b3f0b7c6c 100644 (file)
@@ -68,6 +68,12 @@ addresses.
 
 This will pick from the viable IP addresses the one deemed closest to the user.
 
+LUA records can also contain more complex code, for example::
+
+    www    IN    LUA    A    ";if countryCode('US') then return {'192.0.2.1','192.0.2.2','198.51.100.1'} else return '192.0.2.2' end"
+
+As you can see you can return both single string value or array of strings. 
+
 Using LUA Records with Generic SQL backends
 -------------------------------------------
 
index 2ef58894c49f23805e26086c78c5da6b43d74736..fc6cc78260e5c1edf28e13347f532179178dd2f3 100644 (file)
Binary files a/modules/geoipbackend/regression-tests/GeoLiteCity.mmdb and b/modules/geoipbackend/regression-tests/GeoLiteCity.mmdb differ
index 0b90fbedc6c5983ab42bc755d630b50343c4d530..8a98d90d637ae6035e2ebb1365a96361afab7ea6 100644 (file)
@@ -49,6 +49,7 @@ $tree->insert_network(
       'location' => { "latitude" => 47.913000, "longitude" => -122.304200, accuracy_radius => 1 },
       'autonomous_system_number' => 3320,
       'autonomous_system_organization' => "Test Networks",
+         'subdivisions' => [{ "geoname_id" => 5332921, "iso_code" => "CA", "names" => { "en" => "California" } }]
     }
 );
 
index 3571facf89b613c8ac6d79fd4037f65d8912f221..fa8de9e612fc1156c325211756d3bf42bec715a6 100644 (file)
@@ -2,6 +2,8 @@
 #include <future>
 #include <boost/format.hpp>
 #include <utility>
+#include <algorithm>
+#include <random>
 #include "version.hh"
 #include "ext/luawrapper/include/LuaContext.hpp"
 #include "lock.hh"
@@ -309,61 +311,165 @@ static std::string getGeo(const std::string& ip, GeoIPInterface::GeoIPQueryAttri
     return g_getGeo(ip, (int)qa);
 }
 
-static ComboAddress pickrandom(const vector<ComboAddress>& ips)
+static string pickRandomString(const vector<string>& items)
 {
-  if (ips.empty()) {
-    throw std::invalid_argument("The IP list cannot be empty");
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
   }
-  return ips[dns_random(ips.size())];
+  return items[dns_random(items.size())];
 }
 
-static ComboAddress hashed(const ComboAddress& who, const vector<ComboAddress>& ips)
+static ComboAddress pickRandomComboAddress(const vector<ComboAddress>& items)
 {
-  if (ips.empty()) {
-    throw std::invalid_argument("The IP list cannot be empty");
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
+  }
+  return items[dns_random(items.size())];
+}
+
+static string pickHashedString(const ComboAddress& who, const vector<string>& items)
+{
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
   }
   ComboAddress::addressOnlyHash aoh;
-  return ips[aoh(who) % ips.size()];
+  return items[aoh(who) % items.size()];
 }
 
+static ComboAddress pickHashedComboAddress(const ComboAddress& who, const vector<ComboAddress>& items)
+{
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
+  }
+  ComboAddress::addressOnlyHash aoh;
+  return items[aoh(who) % items.size()];
+}
 
-static ComboAddress pickwrandom(const vector<pair<int,ComboAddress> >& wips)
+static string pickWeightedRandomString(const vector< pair<int, string> >& items)
 {
-  if (wips.empty()) {
-    throw std::invalid_argument("The IP list cannot be empty");
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
   }
   int sum=0;
-  vector<pair<int, ComboAddress> > pick;
-  for(auto& i : wips) {
+  vector< pair<int, string> > pick;
+  pick.reserve(items.size());
+
+  for(auto& i : items) {
     sum += i.first;
     pick.emplace_back(sum, i.second);
   }
+  
+  if (sum == 0) {
+    /* we should not have any weight of zero, but better safe than sorry */
+    return std::string();
+  }
+  
   int r = dns_random(sum);
   auto p = upper_bound(pick.begin(), pick.end(), r, [](int rarg, const decltype(pick)::value_type& a) { return rarg < a.first; });
   return p->second;
 }
 
-static ComboAddress pickwhashed(const ComboAddress& bestwho, vector<pair<int,ComboAddress> >& wips)
+static ComboAddress pickWeightedRandomComboAddress(const vector< pair<int, ComboAddress> >& items)
 {
-  if (wips.empty()) {
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
+  }
+  int sum=0;
+  vector< pair<int, ComboAddress> > pick;
+  pick.reserve(items.size());
+
+  for(auto& i : items) {
+    sum += i.first;
+    pick.emplace_back(sum, ComboAddress(i.second));
+  }
+  
+  if (sum == 0) {
+    /* we should not have any weight of zero, but better safe than sorry */
     return ComboAddress();
   }
+  
+  int r = dns_random(sum);
+  auto p = upper_bound(pick.begin(), pick.end(), r, [](int rarg, const decltype(pick)::value_type& a) { return rarg < a.first; });
+  return p->second;
+}
+
+static string pickWeightedHashedString(const ComboAddress& bestwho, vector< pair<int, string> >& items)
+{
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
+  }
   int sum=0;
-  vector<pair<int, ComboAddress> > pick;
-  for(auto& i : wips) {
+  vector< pair<int, string> > pick;
+  pick.reserve(items.size());
+
+  for(auto& i : items) {
     sum += i.first;
     pick.push_back({sum, i.second});
   }
+  
+  if (sum == 0) {
+    /* we should not have any weight of zero, but better safe than sorry */
+    return std::string();
+  }
+
+  ComboAddress::addressOnlyHash aoh;
+  int r = aoh(bestwho) % sum;
+  auto p = upper_bound(pick.begin(), pick.end(), r, [](int rarg, const decltype(pick)::value_type& a) { return rarg < a.first; });
+  return p->second;
+}
+
+static ComboAddress pickWeightedHashedComboAddress(const ComboAddress& bestwho, vector< pair<int, ComboAddress> >& items)
+{
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
+  }
+  int sum=0;
+  vector< pair<int, ComboAddress> > pick;
+  pick.reserve(items.size());
+  
+  for(auto& i : items) {
+    sum += i.first;
+    pick.push_back({sum, ComboAddress(i.second)});
+  }
+  
   if (sum == 0) {
     /* we should not have any weight of zero, but better safe than sorry */
     return ComboAddress();
   }
+
   ComboAddress::addressOnlyHash aoh;
   int r = aoh(bestwho) % sum;
   auto p = upper_bound(pick.begin(), pick.end(), r, [](int rarg, const decltype(pick)::value_type& a) { return rarg < a.first; });
   return p->second;
 }
 
+static vector<string> pickRandomStrings(int n, const vector<string>& items) 
+{
+  if (items.empty()) {
+    throw std::invalid_argument("The items list cannot be empty");
+  }
+  
+  vector<string> pick;
+  pick.reserve(items.size());
+  
+  for(auto& item : items) {
+    pick.push_back(item);
+  }
+  
+  int count = std::min(std::max<size_t>(0, n), items.size());
+
+  if (count == 0) {
+    return vector<string>();
+  }  
+
+  auto rdev = std::random_device {}; 
+  auto reng = std::default_random_engine { rdev() };
+  std::shuffle(pick.begin(), pick.end(), reng);
+  
+  vector<string> result = {pick.begin(), pick.begin() + count};
+  return result;
+}
+
 static bool getLatLon(const std::string& ip, double& lat, double& lon)
 {
   string inp = getGeo(ip, GeoIPInterface::Location);
@@ -431,7 +537,7 @@ static ComboAddress pickclosest(const ComboAddress& bestwho, const vector<ComboA
   if (wips.empty()) {
     throw std::invalid_argument("The IP list cannot be empty");
   }
-  map<double,vector<ComboAddress> > ranked;
+  map<double, vector<ComboAddress> > ranked;
   double wlat=0, wlon=0;
   getLatLon(bestwho.toString(), wlat, wlon);
   //        cout<<"bestwho "<<wlat<<", "<<wlon<<endl;
@@ -484,52 +590,77 @@ static vector<ComboAddress> useSelector(const std::string &selector, const Combo
   if(selector=="all")
     return candidates;
   else if(selector=="random")
-    ret.emplace_back(pickrandom(candidates));
+    ret.emplace_back(pickRandomComboAddress(candidates));
   else if(selector=="pickclosest")
     ret.emplace_back(pickclosest(bestwho, candidates));
   else if(selector=="hashed")
-    ret.emplace_back(hashed(bestwho, candidates));
+    ret.emplace_back(pickHashedComboAddress(bestwho, candidates));
   else {
     g_log<<Logger::Warning<<"LUA Record called with unknown selector '"<<selector<<"'"<<endl;
-    ret.emplace_back(pickrandom(candidates));
+    ret.emplace_back(pickRandomComboAddress(candidates));
   }
 
   return ret;
 }
 
-static vector<string> convIpListToString(const vector<ComboAddress> &comboAddresses)
+static vector<string> convComboAddressListToString(const vector<ComboAddress>& items)
 {
-  vector<string> ret;
+  vector<string> result;
+  result.reserve(items.size());
 
-  ret.reserve(comboAddresses.size());
-  for (const auto& c : comboAddresses) {
-    ret.emplace_back(c.toString());
+  for (const auto& item : items) {
+    result.emplace_back(item.toString());
   }
 
-  return ret;
+  return result;
 }
 
-static vector<ComboAddress> convIplist(const iplist_t& src)
+static vector<ComboAddress> convComboAddressList(const iplist_t& items)
 {
-  vector<ComboAddress> ret;
+  vector<ComboAddress> result;
+  result.reserve(items.size());
 
-  for(const auto& ip : src) {
-    ret.emplace_back(ip.second);
+  for(const auto& item : items) {
+    result.emplace_back(ComboAddress(item.second));
   }
 
-  return ret;
+  return result;
 }
 
-static vector<pair<int, ComboAddress> > convWIplist(const std::unordered_map<int, wiplist_t >& src)
+static vector<string> convStringList(const iplist_t& items)
 {
-  vector<pair<int,ComboAddress> > ret;
+  vector<string> result;
+  result.reserve(items.size());
 
-  ret.reserve(src.size());
-  for(const auto& i : src) {
-    ret.emplace_back(atoi(i.second.at(1).c_str()), ComboAddress(i.second.at(2)));
+  for(const auto& item : items) {
+    result.emplace_back(item.second);
   }
 
-  return ret;
+  return result;
+}
+
+static vector< pair<int, ComboAddress> > convIntComboAddressList(const std::unordered_map<int, wiplist_t >& items)
+{
+  vector< pair<int,ComboAddress> > result;
+  result.reserve(items.size());
+  
+  for(const auto& item : items) {
+    result.emplace_back(atoi(item.second.at(1).c_str()), ComboAddress(item.second.at(2)));
+  }
+
+  return result;
+}
+
+static vector< pair<int, string> > convIntStringPairList(const std::unordered_map<int, wiplist_t >& items)
+{
+  vector<pair<int,string> > result;
+  result.reserve(items.size());
+
+  for(const auto& item : items) {
+    result.emplace_back(atoi(item.second.at(1).c_str()), item.second.at(2));
+  }
+
+  return result;
 }
 
 static thread_local unique_ptr<AuthLua4> s_LUA;
@@ -792,7 +923,7 @@ static void setupLuaRecords()
       }
 
       vector<ComboAddress> res = useSelector(selector, s_lua_record_ctx->bestwho, candidates);
-      return convIpListToString(res);
+      return convComboAddressListToString(res);
     });
 
   lua.writeFunction("ifurlextup", [](const vector<pair<int, opts_t> >& ipurls, boost::optional<opts_t> options) {
@@ -819,13 +950,13 @@ static void setupLuaRecords()
         }
         if(!available.empty()) {
           vector<ComboAddress> res = useSelector(getOptionValue(options, "selector", "random"), s_lua_record_ctx->bestwho, available);
-          return convIpListToString(res);
+          return convComboAddressListToString(res);
         }
       }
 
       // All units down, apply backupSelector on all candidates
       vector<ComboAddress> res = useSelector(getOptionValue(options, "backupSelector", "random"), s_lua_record_ctx->bestwho, candidates);
-      return convIpListToString(res);
+      return convComboAddressListToString(res);
     });
 
   lua.writeFunction("ifurlup", [](const std::string& url,
@@ -836,12 +967,12 @@ static void setupLuaRecords()
       if(options)
         opts = *options;
       if(auto simple = boost::get<iplist_t>(&ips)) {
-        vector<ComboAddress> unit = convIplist(*simple);
+        vector<ComboAddress> unit = convComboAddressList(*simple);
         candidates.push_back(unit);
       } else {
         auto units = boost::get<ipunitlist_t>(ips);
         for(const auto& u : units) {
-          vector<ComboAddress> unit = convIplist(u.second);
+          vector<ComboAddress> unit = convComboAddressList(u.second);
           candidates.push_back(unit);
         }
       }
@@ -855,7 +986,7 @@ static void setupLuaRecords()
         }
         if(!available.empty()) {
           vector<ComboAddress> res = useSelector(getOptionValue(options, "selector", "random"), s_lua_record_ctx->bestwho, available);
-          return convIpListToString(res);
+          return convComboAddressListToString(res);
         }
       }
 
@@ -866,28 +997,34 @@ static void setupLuaRecords()
       }
 
       vector<ComboAddress> res = useSelector(getOptionValue(options, "backupSelector", "random"), s_lua_record_ctx->bestwho, ret);
-      return convIpListToString(res);
+      return convComboAddressListToString(res);
     });
   /*
    * Returns a random IP address from the supplied list
    * @example pickrandom({ '1.2.3.4', '5.4.3.2' })"
    */
   lua.writeFunction("pickrandom", [](const iplist_t& ips) {
-      vector<ComboAddress> conv = convIplist(ips);
-
-      return pickrandom(conv).toString();
+      vector<string> items = convStringList(ips);
+      return pickRandomString(items);
     });
 
+  lua.writeFunction("pickrandomsample", [](int n, const iplist_t& ips) {
+      vector<string> items = convStringList(ips);
+         return pickRandomStrings(n, items);
+    });
 
+  lua.writeFunction("pickhashed", [](const iplist_t& ips) {
+      vector<string> items = convStringList(ips);
+      return pickHashedString(s_lua_record_ctx->bestwho, items);
+    });
   /*
    * Returns a random IP address from the supplied list, as weighted by the
    * various ``weight`` parameters
    * @example pickwrandom({ {100, '1.2.3.4'}, {50, '5.4.3.2'}, {1, '192.168.1.0'} })
    */
   lua.writeFunction("pickwrandom", [](std::unordered_map<int, wiplist_t> ips) {
-      vector<pair<int,ComboAddress> > conv = convWIplist(ips);
-
-      return pickwrandom(conv).toString();
+      vector< pair<int, string> > items = convIntStringPairList(ips);
+      return pickWeightedRandomString(items);
     });
 
   /*
@@ -896,18 +1033,18 @@ static void setupLuaRecords()
    * @example pickwhashed({ {15, '1.2.3.4'}, {50, '5.4.3.2'} })
    */
   lua.writeFunction("pickwhashed", [](std::unordered_map<int, wiplist_t > ips) {
-      vector<pair<int,ComboAddress> > conv;
+      vector< pair<int, string> > items;
 
-      conv.reserve(ips.size());
+      items.reserve(ips.size());
       for(auto& i : ips)
-        conv.emplace_back(atoi(i.second[1].c_str()), ComboAddress(i.second[2]));
+        items.emplace_back(atoi(i.second[1].c_str()), i.second[2]);
 
-      return pickwhashed(s_lua_record_ctx->bestwho, conv).toString();
+      return pickWeightedHashedString(s_lua_record_ctx->bestwho, items);
     });
 
 
   lua.writeFunction("pickclosest", [](const iplist_t& ips) {
-      vector<ComboAddress > conv = convIplist(ips);
+      vector<ComboAddress> conv = convComboAddressList(ips);
 
       return pickclosest(s_lua_record_ctx->bestwho, conv).toString();
 
@@ -926,25 +1063,57 @@ static void setupLuaRecords()
   });
 
   typedef const boost::variant<string,vector<pair<int,string> > > combovar_t;
+
+  lua.writeFunction("asnum", [](const combovar_t& asns) {
+      string res=getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::ASn);
+      return doCompare(asns, res, [](const std::string& a, const std::string& b) {
+          return !strcasecmp(a.c_str(), b.c_str());
+        });
+    });
   lua.writeFunction("continent", [](const combovar_t& continent) {
      string res=getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Continent);
       return doCompare(continent, res, [](const std::string& a, const std::string& b) {
           return !strcasecmp(a.c_str(), b.c_str());
         });
     });
-  lua.writeFunction("asnum", [](const combovar_t& asns) {
-      string res=getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::ASn);
-      return doCompare(asns, res, [](const std::string& a, const std::string& b) {
+  lua.writeFunction("continentCode", []() {
+      string unknown("unknown");
+      string res = getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Continent);
+      if ( res == unknown ) {
+       return std::string("--");
+      }
+      return res;
+    });
+  lua.writeFunction("country", [](const combovar_t& var) {
+      string res = getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Country2);
+      return doCompare(var, res, [](const std::string& a, const std::string& b) {
           return !strcasecmp(a.c_str(), b.c_str());
         });
+
     });
-  lua.writeFunction("country", [](const combovar_t& var) {
+  lua.writeFunction("countryCode", []() {
+      string unknown("unknown");
       string res = getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Country2);
+      if ( res == unknown ) {
+       return std::string("--");
+      }
+      return res;
+    });
+  lua.writeFunction("region", [](const combovar_t& var) {
+      string res = getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Region);
       return doCompare(var, res, [](const std::string& a, const std::string& b) {
           return !strcasecmp(a.c_str(), b.c_str());
         });
 
     });
+  lua.writeFunction("regionCode", []() {
+      string unknown("unknown");
+      string res = getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Region);
+      if ( res == unknown ) {
+       return std::string("--");
+      }
+      return res;
+    });
   lua.writeFunction("netmask", [](const iplist_t& ips) {
       for(const auto& i :ips) {
         Netmask nm(i.second);
@@ -978,9 +1147,20 @@ static void setupLuaRecords()
         }
       }
       return std::string();
-    }
-    );
+    });
 
+  lua.writeFunction("all", [](const vector< pair<int,string> >& ips) {
+      vector<string> result;
+         result.reserve(ips.size());
+         
+      for(const auto& ip : ips) {
+          result.emplace_back(ip.second);
+      }
+      if(result.empty()) {
+        throw std::invalid_argument("The IP list cannot be empty");
+      }
+      return result;
+    });
 
   lua.writeFunction("include", [&lua](string record) {
       DNSName rec;
index 409d3951efee1f6134a10cbcb3876388a8480b90..34fc00296ee5a1493ab99ded455c9095d651b3f1 100644 (file)
@@ -59,14 +59,22 @@ some.ifportup                3600 IN LUA  A     "ifportup(8080, {{'192.168.42.21
 none.ifportup                3600 IN LUA  A     "ifportup(8080, {{'192.168.42.21', '192.168.21.42'}})"
 all.noneup.ifportup          3600 IN LUA  A     "ifportup(8080, {{'192.168.42.21', '192.168.21.42'}}, {{ backupSelector='all' }})"
 
+hashed.example.org.          3600 IN LUA  A     "pickhashed({{ '1.2.3.4', '4.3.2.1' }})"
+hashed-v6.example.org.       3600 IN LUA  AAAA  "pickhashed({{ '2001:db8:a0b:12f0::1', 'fe80::2a1:9bff:fe9b:f268' }})"
+hashed-txt.example.org.      3600 IN LUA  TXT   "pickhashed({{ 'bob', 'alice' }})"
 whashed.example.org.         3600 IN LUA  A     "pickwhashed({{ {{15, '1.2.3.4'}}, {{42, '4.3.2.1'}} }})"
+whashed-txt.example.org.     3600 IN LUA  TXT   "pickwhashed({{ {{15, 'bob'}}, {{42, 'alice'}} }})"
 rand.example.org.            3600 IN LUA  A     "pickrandom({{'{prefix}.101', '{prefix}.102'}})"
+rand-txt.example.org.        3600 IN LUA  TXT   "pickrandom({{ 'bob', 'alice' }})"
+randn-txt.example.org.       3600 IN LUA  TXT   "pickrandomsample( 2, {{ 'bob', 'alice', 'john' }} )"
 v6-bogus.rand.example.org.   3600 IN LUA  AAAA  "pickrandom({{'{prefix}.101', '{prefix}.102'}})"
-v6.rand.example.org.         3600 IN LUA  AAAA  "pickrandom({{'2001:db8:a0b:12f0::1', 'fe80::2a1:9bff:fe9b:f268'}})"
-closest.geo                  3600 IN LUA  A     "pickclosest({{'1.1.1.2','1.2.3.4'}})"
+v6.rand.example.org.         3600 IN LUA  AAAA  "pickrandom({{ '2001:db8:a0b:12f0::1', 'fe80::2a1:9bff:fe9b:f268' }})"
+closest.geo                  3600 IN LUA  A     "pickclosest({{ '1.1.1.2', '1.2.3.4' }})"
 empty.rand.example.org.      3600 IN LUA  A     "pickrandom()"
 timeout.example.org.         3600 IN LUA  A     "; local i = 0 ;  while i < 1000 do pickrandom() ; i = i + 1 end return '1.2.3.4'"
 wrand.example.org.           3600 IN LUA  A     "pickwrandom({{ {{30, '{prefix}.102'}}, {{15, '{prefix}.103'}} }})"
+wrand-txt.example.org.       3600 IN LUA  TXT   "pickwrandom({{ {{30, 'bob'}}, {{15, 'alice'}} }})"
+all.example.org.             3600 IN LUA  A     "all({{'1.2.3.4','4.3.2.1'}})"
 
 config    IN    LUA    LUA ("settings={{stringmatch='Programming in Lua'}} "
                             "EUWips={{'{prefix}.101','{prefix}.102'}}      "
@@ -92,8 +100,12 @@ nl           IN    LUA    A   ( ";include('config')
                                 "return ifportup(8081, NLips) ")
 latlon.geo      IN LUA    TXT "latlon()"
 continent.geo   IN LUA    TXT ";if(continent('NA')) then return 'true' else return 'false' end"
+continent-code.geo   IN LUA    TXT ";return continentCode()"
 asnum.geo       IN LUA    TXT ";if(asnum('4242')) then return 'true' else return 'false' end"
 country.geo     IN LUA    TXT ";if(country('US')) then return 'true' else return 'false' end"
+country-code.geo     IN LUA    TXT ";return countryCode()"
+region.geo      IN LUA    TXT ";if(region('CA')) then return 'true' else return 'false' end"
+region-code.geo      IN LUA    TXT ";return regionCode()"
 latlonloc.geo   IN LUA    TXT "latlonloc()"
 
 true.netmask     IN LUA   TXT   ( ";if(netmask({{ '{prefix}.0/24' }})) "
@@ -183,6 +195,18 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
         self.assertRcodeEqual(res, dns.rcode.NOERROR)
         self.assertAnyRRsetInAnswer(res, expected)
 
+    def testPickRandomTxt(self):
+        """
+        Basic pickrandom() test with a set of TXT records
+        """
+        expected = [dns.rrset.from_text('rand-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'bob'),
+                    dns.rrset.from_text('rand-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'alice')]
+        query = dns.message.make_query('rand-txt.example.org', 'TXT')
+
+        res = self.sendUDPQuery(query)
+        self.assertRcodeEqual(res, dns.rcode.NOERROR)
+        self.assertAnyRRsetInAnswer(res, expected)
+
     def testBogusV6PickRandom(self):
         """
         Test a bogus AAAA pickrandom() record  with a set of v4 addr
@@ -215,6 +239,22 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
         res = self.sendUDPQuery(query)
         self.assertRcodeEqual(res, dns.rcode.SERVFAIL)
 
+    def testPickRandomSampleTxt(self):
+        """
+        Basic pickrandomsample() test with a set of TXT records
+        """
+        expected = [dns.rrset.from_text('randn-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'bob', 'alice'),
+                    dns.rrset.from_text('randn-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'bob', 'john'),
+                    dns.rrset.from_text('randn-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'alice', 'bob'),
+                    dns.rrset.from_text('randn-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'alice', 'john'),
+                    dns.rrset.from_text('randn-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'john', 'bob'),
+                    dns.rrset.from_text('randn-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'john', 'alice')]
+        query = dns.message.make_query('randn-txt.example.org', 'TXT')
+
+        res = self.sendUDPQuery(query)
+        self.assertRcodeEqual(res, dns.rcode.NOERROR)
+        self.assertIn(res.answer[0], expected)
+
     def testWRandom(self):
         """
         Basic pickwrandom() test with a set of A records
@@ -229,6 +269,18 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
         self.assertRcodeEqual(res, dns.rcode.NOERROR)
         self.assertAnyRRsetInAnswer(res, expected)
 
+    def testWRandomTxt(self):
+        """
+        Basic pickwrandom() test with a set of TXT records
+        """
+        expected = [dns.rrset.from_text('wrand-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'bob'),
+                    dns.rrset.from_text('wrand-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'alice')]
+        query = dns.message.make_query('wrand-txt.example.org', 'TXT')
+
+        res = self.sendUDPQuery(query)
+        self.assertRcodeEqual(res, dns.rcode.NOERROR)
+        self.assertAnyRRsetInAnswer(res, expected)
+
     def testIfportup(self):
         """
         Basic ifportup() test
@@ -475,6 +527,63 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertRRsetInAnswer(res, expected)
 
+    def testCountryCode(self):
+        """
+        Basic countryCode() test
+        """
+        queries = [
+            ('1.1.1.0', 24,  '"au"'),
+            ('1.2.3.0', 24,  '"us"'),
+            ('17.1.0.0', 16, '"--"')
+        ]
+        name = 'country-code.geo.example.org.'
+        for (subnet, mask, txt) in queries:
+            ecso = clientsubnetoption.ClientSubnetOption(subnet, mask)
+            query = dns.message.make_query(name, 'TXT', 'IN', use_edns=True, payload=4096, options=[ecso])
+            expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'TXT', txt)
+
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, expected)
+
+    def testRegion(self):
+        """
+        Basic region() test
+        """
+        queries = [
+            ('1.1.1.0', 24,  '"false"'),
+            ('1.2.3.0', 24,  '"true"'),
+            ('17.1.0.0', 16, '"false"')
+        ]
+        name = 'region.geo.example.org.'
+        for (subnet, mask, txt) in queries:
+            ecso = clientsubnetoption.ClientSubnetOption(subnet, mask)
+            query = dns.message.make_query(name, 'TXT', 'IN', use_edns=True, payload=4096, options=[ecso])
+            expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'TXT', txt)
+
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, expected)
+
+    def testRegionCode(self):
+        """
+        Basic regionCode() test
+        """
+        queries = [
+            ('1.1.1.0', 24,  '"--"'),
+            ('1.2.3.0', 24,  '"ca"'),
+            ('17.1.0.0', 16, '"--"')
+        ]
+        name = 'region-code.geo.example.org.'
+        for (subnet, mask, txt) in queries:
+            ecso = clientsubnetoption.ClientSubnetOption(subnet, mask)
+            query = dns.message.make_query(name, 'TXT', 'IN', use_edns=True, payload=4096, options=[ecso])
+            expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'TXT', txt)
+
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, expected)
+
     def testContinent(self):
         """
         Basic continent() test
@@ -494,6 +603,25 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertRRsetInAnswer(res, expected)
 
+    def testContinentCode(self):
+        """
+        Basic continentCode() test
+        """
+        queries = [
+            ('1.1.1.0', 24,  '"oc"'),
+            ('1.2.3.0', 24,  '"na"'),
+            ('17.1.0.0', 16, '"--"')
+        ]
+        name = 'continent-code.geo.example.org.'
+        for (subnet, mask, txt) in queries:
+            ecso = clientsubnetoption.ClientSubnetOption(subnet, mask)
+            query = dns.message.make_query(name, 'TXT', 'IN', use_edns=True, payload=4096, options=[ecso])
+            expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'TXT', txt)
+
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, expected)
+
     def testClosest(self):
         """
         Basic pickclosest() test
@@ -513,6 +641,17 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertRRsetInAnswer(res, expected)
 
+    def testAll(self):
+        """
+        Basic all() test
+        """
+        expected = [dns.rrset.from_text('all.example.org.', 0, dns.rdataclass.IN, dns.rdatatype.A, '1.2.3.4', '4.3.2.1')]
+        query = dns.message.make_query('all.example.org.', 'A')
+
+        res = self.sendUDPQuery(query)
+        self.assertRcodeEqual(res, dns.rcode.NOERROR)
+        self.assertEqual(res.answer, expected)
+
     def testNetmask(self):
         """
         Basic netmask() test
@@ -586,6 +725,74 @@ createforward6.example.org.                 3600 IN NS   ns2.example.org.
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertRRsetInAnswer(res, first.answer[0])
 
+    def testWHashedTxt(self):
+        """
+        Basic pickwhashed() test with a set of TXT records
+        As the `bestwho` is hashed, we should always get the same answer
+        """
+        expected = [dns.rrset.from_text('whashed-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'bob'),
+                    dns.rrset.from_text('whashed-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'alice')]
+        query = dns.message.make_query('whashed-txt.example.org', 'TXT')
+
+        first = self.sendUDPQuery(query)
+        self.assertRcodeEqual(first, dns.rcode.NOERROR)
+        self.assertAnyRRsetInAnswer(first, expected)
+        for _ in range(5):
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, first.answer[0])
+
+    def testHashed(self):
+        """
+        Basic pickhashed() test with a set of A records
+        As the `bestwho` is hashed, we should always get the same answer
+        """
+        expected = [dns.rrset.from_text('hashed.example.org.', 0, dns.rdataclass.IN, 'A', '1.2.3.4'),
+                    dns.rrset.from_text('hashed.example.org.', 0, dns.rdataclass.IN, 'A', '4.3.2.1')]
+        query = dns.message.make_query('hashed.example.org', 'A')
+
+        first = self.sendUDPQuery(query)
+        self.assertRcodeEqual(first, dns.rcode.NOERROR)
+        self.assertAnyRRsetInAnswer(first, expected)
+        for _ in range(5):
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, first.answer[0])
+
+    def testHashedV6(self):
+        """
+        Basic pickhashed() test with a set of AAAA records
+        As the `bestwho` is hashed, we should always get the same answer
+        """
+        expected = [dns.rrset.from_text('hashed-v6.example.org.', 0, dns.rdataclass.IN, 'AAAA', '2001:db8:a0b:12f0::1'),
+                    dns.rrset.from_text('hashed-v6.example.org.', 0, dns.rdataclass.IN, 'AAAA', 'fe80::2a1:9bff:fe9b:f268')]
+        query = dns.message.make_query('hashed-v6.example.org', 'AAAA')
+
+        first = self.sendUDPQuery(query)
+        self.assertRcodeEqual(first, dns.rcode.NOERROR)
+        self.assertAnyRRsetInAnswer(first, expected)
+        for _ in range(5):
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, first.answer[0])
+
+    def testHashedTXT(self):
+        """
+        Basic pickhashed() test with a set of TXT records
+        As the `bestwho` is hashed, we should always get the same answer
+        """
+        expected = [dns.rrset.from_text('hashed-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'bob'),
+                    dns.rrset.from_text('hashed-txt.example.org.', 0, dns.rdataclass.IN, 'TXT', 'alice')]
+        query = dns.message.make_query('hashed-txt.example.org', 'TXT')
+
+        first = self.sendUDPQuery(query)
+        self.assertRcodeEqual(first, dns.rcode.NOERROR)
+        self.assertAnyRRsetInAnswer(first, expected)
+        for _ in range(5):
+            res = self.sendUDPQuery(query)
+            self.assertRcodeEqual(res, dns.rcode.NOERROR)
+            self.assertRRsetInAnswer(res, first.answer[0])
+
     def testTimeout(self):
         """
         Test if LUA scripts are aborted if script execution takes too long