From 4170d08a1f450f782e3b7a5b937b7a9f9870e487 Mon Sep 17 00:00:00 2001 From: Charles-Henri Bruyand Date: Mon, 27 Sep 2021 18:32:43 +0200 Subject: [PATCH] dnsdist: make sure setting tags will overwrite any existing value --- pdns/dnsdist-lua-actions.cc | 24 +---- pdns/dnsdist-lua-bindings-dnsquestion.cc | 22 +--- pdns/dnsdist.hh | 14 +++ pdns/dnsdistdist/dnsdist-lua-ffi.cc | 6 +- pdns/dnsdistdist/docs/reference/dq.rst | 12 ++- pdns/dnsdistdist/docs/rules-actions.rst | 10 +- regression-tests.dnsdist/test_Tags.py | 125 +++++++++++++++++++++++ 7 files changed, 164 insertions(+), 49 deletions(-) diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index e8f94a4c73..b6e83949fe 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -1456,11 +1456,7 @@ public: } DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - if (!dq->qTag) { - dq->qTag = std::make_shared(); - } - - dq->qTag->insert({d_tag, d_value}); + dq->setTag(d_tag, d_value); return Action::None; } @@ -1636,11 +1632,7 @@ public: } DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override { - if (!dr->qTag) { - dr->qTag = std::make_shared(); - } - - dr->qTag->insert({d_tag, d_value}); + dr->setTag(d_tag, d_value); return Action::None; } @@ -1740,11 +1732,7 @@ public: } } - if (!dq->qTag) { - dq->qTag = std::make_shared(); - } - - dq->qTag->insert({d_tag, std::move(result)}); + dq->setTag(d_tag, std::move(result)); return Action::None; } @@ -1778,11 +1766,7 @@ public: } } - if (!dq->qTag) { - dq->qTag = std::make_shared(); - } - - dq->qTag->insert({d_tag, std::move(result)}); + dq->setTag(d_tag, std::move(result)); return Action::None; } diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index 3ddfa44157..c7ecf3a04d 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -86,18 +86,11 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) }); luaCtx.registerFunction("setTag", [](DNSQuestion& dq, const std::string& strLabel, const std::string& strValue) { - if(dq.qTag == nullptr) { - dq.qTag = std::make_shared(); - } - dq.qTag->insert({strLabel, strValue}); + dq.setTag(strLabel, strValue); }); luaCtx.registerFunction>)>("setTagArray", [](DNSQuestion& dq, const vector>&tags) { - if (!dq.qTag) { - dq.qTag = std::make_shared(); - } - for (const auto& tag : tags) { - dq.qTag->insert({tag.first, tag.second}); + dq.setTag(tag.first, tag.second); } }); luaCtx.registerFunction("getTag", [](const DNSQuestion& dq, const std::string& strLabel) { @@ -215,19 +208,12 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) }); luaCtx.registerFunction("setTag", [](DNSResponse& dr, const std::string& strLabel, const std::string& strValue) { - if(dr.qTag == nullptr) { - dr.qTag = std::make_shared(); - } - dr.qTag->insert({strLabel, strValue}); + dr.setTag(strLabel, strValue); }); luaCtx.registerFunction>)>("setTagArray", [](DNSResponse& dr, const vector>&tags) { - if (!dr.qTag) { - dr.qTag = std::make_shared(); - } - for (const auto& tag : tags) { - dr.qTag->insert({tag.first, tag.second}); + dr.setTag(tag.first, tag.second); } }); luaCtx.registerFunction("getTag", [](const DNSResponse& dr, const std::string& strLabel) { diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 17dd54fa26..2f4d6480fc 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -121,6 +121,20 @@ struct DNSQuestion return !(protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP); } + void setTag(const std::string& key, std::string&& value) { + if (!qTag) { + qTag = std::make_shared(); + } + qTag->insert_or_assign(key, std::move(value)); + } + + void setTag(const std::string& key, const std::string& value) { + if (!qTag) { + qTag = std::make_shared(); + } + qTag->insert_or_assign(key, value); + } + protected: PacketBuffer& data; diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index 67be9c8afa..25e36c218f 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -409,11 +409,7 @@ void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* d void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value) { - if (!dq->dq->qTag) { - dq->dq->qTag = std::make_shared(); - } - - dq->dq->qTag->insert({label, value}); + dq->dq->setTag(label, value); } size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out) diff --git a/pdns/dnsdistdist/docs/reference/dq.rst b/pdns/dnsdistdist/docs/reference/dq.rst index 89b04a5ced..309fa9a135 100644 --- a/pdns/dnsdistdist/docs/reference/dq.rst +++ b/pdns/dnsdistdist/docs/reference/dq.rst @@ -236,16 +236,20 @@ This state can be modified from the various hooks. .. method:: DNSQuestion:setTag(key, value) - Set a tag into the DNSQuestion object. - This function will not overwrite an existing tag. If the tag already exists it will keep its original value. + .. versionchanged:: 1.7.0 + Prior to 1.7.0 calling :func:`DNSQuestion:setTag` would not overwrite an existing tag value if already set. + + Set a tag into the DNSQuestion object. Overwrites the value if any already exists. :param string key: The tag's key :param string value: The tag's value .. method:: DNSQuestion:setTagArray(tags) - Set an array of tags into the DNSQuestion object. - This function will not overwrite an existing tag. If the tag already exists it will keep its original value. + .. versionchanged:: 1.7.0 + Prior to 1.7.0 calling :func:`DNSQuestion:setTagArray` would not overwrite existing tag values if already set. + + Set an array of tags into the DNSQuestion object. Overwrites the values if any already exist. :param table tags: A table of tags, using strings as keys and values diff --git a/pdns/dnsdistdist/docs/rules-actions.rst b/pdns/dnsdistdist/docs/rules-actions.rst index 2de3cc0fcd..3493d7f67c 100644 --- a/pdns/dnsdistdist/docs/rules-actions.rst +++ b/pdns/dnsdistdist/docs/rules-actions.rst @@ -1364,8 +1364,11 @@ The following actions exist. .. versionadded:: 1.6.0 + .. versionchanged:: 1.7.0 + Prior to 1.7.0 :func:`SetTagAction` would not overwrite an existing tag value if already set. + Associate a tag named ``name`` with a value of ``value`` to this query, that will be passed on to the response. - This function will not overwrite an existing tag. If the tag already exists it will keep its original value. + This function will overwrite any existing tag value. Subsequent rules are processed after this action. Note that this function was called :func:`TagAction` before 1.6.0. @@ -1376,8 +1379,11 @@ The following actions exist. .. versionadded:: 1.6.0 + .. versionchanged:: 1.7.0 + Prior to 1.7.0 :func:`SetTagResponseAction` would not overwrite an existing tag value if already set. + Associate a tag named ``name`` with a value of ``value`` to this response. - This function will not overwrite an existing tag. If the tag already exists it will keep its original value. + This function will overwrite any existing tag value. Subsequent rules are processed after this action. Note that this function was called :func:`TagResponseAction` before 1.6.0. diff --git a/regression-tests.dnsdist/test_Tags.py b/regression-tests.dnsdist/test_Tags.py index a6f7bce0a6..1e5553a0f0 100644 --- a/regression-tests.dnsdist/test_Tags.py +++ b/regression-tests.dnsdist/test_Tags.py @@ -187,3 +187,128 @@ class TestTags(DNSDistTest): receivedQuery.id = query.id self.assertEqual(query, receivedQuery) self.assertEqual(expectedResponse, receivedResponse) + +class TestSetTagAction(DNSDistTest): + + _config_template = """ + newServer{address="127.0.0.1:%s"} + + addAction(AllRule(), SetTagAction("dns", "value1")) + addAction("tag-me-dns-2.tags.tests.powerdns.com.", SetTagAction("dns", "value2")) + + addAction(TagRule("dns", "value1"), SpoofAction("1.2.3.50")) + addAction(TagRule("dns", "value2"), SpoofAction("1.2.3.4")) + + """ + + def testSetTagDefault(self): + + """ + Tag: Test setTag overwrites existing value + """ + name = 'tag-me-dns-1.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.50') + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEqual(expectedResponse, receivedResponse) + + def testSetTagOverwritten(self): + + """ + Tag: Test setTag overwrites existing value + """ + name = 'tag-me-dns-2.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.4') + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEqual(expectedResponse, receivedResponse) + +class TestSetTag(DNSDistTest): + + _config_template = """ + newServer{address="127.0.0.1:%s"} + + function dqset(dq) + dq:setTag("dns", "value1") + if tostring(dq.qname) == 'tag-me-dns-2.tags.tests.powerdns.com.' then + dq:setTag("dns", "value2") + end + return DNSAction.None, "" + end + + addAction(AllRule(), LuaAction(dqset)) + + addAction(TagRule("dns", "value1"), SpoofAction("1.2.3.50")) + addAction(TagRule("dns", "value2"), SpoofAction("1.2.3.4")) + + """ + + def testSetTagDefault(self): + + """ + Tag: Test setTag overwrites existing value + """ + name = 'tag-me-dns-1.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.50') + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEqual(expectedResponse, receivedResponse) + + def testSetTagOverwritten(self): + + """ + Tag: Test setTag overwrites existing value + """ + name = 'tag-me-dns-2.tags.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + # dnsdist set RA = RD for spoofed responses + query.flags &= ~dns.flags.RD + expectedResponse = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 60, + dns.rdataclass.IN, + dns.rdatatype.A, + '1.2.3.4') + expectedResponse.answer.append(rrset) + + for method in ("sendUDPQuery", "sendTCPQuery"): + sender = getattr(self, method) + (_, receivedResponse) = sender(query, response=None, useQueue=False) + self.assertTrue(receivedResponse) + self.assertEqual(expectedResponse, receivedResponse) -- 2.47.2