]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Add `QNameLabelsCountRule()` and `QNameWireLengthRule()` 4114/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 5 Jul 2016 13:39:16 +0000 (15:39 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 5 Jul 2016 13:39:16 +0000 (15:39 +0200)
* QNameLabelsCountRule(min, max) matches if the qname has less than
min or more than max labels.
* QNameWireLengthRule(min, max) matches if the qname's length on the
wire is less than min or more than max bytes.
* Also add Lua bindings for DNSName's `countLabels()` and `wirelength()`

pdns/README-dnsdist.md
pdns/dnsdist-console.cc
pdns/dnsdist-lua.cc
pdns/dnsrulactions.hh
regression-tests.dnsdist/test_Advanced.py

index b4076558c4942f40e3731e2a25df2df97b7867b8..70173b9752f133012f87e459deac34d6a393a09a 100644 (file)
@@ -338,6 +338,8 @@ Rules have selectors and actions. Current selectors are:
  * Number of entries in a given section (RecordsCountRule)
  * Number of entries of a specific type in a given section (RecordsTypeCountRule)
  * Presence of trailing data (TrailingDataRule)
+ * Number of labels in the qname (QNameLabelsCountRule)
+ * Wire length of the qname (QNameWireLengthRule)
 
 Special rules are:
 
@@ -392,6 +394,8 @@ A DNS rule can be:
  * an OpcodeRule
  * an OrRule
  * a QClassRule
+ * a QNameLabelsCountRule
+ * a QNameWireLengthRule
  * a QTypeRule
  * a RegexRule
  * a RE2Rule
@@ -1220,6 +1224,8 @@ instantiate a server with additional parameters
     * `OrRule()`: matches if at least one of the sub-rules matches
     * `OpcodeRule()`: matches queries with the specified opcode
     * `QClassRule(qclass)`: matches queries with the specified qclass (numeric)
+    * `QNameLabelsCountRule(min, max)`: matches if the qname has less than `min` or more than `max` labels
+    * `QNameWireLengthRule(min, max)`: matches if the qname's length on the wire is less than `min` or more than `max` bytes
     * `QTypeRule(qtype)`: matches queries with the specified qtype
     * `RegexRule(regex)`: matches the query name against the supplied regex
     * `RecordsCountRule(section, minCount, maxCount)`: matches if there is at least `minCount` and at most `maxCount` records in the `section` section
@@ -1340,9 +1346,11 @@ instantiate a server with additional parameters
         * `toStringWithPort()`: alias for `tostringWithPort()`
     * DNSName related:
         * `newDNSName(name)`: make a DNSName based on this .-terminated name
+        * member `countLabels()`: return the number of labels
         * member `isPartOf(dnsname)`: is this dnsname part of that dnsname
         * member `tostring()`: return as a human friendly . terminated string
         * member `toString()`: alias for `tostring()`
+        * member `wirelength()`: return the length on the wire
     * DNSQuestion related:
         * member `dh`: DNSHeader
         * member `len`: the question length
index c909efc09af84f24c35846c5a0462e2693f195be..1b37daac165ea372ae706177f3aad6b4a1640f5d 100644 (file)
@@ -248,7 +248,7 @@ char* my_generator(const char* text, int state)
       "PoolAction(", "printDNSCryptProviderFingerprint(",
       "RegexRule(", "RemoteLogAction(", "RemoteLogResponseAction(", "rmResponseRule(",
       "rmRule(", "rmServer(", "roundrobin",
-      "QTypeRule(",
+      "QNameLabelsCountRule(", "QNameWireLengthRule(", "QTypeRule(",
       "setACL(", "setDNSSECPool(", "setECSOverride(",
       "setECSSourcePrefixV4(", "setECSSourcePrefixV6(", "setKey(", "setLocal(",
       "setMaxTCPClientThreads(", "setMaxTCPQueuedConnections(", "setMaxUDPOutstanding(", "setRules(",
index ebd6e41889d6afbadc2d94e9aa07e9041b82d65c..e55e73dcfe7f364f061918df6b5ab45349995efc 100644 (file)
@@ -879,6 +879,14 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
       return std::shared_ptr<DNSRule>(new TrailingDataRule());
     });
 
+  g_lua.writeFunction("QNameLabelsCountRule", [](unsigned int minLabelsCount, unsigned int maxLabelsCount) {
+      return std::shared_ptr<DNSRule>(new QNameLabelsCountRule(minLabelsCount, maxLabelsCount));
+    });
+
+  g_lua.writeFunction("QNameWireLengthRule", [](size_t min, size_t max) {
+      return std::shared_ptr<DNSRule>(new QNameWireLengthRule(min, max));
+    });
+
   g_lua.writeFunction("addAction", [](luadnsrule_t var, std::shared_ptr<DNSAction> ea)
                      {
                         setLuaSideEffect();
@@ -1066,6 +1074,8 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
   g_lua.registerFunction("toStringWithPort", &ComboAddress::toStringWithPort);
   g_lua.registerFunction<uint16_t(ComboAddress::*)()>("getPort", [](const ComboAddress& ca) { return ntohs(ca.sin4.sin_port); } );
   g_lua.registerFunction("isPartOf", &DNSName::isPartOf);
+  g_lua.registerFunction("countLabels", &DNSName::countLabels);
+  g_lua.registerFunction("wirelength", &DNSName::wirelength);
   g_lua.registerFunction<string(DNSName::*)()>("tostring", [](const DNSName&dn ) { return dn.toString(); });
   g_lua.registerFunction<string(DNSName::*)()>("toString", [](const DNSName&dn ) { return dn.toString(); });
   g_lua.writeFunction("newDNSName", [](const std::string& name) { return DNSName(name); });
index 33284d5f24e6956c480d0ef5f5a5fa7f038ff8ca..416d40fba4af95958a5d29435d8481100e0a9415 100644 (file)
@@ -488,6 +488,46 @@ public:
   }
 };
 
+class QNameLabelsCountRule : public DNSRule
+{
+public:
+  QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    unsigned int count = dq->qname->countLabels();
+    return count < d_min || count > d_max;
+  }
+  string toString() const override
+  {
+    return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
+  }
+private:
+  unsigned int d_min;
+  unsigned int d_max;
+};
+
+class QNameWireLengthRule : public DNSRule
+{
+public:
+  QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    size_t const wirelength = dq->qname->wirelength();
+    return wirelength < d_min || wirelength > d_max;
+  }
+  string toString() const override
+  {
+    return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
+  }
+private:
+  size_t d_min;
+  size_t d_max;
+};
+
 class DropAction : public DNSAction
 {
 public:
index e2542b3bfa5f0a7ac409dcd75d035da6bd7f4013..d98392fd73e202bd6246bc6f45a5a44bd11cd8b5 100644 (file)
@@ -1018,3 +1018,121 @@ class TestAdvancedNMGRule(DNSDistTest):
         (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
         self.assertEquals(receivedResponse, expectedResponse)
 
+class TestAdvancedLabelsCountRule(DNSDistTest):
+
+    _config_template = """
+    addAction(QNameLabelsCountRule(5,6), RCodeAction(dnsdist.REFUSED))
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testAdvancedLabelsCountRule(self):
+        """
+        Advanced: QNameLabelsCountRule(5,6)
+        """
+        # 6 labels, we should be fine
+        name = 'ok.labelscount.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # more than 6 labels, the query should be refused
+        name = 'not.ok.labelscount.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)
+
+        # less than 5 labels, the query should be refused
+        name = 'labelscountadvanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)
+
+class TestAdvancedWireLengthRule(DNSDistTest):
+
+    _config_template = """
+    addAction(QNameWireLengthRule(54,56), RCodeAction(dnsdist.REFUSED))
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    def testAdvancedWireLengthRule(self):
+        """
+        Advanced: QNameWireLengthRule(54,56)
+        """
+        name = 'longenough.qnamewirelength.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+        # too short, the query should be refused
+        name = 'short.qnamewirelength.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)
+
+        # too long, the query should be refused
+        name = 'toolongtobevalid.qnamewirelength.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        expectedResponse = dns.message.make_response(query)
+        expectedResponse.set_rcode(dns.rcode.REFUSED)
+
+        (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)
+
+        (_, receivedResponse) = self.sendTCPQuery(query, response=None, useQueue=False)
+        self.assertEquals(receivedResponse, expectedResponse)