From: Remi Gacogne Date: Tue, 24 Dec 2024 14:10:46 +0000 (+0100) Subject: dnsdist: Move rules to dnsdist-rules-factory.hh X-Git-Tag: dnsdist-2.0.0-alpha1~160^2~36 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4458146872f44d7dd343cb1bd0ddc9bd336ab799;p=thirdparty%2Fpdns.git dnsdist: Move rules to dnsdist-rules-factory.hh --- diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index 8a881c8b43..6c17d416e1 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -203,6 +203,7 @@ dnsdist_SOURCES = \ dnsdist-resolver.cc dnsdist-resolver.hh \ dnsdist-rings.cc dnsdist-rings.hh \ dnsdist-rule-chains.cc dnsdist-rule-chains.hh \ + dnsdist-rules-factory.hh \ dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-secpoll.cc dnsdist-secpoll.hh \ dnsdist-self-answers.cc dnsdist-self-answers.hh \ @@ -312,6 +313,7 @@ testrunner_SOURCES = \ dnsdist-resolver.cc dnsdist-resolver.hh \ dnsdist-rings.cc dnsdist-rings.hh \ dnsdist-rule-chains.cc dnsdist-rule-chains.hh \ + dnsdist-rules-factory.hh \ dnsdist-rules.cc dnsdist-rules.hh \ dnsdist-self-answers.cc dnsdist-self-answers.hh \ dnsdist-session-cache.cc dnsdist-session-cache.hh \ diff --git a/pdns/dnsdistdist/dnsdist-lua-rules.cc b/pdns/dnsdistdist/dnsdist-lua-rules.cc index 10937c03fc..dd86ec473c 100644 --- a/pdns/dnsdistdist/dnsdist-lua-rules.cc +++ b/pdns/dnsdistdist/dnsdist-lua-rules.cc @@ -21,7 +21,7 @@ */ #include "dnsdist.hh" #include "dnsdist-lua.hh" -#include "dnsdist-rules.hh" +#include "dnsdist-rules-factory.hh" #include "dnsdist-rule-chains.hh" #include "dns_random.hh" diff --git a/pdns/dnsdistdist/dnsdist-rules-factory.hh b/pdns/dnsdistdist/dnsdist-rules-factory.hh new file mode 100644 index 0000000000..470cbda7ad --- /dev/null +++ b/pdns/dnsdistdist/dnsdist-rules-factory.hh @@ -0,0 +1,1415 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#pragma once + +#include "dnsdist-rules.hh" + +#include "cachecleaner.hh" +#include "dnsdist-ecs.hh" +#include "dnsdist-kvs.hh" +#include "dnsdist-lua.hh" +#include "dnsdist-lua-ffi.hh" +#include "dolog.hh" +#include "dnsparser.hh" +#include "dns_random.hh" + +#include +#include +#include +#include + +class MaxQPSIPRule : public DNSRule +{ +public: + MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10, size_t shardsCount=10): + d_shards(shardsCount), d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction) + { + d_cleaningUp.clear(); + gettime(&d_lastCleanup, true); + } + + void clear() + { + for (auto& shard : d_shards) { + shard.lock()->clear(); + } + } + + size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const + { + size_t removed = 0; + if (scannedCount != nullptr) { + *scannedCount = 0; + } + + for (auto& shard : d_shards) { + auto limits = shard.lock(); + const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction)+ 1; + size_t lookedAt = 0; + + auto& sequence = limits->get(); + for (auto entry = sequence.begin(); entry != sequence.end() && lookedAt < toLook; lookedAt++) { + if (entry->d_limiter.seenSince(cutOff)) { + /* entries are ordered from least recently seen to more recently + seen, as soon as we see one that has not expired yet, we are + done */ + lookedAt++; + break; + } + + entry = sequence.erase(entry); + removed++; + } + + if (scannedCount != nullptr) { + *scannedCount += lookedAt; + } + } + + return removed; + } + + void cleanupIfNeeded(const struct timespec& now) const + { + if (d_cleanupDelay > 0) { + struct timespec cutOff = d_lastCleanup; + cutOff.tv_sec += d_cleanupDelay; + + if (cutOff < now) { + try { + if (d_cleaningUp.test_and_set()) { + return; + } + + d_lastCleanup = now; + /* the QPS Limiter doesn't use realtime, be careful! */ + gettime(&cutOff, false); + cutOff.tv_sec -= d_expiration; + + cleanup(cutOff); + d_cleaningUp.clear(); + } + catch (...) { + d_cleaningUp.clear(); + throw; + } + } + } + } + + bool matches(const DNSQuestion* dq) const override + { + cleanupIfNeeded(dq->getQueryRealTime()); + + ComboAddress zeroport(dq->ids.origRemote); + zeroport.sin4.sin_port=0; + zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc); + auto hash = ComboAddress::addressOnlyHash()(zeroport); + auto& shard = d_shards[hash % d_shards.size()]; + { + auto limits = shard.lock(); + auto iter = limits->find(zeroport); + if (iter == limits->end()) { + Entry e(zeroport, QPSLimiter(d_qps, d_burst)); + iter = limits->insert(e).first; + } + + moveCacheItemToBack(*limits, iter); + return !iter->d_limiter.check(d_qps, d_burst); + } + } + + string toString() const override + { + return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst); + } + + size_t getEntriesCount() const + { + size_t count = 0; + for (auto& shard : d_shards) { + count += shard.lock()->size(); + } + return count; + } + + size_t getNumberOfShards() const + { + return d_shards.size(); + } + +private: + struct HashedTag {}; + struct SequencedTag {}; + struct Entry + { + Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr) + { + } + mutable BasicQPSLimiter d_limiter; + ComboAddress d_addr; + }; + + typedef multi_index_container< + Entry, + indexed_by < + hashed_unique, member, ComboAddress::addressOnlyHash >, + sequenced > + > + > qpsContainer_t; + + mutable std::vector> d_shards; + mutable struct timespec d_lastCleanup; + const unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration; + const unsigned int d_scanFraction{10}; + mutable std::atomic_flag d_cleaningUp; +}; + +class MaxQPSRule : public DNSRule +{ +public: + MaxQPSRule(unsigned int qps) + : d_qps(qps, qps) + {} + + MaxQPSRule(unsigned int qps, unsigned int burst) + : d_qps(qps, burst) + {} + + + bool matches(const DNSQuestion* qd) const override + { + return d_qps.check(); + } + + string toString() const override + { + return "Max " + std::to_string(d_qps.getRate()) + " qps"; + } + + +private: + mutable QPSLimiter d_qps; +}; + +class NMGRule : public DNSRule +{ +public: + NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {} +protected: + NetmaskGroup d_nmg; +}; + +class NetmaskGroupRule : public NMGRule +{ +public: + NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : NMGRule(nmg) + { + d_src = src; + d_quiet = quiet; + } + bool matches(const DNSQuestion* dq) const override + { + if(!d_src) { + return d_nmg.match(dq->ids.origDest); + } + return d_nmg.match(dq->ids.origRemote); + } + + string toString() const override + { + string ret = "Src: "; + if(!d_src) { + ret = "Dst: "; + } + if (d_quiet) { + return ret + "in-group"; + } + return ret + d_nmg.toString(); + } +private: + bool d_src; + bool d_quiet; +}; + +class TimedIPSetRule : public DNSRule, boost::noncopyable +{ +private: + struct IPv6 { + IPv6(const ComboAddress& ca) + { + static_assert(sizeof(*this)==16, "IPv6 struct has wrong size"); + memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16); + } + bool operator==(const IPv6& rhs) const + { + return a==rhs.a && b==rhs.b; + } + uint64_t a, b; + }; + +public: + TimedIPSetRule() + { + } + ~TimedIPSetRule() + { + } + bool matches(const DNSQuestion* dq) const override + { + if (dq->ids.origRemote.sin4.sin_family == AF_INET) { + auto ip4s = d_ip4s.read_lock(); + auto fnd = ip4s->find(dq->ids.origRemote.sin4.sin_addr.s_addr); + if (fnd == ip4s->end()) { + return false; + } + return time(nullptr) < fnd->second; + } else { + auto ip6s = d_ip6s.read_lock(); + auto fnd = ip6s->find({dq->ids.origRemote}); + if (fnd == ip6s->end()) { + return false; + } + return time(nullptr) < fnd->second; + } + } + + void add(const ComboAddress& ca, time_t ttd) + { + // think twice before adding templates here + if (ca.sin4.sin_family == AF_INET) { + auto res = d_ip4s.write_lock()->insert({ca.sin4.sin_addr.s_addr, ttd}); + if (!res.second && (time_t)res.first->second < ttd) { + res.first->second = (uint32_t)ttd; + } + } + else { + auto res = d_ip6s.write_lock()->insert({{ca}, ttd}); + if (!res.second && (time_t)res.first->second < ttd) { + // coverity[store_truncates_time_t] + res.first->second = (uint32_t)ttd; + } + } + } + + void remove(const ComboAddress& ca) + { + if (ca.sin4.sin_family == AF_INET) { + d_ip4s.write_lock()->erase(ca.sin4.sin_addr.s_addr); + } + else { + d_ip6s.write_lock()->erase({ca}); + } + } + + void clear() + { + d_ip4s.write_lock()->clear(); + d_ip6s.write_lock()->clear(); + } + + void cleanup() + { + time_t now = time(nullptr); + { + auto ip4s = d_ip4s.write_lock(); + for (auto iter = ip4s->begin(); iter != ip4s->end(); ) { + if (iter->second < now) { + iter = ip4s->erase(iter); + } + else { + ++iter; + } + } + } + + { + auto ip6s = d_ip6s.write_lock(); + for (auto iter = ip6s->begin(); iter != ip6s->end(); ) { + if (iter->second < now) { + iter = ip6s->erase(iter); + } + else { + ++iter; + } + } + + } + + } + + string toString() const override + { + time_t now = time(nullptr); + uint64_t count = 0; + + for (const auto& ip : *(d_ip4s.read_lock())) { + if (now < ip.second) { + ++count; + } + } + + for (const auto& ip : *(d_ip6s.read_lock())) { + if (now < ip.second) { + ++count; + } + } + + return "Src: "+std::to_string(count)+" ips"; + } +private: + struct IPv6Hash + { + std::size_t operator()(const IPv6& ip) const + { + auto ah=std::hash{}(ip.a); + auto bh=std::hash{}(ip.b); + return ah & (bh<<1); + } + }; + mutable SharedLockGuarded> d_ip6s; + mutable SharedLockGuarded> d_ip4s; +}; + + +class AllRule : public DNSRule +{ +public: + AllRule() {} + bool matches(const DNSQuestion* dq) const override + { + return true; + } + + string toString() const override + { + return "All"; + } + +}; + + +class DNSSECRule : public DNSRule +{ +public: + DNSSECRule() + { + + } + bool matches(const DNSQuestion* dq) const override + { + return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. + } + + string toString() const override + { + return "DNSSEC"; + } +}; + +class AndRule : public DNSRule +{ +public: + AndRule(const std::vector > >& rules) + { + for (const auto& r : rules) { + d_rules.push_back(r.second); + } + } + + bool matches(const DNSQuestion* dq) const override + { + for (const auto& rule : d_rules) { + if (!rule->matches(dq)) { + return false; + } + } + return true; + } + + string toString() const override + { + string ret; + for (const auto& rule : d_rules) { + if (!ret.empty()) { + ret+= " && "; + } + ret += "("+ rule->toString()+")"; + } + return ret; + } +private: + std::vector > d_rules; +}; + + +class OrRule : public DNSRule +{ +public: + OrRule(const std::vector > >& rules) + { + for (const auto& r : rules) { + d_rules.push_back(r.second); + } + } + + bool matches(const DNSQuestion* dq) const override + { + for (const auto& rule: d_rules) { + if (rule->matches(dq)) { + return true; + } + } + return false; + } + + string toString() const override + { + string ret; + for (const auto& rule : d_rules) { + if (!ret.empty()) { + ret+= " || "; + } + ret += "("+ rule->toString()+")"; + } + return ret; + } +private: + std::vector > d_rules; +}; + + +class RegexRule : public DNSRule +{ +public: + RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex) + { + + } + bool matches(const DNSQuestion* dq) const override + { + return d_regex.match(dq->ids.qname.toStringNoDot()); + } + + string toString() const override + { + return "Regex: "+d_visual; + } +private: + Regex d_regex; + string d_visual; +}; + +#ifdef HAVE_RE2 +#include +class RE2Rule : public DNSRule +{ +public: + RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2) + { + + } + bool matches(const DNSQuestion* dq) const override + { + return RE2::FullMatch(dq->ids.qname.toStringNoDot(), d_re2); + } + + string toString() const override + { + return "RE2 match: "+d_visual; + } +private: + RE2 d_re2; + string d_visual; +}; +#endif + +#ifdef HAVE_DNS_OVER_HTTPS +class HTTPHeaderRule : public DNSRule +{ +public: + HTTPHeaderRule(const std::string& header, const std::string& regex); + bool matches(const DNSQuestion* dq) const override; + string toString() const override; +private: + string d_header; + Regex d_regex; + string d_visual; +}; + +class HTTPPathRule : public DNSRule +{ +public: + HTTPPathRule(std::string path); + bool matches(const DNSQuestion* dq) const override; + string toString() const override; +private: + string d_path; +}; + +class HTTPPathRegexRule : public DNSRule +{ +public: + HTTPPathRegexRule(const std::string& regex); + bool matches(const DNSQuestion* dq) const override; + string toString() const override; +private: + Regex d_regex; + std::string d_visual; +}; +#endif + +class SNIRule : public DNSRule +{ +public: + SNIRule(const std::string& name) : d_sni(name) + { + } + bool matches(const DNSQuestion* dq) const override + { + return dq->sni == d_sni; + } + string toString() const override + { + return "SNI == " + d_sni; + } +private: + std::string d_sni; +}; + +class SuffixMatchNodeRule : public DNSRule +{ +public: + SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet) + { + } + bool matches(const DNSQuestion* dq) const override + { + return d_smn.check(dq->ids.qname); + } + string toString() const override + { + if(d_quiet) + return "qname==in-set"; + else + return "qname in "+d_smn.toString(); + } +private: + SuffixMatchNode d_smn; + bool d_quiet; +}; + +class QNameRule : public DNSRule +{ +public: + QNameRule(const DNSName& qname) : d_qname(qname) + { + } + + bool matches(const DNSQuestion* dq) const override + { + return d_qname==dq->ids.qname; + } + string toString() const override + { + return "qname=="+d_qname.toString(); + } +private: + DNSName d_qname; +}; + +class QNameSetRule : public DNSRule { +public: + QNameSetRule(const DNSNameSet& names) : qname_idx(names) {} + + bool matches(const DNSQuestion* dq) const override { + return qname_idx.find(dq->ids.qname) != qname_idx.end(); + } + + string toString() const override { + std::stringstream ss; + ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)"; + return ss.str(); + } +private: + DNSNameSet qname_idx; +}; + +class QTypeRule : public DNSRule +{ +public: + QTypeRule(uint16_t qtype) : d_qtype(qtype) + { + } + bool matches(const DNSQuestion* dq) const override + { + return d_qtype == dq->ids.qtype; + } + string toString() const override + { + QType qt(d_qtype); + return "qtype=="+qt.toString(); + } +private: + uint16_t d_qtype; +}; + +class QClassRule : public DNSRule +{ +public: + QClassRule(uint16_t qclass) : d_qclass(qclass) + { + } + bool matches(const DNSQuestion* dq) const override + { + return d_qclass == dq->ids.qclass; + } + string toString() const override + { + return "qclass=="+std::to_string(d_qclass); + } +private: + uint16_t d_qclass; +}; + +class OpcodeRule : public DNSRule +{ +public: + OpcodeRule(uint8_t opcode) : d_opcode(opcode) + { + } + bool matches(const DNSQuestion* dq) const override + { + return d_opcode == dq->getHeader()->opcode; + } + string toString() const override + { + return "opcode=="+std::to_string(d_opcode); + } +private: + uint8_t d_opcode; +}; + +class DSTPortRule : public DNSRule +{ +public: + DSTPortRule(uint16_t port) : d_port(port) + { + } + bool matches(const DNSQuestion* dq) const override + { + return htons(d_port) == dq->ids.origDest.sin4.sin_port; + } + string toString() const override + { + return "dst port=="+std::to_string(d_port); + } +private: + uint16_t d_port; +}; + +class TCPRule : public DNSRule +{ +public: + TCPRule(bool tcp): d_tcp(tcp) + { + } + bool matches(const DNSQuestion* dq) const override + { + return dq->overTCP() == d_tcp; + } + string toString() const override + { + return (d_tcp ? "TCP" : "UDP"); + } +private: + bool d_tcp; +}; + + +class NotRule : public DNSRule +{ +public: + NotRule(const std::shared_ptr& rule): d_rule(rule) + { + } + bool matches(const DNSQuestion* dq) const override + { + return !d_rule->matches(dq); + } + string toString() const override + { + return "!("+ d_rule->toString()+")"; + } +private: + std::shared_ptr d_rule; +}; + +class RecordsCountRule : public DNSRule +{ +public: + RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section) + { + } + bool matches(const DNSQuestion* dq) const override + { + uint16_t count = 0; + switch(d_section) { + case 0: + count = ntohs(dq->getHeader()->qdcount); + break; + case 1: + count = ntohs(dq->getHeader()->ancount); + break; + case 2: + count = ntohs(dq->getHeader()->nscount); + break; + case 3: + count = ntohs(dq->getHeader()->arcount); + break; + } + return count >= d_minCount && count <= d_maxCount; + } + string toString() const override + { + string section; + switch(d_section) { + case 0: + section = "QD"; + break; + case 1: + section = "AN"; + break; + case 2: + section = "NS"; + break; + case 3: + section = "AR"; + break; + } + return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount); + } +private: + uint16_t d_minCount; + uint16_t d_maxCount; + uint8_t d_section; +}; + +class RecordsTypeCountRule : public DNSRule +{ +public: + RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section) + { + } + bool matches(const DNSQuestion* dq) const override + { + uint16_t count = 0; + switch(d_section) { + case 0: + count = ntohs(dq->getHeader()->qdcount); + break; + case 1: + count = ntohs(dq->getHeader()->ancount); + break; + case 2: + count = ntohs(dq->getHeader()->nscount); + break; + case 3: + count = ntohs(dq->getHeader()->arcount); + break; + } + if (count < d_minCount) { + return false; + } + count = getRecordsOfTypeCount(reinterpret_cast(dq->getData().data()), dq->getData().size(), d_section, d_type); + return count >= d_minCount && count <= d_maxCount; + } + string toString() const override + { + string section; + switch(d_section) { + case 0: + section = "QD"; + break; + case 1: + section = "AN"; + break; + case 2: + section = "NS"; + break; + case 3: + section = "AR"; + break; + } + return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= "+ std::to_string(d_maxCount); + } +private: + uint16_t d_type; + uint16_t d_minCount; + uint16_t d_maxCount; + uint8_t d_section; +}; + +class TrailingDataRule : public DNSRule +{ +public: + TrailingDataRule() + { + } + bool matches(const DNSQuestion* dq) const override + { + uint16_t length = getDNSPacketLength(reinterpret_cast(dq->getData().data()), dq->getData().size()); + return length < dq->getData().size(); + } + string toString() const override + { + return "trailing data"; + } +}; + +class QNameLabelsCountRule : public DNSRule +{ +public: + QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount) + { + } + bool matches(const DNSQuestion* dq) const override + { + unsigned int count = dq->ids.qname.countLabels(); + return count < d_min || count > d_max; + } + string toString() const override + { + return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max); + } +private: + unsigned int d_min; + unsigned int d_max; +}; + +class QNameWireLengthRule : public DNSRule +{ +public: + QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max) + { + } + bool matches(const DNSQuestion* dq) const override + { + size_t const wirelength = dq->ids.qname.wirelength(); + return wirelength < d_min || wirelength > d_max; + } + string toString() const override + { + return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max); + } +private: + size_t d_min; + size_t d_max; +}; + +class RCodeRule : public DNSRule +{ +public: + RCodeRule(uint8_t rcode) : d_rcode(rcode) + { + } + bool matches(const DNSQuestion* dq) const override + { + return d_rcode == dq->getHeader()->rcode; + } + string toString() const override + { + return "rcode=="+RCode::to_s(d_rcode); + } +private: + uint8_t d_rcode; +}; + +class ERCodeRule : public DNSRule +{ +public: + ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4) + { + } + bool matches(const DNSQuestion* dq) const override + { + // avoid parsing EDNS OPT RR when not needed. + if (d_rcode != dq->getHeader()->rcode) { + return false; + } + + EDNS0Record edns0; + if (!getEDNS0Record(dq->getData(), edns0)) { + return false; + } + + return d_extrcode == edns0.extRCode; + } + string toString() const override + { + return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4)); + } +private: + uint8_t d_rcode; // plain DNS Rcode + uint8_t d_extrcode; // upper bits in EDNS0 record +}; + +class EDNSVersionRule : public DNSRule +{ +public: + EDNSVersionRule(uint8_t version) : d_version(version) + { + } + bool matches(const DNSQuestion* dq) const override + { + EDNS0Record edns0; + if (!getEDNS0Record(dq->getData(), edns0)) { + return false; + } + + return d_version < edns0.version; + } + string toString() const override + { + return "ednsversion>"+std::to_string(d_version); + } +private: + uint8_t d_version; +}; + +class EDNSOptionRule : public DNSRule +{ +public: + EDNSOptionRule(uint16_t optcode) : d_optcode(optcode) + { + } + bool matches(const DNSQuestion* dq) const override + { + uint16_t optStart; + size_t optLen = 0; + bool last = false; + int res = locateEDNSOptRR(dq->getData(), &optStart, &optLen, &last); + if (res != 0) { + // no EDNS OPT RR + return false; + } + + if (optLen < optRecordMinimumSize) { + return false; + } + + if (optStart < dq->getData().size() && dq->getData().at(optStart) != 0) { + // OPT RR Name != '.' + return false; + } + + return isEDNSOptionInOpt(dq->getData(), optStart, optLen, d_optcode); + } + string toString() const override + { + return "ednsoptcode=="+std::to_string(d_optcode); + } +private: + uint16_t d_optcode; +}; + +class RDRule : public DNSRule +{ +public: + RDRule() + { + } + bool matches(const DNSQuestion* dq) const override + { + return dq->getHeader()->rd == 1; + } + string toString() const override + { + return "rd==1"; + } +}; + +class ProbaRule : public DNSRule +{ +public: + ProbaRule(double proba) : d_proba(proba) + { + } + bool matches(const DNSQuestion* dq) const override + { + if(d_proba == 1.0) + return true; + double rnd = 1.0*dns_random_uint32() / UINT32_MAX; + return rnd > (1.0 - d_proba); + } + string toString() const override + { + return "match with prob. " + (boost::format("%0.2f") % d_proba).str(); + } +private: + double d_proba; +}; + +class TagRule : public DNSRule +{ +public: + TagRule(const std::string& tag, boost::optional value) : d_value(std::move(value)), d_tag(tag) + { + } + bool matches(const DNSQuestion* dq) const override + { + if (!dq->ids.qTag) { + return false; + } + + const auto it = dq->ids.qTag->find(d_tag); + if (it == dq->ids.qTag->cend()) { + return false; + } + + if (!d_value) { + return true; + } + + return it->second == *d_value; + } + + string toString() const override + { + return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : ""); + } + +private: + boost::optional d_value; + std::string d_tag; +}; + +class PoolAvailableRule : public DNSRule +{ +public: + PoolAvailableRule(const std::string& poolname) : d_poolname(poolname) + { + } + + bool matches(const DNSQuestion* dq) const override + { + return (getPool(d_poolname)->countServers(true) > 0); + } + + string toString() const override + { + return "pool '" + d_poolname + "' is available"; + } +private: + std::string d_poolname; +}; + +class PoolOutstandingRule : public DNSRule +{ +public: + PoolOutstandingRule(const std::string& poolname, const size_t limit) : d_poolname(poolname), d_limit(limit) + { + } + + bool matches(const DNSQuestion* dq) const override + { + return (getPool(d_poolname)->poolLoad()) > d_limit; + } + + string toString() const override + { + return "pool '" + d_poolname + "' outstanding > " + std::to_string(d_limit); + } +private: + std::string d_poolname; + size_t d_limit; +}; + +class KeyValueStoreLookupRule: public DNSRule +{ +public: + KeyValueStoreLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) + { + } + + bool matches(const DNSQuestion* dq) const override + { + std::vector keys = d_key->getKeys(*dq); + for (const auto& key : keys) { + if (d_kvs->keyExists(key) == true) { + return true; + } + } + + return false; + } + + string toString() const override + { + return "lookup key-value store based on '" + d_key->toString() + "'"; + } + +private: + std::shared_ptr d_kvs; + std::shared_ptr d_key; +}; + +class KeyValueStoreRangeLookupRule: public DNSRule +{ +public: + KeyValueStoreRangeLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) + { + } + + bool matches(const DNSQuestion* dq) const override + { + std::vector keys = d_key->getKeys(*dq); + for (const auto& key : keys) { + std::string value; + if (d_kvs->getRangeValue(key, value) == true) { + return true; + } + } + + return false; + } + + string toString() const override + { + return "range-based lookup key-value store based on '" + d_key->toString() + "'"; + } + +private: + std::shared_ptr d_kvs; + std::shared_ptr d_key; +}; + +class LuaRule : public DNSRule +{ +public: + typedef std::function func_t; + LuaRule(const func_t& func): d_func(func) + {} + + bool matches(const DNSQuestion* dq) const override + { + try { + auto lock = g_lua.lock(); + return d_func(dq); + } catch (const std::exception &e) { + warnlog("LuaRule failed inside Lua: %s", e.what()); + } catch (...) { + warnlog("LuaRule failed inside Lua: [unknown exception]"); + } + return false; + } + + string toString() const override + { + return "Lua script"; + } +private: + func_t d_func; +}; + +class LuaFFIRule : public DNSRule +{ +public: + typedef std::function func_t; + LuaFFIRule(const func_t& func): d_func(func) + {} + + bool matches(const DNSQuestion* dq) const override + { + dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); + try { + auto lock = g_lua.lock(); + return d_func(&dqffi); + } catch (const std::exception &e) { + warnlog("LuaFFIRule failed inside Lua: %s", e.what()); + } catch (...) { + warnlog("LuaFFIRule failed inside Lua: [unknown exception]"); + } + return false; + } + + string toString() const override + { + return "Lua FFI script"; + } +private: + func_t d_func; +}; + +class LuaFFIPerThreadRule : public DNSRule +{ +public: + typedef std::function func_t; + + LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) + { + } + + bool matches(const DNSQuestion* dq) const override + { + try { + auto& state = t_perThreadStates[d_functionID]; + if (!state.d_initialized) { + setupLuaFFIPerThreadContext(state.d_luaContext); + /* mark the state as initialized first so if there is a syntax error + we only try to execute the code once */ + state.d_initialized = true; + state.d_func = state.d_luaContext.executeCode(d_functionCode); + } + + if (!state.d_func) { + /* the function was not properly initialized */ + return false; + } + + dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); + return state.d_func(&dqffi); + } + catch (const std::exception &e) { + warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what()); + } + catch (...) { + warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]"); + } + return false; + } + + string toString() const override + { + return "Lua FFI per-thread script"; + } +private: + struct PerThreadState + { + LuaContext d_luaContext; + func_t d_func; + bool d_initialized{false}; + }; + + static std::atomic s_functionsCounter; + static thread_local std::map t_perThreadStates; + const std::string d_functionCode; + const uint64_t d_functionID; +}; + +class ProxyProtocolValueRule : public DNSRule +{ +public: + ProxyProtocolValueRule(uint8_t type, boost::optional value): d_value(std::move(value)), d_type(type) + { + } + + bool matches(const DNSQuestion* dq) const override + { + if (!dq->proxyProtocolValues) { + return false; + } + + for (const auto& entry : *dq->proxyProtocolValues) { + if (entry.type == d_type && (!d_value || entry.content == *d_value)) { + return true; + } + } + + return false; + } + + string toString() const override + { + if (d_value) { + return "proxy protocol value of type " + std::to_string(d_type) + " matches"; + } + return "proxy protocol value of type " + std::to_string(d_type) + " is present"; + } + +private: + boost::optional d_value; + uint8_t d_type; +}; + +class PayloadSizeRule : public DNSRule +{ + enum class Comparisons : uint8_t { equal, greater, greaterOrEqual, smaller, smallerOrEqual }; +public: + PayloadSizeRule(const std::string& comparison, uint16_t size): d_size(size) + { + if (comparison == "equal") { + d_comparison = Comparisons::equal; + } + else if (comparison == "greater") { + d_comparison = Comparisons::greater; + } + else if (comparison == "greaterOrEqual") { + d_comparison = Comparisons::greaterOrEqual; + } + else if (comparison == "smaller") { + d_comparison = Comparisons::smaller; + } + else if (comparison == "smallerOrEqual") { + d_comparison = Comparisons::smallerOrEqual; + } + else { + throw std::runtime_error("Unsupported comparison '" + comparison + "'"); + } + } + + bool matches(const DNSQuestion* dq) const override + { + const auto size = dq->getData().size(); + + switch (d_comparison) { + case Comparisons::equal: + return size == d_size; + case Comparisons::greater: + return size > d_size; + case Comparisons::greaterOrEqual: + return size >= d_size; + case Comparisons::smaller: + return size < d_size; + case Comparisons::smallerOrEqual: + return size <= d_size; + default: + return false; + } + } + + string toString() const override + { + static const std::array comparisonStr{ + "equal to" , + "greater than", + "equal to or greater than", + "smaller than", + "equal to or smaller than" + }; + return "payload size is " + comparisonStr.at(static_cast(d_comparison)) + " " + std::to_string(d_size); + } + +private: + uint16_t d_size; + Comparisons d_comparison; +}; diff --git a/pdns/dnsdistdist/dnsdist-rules.cc b/pdns/dnsdistdist/dnsdist-rules.cc index b2688b80dd..fe716a7d1d 100644 --- a/pdns/dnsdistdist/dnsdist-rules.cc +++ b/pdns/dnsdistdist/dnsdist-rules.cc @@ -19,7 +19,7 @@ * along with this program; if not, write to the Free Software * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ -#include "dnsdist-rules.hh" +#include "dnsdist-rules-factory.hh" std::atomic LuaFFIPerThreadRule::s_functionsCounter = 0; thread_local std::map LuaFFIPerThreadRule::t_perThreadStates; diff --git a/pdns/dnsdistdist/dnsdist-rules.hh b/pdns/dnsdistdist/dnsdist-rules.hh index 2d8c4a8415..aca662b962 100644 --- a/pdns/dnsdistdist/dnsdist-rules.hh +++ b/pdns/dnsdistdist/dnsdist-rules.hh @@ -21,20 +21,10 @@ */ #pragma once -#include -#include -#include -#include +#include -#include "cachecleaner.hh" #include "dnsdist.hh" -#include "dnsdist-ecs.hh" -#include "dnsdist-kvs.hh" -#include "dnsdist-lua.hh" -#include "dnsdist-lua-ffi.hh" -#include "dolog.hh" -#include "dnsparser.hh" -#include "dns_random.hh" +#include "stat_t.hh" class DNSRule { @@ -44,1383 +34,6 @@ public: } virtual bool matches(const DNSQuestion* dq) const = 0; virtual string toString() const = 0; - mutable stat_t d_matches{0}; -}; - - -class MaxQPSIPRule : public DNSRule -{ -public: - MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10, size_t shardsCount=10): - d_shards(shardsCount), d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction) - { - d_cleaningUp.clear(); - gettime(&d_lastCleanup, true); - } - - void clear() - { - for (auto& shard : d_shards) { - shard.lock()->clear(); - } - } - - size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const - { - size_t removed = 0; - if (scannedCount != nullptr) { - *scannedCount = 0; - } - - for (auto& shard : d_shards) { - auto limits = shard.lock(); - const size_t toLook = std::round((1.0 * limits->size()) / d_scanFraction)+ 1; - size_t lookedAt = 0; - - auto& sequence = limits->get(); - for (auto entry = sequence.begin(); entry != sequence.end() && lookedAt < toLook; lookedAt++) { - if (entry->d_limiter.seenSince(cutOff)) { - /* entries are ordered from least recently seen to more recently - seen, as soon as we see one that has not expired yet, we are - done */ - lookedAt++; - break; - } - - entry = sequence.erase(entry); - removed++; - } - - if (scannedCount != nullptr) { - *scannedCount += lookedAt; - } - } - - return removed; - } - - void cleanupIfNeeded(const struct timespec& now) const - { - if (d_cleanupDelay > 0) { - struct timespec cutOff = d_lastCleanup; - cutOff.tv_sec += d_cleanupDelay; - - if (cutOff < now) { - try { - if (d_cleaningUp.test_and_set()) { - return; - } - - d_lastCleanup = now; - /* the QPS Limiter doesn't use realtime, be careful! */ - gettime(&cutOff, false); - cutOff.tv_sec -= d_expiration; - - cleanup(cutOff); - d_cleaningUp.clear(); - } - catch (...) { - d_cleaningUp.clear(); - throw; - } - } - } - } - - bool matches(const DNSQuestion* dq) const override - { - cleanupIfNeeded(dq->getQueryRealTime()); - - ComboAddress zeroport(dq->ids.origRemote); - zeroport.sin4.sin_port=0; - zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc); - auto hash = ComboAddress::addressOnlyHash()(zeroport); - auto& shard = d_shards[hash % d_shards.size()]; - { - auto limits = shard.lock(); - auto iter = limits->find(zeroport); - if (iter == limits->end()) { - Entry e(zeroport, QPSLimiter(d_qps, d_burst)); - iter = limits->insert(e).first; - } - - moveCacheItemToBack(*limits, iter); - return !iter->d_limiter.check(d_qps, d_burst); - } - } - - string toString() const override - { - return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst); - } - - size_t getEntriesCount() const - { - size_t count = 0; - for (auto& shard : d_shards) { - count += shard.lock()->size(); - } - return count; - } - - size_t getNumberOfShards() const - { - return d_shards.size(); - } - -private: - struct HashedTag {}; - struct SequencedTag {}; - struct Entry - { - Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr) - { - } - mutable BasicQPSLimiter d_limiter; - ComboAddress d_addr; - }; - - typedef multi_index_container< - Entry, - indexed_by < - hashed_unique, member, ComboAddress::addressOnlyHash >, - sequenced > - > - > qpsContainer_t; - - mutable std::vector> d_shards; - mutable struct timespec d_lastCleanup; - const unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration; - const unsigned int d_scanFraction{10}; - mutable std::atomic_flag d_cleaningUp; -}; - -class MaxQPSRule : public DNSRule -{ -public: - MaxQPSRule(unsigned int qps) - : d_qps(qps, qps) - {} - - MaxQPSRule(unsigned int qps, unsigned int burst) - : d_qps(qps, burst) - {} - - - bool matches(const DNSQuestion* qd) const override - { - return d_qps.check(); - } - - string toString() const override - { - return "Max " + std::to_string(d_qps.getRate()) + " qps"; - } - - -private: - mutable QPSLimiter d_qps; -}; - -class NMGRule : public DNSRule -{ -public: - NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {} -protected: - NetmaskGroup d_nmg; -}; - -class NetmaskGroupRule : public NMGRule -{ -public: - NetmaskGroupRule(const NetmaskGroup& nmg, bool src, bool quiet = false) : NMGRule(nmg) - { - d_src = src; - d_quiet = quiet; - } - bool matches(const DNSQuestion* dq) const override - { - if(!d_src) { - return d_nmg.match(dq->ids.origDest); - } - return d_nmg.match(dq->ids.origRemote); - } - - string toString() const override - { - string ret = "Src: "; - if(!d_src) { - ret = "Dst: "; - } - if (d_quiet) { - return ret + "in-group"; - } - return ret + d_nmg.toString(); - } -private: - bool d_src; - bool d_quiet; -}; - -class TimedIPSetRule : public DNSRule, boost::noncopyable -{ -private: - struct IPv6 { - IPv6(const ComboAddress& ca) - { - static_assert(sizeof(*this)==16, "IPv6 struct has wrong size"); - memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16); - } - bool operator==(const IPv6& rhs) const - { - return a==rhs.a && b==rhs.b; - } - uint64_t a, b; - }; - -public: - TimedIPSetRule() - { - } - ~TimedIPSetRule() - { - } - bool matches(const DNSQuestion* dq) const override - { - if (dq->ids.origRemote.sin4.sin_family == AF_INET) { - auto ip4s = d_ip4s.read_lock(); - auto fnd = ip4s->find(dq->ids.origRemote.sin4.sin_addr.s_addr); - if (fnd == ip4s->end()) { - return false; - } - return time(nullptr) < fnd->second; - } else { - auto ip6s = d_ip6s.read_lock(); - auto fnd = ip6s->find({dq->ids.origRemote}); - if (fnd == ip6s->end()) { - return false; - } - return time(nullptr) < fnd->second; - } - } - - void add(const ComboAddress& ca, time_t ttd) - { - // think twice before adding templates here - if (ca.sin4.sin_family == AF_INET) { - auto res = d_ip4s.write_lock()->insert({ca.sin4.sin_addr.s_addr, ttd}); - if (!res.second && (time_t)res.first->second < ttd) { - res.first->second = (uint32_t)ttd; - } - } - else { - auto res = d_ip6s.write_lock()->insert({{ca}, ttd}); - if (!res.second && (time_t)res.first->second < ttd) { - // coverity[store_truncates_time_t] - res.first->second = (uint32_t)ttd; - } - } - } - - void remove(const ComboAddress& ca) - { - if (ca.sin4.sin_family == AF_INET) { - d_ip4s.write_lock()->erase(ca.sin4.sin_addr.s_addr); - } - else { - d_ip6s.write_lock()->erase({ca}); - } - } - - void clear() - { - d_ip4s.write_lock()->clear(); - d_ip6s.write_lock()->clear(); - } - - void cleanup() - { - time_t now = time(nullptr); - { - auto ip4s = d_ip4s.write_lock(); - for (auto iter = ip4s->begin(); iter != ip4s->end(); ) { - if (iter->second < now) { - iter = ip4s->erase(iter); - } - else { - ++iter; - } - } - } - - { - auto ip6s = d_ip6s.write_lock(); - for (auto iter = ip6s->begin(); iter != ip6s->end(); ) { - if (iter->second < now) { - iter = ip6s->erase(iter); - } - else { - ++iter; - } - } - - } - - } - - string toString() const override - { - time_t now = time(nullptr); - uint64_t count = 0; - - for (const auto& ip : *(d_ip4s.read_lock())) { - if (now < ip.second) { - ++count; - } - } - - for (const auto& ip : *(d_ip6s.read_lock())) { - if (now < ip.second) { - ++count; - } - } - - return "Src: "+std::to_string(count)+" ips"; - } -private: - struct IPv6Hash - { - std::size_t operator()(const IPv6& ip) const - { - auto ah=std::hash{}(ip.a); - auto bh=std::hash{}(ip.b); - return ah & (bh<<1); - } - }; - mutable SharedLockGuarded> d_ip6s; - mutable SharedLockGuarded> d_ip4s; -}; - - -class AllRule : public DNSRule -{ -public: - AllRule() {} - bool matches(const DNSQuestion* dq) const override - { - return true; - } - - string toString() const override - { - return "All"; - } - -}; - - -class DNSSECRule : public DNSRule -{ -public: - DNSSECRule() - { - - } - bool matches(const DNSQuestion* dq) const override - { - return dq->getHeader()->cd || (dnsdist::getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. - } - - string toString() const override - { - return "DNSSEC"; - } -}; - -class AndRule : public DNSRule -{ -public: - AndRule(const std::vector > >& rules) - { - for (const auto& r : rules) { - d_rules.push_back(r.second); - } - } - - bool matches(const DNSQuestion* dq) const override - { - for (const auto& rule : d_rules) { - if (!rule->matches(dq)) { - return false; - } - } - return true; - } - - string toString() const override - { - string ret; - for (const auto& rule : d_rules) { - if (!ret.empty()) { - ret+= " && "; - } - ret += "("+ rule->toString()+")"; - } - return ret; - } -private: - std::vector > d_rules; -}; - - -class OrRule : public DNSRule -{ -public: - OrRule(const std::vector > >& rules) - { - for (const auto& r : rules) { - d_rules.push_back(r.second); - } - } - - bool matches(const DNSQuestion* dq) const override - { - for (const auto& rule: d_rules) { - if (rule->matches(dq)) { - return true; - } - } - return false; - } - - string toString() const override - { - string ret; - for (const auto& rule : d_rules) { - if (!ret.empty()) { - ret+= " || "; - } - ret += "("+ rule->toString()+")"; - } - return ret; - } -private: - std::vector > d_rules; -}; - - -class RegexRule : public DNSRule -{ -public: - RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex) - { - - } - bool matches(const DNSQuestion* dq) const override - { - return d_regex.match(dq->ids.qname.toStringNoDot()); - } - - string toString() const override - { - return "Regex: "+d_visual; - } -private: - Regex d_regex; - string d_visual; -}; - -#ifdef HAVE_RE2 -#include -class RE2Rule : public DNSRule -{ -public: - RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2) - { - - } - bool matches(const DNSQuestion* dq) const override - { - return RE2::FullMatch(dq->ids.qname.toStringNoDot(), d_re2); - } - - string toString() const override - { - return "RE2 match: "+d_visual; - } -private: - RE2 d_re2; - string d_visual; -}; -#endif - -#ifdef HAVE_DNS_OVER_HTTPS -class HTTPHeaderRule : public DNSRule -{ -public: - HTTPHeaderRule(const std::string& header, const std::string& regex); - bool matches(const DNSQuestion* dq) const override; - string toString() const override; -private: - string d_header; - Regex d_regex; - string d_visual; -}; - -class HTTPPathRule : public DNSRule -{ -public: - HTTPPathRule(std::string path); - bool matches(const DNSQuestion* dq) const override; - string toString() const override; -private: - string d_path; -}; - -class HTTPPathRegexRule : public DNSRule -{ -public: - HTTPPathRegexRule(const std::string& regex); - bool matches(const DNSQuestion* dq) const override; - string toString() const override; -private: - Regex d_regex; - std::string d_visual; -}; -#endif - -class SNIRule : public DNSRule -{ -public: - SNIRule(const std::string& name) : d_sni(name) - { - } - bool matches(const DNSQuestion* dq) const override - { - return dq->sni == d_sni; - } - string toString() const override - { - return "SNI == " + d_sni; - } -private: - std::string d_sni; -}; - -class SuffixMatchNodeRule : public DNSRule -{ -public: - SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet) - { - } - bool matches(const DNSQuestion* dq) const override - { - return d_smn.check(dq->ids.qname); - } - string toString() const override - { - if(d_quiet) - return "qname==in-set"; - else - return "qname in "+d_smn.toString(); - } -private: - SuffixMatchNode d_smn; - bool d_quiet; -}; - -class QNameRule : public DNSRule -{ -public: - QNameRule(const DNSName& qname) : d_qname(qname) - { - } - - bool matches(const DNSQuestion* dq) const override - { - return d_qname==dq->ids.qname; - } - string toString() const override - { - return "qname=="+d_qname.toString(); - } -private: - DNSName d_qname; -}; - -class QNameSetRule : public DNSRule { -public: - QNameSetRule(const DNSNameSet& names) : qname_idx(names) {} - - bool matches(const DNSQuestion* dq) const override { - return qname_idx.find(dq->ids.qname) != qname_idx.end(); - } - - string toString() const override { - std::stringstream ss; - ss << "qname in DNSNameSet(" << qname_idx.size() << " FQDNs)"; - return ss.str(); - } -private: - DNSNameSet qname_idx; -}; - -class QTypeRule : public DNSRule -{ -public: - QTypeRule(uint16_t qtype) : d_qtype(qtype) - { - } - bool matches(const DNSQuestion* dq) const override - { - return d_qtype == dq->ids.qtype; - } - string toString() const override - { - QType qt(d_qtype); - return "qtype=="+qt.toString(); - } -private: - uint16_t d_qtype; -}; - -class QClassRule : public DNSRule -{ -public: - QClassRule(uint16_t qclass) : d_qclass(qclass) - { - } - bool matches(const DNSQuestion* dq) const override - { - return d_qclass == dq->ids.qclass; - } - string toString() const override - { - return "qclass=="+std::to_string(d_qclass); - } -private: - uint16_t d_qclass; -}; - -class OpcodeRule : public DNSRule -{ -public: - OpcodeRule(uint8_t opcode) : d_opcode(opcode) - { - } - bool matches(const DNSQuestion* dq) const override - { - return d_opcode == dq->getHeader()->opcode; - } - string toString() const override - { - return "opcode=="+std::to_string(d_opcode); - } -private: - uint8_t d_opcode; -}; - -class DSTPortRule : public DNSRule -{ -public: - DSTPortRule(uint16_t port) : d_port(port) - { - } - bool matches(const DNSQuestion* dq) const override - { - return htons(d_port) == dq->ids.origDest.sin4.sin_port; - } - string toString() const override - { - return "dst port=="+std::to_string(d_port); - } -private: - uint16_t d_port; -}; - -class TCPRule : public DNSRule -{ -public: - TCPRule(bool tcp): d_tcp(tcp) - { - } - bool matches(const DNSQuestion* dq) const override - { - return dq->overTCP() == d_tcp; - } - string toString() const override - { - return (d_tcp ? "TCP" : "UDP"); - } -private: - bool d_tcp; -}; - - -class NotRule : public DNSRule -{ -public: - NotRule(const std::shared_ptr& rule): d_rule(rule) - { - } - bool matches(const DNSQuestion* dq) const override - { - return !d_rule->matches(dq); - } - string toString() const override - { - return "!("+ d_rule->toString()+")"; - } -private: - std::shared_ptr d_rule; -}; - -class RecordsCountRule : public DNSRule -{ -public: - RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section) - { - } - bool matches(const DNSQuestion* dq) const override - { - uint16_t count = 0; - switch(d_section) { - case 0: - count = ntohs(dq->getHeader()->qdcount); - break; - case 1: - count = ntohs(dq->getHeader()->ancount); - break; - case 2: - count = ntohs(dq->getHeader()->nscount); - break; - case 3: - count = ntohs(dq->getHeader()->arcount); - break; - } - return count >= d_minCount && count <= d_maxCount; - } - string toString() const override - { - string section; - switch(d_section) { - case 0: - section = "QD"; - break; - case 1: - section = "AN"; - break; - case 2: - section = "NS"; - break; - case 3: - section = "AR"; - break; - } - return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount); - } -private: - uint16_t d_minCount; - uint16_t d_maxCount; - uint8_t d_section; -}; - -class RecordsTypeCountRule : public DNSRule -{ -public: - RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section) - { - } - bool matches(const DNSQuestion* dq) const override - { - uint16_t count = 0; - switch(d_section) { - case 0: - count = ntohs(dq->getHeader()->qdcount); - break; - case 1: - count = ntohs(dq->getHeader()->ancount); - break; - case 2: - count = ntohs(dq->getHeader()->nscount); - break; - case 3: - count = ntohs(dq->getHeader()->arcount); - break; - } - if (count < d_minCount) { - return false; - } - count = getRecordsOfTypeCount(reinterpret_cast(dq->getData().data()), dq->getData().size(), d_section, d_type); - return count >= d_minCount && count <= d_maxCount; - } - string toString() const override - { - string section; - switch(d_section) { - case 0: - section = "QD"; - break; - case 1: - section = "AN"; - break; - case 2: - section = "NS"; - break; - case 3: - section = "AR"; - break; - } - return std::to_string(d_minCount) + " <= " + QType(d_type).toString() + " records in " + section + " <= "+ std::to_string(d_maxCount); - } -private: - uint16_t d_type; - uint16_t d_minCount; - uint16_t d_maxCount; - uint8_t d_section; -}; - -class TrailingDataRule : public DNSRule -{ -public: - TrailingDataRule() - { - } - bool matches(const DNSQuestion* dq) const override - { - uint16_t length = getDNSPacketLength(reinterpret_cast(dq->getData().data()), dq->getData().size()); - return length < dq->getData().size(); - } - string toString() const override - { - return "trailing data"; - } -}; - -class QNameLabelsCountRule : public DNSRule -{ -public: - QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount) - { - } - bool matches(const DNSQuestion* dq) const override - { - unsigned int count = dq->ids.qname.countLabels(); - return count < d_min || count > d_max; - } - string toString() const override - { - return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max); - } -private: - unsigned int d_min; - unsigned int d_max; -}; - -class QNameWireLengthRule : public DNSRule -{ -public: - QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max) - { - } - bool matches(const DNSQuestion* dq) const override - { - size_t const wirelength = dq->ids.qname.wirelength(); - return wirelength < d_min || wirelength > d_max; - } - string toString() const override - { - return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max); - } -private: - size_t d_min; - size_t d_max; -}; - -class RCodeRule : public DNSRule -{ -public: - RCodeRule(uint8_t rcode) : d_rcode(rcode) - { - } - bool matches(const DNSQuestion* dq) const override - { - return d_rcode == dq->getHeader()->rcode; - } - string toString() const override - { - return "rcode=="+RCode::to_s(d_rcode); - } -private: - uint8_t d_rcode; -}; - -class ERCodeRule : public DNSRule -{ -public: - ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4) - { - } - bool matches(const DNSQuestion* dq) const override - { - // avoid parsing EDNS OPT RR when not needed. - if (d_rcode != dq->getHeader()->rcode) { - return false; - } - - EDNS0Record edns0; - if (!getEDNS0Record(dq->getData(), edns0)) { - return false; - } - - return d_extrcode == edns0.extRCode; - } - string toString() const override - { - return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4)); - } -private: - uint8_t d_rcode; // plain DNS Rcode - uint8_t d_extrcode; // upper bits in EDNS0 record -}; - -class EDNSVersionRule : public DNSRule -{ -public: - EDNSVersionRule(uint8_t version) : d_version(version) - { - } - bool matches(const DNSQuestion* dq) const override - { - EDNS0Record edns0; - if (!getEDNS0Record(dq->getData(), edns0)) { - return false; - } - - return d_version < edns0.version; - } - string toString() const override - { - return "ednsversion>"+std::to_string(d_version); - } -private: - uint8_t d_version; -}; - -class EDNSOptionRule : public DNSRule -{ -public: - EDNSOptionRule(uint16_t optcode) : d_optcode(optcode) - { - } - bool matches(const DNSQuestion* dq) const override - { - uint16_t optStart; - size_t optLen = 0; - bool last = false; - int res = locateEDNSOptRR(dq->getData(), &optStart, &optLen, &last); - if (res != 0) { - // no EDNS OPT RR - return false; - } - - if (optLen < optRecordMinimumSize) { - return false; - } - - if (optStart < dq->getData().size() && dq->getData().at(optStart) != 0) { - // OPT RR Name != '.' - return false; - } - - return isEDNSOptionInOpt(dq->getData(), optStart, optLen, d_optcode); - } - string toString() const override - { - return "ednsoptcode=="+std::to_string(d_optcode); - } -private: - uint16_t d_optcode; -}; - -class RDRule : public DNSRule -{ -public: - RDRule() - { - } - bool matches(const DNSQuestion* dq) const override - { - return dq->getHeader()->rd == 1; - } - string toString() const override - { - return "rd==1"; - } -}; - -class ProbaRule : public DNSRule -{ -public: - ProbaRule(double proba) : d_proba(proba) - { - } - bool matches(const DNSQuestion* dq) const override - { - if(d_proba == 1.0) - return true; - double rnd = 1.0*dns_random_uint32() / UINT32_MAX; - return rnd > (1.0 - d_proba); - } - string toString() const override - { - return "match with prob. " + (boost::format("%0.2f") % d_proba).str(); - } -private: - double d_proba; -}; - -class TagRule : public DNSRule -{ -public: - TagRule(const std::string& tag, boost::optional value) : d_value(std::move(value)), d_tag(tag) - { - } - bool matches(const DNSQuestion* dq) const override - { - if (!dq->ids.qTag) { - return false; - } - - const auto it = dq->ids.qTag->find(d_tag); - if (it == dq->ids.qTag->cend()) { - return false; - } - - if (!d_value) { - return true; - } - - return it->second == *d_value; - } - - string toString() const override - { - return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : ""); - } - -private: - boost::optional d_value; - std::string d_tag; -}; - -class PoolAvailableRule : public DNSRule -{ -public: - PoolAvailableRule(const std::string& poolname) : d_poolname(poolname) - { - } - - bool matches(const DNSQuestion* dq) const override - { - return (getPool(d_poolname)->countServers(true) > 0); - } - - string toString() const override - { - return "pool '" + d_poolname + "' is available"; - } -private: - std::string d_poolname; -}; - -class PoolOutstandingRule : public DNSRule -{ -public: - PoolOutstandingRule(const std::string& poolname, const size_t limit) : d_poolname(poolname), d_limit(limit) - { - } - - bool matches(const DNSQuestion* dq) const override - { - return (getPool(d_poolname)->poolLoad()) > d_limit; - } - - string toString() const override - { - return "pool '" + d_poolname + "' outstanding > " + std::to_string(d_limit); - } -private: - std::string d_poolname; - size_t d_limit; -}; - -class KeyValueStoreLookupRule: public DNSRule -{ -public: - KeyValueStoreLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) - { - } - - bool matches(const DNSQuestion* dq) const override - { - std::vector keys = d_key->getKeys(*dq); - for (const auto& key : keys) { - if (d_kvs->keyExists(key) == true) { - return true; - } - } - - return false; - } - - string toString() const override - { - return "lookup key-value store based on '" + d_key->toString() + "'"; - } - -private: - std::shared_ptr d_kvs; - std::shared_ptr d_key; -}; - -class KeyValueStoreRangeLookupRule: public DNSRule -{ -public: - KeyValueStoreRangeLookupRule(std::shared_ptr& kvs, std::shared_ptr& lookupKey): d_kvs(kvs), d_key(lookupKey) - { - } - - bool matches(const DNSQuestion* dq) const override - { - std::vector keys = d_key->getKeys(*dq); - for (const auto& key : keys) { - std::string value; - if (d_kvs->getRangeValue(key, value) == true) { - return true; - } - } - - return false; - } - - string toString() const override - { - return "range-based lookup key-value store based on '" + d_key->toString() + "'"; - } - -private: - std::shared_ptr d_kvs; - std::shared_ptr d_key; -}; - -class LuaRule : public DNSRule -{ -public: - typedef std::function func_t; - LuaRule(const func_t& func): d_func(func) - {} - - bool matches(const DNSQuestion* dq) const override - { - try { - auto lock = g_lua.lock(); - return d_func(dq); - } catch (const std::exception &e) { - warnlog("LuaRule failed inside Lua: %s", e.what()); - } catch (...) { - warnlog("LuaRule failed inside Lua: [unknown exception]"); - } - return false; - } - - string toString() const override - { - return "Lua script"; - } -private: - func_t d_func; -}; - -class LuaFFIRule : public DNSRule -{ -public: - typedef std::function func_t; - LuaFFIRule(const func_t& func): d_func(func) - {} - - bool matches(const DNSQuestion* dq) const override - { - dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); - try { - auto lock = g_lua.lock(); - return d_func(&dqffi); - } catch (const std::exception &e) { - warnlog("LuaFFIRule failed inside Lua: %s", e.what()); - } catch (...) { - warnlog("LuaFFIRule failed inside Lua: [unknown exception]"); - } - return false; - } - - string toString() const override - { - return "Lua FFI script"; - } -private: - func_t d_func; -}; - -class LuaFFIPerThreadRule : public DNSRule -{ -public: - typedef std::function func_t; - - LuaFFIPerThreadRule(const std::string& code): d_functionCode(code), d_functionID(s_functionsCounter++) - { - } - - bool matches(const DNSQuestion* dq) const override - { - try { - auto& state = t_perThreadStates[d_functionID]; - if (!state.d_initialized) { - setupLuaFFIPerThreadContext(state.d_luaContext); - /* mark the state as initialized first so if there is a syntax error - we only try to execute the code once */ - state.d_initialized = true; - state.d_func = state.d_luaContext.executeCode(d_functionCode); - } - - if (!state.d_func) { - /* the function was not properly initialized */ - return false; - } - - dnsdist_ffi_dnsquestion_t dqffi(const_cast(dq)); - return state.d_func(&dqffi); - } - catch (const std::exception &e) { - warnlog("LuaFFIPerthreadRule failed inside Lua: %s", e.what()); - } - catch (...) { - warnlog("LuaFFIPerThreadRule failed inside Lua: [unknown exception]"); - } - return false; - } - - string toString() const override - { - return "Lua FFI per-thread script"; - } -private: - struct PerThreadState - { - LuaContext d_luaContext; - func_t d_func; - bool d_initialized{false}; - }; - - static std::atomic s_functionsCounter; - static thread_local std::map t_perThreadStates; - const std::string d_functionCode; - const uint64_t d_functionID; -}; - -class ProxyProtocolValueRule : public DNSRule -{ -public: - ProxyProtocolValueRule(uint8_t type, boost::optional value): d_value(std::move(value)), d_type(type) - { - } - - bool matches(const DNSQuestion* dq) const override - { - if (!dq->proxyProtocolValues) { - return false; - } - - for (const auto& entry : *dq->proxyProtocolValues) { - if (entry.type == d_type && (!d_value || entry.content == *d_value)) { - return true; - } - } - - return false; - } - - string toString() const override - { - if (d_value) { - return "proxy protocol value of type " + std::to_string(d_type) + " matches"; - } - return "proxy protocol value of type " + std::to_string(d_type) + " is present"; - } - -private: - boost::optional d_value; - uint8_t d_type; -}; - -class PayloadSizeRule : public DNSRule -{ - enum class Comparisons : uint8_t { equal, greater, greaterOrEqual, smaller, smallerOrEqual }; -public: - PayloadSizeRule(const std::string& comparison, uint16_t size): d_size(size) - { - if (comparison == "equal") { - d_comparison = Comparisons::equal; - } - else if (comparison == "greater") { - d_comparison = Comparisons::greater; - } - else if (comparison == "greaterOrEqual") { - d_comparison = Comparisons::greaterOrEqual; - } - else if (comparison == "smaller") { - d_comparison = Comparisons::smaller; - } - else if (comparison == "smallerOrEqual") { - d_comparison = Comparisons::smallerOrEqual; - } - else { - throw std::runtime_error("Unsupported comparison '" + comparison + "'"); - } - } - bool matches(const DNSQuestion* dq) const override - { - const auto size = dq->getData().size(); - - switch (d_comparison) { - case Comparisons::equal: - return size == d_size; - case Comparisons::greater: - return size > d_size; - case Comparisons::greaterOrEqual: - return size >= d_size; - case Comparisons::smaller: - return size < d_size; - case Comparisons::smallerOrEqual: - return size <= d_size; - default: - return false; - } - } - - string toString() const override - { - static const std::array comparisonStr{ - "equal to" , - "greater than", - "equal to or greater than", - "smaller than", - "equal to or smaller than" - }; - return "payload size is " + comparisonStr.at(static_cast(d_comparison)) + " " + std::to_string(d_size); - } - -private: - uint16_t d_size; - Comparisons d_comparison; + mutable stat_t d_matches{0}; }; diff --git a/pdns/dnsdistdist/test-dnsdistrules_cc.cc b/pdns/dnsdistdist/test-dnsdistrules_cc.cc index 7457fb3bb1..8c91e3699e 100644 --- a/pdns/dnsdistdist/test-dnsdistrules_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistrules_cc.cc @@ -8,7 +8,7 @@ #include #include -#include "dnsdist-rules.hh" +#include "dnsdist-rules-factory.hh" void checkParameterBound(const std::string& parameter, uint64_t value, size_t max) {