]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsdistdist/dnsdist-rules.hh
dnsdist: Set a correct EDNS OPT RR for self-generated answers
[thirdparty/pdns.git] / pdns / dnsdistdist / dnsdist-rules.hh
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #pragma once
23
24 #include "cachecleaner.hh"
25 #include "dnsdist.hh"
26 #include "dnsdist-ecs.hh"
27 #include "dnsparser.hh"
28
29 class MaxQPSIPRule : public DNSRule
30 {
31 public:
32 MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60, unsigned int scanFraction=10):
33 d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration), d_scanFraction(scanFraction)
34 {
35 gettime(&d_lastCleanup, true);
36 }
37
38 void clear()
39 {
40 std::lock_guard<std::mutex> lock(d_lock);
41 d_limits.clear();
42 }
43
44 size_t cleanup(const struct timespec& cutOff, size_t* scannedCount=nullptr) const
45 {
46 std::lock_guard<std::mutex> lock(d_lock);
47 size_t toLook = d_limits.size() / d_scanFraction + 1;
48 size_t lookedAt = 0;
49
50 size_t removed = 0;
51 auto& sequence = d_limits.get<SequencedTag>();
52 for (auto entry = sequence.begin(); entry != sequence.end() && lookedAt < toLook; lookedAt++) {
53 if (entry->d_limiter.seenSince(cutOff)) {
54 /* entries are ordered from least recently seen to more recently
55 seen, as soon as we see one that has not expired yet, we are
56 done */
57 lookedAt++;
58 break;
59 }
60
61 entry = sequence.erase(entry);
62 removed++;
63 }
64
65 if (scannedCount != nullptr) {
66 *scannedCount = lookedAt;
67 }
68
69 return removed;
70 }
71
72 void cleanupIfNeeded(const struct timespec& now) const
73 {
74 if (d_cleanupDelay > 0) {
75 struct timespec cutOff = d_lastCleanup;
76 cutOff.tv_sec += d_cleanupDelay;
77
78 if (cutOff < now) {
79 /* the QPS Limiter doesn't use realtime, be careful! */
80 gettime(&cutOff, false);
81 cutOff.tv_sec -= d_expiration;
82
83 cleanup(cutOff);
84
85 d_lastCleanup = now;
86 }
87 }
88 }
89
90 bool matches(const DNSQuestion* dq) const override
91 {
92 cleanupIfNeeded(*dq->queryTime);
93
94 ComboAddress zeroport(*dq->remote);
95 zeroport.sin4.sin_port=0;
96 zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
97 {
98 std::lock_guard<std::mutex> lock(d_lock);
99 auto iter = d_limits.find(zeroport);
100 if (iter == d_limits.end()) {
101 Entry e(zeroport, QPSLimiter(d_qps, d_burst));
102 iter = d_limits.insert(e).first;
103 }
104
105 moveCacheItemToBack(d_limits, iter);
106 return !iter->d_limiter.check(d_qps, d_burst);
107 }
108 }
109
110 string toString() const override
111 {
112 return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
113 }
114
115 size_t getEntriesCount() const
116 {
117 std::lock_guard<std::mutex> lock(d_lock);
118 return d_limits.size();
119 }
120
121 private:
122 struct OrderedTag {};
123 struct SequencedTag {};
124 struct Entry
125 {
126 Entry(const ComboAddress& addr, BasicQPSLimiter&& limiter): d_limiter(limiter), d_addr(addr)
127 {
128 }
129 mutable BasicQPSLimiter d_limiter;
130 ComboAddress d_addr;
131 };
132
133 typedef multi_index_container<
134 Entry,
135 indexed_by <
136 ordered_unique<tag<OrderedTag>, member<Entry,ComboAddress,&Entry::d_addr>, ComboAddress::addressOnlyLessThan >,
137 sequenced<tag<SequencedTag> >
138 >
139 > qpsContainer_t;
140
141 mutable std::mutex d_lock;
142 mutable qpsContainer_t d_limits;
143 mutable struct timespec d_lastCleanup;
144 unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration;
145 unsigned int d_scanFraction{10};
146 };
147
148 class MaxQPSRule : public DNSRule
149 {
150 public:
151 MaxQPSRule(unsigned int qps)
152 : d_qps(qps, qps)
153 {}
154
155 MaxQPSRule(unsigned int qps, unsigned int burst)
156 : d_qps(qps, burst)
157 {}
158
159
160 bool matches(const DNSQuestion* qd) const override
161 {
162 return d_qps.check();
163 }
164
165 string toString() const override
166 {
167 return "Max " + std::to_string(d_qps.getRate()) + " qps";
168 }
169
170
171 private:
172 mutable QPSLimiter d_qps;
173 };
174
175 class NMGRule : public DNSRule
176 {
177 public:
178 NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
179 protected:
180 NetmaskGroup d_nmg;
181 };
182
183 class NetmaskGroupRule : public NMGRule
184 {
185 public:
186 NetmaskGroupRule(const NetmaskGroup& nmg, bool src) : NMGRule(nmg)
187 {
188 d_src = src;
189 }
190 bool matches(const DNSQuestion* dq) const override
191 {
192 if(!d_src) {
193 return d_nmg.match(*dq->local);
194 }
195 return d_nmg.match(*dq->remote);
196 }
197
198 string toString() const override
199 {
200 if(!d_src) {
201 return "Dst: "+d_nmg.toString();
202 }
203 return "Src: "+d_nmg.toString();
204 }
205 private:
206 bool d_src;
207 };
208
209 class TimedIPSetRule : public DNSRule, boost::noncopyable
210 {
211 private:
212 struct IPv6 {
213 IPv6(const ComboAddress& ca)
214 {
215 static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
216 memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
217 }
218 bool operator==(const IPv6& rhs) const
219 {
220 return a==rhs.a && b==rhs.b;
221 }
222 uint64_t a, b;
223 };
224
225 public:
226 TimedIPSetRule()
227 {
228 pthread_rwlock_init(&d_lock4, 0);
229 pthread_rwlock_init(&d_lock6, 0);
230 }
231 bool matches(const DNSQuestion* dq) const override
232 {
233 if(dq->remote->sin4.sin_family == AF_INET) {
234 ReadLock rl(&d_lock4);
235 auto fnd = d_ip4s.find(dq->remote->sin4.sin_addr.s_addr);
236 if(fnd == d_ip4s.end()) {
237 return false;
238 }
239 return time(0) < fnd->second;
240 } else {
241 ReadLock rl(&d_lock6);
242 auto fnd = d_ip6s.find({*dq->remote});
243 if(fnd == d_ip6s.end()) {
244 return false;
245 }
246 return time(0) < fnd->second;
247 }
248 }
249
250 void add(const ComboAddress& ca, time_t ttd)
251 {
252 // think twice before adding templates here
253 if(ca.sin4.sin_family == AF_INET) {
254 WriteLock rl(&d_lock4);
255 auto res=d_ip4s.insert({ca.sin4.sin_addr.s_addr, ttd});
256 if(!res.second && (time_t)res.first->second < ttd)
257 res.first->second = (uint32_t)ttd;
258 }
259 else {
260 WriteLock rl(&d_lock6);
261 auto res=d_ip6s.insert({{ca}, ttd});
262 if(!res.second && (time_t)res.first->second < ttd)
263 res.first->second = (uint32_t)ttd;
264 }
265 }
266
267 void remove(const ComboAddress& ca)
268 {
269 if(ca.sin4.sin_family == AF_INET) {
270 WriteLock rl(&d_lock4);
271 d_ip4s.erase(ca.sin4.sin_addr.s_addr);
272 }
273 else {
274 WriteLock rl(&d_lock6);
275 d_ip6s.erase({ca});
276 }
277 }
278
279 void clear()
280 {
281 {
282 WriteLock rl(&d_lock4);
283 d_ip4s.clear();
284 }
285 WriteLock rl(&d_lock6);
286 d_ip6s.clear();
287 }
288
289 void cleanup()
290 {
291 time_t now=time(0);
292 {
293 WriteLock rl(&d_lock4);
294
295 for(auto iter = d_ip4s.begin(); iter != d_ip4s.end(); ) {
296 if(iter->second < now)
297 iter=d_ip4s.erase(iter);
298 else
299 ++iter;
300 }
301
302 }
303
304 {
305 WriteLock rl(&d_lock6);
306
307 for(auto iter = d_ip6s.begin(); iter != d_ip6s.end(); ) {
308 if(iter->second < now)
309 iter=d_ip6s.erase(iter);
310 else
311 ++iter;
312 }
313
314 }
315
316 }
317
318 string toString() const override
319 {
320 time_t now=time(0);
321 uint64_t count = 0;
322 {
323 ReadLock rl(&d_lock4);
324 for(const auto& ip : d_ip4s)
325 if(now < ip.second)
326 ++count;
327 }
328 {
329 ReadLock rl(&d_lock6);
330 for(const auto& ip : d_ip6s)
331 if(now < ip.second)
332 ++count;
333 }
334
335 return "Src: "+std::to_string(count)+" ips";
336 }
337 private:
338 struct IPv6Hash
339 {
340 std::size_t operator()(const IPv6& ip) const
341 {
342 auto ah=std::hash<uint64_t>{}(ip.a);
343 auto bh=std::hash<uint64_t>{}(ip.b);
344 return ah & (bh<<1);
345 }
346 };
347 std::unordered_map<IPv6, time_t, IPv6Hash> d_ip6s;
348 std::unordered_map<uint32_t, time_t> d_ip4s;
349 mutable pthread_rwlock_t d_lock4;
350 mutable pthread_rwlock_t d_lock6;
351 };
352
353
354 class AllRule : public DNSRule
355 {
356 public:
357 AllRule() {}
358 bool matches(const DNSQuestion* dq) const override
359 {
360 return true;
361 }
362
363 string toString() const override
364 {
365 return "All";
366 }
367
368 };
369
370
371 class DNSSECRule : public DNSRule
372 {
373 public:
374 DNSSECRule()
375 {
376
377 }
378 bool matches(const DNSQuestion* dq) const override
379 {
380 return dq->dh->cd || (getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default..
381 }
382
383 string toString() const override
384 {
385 return "DNSSEC";
386 }
387 };
388
389 class AndRule : public DNSRule
390 {
391 public:
392 AndRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
393 {
394 for(const auto& r : rules)
395 d_rules.push_back(r.second);
396 }
397
398 bool matches(const DNSQuestion* dq) const override
399 {
400 auto iter = d_rules.begin();
401 for(; iter != d_rules.end(); ++iter)
402 if(!(*iter)->matches(dq))
403 break;
404 return iter == d_rules.end();
405 }
406
407 string toString() const override
408 {
409 string ret;
410 for(const auto& rule : d_rules) {
411 if(!ret.empty())
412 ret+= " && ";
413 ret += "("+ rule->toString()+")";
414 }
415 return ret;
416 }
417 private:
418
419 vector<std::shared_ptr<DNSRule> > d_rules;
420
421 };
422
423
424 class OrRule : public DNSRule
425 {
426 public:
427 OrRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
428 {
429 for(const auto& r : rules)
430 d_rules.push_back(r.second);
431 }
432
433 bool matches(const DNSQuestion* dq) const override
434 {
435 auto iter = d_rules.begin();
436 for(; iter != d_rules.end(); ++iter)
437 if((*iter)->matches(dq))
438 return true;
439 return false;
440 }
441
442 string toString() const override
443 {
444 string ret;
445 for(const auto& rule : d_rules) {
446 if(!ret.empty())
447 ret+= " || ";
448 ret += "("+ rule->toString()+")";
449 }
450 return ret;
451 }
452 private:
453
454 vector<std::shared_ptr<DNSRule> > d_rules;
455
456 };
457
458
459 class RegexRule : public DNSRule
460 {
461 public:
462 RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
463 {
464
465 }
466 bool matches(const DNSQuestion* dq) const override
467 {
468 return d_regex.match(dq->qname->toStringNoDot());
469 }
470
471 string toString() const override
472 {
473 return "Regex: "+d_visual;
474 }
475 private:
476 Regex d_regex;
477 string d_visual;
478 };
479
480 #ifdef HAVE_RE2
481 #include <re2/re2.h>
482 class RE2Rule : public DNSRule
483 {
484 public:
485 RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
486 {
487
488 }
489 bool matches(const DNSQuestion* dq) const override
490 {
491 return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2);
492 }
493
494 string toString() const override
495 {
496 return "RE2 match: "+d_visual;
497 }
498 private:
499 RE2 d_re2;
500 string d_visual;
501 };
502 #endif
503
504
505 class SuffixMatchNodeRule : public DNSRule
506 {
507 public:
508 SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
509 {
510 }
511 bool matches(const DNSQuestion* dq) const override
512 {
513 return d_smn.check(*dq->qname);
514 }
515 string toString() const override
516 {
517 if(d_quiet)
518 return "qname==in-set";
519 else
520 return "qname in "+d_smn.toString();
521 }
522 private:
523 SuffixMatchNode d_smn;
524 bool d_quiet;
525 };
526
527 class QNameRule : public DNSRule
528 {
529 public:
530 QNameRule(const DNSName& qname) : d_qname(qname)
531 {
532 }
533 bool matches(const DNSQuestion* dq) const override
534 {
535 return d_qname==*dq->qname;
536 }
537 string toString() const override
538 {
539 return "qname=="+d_qname.toString();
540 }
541 private:
542 DNSName d_qname;
543 };
544
545
546 class QTypeRule : public DNSRule
547 {
548 public:
549 QTypeRule(uint16_t qtype) : d_qtype(qtype)
550 {
551 }
552 bool matches(const DNSQuestion* dq) const override
553 {
554 return d_qtype == dq->qtype;
555 }
556 string toString() const override
557 {
558 QType qt(d_qtype);
559 return "qtype=="+qt.getName();
560 }
561 private:
562 uint16_t d_qtype;
563 };
564
565 class QClassRule : public DNSRule
566 {
567 public:
568 QClassRule(uint16_t qclass) : d_qclass(qclass)
569 {
570 }
571 bool matches(const DNSQuestion* dq) const override
572 {
573 return d_qclass == dq->qclass;
574 }
575 string toString() const override
576 {
577 return "qclass=="+std::to_string(d_qclass);
578 }
579 private:
580 uint16_t d_qclass;
581 };
582
583 class OpcodeRule : public DNSRule
584 {
585 public:
586 OpcodeRule(uint8_t opcode) : d_opcode(opcode)
587 {
588 }
589 bool matches(const DNSQuestion* dq) const override
590 {
591 return d_opcode == dq->dh->opcode;
592 }
593 string toString() const override
594 {
595 return "opcode=="+std::to_string(d_opcode);
596 }
597 private:
598 uint8_t d_opcode;
599 };
600
601 class DSTPortRule : public DNSRule
602 {
603 public:
604 DSTPortRule(uint16_t port) : d_port(port)
605 {
606 }
607 bool matches(const DNSQuestion* dq) const override
608 {
609 return htons(d_port) == dq->local->sin4.sin_port;
610 }
611 string toString() const override
612 {
613 return "dst port=="+std::to_string(d_port);
614 }
615 private:
616 uint16_t d_port;
617 };
618
619 class TCPRule : public DNSRule
620 {
621 public:
622 TCPRule(bool tcp): d_tcp(tcp)
623 {
624 }
625 bool matches(const DNSQuestion* dq) const override
626 {
627 return dq->tcp == d_tcp;
628 }
629 string toString() const override
630 {
631 return (d_tcp ? "TCP" : "UDP");
632 }
633 private:
634 bool d_tcp;
635 };
636
637
638 class NotRule : public DNSRule
639 {
640 public:
641 NotRule(shared_ptr<DNSRule>& rule): d_rule(rule)
642 {
643 }
644 bool matches(const DNSQuestion* dq) const override
645 {
646 return !d_rule->matches(dq);
647 }
648 string toString() const override
649 {
650 return "!("+ d_rule->toString()+")";
651 }
652 private:
653 shared_ptr<DNSRule> d_rule;
654 };
655
656 class RecordsCountRule : public DNSRule
657 {
658 public:
659 RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
660 {
661 }
662 bool matches(const DNSQuestion* dq) const override
663 {
664 uint16_t count = 0;
665 switch(d_section) {
666 case 0:
667 count = ntohs(dq->dh->qdcount);
668 break;
669 case 1:
670 count = ntohs(dq->dh->ancount);
671 break;
672 case 2:
673 count = ntohs(dq->dh->nscount);
674 break;
675 case 3:
676 count = ntohs(dq->dh->arcount);
677 break;
678 }
679 return count >= d_minCount && count <= d_maxCount;
680 }
681 string toString() const override
682 {
683 string section;
684 switch(d_section) {
685 case 0:
686 section = "QD";
687 break;
688 case 1:
689 section = "AN";
690 break;
691 case 2:
692 section = "NS";
693 break;
694 case 3:
695 section = "AR";
696 break;
697 }
698 return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
699 }
700 private:
701 uint16_t d_minCount;
702 uint16_t d_maxCount;
703 uint8_t d_section;
704 };
705
706 class RecordsTypeCountRule : public DNSRule
707 {
708 public:
709 RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
710 {
711 }
712 bool matches(const DNSQuestion* dq) const override
713 {
714 uint16_t count = 0;
715 switch(d_section) {
716 case 0:
717 count = ntohs(dq->dh->qdcount);
718 break;
719 case 1:
720 count = ntohs(dq->dh->ancount);
721 break;
722 case 2:
723 count = ntohs(dq->dh->nscount);
724 break;
725 case 3:
726 count = ntohs(dq->dh->arcount);
727 break;
728 }
729 if (count < d_minCount) {
730 return false;
731 }
732 count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->dh), dq->len, d_section, d_type);
733 return count >= d_minCount && count <= d_maxCount;
734 }
735 string toString() const override
736 {
737 string section;
738 switch(d_section) {
739 case 0:
740 section = "QD";
741 break;
742 case 1:
743 section = "AN";
744 break;
745 case 2:
746 section = "NS";
747 break;
748 case 3:
749 section = "AR";
750 break;
751 }
752 return std::to_string(d_minCount) + " <= " + QType(d_type).getName() + " records in " + section + " <= "+ std::to_string(d_maxCount);
753 }
754 private:
755 uint16_t d_type;
756 uint16_t d_minCount;
757 uint16_t d_maxCount;
758 uint8_t d_section;
759 };
760
761 class TrailingDataRule : public DNSRule
762 {
763 public:
764 TrailingDataRule()
765 {
766 }
767 bool matches(const DNSQuestion* dq) const override
768 {
769 uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->dh), dq->len);
770 return length < dq->len;
771 }
772 string toString() const override
773 {
774 return "trailing data";
775 }
776 };
777
778 class QNameLabelsCountRule : public DNSRule
779 {
780 public:
781 QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
782 {
783 }
784 bool matches(const DNSQuestion* dq) const override
785 {
786 unsigned int count = dq->qname->countLabels();
787 return count < d_min || count > d_max;
788 }
789 string toString() const override
790 {
791 return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
792 }
793 private:
794 unsigned int d_min;
795 unsigned int d_max;
796 };
797
798 class QNameWireLengthRule : public DNSRule
799 {
800 public:
801 QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
802 {
803 }
804 bool matches(const DNSQuestion* dq) const override
805 {
806 size_t const wirelength = dq->qname->wirelength();
807 return wirelength < d_min || wirelength > d_max;
808 }
809 string toString() const override
810 {
811 return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
812 }
813 private:
814 size_t d_min;
815 size_t d_max;
816 };
817
818 class RCodeRule : public DNSRule
819 {
820 public:
821 RCodeRule(uint8_t rcode) : d_rcode(rcode)
822 {
823 }
824 bool matches(const DNSQuestion* dq) const override
825 {
826 return d_rcode == dq->dh->rcode;
827 }
828 string toString() const override
829 {
830 return "rcode=="+RCode::to_s(d_rcode);
831 }
832 private:
833 uint8_t d_rcode;
834 };
835
836 class ERCodeRule : public DNSRule
837 {
838 public:
839 ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
840 {
841 }
842 bool matches(const DNSQuestion* dq) const override
843 {
844 // avoid parsing EDNS OPT RR when not needed.
845 if (d_rcode != dq->dh->rcode) {
846 return false;
847 }
848
849 uint16_t optStart;
850 size_t optLen = 0;
851 bool last = false;
852 const char * packet = reinterpret_cast<const char*>(dq->dh);
853 std::string packetStr(packet, dq->len);
854 int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
855 if (res != 0) {
856 // no EDNS OPT RR
857 return d_extrcode == 0;
858 }
859
860 // root label (1), type (2), class (2), ttl (4) + rdlen (2)
861 if (optLen < 11) {
862 return false;
863 }
864
865 if (optStart < dq->len && packet[optStart] != 0) {
866 // OPT RR Name != '.'
867 return false;
868 }
869 EDNS0Record edns0;
870 static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
871 // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
872 memcpy(&edns0, packet + optStart + 5, sizeof edns0);
873
874 return d_extrcode == edns0.extRCode;
875 }
876 string toString() const override
877 {
878 return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
879 }
880 private:
881 uint8_t d_rcode; // plain DNS Rcode
882 uint8_t d_extrcode; // upper bits in EDNS0 record
883 };
884
885 class EDNSOptionRule : public DNSRule
886 {
887 public:
888 EDNSOptionRule(uint16_t optcode) : d_optcode(optcode)
889 {
890 }
891 bool matches(const DNSQuestion* dq) const override
892 {
893 uint16_t optStart;
894 size_t optLen = 0;
895 bool last = false;
896 const char * packet = reinterpret_cast<const char*>(dq->dh);
897 std::string packetStr(packet, dq->len);
898 int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
899 if (res != 0) {
900 // no EDNS OPT RR
901 return false;
902 }
903
904 // root label (1), type (2), class (2), ttl (4) + rdlen (2)
905 if (optLen < 11) {
906 return false;
907 }
908
909 if (optStart < dq->len && packetStr.at(optStart) != 0) {
910 // OPT RR Name != '.'
911 return false;
912 }
913
914 return isEDNSOptionInOpt(packetStr, optStart, optLen, d_optcode);
915 }
916 string toString() const override
917 {
918 return "ednsoptcode=="+std::to_string(d_optcode);
919 }
920 private:
921 uint16_t d_optcode;
922 };
923
924 class RDRule : public DNSRule
925 {
926 public:
927 RDRule()
928 {
929 }
930 bool matches(const DNSQuestion* dq) const override
931 {
932 return dq->dh->rd == 1;
933 }
934 string toString() const override
935 {
936 return "rd==1";
937 }
938 };
939
940 class ProbaRule : public DNSRule
941 {
942 public:
943 ProbaRule(double proba) : d_proba(proba)
944 {
945 }
946 bool matches(const DNSQuestion* dq) const override
947 {
948 if(d_proba == 1.0)
949 return true;
950 double rnd = 1.0*random() / RAND_MAX;
951 return rnd > (1.0 - d_proba);
952 }
953 string toString() const override
954 {
955 return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
956 }
957 private:
958 double d_proba;
959 };
960
961 class TagRule : public DNSRule
962 {
963 public:
964 TagRule(const std::string& tag, boost::optional<std::string> value) : d_value(value), d_tag(tag)
965 {
966 }
967 bool matches(const DNSQuestion* dq) const override
968 {
969 if (!dq->qTag) {
970 return false;
971 }
972
973 const auto it = dq->qTag->find(d_tag);
974 if (it == dq->qTag->cend()) {
975 return false;
976 }
977
978 if (!d_value) {
979 return true;
980 }
981
982 return it->second == *d_value;
983 }
984
985 string toString() const override
986 {
987 return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
988 }
989
990 private:
991 boost::optional<std::string> d_value;
992 std::string d_tag;
993 };