]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Move rules to dnsdist-rules-factory.hh
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 24 Dec 2024 14:10:46 +0000 (15:10 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 16 Jan 2025 08:50:25 +0000 (09:50 +0100)
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-lua-rules.cc
pdns/dnsdistdist/dnsdist-rules-factory.hh [new file with mode: 0644]
pdns/dnsdistdist/dnsdist-rules.cc
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/test-dnsdistrules_cc.cc

index 8a881c8b43e5d830bd2c2fb5dd7067ca71437625..6c17d416e19b3ee547360352cd8a8b1c50f2d1dd 100644 (file)
@@ -203,6 +203,7 @@ dnsdist_SOURCES = \
        dnsdist-resolver.cc dnsdist-resolver.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
        dnsdist-rule-chains.cc dnsdist-rule-chains.hh \
+       dnsdist-rules-factory.hh \
        dnsdist-rules.cc dnsdist-rules.hh \
        dnsdist-secpoll.cc dnsdist-secpoll.hh \
        dnsdist-self-answers.cc dnsdist-self-answers.hh \
@@ -312,6 +313,7 @@ testrunner_SOURCES = \
        dnsdist-resolver.cc dnsdist-resolver.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
        dnsdist-rule-chains.cc dnsdist-rule-chains.hh \
+       dnsdist-rules-factory.hh \
        dnsdist-rules.cc dnsdist-rules.hh \
        dnsdist-self-answers.cc dnsdist-self-answers.hh \
        dnsdist-session-cache.cc dnsdist-session-cache.hh \
index 10937c03fcdf29853298713d42c6d85623871214..dd86ec473cd5e4f35c81d220f8698ec7cf114b23 100644 (file)
@@ -21,7 +21,7 @@
  */
 #include "dnsdist.hh"
 #include "dnsdist-lua.hh"
-#include "dnsdist-rules.hh"
+#include "dnsdist-rules-factory.hh"
 #include "dnsdist-rule-chains.hh"
 #include "dns_random.hh"
 
diff --git a/pdns/dnsdistdist/dnsdist-rules-factory.hh b/pdns/dnsdistdist/dnsdist-rules-factory.hh
new file mode 100644 (file)
index 0000000..470cbda
--- /dev/null
@@ -0,0 +1,1415 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "dnsdist-rules.hh"
+
+#include "cachecleaner.hh"
+#include "dnsdist-ecs.hh"
+#include "dnsdist-kvs.hh"
+#include "dnsdist-lua.hh"
+#include "dnsdist-lua-ffi.hh"
+#include "dolog.hh"
+#include "dnsparser.hh"
+#include "dns_random.hh"
+
+#include <boost/multi_index_container.hpp>
+#include <boost/multi_index/ordered_index.hpp>
+#include <boost/multi_index/sequenced_index.hpp>
+#include <boost/multi_index/key_extractors.hpp>
+
+class MaxQPSIPRule : public DNSRule
+{
+public:
+  MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10, size_t shardsCount=10):
+    d_shards(shardsCount), d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction)
+  {
+    d_cleaningUp.clear();
+    gettime(&d_lastCleanup, true);
+  }
+
+  void clear()
+  {
+    for (auto& shard : d_shards) {
+      shard.lock()->clear();
+    }
+  }
+
+  size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const
+  {
+    size_t removed = 0;
+    if (scannedCount != nullptr) {
+      *scannedCount = 0;
+    }
+
+    for (auto& shard : d_shards) {
+      auto limits = shard.lock();
+      const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction)+ 1;
+      size_t lookedAt = 0;
+
+      auto& sequence = limits->get<SequencedTag>();
+      for (auto entry = sequence.begin(); entry != sequence.end() && lookedAt < toLook; lookedAt++) {
+        if (entry->d_limiter.seenSince(cutOff)) {
+          /* entries are ordered from least recently seen to more recently
+             seen, as soon as we see one that has not expired yet, we are
+             done */
+          lookedAt++;
+          break;
+        }
+
+        entry = sequence.erase(entry);
+        removed++;
+      }
+
+      if (scannedCount != nullptr) {
+        *scannedCount += lookedAt;
+      }
+    }
+
+    return removed;
+  }
+
+  void cleanupIfNeeded(const struct timespec& now) const
+  {
+    if (d_cleanupDelay > 0) {
+      struct timespec cutOff = d_lastCleanup;
+      cutOff.tv_sec += d_cleanupDelay;
+
+      if (cutOff < now) {
+        try {
+          if (d_cleaningUp.test_and_set()) {
+            return;
+          }
+
+          d_lastCleanup = now;
+          /* the QPS Limiter doesn't use realtime, be careful! */
+          gettime(&cutOff, false);
+          cutOff.tv_sec -= d_expiration;
+
+          cleanup(cutOff);
+          d_cleaningUp.clear();
+        }
+        catch (...) {
+          d_cleaningUp.clear();
+          throw;
+        }
+      }
+    }
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    cleanupIfNeeded(dq->getQueryRealTime());
+
+    ComboAddress zeroport(dq->ids.origRemote);
+    zeroport.sin4.sin_port=0;
+    zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
+    auto hash = ComboAddress::addressOnlyHash()(zeroport);
+    auto& shard = d_shards[hash % d_shards.size()];
+    {
+      auto limits = shard.lock();
+      auto iter = limits->find(zeroport);
+      if (iter == limits->end()) {
+        Entry e(zeroport, QPSLimiter(d_qps, d_burst));
+        iter = limits->insert(e).first;
+      }
+
+      moveCacheItemToBack<SequencedTag>(*limits, iter);
+      return !iter->d_limiter.check(d_qps, d_burst);
+    }
+  }
+
+  string toString() const override
+  {
+    return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
+  }
+
+  size_t getEntriesCount() const
+  {
+    size_t count = 0;
+    for (auto& shard : d_shards) {
+      count += shard.lock()->size();
+    }
+    return count;
+  }
+
+  size_t getNumberOfShards() const
+  {
+    return d_shards.size();
+  }
+
+private:
+  struct HashedTag {};
+  struct SequencedTag {};
+  struct Entry
+  {
+    Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr)
+    {
+    }
+    mutable BasicQPSLimiter d_limiter;
+    ComboAddress d_addr;
+  };
+
+  typedef multi_index_container<
+    Entry,
+    indexed_by <
+      hashed_unique<tag<HashedTag>, member<Entry,ComboAddress,&Entry::d_addr>, ComboAddress::addressOnlyHash >,
+      sequenced<tag<SequencedTag> >
+      >
+  > qpsContainer_t;
+
+  mutable std::vector<LockGuarded<qpsContainer_t>> d_shards;
+  mutable struct timespec d_lastCleanup;
+  const unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration;
+  const unsigned int d_scanFraction{10};
+  mutable std::atomic_flag d_cleaningUp;
+};
+
+class MaxQPSRule : public DNSRule
+{
+public:
+  MaxQPSRule(unsigned int qps)
+   : d_qps(qps, qps)
+  {}
+
+  MaxQPSRule(unsigned int qps, unsigned int burst)
+   : d_qps(qps, burst)
+  {}
+
+
+  bool matches(const DNSQuestion* qd) const override
+  {
+    return d_qps.check();
+  }
+
+  string toString() const override
+  {
+    return "Max " + std::to_string(d_qps.getRate()) + " qps";
+  }
+
+
+private:
+  mutable QPSLimiter d_qps;
+};
+
+class NMGRule : public DNSRule
+{
+public:
+  NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
+protected:
+  NetmaskGroup d_nmg;
+};
+
+class NetmaskGroupRule : public NMGRule
+{
+public:
+  NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : NMGRule(nmg)
+  {
+      d_src = src;
+      d_quiet = quiet;
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if(!d_src) {
+        return d_nmg.match(dq->ids.origDest);
+    }
+    return d_nmg.match(dq->ids.origRemote);
+  }
+
+  string toString() const override
+  {
+    string ret = "Src: ";
+    if(!d_src) {
+        ret = "Dst: ";
+    }
+    if (d_quiet) {
+      return ret + "in-group";
+    }
+    return ret + d_nmg.toString();
+  }
+private:
+  bool d_src;
+  bool d_quiet;
+};
+
+class TimedIPSetRule : public DNSRule, boost::noncopyable
+{
+private:
+  struct IPv6 {
+    IPv6(const ComboAddress& ca)
+    {
+      static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
+      memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
+    }
+    bool operator==(const IPv6& rhs) const
+    {
+      return a==rhs.a && b==rhs.b;
+    }
+    uint64_t a, b;
+  };
+
+public:
+  TimedIPSetRule()
+  {
+  }
+  ~TimedIPSetRule()
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if (dq->ids.origRemote.sin4.sin_family == AF_INET) {
+      auto ip4s = d_ip4s.read_lock();
+      auto fnd = ip4s->find(dq->ids.origRemote.sin4.sin_addr.s_addr);
+      if (fnd == ip4s->end()) {
+        return false;
+      }
+      return time(nullptr) < fnd->second;
+    } else {
+      auto ip6s = d_ip6s.read_lock();
+      auto fnd = ip6s->find({dq->ids.origRemote});
+      if (fnd == ip6s->end()) {
+        return false;
+      }
+      return time(nullptr) < fnd->second;
+    }
+  }
+
+  void add(const ComboAddress& ca, time_t ttd)
+  {
+    // think twice before adding templates here
+    if (ca.sin4.sin_family == AF_INET) {
+      auto res = d_ip4s.write_lock()->insert({ca.sin4.sin_addr.s_addr, ttd});
+      if (!res.second && (time_t)res.first->second < ttd) {
+        res.first->second = (uint32_t)ttd;
+      }
+    }
+    else {
+      auto res = d_ip6s.write_lock()->insert({{ca}, ttd});
+      if (!res.second && (time_t)res.first->second < ttd) {
+        // coverity[store_truncates_time_t]
+        res.first->second = (uint32_t)ttd;
+      }
+    }
+  }
+
+  void remove(const ComboAddress& ca)
+  {
+    if (ca.sin4.sin_family == AF_INET) {
+      d_ip4s.write_lock()->erase(ca.sin4.sin_addr.s_addr);
+    }
+    else {
+      d_ip6s.write_lock()->erase({ca});
+    }
+  }
+
+  void clear()
+  {
+    d_ip4s.write_lock()->clear();
+    d_ip6s.write_lock()->clear();
+  }
+
+  void cleanup()
+  {
+    time_t now = time(nullptr);
+    {
+      auto ip4s = d_ip4s.write_lock();
+      for (auto iter = ip4s->begin(); iter != ip4s->end(); ) {
+       if (iter->second < now) {
+         iter = ip4s->erase(iter);
+        }
+       else {
+         ++iter;
+        }
+      }
+    }
+
+    {
+      auto ip6s = d_ip6s.write_lock();
+      for (auto iter = ip6s->begin(); iter != ip6s->end(); ) {
+       if (iter->second < now) {
+         iter = ip6s->erase(iter);
+        }
+       else {
+         ++iter;
+        }
+      }
+
+    }
+
+  }
+
+  string toString() const override
+  {
+    time_t now = time(nullptr);
+    uint64_t count = 0;
+
+    for (const auto& ip : *(d_ip4s.read_lock())) {
+      if (now < ip.second) {
+        ++count;
+      }
+    }
+
+    for (const auto& ip : *(d_ip6s.read_lock())) {
+      if (now < ip.second) {
+        ++count;
+      }
+    }
+
+    return "Src: "+std::to_string(count)+" ips";
+  }
+private:
+  struct IPv6Hash
+  {
+    std::size_t operator()(const IPv6& ip) const
+    {
+      auto ah=std::hash<uint64_t>{}(ip.a);
+      auto bh=std::hash<uint64_t>{}(ip.b);
+      return ah & (bh<<1);
+    }
+  };
+  mutable SharedLockGuarded<std::unordered_map<IPv6, time_t, IPv6Hash>> d_ip6s;
+  mutable SharedLockGuarded<std::unordered_map<uint32_t, time_t>> d_ip4s;
+};
+
+
+class AllRule : public DNSRule
+{
+public:
+  AllRule() {}
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return true;
+  }
+
+  string toString() const override
+  {
+    return "All";
+  }
+
+};
+
+
+class DNSSECRule : public DNSRule
+{
+public:
+  DNSSECRule()
+  {
+
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
+  }
+
+  string toString() const override
+  {
+    return "DNSSEC";
+  }
+};
+
+class AndRule : public DNSRule
+{
+public:
+  AndRule(const std::vector<pair<int, std::shared_ptr<DNSRule> > >& rules)
+  {
+    for (const auto& r : rules) {
+      d_rules.push_back(r.second);
+    }
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    for (const auto& rule : d_rules) {
+      if (!rule->matches(dq)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  string toString() const override
+  {
+    string ret;
+    for (const auto& rule : d_rules) {
+      if (!ret.empty()) {
+        ret+= " && ";
+      }
+      ret += "("+ rule->toString()+")";
+    }
+    return ret;
+  }
+private:
+  std::vector<std::shared_ptr<DNSRule> > d_rules;
+};
+
+
+class OrRule : public DNSRule
+{
+public:
+  OrRule(const std::vector<pair<int, std::shared_ptr<DNSRule> > >& rules)
+  {
+    for (const auto& r : rules) {
+      d_rules.push_back(r.second);
+    }
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    for (const auto& rule: d_rules) {
+      if (rule->matches(dq)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  string toString() const override
+  {
+    string ret;
+    for (const auto& rule : d_rules) {
+      if (!ret.empty()) {
+        ret+= " || ";
+      }
+      ret += "("+ rule->toString()+")";
+    }
+    return ret;
+  }
+private:
+  std::vector<std::shared_ptr<DNSRule> > d_rules;
+};
+
+
+class RegexRule : public DNSRule
+{
+public:
+  RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
+  {
+
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_regex.match(dq->ids.qname.toStringNoDot());
+  }
+
+  string toString() const override
+  {
+    return "Regex: "+d_visual;
+  }
+private:
+  Regex d_regex;
+  string d_visual;
+};
+
+#ifdef HAVE_RE2
+#include <re2/re2.h>
+class RE2Rule : public DNSRule
+{
+public:
+  RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
+  {
+
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return RE2::FullMatch(dq->ids.qname.toStringNoDot(), d_re2);
+  }
+
+  string toString() const override
+  {
+    return "RE2 match: "+d_visual;
+  }
+private:
+  RE2 d_re2;
+  string d_visual;
+};
+#endif
+
+#ifdef HAVE_DNS_OVER_HTTPS
+class HTTPHeaderRule : public DNSRule
+{
+public:
+  HTTPHeaderRule(const std::string& header, const std::string& regex);
+  bool matches(const DNSQuestion* dq) const override;
+  string toString() const override;
+private:
+  string d_header;
+  Regex d_regex;
+  string d_visual;
+};
+
+class HTTPPathRule : public DNSRule
+{
+public:
+  HTTPPathRule(std::string path);
+  bool matches(const DNSQuestion* dq) const override;
+  string toString() const override;
+private:
+  string d_path;
+};
+
+class HTTPPathRegexRule : public DNSRule
+{
+public:
+  HTTPPathRegexRule(const std::string& regex);
+  bool matches(const DNSQuestion* dq) const override;
+  string toString() const override;
+private:
+  Regex d_regex;
+  std::string d_visual;
+};
+#endif
+
+class SNIRule : public DNSRule
+{
+public:
+  SNIRule(const std::string& name) : d_sni(name)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->sni == d_sni;
+  }
+  string toString() const override
+  {
+    return "SNI == " + d_sni;
+  }
+private:
+  std::string d_sni;
+};
+
+class SuffixMatchNodeRule : public DNSRule
+{
+public:
+  SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_smn.check(dq->ids.qname);
+  }
+  string toString() const override
+  {
+    if(d_quiet)
+      return "qname==in-set";
+    else
+      return "qname in "+d_smn.toString();
+  }
+private:
+  SuffixMatchNode d_smn;
+  bool d_quiet;
+};
+
+class QNameRule : public DNSRule
+{
+public:
+  QNameRule(const DNSName& qname) : d_qname(qname)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_qname==dq->ids.qname;
+  }
+  string toString() const override
+  {
+    return "qname=="+d_qname.toString();
+  }
+private:
+  DNSName d_qname;
+};
+
+class QNameSetRule : public DNSRule {
+public:
+    QNameSetRule(const DNSNameSet& names) : qname_idx(names) {}
+
+    bool matches(const DNSQuestion* dq) const override {
+        return qname_idx.find(dq->ids.qname) != qname_idx.end();
+    }
+
+    string toString() const override {
+        std::stringstream ss;
+        ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)";
+        return ss.str();
+    }
+private:
+    DNSNameSet qname_idx;
+};
+
+class QTypeRule : public DNSRule
+{
+public:
+  QTypeRule(uint16_t qtype) : d_qtype(qtype)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_qtype == dq->ids.qtype;
+  }
+  string toString() const override
+  {
+    QType qt(d_qtype);
+    return "qtype=="+qt.toString();
+  }
+private:
+  uint16_t d_qtype;
+};
+
+class QClassRule : public DNSRule
+{
+public:
+  QClassRule(uint16_t qclass) : d_qclass(qclass)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_qclass == dq->ids.qclass;
+  }
+  string toString() const override
+  {
+    return "qclass=="+std::to_string(d_qclass);
+  }
+private:
+  uint16_t d_qclass;
+};
+
+class OpcodeRule : public DNSRule
+{
+public:
+  OpcodeRule(uint8_t opcode) : d_opcode(opcode)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_opcode == dq->getHeader()->opcode;
+  }
+  string toString() const override
+  {
+    return "opcode=="+std::to_string(d_opcode);
+  }
+private:
+  uint8_t d_opcode;
+};
+
+class DSTPortRule : public DNSRule
+{
+public:
+  DSTPortRule(uint16_t port) : d_port(port)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return htons(d_port) == dq->ids.origDest.sin4.sin_port;
+  }
+  string toString() const override
+  {
+    return "dst port=="+std::to_string(d_port);
+  }
+private:
+  uint16_t d_port;
+};
+
+class TCPRule : public DNSRule
+{
+public:
+  TCPRule(bool tcp): d_tcp(tcp)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->overTCP() == d_tcp;
+  }
+  string toString() const override
+  {
+    return (d_tcp ? "TCP" : "UDP");
+  }
+private:
+  bool d_tcp;
+};
+
+
+class NotRule : public DNSRule
+{
+public:
+  NotRule(const std::shared_ptr<DNSRule>& rule): d_rule(rule)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return !d_rule->matches(dq);
+  }
+  string toString() const override
+  {
+    return "!("+ d_rule->toString()+")";
+  }
+private:
+  std::shared_ptr<DNSRule> d_rule;
+};
+
+class RecordsCountRule : public DNSRule
+{
+public:
+  RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    uint16_t count = 0;
+    switch(d_section) {
+    case 0:
+      count = ntohs(dq->getHeader()->qdcount);
+      break;
+    case 1:
+      count = ntohs(dq->getHeader()->ancount);
+      break;
+    case 2:
+      count = ntohs(dq->getHeader()->nscount);
+      break;
+    case 3:
+      count = ntohs(dq->getHeader()->arcount);
+      break;
+    }
+    return count >= d_minCount && count <= d_maxCount;
+  }
+  string toString() const override
+  {
+    string section;
+    switch(d_section) {
+    case 0:
+      section = "QD";
+      break;
+    case 1:
+      section = "AN";
+      break;
+    case 2:
+      section = "NS";
+      break;
+    case 3:
+      section = "AR";
+      break;
+    }
+    return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
+  }
+private:
+  uint16_t d_minCount;
+  uint16_t d_maxCount;
+  uint8_t d_section;
+};
+
+class RecordsTypeCountRule : public DNSRule
+{
+public:
+  RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    uint16_t count = 0;
+    switch(d_section) {
+    case 0:
+      count = ntohs(dq->getHeader()->qdcount);
+      break;
+    case 1:
+      count = ntohs(dq->getHeader()->ancount);
+      break;
+    case 2:
+      count = ntohs(dq->getHeader()->nscount);
+      break;
+    case 3:
+      count = ntohs(dq->getHeader()->arcount);
+      break;
+    }
+    if (count < d_minCount) {
+      return false;
+    }
+    count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size(), d_section, d_type);
+    return count >= d_minCount && count <= d_maxCount;
+  }
+  string toString() const override
+  {
+    string section;
+    switch(d_section) {
+    case 0:
+      section = "QD";
+      break;
+    case 1:
+      section = "AN";
+      break;
+    case 2:
+      section = "NS";
+      break;
+    case 3:
+      section = "AR";
+      break;
+    }
+    return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= "+ std::to_string(d_maxCount);
+  }
+private:
+  uint16_t d_type;
+  uint16_t d_minCount;
+  uint16_t d_maxCount;
+  uint8_t d_section;
+};
+
+class TrailingDataRule : public DNSRule
+{
+public:
+  TrailingDataRule()
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size());
+    return length < dq->getData().size();
+  }
+  string toString() const override
+  {
+    return "trailing data";
+  }
+};
+
+class QNameLabelsCountRule : public DNSRule
+{
+public:
+  QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    unsigned int count = dq->ids.qname.countLabels();
+    return count < d_min || count > d_max;
+  }
+  string toString() const override
+  {
+    return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
+  }
+private:
+  unsigned int d_min;
+  unsigned int d_max;
+};
+
+class QNameWireLengthRule : public DNSRule
+{
+public:
+  QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    size_t const wirelength = dq->ids.qname.wirelength();
+    return wirelength < d_min || wirelength > d_max;
+  }
+  string toString() const override
+  {
+    return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
+  }
+private:
+  size_t d_min;
+  size_t d_max;
+};
+
+class RCodeRule : public DNSRule
+{
+public:
+  RCodeRule(uint8_t rcode) : d_rcode(rcode)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_rcode == dq->getHeader()->rcode;
+  }
+  string toString() const override
+  {
+    return "rcode=="+RCode::to_s(d_rcode);
+  }
+private:
+  uint8_t d_rcode;
+};
+
+class ERCodeRule : public DNSRule
+{
+public:
+  ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    // avoid parsing EDNS OPT RR when not needed.
+    if (d_rcode != dq->getHeader()->rcode) {
+      return false;
+    }
+
+    EDNS0Record edns0;
+    if (!getEDNS0Record(dq->getData(), edns0)) {
+      return false;
+    }
+
+    return d_extrcode == edns0.extRCode;
+  }
+  string toString() const override
+  {
+    return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
+  }
+private:
+  uint8_t d_rcode;     // plain DNS Rcode
+  uint8_t d_extrcode;  // upper bits in EDNS0 record
+};
+
+class EDNSVersionRule : public DNSRule
+{
+public:
+  EDNSVersionRule(uint8_t version) : d_version(version)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    EDNS0Record edns0;
+    if (!getEDNS0Record(dq->getData(), edns0)) {
+      return false;
+    }
+
+    return d_version < edns0.version;
+  }
+  string toString() const override
+  {
+    return "ednsversion>"+std::to_string(d_version);
+  }
+private:
+  uint8_t d_version;
+};
+
+class EDNSOptionRule : public DNSRule
+{
+public:
+  EDNSOptionRule(uint16_t optcode) : d_optcode(optcode)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    uint16_t optStart;
+    size_t optLen = 0;
+    bool last = false;
+    int res = locateEDNSOptRR(dq->getData(), &optStart, &optLen, &last);
+    if (res != 0) {
+      // no EDNS OPT RR
+      return false;
+    }
+
+    if (optLen < optRecordMinimumSize) {
+      return false;
+    }
+
+    if (optStart < dq->getData().size() && dq->getData().at(optStart) != 0) {
+      // OPT RR Name != '.'
+      return false;
+    }
+
+    return isEDNSOptionInOpt(dq->getData(), optStart, optLen, d_optcode);
+  }
+  string toString() const override
+  {
+    return "ednsoptcode=="+std::to_string(d_optcode);
+  }
+private:
+  uint16_t d_optcode;
+};
+
+class RDRule : public DNSRule
+{
+public:
+  RDRule()
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->getHeader()->rd == 1;
+  }
+  string toString() const override
+  {
+    return "rd==1";
+  }
+};
+
+class ProbaRule : public DNSRule
+{
+public:
+  ProbaRule(double proba) : d_proba(proba)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if(d_proba == 1.0)
+      return true;
+    double rnd = 1.0*dns_random_uint32() / UINT32_MAX;
+    return rnd > (1.0 - d_proba);
+  }
+  string toString() const override
+  {
+    return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
+  }
+private:
+  double d_proba;
+};
+
+class TagRule : public DNSRule
+{
+public:
+  TagRule(const std::string& tag, boost::optional<std::string> value) : d_value(std::move(value)), d_tag(tag)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if (!dq->ids.qTag) {
+      return false;
+    }
+
+    const auto it = dq->ids.qTag->find(d_tag);
+    if (it == dq->ids.qTag->cend()) {
+      return false;
+    }
+
+    if (!d_value) {
+      return true;
+    }
+
+    return it->second == *d_value;
+  }
+
+  string toString() const override
+  {
+    return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
+  }
+
+private:
+  boost::optional<std::string> d_value;
+  std::string d_tag;
+};
+
+class PoolAvailableRule : public DNSRule
+{
+public:
+  PoolAvailableRule(const std::string& poolname) : d_poolname(poolname)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return (getPool(d_poolname)->countServers(true) > 0);
+  }
+
+  string toString() const override
+  {
+    return "pool '" + d_poolname + "' is available";
+  }
+private:
+  std::string d_poolname;
+};
+
+class PoolOutstandingRule : public DNSRule
+{
+public:
+  PoolOutstandingRule(const std::string& poolname, const size_t limit) : d_poolname(poolname), d_limit(limit)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return (getPool(d_poolname)->poolLoad()) > d_limit;
+  }
+
+  string toString() const override
+  {
+    return "pool '" + d_poolname + "' outstanding > " + std::to_string(d_limit);
+  }
+private:
+  std::string d_poolname;
+  size_t d_limit;
+};
+
+class KeyValueStoreLookupRule: public DNSRule
+{
+public:
+  KeyValueStoreLookupRule(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey): d_kvs(kvs), d_key(lookupKey)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    std::vector<std::string> keys = d_key->getKeys(*dq);
+    for (const auto& key : keys) {
+      if (d_kvs->keyExists(key) == true) {
+        return true;
+      }
+    }
+
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "lookup key-value store based on '" + d_key->toString() + "'";
+  }
+
+private:
+  std::shared_ptr<KeyValueStore> d_kvs;
+  std::shared_ptr<KeyValueLookupKey> d_key;
+};
+
+class KeyValueStoreRangeLookupRule: public DNSRule
+{
+public:
+  KeyValueStoreRangeLookupRule(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey): d_kvs(kvs), d_key(lookupKey)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    std::vector<std::string> keys = d_key->getKeys(*dq);
+    for (const auto& key : keys) {
+      std::string value;
+      if (d_kvs->getRangeValue(key, value) == true) {
+        return true;
+      }
+    }
+
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "range-based lookup key-value store based on '" + d_key->toString() + "'";
+  }
+
+private:
+  std::shared_ptr<KeyValueStore> d_kvs;
+  std::shared_ptr<KeyValueLookupKey> d_key;
+};
+
+class LuaRule : public DNSRule
+{
+public:
+  typedef std::function<bool(const DNSQuestion* dq)> func_t;
+  LuaRule(const func_t& func): d_func(func)
+  {}
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    try {
+      auto lock = g_lua.lock();
+      return d_func(dq);
+    } catch (const std::exception &e) {
+      warnlog("LuaRule failed inside Lua: %s", e.what());
+    } catch (...) {
+      warnlog("LuaRule failed inside Lua: [unknown exception]");
+    }
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "Lua script";
+  }
+private:
+  func_t d_func;
+};
+
+class LuaFFIRule : public DNSRule
+{
+public:
+  typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+  LuaFFIRule(const func_t& func): d_func(func)
+  {}
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
+    try {
+      auto lock = g_lua.lock();
+      return d_func(&dqffi);
+    } catch (const std::exception &e) {
+      warnlog("LuaFFIRule failed inside Lua: %s", e.what());
+    } catch (...) {
+      warnlog("LuaFFIRule failed inside Lua: [unknown exception]");
+    }
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI script";
+  }
+private:
+  func_t d_func;
+};
+
+class LuaFFIPerThreadRule : public DNSRule
+{
+public:
+  typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
+
+  LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    try {
+      auto& state = t_perThreadStates[d_functionID];
+      if (!state.d_initialized) {
+        setupLuaFFIPerThreadContext(state.d_luaContext);
+        /* mark the state as initialized first so if there is a syntax error
+           we only try to execute the code once */
+        state.d_initialized = true;
+        state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
+      }
+
+      if (!state.d_func) {
+        /* the function was not properly initialized */
+        return false;
+      }
+
+      dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
+      return state.d_func(&dqffi);
+    }
+    catch (const std::exception &e) {
+      warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what());
+    }
+    catch (...) {
+      warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]");
+    }
+    return false;
+  }
+
+  string toString() const override
+  {
+    return "Lua FFI per-thread script";
+  }
+private:
+  struct PerThreadState
+  {
+    LuaContext d_luaContext;
+    func_t d_func;
+    bool d_initialized{false};
+  };
+
+  static std::atomic<uint64_t> s_functionsCounter;
+  static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
+  const std::string d_functionCode;
+  const uint64_t d_functionID;
+};
+
+class ProxyProtocolValueRule : public DNSRule
+{
+public:
+  ProxyProtocolValueRule(uint8_t type, boost::optional<std::string> value): d_value(std::move(value)), d_type(type)
+  {
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if (!dq->proxyProtocolValues) {
+      return false;
+    }
+
+    for (const auto& entry : *dq->proxyProtocolValues) {
+      if (entry.type == d_type && (!d_value || entry.content == *d_value)) {
+        return true;
+      }
+    }
+
+    return false;
+  }
+
+  string toString() const override
+  {
+    if (d_value) {
+      return "proxy protocol value of type " + std::to_string(d_type) + " matches";
+    }
+    return "proxy protocol value of type " + std::to_string(d_type) + " is present";
+  }
+
+private:
+  boost::optional<std::string> d_value;
+  uint8_t d_type;
+};
+
+class PayloadSizeRule : public DNSRule
+{
+  enum class Comparisons : uint8_t { equal, greater, greaterOrEqual, smaller, smallerOrEqual };
+public:
+  PayloadSizeRule(const std::string& comparison, uint16_t size): d_size(size)
+  {
+    if (comparison == "equal") {
+      d_comparison = Comparisons::equal;
+    }
+    else if (comparison == "greater") {
+      d_comparison = Comparisons::greater;
+    }
+    else if (comparison == "greaterOrEqual") {
+      d_comparison = Comparisons::greaterOrEqual;
+    }
+    else if (comparison == "smaller") {
+      d_comparison = Comparisons::smaller;
+    }
+    else if (comparison == "smallerOrEqual") {
+      d_comparison = Comparisons::smallerOrEqual;
+    }
+    else {
+      throw std::runtime_error("Unsupported comparison '" + comparison + "'");
+    }
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    const auto size = dq->getData().size();
+
+    switch (d_comparison) {
+    case Comparisons::equal:
+      return size == d_size;
+    case Comparisons::greater:
+      return size > d_size;
+    case Comparisons::greaterOrEqual:
+      return size >= d_size;
+    case Comparisons::smaller:
+      return size < d_size;
+    case Comparisons::smallerOrEqual:
+      return size <= d_size;
+    default:
+      return false;
+    }
+  }
+
+  string toString() const override
+  {
+    static const std::array<const std::string, 5> comparisonStr{
+      "equal to" ,
+      "greater than",
+      "equal to or greater than",
+      "smaller than",
+      "equal to or smaller than"
+    };
+    return "payload size is " + comparisonStr.at(static_cast<size_t>(d_comparison)) + " " + std::to_string(d_size);
+  }
+
+private:
+  uint16_t d_size;
+  Comparisons d_comparison;
+};
index b2688b80ddb2ceb15e354a75b64cbe3dd1f7cfe6..fe716a7d1dd2aa06a5e2831246c307e0d0ec735b 100644 (file)
@@ -19,7 +19,7 @@
  * along with this program; if not, write to the Free Software
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
-#include "dnsdist-rules.hh"
+#include "dnsdist-rules-factory.hh"
 
 std::atomic<uint64_t> LuaFFIPerThreadRule::s_functionsCounter = 0;
 thread_local std::map<uint64_t, LuaFFIPerThreadRule::PerThreadState> LuaFFIPerThreadRule::t_perThreadStates;
index 2d8c4a84154946136152545a3f3835ca34d3bad5..aca662b9624e127ddec2fc32965d9c78ce0832a9 100644 (file)
  */
 #pragma once
 
-#include <boost/multi_index_container.hpp>
-#include <boost/multi_index/ordered_index.hpp>
-#include <boost/multi_index/sequenced_index.hpp>
-#include <boost/multi_index/key_extractors.hpp>
+#include <string>
 
-#include "cachecleaner.hh"
 #include "dnsdist.hh"
-#include "dnsdist-ecs.hh"
-#include "dnsdist-kvs.hh"
-#include "dnsdist-lua.hh"
-#include "dnsdist-lua-ffi.hh"
-#include "dolog.hh"
-#include "dnsparser.hh"
-#include "dns_random.hh"
+#include "stat_t.hh"
 
 class DNSRule
 {
@@ -44,1383 +34,6 @@ public:
   }
   virtual bool matches(const DNSQuestion* dq) const = 0;
   virtual string toString() const = 0;
-  mutable stat_t d_matches{0};
-};
-
-
-class MaxQPSIPRule : public DNSRule
-{
-public:
-  MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10, size_t shardsCount=10):
-    d_shards(shardsCount), d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction)
-  {
-    d_cleaningUp.clear();
-    gettime(&d_lastCleanup, true);
-  }
-
-  void clear()
-  {
-    for (auto& shard : d_shards) {
-      shard.lock()->clear();
-    }
-  }
-
-  size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const
-  {
-    size_t removed = 0;
-    if (scannedCount != nullptr) {
-      *scannedCount = 0;
-    }
-
-    for (auto& shard : d_shards) {
-      auto limits = shard.lock();
-      const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction)+ 1;
-      size_t lookedAt = 0;
-
-      auto& sequence = limits->get<SequencedTag>();
-      for (auto entry = sequence.begin(); entry != sequence.end() && lookedAt < toLook; lookedAt++) {
-        if (entry->d_limiter.seenSince(cutOff)) {
-          /* entries are ordered from least recently seen to more recently
-             seen, as soon as we see one that has not expired yet, we are
-             done */
-          lookedAt++;
-          break;
-        }
-
-        entry = sequence.erase(entry);
-        removed++;
-      }
-
-      if (scannedCount != nullptr) {
-        *scannedCount += lookedAt;
-      }
-    }
-
-    return removed;
-  }
-
-  void cleanupIfNeeded(const struct timespec& now) const
-  {
-    if (d_cleanupDelay > 0) {
-      struct timespec cutOff = d_lastCleanup;
-      cutOff.tv_sec += d_cleanupDelay;
-
-      if (cutOff < now) {
-        try {
-          if (d_cleaningUp.test_and_set()) {
-            return;
-          }
-
-          d_lastCleanup = now;
-          /* the QPS Limiter doesn't use realtime, be careful! */
-          gettime(&cutOff, false);
-          cutOff.tv_sec -= d_expiration;
-
-          cleanup(cutOff);
-          d_cleaningUp.clear();
-        }
-        catch (...) {
-          d_cleaningUp.clear();
-          throw;
-        }
-      }
-    }
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    cleanupIfNeeded(dq->getQueryRealTime());
-
-    ComboAddress zeroport(dq->ids.origRemote);
-    zeroport.sin4.sin_port=0;
-    zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
-    auto hash = ComboAddress::addressOnlyHash()(zeroport);
-    auto& shard = d_shards[hash % d_shards.size()];
-    {
-      auto limits = shard.lock();
-      auto iter = limits->find(zeroport);
-      if (iter == limits->end()) {
-        Entry e(zeroport, QPSLimiter(d_qps, d_burst));
-        iter = limits->insert(e).first;
-      }
-
-      moveCacheItemToBack<SequencedTag>(*limits, iter);
-      return !iter->d_limiter.check(d_qps, d_burst);
-    }
-  }
-
-  string toString() const override
-  {
-    return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
-  }
-
-  size_t getEntriesCount() const
-  {
-    size_t count = 0;
-    for (auto& shard : d_shards) {
-      count += shard.lock()->size();
-    }
-    return count;
-  }
-
-  size_t getNumberOfShards() const
-  {
-    return d_shards.size();
-  }
-
-private:
-  struct HashedTag {};
-  struct SequencedTag {};
-  struct Entry
-  {
-    Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr)
-    {
-    }
-    mutable BasicQPSLimiter d_limiter;
-    ComboAddress d_addr;
-  };
-
-  typedef multi_index_container<
-    Entry,
-    indexed_by <
-      hashed_unique<tag<HashedTag>, member<Entry,ComboAddress,&Entry::d_addr>, ComboAddress::addressOnlyHash >,
-      sequenced<tag<SequencedTag> >
-      >
-  > qpsContainer_t;
-
-  mutable std::vector<LockGuarded<qpsContainer_t>> d_shards;
-  mutable struct timespec d_lastCleanup;
-  const unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration;
-  const unsigned int d_scanFraction{10};
-  mutable std::atomic_flag d_cleaningUp;
-};
-
-class MaxQPSRule : public DNSRule
-{
-public:
-  MaxQPSRule(unsigned int qps)
-   : d_qps(qps, qps)
-  {}
-
-  MaxQPSRule(unsigned int qps, unsigned int burst)
-   : d_qps(qps, burst)
-  {}
-
-
-  bool matches(const DNSQuestion* qd) const override
-  {
-    return d_qps.check();
-  }
-
-  string toString() const override
-  {
-    return "Max " + std::to_string(d_qps.getRate()) + " qps";
-  }
-
-
-private:
-  mutable QPSLimiter d_qps;
-};
-
-class NMGRule : public DNSRule
-{
-public:
-  NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
-protected:
-  NetmaskGroup d_nmg;
-};
-
-class NetmaskGroupRule : public NMGRule
-{
-public:
-  NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : NMGRule(nmg)
-  {
-      d_src = src;
-      d_quiet = quiet;
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if(!d_src) {
-        return d_nmg.match(dq->ids.origDest);
-    }
-    return d_nmg.match(dq->ids.origRemote);
-  }
-
-  string toString() const override
-  {
-    string ret = "Src: ";
-    if(!d_src) {
-        ret = "Dst: ";
-    }
-    if (d_quiet) {
-      return ret + "in-group";
-    }
-    return ret + d_nmg.toString();
-  }
-private:
-  bool d_src;
-  bool d_quiet;
-};
-
-class TimedIPSetRule : public DNSRule, boost::noncopyable
-{
-private:
-  struct IPv6 {
-    IPv6(const ComboAddress& ca)
-    {
-      static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
-      memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
-    }
-    bool operator==(const IPv6& rhs) const
-    {
-      return a==rhs.a && b==rhs.b;
-    }
-    uint64_t a, b;
-  };
-
-public:
-  TimedIPSetRule()
-  {
-  }
-  ~TimedIPSetRule()
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if (dq->ids.origRemote.sin4.sin_family == AF_INET) {
-      auto ip4s = d_ip4s.read_lock();
-      auto fnd = ip4s->find(dq->ids.origRemote.sin4.sin_addr.s_addr);
-      if (fnd == ip4s->end()) {
-        return false;
-      }
-      return time(nullptr) < fnd->second;
-    } else {
-      auto ip6s = d_ip6s.read_lock();
-      auto fnd = ip6s->find({dq->ids.origRemote});
-      if (fnd == ip6s->end()) {
-        return false;
-      }
-      return time(nullptr) < fnd->second;
-    }
-  }
-
-  void add(const ComboAddress& ca, time_t ttd)
-  {
-    // think twice before adding templates here
-    if (ca.sin4.sin_family == AF_INET) {
-      auto res = d_ip4s.write_lock()->insert({ca.sin4.sin_addr.s_addr, ttd});
-      if (!res.second && (time_t)res.first->second < ttd) {
-        res.first->second = (uint32_t)ttd;
-      }
-    }
-    else {
-      auto res = d_ip6s.write_lock()->insert({{ca}, ttd});
-      if (!res.second && (time_t)res.first->second < ttd) {
-        // coverity[store_truncates_time_t]
-        res.first->second = (uint32_t)ttd;
-      }
-    }
-  }
-
-  void remove(const ComboAddress& ca)
-  {
-    if (ca.sin4.sin_family == AF_INET) {
-      d_ip4s.write_lock()->erase(ca.sin4.sin_addr.s_addr);
-    }
-    else {
-      d_ip6s.write_lock()->erase({ca});
-    }
-  }
-
-  void clear()
-  {
-    d_ip4s.write_lock()->clear();
-    d_ip6s.write_lock()->clear();
-  }
-
-  void cleanup()
-  {
-    time_t now = time(nullptr);
-    {
-      auto ip4s = d_ip4s.write_lock();
-      for (auto iter = ip4s->begin(); iter != ip4s->end(); ) {
-       if (iter->second < now) {
-         iter = ip4s->erase(iter);
-        }
-       else {
-         ++iter;
-        }
-      }
-    }
-
-    {
-      auto ip6s = d_ip6s.write_lock();
-      for (auto iter = ip6s->begin(); iter != ip6s->end(); ) {
-       if (iter->second < now) {
-         iter = ip6s->erase(iter);
-        }
-       else {
-         ++iter;
-        }
-      }
-
-    }
-
-  }
-
-  string toString() const override
-  {
-    time_t now = time(nullptr);
-    uint64_t count = 0;
-
-    for (const auto& ip : *(d_ip4s.read_lock())) {
-      if (now < ip.second) {
-        ++count;
-      }
-    }
-
-    for (const auto& ip : *(d_ip6s.read_lock())) {
-      if (now < ip.second) {
-        ++count;
-      }
-    }
-
-    return "Src: "+std::to_string(count)+" ips";
-  }
-private:
-  struct IPv6Hash
-  {
-    std::size_t operator()(const IPv6& ip) const
-    {
-      auto ah=std::hash<uint64_t>{}(ip.a);
-      auto bh=std::hash<uint64_t>{}(ip.b);
-      return ah & (bh<<1);
-    }
-  };
-  mutable SharedLockGuarded<std::unordered_map<IPv6, time_t, IPv6Hash>> d_ip6s;
-  mutable SharedLockGuarded<std::unordered_map<uint32_t, time_t>> d_ip4s;
-};
-
-
-class AllRule : public DNSRule
-{
-public:
-  AllRule() {}
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return true;
-  }
-
-  string toString() const override
-  {
-    return "All";
-  }
-
-};
-
-
-class DNSSECRule : public DNSRule
-{
-public:
-  DNSSECRule()
-  {
-
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
-  }
-
-  string toString() const override
-  {
-    return "DNSSEC";
-  }
-};
-
-class AndRule : public DNSRule
-{
-public:
-  AndRule(const std::vector<pair<int, std::shared_ptr<DNSRule> > >& rules)
-  {
-    for (const auto& r : rules) {
-      d_rules.push_back(r.second);
-    }
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    for (const auto& rule : d_rules) {
-      if (!rule->matches(dq)) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  string toString() const override
-  {
-    string ret;
-    for (const auto& rule : d_rules) {
-      if (!ret.empty()) {
-        ret+= " && ";
-      }
-      ret += "("+ rule->toString()+")";
-    }
-    return ret;
-  }
-private:
-  std::vector<std::shared_ptr<DNSRule> > d_rules;
-};
-
-
-class OrRule : public DNSRule
-{
-public:
-  OrRule(const std::vector<pair<int, std::shared_ptr<DNSRule> > >& rules)
-  {
-    for (const auto& r : rules) {
-      d_rules.push_back(r.second);
-    }
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    for (const auto& rule: d_rules) {
-      if (rule->matches(dq)) {
-        return true;
-      }
-    }
-    return false;
-  }
-
-  string toString() const override
-  {
-    string ret;
-    for (const auto& rule : d_rules) {
-      if (!ret.empty()) {
-        ret+= " || ";
-      }
-      ret += "("+ rule->toString()+")";
-    }
-    return ret;
-  }
-private:
-  std::vector<std::shared_ptr<DNSRule> > d_rules;
-};
-
-
-class RegexRule : public DNSRule
-{
-public:
-  RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
-  {
-
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_regex.match(dq->ids.qname.toStringNoDot());
-  }
-
-  string toString() const override
-  {
-    return "Regex: "+d_visual;
-  }
-private:
-  Regex d_regex;
-  string d_visual;
-};
-
-#ifdef HAVE_RE2
-#include <re2/re2.h>
-class RE2Rule : public DNSRule
-{
-public:
-  RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
-  {
-
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return RE2::FullMatch(dq->ids.qname.toStringNoDot(), d_re2);
-  }
-
-  string toString() const override
-  {
-    return "RE2 match: "+d_visual;
-  }
-private:
-  RE2 d_re2;
-  string d_visual;
-};
-#endif
-
-#ifdef HAVE_DNS_OVER_HTTPS
-class HTTPHeaderRule : public DNSRule
-{
-public:
-  HTTPHeaderRule(const std::string& header, const std::string& regex);
-  bool matches(const DNSQuestion* dq) const override;
-  string toString() const override;
-private:
-  string d_header;
-  Regex d_regex;
-  string d_visual;
-};
-
-class HTTPPathRule : public DNSRule
-{
-public:
-  HTTPPathRule(std::string path);
-  bool matches(const DNSQuestion* dq) const override;
-  string toString() const override;
-private:
-  string d_path;
-};
-
-class HTTPPathRegexRule : public DNSRule
-{
-public:
-  HTTPPathRegexRule(const std::string& regex);
-  bool matches(const DNSQuestion* dq) const override;
-  string toString() const override;
-private:
-  Regex d_regex;
-  std::string d_visual;
-};
-#endif
-
-class SNIRule : public DNSRule
-{
-public:
-  SNIRule(const std::string& name) : d_sni(name)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return dq->sni == d_sni;
-  }
-  string toString() const override
-  {
-    return "SNI == " + d_sni;
-  }
-private:
-  std::string d_sni;
-};
-
-class SuffixMatchNodeRule : public DNSRule
-{
-public:
-  SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_smn.check(dq->ids.qname);
-  }
-  string toString() const override
-  {
-    if(d_quiet)
-      return "qname==in-set";
-    else
-      return "qname in "+d_smn.toString();
-  }
-private:
-  SuffixMatchNode d_smn;
-  bool d_quiet;
-};
-
-class QNameRule : public DNSRule
-{
-public:
-  QNameRule(const DNSName& qname) : d_qname(qname)
-  {
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_qname==dq->ids.qname;
-  }
-  string toString() const override
-  {
-    return "qname=="+d_qname.toString();
-  }
-private:
-  DNSName d_qname;
-};
-
-class QNameSetRule : public DNSRule {
-public:
-    QNameSetRule(const DNSNameSet& names) : qname_idx(names) {}
-
-    bool matches(const DNSQuestion* dq) const override {
-        return qname_idx.find(dq->ids.qname) != qname_idx.end();
-    }
-
-    string toString() const override {
-        std::stringstream ss;
-        ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)";
-        return ss.str();
-    }
-private:
-    DNSNameSet qname_idx;
-};
-
-class QTypeRule : public DNSRule
-{
-public:
-  QTypeRule(uint16_t qtype) : d_qtype(qtype)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_qtype == dq->ids.qtype;
-  }
-  string toString() const override
-  {
-    QType qt(d_qtype);
-    return "qtype=="+qt.toString();
-  }
-private:
-  uint16_t d_qtype;
-};
-
-class QClassRule : public DNSRule
-{
-public:
-  QClassRule(uint16_t qclass) : d_qclass(qclass)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_qclass == dq->ids.qclass;
-  }
-  string toString() const override
-  {
-    return "qclass=="+std::to_string(d_qclass);
-  }
-private:
-  uint16_t d_qclass;
-};
-
-class OpcodeRule : public DNSRule
-{
-public:
-  OpcodeRule(uint8_t opcode) : d_opcode(opcode)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_opcode == dq->getHeader()->opcode;
-  }
-  string toString() const override
-  {
-    return "opcode=="+std::to_string(d_opcode);
-  }
-private:
-  uint8_t d_opcode;
-};
-
-class DSTPortRule : public DNSRule
-{
-public:
-  DSTPortRule(uint16_t port) : d_port(port)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return htons(d_port) == dq->ids.origDest.sin4.sin_port;
-  }
-  string toString() const override
-  {
-    return "dst port=="+std::to_string(d_port);
-  }
-private:
-  uint16_t d_port;
-};
-
-class TCPRule : public DNSRule
-{
-public:
-  TCPRule(bool tcp): d_tcp(tcp)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return dq->overTCP() == d_tcp;
-  }
-  string toString() const override
-  {
-    return (d_tcp ? "TCP" : "UDP");
-  }
-private:
-  bool d_tcp;
-};
-
-
-class NotRule : public DNSRule
-{
-public:
-  NotRule(const std::shared_ptr<DNSRule>& rule): d_rule(rule)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return !d_rule->matches(dq);
-  }
-  string toString() const override
-  {
-    return "!("+ d_rule->toString()+")";
-  }
-private:
-  std::shared_ptr<DNSRule> d_rule;
-};
-
-class RecordsCountRule : public DNSRule
-{
-public:
-  RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    uint16_t count = 0;
-    switch(d_section) {
-    case 0:
-      count = ntohs(dq->getHeader()->qdcount);
-      break;
-    case 1:
-      count = ntohs(dq->getHeader()->ancount);
-      break;
-    case 2:
-      count = ntohs(dq->getHeader()->nscount);
-      break;
-    case 3:
-      count = ntohs(dq->getHeader()->arcount);
-      break;
-    }
-    return count >= d_minCount && count <= d_maxCount;
-  }
-  string toString() const override
-  {
-    string section;
-    switch(d_section) {
-    case 0:
-      section = "QD";
-      break;
-    case 1:
-      section = "AN";
-      break;
-    case 2:
-      section = "NS";
-      break;
-    case 3:
-      section = "AR";
-      break;
-    }
-    return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
-  }
-private:
-  uint16_t d_minCount;
-  uint16_t d_maxCount;
-  uint8_t d_section;
-};
-
-class RecordsTypeCountRule : public DNSRule
-{
-public:
-  RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    uint16_t count = 0;
-    switch(d_section) {
-    case 0:
-      count = ntohs(dq->getHeader()->qdcount);
-      break;
-    case 1:
-      count = ntohs(dq->getHeader()->ancount);
-      break;
-    case 2:
-      count = ntohs(dq->getHeader()->nscount);
-      break;
-    case 3:
-      count = ntohs(dq->getHeader()->arcount);
-      break;
-    }
-    if (count < d_minCount) {
-      return false;
-    }
-    count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size(), d_section, d_type);
-    return count >= d_minCount && count <= d_maxCount;
-  }
-  string toString() const override
-  {
-    string section;
-    switch(d_section) {
-    case 0:
-      section = "QD";
-      break;
-    case 1:
-      section = "AN";
-      break;
-    case 2:
-      section = "NS";
-      break;
-    case 3:
-      section = "AR";
-      break;
-    }
-    return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= "+ std::to_string(d_maxCount);
-  }
-private:
-  uint16_t d_type;
-  uint16_t d_minCount;
-  uint16_t d_maxCount;
-  uint8_t d_section;
-};
-
-class TrailingDataRule : public DNSRule
-{
-public:
-  TrailingDataRule()
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->getData().data()), dq->getData().size());
-    return length < dq->getData().size();
-  }
-  string toString() const override
-  {
-    return "trailing data";
-  }
-};
-
-class QNameLabelsCountRule : public DNSRule
-{
-public:
-  QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    unsigned int count = dq->ids.qname.countLabels();
-    return count < d_min || count > d_max;
-  }
-  string toString() const override
-  {
-    return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
-  }
-private:
-  unsigned int d_min;
-  unsigned int d_max;
-};
-
-class QNameWireLengthRule : public DNSRule
-{
-public:
-  QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    size_t const wirelength = dq->ids.qname.wirelength();
-    return wirelength < d_min || wirelength > d_max;
-  }
-  string toString() const override
-  {
-    return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
-  }
-private:
-  size_t d_min;
-  size_t d_max;
-};
-
-class RCodeRule : public DNSRule
-{
-public:
-  RCodeRule(uint8_t rcode) : d_rcode(rcode)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_rcode == dq->getHeader()->rcode;
-  }
-  string toString() const override
-  {
-    return "rcode=="+RCode::to_s(d_rcode);
-  }
-private:
-  uint8_t d_rcode;
-};
-
-class ERCodeRule : public DNSRule
-{
-public:
-  ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    // avoid parsing EDNS OPT RR when not needed.
-    if (d_rcode != dq->getHeader()->rcode) {
-      return false;
-    }
-
-    EDNS0Record edns0;
-    if (!getEDNS0Record(dq->getData(), edns0)) {
-      return false;
-    }
-
-    return d_extrcode == edns0.extRCode;
-  }
-  string toString() const override
-  {
-    return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
-  }
-private:
-  uint8_t d_rcode;     // plain DNS Rcode
-  uint8_t d_extrcode;  // upper bits in EDNS0 record
-};
-
-class EDNSVersionRule : public DNSRule
-{
-public:
-  EDNSVersionRule(uint8_t version) : d_version(version)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    EDNS0Record edns0;
-    if (!getEDNS0Record(dq->getData(), edns0)) {
-      return false;
-    }
-
-    return d_version < edns0.version;
-  }
-  string toString() const override
-  {
-    return "ednsversion>"+std::to_string(d_version);
-  }
-private:
-  uint8_t d_version;
-};
-
-class EDNSOptionRule : public DNSRule
-{
-public:
-  EDNSOptionRule(uint16_t optcode) : d_optcode(optcode)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    uint16_t optStart;
-    size_t optLen = 0;
-    bool last = false;
-    int res = locateEDNSOptRR(dq->getData(), &optStart, &optLen, &last);
-    if (res != 0) {
-      // no EDNS OPT RR
-      return false;
-    }
-
-    if (optLen < optRecordMinimumSize) {
-      return false;
-    }
-
-    if (optStart < dq->getData().size() && dq->getData().at(optStart) != 0) {
-      // OPT RR Name != '.'
-      return false;
-    }
-
-    return isEDNSOptionInOpt(dq->getData(), optStart, optLen, d_optcode);
-  }
-  string toString() const override
-  {
-    return "ednsoptcode=="+std::to_string(d_optcode);
-  }
-private:
-  uint16_t d_optcode;
-};
-
-class RDRule : public DNSRule
-{
-public:
-  RDRule()
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return dq->getHeader()->rd == 1;
-  }
-  string toString() const override
-  {
-    return "rd==1";
-  }
-};
-
-class ProbaRule : public DNSRule
-{
-public:
-  ProbaRule(double proba) : d_proba(proba)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if(d_proba == 1.0)
-      return true;
-    double rnd = 1.0*dns_random_uint32() / UINT32_MAX;
-    return rnd > (1.0 - d_proba);
-  }
-  string toString() const override
-  {
-    return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
-  }
-private:
-  double d_proba;
-};
-
-class TagRule : public DNSRule
-{
-public:
-  TagRule(const std::string& tag, boost::optional<std::string> value) : d_value(std::move(value)), d_tag(tag)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if (!dq->ids.qTag) {
-      return false;
-    }
-
-    const auto it = dq->ids.qTag->find(d_tag);
-    if (it == dq->ids.qTag->cend()) {
-      return false;
-    }
-
-    if (!d_value) {
-      return true;
-    }
-
-    return it->second == *d_value;
-  }
-
-  string toString() const override
-  {
-    return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
-  }
-
-private:
-  boost::optional<std::string> d_value;
-  std::string d_tag;
-};
-
-class PoolAvailableRule : public DNSRule
-{
-public:
-  PoolAvailableRule(const std::string& poolname) : d_poolname(poolname)
-  {
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return (getPool(d_poolname)->countServers(true) > 0);
-  }
-
-  string toString() const override
-  {
-    return "pool '" + d_poolname + "' is available";
-  }
-private:
-  std::string d_poolname;
-};
-
-class PoolOutstandingRule : public DNSRule
-{
-public:
-  PoolOutstandingRule(const std::string& poolname, const size_t limit) : d_poolname(poolname), d_limit(limit)
-  {
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return (getPool(d_poolname)->poolLoad()) > d_limit;
-  }
-
-  string toString() const override
-  {
-    return "pool '" + d_poolname + "' outstanding > " + std::to_string(d_limit);
-  }
-private:
-  std::string d_poolname;
-  size_t d_limit;
-};
-
-class KeyValueStoreLookupRule: public DNSRule
-{
-public:
-  KeyValueStoreLookupRule(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey): d_kvs(kvs), d_key(lookupKey)
-  {
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    std::vector<std::string> keys = d_key->getKeys(*dq);
-    for (const auto& key : keys) {
-      if (d_kvs->keyExists(key) == true) {
-        return true;
-      }
-    }
-
-    return false;
-  }
-
-  string toString() const override
-  {
-    return "lookup key-value store based on '" + d_key->toString() + "'";
-  }
-
-private:
-  std::shared_ptr<KeyValueStore> d_kvs;
-  std::shared_ptr<KeyValueLookupKey> d_key;
-};
-
-class KeyValueStoreRangeLookupRule: public DNSRule
-{
-public:
-  KeyValueStoreRangeLookupRule(std::shared_ptr<KeyValueStore>& kvs, std::shared_ptr<KeyValueLookupKey>& lookupKey): d_kvs(kvs), d_key(lookupKey)
-  {
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    std::vector<std::string> keys = d_key->getKeys(*dq);
-    for (const auto& key : keys) {
-      std::string value;
-      if (d_kvs->getRangeValue(key, value) == true) {
-        return true;
-      }
-    }
-
-    return false;
-  }
-
-  string toString() const override
-  {
-    return "range-based lookup key-value store based on '" + d_key->toString() + "'";
-  }
-
-private:
-  std::shared_ptr<KeyValueStore> d_kvs;
-  std::shared_ptr<KeyValueLookupKey> d_key;
-};
-
-class LuaRule : public DNSRule
-{
-public:
-  typedef std::function<bool(const DNSQuestion* dq)> func_t;
-  LuaRule(const func_t& func): d_func(func)
-  {}
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    try {
-      auto lock = g_lua.lock();
-      return d_func(dq);
-    } catch (const std::exception &e) {
-      warnlog("LuaRule failed inside Lua: %s", e.what());
-    } catch (...) {
-      warnlog("LuaRule failed inside Lua: [unknown exception]");
-    }
-    return false;
-  }
-
-  string toString() const override
-  {
-    return "Lua script";
-  }
-private:
-  func_t d_func;
-};
-
-class LuaFFIRule : public DNSRule
-{
-public:
-  typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
-  LuaFFIRule(const func_t& func): d_func(func)
-  {}
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
-    try {
-      auto lock = g_lua.lock();
-      return d_func(&dqffi);
-    } catch (const std::exception &e) {
-      warnlog("LuaFFIRule failed inside Lua: %s", e.what());
-    } catch (...) {
-      warnlog("LuaFFIRule failed inside Lua: [unknown exception]");
-    }
-    return false;
-  }
-
-  string toString() const override
-  {
-    return "Lua FFI script";
-  }
-private:
-  func_t d_func;
-};
-
-class LuaFFIPerThreadRule : public DNSRule
-{
-public:
-  typedef std::function<bool(dnsdist_ffi_dnsquestion_t* dq)> func_t;
-
-  LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++)
-  {
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    try {
-      auto& state = t_perThreadStates[d_functionID];
-      if (!state.d_initialized) {
-        setupLuaFFIPerThreadContext(state.d_luaContext);
-        /* mark the state as initialized first so if there is a syntax error
-           we only try to execute the code once */
-        state.d_initialized = true;
-        state.d_func = state.d_luaContext.executeCode<func_t>(d_functionCode);
-      }
-
-      if (!state.d_func) {
-        /* the function was not properly initialized */
-        return false;
-      }
-
-      dnsdist_ffi_dnsquestion_t dqffi(const_cast<DNSQuestion*>(dq));
-      return state.d_func(&dqffi);
-    }
-    catch (const std::exception &e) {
-      warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what());
-    }
-    catch (...) {
-      warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]");
-    }
-    return false;
-  }
-
-  string toString() const override
-  {
-    return "Lua FFI per-thread script";
-  }
-private:
-  struct PerThreadState
-  {
-    LuaContext d_luaContext;
-    func_t d_func;
-    bool d_initialized{false};
-  };
-
-  static std::atomic<uint64_t> s_functionsCounter;
-  static thread_local std::map<uint64_t, PerThreadState> t_perThreadStates;
-  const std::string d_functionCode;
-  const uint64_t d_functionID;
-};
-
-class ProxyProtocolValueRule : public DNSRule
-{
-public:
-  ProxyProtocolValueRule(uint8_t type, boost::optional<std::string> value): d_value(std::move(value)), d_type(type)
-  {
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if (!dq->proxyProtocolValues) {
-      return false;
-    }
-
-    for (const auto& entry : *dq->proxyProtocolValues) {
-      if (entry.type == d_type && (!d_value || entry.content == *d_value)) {
-        return true;
-      }
-    }
-
-    return false;
-  }
-
-  string toString() const override
-  {
-    if (d_value) {
-      return "proxy protocol value of type " + std::to_string(d_type) + " matches";
-    }
-    return "proxy protocol value of type " + std::to_string(d_type) + " is present";
-  }
-
-private:
-  boost::optional<std::string> d_value;
-  uint8_t d_type;
-};
-
-class PayloadSizeRule : public DNSRule
-{
-  enum class Comparisons : uint8_t { equal, greater, greaterOrEqual, smaller, smallerOrEqual };
-public:
-  PayloadSizeRule(const std::string& comparison, uint16_t size): d_size(size)
-  {
-    if (comparison == "equal") {
-      d_comparison = Comparisons::equal;
-    }
-    else if (comparison == "greater") {
-      d_comparison = Comparisons::greater;
-    }
-    else if (comparison == "greaterOrEqual") {
-      d_comparison = Comparisons::greaterOrEqual;
-    }
-    else if (comparison == "smaller") {
-      d_comparison = Comparisons::smaller;
-    }
-    else if (comparison == "smallerOrEqual") {
-      d_comparison = Comparisons::smallerOrEqual;
-    }
-    else {
-      throw std::runtime_error("Unsupported comparison '" + comparison + "'");
-    }
-  }
 
-  bool matches(const DNSQuestion* dq) const override
-  {
-    const auto size = dq->getData().size();
-
-    switch (d_comparison) {
-    case Comparisons::equal:
-      return size == d_size;
-    case Comparisons::greater:
-      return size > d_size;
-    case Comparisons::greaterOrEqual:
-      return size >= d_size;
-    case Comparisons::smaller:
-      return size < d_size;
-    case Comparisons::smallerOrEqual:
-      return size <= d_size;
-    default:
-      return false;
-    }
-  }
-
-  string toString() const override
-  {
-    static const std::array<const std::string, 5> comparisonStr{
-      "equal to" ,
-      "greater than",
-      "equal to or greater than",
-      "smaller than",
-      "equal to or smaller than"
-    };
-    return "payload size is " + comparisonStr.at(static_cast<size_t>(d_comparison)) + " " + std::to_string(d_size);
-  }
-
-private:
-  uint16_t d_size;
-  Comparisons d_comparison;
+  mutable stat_t d_matches{0};
 };
index 7457fb3bb16f1005d60d72574537803f586855dc..8c91e3699e187ddbe9276c1b9390b0f2f6168253 100644 (file)
@@ -8,7 +8,7 @@
 #include <thread>
 #include <boost/test/unit_test.hpp>
 
-#include "dnsdist-rules.hh"
+#include "dnsdist-rules-factory.hh"
 
 void checkParameterBound(const std::string& parameter, uint64_t value, size_t max)
 {