]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdist-lua-rules.cc
rm*Rule: rename num to id
[thirdparty/pdns.git] / pdns / dnsdist-lua-rules.cc
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsdist-lua.hh"
25
26 #include "dnsparser.hh"
27
28 class MaxQPSIPRule : public DNSRule
29 {
30 public:
31 MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64) :
32 d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc)
33 {
34 pthread_rwlock_init(&d_lock, 0);
35 }
36
37 bool matches(const DNSQuestion* dq) const override
38 {
39 ComboAddress zeroport(*dq->remote);
40 zeroport.sin4.sin_port=0;
41 zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
42 {
43 ReadLock r(&d_lock);
44 const auto iter = d_limits.find(zeroport);
45 if (iter != d_limits.end()) {
46 return !iter->second.check();
47 }
48 }
49 {
50 WriteLock w(&d_lock);
51 auto iter = d_limits.find(zeroport);
52 if(iter == d_limits.end()) {
53 iter=d_limits.insert({zeroport,QPSLimiter(d_qps, d_burst)}).first;
54 }
55 return !iter->second.check();
56 }
57 }
58
59 string toString() const override
60 {
61 return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
62 }
63
64
65 private:
66 mutable pthread_rwlock_t d_lock;
67 mutable std::map<ComboAddress, QPSLimiter> d_limits;
68 unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc;
69
70 };
71
72 class MaxQPSRule : public DNSRule
73 {
74 public:
75 MaxQPSRule(unsigned int qps)
76 : d_qps(qps, qps)
77 {}
78
79 MaxQPSRule(unsigned int qps, unsigned int burst)
80 : d_qps(qps, burst)
81 {}
82
83
84 bool matches(const DNSQuestion* qd) const override
85 {
86 return d_qps.check();
87 }
88
89 string toString() const override
90 {
91 return "Max " + std::to_string(d_qps.getRate()) + " qps";
92 }
93
94
95 private:
96 mutable QPSLimiter d_qps;
97 };
98
99 class NMGRule : public DNSRule
100 {
101 public:
102 NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
103 protected:
104 NetmaskGroup d_nmg;
105 };
106
107 class NetmaskGroupRule : public NMGRule
108 {
109 public:
110 NetmaskGroupRule(const NetmaskGroup& nmg, bool src) : NMGRule(nmg)
111 {
112 d_src = src;
113 }
114 bool matches(const DNSQuestion* dq) const override
115 {
116 if(!d_src) {
117 return d_nmg.match(*dq->local);
118 }
119 return d_nmg.match(*dq->remote);
120 }
121
122 string toString() const override
123 {
124 if(!d_src) {
125 return "Dst: "+d_nmg.toString();
126 }
127 return "Src: "+d_nmg.toString();
128 }
129 private:
130 bool d_src;
131 };
132
133 class TimedIPSetRule : public DNSRule, boost::noncopyable
134 {
135 private:
136 struct IPv6 {
137 IPv6(const ComboAddress& ca)
138 {
139 static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
140 memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
141 }
142 bool operator==(const IPv6& rhs) const
143 {
144 return a==rhs.a && b==rhs.b;
145 }
146 uint64_t a, b;
147 };
148
149 public:
150 TimedIPSetRule()
151 {
152 pthread_rwlock_init(&d_lock4, 0);
153 pthread_rwlock_init(&d_lock6, 0);
154 }
155 bool matches(const DNSQuestion* dq) const override
156 {
157 if(dq->remote->sin4.sin_family == AF_INET) {
158 ReadLock rl(&d_lock4);
159 auto fnd = d_ip4s.find(dq->remote->sin4.sin_addr.s_addr);
160 if(fnd == d_ip4s.end()) {
161 return false;
162 }
163 return time(0) < fnd->second;
164 } else {
165 ReadLock rl(&d_lock6);
166 auto fnd = d_ip6s.find({*dq->remote});
167 if(fnd == d_ip6s.end()) {
168 return false;
169 }
170 return time(0) < fnd->second;
171 }
172 }
173
174 void add(const ComboAddress& ca, time_t ttd)
175 {
176 // think twice before adding templates here
177 if(ca.sin4.sin_family == AF_INET) {
178 WriteLock rl(&d_lock4);
179 auto res=d_ip4s.insert({ca.sin4.sin_addr.s_addr, ttd});
180 if(!res.second && (time_t)res.first->second < ttd)
181 res.first->second = (uint32_t)ttd;
182 }
183 else {
184 WriteLock rl(&d_lock6);
185 auto res=d_ip6s.insert({{ca}, ttd});
186 if(!res.second && (time_t)res.first->second < ttd)
187 res.first->second = (uint32_t)ttd;
188 }
189 }
190
191 void remove(const ComboAddress& ca)
192 {
193 if(ca.sin4.sin_family == AF_INET) {
194 WriteLock rl(&d_lock4);
195 d_ip4s.erase(ca.sin4.sin_addr.s_addr);
196 }
197 else {
198 WriteLock rl(&d_lock6);
199 d_ip6s.erase({ca});
200 }
201 }
202
203 void clear()
204 {
205 {
206 WriteLock rl(&d_lock4);
207 d_ip4s.clear();
208 }
209 WriteLock rl(&d_lock6);
210 d_ip6s.clear();
211 }
212
213 void cleanup()
214 {
215 time_t now=time(0);
216 {
217 WriteLock rl(&d_lock4);
218
219 for(auto iter = d_ip4s.begin(); iter != d_ip4s.end(); ) {
220 if(iter->second < now)
221 iter=d_ip4s.erase(iter);
222 else
223 ++iter;
224 }
225
226 }
227
228 {
229 WriteLock rl(&d_lock6);
230
231 for(auto iter = d_ip6s.begin(); iter != d_ip6s.end(); ) {
232 if(iter->second < now)
233 iter=d_ip6s.erase(iter);
234 else
235 ++iter;
236 }
237
238 }
239
240 }
241
242 string toString() const override
243 {
244 time_t now=time(0);
245 uint64_t count = 0;
246 {
247 ReadLock rl(&d_lock4);
248 for(const auto& ip : d_ip4s)
249 if(now < ip.second)
250 ++count;
251 }
252 {
253 ReadLock rl(&d_lock6);
254 for(const auto& ip : d_ip6s)
255 if(now < ip.second)
256 ++count;
257 }
258
259 return "Src: "+std::to_string(count)+" ips";
260 }
261 private:
262 struct IPv6Hash
263 {
264 std::size_t operator()(const IPv6& ip) const
265 {
266 auto ah=std::hash<uint64_t>{}(ip.a);
267 auto bh=std::hash<uint64_t>{}(ip.b);
268 return ah & (bh<<1);
269 }
270 };
271 std::unordered_map<IPv6, time_t, IPv6Hash> d_ip6s;
272 std::unordered_map<uint32_t, time_t> d_ip4s;
273 mutable pthread_rwlock_t d_lock4;
274 mutable pthread_rwlock_t d_lock6;
275 };
276
277
278 class AllRule : public DNSRule
279 {
280 public:
281 AllRule() {}
282 bool matches(const DNSQuestion* dq) const override
283 {
284 return true;
285 }
286
287 string toString() const override
288 {
289 return "All";
290 }
291
292 };
293
294
295 class DNSSECRule : public DNSRule
296 {
297 public:
298 DNSSECRule()
299 {
300
301 }
302 bool matches(const DNSQuestion* dq) const override
303 {
304 return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default..
305 }
306
307 string toString() const override
308 {
309 return "DNSSEC";
310 }
311 };
312
313 class AndRule : public DNSRule
314 {
315 public:
316 AndRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
317 {
318 for(const auto& r : rules)
319 d_rules.push_back(r.second);
320 }
321
322 bool matches(const DNSQuestion* dq) const override
323 {
324 auto iter = d_rules.begin();
325 for(; iter != d_rules.end(); ++iter)
326 if(!(*iter)->matches(dq))
327 break;
328 return iter == d_rules.end();
329 }
330
331 string toString() const override
332 {
333 string ret;
334 for(const auto& rule : d_rules) {
335 if(!ret.empty())
336 ret+= " && ";
337 ret += "("+ rule->toString()+")";
338 }
339 return ret;
340 }
341 private:
342
343 vector<std::shared_ptr<DNSRule> > d_rules;
344
345 };
346
347
348 class OrRule : public DNSRule
349 {
350 public:
351 OrRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
352 {
353 for(const auto& r : rules)
354 d_rules.push_back(r.second);
355 }
356
357 bool matches(const DNSQuestion* dq) const override
358 {
359 auto iter = d_rules.begin();
360 for(; iter != d_rules.end(); ++iter)
361 if((*iter)->matches(dq))
362 return true;
363 return false;
364 }
365
366 string toString() const override
367 {
368 string ret;
369 for(const auto& rule : d_rules) {
370 if(!ret.empty())
371 ret+= " || ";
372 ret += "("+ rule->toString()+")";
373 }
374 return ret;
375 }
376 private:
377
378 vector<std::shared_ptr<DNSRule> > d_rules;
379
380 };
381
382
383 class RegexRule : public DNSRule
384 {
385 public:
386 RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
387 {
388
389 }
390 bool matches(const DNSQuestion* dq) const override
391 {
392 return d_regex.match(dq->qname->toStringNoDot());
393 }
394
395 string toString() const override
396 {
397 return "Regex: "+d_visual;
398 }
399 private:
400 Regex d_regex;
401 string d_visual;
402 };
403
404 #ifdef HAVE_RE2
405 #include <re2/re2.h>
406 class RE2Rule : public DNSRule
407 {
408 public:
409 RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
410 {
411
412 }
413 bool matches(const DNSQuestion* dq) const override
414 {
415 return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2);
416 }
417
418 string toString() const override
419 {
420 return "RE2 match: "+d_visual;
421 }
422 private:
423 RE2 d_re2;
424 string d_visual;
425 };
426 #endif
427
428
429 class SuffixMatchNodeRule : public DNSRule
430 {
431 public:
432 SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
433 {
434 }
435 bool matches(const DNSQuestion* dq) const override
436 {
437 return d_smn.check(*dq->qname);
438 }
439 string toString() const override
440 {
441 if(d_quiet)
442 return "qname==in-set";
443 else
444 return "qname in "+d_smn.toString();
445 }
446 private:
447 SuffixMatchNode d_smn;
448 bool d_quiet;
449 };
450
451 class QNameRule : public DNSRule
452 {
453 public:
454 QNameRule(const DNSName& qname) : d_qname(qname)
455 {
456 }
457 bool matches(const DNSQuestion* dq) const override
458 {
459 return d_qname==*dq->qname;
460 }
461 string toString() const override
462 {
463 return "qname=="+d_qname.toString();
464 }
465 private:
466 DNSName d_qname;
467 };
468
469
470 class QTypeRule : public DNSRule
471 {
472 public:
473 QTypeRule(uint16_t qtype) : d_qtype(qtype)
474 {
475 }
476 bool matches(const DNSQuestion* dq) const override
477 {
478 return d_qtype == dq->qtype;
479 }
480 string toString() const override
481 {
482 QType qt(d_qtype);
483 return "qtype=="+qt.getName();
484 }
485 private:
486 uint16_t d_qtype;
487 };
488
489 class QClassRule : public DNSRule
490 {
491 public:
492 QClassRule(uint16_t qclass) : d_qclass(qclass)
493 {
494 }
495 bool matches(const DNSQuestion* dq) const override
496 {
497 return d_qclass == dq->qclass;
498 }
499 string toString() const override
500 {
501 return "qclass=="+std::to_string(d_qclass);
502 }
503 private:
504 uint16_t d_qclass;
505 };
506
507 class OpcodeRule : public DNSRule
508 {
509 public:
510 OpcodeRule(uint8_t opcode) : d_opcode(opcode)
511 {
512 }
513 bool matches(const DNSQuestion* dq) const override
514 {
515 return d_opcode == dq->dh->opcode;
516 }
517 string toString() const override
518 {
519 return "opcode=="+std::to_string(d_opcode);
520 }
521 private:
522 uint8_t d_opcode;
523 };
524
525 class TCPRule : public DNSRule
526 {
527 public:
528 TCPRule(bool tcp): d_tcp(tcp)
529 {
530 }
531 bool matches(const DNSQuestion* dq) const override
532 {
533 return dq->tcp == d_tcp;
534 }
535 string toString() const override
536 {
537 return (d_tcp ? "TCP" : "UDP");
538 }
539 private:
540 bool d_tcp;
541 };
542
543
544 class NotRule : public DNSRule
545 {
546 public:
547 NotRule(shared_ptr<DNSRule>& rule): d_rule(rule)
548 {
549 }
550 bool matches(const DNSQuestion* dq) const override
551 {
552 return !d_rule->matches(dq);
553 }
554 string toString() const override
555 {
556 return "!("+ d_rule->toString()+")";
557 }
558 private:
559 shared_ptr<DNSRule> d_rule;
560 };
561
562 class RecordsCountRule : public DNSRule
563 {
564 public:
565 RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
566 {
567 }
568 bool matches(const DNSQuestion* dq) const override
569 {
570 uint16_t count = 0;
571 switch(d_section) {
572 case 0:
573 count = ntohs(dq->dh->qdcount);
574 break;
575 case 1:
576 count = ntohs(dq->dh->ancount);
577 break;
578 case 2:
579 count = ntohs(dq->dh->nscount);
580 break;
581 case 3:
582 count = ntohs(dq->dh->arcount);
583 break;
584 }
585 return count >= d_minCount && count <= d_maxCount;
586 }
587 string toString() const override
588 {
589 string section;
590 switch(d_section) {
591 case 0:
592 section = "QD";
593 break;
594 case 1:
595 section = "AN";
596 break;
597 case 2:
598 section = "NS";
599 break;
600 case 3:
601 section = "AR";
602 break;
603 }
604 return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
605 }
606 private:
607 uint16_t d_minCount;
608 uint16_t d_maxCount;
609 uint8_t d_section;
610 };
611
612 class RecordsTypeCountRule : public DNSRule
613 {
614 public:
615 RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
616 {
617 }
618 bool matches(const DNSQuestion* dq) const override
619 {
620 uint16_t count = 0;
621 switch(d_section) {
622 case 0:
623 count = ntohs(dq->dh->qdcount);
624 break;
625 case 1:
626 count = ntohs(dq->dh->ancount);
627 break;
628 case 2:
629 count = ntohs(dq->dh->nscount);
630 break;
631 case 3:
632 count = ntohs(dq->dh->arcount);
633 break;
634 }
635 if (count < d_minCount) {
636 return false;
637 }
638 count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->dh), dq->len, d_section, d_type);
639 return count >= d_minCount && count <= d_maxCount;
640 }
641 string toString() const override
642 {
643 string section;
644 switch(d_section) {
645 case 0:
646 section = "QD";
647 break;
648 case 1:
649 section = "AN";
650 break;
651 case 2:
652 section = "NS";
653 break;
654 case 3:
655 section = "AR";
656 break;
657 }
658 return std::to_string(d_minCount) + " <= " + QType(d_type).getName() + " records in " + section + " <= "+ std::to_string(d_maxCount);
659 }
660 private:
661 uint16_t d_type;
662 uint16_t d_minCount;
663 uint16_t d_maxCount;
664 uint8_t d_section;
665 };
666
667 class TrailingDataRule : public DNSRule
668 {
669 public:
670 TrailingDataRule()
671 {
672 }
673 bool matches(const DNSQuestion* dq) const override
674 {
675 uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->dh), dq->len);
676 return length < dq->len;
677 }
678 string toString() const override
679 {
680 return "trailing data";
681 }
682 };
683
684 class QNameLabelsCountRule : public DNSRule
685 {
686 public:
687 QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
688 {
689 }
690 bool matches(const DNSQuestion* dq) const override
691 {
692 unsigned int count = dq->qname->countLabels();
693 return count < d_min || count > d_max;
694 }
695 string toString() const override
696 {
697 return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
698 }
699 private:
700 unsigned int d_min;
701 unsigned int d_max;
702 };
703
704 class QNameWireLengthRule : public DNSRule
705 {
706 public:
707 QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
708 {
709 }
710 bool matches(const DNSQuestion* dq) const override
711 {
712 size_t const wirelength = dq->qname->wirelength();
713 return wirelength < d_min || wirelength > d_max;
714 }
715 string toString() const override
716 {
717 return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
718 }
719 private:
720 size_t d_min;
721 size_t d_max;
722 };
723
724 class RCodeRule : public DNSRule
725 {
726 public:
727 RCodeRule(uint8_t rcode) : d_rcode(rcode)
728 {
729 }
730 bool matches(const DNSQuestion* dq) const override
731 {
732 return d_rcode == dq->dh->rcode;
733 }
734 string toString() const override
735 {
736 return "rcode=="+RCode::to_s(d_rcode);
737 }
738 private:
739 uint8_t d_rcode;
740 };
741
742 class ERCodeRule : public DNSRule
743 {
744 public:
745 ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
746 {
747 }
748 bool matches(const DNSQuestion* dq) const override
749 {
750 // avoid parsing EDNS OPT RR when not needed.
751 if (d_rcode != dq->dh->rcode) {
752 return false;
753 }
754
755 char * optStart = NULL;
756 size_t optLen = 0;
757 bool last = false;
758 int res = locateEDNSOptRR(const_cast<char*>(reinterpret_cast<const char*>(dq->dh)), dq->len, &optStart, &optLen, &last);
759 if (res != 0) {
760 // no EDNS OPT RR
761 return d_extrcode == 0;
762 }
763
764 // root label (1), type (2), class (2), ttl (4) + rdlen (2)
765 if (optLen < 11) {
766 return false;
767 }
768
769 if (*optStart != 0) {
770 // OPT RR Name != '.'
771 return false;
772 }
773 EDNS0Record edns0;
774 static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
775 // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
776 memcpy(&edns0, optStart + 5, sizeof edns0);
777
778 return d_extrcode == edns0.extRCode;
779 }
780 string toString() const override
781 {
782 return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
783 }
784 private:
785 uint8_t d_rcode; // plain DNS Rcode
786 uint8_t d_extrcode; // upper bits in EDNS0 record
787 };
788
789 class RDRule : public DNSRule
790 {
791 public:
792 RDRule()
793 {
794 }
795 bool matches(const DNSQuestion* dq) const override
796 {
797 return dq->dh->rd == 1;
798 }
799 string toString() const override
800 {
801 return "rd==1";
802 }
803 };
804
805 class ProbaRule : public DNSRule
806 {
807 public:
808 ProbaRule(double proba) : d_proba(proba)
809 {
810 }
811 bool matches(const DNSQuestion* dq) const override
812 {
813 if(d_proba == 1.0)
814 return true;
815 double rnd = 1.0*random() / RAND_MAX;
816 return rnd > (1.0 - d_proba);
817 }
818 string toString() const override
819 {
820 return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
821 }
822 private:
823 double d_proba;
824 };
825
826 class TagRule : public DNSRule
827 {
828 public:
829 TagRule(std::string tag, boost::optional<std::string> value) : d_value(value), d_tag(tag)
830 {
831 }
832 bool matches(const DNSQuestion* dq) const override
833 {
834 if (dq->qTag == nullptr) {
835 return false;
836 }
837
838 const auto got = dq->qTag->tagData.find(d_tag);
839 if (got == dq->qTag->tagData.cend()) {
840 return false;
841 }
842
843 if (!d_value) {
844 return true;
845 }
846
847 return got->second == *d_value;
848 }
849
850 string toString() const override
851 {
852 return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
853 }
854
855 private:
856 boost::optional<std::string> d_value;
857 std::string d_tag;
858 };
859
860 std::shared_ptr<DNSRule> makeRule(const luadnsrule_t& var)
861 {
862 if (var.type() == typeid(std::shared_ptr<DNSRule>))
863 return *boost::get<std::shared_ptr<DNSRule>>(&var);
864
865 SuffixMatchNode smn;
866 NetmaskGroup nmg;
867 auto add=[&](string src) {
868 try {
869 nmg.addMask(src); // need to try mask first, all masks are domain names!
870 } catch(...) {
871 smn.add(DNSName(src));
872 }
873 };
874
875 if (var.type() == typeid(string))
876 add(*boost::get<string>(&var));
877
878 else if (var.type() == typeid(vector<pair<int, string>>))
879 for(const auto& a : *boost::get<vector<pair<int, string>>>(&var))
880 add(a.second);
881
882 else if (var.type() == typeid(DNSName))
883 smn.add(*boost::get<DNSName>(&var));
884
885 else if (var.type() == typeid(vector<pair<int, DNSName>>))
886 for(const auto& a : *boost::get<vector<pair<int, DNSName>>>(&var))
887 smn.add(a.second);
888
889 if(nmg.empty())
890 return std::make_shared<SuffixMatchNodeRule>(smn);
891 else
892 return std::make_shared<NetmaskGroupRule>(nmg, true);
893 }
894
895 static boost::uuids::uuid makeRuleID(std::string& id)
896 {
897 if (id.empty()) {
898 return t_uuidGenerator();
899 }
900
901 boost::uuids::string_generator gen;
902 return gen(id);
903 }
904
905 void parseRuleParams(boost::optional<luaruleparams_t> params, boost::uuids::uuid& uuid)
906 {
907 string uuidStr;
908
909 if (params) {
910 if (params->count("uuid")) {
911 uuidStr = boost::get<std::string>((*params)["uuid"]);
912 }
913 }
914
915 uuid = makeRuleID(uuidStr);
916 }
917
918 void setupLuaRules()
919 {
920 g_lua.writeFunction("makeRule", makeRule);
921
922 g_lua.registerFunction<string(std::shared_ptr<DNSRule>::*)()>("toString", [](const std::shared_ptr<DNSRule>& rule) { return rule->toString(); });
923
924 g_lua.writeFunction("showResponseRules", [](boost::optional<bool> showUUIDs) {
925 setLuaNoSideEffect();
926 int num=0;
927 if (showUUIDs.get_value_or(false)) {
928 boost::format fmt("%-3d %-38s %9d %-50s %s\n");
929 g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
930 for(const auto& lim : g_resprulactions.getCopy()) {
931 string name = lim.d_rule->toString();
932 g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
933 ++num;
934 }
935 }
936 else {
937 boost::format fmt("%-3d %9d %-50s %s\n");
938 g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
939 for(const auto& lim : g_resprulactions.getCopy()) {
940 string name = lim.d_rule->toString();
941 g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
942 ++num;
943 }
944 }
945 });
946
947 g_lua.writeFunction("rmResponseRule", [](boost::variant<unsigned int, std::string> id) {
948 setLuaSideEffect();
949 auto rules = g_resprulactions.getCopy();
950 if (auto str = boost::get<std::string>(&id)) {
951 boost::uuids::string_generator gen;
952 const auto uuid = gen(*str);
953 rules.erase(std::remove_if(rules.begin(),
954 rules.end(),
955 [uuid](const DNSDistResponseRuleAction& a) { return a.d_id == uuid; }),
956 rules.end());
957 }
958 else if (auto pos = boost::get<unsigned int>(&id)) {
959 if (*pos >= rules.size()) {
960 g_outputBuffer = "Error: attempt to delete non-existing rule\n";
961 return;
962 }
963 rules.erase(rules.begin()+*pos);
964 }
965 g_resprulactions.setState(rules);
966 });
967
968 g_lua.writeFunction("topResponseRule", []() {
969 setLuaSideEffect();
970 auto rules = g_resprulactions.getCopy();
971 if(rules.empty())
972 return;
973 auto subject = *rules.rbegin();
974 rules.erase(std::prev(rules.end()));
975 rules.insert(rules.begin(), subject);
976 g_resprulactions.setState(rules);
977 });
978
979 g_lua.writeFunction("mvResponseRule", [](unsigned int from, unsigned int to) {
980 setLuaSideEffect();
981 auto rules = g_resprulactions.getCopy();
982 if(from >= rules.size() || to > rules.size()) {
983 g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
984 return;
985 }
986 auto subject = rules[from];
987 rules.erase(rules.begin()+from);
988 if(to == rules.size())
989 rules.push_back(subject);
990 else {
991 if(from < to)
992 --to;
993 rules.insert(rules.begin()+to, subject);
994 }
995 g_resprulactions.setState(rules);
996 });
997
998 g_lua.writeFunction("showCacheHitResponseRules", [](boost::optional<bool> showUUIDs) {
999 setLuaNoSideEffect();
1000 int num=0;
1001 if (showUUIDs.get_value_or(false)) {
1002 boost::format fmt("%-3d %-38s %9d %-50s %s\n");
1003 g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
1004 for(const auto& lim : g_cachehitresprulactions.getCopy()) {
1005 string name = lim.d_rule->toString();
1006 g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1007 ++num;
1008 }
1009 }
1010 else {
1011 boost::format fmt("%-3d %9d %-50s %s\n");
1012 g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
1013 for(const auto& lim : g_cachehitresprulactions.getCopy()) {
1014 string name = lim.d_rule->toString();
1015 g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1016 ++num;
1017 }
1018 }
1019 });
1020
1021 g_lua.writeFunction("rmCacheHitResponseRule", [](boost::variant<unsigned int, std::string> id) {
1022 setLuaSideEffect();
1023 auto rules = g_cachehitresprulactions.getCopy();
1024 if (auto str = boost::get<std::string>(&id)) {
1025 boost::uuids::string_generator gen;
1026 const auto uuid = gen(*str);
1027 rules.erase(std::remove_if(rules.begin(),
1028 rules.end(),
1029 [uuid](const DNSDistResponseRuleAction& a) { return a.d_id == uuid; }),
1030 rules.end());
1031 }
1032 else if (auto pos = boost::get<unsigned int>(&id)) {
1033 if (*pos >= rules.size()) {
1034 g_outputBuffer = "Error: attempt to delete non-existing rule\n";
1035 return;
1036 }
1037 rules.erase(rules.begin()+*pos);
1038 }
1039 g_cachehitresprulactions.setState(rules);
1040 });
1041
1042 g_lua.writeFunction("topCacheHitResponseRule", []() {
1043 setLuaSideEffect();
1044 auto rules = g_cachehitresprulactions.getCopy();
1045 if(rules.empty())
1046 return;
1047 auto subject = *rules.rbegin();
1048 rules.erase(std::prev(rules.end()));
1049 rules.insert(rules.begin(), subject);
1050 g_cachehitresprulactions.setState(rules);
1051 });
1052
1053 g_lua.writeFunction("mvCacheHitResponseRule", [](unsigned int from, unsigned int to) {
1054 setLuaSideEffect();
1055 auto rules = g_cachehitresprulactions.getCopy();
1056 if(from >= rules.size() || to > rules.size()) {
1057 g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
1058 return;
1059 }
1060 auto subject = rules[from];
1061 rules.erase(rules.begin()+from);
1062 if(to == rules.size())
1063 rules.push_back(subject);
1064 else {
1065 if(from < to)
1066 --to;
1067 rules.insert(rules.begin()+to, subject);
1068 }
1069 g_cachehitresprulactions.setState(rules);
1070 });
1071
1072 g_lua.writeFunction("rmRule", [](boost::variant<unsigned int, std::string> id) {
1073 setLuaSideEffect();
1074 auto rules = g_rulactions.getCopy();
1075 if (auto str = boost::get<std::string>(&id)) {
1076 boost::uuids::string_generator gen;
1077 const auto uuid = gen(*str);
1078 rules.erase(std::remove_if(rules.begin(),
1079 rules.end(),
1080 [uuid](const DNSDistRuleAction& a) { return a.d_id == uuid; }),
1081 rules.end());
1082 }
1083 else if (auto pos = boost::get<unsigned int>(&id)) {
1084 if (*pos >= rules.size()) {
1085 g_outputBuffer = "Error: attempt to delete non-existing rule\n";
1086 return;
1087 }
1088 rules.erase(rules.begin()+*pos);
1089 }
1090 g_rulactions.setState(rules);
1091 });
1092
1093 g_lua.writeFunction("topRule", []() {
1094 setLuaSideEffect();
1095 auto rules = g_rulactions.getCopy();
1096 if(rules.empty())
1097 return;
1098 auto subject = *rules.rbegin();
1099 rules.erase(std::prev(rules.end()));
1100 rules.insert(rules.begin(), subject);
1101 g_rulactions.setState(rules);
1102 });
1103
1104 g_lua.writeFunction("mvRule", [](unsigned int from, unsigned int to) {
1105 setLuaSideEffect();
1106 auto rules = g_rulactions.getCopy();
1107 if(from >= rules.size() || to > rules.size()) {
1108 g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
1109 return;
1110 }
1111
1112 auto subject = rules[from];
1113 rules.erase(rules.begin()+from);
1114 if(to == rules.size())
1115 rules.push_back(subject);
1116 else {
1117 if(from < to)
1118 --to;
1119 rules.insert(rules.begin()+to, subject);
1120 }
1121 g_rulactions.setState(rules);
1122 });
1123
1124 g_lua.writeFunction("clearRules", []() {
1125 setLuaSideEffect();
1126 g_rulactions.modify([](decltype(g_rulactions)::value_type& rulactions) {
1127 rulactions.clear();
1128 });
1129 });
1130
1131 g_lua.writeFunction("setRules", [](std::vector<DNSDistRuleAction>& newruleactions) {
1132 setLuaSideEffect();
1133 g_rulactions.modify([newruleactions](decltype(g_rulactions)::value_type& gruleactions) {
1134 gruleactions.clear();
1135 for (const auto& newruleaction : newruleactions) {
1136 if (newruleaction.d_action) {
1137 auto rule=makeRule(newruleaction.d_rule);
1138 gruleactions.push_back({rule, newruleaction.d_action, newruleaction.d_id});
1139 }
1140 }
1141 });
1142 });
1143
1144 g_lua.writeFunction("MaxQPSIPRule", [](unsigned int qps, boost::optional<int> ipv4trunc, boost::optional<int> ipv6trunc, boost::optional<int> burst) {
1145 return std::shared_ptr<DNSRule>(new MaxQPSIPRule(qps, burst.get_value_or(qps), ipv4trunc.get_value_or(32), ipv6trunc.get_value_or(64)));
1146 });
1147
1148 g_lua.writeFunction("MaxQPSRule", [](unsigned int qps, boost::optional<int> burst) {
1149 if(!burst)
1150 return std::shared_ptr<DNSRule>(new MaxQPSRule(qps));
1151 else
1152 return std::shared_ptr<DNSRule>(new MaxQPSRule(qps, *burst));
1153 });
1154
1155 g_lua.writeFunction("RegexRule", [](const std::string& str) {
1156 return std::shared_ptr<DNSRule>(new RegexRule(str));
1157 });
1158
1159 #ifdef HAVE_RE2
1160 g_lua.writeFunction("RE2Rule", [](const std::string& str) {
1161 return std::shared_ptr<DNSRule>(new RE2Rule(str));
1162 });
1163 #endif
1164
1165 g_lua.writeFunction("SuffixMatchNodeRule", [](const SuffixMatchNode& smn, boost::optional<bool> quiet) {
1166 return std::shared_ptr<DNSRule>(new SuffixMatchNodeRule(smn, quiet ? *quiet : false));
1167 });
1168
1169 g_lua.writeFunction("NetmaskGroupRule", [](const NetmaskGroup& nmg, boost::optional<bool> src) {
1170 return std::shared_ptr<DNSRule>(new NetmaskGroupRule(nmg, src ? *src : true));
1171 });
1172
1173 g_lua.writeFunction("benchRule", [](std::shared_ptr<DNSRule> rule, boost::optional<int> times_, boost::optional<string> suffix_) {
1174 setLuaNoSideEffect();
1175 int times = times_.get_value_or(100000);
1176 DNSName suffix(suffix_.get_value_or("powerdns.com"));
1177 struct item {
1178 vector<uint8_t> packet;
1179 ComboAddress rem;
1180 DNSName qname;
1181 uint16_t qtype, qclass;
1182 };
1183 vector<item> items;
1184 items.reserve(1000);
1185 for(int n=0; n < 1000; ++n) {
1186 struct item i;
1187 i.qname=DNSName(std::to_string(random()));
1188 i.qname += suffix;
1189 i.qtype = random() % 0xff;
1190 i.qclass = 1;
1191 i.rem=ComboAddress("127.0.0.1");
1192 i.rem.sin4.sin_addr.s_addr = random();
1193 DNSPacketWriter pw(i.packet, i.qname, i.qtype);
1194 items.push_back(i);
1195 }
1196
1197 int matches=0;
1198 ComboAddress dummy("127.0.0.1");
1199 DTime dt;
1200 dt.set();
1201 for(int n=0; n < times; ++n) {
1202 const item& i = items[n % items.size()];
1203 DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, (struct dnsheader*)&i.packet[0], i.packet.size(), i.packet.size(), false);
1204 if(rule->matches(&dq))
1205 matches++;
1206 }
1207 double udiff=dt.udiff();
1208 g_outputBuffer=(boost::format("Had %d matches out of %d, %.1f qps, in %.1f usec\n") % matches % times % (1000000*(1.0*times/udiff)) % udiff).str();
1209
1210 });
1211
1212 g_lua.writeFunction("AllRule", []() {
1213 return std::shared_ptr<DNSRule>(new AllRule());
1214 });
1215
1216 g_lua.writeFunction("ProbaRule", [](double proba) {
1217 return std::shared_ptr<DNSRule>(new ProbaRule(proba));
1218 });
1219
1220 g_lua.writeFunction("QNameRule", [](const std::string& qname) {
1221 return std::shared_ptr<DNSRule>(new QNameRule(DNSName(qname)));
1222 });
1223
1224 g_lua.writeFunction("QTypeRule", [](boost::variant<int, std::string> str) {
1225 uint16_t qtype;
1226 if(auto dir = boost::get<int>(&str)) {
1227 qtype = *dir;
1228 }
1229 else {
1230 string val=boost::get<string>(str);
1231 qtype = QType::chartocode(val.c_str());
1232 if(!qtype)
1233 throw std::runtime_error("Unable to convert '"+val+"' to a DNS type");
1234 }
1235 return std::shared_ptr<DNSRule>(new QTypeRule(qtype));
1236 });
1237
1238 g_lua.writeFunction("QClassRule", [](int c) {
1239 return std::shared_ptr<DNSRule>(new QClassRule(c));
1240 });
1241
1242 g_lua.writeFunction("OpcodeRule", [](uint8_t code) {
1243 return std::shared_ptr<DNSRule>(new OpcodeRule(code));
1244 });
1245
1246 g_lua.writeFunction("AndRule", [](vector<pair<int, std::shared_ptr<DNSRule> > >a) {
1247 return std::shared_ptr<DNSRule>(new AndRule(a));
1248 });
1249
1250 g_lua.writeFunction("OrRule", [](vector<pair<int, std::shared_ptr<DNSRule> > >a) {
1251 return std::shared_ptr<DNSRule>(new OrRule(a));
1252 });
1253
1254 g_lua.writeFunction("TCPRule", [](bool tcp) {
1255 return std::shared_ptr<DNSRule>(new TCPRule(tcp));
1256 });
1257
1258 g_lua.writeFunction("DNSSECRule", []() {
1259 return std::shared_ptr<DNSRule>(new DNSSECRule());
1260 });
1261
1262 g_lua.writeFunction("NotRule", [](std::shared_ptr<DNSRule>rule) {
1263 return std::shared_ptr<DNSRule>(new NotRule(rule));
1264 });
1265
1266 g_lua.writeFunction("RecordsCountRule", [](uint8_t section, uint16_t minCount, uint16_t maxCount) {
1267 return std::shared_ptr<DNSRule>(new RecordsCountRule(section, minCount, maxCount));
1268 });
1269
1270 g_lua.writeFunction("RecordsTypeCountRule", [](uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount) {
1271 return std::shared_ptr<DNSRule>(new RecordsTypeCountRule(section, type, minCount, maxCount));
1272 });
1273
1274 g_lua.writeFunction("TrailingDataRule", []() {
1275 return std::shared_ptr<DNSRule>(new TrailingDataRule());
1276 });
1277
1278 g_lua.writeFunction("QNameLabelsCountRule", [](unsigned int minLabelsCount, unsigned int maxLabelsCount) {
1279 return std::shared_ptr<DNSRule>(new QNameLabelsCountRule(minLabelsCount, maxLabelsCount));
1280 });
1281
1282 g_lua.writeFunction("QNameWireLengthRule", [](size_t min, size_t max) {
1283 return std::shared_ptr<DNSRule>(new QNameWireLengthRule(min, max));
1284 });
1285
1286 g_lua.writeFunction("RCodeRule", [](uint8_t rcode) {
1287 return std::shared_ptr<DNSRule>(new RCodeRule(rcode));
1288 });
1289
1290 g_lua.writeFunction("ERCodeRule", [](uint8_t rcode) {
1291 return std::shared_ptr<DNSRule>(new ERCodeRule(rcode));
1292 });
1293
1294 g_lua.writeFunction("showRules", [](boost::optional<bool> showUUIDs) {
1295 setLuaNoSideEffect();
1296 int num=0;
1297 if (showUUIDs.get_value_or(false)) {
1298 boost::format fmt("%-3d %-38s %9d %-56s %s\n");
1299 g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
1300 for(const auto& lim : g_rulactions.getCopy()) {
1301 string name = lim.d_rule->toString();
1302 g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1303 ++num;
1304 }
1305 }
1306 else {
1307 boost::format fmt("%-3d %9d %-50s %s\n");
1308 g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
1309 for(const auto& lim : g_rulactions.getCopy()) {
1310 string name = lim.d_rule->toString();
1311 g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
1312 ++num;
1313 }
1314 }
1315 });
1316
1317 g_lua.writeFunction("RDRule", []() {
1318 return std::shared_ptr<DNSRule>(new RDRule());
1319 });
1320
1321 g_lua.writeFunction("TagRule", [](std::string tag, boost::optional<std::string> value) {
1322 return std::shared_ptr<DNSRule>(new TagRule(tag, value));
1323 });
1324
1325 g_lua.writeFunction("TimedIPSetRule", []() {
1326 return std::shared_ptr<TimedIPSetRule>(new TimedIPSetRule());
1327 });
1328
1329 g_lua.registerFunction<void(std::shared_ptr<TimedIPSetRule>::*)()>("clear", [](std::shared_ptr<TimedIPSetRule> tisr) {
1330 tisr->clear();
1331 });
1332
1333 g_lua.registerFunction<void(std::shared_ptr<TimedIPSetRule>::*)()>("cleanup", [](std::shared_ptr<TimedIPSetRule> tisr) {
1334 tisr->cleanup();
1335 });
1336
1337 g_lua.registerFunction<void(std::shared_ptr<TimedIPSetRule>::*)(const ComboAddress& ca, int t)>("add", [](std::shared_ptr<TimedIPSetRule> tisr, const ComboAddress& ca, int t) {
1338 tisr->add(ca, time(0)+t);
1339 });
1340
1341 g_lua.registerFunction<std::shared_ptr<DNSRule>(std::shared_ptr<TimedIPSetRule>::*)()>("slice", [](std::shared_ptr<TimedIPSetRule> tisr) {
1342 return std::dynamic_pointer_cast<DNSRule>(tisr);
1343 });
1344 }