]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsrulactions.hh
rec: Don't account chained queries more than once
[thirdparty/pdns.git] / pdns / dnsrulactions.hh
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsname.hh"
25 #include "dolog.hh"
26 #include "ednsoptions.hh"
27 #include "lock.hh"
28 #include "remote_logger.hh"
29 #include "dnsdist-protobuf.hh"
30 #include "dnsparser.hh"
31
32 class MaxQPSIPRule : public DNSRule
33 {
34 public:
35 MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64) :
36 d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc)
37 {
38 pthread_rwlock_init(&d_lock, 0);
39 }
40
41 bool matches(const DNSQuestion* dq) const override
42 {
43 ComboAddress zeroport(*dq->remote);
44 zeroport.sin4.sin_port=0;
45 zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
46 {
47 ReadLock r(&d_lock);
48 const auto iter = d_limits.find(zeroport);
49 if (iter != d_limits.end()) {
50 return !iter->second.check();
51 }
52 }
53 {
54 WriteLock w(&d_lock);
55 auto iter = d_limits.find(zeroport);
56 if(iter == d_limits.end()) {
57 iter=d_limits.insert({zeroport,QPSLimiter(d_qps, d_burst)}).first;
58 }
59 return !iter->second.check();
60 }
61 }
62
63 string toString() const override
64 {
65 return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
66 }
67
68
69 private:
70 mutable pthread_rwlock_t d_lock;
71 mutable std::map<ComboAddress, QPSLimiter> d_limits;
72 unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc;
73
74 };
75
76 class MaxQPSRule : public DNSRule
77 {
78 public:
79 MaxQPSRule(unsigned int qps)
80 : d_qps(qps, qps)
81 {}
82
83 MaxQPSRule(unsigned int qps, unsigned int burst)
84 : d_qps(qps, burst)
85 {}
86
87
88 bool matches(const DNSQuestion* qd) const override
89 {
90 return d_qps.check();
91 }
92
93 string toString() const override
94 {
95 return "Max " + std::to_string(d_qps.getRate()) + " qps";
96 }
97
98
99 private:
100 mutable QPSLimiter d_qps;
101 };
102
103 class NMGRule : public DNSRule
104 {
105 public:
106 NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
107 protected:
108 NetmaskGroup d_nmg;
109 };
110
111 class NetmaskGroupRule : public NMGRule
112 {
113 public:
114 NetmaskGroupRule(const NetmaskGroup& nmg, bool src) : NMGRule(nmg)
115 {
116 d_src = src;
117 }
118 bool matches(const DNSQuestion* dq) const override
119 {
120 if(!d_src) {
121 return d_nmg.match(*dq->local);
122 }
123 return d_nmg.match(*dq->remote);
124 }
125
126 string toString() const override
127 {
128 if(!d_src) {
129 return "Dst: "+d_nmg.toString();
130 }
131 return "Src: "+d_nmg.toString();
132 }
133 private:
134 bool d_src;
135 };
136
137 class TimedIPSetRule : public DNSRule, boost::noncopyable
138 {
139 private:
140 struct IPv6 {
141 IPv6(const ComboAddress& ca)
142 {
143 static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
144 memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
145 }
146 bool operator==(const IPv6& rhs) const
147 {
148 return a==rhs.a && b==rhs.b;
149 }
150 uint64_t a, b;
151 };
152
153 public:
154 TimedIPSetRule()
155 {
156 pthread_rwlock_init(&d_lock4, 0);
157 pthread_rwlock_init(&d_lock6, 0);
158 }
159 bool matches(const DNSQuestion* dq) const override
160 {
161 if(dq->remote->sin4.sin_family == AF_INET) {
162 ReadLock rl(&d_lock4);
163 auto fnd = d_ip4s.find(dq->remote->sin4.sin_addr.s_addr);
164 if(fnd == d_ip4s.end()) {
165 return false;
166 }
167 return time(0) < fnd->second;
168 } else {
169 ReadLock rl(&d_lock6);
170 auto fnd = d_ip6s.find({*dq->remote});
171 if(fnd == d_ip6s.end()) {
172 return false;
173 }
174 return time(0) < fnd->second;
175 }
176 }
177
178 void add(const ComboAddress& ca, time_t ttd)
179 {
180 // think twice before adding templates here
181 if(ca.sin4.sin_family == AF_INET) {
182 WriteLock rl(&d_lock4);
183 auto res=d_ip4s.insert({ca.sin4.sin_addr.s_addr, ttd});
184 if(!res.second && (time_t)res.first->second < ttd)
185 res.first->second = (uint32_t)ttd;
186 }
187 else {
188 WriteLock rl(&d_lock6);
189 auto res=d_ip6s.insert({{ca}, ttd});
190 if(!res.second && (time_t)res.first->second < ttd)
191 res.first->second = (uint32_t)ttd;
192 }
193 }
194
195 void remove(const ComboAddress& ca)
196 {
197 if(ca.sin4.sin_family == AF_INET) {
198 WriteLock rl(&d_lock4);
199 d_ip4s.erase(ca.sin4.sin_addr.s_addr);
200 }
201 else {
202 WriteLock rl(&d_lock6);
203 d_ip6s.erase({ca});
204 }
205 }
206
207 void clear()
208 {
209 {
210 WriteLock rl(&d_lock4);
211 d_ip4s.clear();
212 }
213 WriteLock rl(&d_lock6);
214 d_ip6s.clear();
215 }
216
217 void cleanup()
218 {
219 time_t now=time(0);
220 {
221 WriteLock rl(&d_lock4);
222
223 for(auto iter = d_ip4s.begin(); iter != d_ip4s.end(); ) {
224 if(iter->second < now)
225 iter=d_ip4s.erase(iter);
226 else
227 ++iter;
228 }
229
230 }
231
232 {
233 WriteLock rl(&d_lock6);
234
235 for(auto iter = d_ip6s.begin(); iter != d_ip6s.end(); ) {
236 if(iter->second < now)
237 iter=d_ip6s.erase(iter);
238 else
239 ++iter;
240 }
241
242 }
243
244 }
245
246 string toString() const override
247 {
248 time_t now=time(0);
249 uint64_t count = 0;
250 {
251 ReadLock rl(&d_lock4);
252 for(const auto& ip : d_ip4s)
253 if(now < ip.second)
254 ++count;
255 }
256 {
257 ReadLock rl(&d_lock6);
258 for(const auto& ip : d_ip6s)
259 if(now < ip.second)
260 ++count;
261 }
262
263 return "Src: "+std::to_string(count)+" ips";
264 }
265 private:
266 struct IPv6Hash
267 {
268 std::size_t operator()(const IPv6& ip) const
269 {
270 auto ah=std::hash<uint64_t>{}(ip.a);
271 auto bh=std::hash<uint64_t>{}(ip.b);
272 return ah & (bh<<1);
273 }
274 };
275 std::unordered_map<IPv6, time_t, IPv6Hash> d_ip6s;
276 std::unordered_map<uint32_t, time_t> d_ip4s;
277 mutable pthread_rwlock_t d_lock4;
278 mutable pthread_rwlock_t d_lock6;
279 };
280
281
282 class AllRule : public DNSRule
283 {
284 public:
285 AllRule() {}
286 bool matches(const DNSQuestion* dq) const override
287 {
288 return true;
289 }
290
291 string toString() const override
292 {
293 return "All";
294 }
295
296 };
297
298
299 class DNSSECRule : public DNSRule
300 {
301 public:
302 DNSSECRule()
303 {
304
305 }
306 bool matches(const DNSQuestion* dq) const override
307 {
308 return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default..
309 }
310
311 string toString() const override
312 {
313 return "DNSSEC";
314 }
315 };
316
317 class AndRule : public DNSRule
318 {
319 public:
320 AndRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
321 {
322 for(const auto& r : rules)
323 d_rules.push_back(r.second);
324 }
325
326 bool matches(const DNSQuestion* dq) const override
327 {
328 auto iter = d_rules.begin();
329 for(; iter != d_rules.end(); ++iter)
330 if(!(*iter)->matches(dq))
331 break;
332 return iter == d_rules.end();
333 }
334
335 string toString() const override
336 {
337 string ret;
338 for(const auto& rule : d_rules) {
339 if(!ret.empty())
340 ret+= " && ";
341 ret += "("+ rule->toString()+")";
342 }
343 return ret;
344 }
345 private:
346
347 vector<std::shared_ptr<DNSRule> > d_rules;
348
349 };
350
351
352 class OrRule : public DNSRule
353 {
354 public:
355 OrRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
356 {
357 for(const auto& r : rules)
358 d_rules.push_back(r.second);
359 }
360
361 bool matches(const DNSQuestion* dq) const override
362 {
363 auto iter = d_rules.begin();
364 for(; iter != d_rules.end(); ++iter)
365 if((*iter)->matches(dq))
366 return true;
367 return false;
368 }
369
370 string toString() const override
371 {
372 string ret;
373 for(const auto& rule : d_rules) {
374 if(!ret.empty())
375 ret+= " || ";
376 ret += "("+ rule->toString()+")";
377 }
378 return ret;
379 }
380 private:
381
382 vector<std::shared_ptr<DNSRule> > d_rules;
383
384 };
385
386
387 class RegexRule : public DNSRule
388 {
389 public:
390 RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
391 {
392
393 }
394 bool matches(const DNSQuestion* dq) const override
395 {
396 return d_regex.match(dq->qname->toStringNoDot());
397 }
398
399 string toString() const override
400 {
401 return "Regex: "+d_visual;
402 }
403 private:
404 Regex d_regex;
405 string d_visual;
406 };
407
408 #ifdef HAVE_RE2
409 #include <re2/re2.h>
410 class RE2Rule : public DNSRule
411 {
412 public:
413 RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
414 {
415
416 }
417 bool matches(const DNSQuestion* dq) const override
418 {
419 return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2);
420 }
421
422 string toString() const override
423 {
424 return "RE2 match: "+d_visual;
425 }
426 private:
427 RE2 d_re2;
428 string d_visual;
429 };
430 #endif
431
432
433 class SuffixMatchNodeRule : public DNSRule
434 {
435 public:
436 SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
437 {
438 }
439 bool matches(const DNSQuestion* dq) const override
440 {
441 return d_smn.check(*dq->qname);
442 }
443 string toString() const override
444 {
445 if(d_quiet)
446 return "qname==in-set";
447 else
448 return "qname in "+d_smn.toString();
449 }
450 private:
451 SuffixMatchNode d_smn;
452 bool d_quiet;
453 };
454
455 class QNameRule : public DNSRule
456 {
457 public:
458 QNameRule(const DNSName& qname) : d_qname(qname)
459 {
460 }
461 bool matches(const DNSQuestion* dq) const override
462 {
463 return d_qname==*dq->qname;
464 }
465 string toString() const override
466 {
467 return "qname=="+d_qname.toString();
468 }
469 private:
470 DNSName d_qname;
471 };
472
473
474 class QTypeRule : public DNSRule
475 {
476 public:
477 QTypeRule(uint16_t qtype) : d_qtype(qtype)
478 {
479 }
480 bool matches(const DNSQuestion* dq) const override
481 {
482 return d_qtype == dq->qtype;
483 }
484 string toString() const override
485 {
486 QType qt(d_qtype);
487 return "qtype=="+qt.getName();
488 }
489 private:
490 uint16_t d_qtype;
491 };
492
493 class QClassRule : public DNSRule
494 {
495 public:
496 QClassRule(uint16_t qclass) : d_qclass(qclass)
497 {
498 }
499 bool matches(const DNSQuestion* dq) const override
500 {
501 return d_qclass == dq->qclass;
502 }
503 string toString() const override
504 {
505 return "qclass=="+std::to_string(d_qclass);
506 }
507 private:
508 uint16_t d_qclass;
509 };
510
511 class OpcodeRule : public DNSRule
512 {
513 public:
514 OpcodeRule(uint8_t opcode) : d_opcode(opcode)
515 {
516 }
517 bool matches(const DNSQuestion* dq) const override
518 {
519 return d_opcode == dq->dh->opcode;
520 }
521 string toString() const override
522 {
523 return "opcode=="+std::to_string(d_opcode);
524 }
525 private:
526 uint8_t d_opcode;
527 };
528
529 class TCPRule : public DNSRule
530 {
531 public:
532 TCPRule(bool tcp): d_tcp(tcp)
533 {
534 }
535 bool matches(const DNSQuestion* dq) const override
536 {
537 return dq->tcp == d_tcp;
538 }
539 string toString() const override
540 {
541 return (d_tcp ? "TCP" : "UDP");
542 }
543 private:
544 bool d_tcp;
545 };
546
547
548 class NotRule : public DNSRule
549 {
550 public:
551 NotRule(shared_ptr<DNSRule>& rule): d_rule(rule)
552 {
553 }
554 bool matches(const DNSQuestion* dq) const override
555 {
556 return !d_rule->matches(dq);
557 }
558 string toString() const override
559 {
560 return "!("+ d_rule->toString()+")";
561 }
562 private:
563 shared_ptr<DNSRule> d_rule;
564 };
565
566 class RecordsCountRule : public DNSRule
567 {
568 public:
569 RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
570 {
571 }
572 bool matches(const DNSQuestion* dq) const override
573 {
574 uint16_t count = 0;
575 switch(d_section) {
576 case 0:
577 count = ntohs(dq->dh->qdcount);
578 break;
579 case 1:
580 count = ntohs(dq->dh->ancount);
581 break;
582 case 2:
583 count = ntohs(dq->dh->nscount);
584 break;
585 case 3:
586 count = ntohs(dq->dh->arcount);
587 break;
588 }
589 return count >= d_minCount && count <= d_maxCount;
590 }
591 string toString() const override
592 {
593 string section;
594 switch(d_section) {
595 case 0:
596 section = "QD";
597 break;
598 case 1:
599 section = "AN";
600 break;
601 case 2:
602 section = "NS";
603 break;
604 case 3:
605 section = "AR";
606 break;
607 }
608 return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
609 }
610 private:
611 uint16_t d_minCount;
612 uint16_t d_maxCount;
613 uint8_t d_section;
614 };
615
616 class RecordsTypeCountRule : public DNSRule
617 {
618 public:
619 RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
620 {
621 }
622 bool matches(const DNSQuestion* dq) const override
623 {
624 uint16_t count = 0;
625 switch(d_section) {
626 case 0:
627 count = ntohs(dq->dh->qdcount);
628 break;
629 case 1:
630 count = ntohs(dq->dh->ancount);
631 break;
632 case 2:
633 count = ntohs(dq->dh->nscount);
634 break;
635 case 3:
636 count = ntohs(dq->dh->arcount);
637 break;
638 }
639 if (count < d_minCount) {
640 return false;
641 }
642 count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->dh), dq->len, d_section, d_type);
643 return count >= d_minCount && count <= d_maxCount;
644 }
645 string toString() const override
646 {
647 string section;
648 switch(d_section) {
649 case 0:
650 section = "QD";
651 break;
652 case 1:
653 section = "AN";
654 break;
655 case 2:
656 section = "NS";
657 break;
658 case 3:
659 section = "AR";
660 break;
661 }
662 return std::to_string(d_minCount) + " <= " + QType(d_type).getName() + " records in " + section + " <= "+ std::to_string(d_maxCount);
663 }
664 private:
665 uint16_t d_type;
666 uint16_t d_minCount;
667 uint16_t d_maxCount;
668 uint8_t d_section;
669 };
670
671 class TrailingDataRule : public DNSRule
672 {
673 public:
674 TrailingDataRule()
675 {
676 }
677 bool matches(const DNSQuestion* dq) const override
678 {
679 uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->dh), dq->len);
680 return length < dq->len;
681 }
682 string toString() const override
683 {
684 return "trailing data";
685 }
686 };
687
688 class QNameLabelsCountRule : public DNSRule
689 {
690 public:
691 QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
692 {
693 }
694 bool matches(const DNSQuestion* dq) const override
695 {
696 unsigned int count = dq->qname->countLabels();
697 return count < d_min || count > d_max;
698 }
699 string toString() const override
700 {
701 return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
702 }
703 private:
704 unsigned int d_min;
705 unsigned int d_max;
706 };
707
708 class QNameWireLengthRule : public DNSRule
709 {
710 public:
711 QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
712 {
713 }
714 bool matches(const DNSQuestion* dq) const override
715 {
716 size_t const wirelength = dq->qname->wirelength();
717 return wirelength < d_min || wirelength > d_max;
718 }
719 string toString() const override
720 {
721 return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
722 }
723 private:
724 size_t d_min;
725 size_t d_max;
726 };
727
728 class RCodeRule : public DNSRule
729 {
730 public:
731 RCodeRule(int rcode) : d_rcode(rcode)
732 {
733 }
734 bool matches(const DNSQuestion* dq) const override
735 {
736 return d_rcode == dq->dh->rcode;
737 }
738 string toString() const override
739 {
740 return "rcode=="+RCode::to_s(d_rcode);
741 }
742 private:
743 int d_rcode;
744 };
745
746 class RDRule : public DNSRule
747 {
748 public:
749 RDRule()
750 {
751 }
752 bool matches(const DNSQuestion* dq) const override
753 {
754 return dq->dh->rd == 1;
755 }
756 string toString() const override
757 {
758 return "rd==1";
759 }
760 };
761
762
763 class ProbaRule : public DNSRule
764 {
765 public:
766 ProbaRule(double proba) : d_proba(proba)
767 {
768 }
769 bool matches(const DNSQuestion* dq) const override;
770 string toString() const override;
771 double d_proba;
772 };
773
774
775 class DropAction : public DNSAction
776 {
777 public:
778 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
779 {
780 return Action::Drop;
781 }
782 string toString() const override
783 {
784 return "drop";
785 }
786 };
787
788 class AllowAction : public DNSAction
789 {
790 public:
791 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
792 {
793 return Action::Allow;
794 }
795 string toString() const override
796 {
797 return "allow";
798 }
799 };
800
801
802 class QPSAction : public DNSAction
803 {
804 public:
805 QPSAction(int limit) : d_qps(limit, limit)
806 {}
807 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
808 {
809 if(d_qps.check())
810 return Action::None;
811 else
812 return Action::Drop;
813 }
814 string toString() const override
815 {
816 return "qps limit to "+std::to_string(d_qps.getRate());
817 }
818 private:
819 QPSLimiter d_qps;
820 };
821
822 class DelayAction : public DNSAction
823 {
824 public:
825 DelayAction(int msec) : d_msec(msec)
826 {}
827 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
828 {
829 *ruleresult=std::to_string(d_msec);
830 return Action::Delay;
831 }
832 string toString() const override
833 {
834 return "delay by "+std::to_string(d_msec)+ " msec";
835 }
836 private:
837 int d_msec;
838 };
839
840
841 class TeeAction : public DNSAction
842 {
843 public:
844 TeeAction(const ComboAddress& ca, bool addECS=false);
845 ~TeeAction() override;
846 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override;
847 string toString() const override;
848 std::unordered_map<string, double> getStats() const override;
849
850 private:
851 ComboAddress d_remote;
852 std::thread d_worker;
853 void worker();
854
855 int d_fd;
856 mutable std::atomic<unsigned long> d_senderrors{0};
857 unsigned long d_recverrors{0};
858 mutable std::atomic<unsigned long> d_queries{0};
859 unsigned long d_responses{0};
860 unsigned long d_nxdomains{0};
861 unsigned long d_servfails{0};
862 unsigned long d_refuseds{0};
863 unsigned long d_formerrs{0};
864 unsigned long d_notimps{0};
865 unsigned long d_noerrors{0};
866 mutable unsigned long d_tcpdrops{0};
867 unsigned long d_otherrcode{0};
868 std::atomic<bool> d_pleaseQuit{false};
869 bool d_addECS{false};
870 };
871
872 class PoolAction : public DNSAction
873 {
874 public:
875 PoolAction(const std::string& pool) : d_pool(pool) {}
876 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
877 {
878 *ruleresult=d_pool;
879 return Action::Pool;
880 }
881 string toString() const override
882 {
883 return "to pool "+d_pool;
884 }
885
886 private:
887 string d_pool;
888 };
889
890
891 class QPSPoolAction : public DNSAction
892 {
893 public:
894 QPSPoolAction(unsigned int limit, const std::string& pool) : d_qps(limit, limit), d_pool(pool) {}
895 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
896 {
897 if(d_qps.check()) {
898 *ruleresult=d_pool;
899 return Action::Pool;
900 }
901 else
902 return Action::None;
903 }
904 string toString() const override
905 {
906 return "max " +std::to_string(d_qps.getRate())+" to pool "+d_pool;
907 }
908
909 private:
910 QPSLimiter d_qps;
911 string d_pool;
912 };
913
914 class RCodeAction : public DNSAction
915 {
916 public:
917 RCodeAction(int rcode) : d_rcode(rcode) {}
918 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
919 {
920 dq->dh->rcode = d_rcode;
921 dq->dh->qr = true; // for good measure
922 return Action::HeaderModify;
923 }
924 string toString() const override
925 {
926 return "set rcode "+std::to_string(d_rcode);
927 }
928
929 private:
930 int d_rcode;
931 };
932
933 class TCAction : public DNSAction
934 {
935 public:
936 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
937 {
938 return Action::Truncate;
939 }
940 string toString() const override
941 {
942 return "tc=1 answer";
943 }
944 };
945
946 class SpoofAction : public DNSAction
947 {
948 public:
949 SpoofAction(const vector<ComboAddress>& addrs) : d_addrs(addrs)
950 {
951 }
952
953 SpoofAction(const string& cname): d_cname(cname) { }
954
955 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
956 {
957 uint16_t qtype = dq->qtype;
958 // do we even have a response?
959 if(d_cname.empty() && !std::count_if(d_addrs.begin(), d_addrs.end(), [qtype](const ComboAddress& a)
960 {
961 return (qtype == QType::ANY || ((a.sin4.sin_family == AF_INET && qtype == QType::A) ||
962 (a.sin4.sin_family == AF_INET6 && qtype == QType::AAAA)));
963 }))
964 return Action::None;
965
966 vector<ComboAddress> addrs;
967 unsigned int totrdatalen=0;
968 if (!d_cname.empty()) {
969 qtype = QType::CNAME;
970 totrdatalen += d_cname.toDNSString().size();
971 } else {
972 for(const auto& addr : d_addrs) {
973 if(qtype != QType::ANY && ((addr.sin4.sin_family == AF_INET && qtype != QType::A) ||
974 (addr.sin4.sin_family == AF_INET6 && qtype != QType::AAAA)))
975 continue;
976 totrdatalen += addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr);
977 addrs.push_back(addr);
978 }
979 }
980
981 if(addrs.size() > 1)
982 random_shuffle(addrs.begin(), addrs.end());
983
984 unsigned int consumed=0;
985 DNSName ignore((char*)dq->dh, dq->len, sizeof(dnsheader), false, 0, 0, &consumed);
986
987 if (dq->size < (sizeof(dnsheader) + consumed + 4 + ((d_cname.empty() ? 0 : 1) + addrs.size())*12 /* recordstart */ + totrdatalen)) {
988 return Action::None;
989 }
990
991 dq->len = sizeof(dnsheader) + consumed + 4; // there goes your EDNS
992 char* dest = ((char*)dq->dh) + dq->len;
993
994 dq->dh->qr = true; // for good measure
995 dq->dh->ra = dq->dh->rd; // for good measure
996 dq->dh->ad = false;
997 dq->dh->ancount = 0;
998 dq->dh->arcount = 0; // for now, forget about your EDNS, we're marching over it
999
1000 if(qtype == QType::CNAME) {
1001 string wireData = d_cname.toDNSString(); // Note! This doesn't do compression!
1002 const unsigned char recordstart[]={0xc0, 0x0c, // compressed name
1003 0, (unsigned char) qtype,
1004 0, QClass::IN, // IN
1005 0, 0, 0, 60, // TTL
1006 0, (unsigned char)wireData.length()};
1007 static_assert(sizeof(recordstart) == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid");
1008
1009 memcpy(dest, recordstart, sizeof(recordstart));
1010 dest += sizeof(recordstart);
1011 memcpy(dest, wireData.c_str(), wireData.length());
1012 dq->len += wireData.length() + sizeof(recordstart);
1013 dq->dh->ancount++;
1014 }
1015 else {
1016 for(const auto& addr : addrs) {
1017 unsigned char rdatalen = addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr);
1018 const unsigned char recordstart[]={0xc0, 0x0c, // compressed name
1019 0, (unsigned char) (addr.sin4.sin_family == AF_INET ? QType::A : QType::AAAA),
1020 0, QClass::IN, // IN
1021 0, 0, 0, 60, // TTL
1022 0, rdatalen};
1023 static_assert(sizeof(recordstart) == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid");
1024
1025 memcpy(dest, recordstart, sizeof(recordstart));
1026 dest += sizeof(recordstart);
1027
1028 memcpy(dest,
1029 addr.sin4.sin_family == AF_INET ? (void*)&addr.sin4.sin_addr.s_addr : (void*)&addr.sin6.sin6_addr.s6_addr,
1030 rdatalen);
1031 dest += rdatalen;
1032 dq->len += rdatalen + sizeof(recordstart);
1033 dq->dh->ancount++;
1034 }
1035 }
1036
1037 dq->dh->ancount = htons(dq->dh->ancount);
1038
1039 return Action::HeaderModify;
1040 }
1041
1042 string toString() const override
1043 {
1044 string ret = "spoof in ";
1045 if(!d_cname.empty()) {
1046 ret+=d_cname.toString()+ " ";
1047 } else {
1048 for(const auto& a : d_addrs)
1049 ret += a.toString()+" ";
1050 }
1051 return ret;
1052 }
1053 private:
1054 std::vector<ComboAddress> d_addrs;
1055 DNSName d_cname;
1056 };
1057
1058 class MacAddrAction : public DNSAction
1059 {
1060 public:
1061 MacAddrAction(uint16_t code) : d_code(code)
1062 {}
1063 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1064 {
1065 if(dq->dh->arcount)
1066 return Action::None;
1067
1068 string mac = getMACAddress(*dq->remote);
1069 if(mac.empty())
1070 return Action::None;
1071
1072 string optRData;
1073 generateEDNSOption(d_code, mac, optRData);
1074
1075 string res;
1076 generateOptRR(optRData, res);
1077
1078 if ((dq->size - dq->len) < res.length())
1079 return Action::None;
1080
1081 dq->dh->arcount = htons(1);
1082 char* dest = ((char*)dq->dh) + dq->len;
1083 memcpy(dest, res.c_str(), res.length());
1084 dq->len += res.length();
1085
1086 return Action::None;
1087 }
1088 string toString() const override
1089 {
1090 return "add EDNS MAC (code="+std::to_string(d_code)+")";
1091 }
1092 private:
1093 uint16_t d_code{3};
1094 };
1095
1096 class NoRecurseAction : public DNSAction
1097 {
1098 public:
1099 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1100 {
1101 dq->dh->rd = false;
1102 return Action::None;
1103 }
1104 string toString() const override
1105 {
1106 return "set rd=0";
1107 }
1108 };
1109
1110 class LogAction : public DNSAction, public boost::noncopyable
1111 {
1112 public:
1113 LogAction() : d_fp(0)
1114 {
1115 }
1116 LogAction(const std::string& str, bool binary=true, bool append=false, bool buffered=true) : d_fname(str), d_binary(binary)
1117 {
1118 if(str.empty())
1119 return;
1120 if(append)
1121 d_fp = fopen(str.c_str(), "a+");
1122 else
1123 d_fp = fopen(str.c_str(), "w");
1124 if(!d_fp)
1125 throw std::runtime_error("Unable to open file '"+str+"' for logging: "+string(strerror(errno)));
1126 if(!buffered)
1127 setbuf(d_fp, 0);
1128 }
1129 ~LogAction() override
1130 {
1131 if(d_fp)
1132 fclose(d_fp);
1133 }
1134 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1135 {
1136 if(!d_fp) {
1137 vinfolog("Packet from %s for %s %s with id %d", dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->dh->id);
1138 }
1139 else {
1140 if(d_binary) {
1141 string out = dq->qname->toDNSString();
1142 fwrite(out.c_str(), 1, out.size(), d_fp);
1143 fwrite((void*)&dq->qtype, 1, 2, d_fp);
1144 }
1145 else {
1146 fprintf(d_fp, "Packet from %s for %s %s with id %d\n", dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).getName().c_str(), dq->dh->id);
1147 }
1148 }
1149 return Action::None;
1150 }
1151 string toString() const override
1152 {
1153 if (!d_fname.empty()) {
1154 return "log to " + d_fname;
1155 }
1156 return "log";
1157 }
1158 private:
1159 string d_fname;
1160 FILE* d_fp{0};
1161 bool d_binary{true};
1162 };
1163
1164
1165 class DisableValidationAction : public DNSAction
1166 {
1167 public:
1168 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1169 {
1170 dq->dh->cd = true;
1171 return Action::None;
1172 }
1173 string toString() const override
1174 {
1175 return "set cd=1";
1176 }
1177 };
1178
1179 class SkipCacheAction : public DNSAction
1180 {
1181 public:
1182 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1183 {
1184 dq->skipCache = true;
1185 return Action::None;
1186 }
1187 string toString() const override
1188 {
1189 return "skip cache";
1190 }
1191 };
1192
1193 class ECSPrefixLengthAction : public DNSAction
1194 {
1195 public:
1196 ECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length)
1197 {
1198 }
1199 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1200 {
1201 dq->ecsPrefixLength = dq->remote->sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength;
1202 return Action::None;
1203 }
1204 string toString() const override
1205 {
1206 return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength);
1207 }
1208 private:
1209 uint16_t d_v4PrefixLength;
1210 uint16_t d_v6PrefixLength;
1211 };
1212
1213 class ECSOverrideAction : public DNSAction
1214 {
1215 public:
1216 ECSOverrideAction(bool ecsOverride) : d_ecsOverride(ecsOverride)
1217 {
1218 }
1219 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1220 {
1221 dq->ecsOverride = d_ecsOverride;
1222 return Action::None;
1223 }
1224 string toString() const override
1225 {
1226 return "set ECS override to " + std::to_string(d_ecsOverride);
1227 }
1228 private:
1229 bool d_ecsOverride;
1230 };
1231
1232
1233 class DisableECSAction : public DNSAction
1234 {
1235 public:
1236 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1237 {
1238 dq->useECS = false;
1239 return Action::None;
1240 }
1241 string toString() const override
1242 {
1243 return "disable ECS";
1244 }
1245 };
1246
1247 class RemoteLogAction : public DNSAction, public boost::noncopyable
1248 {
1249 public:
1250 RemoteLogAction(std::shared_ptr<RemoteLogger> logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc): d_logger(logger), d_alterFunc(alterFunc)
1251 {
1252 }
1253 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1254 {
1255 #ifdef HAVE_PROTOBUF
1256 if (!dq->uniqueId) {
1257 dq->uniqueId = t_uuidGenerator();
1258 }
1259
1260 DNSDistProtoBufMessage message(*dq);
1261 {
1262 if (d_alterFunc) {
1263 std::lock_guard<std::mutex> lock(g_luamutex);
1264 (*d_alterFunc)(*dq, &message);
1265 }
1266 }
1267 std::string data;
1268 message.serialize(data);
1269 d_logger->queueData(data);
1270 #endif /* HAVE_PROTOBUF */
1271 return Action::None;
1272 }
1273 string toString() const override
1274 {
1275 return "remote log to " + (d_logger ? d_logger->toString() : "");
1276 }
1277 private:
1278 std::shared_ptr<RemoteLogger> d_logger;
1279 boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > d_alterFunc;
1280 };
1281
1282 class SNMPTrapAction : public DNSAction
1283 {
1284 public:
1285 SNMPTrapAction(const std::string& reason): d_reason(reason)
1286 {
1287 }
1288 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1289 {
1290 if (g_snmpAgent && g_snmpTrapsEnabled) {
1291 g_snmpAgent->sendDNSTrap(*dq, d_reason);
1292 }
1293
1294 return Action::None;
1295 }
1296 string toString() const override
1297 {
1298 return "send SNMP trap";
1299 }
1300 private:
1301 std::string d_reason;
1302 };
1303
1304 class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable
1305 {
1306 public:
1307 RemoteLogResponseAction(std::shared_ptr<RemoteLogger> logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_includeCNAME(includeCNAME)
1308 {
1309 }
1310 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1311 {
1312 #ifdef HAVE_PROTOBUF
1313 if (!dr->uniqueId) {
1314 dr->uniqueId = t_uuidGenerator();
1315 }
1316
1317 DNSDistProtoBufMessage message(*dr, d_includeCNAME);
1318 {
1319 if (d_alterFunc) {
1320 std::lock_guard<std::mutex> lock(g_luamutex);
1321 (*d_alterFunc)(*dr, &message);
1322 }
1323 }
1324 std::string data;
1325 message.serialize(data);
1326 d_logger->queueData(data);
1327 #endif /* HAVE_PROTOBUF */
1328 return Action::None;
1329 }
1330 string toString() const override
1331 {
1332 return "remote log response to " + (d_logger ? d_logger->toString() : "");
1333 }
1334 private:
1335 std::shared_ptr<RemoteLogger> d_logger;
1336 boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > d_alterFunc;
1337 bool d_includeCNAME;
1338 };
1339
1340 class DropResponseAction : public DNSResponseAction
1341 {
1342 public:
1343 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1344 {
1345 return Action::Drop;
1346 }
1347 string toString() const override
1348 {
1349 return "drop";
1350 }
1351 };
1352
1353 class AllowResponseAction : public DNSResponseAction
1354 {
1355 public:
1356 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1357 {
1358 return Action::Allow;
1359 }
1360 string toString() const override
1361 {
1362 return "allow";
1363 }
1364 };
1365
1366 class DelayResponseAction : public DNSResponseAction
1367 {
1368 public:
1369 DelayResponseAction(int msec) : d_msec(msec)
1370 {}
1371 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1372 {
1373 *ruleresult=std::to_string(d_msec);
1374 return Action::Delay;
1375 }
1376 string toString() const override
1377 {
1378 return "delay by "+std::to_string(d_msec)+ " msec";
1379 }
1380 private:
1381 int d_msec;
1382 };
1383
1384 class SNMPTrapResponseAction : public DNSResponseAction
1385 {
1386 public:
1387 SNMPTrapResponseAction(const std::string& reason): d_reason(reason)
1388 {
1389 }
1390 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1391 {
1392 if (g_snmpAgent && g_snmpTrapsEnabled) {
1393 g_snmpAgent->sendDNSTrap(*dr, d_reason);
1394 }
1395
1396 return Action::None;
1397 }
1398 string toString() const override
1399 {
1400 return "send SNMP trap";
1401 }
1402 private:
1403 std::string d_reason;
1404 };