]> git.ipfire.org Git - thirdparty/pdns.git/blame - pdns/dnsdist-lua-rules.cc
rm*Rule: rename num to id
[thirdparty/pdns.git] / pdns / dnsdist-lua-rules.cc
CommitLineData
6bb38cd6
RG
1/*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22#include "dnsdist.hh"
d83feb68 23#include "dnsdist-ecs.hh"
6bb38cd6
RG
24#include "dnsdist-lua.hh"
25
26#include "dnsparser.hh"
27
28class MaxQPSIPRule : public DNSRule
29{
30public:
31 MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64) :
32 d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc)
33 {
34 pthread_rwlock_init(&d_lock, 0);
35 }
36
37 bool matches(const DNSQuestion* dq) const override
38 {
39 ComboAddress zeroport(*dq->remote);
40 zeroport.sin4.sin_port=0;
41 zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
42 {
43 ReadLock r(&d_lock);
44 const auto iter = d_limits.find(zeroport);
45 if (iter != d_limits.end()) {
46 return !iter->second.check();
47 }
48 }
49 {
50 WriteLock w(&d_lock);
51 auto iter = d_limits.find(zeroport);
52 if(iter == d_limits.end()) {
53 iter=d_limits.insert({zeroport,QPSLimiter(d_qps, d_burst)}).first;
54 }
55 return !iter->second.check();
56 }
57 }
58
59 string toString() const override
60 {
61 return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
62 }
63
64
65private:
66 mutable pthread_rwlock_t d_lock;
67 mutable std::map<ComboAddress, QPSLimiter> d_limits;
68 unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc;
69
70};
71
72class MaxQPSRule : public DNSRule
73{
74public:
75 MaxQPSRule(unsigned int qps)
76 : d_qps(qps, qps)
77 {}
78
79 MaxQPSRule(unsigned int qps, unsigned int burst)
80 : d_qps(qps, burst)
81 {}
82
83
84 bool matches(const DNSQuestion* qd) const override
85 {
86 return d_qps.check();
87 }
88
89 string toString() const override
90 {
91 return "Max " + std::to_string(d_qps.getRate()) + " qps";
92 }
93
94
95private:
96 mutable QPSLimiter d_qps;
97};
98
99class NMGRule : public DNSRule
100{
101public:
102 NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
103protected:
104 NetmaskGroup d_nmg;
105};
106
107class NetmaskGroupRule : public NMGRule
108{
109public:
110 NetmaskGroupRule(const NetmaskGroup& nmg, bool src) : NMGRule(nmg)
111 {
112 d_src = src;
113 }
114 bool matches(const DNSQuestion* dq) const override
115 {
116 if(!d_src) {
117 return d_nmg.match(*dq->local);
118 }
119 return d_nmg.match(*dq->remote);
120 }
121
122 string toString() const override
123 {
124 if(!d_src) {
125 return "Dst: "+d_nmg.toString();
126 }
127 return "Src: "+d_nmg.toString();
128 }
129private:
130 bool d_src;
131};
132
133class TimedIPSetRule : public DNSRule, boost::noncopyable
134{
135private:
136 struct IPv6 {
137 IPv6(const ComboAddress& ca)
138 {
139 static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
140 memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
141 }
142 bool operator==(const IPv6& rhs) const
143 {
144 return a==rhs.a && b==rhs.b;
145 }
146 uint64_t a, b;
147 };
148
149public:
150 TimedIPSetRule()
151 {
152 pthread_rwlock_init(&d_lock4, 0);
153 pthread_rwlock_init(&d_lock6, 0);
154 }
155 bool matches(const DNSQuestion* dq) const override
156 {
157 if(dq->remote->sin4.sin_family == AF_INET) {
158 ReadLock rl(&d_lock4);
159 auto fnd = d_ip4s.find(dq->remote->sin4.sin_addr.s_addr);
160 if(fnd == d_ip4s.end()) {
161 return false;
162 }
163 return time(0) < fnd->second;
164 } else {
165 ReadLock rl(&d_lock6);
166 auto fnd = d_ip6s.find({*dq->remote});
167 if(fnd == d_ip6s.end()) {
168 return false;
169 }
170 return time(0) < fnd->second;
171 }
172 }
173
174 void add(const ComboAddress& ca, time_t ttd)
175 {
176 // think twice before adding templates here
177 if(ca.sin4.sin_family == AF_INET) {
178 WriteLock rl(&d_lock4);
179 auto res=d_ip4s.insert({ca.sin4.sin_addr.s_addr, ttd});
180 if(!res.second && (time_t)res.first->second < ttd)
181 res.first->second = (uint32_t)ttd;
182 }
183 else {
184 WriteLock rl(&d_lock6);
185 auto res=d_ip6s.insert({{ca}, ttd});
186 if(!res.second && (time_t)res.first->second < ttd)
187 res.first->second = (uint32_t)ttd;
188 }
189 }
190
191 void remove(const ComboAddress& ca)
192 {
193 if(ca.sin4.sin_family == AF_INET) {
194 WriteLock rl(&d_lock4);
195 d_ip4s.erase(ca.sin4.sin_addr.s_addr);
196 }
197 else {
198 WriteLock rl(&d_lock6);
199 d_ip6s.erase({ca});
200 }
201 }
202
203 void clear()
204 {
205 {
206 WriteLock rl(&d_lock4);
207 d_ip4s.clear();
208 }
209 WriteLock rl(&d_lock6);
210 d_ip6s.clear();
211 }
212
213 void cleanup()
214 {
215 time_t now=time(0);
216 {
217 WriteLock rl(&d_lock4);
218
219 for(auto iter = d_ip4s.begin(); iter != d_ip4s.end(); ) {
220 if(iter->second < now)
221 iter=d_ip4s.erase(iter);
222 else
223 ++iter;
224 }
225
226 }
227
228 {
229 WriteLock rl(&d_lock6);
230
231 for(auto iter = d_ip6s.begin(); iter != d_ip6s.end(); ) {
232 if(iter->second < now)
233 iter=d_ip6s.erase(iter);
234 else
235 ++iter;
236 }
237
238 }
239
240 }
241
242 string toString() const override
243 {
244 time_t now=time(0);
245 uint64_t count = 0;
246 {
247 ReadLock rl(&d_lock4);
248 for(const auto& ip : d_ip4s)
249 if(now < ip.second)
250 ++count;
251 }
252 {
253 ReadLock rl(&d_lock6);
254 for(const auto& ip : d_ip6s)
255 if(now < ip.second)
256 ++count;
257 }
258
259 return "Src: "+std::to_string(count)+" ips";
260 }
261private:
262 struct IPv6Hash
263 {
264 std::size_t operator()(const IPv6& ip) const
265 {
266 auto ah=std::hash<uint64_t>{}(ip.a);
267 auto bh=std::hash<uint64_t>{}(ip.b);
268 return ah & (bh<<1);
269 }
270 };
271 std::unordered_map<IPv6, time_t, IPv6Hash> d_ip6s;
272 std::unordered_map<uint32_t, time_t> d_ip4s;
273 mutable pthread_rwlock_t d_lock4;
274 mutable pthread_rwlock_t d_lock6;
275};
276
277
278class AllRule : public DNSRule
279{
280public:
281 AllRule() {}
282 bool matches(const DNSQuestion* dq) const override
283 {
284 return true;
285 }
286
287 string toString() const override
288 {
289 return "All";
290 }
291
292};
293
294
295class DNSSECRule : public DNSRule
296{
297public:
298 DNSSECRule()
299 {
300
301 }
302 bool matches(const DNSQuestion* dq) const override
303 {
304 return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default..
305 }
306
307 string toString() const override
308 {
309 return "DNSSEC";
310 }
311};
312
313class AndRule : public DNSRule
314{
315public:
316 AndRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
317 {
318 for(const auto& r : rules)
319 d_rules.push_back(r.second);
320 }
321
322 bool matches(const DNSQuestion* dq) const override
323 {
324 auto iter = d_rules.begin();
325 for(; iter != d_rules.end(); ++iter)
326 if(!(*iter)->matches(dq))
327 break;
328 return iter == d_rules.end();
329 }
330
331 string toString() const override
332 {
333 string ret;
334 for(const auto& rule : d_rules) {
335 if(!ret.empty())
336 ret+= " && ";
337 ret += "("+ rule->toString()+")";
338 }
339 return ret;
340 }
341private:
342
343 vector<std::shared_ptr<DNSRule> > d_rules;
344
345};
346
347
348class OrRule : public DNSRule
349{
350public:
351 OrRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
352 {
353 for(const auto& r : rules)
354 d_rules.push_back(r.second);
355 }
356
357 bool matches(const DNSQuestion* dq) const override
358 {
359 auto iter = d_rules.begin();
360 for(; iter != d_rules.end(); ++iter)
361 if((*iter)->matches(dq))
362 return true;
363 return false;
364 }
365
366 string toString() const override
367 {
368 string ret;
369 for(const auto& rule : d_rules) {
370 if(!ret.empty())
371 ret+= " || ";
372 ret += "("+ rule->toString()+")";
373 }
374 return ret;
375 }
376private:
377
378 vector<std::shared_ptr<DNSRule> > d_rules;
379
380};
381
382
383class RegexRule : public DNSRule
384{
385public:
386 RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
387 {
388
389 }
390 bool matches(const DNSQuestion* dq) const override
391 {
392 return d_regex.match(dq->qname->toStringNoDot());
393 }
394
395 string toString() const override
396 {
397 return "Regex: "+d_visual;
398 }
399private:
400 Regex d_regex;
401 string d_visual;
402};
403
404#ifdef HAVE_RE2
405#include <re2/re2.h>
406class RE2Rule : public DNSRule
407{
408public:
409 RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
410 {
411
412 }
413 bool matches(const DNSQuestion* dq) const override
414 {
415 return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2);
416 }
417
418 string toString() const override
419 {
420 return "RE2 match: "+d_visual;
421 }
422private:
423 RE2 d_re2;
424 string d_visual;
425};
426#endif
427
428
429class SuffixMatchNodeRule : public DNSRule
430{
431public:
432 SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
433 {
434 }
435 bool matches(const DNSQuestion* dq) const override
436 {
437 return d_smn.check(*dq->qname);
438 }
439 string toString() const override
440 {
441 if(d_quiet)
442 return "qname==in-set";
443 else
444 return "qname in "+d_smn.toString();
445 }
446private:
447 SuffixMatchNode d_smn;
448 bool d_quiet;
449};
450
451class QNameRule : public DNSRule
452{
453public:
454 QNameRule(const DNSName& qname) : d_qname(qname)
455 {
456 }
457 bool matches(const DNSQuestion* dq) const override
458 {
459 return d_qname==*dq->qname;
460 }
461 string toString() const override
462 {
463 return "qname=="+d_qname.toString();
464 }
465private:
466 DNSName d_qname;
467};
468
469
470class QTypeRule : public DNSRule
471{
472public:
473 QTypeRule(uint16_t qtype) : d_qtype(qtype)
474 {
475 }
476 bool matches(const DNSQuestion* dq) const override
477 {
478 return d_qtype == dq->qtype;
479 }
480 string toString() const override
481 {
482 QType qt(d_qtype);
483 return "qtype=="+qt.getName();
484 }
485private:
486 uint16_t d_qtype;
487};
488
489class QClassRule : public DNSRule
490{
491public:
492 QClassRule(uint16_t qclass) : d_qclass(qclass)
493 {
494 }
495 bool matches(const DNSQuestion* dq) const override
496 {
497 return d_qclass == dq->qclass;
498 }
499 string toString() const override
500 {
501 return "qclass=="+std::to_string(d_qclass);
502 }
503private:
504 uint16_t d_qclass;
505};
506
507class OpcodeRule : public DNSRule
508{
509public:
510 OpcodeRule(uint8_t opcode) : d_opcode(opcode)
511 {
512 }
513 bool matches(const DNSQuestion* dq) const override
514 {
515 return d_opcode == dq->dh->opcode;
516 }
517 string toString() const override
518 {
519 return "opcode=="+std::to_string(d_opcode);
520 }
521private:
522 uint8_t d_opcode;
523};
524
525class TCPRule : public DNSRule
526{
527public:
528 TCPRule(bool tcp): d_tcp(tcp)
529 {
530 }
531 bool matches(const DNSQuestion* dq) const override
532 {
533 return dq->tcp == d_tcp;
534 }
535 string toString() const override
536 {
537 return (d_tcp ? "TCP" : "UDP");
538 }
539private:
540 bool d_tcp;
541};
542
543
544class NotRule : public DNSRule
545{
546public:
547 NotRule(shared_ptr<DNSRule>& rule): d_rule(rule)
548 {
549 }
550 bool matches(const DNSQuestion* dq) const override
551 {
552 return !d_rule->matches(dq);
553 }
554 string toString() const override
555 {
556 return "!("+ d_rule->toString()+")";
557 }
558private:
559 shared_ptr<DNSRule> d_rule;
560};
561
562class RecordsCountRule : public DNSRule
563{
564public:
565 RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
566 {
567 }
568 bool matches(const DNSQuestion* dq) const override
569 {
570 uint16_t count = 0;
571 switch(d_section) {
572 case 0:
573 count = ntohs(dq->dh->qdcount);
574 break;
575 case 1:
576 count = ntohs(dq->dh->ancount);
577 break;
578 case 2:
579 count = ntohs(dq->dh->nscount);
580 break;
581 case 3:
582 count = ntohs(dq->dh->arcount);
583 break;
584 }
585 return count >= d_minCount && count <= d_maxCount;
586 }
587 string toString() const override
588 {
589 string section;
590 switch(d_section) {
591 case 0:
592 section = "QD";
593 break;
594 case 1:
595 section = "AN";
596 break;
597 case 2:
598 section = "NS";
599 break;
600 case 3:
601 section = "AR";
602 break;
603 }
604 return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
605 }
606private:
607 uint16_t d_minCount;
608 uint16_t d_maxCount;
609 uint8_t d_section;
610};
611
612class RecordsTypeCountRule : public DNSRule
613{
614public:
615 RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
616 {
617 }
618 bool matches(const DNSQuestion* dq) const override
619 {
620 uint16_t count = 0;
621 switch(d_section) {
622 case 0:
623 count = ntohs(dq->dh->qdcount);
624 break;
625 case 1:
626 count = ntohs(dq->dh->ancount);
627 break;
628 case 2:
629 count = ntohs(dq->dh->nscount);
630 break;
631 case 3:
632 count = ntohs(dq->dh->arcount);
633 break;
634 }
635 if (count < d_minCount) {
636 return false;
637 }
638 count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->dh), dq->len, d_section, d_type);
639 return count >= d_minCount && count <= d_maxCount;
640 }
641 string toString() const override
642 {
643 string section;
644 switch(d_section) {
645 case 0:
646 section = "QD";
647 break;
648 case 1:
649 section = "AN";
650 break;
651 case 2:
652 section = "NS";
653 break;
654 case 3:
655 section = "AR";
656 break;
657 }
658 return std::to_string(d_minCount) + " <= " + QType(d_type).getName() + " records in " + section + " <= "+ std::to_string(d_maxCount);
659 }
660private:
661 uint16_t d_type;
662 uint16_t d_minCount;
663 uint16_t d_maxCount;
664 uint8_t d_section;
665};
666
667class TrailingDataRule : public DNSRule
668{
669public:
670 TrailingDataRule()
671 {
672 }
673 bool matches(const DNSQuestion* dq) const override
674 {
675 uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->dh), dq->len);
676 return length < dq->len;
677 }
678 string toString() const override
679 {
680 return "trailing data";
681 }
682};
683
684class QNameLabelsCountRule : public DNSRule
685{
686public:
687 QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
688 {
689 }
690 bool matches(const DNSQuestion* dq) const override
691 {
692 unsigned int count = dq->qname->countLabels();
693 return count < d_min || count > d_max;
694 }
695 string toString() const override
696 {
697 return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
698 }
699private:
700 unsigned int d_min;
701 unsigned int d_max;
702};
703
704class QNameWireLengthRule : public DNSRule
705{
706public:
707 QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
708 {
709 }
710 bool matches(const DNSQuestion* dq) const override
711 {
712 size_t const wirelength = dq->qname->wirelength();
713 return wirelength < d_min || wirelength > d_max;
714 }
715 string toString() const override
716 {
717 return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
718 }
719private:
720 size_t d_min;
721 size_t d_max;
722};
723
724class RCodeRule : public DNSRule
725{
726public:
f6007449 727 RCodeRule(uint8_t rcode) : d_rcode(rcode)
6bb38cd6
RG
728 {
729 }
730 bool matches(const DNSQuestion* dq) const override
731 {
732 return d_rcode == dq->dh->rcode;
733 }
734 string toString() const override
735 {
736 return "rcode=="+RCode::to_s(d_rcode);
737 }
738private:
f6007449 739 uint8_t d_rcode;
6bb38cd6
RG
740};
741
d83feb68
CH
742class ERCodeRule : public DNSRule
743{
744public:
f6007449 745 ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
d83feb68
CH
746 {
747 }
748 bool matches(const DNSQuestion* dq) const override
749 {
750 // avoid parsing EDNS OPT RR when not needed.
751 if (d_rcode != dq->dh->rcode) {
752 return false;
753 }
754
755 char * optStart = NULL;
756 size_t optLen = 0;
757 bool last = false;
85db7e3c 758 int res = locateEDNSOptRR(const_cast<char*>(reinterpret_cast<const char*>(dq->dh)), dq->len, &optStart, &optLen, &last);
d83feb68
CH
759 if (res != 0) {
760 // no EDNS OPT RR
761 return d_extrcode == 0;
762 }
763
f865c0b0 764 // root label (1), type (2), class (2), ttl (4) + rdlen (2)
d83feb68
CH
765 if (optLen < 11) {
766 return false;
767 }
768
769 if (*optStart != 0) {
770 // OPT RR Name != '.'
771 return false;
772 }
773 EDNS0Record edns0;
774 static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
f865c0b0 775 // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
d83feb68
CH
776 memcpy(&edns0, optStart + 5, sizeof edns0);
777
778 return d_extrcode == edns0.extRCode;
779 }
780 string toString() const override
781 {
782 return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
783 }
784private:
f6007449
CH
785 uint8_t d_rcode; // plain DNS Rcode
786 uint8_t d_extrcode; // upper bits in EDNS0 record
d83feb68
CH
787};
788
6bb38cd6
RG
789class RDRule : public DNSRule
790{
791public:
792 RDRule()
793 {
794 }
795 bool matches(const DNSQuestion* dq) const override
796 {
797 return dq->dh->rd == 1;
798 }
799 string toString() const override
800 {
801 return "rd==1";
802 }
803};
804
805class ProbaRule : public DNSRule
806{
807public:
808 ProbaRule(double proba) : d_proba(proba)
809 {
810 }
811 bool matches(const DNSQuestion* dq) const override
812 {
813 if(d_proba == 1.0)
814 return true;
815 double rnd = 1.0*random() / RAND_MAX;
816 return rnd > (1.0 - d_proba);
817 }
818 string toString() const override
819 {
820 return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
821 }
822private:
823 double d_proba;
824};
825
826class TagRule : public DNSRule
827{
828public:
829 TagRule(std::string tag, boost::optional<std::string> value) : d_value(value), d_tag(tag)
830 {
831 }
832 bool matches(const DNSQuestion* dq) const override
833 {
834 if (dq->qTag == nullptr) {
835 return false;
836 }
837
838 const auto got = dq->qTag->tagData.find(d_tag);
839 if (got == dq->qTag->tagData.cend()) {
840 return false;
841 }
842
843 if (!d_value) {
844 return true;
845 }
846
847 return got->second == *d_value;
848 }
849
850 string toString() const override
851 {
852 return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
853 }
854
855private:
856 boost::optional<std::string> d_value;
857 std::string d_tag;
858};
859
860std::shared_ptr<DNSRule> makeRule(const luadnsrule_t& var)
861{
862 if (var.type() == typeid(std::shared_ptr<DNSRule>))
863 return *boost::get<std::shared_ptr<DNSRule>>(&var);
864
865 SuffixMatchNode smn;
866 NetmaskGroup nmg;
867 auto add=[&](string src) {
868 try {
869 nmg.addMask(src); // need to try mask first, all masks are domain names!
870 } catch(...) {
871 smn.add(DNSName(src));
872 }
873 };
874
875 if (var.type() == typeid(string))
876 add(*boost::get<string>(&var));
877
878 else if (var.type() == typeid(vector<pair<int, string>>))
879 for(const auto& a : *boost::get<vector<pair<int, string>>>(&var))
880 add(a.second);
881
882 else if (var.type() == typeid(DNSName))
883 smn.add(*boost::get<DNSName>(&var));
884
885 else if (var.type() == typeid(vector<pair<int, DNSName>>))
886 for(const auto& a : *boost::get<vector<pair<int, DNSName>>>(&var))
887 smn.add(a.second);
888
889 if(nmg.empty())
890 return std::make_shared<SuffixMatchNodeRule>(smn);
891 else
892 return std::make_shared<NetmaskGroupRule>(nmg, true);
893}
894
6a510eff 895static boost::uuids::uuid makeRuleID(std::string& id)
4d5959e6
RG
896{
897 if (id.empty()) {
898 return t_uuidGenerator();
899 }
900
901 boost::uuids::string_generator gen;
902 return gen(id);
903}
904
905void parseRuleParams(boost::optional<luaruleparams_t> params, boost::uuids::uuid& uuid)
906{
907 string uuidStr;
908
909 if (params) {
910 if (params->count("uuid")) {
911 uuidStr = boost::get<std::string>((*params)["uuid"]);
912 }
913 }
914
6a510eff 915 uuid = makeRuleID(uuidStr);
4d5959e6
RG
916}
917
6bb38cd6
RG
918void setupLuaRules()
919{
920 g_lua.writeFunction("makeRule", makeRule);
921
922 g_lua.registerFunction<string(std::shared_ptr<DNSRule>::*)()>("toString", [](const std::shared_ptr<DNSRule>& rule) { return rule->toString(); });
923
4d5959e6 924 g_lua.writeFunction("showResponseRules", [](boost::optional<bool> showUUIDs) {
6bb38cd6 925 setLuaNoSideEffect();
6bb38cd6 926 int num=0;
4d5959e6
RG
927 if (showUUIDs.get_value_or(false)) {
928 boost::format fmt("%-3d %-38s %9d %-50s %s\n");
929 g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
930 for(const auto& lim : g_resprulactions.getCopy()) {
931 string name = lim.d_rule->toString();
932 g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
933 ++num;
934 }
935 }
936 else {
937 boost::format fmt("%-3d %9d %-50s %s\n");
938 g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
939 for(const auto& lim : g_resprulactions.getCopy()) {
940 string name = lim.d_rule->toString();
941 g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
942 ++num;
943 }
6bb38cd6
RG
944 }
945 });
946
7762339e 947 g_lua.writeFunction("rmResponseRule", [](boost::variant<unsigned int, std::string> id) {
6bb38cd6
RG
948 setLuaSideEffect();
949 auto rules = g_resprulactions.getCopy();
7762339e 950 if (auto str = boost::get<std::string>(&id)) {
4d5959e6
RG
951 boost::uuids::string_generator gen;
952 const auto uuid = gen(*str);
953 rules.erase(std::remove_if(rules.begin(),
954 rules.end(),
955 [uuid](const DNSDistResponseRuleAction& a) { return a.d_id == uuid; }),
956 rules.end());
957 }
7762339e 958 else if (auto pos = boost::get<unsigned int>(&id)) {
4d5959e6
RG
959 if (*pos >= rules.size()) {
960 g_outputBuffer = "Error: attempt to delete non-existing rule\n";
961 return;
962 }
963 rules.erase(rules.begin()+*pos);
6bb38cd6 964 }
6bb38cd6
RG
965 g_resprulactions.setState(rules);
966 });
967
968 g_lua.writeFunction("topResponseRule", []() {
969 setLuaSideEffect();
970 auto rules = g_resprulactions.getCopy();
971 if(rules.empty())
972 return;
973 auto subject = *rules.rbegin();
974 rules.erase(std::prev(rules.end()));
975 rules.insert(rules.begin(), subject);
976 g_resprulactions.setState(rules);
977 });
978
979 g_lua.writeFunction("mvResponseRule", [](unsigned int from, unsigned int to) {
980 setLuaSideEffect();
981 auto rules = g_resprulactions.getCopy();
982 if(from >= rules.size() || to > rules.size()) {
983 g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
984 return;
985 }
986 auto subject = rules[from];
987 rules.erase(rules.begin()+from);
988 if(to == rules.size())
989 rules.push_back(subject);
990 else {
991 if(from < to)
992 --to;
993 rules.insert(rules.begin()+to, subject);
994 }
995 g_resprulactions.setState(rules);
996 });
997
4d5959e6 998 g_lua.writeFunction("showCacheHitResponseRules", [](boost::optional<bool> showUUIDs) {
6bb38cd6 999 setLuaNoSideEffect();
6bb38cd6 1000 int num=0;
4d5959e6
RG
1001 if (showUUIDs.get_value_or(false)) {
1002 boost::format fmt("%-3d %-38s %9d %-50s %s\n");
1003 g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
1004 for(const auto& lim : g_cachehitresprulactions.getCopy()) {
1005 string name = lim.d_rule->toString();
1006 g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1007 ++num;
1008 }
1009 }
1010 else {
1011 boost::format fmt("%-3d %9d %-50s %s\n");
1012 g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
1013 for(const auto& lim : g_cachehitresprulactions.getCopy()) {
1014 string name = lim.d_rule->toString();
1015 g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1016 ++num;
1017 }
6bb38cd6
RG
1018 }
1019 });
1020
7762339e 1021 g_lua.writeFunction("rmCacheHitResponseRule", [](boost::variant<unsigned int, std::string> id) {
6bb38cd6
RG
1022 setLuaSideEffect();
1023 auto rules = g_cachehitresprulactions.getCopy();
7762339e 1024 if (auto str = boost::get<std::string>(&id)) {
4d5959e6
RG
1025 boost::uuids::string_generator gen;
1026 const auto uuid = gen(*str);
1027 rules.erase(std::remove_if(rules.begin(),
1028 rules.end(),
1029 [uuid](const DNSDistResponseRuleAction& a) { return a.d_id == uuid; }),
1030 rules.end());
1031 }
7762339e 1032 else if (auto pos = boost::get<unsigned int>(&id)) {
4d5959e6
RG
1033 if (*pos >= rules.size()) {
1034 g_outputBuffer = "Error: attempt to delete non-existing rule\n";
1035 return;
1036 }
1037 rules.erase(rules.begin()+*pos);
6bb38cd6 1038 }
6bb38cd6
RG
1039 g_cachehitresprulactions.setState(rules);
1040 });
1041
1042 g_lua.writeFunction("topCacheHitResponseRule", []() {
1043 setLuaSideEffect();
1044 auto rules = g_cachehitresprulactions.getCopy();
1045 if(rules.empty())
1046 return;
1047 auto subject = *rules.rbegin();
1048 rules.erase(std::prev(rules.end()));
1049 rules.insert(rules.begin(), subject);
1050 g_cachehitresprulactions.setState(rules);
1051 });
1052
1053 g_lua.writeFunction("mvCacheHitResponseRule", [](unsigned int from, unsigned int to) {
1054 setLuaSideEffect();
1055 auto rules = g_cachehitresprulactions.getCopy();
1056 if(from >= rules.size() || to > rules.size()) {
1057 g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
1058 return;
1059 }
1060 auto subject = rules[from];
1061 rules.erase(rules.begin()+from);
1062 if(to == rules.size())
1063 rules.push_back(subject);
1064 else {
1065 if(from < to)
1066 --to;
1067 rules.insert(rules.begin()+to, subject);
1068 }
1069 g_cachehitresprulactions.setState(rules);
1070 });
1071
7762339e 1072 g_lua.writeFunction("rmRule", [](boost::variant<unsigned int, std::string> id) {
6bb38cd6
RG
1073 setLuaSideEffect();
1074 auto rules = g_rulactions.getCopy();
7762339e 1075 if (auto str = boost::get<std::string>(&id)) {
4d5959e6
RG
1076 boost::uuids::string_generator gen;
1077 const auto uuid = gen(*str);
1078 rules.erase(std::remove_if(rules.begin(),
1079 rules.end(),
1080 [uuid](const DNSDistRuleAction& a) { return a.d_id == uuid; }),
1081 rules.end());
1082 }
7762339e 1083 else if (auto pos = boost::get<unsigned int>(&id)) {
4d5959e6
RG
1084 if (*pos >= rules.size()) {
1085 g_outputBuffer = "Error: attempt to delete non-existing rule\n";
1086 return;
1087 }
1088 rules.erase(rules.begin()+*pos);
6bb38cd6 1089 }
6bb38cd6
RG
1090 g_rulactions.setState(rules);
1091 });
1092
1093 g_lua.writeFunction("topRule", []() {
1094 setLuaSideEffect();
1095 auto rules = g_rulactions.getCopy();
1096 if(rules.empty())
1097 return;
1098 auto subject = *rules.rbegin();
1099 rules.erase(std::prev(rules.end()));
1100 rules.insert(rules.begin(), subject);
1101 g_rulactions.setState(rules);
1102 });
1103
1104 g_lua.writeFunction("mvRule", [](unsigned int from, unsigned int to) {
1105 setLuaSideEffect();
1106 auto rules = g_rulactions.getCopy();
1107 if(from >= rules.size() || to > rules.size()) {
1108 g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
1109 return;
1110 }
1111
1112 auto subject = rules[from];
1113 rules.erase(rules.begin()+from);
1114 if(to == rules.size())
1115 rules.push_back(subject);
1116 else {
1117 if(from < to)
1118 --to;
1119 rules.insert(rules.begin()+to, subject);
1120 }
1121 g_rulactions.setState(rules);
1122 });
1123
1124 g_lua.writeFunction("clearRules", []() {
1125 setLuaSideEffect();
1126 g_rulactions.modify([](decltype(g_rulactions)::value_type& rulactions) {
1127 rulactions.clear();
1128 });
1129 });
1130
4d5959e6 1131 g_lua.writeFunction("setRules", [](std::vector<DNSDistRuleAction>& newruleactions) {
6bb38cd6
RG
1132 setLuaSideEffect();
1133 g_rulactions.modify([newruleactions](decltype(g_rulactions)::value_type& gruleactions) {
1134 gruleactions.clear();
1135 for (const auto& newruleaction : newruleactions) {
4d5959e6
RG
1136 if (newruleaction.d_action) {
1137 auto rule=makeRule(newruleaction.d_rule);
1138 gruleactions.push_back({rule, newruleaction.d_action, newruleaction.d_id});
6bb38cd6
RG
1139 }
1140 }
1141 });
1142 });
1143
1144 g_lua.writeFunction("MaxQPSIPRule", [](unsigned int qps, boost::optional<int> ipv4trunc, boost::optional<int> ipv6trunc, boost::optional<int> burst) {
1145 return std::shared_ptr<DNSRule>(new MaxQPSIPRule(qps, burst.get_value_or(qps), ipv4trunc.get_value_or(32), ipv6trunc.get_value_or(64)));
1146 });
1147
1148 g_lua.writeFunction("MaxQPSRule", [](unsigned int qps, boost::optional<int> burst) {
1149 if(!burst)
1150 return std::shared_ptr<DNSRule>(new MaxQPSRule(qps));
1151 else
1152 return std::shared_ptr<DNSRule>(new MaxQPSRule(qps, *burst));
1153 });
1154
1155 g_lua.writeFunction("RegexRule", [](const std::string& str) {
1156 return std::shared_ptr<DNSRule>(new RegexRule(str));
1157 });
1158
1159#ifdef HAVE_RE2
1160 g_lua.writeFunction("RE2Rule", [](const std::string& str) {
1161 return std::shared_ptr<DNSRule>(new RE2Rule(str));
1162 });
1163#endif
1164
1165 g_lua.writeFunction("SuffixMatchNodeRule", [](const SuffixMatchNode& smn, boost::optional<bool> quiet) {
1166 return std::shared_ptr<DNSRule>(new SuffixMatchNodeRule(smn, quiet ? *quiet : false));
1167 });
1168
1169 g_lua.writeFunction("NetmaskGroupRule", [](const NetmaskGroup& nmg, boost::optional<bool> src) {
1170 return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true));
1171 });
1172
1173 g_lua.writeFunction("benchRule", [](std::shared_ptr<DNSRule> rule, boost::optional<int> times_, boost::optional<string> suffix_) {
1174 setLuaNoSideEffect();
1175 int times = times_.get_value_or(100000);
1176 DNSName suffix(suffix_.get_value_or("powerdns.com"));
1177 struct item {
1178 vector<uint8_t> packet;
1179 ComboAddress rem;
1180 DNSName qname;
1181 uint16_t qtype, qclass;
1182 };
1183 vector<item> items;
1184 items.reserve(1000);
1185 for(int n=0; n < 1000; ++n) {
1186 struct item i;
1187 i.qname=DNSName(std::to_string(random()));
1188 i.qname += suffix;
1189 i.qtype = random() % 0xff;
1190 i.qclass = 1;
1191 i.rem=ComboAddress("127.0.0.1");
1192 i.rem.sin4.sin_addr.s_addr = random();
1193 DNSPacketWriter pw(i.packet, i.qname, i.qtype);
1194 items.push_back(i);
1195 }
1196
1197 int matches=0;
1198 ComboAddress dummy("127.0.0.1");
1199 DTime dt;
1200 dt.set();
1201 for(int n=0; n < times; ++n) {
1202 const item& i = items[n % items.size()];
1203 DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, (struct dnsheader*)&i.packet[0], i.packet.size(), i.packet.size(), false);
1204 if(rule->matches(&dq))
1205 matches++;
1206 }
1207 double udiff=dt.udiff();
1208 g_outputBuffer=(boost::format("Had %d matches out of %d, %.1f qps, in %.1f usec\n") % matches % times % (1000000*(1.0*times/udiff)) % udiff).str();
1209
1210 });
1211
1212 g_lua.writeFunction("AllRule", []() {
1213 return std::shared_ptr<DNSRule>(new AllRule());
1214 });
1215
1216 g_lua.writeFunction("ProbaRule", [](double proba) {
1217 return std::shared_ptr<DNSRule>(new ProbaRule(proba));
1218 });
1219
1220 g_lua.writeFunction("QNameRule", [](const std::string& qname) {
1221 return std::shared_ptr<DNSRule>(new QNameRule(DNSName(qname)));
1222 });
1223
1224 g_lua.writeFunction("QTypeRule", [](boost::variant<int, std::string> str) {
1225 uint16_t qtype;
1226 if(auto dir = boost::get<int>(&str)) {
1227 qtype = *dir;
1228 }
1229 else {
1230 string val=boost::get<string>(str);
1231 qtype = QType::chartocode(val.c_str());
1232 if(!qtype)
1233 throw std::runtime_error("Unable to convert '"+val+"' to a DNS type");
1234 }
1235 return std::shared_ptr<DNSRule>(new QTypeRule(qtype));
1236 });
1237
1238 g_lua.writeFunction("QClassRule", [](int c) {
1239 return std::shared_ptr<DNSRule>(new QClassRule(c));
1240 });
1241
1242 g_lua.writeFunction("OpcodeRule", [](uint8_t code) {
1243 return std::shared_ptr<DNSRule>(new OpcodeRule(code));
1244 });
1245
1246 g_lua.writeFunction("AndRule", [](vector<pair<int, std::shared_ptr<DNSRule> > >a) {
1247 return std::shared_ptr<DNSRule>(new AndRule(a));
1248 });
1249
1250 g_lua.writeFunction("OrRule", [](vector<pair<int, std::shared_ptr<DNSRule> > >a) {
1251 return std::shared_ptr<DNSRule>(new OrRule(a));
1252 });
1253
1254 g_lua.writeFunction("TCPRule", [](bool tcp) {
1255 return std::shared_ptr<DNSRule>(new TCPRule(tcp));
1256 });
1257
1258 g_lua.writeFunction("DNSSECRule", []() {
1259 return std::shared_ptr<DNSRule>(new DNSSECRule());
1260 });
1261
1262 g_lua.writeFunction("NotRule", [](std::shared_ptr<DNSRule>rule) {
1263 return std::shared_ptr<DNSRule>(new NotRule(rule));
1264 });
1265
1266 g_lua.writeFunction("RecordsCountRule", [](uint8_t section, uint16_t minCount, uint16_t maxCount) {
1267 return std::shared_ptr<DNSRule>(new RecordsCountRule(section, minCount, maxCount));
1268 });
1269
1270 g_lua.writeFunction("RecordsTypeCountRule", [](uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount) {
1271 return std::shared_ptr<DNSRule>(new RecordsTypeCountRule(section, type, minCount, maxCount));
1272 });
1273
1274 g_lua.writeFunction("TrailingDataRule", []() {
1275 return std::shared_ptr<DNSRule>(new TrailingDataRule());
1276 });
1277
1278 g_lua.writeFunction("QNameLabelsCountRule", [](unsigned int minLabelsCount, unsigned int maxLabelsCount) {
1279 return std::shared_ptr<DNSRule>(new QNameLabelsCountRule(minLabelsCount, maxLabelsCount));
1280 });
1281
1282 g_lua.writeFunction("QNameWireLengthRule", [](size_t min, size_t max) {
1283 return std::shared_ptr<DNSRule>(new QNameWireLengthRule(min, max));
1284 });
1285
f6007449 1286 g_lua.writeFunction("RCodeRule", [](uint8_t rcode) {
6bb38cd6
RG
1287 return std::shared_ptr<DNSRule>(new RCodeRule(rcode));
1288 });
1289
f6007449 1290 g_lua.writeFunction("ERCodeRule", [](uint8_t rcode) {
d83feb68
CH
1291 return std::shared_ptr<DNSRule>(new ERCodeRule(rcode));
1292 });
1293
4d5959e6
RG
1294 g_lua.writeFunction("showRules", [](boost::optional<bool> showUUIDs) {
1295 setLuaNoSideEffect();
1296 int num=0;
1297 if (showUUIDs.get_value_or(false)) {
1298 boost::format fmt("%-3d %-38s %9d %-56s %s\n");
1299 g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
1300 for(const auto& lim : g_rulactions.getCopy()) {
1301 string name = lim.d_rule->toString();
1302 g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1303 ++num;
1304 }
1305 }
1306 else {
1307 boost::format fmt("%-3d %9d %-50s %s\n");
1308 g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
1309 for(const auto& lim : g_rulactions.getCopy()) {
1310 string name = lim.d_rule->toString();
1311 g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1312 ++num;
1313 }
6bb38cd6
RG
1314 }
1315 });
1316
1317 g_lua.writeFunction("RDRule", []() {
1318 return std::shared_ptr<DNSRule>(new RDRule());
1319 });
1320
1321 g_lua.writeFunction("TagRule", [](std::string tag, boost::optional<std::string> value) {
1322 return std::shared_ptr<DNSRule>(new TagRule(tag, value));
1323 });
1324
1325 g_lua.writeFunction("TimedIPSetRule", []() {
1326 return std::shared_ptr<TimedIPSetRule>(new TimedIPSetRule());
1327 });
1328
1329 g_lua.registerFunction<void(std::shared_ptr<TimedIPSetRule>::*)()>("clear", [](std::shared_ptr<TimedIPSetRule> tisr) {
1330 tisr->clear();
1331 });
1332
1333 g_lua.registerFunction<void(std::shared_ptr<TimedIPSetRule>::*)()>("cleanup", [](std::shared_ptr<TimedIPSetRule> tisr) {
1334 tisr->cleanup();
1335 });
1336
1337 g_lua.registerFunction<void(std::shared_ptr<TimedIPSetRule>::*)(const ComboAddress& ca, int t)>("add", [](std::shared_ptr<TimedIPSetRule> tisr, const ComboAddress& ca, int t) {
1338 tisr->add(ca, time(0)+t);
1339 });
1340
1341 g_lua.registerFunction<std::shared_ptr<DNSRule>(std::shared_ptr<TimedIPSetRule>::*)()>("slice", [](std::shared_ptr<TimedIPSetRule> tisr) {
1342 return std::dynamic_pointer_cast<DNSRule>(tisr);
1343 });
1344}