]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add support for range-based lookups into a Key-Value store
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 24 Jun 2021 16:07:00 +0000 (18:07 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 20 Jul 2021 09:32:20 +0000 (11:32 +0200)
This feature allows doing a range-based lookup (mostly useful for IP addresses), assuming that:
- there is a key for the last element of the range (2001:0db8:ffff:ffff:ffff:ffff:ffff:ffff for 2001:db8::/32)
which contains the first element of the range (2001:0db8:0000:0000:0000:0000:0000:0000) followed by any data in the value
- AND there is no overlapping ranges in the database !!

This requires that the underlying store supports ordered keys, which is true for LMDB but not for CDB, for example.

pdns/dnsdist-console.cc
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdistdist/dnsdist-kvs.cc
pdns/dnsdistdist/dnsdist-kvs.hh
pdns/dnsdistdist/dnsdist-lua-bindings-kvs.cc
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/docs/reference/kvs.rst
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsdistdist/test-dnsdistkvs_cc.cc
regression-tests.dnsdist/test_LMDB.py

index 06c69e1b8a9ae6b9d5425c0a6785495d56795fa5..50679039d73a3f86305716f239d7ee8f9a21cd1d 100644 (file)
@@ -479,11 +479,13 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "inClientStartup", true, "", "returns true during console client parsing of configuration" },
   { "includeDirectory", true, "path", "include configuration files from `path`" },
   { "KeyValueLookupKeyQName", true, "[wireFormat]", "Return a new KeyValueLookupKey object that, when passed to KeyValueStoreLookupAction or KeyValueStoreLookupRule, will return the qname of the query, either in wire format (default) or in plain text if 'wireFormat' is false" },
-  { "KeyValueLookupKeySourceIP", true, "[v4Mask [, v6Mask]]", "Return a new KeyValueLookupKey object that, when passed to KeyValueStoreLookupAction or KeyValueStoreLookupRule, will return the (possibly bitmasked) source IP of the client in network byte-order." },
+  { "KeyValueLookupKeySourceIP", true, "[v4Mask [, v6Mask [, includePort]]]", "Return a new KeyValueLookupKey object that, when passed to KeyValueStoreLookupAction or KeyValueStoreLookupRule, will return the (possibly bitmasked) source IP of the client in network byte-order." },
   { "KeyValueLookupKeySuffix", true, "[minLabels [,wireFormat]]", "Return a new KeyValueLookupKey object that, when passed to KeyValueStoreLookupAction or KeyValueStoreLookupRule, will return a vector of keys based on the labels of the qname in DNS wire format or plain text" },
   { "KeyValueLookupKeyTag", true, "tag", "Return a new KeyValueLookupKey object that, when passed to KeyValueStoreLookupAction or KeyValueStoreLookupRule, will return the value of the corresponding tag for this query, if it exists" },
   { "KeyValueStoreLookupAction", true, "kvs, lookupKey, destinationTag", "does a lookup into the key value store referenced by 'kvs' using the key returned by 'lookupKey', and storing the result if any into the tag named 'destinationTag'" },
+  { "KeyValueStoreRangeLookupAction", true, "kvs, lookupKey, destinationTag", "does a range-based lookup into the key value store referenced by 'kvs' using the key returned by 'lookupKey', and storing the result if any into the tag named 'destinationTag'" },
   { "KeyValueStoreLookupRule", true, "kvs, lookupKey", "matches queries if the key is found in the specified Key Value store" },
+  { "KeyValueStoreRangeLookupRule", true, "kvs, lookupKey", "matches queries if the key is found in the specified Key Value store" },
   { "leastOutstanding", false, "", "Send traffic to downstream server with least outstanding queries, with the lowest 'order', and within that the lowest recent latency"},
   { "LogAction", true, "[filename], [binary], [append], [buffered]", "Log a line for each query, to the specified file if any, to the console (require verbose) otherwise. When logging to a file, the `binary` optional parameter specifies whether we log in binary form (default) or in textual form, the `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." },
   { "LogResponseAction", true, "[filename], [append], [buffered]", "Log a line for each response, to the specified file if any, to the console (require verbose) otherwise. The `append` optional parameter specifies whether we open the file for appending or truncate each time (default), and the `buffered` optional parameter specifies whether writes to the file are buffered (default) or not." },
index 453b5c06fc6ef01c319b63763b54c15ed2d96fca..95a4d539131e8c6ef31b7cb9f349283737838011 100644 (file)
@@ -1672,6 +1672,44 @@ private:
   std::string d_tag;
 };
 
+class KeyValueStoreRangeLookupAction : public DNSAction
+{
+public:
+  // this action does not stop the processing
+  KeyValueStoreRangeLookupAction(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag): d_kvs(kvs), d_key(lookupKey), d_tag(destinationTag)
+  {
+  }
+
+  DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
+  {
+    std::vector<std::string> keys = d_key->getKeys(*dq);
+    std::string result;
+    for (const auto& key : keys) {
+      if (d_kvs->getRangeValue(key, result) == true) {
+        break;
+      }
+    }
+
+    if (!dq->qTag) {
+      dq->qTag = std::make_shared<QTag>();
+    }
+
+    dq->qTag->insert({d_tag, std::move(result)});
+
+    return Action::None;
+  }
+
+  std::string toString() const override
+  {
+    return "do a range-based lookup in key-value store based on '" + d_key->toString() + "' and set the result in tag '" + d_tag + "'";
+  }
+
+private:
+  std::shared_ptr<KeyValueStore> d_kvs;
+  std::shared_ptr<KeyValueLookupKey> d_key;
+  std::string d_tag;
+};
+
 class NegativeAndSOAAction: public DNSAction
 {
 public:
@@ -2226,6 +2264,10 @@ void setupLuaActions(LuaContext& luaCtx)
       return std::shared_ptr<DNSAction>(new KeyValueStoreLookupAction(kvs, lookupKey, destinationTag));
     });
 
+  luaCtx.writeFunction("KeyValueStoreRangeLookupAction", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey, const std::string& destinationTag) {
+      return std::shared_ptr<DNSAction>(new KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag));
+    });
+
   luaCtx.writeFunction("NegativeAndSOAAction", [](bool nxd, const std::string& zone, uint32_t ttl, const std::string& mname, const std::string& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum, boost::optional<responseParams_t> vars) {
       auto ret = std::shared_ptr<DNSAction>(new NegativeAndSOAAction(nxd, DNSName(zone), ttl, DNSName(mname), DNSName(rname), serial, refresh, retry, expire, minimum));
       auto action = std::dynamic_pointer_cast<NegativeAndSOAAction>(ret);
index 2153b63a31e81a929b80f2f82e63fd555275d05e..9c9bec0afae774277de1bb1f86a19ca6084d881f 100644 (file)
@@ -595,6 +595,10 @@ void setupLuaRules(LuaContext& luaCtx)
       return std::shared_ptr<DNSRule>(new KeyValueStoreLookupRule(kvs, lookupKey));
     });
 
+  luaCtx.writeFunction("KeyValueStoreRangeLookupRule", [](std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey) {
+      return std::shared_ptr<DNSRule>(new KeyValueStoreRangeLookupRule(kvs, lookupKey));
+    });
+
   luaCtx.writeFunction("LuaRule", [](LuaRule::func_t func) {
       return std::shared_ptr<DNSRule>(new LuaRule(func));
     });
index 35621dd2198a7a060ed8b77ad96bf29c9ae87c36..552936fd84687c6ed18416386f9b7a0ad71c836a 100644 (file)
@@ -30,15 +30,24 @@ std::vector<std::string> KeyValueLookupKeySourceIP::getKeys(const ComboAddress&
   std::vector<std::string> result;
   ComboAddress truncated(addr);
 
+  std::string key;
   if (truncated.isIPv4()) {
     truncated.truncate(d_v4Mask);
-    result.emplace_back(reinterpret_cast<const char*>(&truncated.sin4.sin_addr.s_addr), sizeof(truncated.sin4.sin_addr.s_addr));
+    key.reserve(sizeof(truncated.sin4.sin_addr.s_addr) + (d_includePort ? sizeof(truncated.sin4.sin_port) : 0));
+    key.append(reinterpret_cast<const char*>(&truncated.sin4.sin_addr.s_addr), sizeof(truncated.sin4.sin_addr.s_addr));
   }
   else if (truncated.isIPv6()) {
     truncated.truncate(d_v6Mask);
-    result.emplace_back(reinterpret_cast<const char*>(&truncated.sin6.sin6_addr.s6_addr), sizeof(truncated.sin6.sin6_addr.s6_addr));
+    key.reserve(sizeof(truncated.sin6.sin6_addr.s6_addr) + (d_includePort ? sizeof(truncated.sin4.sin_port) : 0));
+    key.append(reinterpret_cast<const char*>(&truncated.sin6.sin6_addr.s6_addr), sizeof(truncated.sin6.sin6_addr.s6_addr));
   }
 
+  if (d_includePort) {
+    key.append(reinterpret_cast<const char*>(&truncated.sin4.sin_port), sizeof(truncated.sin4.sin_port));
+  }
+
+  result.push_back(std::move(key));
+
   return result;
 }
 
@@ -112,6 +121,51 @@ bool LMDBKVStore::keyExists(const std::string& key)
   return false;
 }
 
+bool LMDBKVStore::getRangeValue(const std::string& key, std::string& value)
+{
+  try {
+    auto transaction = d_env.getROTransaction();
+    auto cursor = transaction->getROCursor(d_dbi);
+    MDBOutVal actualKey;
+    MDBOutVal result;
+    // for range-based lookups, we expect the data in LMDB
+    // to be stored with the last value of the range as key
+    // and the first value of the range as data, sometimes
+    // followed by any other content we don't care about
+
+    // retrieve the first key greater or equal to our key
+    int rc = cursor.lower_bound(MDBInVal(key), actualKey, result);
+
+    if (rc == 0) {
+      auto last = actualKey.get<std::string>();
+      if (last.size() != key.size() || key > last) {
+        return false;
+      }
+
+      value = result.get<std::string>();
+      if (value.size() < key.size()) {
+        return false;
+      }
+
+      // take the first part of the data, which should be
+      // the first address of the range
+      auto first = value.substr(0, key.size());
+      if (first.size() != key.size() || key < first) {
+        return false;
+      }
+
+      return true;
+    }
+    else if (rc == MDB_NOTFOUND) {
+      return false;
+    }
+  }
+  catch(const std::exception& e) {
+    vinfolog("Error while looking up a range from LMDB file '%s', database '%s': %s", d_fname, d_dbName, e.what());
+  }
+  return false;
+}
+
 #endif /* HAVE_LMDB */
 
 #ifdef HAVE_CDB
index 37ac34cab3ffa85dcff7c9b8464f5e1162e20602..f9f9304415388dcaf54bc4dac8ff67fdf717ced0 100644 (file)
@@ -36,7 +36,7 @@ public:
 class KeyValueLookupKeySourceIP: public KeyValueLookupKey
 {
 public:
-  KeyValueLookupKeySourceIP(uint8_t v4Mask, uint8_t v6Mask): d_v4Mask(v4Mask), d_v6Mask(v6Mask)
+  KeyValueLookupKeySourceIP(uint8_t v4Mask, uint8_t v6Mask, bool includePort): d_v4Mask(v4Mask), d_v6Mask(v6Mask), d_includePort(includePort)
   {
   }
 
@@ -49,11 +49,12 @@ public:
 
   std::string toString() const override
   {
-    return "source IP (masked to " + std::to_string(d_v4Mask) + " (v4) / " + std::to_string(d_v6Mask) + " (v6) bits)";
+    return "source IP (masked to " + std::to_string(d_v4Mask) + " (v4) / " + std::to_string(d_v6Mask) + " (v6) bits)" + (d_includePort ? " including the port" : "");
   }
 private:
   uint8_t d_v4Mask;
   uint8_t d_v6Mask;
+  bool d_includePort;
 };
 
 class KeyValueLookupKeyQName: public KeyValueLookupKey
@@ -152,6 +153,15 @@ public:
 
   virtual bool keyExists(const std::string& key) = 0;
   virtual bool getValue(const std::string& key, std::string& value) = 0;
+  // do a range-based lookup (mostly useful for IP addresses), assuming that:
+  // there is a key for the last element of the range (2001:0db8:ffff:ffff:ffff:ffff:ffff:ffff for 2001:db8::/32)
+  // which contains the first element of the range (2001:0db8:0000:0000:0000:0000:0000:0000) followed by any data in the value
+  // AND there is no overlapping ranges in the database !!
+  // This requires that the underlying store supports ordered keys, which is true for LMDB but not for CDB, for example.
+  virtual bool getRangeValue(const std::string& key, std::string& value)
+  {
+    throw std::runtime_error("range-based lookups are not implemented for this Key-Value Store");
+  }
   virtual bool reload()
   {
     return false;
@@ -171,6 +181,7 @@ public:
 
   bool keyExists(const std::string& key) override;
   bool getValue(const std::string& key, std::string& value) override;
+  bool getRangeValue(const std::string& key, std::string& value) override;
 
 private:
   MDBEnv d_env;
index fd7b827e00782d774f5836315f0e493e5bacd4be..b1e0c0fc0f0b09b774bfc35e507fefafc5d47e87 100644 (file)
@@ -26,8 +26,8 @@
 void setupLuaBindingsKVS(LuaContext& luaCtx, bool client)
 {
   /* Key Value Store objects */
-  luaCtx.writeFunction("KeyValueLookupKeySourceIP", [](boost::optional<uint8_t> v4Mask, boost::optional<uint8_t> v6Mask) {
-    return std::shared_ptr<KeyValueLookupKey>(new KeyValueLookupKeySourceIP(v4Mask.get_value_or(32), v6Mask.get_value_or(128)));
+  luaCtx.writeFunction("KeyValueLookupKeySourceIP", [](boost::optional<uint8_t> v4Mask, boost::optional<uint8_t> v6Mask, boost::optional<bool> includePort) {
+    return std::shared_ptr<KeyValueLookupKey>(new KeyValueLookupKeySourceIP(v4Mask.get_value_or(32), v6Mask.get_value_or(128), includePort.get_value_or(false)));
   });
   luaCtx.writeFunction("KeyValueLookupKeyQName", [](boost::optional<bool> wireFormat) {
     return std::shared_ptr<KeyValueLookupKey>(new KeyValueLookupKeyQName(wireFormat ? *wireFormat : true));
@@ -65,7 +65,7 @@ void setupLuaBindingsKVS(LuaContext& luaCtx, bool client)
 
     if (keyVar.type() == typeid(ComboAddress)) {
       const auto ca = boost::get<ComboAddress>(&keyVar);
-      KeyValueLookupKeySourceIP lookup(32, 128);
+      KeyValueLookupKeySourceIP lookup(32, 128, false);
       for (const auto& key : lookup.getKeys(*ca)) {
         if (kvs->getValue(key, result)) {
           return result;
index be35f8027dd63d9bee6e098125bd38b80941e49f..0b42866341afaa37efc0804f22bed02a235f56dd 100644 (file)
@@ -1121,6 +1121,36 @@ private:
   std::shared_ptr<KeyValueLookupKey> d_key;
 };
 
+class KeyValueStoreRangeLookupRule: public DNSRule
+{
+public:
+  KeyValueStoreRangeLookupRule(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey): d_kvs(kvs), d_key(lookupKey)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    std::vector<std::string> keys = d_key->getKeys(*dq);
+    for (const auto& key : keys) {
+      std::string value;
+      if (d_kvs->getRangeValue(key, value) == true) {
+        return true;
+      }
+    }
+
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "range-based lookup key-value store based on '" + d_key->toString() + "'";
+  }
+
+private:
+  std::shared_ptr<KeyValueStore> d_kvs;
+  std::shared_ptr<KeyValueLookupKey> d_key;
+};
+
 class LuaRule : public DNSRule
 {
 public:
index dbdfaacdec7cbdd455d8eefb9d8cee275fcb5c01..b7395a2eabada5f42f86582ef3b65a20eb7e887d 100644 (file)
@@ -86,10 +86,14 @@ If the value found in the LMDB database for the key '\\8powerdns\\3com\\0' was '
   .. versionchanged:: 1.5.0
     Optional parameters ``v4mask`` and ``v6mask`` added.
 
+  .. versionchanged:: 1.7.0
+    Optional parameter ``includePort`` added.
+
   Return a new KeyValueLookupKey object that, when passed to :func:`KeyValueStoreLookupAction` or :func:`KeyValueStoreLookupRule`, will return the source IP of the client in network byte-order.
 
   :param int v4mask: Mask applied to IPv4 addresses. Default is 32 (the whole address)
   :param int v6mask: Mask applied to IPv6 addresses. Default is 128 (the whole address)
+  :param int includePort: Whether to append the port (in network byte-order) after the address. Default is false
 
 .. function:: KeyValueLookupKeySuffix([minLabels [, wireFormat]]) -> KeyValueLookupKey
 
index 835d99631540e0bb45ef79dd7d1d40e0f39ecf09..124307e1a4ff3ecf2f8fba0257bd3deb72d2c189 100644 (file)
@@ -498,6 +498,18 @@ These ``DNSRule``\ s be one of the following items:
   :param KeyValueStore kvs: The key value store to query
   :param KeyValueLookupKey lookupKey: The key to use for the lookup
 
+.. function:: KeyValueStoreRangeLookupRule(kvs, lookupKey)
+
+  .. versionadded:: 1.7.0
+
+  Does a range-based lookup into the key value store referenced by 'kvs' using the key returned by 'lookupKey' and returns true if there is a range covering that key.
+
+  This assumes that there is a key for the last element of the range (for example 2001:0db8:ffff:ffff:ffff:ffff:ffff:ffff for 2001:db8::/32) which contains the first element of the range (2001:0db8:0000:0000:0000:0000:0000:0000) (optionally followed by any data), as value and that there is no overlapping ranges in the database.
+  This requires that the underlying store supports ordered keys, which is true for LMDB but not for CDB.
+
+  :param KeyValueStore kvs: The key value store to query
+  :param KeyValueLookupKey lookupKey: The key to use for the lookup
+
 .. function:: LuaFFIPerThreadRule(function)
 
   .. versionadded:: 1.7.0
@@ -952,6 +964,21 @@ The following actions exist.
   :param KeyValueLookupKey lookupKey: The key to use for the lookup
   :param string destinationTag: The name of the tag to store the result into
 
+.. function:: KeyValueStoreRangeLookupAction(kvs, lookupKey, destinationTag)
+
+  .. versionadded:: 1.7.0
+
+  Does a range-based lookup into the key value store referenced by 'kvs' using the key returned by 'lookupKey',
+  and storing the result if any into the tag named 'destinationTag'.
+  This assumes that there is a key for the last element of the range (for example 2001:0db8:ffff:ffff:ffff:ffff:ffff:ffff for 2001:db8::/32) which contains the first element of the range (2001:0db8:0000:0000:0000:0000:0000:0000) (optionally followed by any data), as value and that there is no overlapping ranges in the database.
+  This requires that the underlying store supports ordered keys, which is true for LMDB but not for CDB.
+
+  Subsequent rules are processed after this action.
+
+  :param KeyValueStore kvs: The key value store to query
+  :param KeyValueLookupKey lookupKey: The key to use for the lookup
+  :param string destinationTag: The name of the tag to store the result into
+
 .. function:: LogAction([filename[, binary[, append[, buffered[, verboseOnly[, includeTimestamp]]]]]])
 
   .. versionchanged:: 1.4.0
index 27e86cae35112b1886872038a2079008b4f6426a..316fd782062937d23be4aa057e3cbab0aa6e4fa4 100644 (file)
@@ -14,7 +14,7 @@ static void doKVSChecks(std::unique_ptr<KeyValueStore>& kvs, const ComboAddress&
 {
   /* source IP */
   {
-    auto lookupKey = make_unique<KeyValueLookupKeySourceIP>(32, 128);
+    auto lookupKey = make_unique<KeyValueLookupKeySourceIP>(32, 128, false);
     std::string value;
     /* local address is not in the db, remote is */
     BOOST_CHECK_EQUAL(kvs->getValue(std::string(reinterpret_cast<const char*>(&lc.sin4.sin_addr.s_addr), sizeof(lc.sin4.sin_addr.s_addr)), value), false);
@@ -32,7 +32,7 @@ static void doKVSChecks(std::unique_ptr<KeyValueStore>& kvs, const ComboAddress&
 
   /* masked source IP */
   {
-    auto lookupKey = make_unique<KeyValueLookupKeySourceIP>(25, 65);
+    auto lookupKey = make_unique<KeyValueLookupKeySourceIP>(25, 65, false);
 
     auto keys = lookupKey->getKeys(v4ToMask);
     BOOST_CHECK_EQUAL(keys.size(), 1U);
@@ -51,6 +51,21 @@ static void doKVSChecks(std::unique_ptr<KeyValueStore>& kvs, const ComboAddress&
     }
   }
 
+  /* source IP + port */
+  {
+    auto lookupKey = make_unique<KeyValueLookupKeySourceIP>(32, 128, true);
+    std::string value;
+    BOOST_CHECK(kvs->keyExists(std::string(reinterpret_cast<const char*>(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast<const char*>(&rem.sin4.sin_port), sizeof(rem.sin4.sin_port))));
+
+    auto keys = lookupKey->getKeys(dq);
+    BOOST_CHECK_EQUAL(keys.size(), 1U);
+    for (const auto& key : keys) {
+      value.clear();
+      BOOST_CHECK_EQUAL(kvs->getValue(key, value), true);
+      BOOST_CHECK_EQUAL(value, "this is the value for the remote addr + port");
+    }
+  }
+
   const DNSName subdomain = DNSName("sub") + *dq.qname;
   const DNSName notPDNS("not-powerdns.com.");
 
@@ -220,6 +235,64 @@ static void doKVSChecks(std::unique_ptr<KeyValueStore>& kvs, const ComboAddress&
     BOOST_CHECK_EQUAL(value, "this is the value for the qname");
   }
 }
+
+#if defined(HAVE_LMDB)
+static void doKVSRangeChecks(std::unique_ptr<KeyValueStore>& kvs)
+{
+  {
+    /* do a range-based lookup */
+    const ComboAddress first("2001:0db8:0000:0000:0000:0000:0000:0000");
+    const ComboAddress inside("2001:0db8:7fff:ffff:ffff:ffff:ffff:ffff");
+    const ComboAddress last("2001:0db8:ffff:ffff:ffff:ffff:ffff:ffff");
+    const ComboAddress notInRange1("2001:0db7:ffff:ffff:ffff:ffff:ffff:ffff");
+    const ComboAddress notInRange2("2001:0db9:0000:0000:0000:0000:0000:0000");
+    const std::string expectedValue = std::string(reinterpret_cast<const char*>(&first.sin6.sin6_addr.s6_addr), sizeof(first.sin6.sin6_addr.s6_addr)) + std::string("any other data");
+
+    auto check = [expectedValue, &kvs](const ComboAddress& key, bool shouldBeFound) {
+      // cerr<<"Checking "<<key.toString()<<", should "<<(shouldBeFound ? "" : "NOT ")<<"be found"<<endl;
+      auto lookupKey = std::string(reinterpret_cast<const char*>(&key.sin6.sin6_addr.s6_addr), sizeof(key.sin6.sin6_addr.s6_addr));
+      std::string value;
+      BOOST_CHECK_EQUAL(kvs->getRangeValue(lookupKey, value), shouldBeFound);
+      if (shouldBeFound) {
+        BOOST_CHECK_EQUAL(value, expectedValue);
+      }
+    };
+
+    check(first, true);
+    check(last, true);
+    check(inside, true);
+    check(notInRange1, false);
+    check(notInRange2, false);
+  }
+
+  {
+    const ComboAddress first("192.0.2.1:0");
+    const ComboAddress inside("192.0.2.1:42");
+    const ComboAddress last("192.0.2.1:16383");
+    const ComboAddress notInRange1("192.0.2.0:65535");
+    const ComboAddress notInRange2("192.0.2.1:16384");
+    const std::string expectedValue = std::string(reinterpret_cast<const char*>(&first.sin4.sin_addr.s_addr), sizeof(first.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast<const char*>(&first.sin4.sin_port), sizeof(first.sin4.sin_port)) + std::string("any other data");
+
+    auto check = [expectedValue, &kvs](const ComboAddress& key, bool shouldBeFound) {
+      // cerr<<"Checking "<<key.toStringWithPort()<<", should "<<(shouldBeFound ? "" : "NOT ")<<"be found"<<endl;
+      auto lookupKey = std::string(reinterpret_cast<const char*>(&key.sin4.sin_addr.s_addr), sizeof(key.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast<const char*>(&key.sin4.sin_port), sizeof(key.sin4.sin_port));
+      std::string value;
+      BOOST_CHECK_EQUAL(kvs->getRangeValue(lookupKey, value), shouldBeFound);
+      if (shouldBeFound) {
+        BOOST_CHECK_EQUAL(value, expectedValue);
+      }
+    };
+
+    check(first, true);
+    check(last, true);
+    check(inside, true);
+    check(notInRange1, false);
+    check(notInRange2, false);
+  }
+
+}
+#endif // defined(HAVE_LMDB)
+
 #endif // defined(HAVE_LMDB) || defined(HAVE_CDB)
 
 BOOST_AUTO_TEST_SUITE(dnsdistkvs_cc)
@@ -247,21 +320,46 @@ BOOST_AUTO_TEST_CASE(test_LMDB) {
   v4Masked.truncate(25);
   v6Masked.truncate(65);
 
+  const ComboAddress firstRangeAddr6("2001:0db8:0000:0000:0000:0000:0000:0000");
+  const ComboAddress lastRangeAddr6("2001:0db8:ffff:ffff:ffff:ffff:ffff:ffff");
+  const ComboAddress firstRangeAddr4("192.0.2.1:0");
+  const ComboAddress lastRangeAddr4("192.0.2.1:16383");
+
   const string dbPath("/tmp/test_lmdb.XXXXXX");
   {
     MDBEnv env(dbPath.c_str(), MDB_NOSUBDIR, 0600);
     auto transaction = env.getRWTransaction();
     auto dbi = transaction->openDB("db-name", MDB_CREATE);
     transaction->put(dbi, MDBInVal(std::string(reinterpret_cast<const char*>(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr))), MDBInVal("this is the value for the remote addr"));
+    transaction->put(dbi, MDBInVal(std::string(reinterpret_cast<const char*>(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast<const char*>(&rem.sin4.sin_port), sizeof(rem.sin4.sin_port))), MDBInVal("this is the value for the remote addr + port"));
     transaction->put(dbi, MDBInVal(std::string(reinterpret_cast<const char*>(&v4Masked.sin4.sin_addr.s_addr), sizeof(v4Masked.sin4.sin_addr.s_addr))), MDBInVal("this is the value for the masked v4 addr"));
     transaction->put(dbi, MDBInVal(std::string(reinterpret_cast<const char*>(&v6Masked.sin6.sin6_addr.s6_addr), sizeof(v6Masked.sin6.sin6_addr.s6_addr))), MDBInVal("this is the value for the masked v6 addr"));
     transaction->put(dbi, MDBInVal(qname.toDNSStringLC()), MDBInVal("this is the value for the qname"));
     transaction->put(dbi, MDBInVal(plaintextDomain.toStringRootDot()), MDBInVal("this is the value for the plaintext domain"));
+
+    transaction->commit();
+  }
+
+  {
+    MDBEnv env(dbPath.c_str(), MDB_NOSUBDIR, 0600);
+    auto transaction = env.getRWTransaction();
+    auto dbi = transaction->openDB("range-db-name", MDB_CREATE);
+    /* range-based lookups */
+    std::string value = std::string(reinterpret_cast<const char*>(&firstRangeAddr6.sin6.sin6_addr.s6_addr), sizeof(firstRangeAddr6.sin6.sin6_addr.s6_addr)) + std::string("any other data");
+    transaction->put(dbi, MDBInVal(std::string(reinterpret_cast<const char*>(&lastRangeAddr6.sin6.sin6_addr.s6_addr), sizeof(lastRangeAddr6.sin6.sin6_addr.s6_addr))), MDBInVal(value));
+
+    value = std::string(reinterpret_cast<const char*>(&firstRangeAddr4.sin4.sin_addr.s_addr), sizeof(firstRangeAddr4.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast<const char*>(&firstRangeAddr4.sin4.sin_port), sizeof(firstRangeAddr4.sin4.sin_port)) + std::string("any other data");
+    transaction->put(dbi, MDBInVal(std::string(reinterpret_cast<const char*>(&lastRangeAddr4.sin4.sin_addr.s_addr), sizeof(lastRangeAddr4.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast<const char*>(&lastRangeAddr4.sin4.sin_port), sizeof(lastRangeAddr4.sin4.sin_port))), MDBInVal(value));
+
     transaction->commit();
   }
 
   auto lmdb = std::unique_ptr<KeyValueStore>(new LMDBKVStore(dbPath, "db-name"));
   doKVSChecks(lmdb, lc, rem, dq, plaintextDomain);
+  lmdb.reset();
+
+  lmdb = std::unique_ptr<KeyValueStore>(new LMDBKVStore(dbPath, "range-db-name"));
+  doKVSRangeChecks(lmdb);
   /*
   std::string value;
   DTime dt;
@@ -308,6 +406,7 @@ BOOST_AUTO_TEST_CASE(test_CDB) {
     BOOST_REQUIRE(fd >= 0);
     CDBWriter writer(fd);
     BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast<const char*>(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr)), "this is the value for the remote addr"));
+    BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast<const char*>(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast<const char*>(&rem.sin4.sin_port), sizeof(rem.sin4.sin_port)), "this is the value for the remote addr + port"));
     BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast<const char*>(&v4Masked.sin4.sin_addr.s_addr), sizeof(v4Masked.sin4.sin_addr.s_addr)), "this is the value for the masked v4 addr"));
     BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast<const char*>(&v6Masked.sin6.sin6_addr.s6_addr), sizeof(v6Masked.sin6.sin6_addr.s6_addr)), "this is the value for the masked v6 addr"));
     BOOST_REQUIRE(writer.addEntry(qname.toDNSStringLC(), "this is the value for the qname"));
index 0fb1143e826d57136d0e033a495a9ec9b03d8b05..5e308d7d32bceb737829a713e38647520f602604 100644 (file)
@@ -4,6 +4,8 @@ import dns
 import lmdb
 import os
 import socket
+import struct
+
 from dnsdisttests import DNSDistTest
 
 @unittest.skipIf('SKIP_LMDB_TESTS' in os.environ, 'LMDB tests are disabled')
@@ -192,3 +194,119 @@ class TestLMDB(DNSDistTest):
             self.assertFalse(receivedQuery)
             self.assertTrue(receivedResponse)
             self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPInRange(DNSDistTest):
+
+    _lmdbFileName = '/tmp/test-lmdb-range-1-db'
+    _lmdbDBName = 'db-name'
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    kvs = newLMDBKVStore('%s', '%s')
+
+    -- KVS range lookups follow
+    -- does a range lookup in the LMDB database using the source IP as key
+    addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+    -- otherwise, spoof a different response
+    addAction(AllRule(), SpoofAction('9.9.9.9'))
+    """
+    _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+    @classmethod
+    def setUpLMDB(cls):
+        env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+        db = env.open_db(key=cls._lmdbDBName.encode())
+        with env.begin(db=db, write=True) as txn:
+            txn.put(socket.inet_aton('127.255.255.255') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.setUpLMDB()
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
+
+    def testLMDBSource(self):
+        """
+        LMDB range: Match on source address
+        """
+        name = 'source-ip.lmdb-range.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '5.6.7.8')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPNotInRange(DNSDistTest):
+
+    _lmdbFileName = '/tmp/test-lmdb-range-2-db'
+    _lmdbDBName = 'db-name'
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    kvs = newLMDBKVStore('%s', '%s')
+
+    -- KVS range lookups follow
+    -- does a range lookup in the LMDB database using the source IP as key
+    addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+    -- otherwise, spoof a different response
+    addAction(AllRule(), SpoofAction('9.9.9.9'))
+    """
+    _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+    @classmethod
+    def setUpLMDB(cls):
+        env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+        db = env.open_db(key=cls._lmdbDBName.encode())
+        with env.begin(db=db, write=True) as txn:
+            txn.put(socket.inet_aton('127.0.0.0') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.setUpLMDB()
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
+
+    def testLMDBSource(self):
+        """
+        LMDB not in range: Match on source address
+        """
+        name = 'source-ip.lmdb-not-in-range.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '9.9.9.9')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)