]> git.ipfire.org Git - thirdparty/pdns.git/blob - pdns/dnsrulactions.hh
Merge pull request #4692 from cmouse/ssql-unique-ptr
[thirdparty/pdns.git] / pdns / dnsrulactions.hh
1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #include "dnsdist.hh"
23 #include "dnsdist-ecs.hh"
24 #include "dnsname.hh"
25 #include "dolog.hh"
26 #include "ednsoptions.hh"
27 #include "lock.hh"
28 #include "remote_logger.hh"
29 #include "dnsdist-protobuf.hh"
30 #include "dnsparser.hh"
31
32 class MaxQPSIPRule : public DNSRule
33 {
34 public:
35 MaxQPSIPRule(unsigned int qps, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64) :
36 d_qps(qps), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc)
37 {
38 pthread_rwlock_init(&d_lock, 0);
39 }
40
41 bool matches(const DNSQuestion* dq) const override
42 {
43 ComboAddress zeroport(*dq->remote);
44 zeroport.sin4.sin_port=0;
45 zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
46 {
47 ReadLock r(&d_lock);
48 const auto iter = d_limits.find(zeroport);
49 if (iter != d_limits.end()) {
50 return !iter->second.check();
51 }
52 }
53 {
54 WriteLock w(&d_lock);
55 auto iter = d_limits.find(zeroport);
56 if(iter == d_limits.end()) {
57 iter=d_limits.insert({zeroport,QPSLimiter(d_qps, d_qps)}).first;
58 }
59 return !iter->second.check();
60 }
61 }
62
63 string toString() const override
64 {
65 return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps);
66 }
67
68
69 private:
70 mutable pthread_rwlock_t d_lock;
71 mutable std::map<ComboAddress, QPSLimiter> d_limits;
72 unsigned int d_qps, d_ipv4trunc, d_ipv6trunc;
73
74 };
75
76 class MaxQPSRule : public DNSRule
77 {
78 public:
79 MaxQPSRule(unsigned int qps)
80 : d_qps(qps, qps)
81 {}
82
83 MaxQPSRule(unsigned int qps, unsigned int burst)
84 : d_qps(qps, burst)
85 {}
86
87
88 bool matches(const DNSQuestion* qd) const override
89 {
90 return d_qps.check();
91 }
92
93 string toString() const override
94 {
95 return "Max " + std::to_string(d_qps.getRate()) + " qps";
96 }
97
98
99 private:
100 mutable QPSLimiter d_qps;
101 };
102
103 class NMGRule : public DNSRule
104 {
105 public:
106 NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
107 protected:
108 NetmaskGroup d_nmg;
109 };
110
111 class NetmaskGroupRule : public NMGRule
112 {
113 public:
114 NetmaskGroupRule(const NetmaskGroup& nmg, bool src) : NMGRule(nmg)
115 {
116 d_src = src;
117 }
118 bool matches(const DNSQuestion* dq) const override
119 {
120 if(!d_src) {
121 return d_nmg.match(*dq->local);
122 }
123 return d_nmg.match(*dq->remote);
124 }
125
126 string toString() const override
127 {
128 if(!d_src) {
129 return "Dst: "+d_nmg.toString();
130 }
131 return "Src: "+d_nmg.toString();
132 }
133 private:
134 bool d_src;
135 };
136
137 class TimedIPSetRule : public DNSRule, boost::noncopyable
138 {
139 private:
140 struct IPv6 {
141 IPv6(const ComboAddress& ca)
142 {
143 static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
144 memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
145 }
146 bool operator==(const IPv6& rhs) const
147 {
148 return a==rhs.a && b==rhs.b;
149 }
150 uint64_t a, b;
151 };
152
153 public:
154 TimedIPSetRule()
155 {
156 pthread_rwlock_init(&d_lock4, 0);
157 pthread_rwlock_init(&d_lock6, 0);
158 }
159 bool matches(const DNSQuestion* dq) const override
160 {
161 if(dq->remote->sin4.sin_family == AF_INET) {
162 ReadLock rl(&d_lock4);
163 auto fnd = d_ip4s.find(dq->remote->sin4.sin_addr.s_addr);
164 if(fnd == d_ip4s.end()) {
165 return false;
166 }
167 return time(0) < fnd->second;
168 } else {
169 ReadLock rl(&d_lock6);
170 auto fnd = d_ip6s.find({*dq->remote});
171 if(fnd == d_ip6s.end()) {
172 return false;
173 }
174 return time(0) < fnd->second;
175 }
176 }
177
178 void add(const ComboAddress& ca, time_t ttd)
179 {
180 // think twice before adding templates here
181 if(ca.sin4.sin_family == AF_INET) {
182 WriteLock rl(&d_lock4);
183 auto res=d_ip4s.insert({ca.sin4.sin_addr.s_addr, ttd});
184 if(!res.second && (time_t)res.first->second < ttd)
185 res.first->second = (uint32_t)ttd;
186 }
187 else {
188 WriteLock rl(&d_lock6);
189 auto res=d_ip6s.insert({{ca}, ttd});
190 if(!res.second && (time_t)res.first->second < ttd)
191 res.first->second = (uint32_t)ttd;
192 }
193 }
194
195 void remove(const ComboAddress& ca)
196 {
197 if(ca.sin4.sin_family == AF_INET) {
198 WriteLock rl(&d_lock4);
199 d_ip4s.erase(ca.sin4.sin_addr.s_addr);
200 }
201 else {
202 WriteLock rl(&d_lock6);
203 d_ip6s.erase({ca});
204 }
205 }
206
207 void clear()
208 {
209 {
210 WriteLock rl(&d_lock4);
211 d_ip4s.clear();
212 }
213 WriteLock rl(&d_lock6);
214 d_ip6s.clear();
215 }
216
217 void cleanup()
218 {
219 time_t now=time(0);
220 {
221 WriteLock rl(&d_lock4);
222
223 for(auto iter = d_ip4s.begin(); iter != d_ip4s.end(); ) {
224 if(iter->second < now)
225 iter=d_ip4s.erase(iter);
226 else
227 ++iter;
228 }
229
230 }
231
232 {
233 WriteLock rl(&d_lock6);
234
235 for(auto iter = d_ip6s.begin(); iter != d_ip6s.end(); ) {
236 if(iter->second < now)
237 iter=d_ip6s.erase(iter);
238 else
239 ++iter;
240 }
241
242 }
243
244 }
245
246 string toString() const override
247 {
248 time_t now=time(0);
249 uint64_t count = 0;
250 {
251 ReadLock rl(&d_lock4);
252 for(const auto& ip : d_ip4s)
253 if(now < ip.second)
254 ++count;
255 }
256 {
257 ReadLock rl(&d_lock6);
258 for(const auto& ip : d_ip6s)
259 if(now < ip.second)
260 ++count;
261 }
262
263 return "Src: "+std::to_string(count)+" ips";
264 }
265 private:
266 struct IPv6Hash
267 {
268 std::size_t operator()(const IPv6& ip) const
269 {
270 auto ah=std::hash<uint64_t>{}(ip.a);
271 auto bh=std::hash<uint64_t>{}(ip.b);
272 return ah & (bh<<1);
273 }
274 };
275 std::unordered_map<IPv6, uint32_t, IPv6Hash> d_ip6s;
276 std::unordered_map<uint32_t, uint32_t> d_ip4s;
277 mutable pthread_rwlock_t d_lock4;
278 mutable pthread_rwlock_t d_lock6;
279 };
280
281
282 class AllRule : public DNSRule
283 {
284 public:
285 AllRule() {}
286 bool matches(const DNSQuestion* dq) const override
287 {
288 return true;
289 }
290
291 string toString() const override
292 {
293 return "All";
294 }
295
296 };
297
298
299 class DNSSECRule : public DNSRule
300 {
301 public:
302 DNSSECRule()
303 {
304
305 }
306 bool matches(const DNSQuestion* dq) const override
307 {
308 return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default..
309 }
310
311 string toString() const override
312 {
313 return "DNSSEC";
314 }
315 };
316
317 class AndRule : public DNSRule
318 {
319 public:
320 AndRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
321 {
322 for(const auto& r : rules)
323 d_rules.push_back(r.second);
324 }
325
326 bool matches(const DNSQuestion* dq) const override
327 {
328 auto iter = d_rules.begin();
329 for(; iter != d_rules.end(); ++iter)
330 if(!(*iter)->matches(dq))
331 break;
332 return iter == d_rules.end();
333 }
334
335 string toString() const override
336 {
337 string ret;
338 for(const auto& rule : d_rules) {
339 if(!ret.empty())
340 ret+= " && ";
341 ret += "("+ rule->toString()+")";
342 }
343 return ret;
344 }
345 private:
346
347 vector<std::shared_ptr<DNSRule> > d_rules;
348
349 };
350
351
352 class OrRule : public DNSRule
353 {
354 public:
355 OrRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
356 {
357 for(const auto& r : rules)
358 d_rules.push_back(r.second);
359 }
360
361 bool matches(const DNSQuestion* dq) const override
362 {
363 auto iter = d_rules.begin();
364 for(; iter != d_rules.end(); ++iter)
365 if((*iter)->matches(dq))
366 return true;
367 return false;
368 }
369
370 string toString() const override
371 {
372 string ret;
373 for(const auto& rule : d_rules) {
374 if(!ret.empty())
375 ret+= " || ";
376 ret += "("+ rule->toString()+")";
377 }
378 return ret;
379 }
380 private:
381
382 vector<std::shared_ptr<DNSRule> > d_rules;
383
384 };
385
386
387 class RegexRule : public DNSRule
388 {
389 public:
390 RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
391 {
392
393 }
394 bool matches(const DNSQuestion* dq) const override
395 {
396 return d_regex.match(dq->qname->toStringNoDot());
397 }
398
399 string toString() const override
400 {
401 return "Regex: "+d_visual;
402 }
403 private:
404 Regex d_regex;
405 string d_visual;
406 };
407
408 #ifdef HAVE_RE2
409 #include <re2/re2.h>
410 class RE2Rule : public DNSRule
411 {
412 public:
413 RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
414 {
415
416 }
417 bool matches(const DNSQuestion* dq) const override
418 {
419 return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2);
420 }
421
422 string toString() const override
423 {
424 return "RE2 match: "+d_visual;
425 }
426 private:
427 RE2 d_re2;
428 string d_visual;
429 };
430 #endif
431
432
433 class SuffixMatchNodeRule : public DNSRule
434 {
435 public:
436 SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
437 {
438 }
439 bool matches(const DNSQuestion* dq) const override
440 {
441 return d_smn.check(*dq->qname);
442 }
443 string toString() const override
444 {
445 if(d_quiet)
446 return "qname==in-set";
447 else
448 return "qname in "+d_smn.toString();
449 }
450 private:
451 SuffixMatchNode d_smn;
452 bool d_quiet;
453 };
454
455 class QNameRule : public DNSRule
456 {
457 public:
458 QNameRule(const DNSName& qname) : d_qname(qname)
459 {
460 }
461 bool matches(const DNSQuestion* dq) const override
462 {
463 return d_qname==*dq->qname;
464 }
465 string toString() const override
466 {
467 return "qname=="+d_qname.toString();
468 }
469 private:
470 DNSName d_qname;
471 };
472
473
474 class QTypeRule : public DNSRule
475 {
476 public:
477 QTypeRule(uint16_t qtype) : d_qtype(qtype)
478 {
479 }
480 bool matches(const DNSQuestion* dq) const override
481 {
482 return d_qtype == dq->qtype;
483 }
484 string toString() const override
485 {
486 QType qt(d_qtype);
487 return "qtype=="+qt.getName();
488 }
489 private:
490 uint16_t d_qtype;
491 };
492
493 class QClassRule : public DNSRule
494 {
495 public:
496 QClassRule(uint16_t qclass) : d_qclass(qclass)
497 {
498 }
499 bool matches(const DNSQuestion* dq) const override
500 {
501 return d_qclass == dq->qclass;
502 }
503 string toString() const override
504 {
505 return "qclass=="+std::to_string(d_qclass);
506 }
507 private:
508 uint16_t d_qclass;
509 };
510
511 class OpcodeRule : public DNSRule
512 {
513 public:
514 OpcodeRule(uint8_t opcode) : d_opcode(opcode)
515 {
516 }
517 bool matches(const DNSQuestion* dq) const override
518 {
519 return d_opcode == dq->dh->opcode;
520 }
521 string toString() const override
522 {
523 return "opcode=="+std::to_string(d_opcode);
524 }
525 private:
526 uint8_t d_opcode;
527 };
528
529 class TCPRule : public DNSRule
530 {
531 public:
532 TCPRule(bool tcp): d_tcp(tcp)
533 {
534 }
535 bool matches(const DNSQuestion* dq) const override
536 {
537 return dq->tcp == d_tcp;
538 }
539 string toString() const override
540 {
541 return (d_tcp ? "TCP" : "UDP");
542 }
543 private:
544 bool d_tcp;
545 };
546
547
548 class NotRule : public DNSRule
549 {
550 public:
551 NotRule(shared_ptr<DNSRule>& rule): d_rule(rule)
552 {
553 }
554 bool matches(const DNSQuestion* dq) const override
555 {
556 return !d_rule->matches(dq);
557 }
558 string toString() const override
559 {
560 return "!("+ d_rule->toString()+")";
561 }
562 private:
563 shared_ptr<DNSRule> d_rule;
564 };
565
566 class RecordsCountRule : public DNSRule
567 {
568 public:
569 RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
570 {
571 }
572 bool matches(const DNSQuestion* dq) const override
573 {
574 uint16_t count = 0;
575 switch(d_section) {
576 case 0:
577 count = ntohs(dq->dh->qdcount);
578 break;
579 case 1:
580 count = ntohs(dq->dh->ancount);
581 break;
582 case 2:
583 count = ntohs(dq->dh->nscount);
584 break;
585 case 3:
586 count = ntohs(dq->dh->arcount);
587 break;
588 }
589 return count >= d_minCount && count <= d_maxCount;
590 }
591 string toString() const override
592 {
593 string section;
594 switch(d_section) {
595 case 0:
596 section = "QD";
597 break;
598 case 1:
599 section = "AN";
600 break;
601 case 2:
602 section = "NS";
603 break;
604 case 3:
605 section = "AR";
606 break;
607 }
608 return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
609 }
610 private:
611 uint16_t d_minCount;
612 uint16_t d_maxCount;
613 uint8_t d_section;
614 };
615
616 class RecordsTypeCountRule : public DNSRule
617 {
618 public:
619 RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
620 {
621 }
622 bool matches(const DNSQuestion* dq) const override
623 {
624 uint16_t count = 0;
625 switch(d_section) {
626 case 0:
627 count = ntohs(dq->dh->qdcount);
628 break;
629 case 1:
630 count = ntohs(dq->dh->ancount);
631 break;
632 case 2:
633 count = ntohs(dq->dh->nscount);
634 break;
635 case 3:
636 count = ntohs(dq->dh->arcount);
637 break;
638 }
639 if (count < d_minCount) {
640 return false;
641 }
642 count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->dh), dq->len, d_section, d_type);
643 return count >= d_minCount && count <= d_maxCount;
644 }
645 string toString() const override
646 {
647 string section;
648 switch(d_section) {
649 case 0:
650 section = "QD";
651 break;
652 case 1:
653 section = "AN";
654 break;
655 case 2:
656 section = "NS";
657 break;
658 case 3:
659 section = "AR";
660 break;
661 }
662 return std::to_string(d_minCount) + " <= " + QType(d_type).getName() + " records in " + section + " <= "+ std::to_string(d_maxCount);
663 }
664 private:
665 uint16_t d_type;
666 uint16_t d_minCount;
667 uint16_t d_maxCount;
668 uint8_t d_section;
669 };
670
671 class TrailingDataRule : public DNSRule
672 {
673 public:
674 TrailingDataRule()
675 {
676 }
677 bool matches(const DNSQuestion* dq) const override
678 {
679 uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->dh), dq->len);
680 return length < dq->len;
681 }
682 string toString() const override
683 {
684 return "trailing data";
685 }
686 };
687
688 class QNameLabelsCountRule : public DNSRule
689 {
690 public:
691 QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
692 {
693 }
694 bool matches(const DNSQuestion* dq) const override
695 {
696 unsigned int count = dq->qname->countLabels();
697 return count < d_min || count > d_max;
698 }
699 string toString() const override
700 {
701 return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
702 }
703 private:
704 unsigned int d_min;
705 unsigned int d_max;
706 };
707
708 class QNameWireLengthRule : public DNSRule
709 {
710 public:
711 QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
712 {
713 }
714 bool matches(const DNSQuestion* dq) const override
715 {
716 size_t const wirelength = dq->qname->wirelength();
717 return wirelength < d_min || wirelength > d_max;
718 }
719 string toString() const override
720 {
721 return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
722 }
723 private:
724 size_t d_min;
725 size_t d_max;
726 };
727
728 class RCodeRule : public DNSRule
729 {
730 public:
731 RCodeRule(int rcode) : d_rcode(rcode)
732 {
733 }
734 bool matches(const DNSQuestion* dq) const override
735 {
736 return d_rcode == dq->dh->rcode;
737 }
738 string toString() const override
739 {
740 return "rcode=="+RCode::to_s(d_rcode);
741 }
742 private:
743 int d_rcode;
744 };
745
746 class RDRule : public DNSRule
747 {
748 public:
749 RDRule()
750 {
751 }
752 bool matches(const DNSQuestion* dq) const override
753 {
754 return dq->dh->rd == 1;
755 }
756 string toString() const override
757 {
758 return "rd==1";
759 }
760 };
761
762
763 class DropAction : public DNSAction
764 {
765 public:
766 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
767 {
768 return Action::Drop;
769 }
770 string toString() const override
771 {
772 return "drop";
773 }
774 };
775
776 class AllowAction : public DNSAction
777 {
778 public:
779 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
780 {
781 return Action::Allow;
782 }
783 string toString() const override
784 {
785 return "allow";
786 }
787 };
788
789
790 class QPSAction : public DNSAction
791 {
792 public:
793 QPSAction(int limit) : d_qps(limit, limit)
794 {}
795 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
796 {
797 if(d_qps.check())
798 return Action::None;
799 else
800 return Action::Drop;
801 }
802 string toString() const override
803 {
804 return "qps limit to "+std::to_string(d_qps.getRate());
805 }
806 private:
807 QPSLimiter d_qps;
808 };
809
810 class DelayAction : public DNSAction
811 {
812 public:
813 DelayAction(int msec) : d_msec(msec)
814 {}
815 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
816 {
817 *ruleresult=std::to_string(d_msec);
818 return Action::Delay;
819 }
820 string toString() const override
821 {
822 return "delay by "+std::to_string(d_msec)+ " msec";
823 }
824 private:
825 int d_msec;
826 };
827
828
829 class TeeAction : public DNSAction
830 {
831 public:
832 TeeAction(const ComboAddress& ca, bool addECS=false);
833 ~TeeAction() override;
834 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override;
835 string toString() const override;
836 std::unordered_map<string, double> getStats() const override;
837 private:
838 ComboAddress d_remote;
839 std::thread d_worker;
840 void worker();
841
842 int d_fd;
843 mutable std::atomic<unsigned long> d_senderrors{0};
844 unsigned long d_recverrors{0};
845 mutable std::atomic<unsigned long> d_queries{0};
846 unsigned long d_responses{0};
847 unsigned long d_nxdomains{0};
848 unsigned long d_servfails{0};
849 unsigned long d_refuseds{0};
850 unsigned long d_formerrs{0};
851 unsigned long d_notimps{0};
852 unsigned long d_noerrors{0};
853 mutable unsigned long d_tcpdrops{0};
854 unsigned long d_otherrcode{0};
855 std::atomic<bool> d_pleaseQuit{false};
856 bool d_addECS{false};
857 };
858
859
860
861 class PoolAction : public DNSAction
862 {
863 public:
864 PoolAction(const std::string& pool) : d_pool(pool) {}
865 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
866 {
867 *ruleresult=d_pool;
868 return Action::Pool;
869 }
870 string toString() const override
871 {
872 return "to pool "+d_pool;
873 }
874
875 private:
876 string d_pool;
877 };
878
879
880 class QPSPoolAction : public DNSAction
881 {
882 public:
883 QPSPoolAction(unsigned int limit, const std::string& pool) : d_qps(limit, limit), d_pool(pool) {}
884 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
885 {
886 if(d_qps.check()) {
887 *ruleresult=d_pool;
888 return Action::Pool;
889 }
890 else
891 return Action::None;
892 }
893 string toString() const override
894 {
895 return "max " +std::to_string(d_qps.getRate())+" to pool "+d_pool;
896 }
897
898 private:
899 QPSLimiter d_qps;
900 string d_pool;
901 };
902
903 class RCodeAction : public DNSAction
904 {
905 public:
906 RCodeAction(int rcode) : d_rcode(rcode) {}
907 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
908 {
909 dq->dh->rcode = d_rcode;
910 dq->dh->qr = true; // for good measure
911 return Action::HeaderModify;
912 }
913 string toString() const override
914 {
915 return "set rcode "+std::to_string(d_rcode);
916 }
917
918 private:
919 int d_rcode;
920 };
921
922 class TCAction : public DNSAction
923 {
924 public:
925 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
926 {
927 return Action::Truncate;
928 }
929 string toString() const override
930 {
931 return "tc=1 answer";
932 }
933 };
934
935 class SpoofAction : public DNSAction
936 {
937 public:
938 SpoofAction(const vector<ComboAddress>& addrs) : d_addrs(addrs)
939 {
940 }
941
942 SpoofAction(const string& cname): d_cname(cname) { }
943
944 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
945 {
946 uint16_t qtype = dq->qtype;
947 // do we even have a response?
948 if(d_cname.empty() && !std::count_if(d_addrs.begin(), d_addrs.end(), [qtype](const ComboAddress& a)
949 {
950 return (qtype == QType::ANY || ((a.sin4.sin_family == AF_INET && qtype == QType::A) ||
951 (a.sin4.sin_family == AF_INET6 && qtype == QType::AAAA)));
952 }))
953 return Action::None;
954
955 vector<ComboAddress> addrs;
956 unsigned int totrdatalen=0;
957 if (!d_cname.empty()) {
958 qtype = QType::CNAME;
959 totrdatalen += d_cname.toDNSString().size();
960 } else {
961 for(const auto& addr : d_addrs) {
962 if(qtype != QType::ANY && ((addr.sin4.sin_family == AF_INET && qtype != QType::A) ||
963 (addr.sin4.sin_family == AF_INET6 && qtype != QType::AAAA)))
964 continue;
965 totrdatalen += addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr);
966 addrs.push_back(addr);
967 }
968 }
969
970 if(addrs.size() > 1)
971 random_shuffle(addrs.begin(), addrs.end());
972
973 unsigned int consumed=0;
974 DNSName ignore((char*)dq->dh, dq->len, sizeof(dnsheader), false, 0, 0, &consumed);
975
976 if (dq->size < (sizeof(dnsheader) + consumed + 4 + ((d_cname.empty() ? 0 : 1) + addrs.size())*12 /* recordstart */ + totrdatalen)) {
977 return Action::None;
978 }
979
980 dq->len = sizeof(dnsheader) + consumed + 4; // there goes your EDNS
981 char* dest = ((char*)dq->dh) + dq->len;
982
983 dq->dh->qr = true; // for good measure
984 dq->dh->ra = dq->dh->rd; // for good measure
985 dq->dh->ad = false;
986 dq->dh->ancount = 0;
987 dq->dh->arcount = 0; // for now, forget about your EDNS, we're marching over it
988
989 if(qtype == QType::CNAME) {
990 string wireData = d_cname.toDNSString(); // Note! This doesn't do compression!
991 const unsigned char recordstart[]={0xc0, 0x0c, // compressed name
992 0, (unsigned char) qtype,
993 0, QClass::IN, // IN
994 0, 0, 0, 60, // TTL
995 0, (unsigned char)wireData.length()};
996 static_assert(sizeof(recordstart) == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid");
997
998 memcpy(dest, recordstart, sizeof(recordstart));
999 dest += sizeof(recordstart);
1000 memcpy(dest, wireData.c_str(), wireData.length());
1001 dq->len += wireData.length() + sizeof(recordstart);
1002 dq->dh->ancount++;
1003 }
1004 else {
1005 for(const auto& addr : addrs) {
1006 unsigned char rdatalen = addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr);
1007 const unsigned char recordstart[]={0xc0, 0x0c, // compressed name
1008 0, (unsigned char) (addr.sin4.sin_family == AF_INET ? QType::A : QType::AAAA),
1009 0, QClass::IN, // IN
1010 0, 0, 0, 60, // TTL
1011 0, rdatalen};
1012 static_assert(sizeof(recordstart) == 12, "sizeof(recordstart) must be equal to 12, otherwise the above check is invalid");
1013
1014 memcpy(dest, recordstart, sizeof(recordstart));
1015 dest += sizeof(recordstart);
1016
1017 memcpy(dest,
1018 addr.sin4.sin_family == AF_INET ? (void*)&addr.sin4.sin_addr.s_addr : (void*)&addr.sin6.sin6_addr.s6_addr,
1019 rdatalen);
1020 dest += rdatalen;
1021 dq->len += rdatalen + sizeof(recordstart);
1022 dq->dh->ancount++;
1023 }
1024 }
1025
1026 dq->dh->ancount = htons(dq->dh->ancount);
1027
1028 return Action::HeaderModify;
1029 }
1030
1031 string toString() const override
1032 {
1033 string ret = "spoof in ";
1034 if(!d_cname.empty()) {
1035 ret+=d_cname.toString()+ " ";
1036 } else {
1037 for(const auto& a : d_addrs)
1038 ret += a.toString()+" ";
1039 }
1040 return ret;
1041 }
1042 private:
1043 std::vector<ComboAddress> d_addrs;
1044 DNSName d_cname;
1045 };
1046
1047 class MacAddrAction : public DNSAction
1048 {
1049 public:
1050 MacAddrAction(uint16_t code) : d_code(code)
1051 {}
1052 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1053 {
1054 if(dq->dh->arcount)
1055 return Action::None;
1056
1057 string mac = getMACAddress(*dq->remote);
1058 if(mac.empty())
1059 return Action::None;
1060
1061 string optRData;
1062 generateEDNSOption(d_code, mac, optRData);
1063
1064 string res;
1065 generateOptRR(optRData, res);
1066
1067 if ((dq->size - dq->len) < res.length())
1068 return Action::None;
1069
1070 dq->dh->arcount = htons(1);
1071 char* dest = ((char*)dq->dh) + dq->len;
1072 memcpy(dest, res.c_str(), res.length());
1073 dq->len += res.length();
1074
1075 return Action::None;
1076 }
1077 string toString() const override
1078 {
1079 return "add EDNS MAC (code="+std::to_string(d_code)+")";
1080 }
1081 private:
1082 uint16_t d_code{3};
1083 };
1084
1085 class NoRecurseAction : public DNSAction
1086 {
1087 public:
1088 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1089 {
1090 dq->dh->rd = false;
1091 return Action::None;
1092 }
1093 string toString() const override
1094 {
1095 return "set rd=0";
1096 }
1097 };
1098
1099 class LogAction : public DNSAction, public boost::noncopyable
1100 {
1101 public:
1102 LogAction() : d_fp(0)
1103 {
1104 }
1105 LogAction(const std::string& str, bool binary=true, bool append=false, bool buffered=true) : d_fname(str), d_binary(binary)
1106 {
1107 if(str.empty())
1108 return;
1109 if(append)
1110 d_fp = fopen(str.c_str(), "a+");
1111 else
1112 d_fp = fopen(str.c_str(), "w");
1113 if(!d_fp)
1114 throw std::runtime_error("Unable to open file '"+str+"' for logging: "+string(strerror(errno)));
1115 if(!buffered)
1116 setbuf(d_fp, 0);
1117 }
1118 ~LogAction() override
1119 {
1120 if(d_fp)
1121 fclose(d_fp);
1122 }
1123 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1124 {
1125 if(!d_fp) {
1126 vinfolog("Packet from %s for %s %s with id %d", dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->dh->id);
1127 }
1128 else {
1129 if(d_binary) {
1130 string out = dq->qname->toDNSString();
1131 fwrite(out.c_str(), 1, out.size(), d_fp);
1132 fwrite((void*)&dq->qtype, 1, 2, d_fp);
1133 }
1134 else {
1135 fprintf(d_fp, "Packet from %s for %s %s with id %d\n", dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).getName().c_str(), dq->dh->id);
1136 }
1137 }
1138 return Action::None;
1139 }
1140 string toString() const override
1141 {
1142 if (!d_fname.empty()) {
1143 return "log to " + d_fname;
1144 }
1145 return "log";
1146 }
1147 private:
1148 string d_fname;
1149 FILE* d_fp{0};
1150 bool d_binary{true};
1151 };
1152
1153
1154 class DisableValidationAction : public DNSAction
1155 {
1156 public:
1157 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1158 {
1159 dq->dh->cd = true;
1160 return Action::None;
1161 }
1162 string toString() const override
1163 {
1164 return "set cd=1";
1165 }
1166 };
1167
1168 class SkipCacheAction : public DNSAction
1169 {
1170 public:
1171 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1172 {
1173 dq->skipCache = true;
1174 return Action::None;
1175 }
1176 string toString() const override
1177 {
1178 return "skip cache";
1179 }
1180 };
1181
1182 class ECSPrefixLengthAction : public DNSAction
1183 {
1184 public:
1185 ECSPrefixLengthAction(uint16_t v4Length, uint16_t v6Length) : d_v4PrefixLength(v4Length), d_v6PrefixLength(v6Length)
1186 {
1187 }
1188 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1189 {
1190 dq->ecsPrefixLength = dq->remote->sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength;
1191 return Action::None;
1192 }
1193 string toString() const override
1194 {
1195 return "set ECS prefix length to " + std::to_string(d_v4PrefixLength) + "/" + std::to_string(d_v6PrefixLength);
1196 }
1197 private:
1198 uint16_t d_v4PrefixLength;
1199 uint16_t d_v6PrefixLength;
1200 };
1201
1202 class ECSOverrideAction : public DNSAction
1203 {
1204 public:
1205 ECSOverrideAction(bool ecsOverride) : d_ecsOverride(ecsOverride)
1206 {
1207 }
1208 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1209 {
1210 dq->ecsOverride = d_ecsOverride;
1211 return Action::None;
1212 }
1213 string toString() const override
1214 {
1215 return "set ECS override to " + std::to_string(d_ecsOverride);
1216 }
1217 private:
1218 bool d_ecsOverride;
1219 };
1220
1221
1222 class DisableECSAction : public DNSAction
1223 {
1224 public:
1225 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1226 {
1227 dq->useECS = false;
1228 return Action::None;
1229 }
1230 string toString() const override
1231 {
1232 return "disable ECS";
1233 }
1234 };
1235
1236 class RemoteLogAction : public DNSAction, public boost::noncopyable
1237 {
1238 public:
1239 RemoteLogAction(std::shared_ptr<RemoteLogger> logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc): d_logger(logger), d_alterFunc(alterFunc)
1240 {
1241 }
1242 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1243 {
1244 #ifdef HAVE_PROTOBUF
1245 DNSDistProtoBufMessage message(*dq);
1246 {
1247 if (d_alterFunc) {
1248 std::lock_guard<std::mutex> lock(g_luamutex);
1249 (*d_alterFunc)(*dq, &message);
1250 }
1251 }
1252 std::string data;
1253 message.serialize(data);
1254 d_logger->queueData(data);
1255 #endif /* HAVE_PROTOBUF */
1256 return Action::None;
1257 }
1258 string toString() const override
1259 {
1260 return "remote log to " + d_logger->toString();
1261 }
1262 private:
1263 std::shared_ptr<RemoteLogger> d_logger;
1264 boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > d_alterFunc;
1265 };
1266
1267 class SNMPTrapAction : public DNSAction
1268 {
1269 public:
1270 SNMPTrapAction(const std::string& reason): d_reason(reason)
1271 {
1272 }
1273 DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
1274 {
1275 if (g_snmpAgent && g_snmpTrapsEnabled) {
1276 g_snmpAgent->sendDNSTrap(*dq, d_reason);
1277 }
1278
1279 return Action::None;
1280 }
1281 string toString() const override
1282 {
1283 return "send SNMP trap";
1284 }
1285 private:
1286 std::string d_reason;
1287 };
1288
1289 class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable
1290 {
1291 public:
1292 RemoteLogResponseAction(std::shared_ptr<RemoteLogger> logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_includeCNAME(includeCNAME)
1293 {
1294 }
1295 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1296 {
1297 #ifdef HAVE_PROTOBUF
1298 DNSDistProtoBufMessage message(*dr, d_includeCNAME);
1299 {
1300 if (d_alterFunc) {
1301 std::lock_guard<std::mutex> lock(g_luamutex);
1302 (*d_alterFunc)(*dr, &message);
1303 }
1304 }
1305 std::string data;
1306 message.serialize(data);
1307 d_logger->queueData(data);
1308 #endif /* HAVE_PROTOBUF */
1309 return Action::None;
1310 }
1311 string toString() const override
1312 {
1313 return "remote log response to " + d_logger->toString();
1314 }
1315 private:
1316 std::shared_ptr<RemoteLogger> d_logger;
1317 boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > d_alterFunc;
1318 bool d_includeCNAME;
1319 };
1320
1321 class DropResponseAction : public DNSResponseAction
1322 {
1323 public:
1324 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1325 {
1326 return Action::Drop;
1327 }
1328 string toString() const override
1329 {
1330 return "drop";
1331 }
1332 };
1333
1334 class AllowResponseAction : public DNSResponseAction
1335 {
1336 public:
1337 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1338 {
1339 return Action::Allow;
1340 }
1341 string toString() const override
1342 {
1343 return "allow";
1344 }
1345 };
1346
1347 class DelayResponseAction : public DNSResponseAction
1348 {
1349 public:
1350 DelayResponseAction(int msec) : d_msec(msec)
1351 {}
1352 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1353 {
1354 *ruleresult=std::to_string(d_msec);
1355 return Action::Delay;
1356 }
1357 string toString() const override
1358 {
1359 return "delay by "+std::to_string(d_msec)+ " msec";
1360 }
1361 private:
1362 int d_msec;
1363 };
1364
1365 class SNMPTrapResponseAction : public DNSResponseAction
1366 {
1367 public:
1368 SNMPTrapResponseAction(const std::string& reason): d_reason(reason)
1369 {
1370 }
1371 DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
1372 {
1373 if (g_snmpAgent && g_snmpTrapsEnabled) {
1374 g_snmpAgent->sendDNSTrap(*dr, d_reason);
1375 }
1376
1377 return Action::None;
1378 }
1379 string toString() const override
1380 {
1381 return "send SNMP trap";
1382 }
1383 private:
1384 std::string d_reason;
1385 };