]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
rec: fix TCP case for cached policy tags
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Mon, 17 Jun 2024 12:58:01 +0000 (14:58 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 18 Jun 2024 07:46:49 +0000 (09:46 +0200)
pdns/recursordist/rec-main.hh
pdns/recursordist/rec-tcp.cc
regression-tests.recursor-dnssec/test_Protobuf.py

index c69e4de74a199d780584fba52c71cae5b187b346..b97ec06237bd212c0b71543847c3246ae4b94f45 100644 (file)
@@ -125,7 +125,7 @@ struct DNSComboWriter
   };
   std::string d_query;
   std::unordered_set<std::string> d_policyTags;
-  const std::unordered_set<std::string> d_gettagPolicyTags;
+  std::unordered_set<std::string> d_gettagPolicyTags;
   std::string d_routingTag;
   std::vector<DNSRecord> d_records;
 
index 386dd5060a8ee73b225f8f3fc710e30e1a68093a..a1af4a10889e9084a3c381c42e15a32cdbf70961 100644 (file)
@@ -327,16 +327,23 @@ static void doProcessTCPQuestion(std::unique_ptr<DNSComboWriter>& comboWriter, s
       if (t_pdl) {
         try {
           if (t_pdl->hasGettagFFIFunc()) {
-            RecursorLua4::FFIParams params(qname, qtype, comboWriter->d_local, comboWriter->d_remote, comboWriter->d_destination, comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_data, comboWriter->d_policyTags, comboWriter->d_records, ednsOptions, comboWriter->d_proxyProtocolValues, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_rcode, comboWriter->d_ttlCap, comboWriter->d_variable, true, logQuery, comboWriter->d_logResponse, comboWriter->d_followCNAMERecords, comboWriter->d_extendedErrorCode, comboWriter->d_extendedErrorExtra, comboWriter->d_responsePaddingDisabled, comboWriter->d_meta);
+            RecursorLua4::FFIParams params(qname, qtype, comboWriter->d_local, comboWriter->d_remote, comboWriter->d_destination, comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_data, comboWriter->d_gettagPolicyTags, comboWriter->d_records, ednsOptions, comboWriter->d_proxyProtocolValues, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_rcode, comboWriter->d_ttlCap, comboWriter->d_variable, true, logQuery, comboWriter->d_logResponse, comboWriter->d_followCNAMERecords, comboWriter->d_extendedErrorCode, comboWriter->d_extendedErrorExtra, comboWriter->d_responsePaddingDisabled, comboWriter->d_meta);
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTagFFI);
             comboWriter->d_tag = t_pdl->gettag_ffi(params);
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTagFFI, comboWriter->d_tag, false);
           }
           else if (t_pdl->hasGettagFunc()) {
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTag);
-            comboWriter->d_tag = t_pdl->gettag(comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_destination, qname, qtype, &comboWriter->d_policyTags, comboWriter->d_data, ednsOptions, true, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_proxyProtocolValues);
+            comboWriter->d_tag = t_pdl->gettag(comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_destination, qname, qtype, &comboWriter->d_gettagPolicyTags, comboWriter->d_data, ednsOptions, true, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_proxyProtocolValues);
             comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTag, comboWriter->d_tag, false);
           }
+          // Copy d_gettagPolicyTags to d_policyTags, so other Lua hooks see them and can add their
+          // own. Before storing into the packetcache, the tags in d_gettagPolicyTags will be
+          // cleared by addPolicyTagsToPBMessageIfNeeded() so they do *not* end up in the PC. When an
+          // Protobuf message is constructed, one part comes from the PC (including the tags
+          // set by non-gettag hooks), and the tags in d_gettagPolicyTags will be added by the code
+          // constructing the PB message.
+          comboWriter->d_policyTags = comboWriter->d_gettagPolicyTags;
         }
         catch (const std::exception& e) {
           if (g_logCommonErrors) {
index b83d9c95a07367c563e0db371913210d4d6933a9..d9b9d851999fafa23153eb6b2a039d46ad8d6be1 100644 (file)
@@ -305,6 +305,7 @@ class TestRecursorProtobuf(RecursorTest):
 @ 3600 IN SOA {soa}
 a 3600 IN A 192.0.2.42
 tagged 3600 IN A 192.0.2.84
+taggedtcp 3600 IN A 192.0.2.87
 meta 3600 IN A 192.0.2.85
 query-selected 3600 IN A 192.0.2.84
 answer-selected 3600 IN A 192.0.2.84
@@ -997,7 +998,7 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
     """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
     _lua_dns_script_file = """
     function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp)
-      if qname:equal('tagged.example.') then
+      if qname:equal('tagged.example.') or qname:equal('taggedtcp.example.') then
         return 0, { '' .. math.random() }
       end
       return 0
@@ -1039,6 +1040,145 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         ts2 = msg.response.tags[0]
         self.assertNotEqual(ts1, ts2)
 
+    def testTaggedTCP(self):
+        name = 'taggedtcp.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.87')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        print(msg.response)
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        print(msg.response)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
+class ProtobufTagCacheFFITest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly cache tags (actually not cache them) for the FFI case
+    """
+
+    _confdir = 'ProtobufTagCacheFFI'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+    _lua_dns_script_file = """
+    local ffi = require("ffi")
+
+    ffi.cdef[[
+      typedef struct pdns_ffi_param pdns_ffi_param_t;
+
+      const char* pdns_ffi_param_get_qname(pdns_ffi_param_t* ref);
+      void pdns_ffi_param_add_policytag(pdns_ffi_param_t* ref, const char* name);
+    ]]
+
+    function gettag_ffi(obj)
+      qname = ffi.string(ffi.C.pdns_ffi_param_get_qname(obj))
+      if qname == 'tagged.example' or qname == 'taggedtcp.example' then
+        ffi.C.pdns_ffi_param_add_policytag(obj, '' .. math.random())
+      end
+      return 0
+    end
+    """
+
+    def testTagged(self):
+        name = 'tagged.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
+    def testTaggedTCP(self):
+        name = 'taggedtcp.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.87')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        print(msg.response)
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendTCPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res)
+        print(msg.response)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # time may have passed, so do not check TTL
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
 class ProtobufSelectedFromLuaTest(TestRecursorProtobuf):
     """
     This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.