]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
rec: fix TCP case for cached policy tags
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Mon, 17 Jun 2024 12:58:01 +0000 (14:58 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 18 Jun 2024 09:29:42 +0000 (11:29 +0200)
(cherry picked from commit a7f8db9e9259dfe08e47959a6613f80b971ea535)

pdns/recursordist/rec-main.hh
pdns/recursordist/rec-tcp.cc
regression-tests.recursor-dnssec/test_Protobuf.py

index 9fc1d96bc8f0be2808a89688f2b116ab16ed2c90..6c7910df5929a4bd36eb700aaee796afd903ae38 100644 (file)
@@ -125,7 +125,7 @@ struct DNSComboWriter
   };
   std::string d_query;
   std::unordered_set<std::string> d_policyTags;
-  const std::unordered_set<std::string> d_gettagPolicyTags;
+  std::unordered_set<std::string> d_gettagPolicyTags;
   std::string d_routingTag;
   std::vector<DNSRecord> d_records;
 
index 7f7c4bef5b9440f7041286b1940ec9a5fb34e38e..e74c6f9e2185b790c51340d292ab9e1fd3a79148 100644 (file)
@@ -327,16 +327,23 @@ static void doProcessTCPQuestion(std::unique_ptr<DNSComboWriter>& comboWriter, s
       if (t_pdl) {
         try {
           if (t_pdl->d_gettag_ffi) {
-            RecursorLua4::FFIParams params(qname, qtype, comboWriter->d_destination, comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_data, comboWriter->d_policyTags, comboWriter->d_records, ednsOptions, comboWriter->d_proxyProtocolValues, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_rcode, comboWriter->d_ttlCap, comboWriter->d_variable, true, logQuery, comboWriter->d_logResponse, comboWriter->d_followCNAMERecords, comboWriter->d_extendedErrorCode, comboWriter->d_extendedErrorExtra, comboWriter->d_responsePaddingDisabled, comboWriter->d_meta);
+            RecursorLua4::FFIParams params(qname, qtype, comboWriter->d_destination, comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_data, comboWriter->d_gettagPolicyTags, comboWriter->d_records, ednsOptions, comboWriter->d_proxyProtocolValues, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_rcode, comboWriter->d_ttlCap, comboWriter->d_variable, true, logQuery, comboWriter->d_logResponse, comboWriter->d_followCNAMERecords, comboWriter->d_extendedErrorCode, comboWriter->d_extendedErrorExtra, comboWriter->d_responsePaddingDisabled, comboWriter->d_meta);
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTagFFI);
             comboWriter->d_tag = t_pdl->gettag_ffi(params);
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTagFFI, comboWriter->d_tag, false);
           }
           else if (t_pdl->d_gettag) {
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTag);
-            comboWriter->d_tag = t_pdl->gettag(comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_destination, qname, qtype, &comboWriter->d_policyTags, comboWriter->d_data, ednsOptions, true, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_proxyProtocolValues);
+            comboWriter->d_tag = t_pdl->gettag(comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_destination, qname, qtype, &comboWriter->d_gettagPolicyTags, comboWriter->d_data, ednsOptions, true, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_proxyProtocolValues);
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTag, comboWriter->d_tag, false);
           }
+          // Copy d_gettagPolicyTags to d_policyTags, so other Lua hooks see them and can add their
+          // own. Before storing into the packetcache, the tags in d_gettagPolicyTags will be
+          // cleared by addPolicyTagsToPBMessageIfNeeded() so they do *not* end up in the PC. When an
+          // Protobuf message is constructed, one part comes from the PC (including the tags
+          // set by non-gettag hooks), and the tags in d_gettagPolicyTags will be added by the code
+          // constructing the PB message.
+          comboWriter->d_policyTags = comboWriter->d_gettagPolicyTags;
         }
         catch (const std::exception& e) {
           if (g_logCommonErrors) {
index 953a9ce20ead2a57aa4516deb8597f2b72ee935a..2af5114d9f239835f62e586a0bbb2a69c418a60d 100644 (file)
@@ -302,6 +302,7 @@ class TestRecursorProtobuf(RecursorTest):
 @ 3600 IN SOA {soa}
 a 3600 IN A 192.0.2.42
 tagged 3600 IN A 192.0.2.84
+taggedtcp 3600 IN A 192.0.2.87
 meta 3600 IN A 192.0.2.85
 query-selected 3600 IN A 192.0.2.84
 answer-selected 3600 IN A 192.0.2.84
@@ -994,7 +995,7 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
     """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
     _lua_dns_script_file = """
     function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp)
-      if qname:equal('tagged.example.') then
+      if qname:equal('tagged.example.') or qname:equal('taggedtcp.example.') then
         return 0, { '' .. math.random() }
       end
       return 0
@@ -1036,6 +1037,145 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         ts2 = msg.response.tags[0]
         self.assertNotEqual(ts1, ts2)
 
+    def testTaggedTCP(self):
+        name = 'taggedtcp.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.87')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        print(msg.response)
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        print(msg.response)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
+class ProtobufTagCacheFFITest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly cache tags (actually not cache them) for the FFI case
+    """
+
+    _confdir = 'ProtobufTagCacheFFI'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+    _lua_dns_script_file = """
+    local ffi = require("ffi")
+
+    ffi.cdef[[
+      typedef struct pdns_ffi_param pdns_ffi_param_t;
+
+      const char* pdns_ffi_param_get_qname(pdns_ffi_param_t* ref);
+      void pdns_ffi_param_add_policytag(pdns_ffi_param_t* ref, const char* name);
+    ]]
+
+    function gettag_ffi(obj)
+      qname = ffi.string(ffi.C.pdns_ffi_param_get_qname(obj))
+      if qname == 'tagged.example' or qname == 'taggedtcp.example' then
+        ffi.C.pdns_ffi_param_add_policytag(obj, '' .. math.random())
+      end
+      return 0
+    end
+    """
+
+    def testTagged(self):
+        name = 'tagged.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
+    def testTaggedTCP(self):
+        name = 'taggedtcp.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.87')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        print(msg.response)
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        print(msg.response)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
 class ProtobufSelectedFromLuaTest(TestRecursorProtobuf):
     """
     This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.