]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: make sure setting tags will overwrite any existing value 10767/head
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Mon, 27 Sep 2021 16:32:43 +0000 (18:32 +0200)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Mon, 27 Sep 2021 16:32:43 +0000 (18:32 +0200)
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist.hh
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/docs/reference/dq.rst
pdns/dnsdistdist/docs/rules-actions.rst
regression-tests.dnsdist/test_Tags.py

index e8f94a4c735d5541880894ecb36d2914d8961bd2..b6e83949fec1a2292c9df806fdbc7f5ba038f17f 100644 (file)
@@ -1456,11 +1456,7 @@ public:
   }
   DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
   {
-    if (!dq->qTag) {
-      dq->qTag = std::make_shared<QTag>();
-    }
-
-    dq->qTag->insert({d_tag, d_value});
+    dq->setTag(d_tag, d_value);
 
     return Action::None;
   }
@@ -1636,11 +1632,7 @@ public:
   }
   DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
   {
-    if (!dr->qTag) {
-      dr->qTag = std::make_shared<QTag>();
-    }
-
-    dr->qTag->insert({d_tag, d_value});
+    dr->setTag(d_tag, d_value);
 
     return Action::None;
   }
@@ -1740,11 +1732,7 @@ public:
       }
     }
 
-    if (!dq->qTag) {
-      dq->qTag = std::make_shared<QTag>();
-    }
-
-    dq->qTag->insert({d_tag, std::move(result)});
+    dq->setTag(d_tag, std::move(result));
 
     return Action::None;
   }
@@ -1778,11 +1766,7 @@ public:
       }
     }
 
-    if (!dq->qTag) {
-      dq->qTag = std::make_shared<QTag>();
-    }
-
-    dq->qTag->insert({d_tag, std::move(result)});
+    dq->setTag(d_tag, std::move(result));
 
     return Action::None;
   }
index 3ddfa4415728828912b052af9c4d1f4e467d344b..c7ecf3a04df88a5d72efadab9f816872b8a034fb 100644 (file)
@@ -86,18 +86,11 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     });
 
   luaCtx.registerFunction<void(DNSQuestion::*)(std::string, std::string)>("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) {
-      if(dq.qTag == nullptr) {
-        dq.qTag = std::make_shared<QTag>();
-      }
-      dq.qTag->insert({strLabel, strValue});
+      dq.setTag(strLabel, strValue);
     });
   luaCtx.registerFunction<void(DNSQuestion::*)(vector<pair<string, string>>)>("setTagArray", [](DNSQuestion& dq, const vector<pair<string, string>>&tags) {
-      if (!dq.qTag) {
-        dq.qTag = std::make_shared<QTag>();
-      }
-
       for (const auto& tag : tags) {
-        dq.qTag->insert({tag.first, tag.second});
+        dq.setTag(tag.first, tag.second);
       }
     });
   luaCtx.registerFunction<string(DNSQuestion::*)(std::string)const>("getTag", [](const DNSQuestion& dq, const std::string& strLabel) {
@@ -215,19 +208,12 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
     });
 
   luaCtx.registerFunction<void(DNSResponse::*)(std::string, std::string)>("setTag", [](DNSResponse& dr, const std::string& strLabel, const std::string& strValue) {
-      if(dr.qTag == nullptr) {
-        dr.qTag = std::make_shared<QTag>();
-      }
-      dr.qTag->insert({strLabel, strValue});
+      dr.setTag(strLabel, strValue);
     });
 
   luaCtx.registerFunction<void(DNSResponse::*)(vector<pair<string, string>>)>("setTagArray", [](DNSResponse& dr, const vector<pair<string, string>>&tags) {
-      if (!dr.qTag) {
-        dr.qTag = std::make_shared<QTag>();
-      }
-
       for (const auto& tag : tags) {
-        dr.qTag->insert({tag.first, tag.second});
+        dr.setTag(tag.first, tag.second);
       }
     });
   luaCtx.registerFunction<string(DNSResponse::*)(std::string)const>("getTag", [](const DNSResponse& dr, const std::string& strLabel) {
index 17dd54fa26400e42632b790efa81044144ecb566..2f4d6480fc666620328ebbdde81015946259360d 100644 (file)
@@ -121,6 +121,20 @@ struct DNSQuestion
     return !(protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP);
   }
 
+  void setTag(const std::string& key, std::string&& value) {
+    if (!qTag) {
+      qTag = std::make_shared<QTag>();
+    }
+    qTag->insert_or_assign(key, std::move(value));
+  }
+
+  void setTag(const std::string& key, const std::string& value) {
+    if (!qTag) {
+      qTag = std::make_shared<QTag>();
+    }
+    qTag->insert_or_assign(key, value);
+  }
+
 protected:
   PacketBuffer& data;
 
index 67be9c8afa3e775179153341b0f8683f683d6a00..25e36c218f40eff4030e6ddc0fa8f6430d4b3b69 100644 (file)
@@ -409,11 +409,7 @@ void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* d
 
 void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value)
 {
-  if (!dq->dq->qTag) {
-    dq->dq->qTag = std::make_shared<QTag>();
-  }
-
-  dq->dq->qTag->insert({label, value});
+  dq->dq->setTag(label, value);
 }
 
 size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out)
index 89b04a5cedebaad2051d76d6e1b90992f630cc0d..309fa9a135e9845571e09171080f631aced41097 100644 (file)
@@ -236,16 +236,20 @@ This state can be modified from the various hooks.
 
   .. method:: DNSQuestion:setTag(key, value)
 
-    Set a tag into the DNSQuestion object.
-    This function will not overwrite an existing tag. If the tag already exists it will keep its original value.
+    .. versionchanged:: 1.7.0
+      Prior to 1.7.0 calling :func:`DNSQuestion:setTag` would not overwrite an existing tag value if already set.
+
+    Set a tag into the DNSQuestion object. Overwrites the value if any already exists.
   
     :param string key: The tag's key
     :param string value: The tag's value
 
   .. method:: DNSQuestion:setTagArray(tags)
 
-    Set an array of tags into the DNSQuestion object.
-    This function will not overwrite an existing tag. If the tag already exists it will keep its original value.
+    .. versionchanged:: 1.7.0
+      Prior to 1.7.0 calling :func:`DNSQuestion:setTagArray` would not overwrite existing tag values if already set.
+
+    Set an array of tags into the DNSQuestion object. Overwrites the values if any already exist.
   
     :param table tags: A table of tags, using strings as keys and values
 
index 2de3cc0fcd8040068be8bcc97d2e473887361c55..3493d7f67c79bd5293f763fd1cde471eff607286 100644 (file)
@@ -1364,8 +1364,11 @@ The following actions exist.
 
   .. versionadded:: 1.6.0
 
+  .. versionchanged:: 1.7.0
+    Prior to 1.7.0 :func:`SetTagAction` would not overwrite an existing tag value if already set.
+
   Associate a tag named ``name`` with a value of ``value`` to this query, that will be passed on to the response.
-  This function will not overwrite an existing tag. If the tag already exists it will keep its original value.
+  This function will overwrite any existing tag value.
   Subsequent rules are processed after this action.
   Note that this function was called :func:`TagAction` before 1.6.0.
 
@@ -1376,8 +1379,11 @@ The following actions exist.
 
   .. versionadded:: 1.6.0
 
+  .. versionchanged:: 1.7.0
+    Prior to 1.7.0 :func:`SetTagResponseAction` would not overwrite an existing tag value if already set.
+
   Associate a tag named ``name`` with a value of ``value`` to this response.
-  This function will not overwrite an existing tag. If the tag already exists it will keep its original value.
+  This function will overwrite any existing tag value.
   Subsequent rules are processed after this action.
   Note that this function was called :func:`TagResponseAction` before 1.6.0.
 
index a6f7bce0a69f4b889d7d8efb85a3353db6d8dcea..1e5553a0f03d66c0fb49e0a5f7b7da4ab0d3b598 100644 (file)
@@ -187,3 +187,128 @@ class TestTags(DNSDistTest):
             receivedQuery.id = query.id
             self.assertEqual(query, receivedQuery)
             self.assertEqual(expectedResponse, receivedResponse)
+
+class TestSetTagAction(DNSDistTest):
+
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    addAction(AllRule(), SetTagAction("dns", "value1"))
+    addAction("tag-me-dns-2.tags.tests.powerdns.com.", SetTagAction("dns", "value2"))
+
+    addAction(TagRule("dns", "value1"), SpoofAction("1.2.3.50"))
+    addAction(TagRule("dns", "value2"), SpoofAction("1.2.3.4"))
+
+    """
+
+    def testSetTagDefault(self):
+
+        """
+        Tag: Test setTag overwrites existing value
+        """
+        name = 'tag-me-dns-1.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.50')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+    def testSetTagOverwritten(self):
+
+        """
+        Tag: Test setTag overwrites existing value
+        """
+        name = 'tag-me-dns-2.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.4')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+class TestSetTag(DNSDistTest):
+
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    function dqset(dq)
+      dq:setTag("dns", "value1")
+      if tostring(dq.qname) == 'tag-me-dns-2.tags.tests.powerdns.com.' then
+        dq:setTag("dns", "value2")
+      end
+      return DNSAction.None, ""
+    end
+
+    addAction(AllRule(), LuaAction(dqset))
+
+    addAction(TagRule("dns", "value1"), SpoofAction("1.2.3.50"))
+    addAction(TagRule("dns", "value2"), SpoofAction("1.2.3.4"))
+
+    """
+
+    def testSetTagDefault(self):
+
+        """
+        Tag: Test setTag overwrites existing value
+        """
+        name = 'tag-me-dns-1.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.50')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+    def testSetTagOverwritten(self):
+
+        """
+        Tag: Test setTag overwrites existing value
+        """
+        name = 'tag-me-dns-2.tags.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '1.2.3.4')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (_, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)