]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
rec: Backport 13021 to rec-4.8.x: fix setting of policy tags
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Thu, 20 Jul 2023 10:48:36 +0000 (12:48 +0200)
committerOtto Moerbeek <otto.moerbeek@open-xchange.com>
Thu, 20 Jul 2023 10:58:09 +0000 (12:58 +0200)
Backport of #13021

pdns/pdns_recursor.cc
pdns/recursordist/rec-main.cc
pdns/recursordist/rec-main.hh
pdns/recursordist/rec-tcp.cc
regression-tests.recursor-dnssec/test_Protobuf.py

index 6e8ac8219c11e8f80acfd8ffdc2357a0c804ec6f..858dc28b9d9975661f93a4649f90ffe909dcfd80 100644 (file)
@@ -835,6 +835,20 @@ static bool isEnabledForUDRs(const std::shared_ptr<std::vector<std::unique_ptr<F
 }
 #endif // HAVE_FSTRM
 
+static void addPolicyTagsToPBMessageIfNeeded(DNSComboWriter& comboWriter, pdns::ProtoZero::RecMessage& pbMessage)
+{
+  /* we do _not_ want to store policy tags set by the gettag hook into the packet cache,
+     since the call to gettag for subsequent queries could yield the same PC tag but different policy tags */
+  if (!comboWriter.d_gettagPolicyTags.empty()) {
+    for (const auto& tag : comboWriter.d_gettagPolicyTags) {
+      comboWriter.d_policyTags.erase(tag);
+    }
+  }
+  if (!comboWriter.d_policyTags.empty()) {
+    pbMessage.addPolicyTags(comboWriter.d_policyTags);
+  }
+}
+
 void startDoResolve(void* p)
 {
   auto dc = std::unique_ptr<DNSComboWriter>(reinterpret_cast<DNSComboWriter*>(p));
@@ -1603,9 +1617,10 @@ void startDoResolve(void* p)
         pbMessage.setAppliedPolicyHit(appliedPolicy.d_hit);
         pbMessage.setAppliedPolicyKind(appliedPolicy.d_kind);
       }
-      pbMessage.addPolicyTags(dc->d_policyTags);
       pbMessage.setInBytes(packet.size());
       pbMessage.setValidationState(sr.getValidationState());
+      // See if we want to store the policyTags into the PC
+      addPolicyTagsToPBMessageIfNeeded(*dc, pbMessage);
 
       // Take s snap of the current protobuf buffer state to store in the PC
       pbDataForCache = boost::make_optional(RecursorPacketCache::PBData{
@@ -1693,6 +1708,7 @@ void startDoResolve(void* p)
       pbMessage.setDeviceId(dq.deviceId);
       pbMessage.setDeviceName(dq.deviceName);
       pbMessage.setToPort(dc->d_destination.getPort());
+      pbMessage.addPolicyTags(dc->d_gettagPolicyTags);
 
       for (const auto& m : dq.meta) {
         pbMessage.setMeta(m.first, m.second.stringVal, m.second.intVal);
@@ -2132,7 +2148,7 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr
         eventTrace.add(RecEventTrace::AnswerSent);
 
         if (t_protobufServers.servers && logResponse && !(luaconfsLocal->protobufExportConfig.taggedOnly && pbData && !pbData->d_tagged)) {
-          protobufLogResponse(dh, luaconfsLocal, pbData, tv, false, source, destination, mappedSource, ednssubnet, uniqueId, requestorId, deviceId, deviceName, meta, eventTrace);
+          protobufLogResponse(dh, luaconfsLocal, pbData, tv, false, source, destination, mappedSource, ednssubnet, uniqueId, requestorId, deviceId, deviceName, meta, eventTrace, policyTags);
         }
 
         if (eventTrace.enabled() && SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_log) {
index 343454cd56ffc976d85d9cba96d6f0fb0087103c..bb6b742b890de023cd318e021bfe9801249665aa 100644 (file)
@@ -518,7 +518,8 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolder<LuaConfigI
                          const EDNSSubnetOpts& ednssubnet,
                          const boost::uuids::uuid& uniqueId, const string& requestorId, const string& deviceId,
                          const string& deviceName, const std::map<std::string, RecursorLua4::MetaValue>& meta,
-                         const RecEventTrace& eventTrace)
+                         const RecEventTrace& eventTrace,
+                         const std::unordered_set<std::string>& policyTags)
 {
   pdns::ProtoZero::RecMessage pbMessage(pbData ? pbData->d_message : "", pbData ? pbData->d_response : "", 64, 10); // The extra bytes we are going to add
   // Normally we take the immutable string from the cache and append a few values, but if it's not there (can this happen?)
@@ -538,12 +539,14 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolder<LuaConfigI
 
   // In message part
   if (!luaconfsLocal->protobufExportConfig.logMappedFrom) {
+    pbMessage.setSocketFamily(source.sin4.sin_family);
     Netmask requestorNM(source, source.sin4.sin_family == AF_INET ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6);
     auto requestor = requestorNM.getMaskedNetwork();
     pbMessage.setFrom(requestor);
     pbMessage.setFromPort(source.getPort());
   }
   else {
+    pbMessage.setSocketFamily(mappedSource.sin4.sin_family);
     Netmask requestorNM(mappedSource, mappedSource.sin4.sin_family == AF_INET ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6);
     auto requestor = requestorNM.getMaskedNetwork();
     pbMessage.setFrom(requestor);
@@ -571,6 +574,8 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolder<LuaConfigI
   if (eventTrace.enabled() && SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_pb) {
     pbMessage.addEvents(eventTrace);
   }
+  pbMessage.addPolicyTags(policyTags);
+
   protobufLogResponse(pbMessage);
 }
 
index 6c231f7fbe783661b37fe02df2a069e457952c1c..6464b5982b0382181b98038ac7a5dbf6e12249ba 100644 (file)
@@ -59,7 +59,7 @@ struct DNSComboWriter
   }
 
   DNSComboWriter(const std::string& query, const struct timeval& now, std::unordered_set<std::string>&& policyTags, shared_ptr<RecursorLua4> luaContext, LuaContext::LuaObject&& data, std::vector<DNSRecord>&& records) :
-    d_mdp(true, query), d_now(now), d_query(query), d_policyTags(std::move(policyTags)), d_records(std::move(records)), d_luaContext(luaContext), d_data(std::move(data))
+    d_mdp(true, query), d_now(now), d_query(query), d_policyTags(std::move(policyTags)), d_gettagPolicyTags(d_policyTags), d_records(std::move(records)), d_luaContext(luaContext), d_data(std::move(data))
   {
   }
 
@@ -125,6 +125,7 @@ struct DNSComboWriter
   };
   std::string d_query;
   std::unordered_set<std::string> d_policyTags;
+  const std::unordered_set<std::string> d_gettagPolicyTags;
   std::string d_routingTag;
   std::vector<DNSRecord> d_records;
 
@@ -537,7 +538,8 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolder<LuaConfigI
                          const ComboAddress& mappedSource, const EDNSSubnetOpts& ednssubnet,
                          const boost::uuids::uuid& uniqueId, const string& requestorId, const string& deviceId,
                          const string& deviceName, const std::map<std::string, RecursorLua4::MetaValue>& meta,
-                         const RecEventTrace& eventTrace);
+                         const RecEventTrace& eventTrace,
+                         const std::unordered_set<std::string>& policyTags);
 void requestWipeCaches(const DNSName& canon);
 void startDoResolve(void* p);
 bool expectProxyProtocol(const ComboAddress& from);
index 160fe9857228795797327f3c76ea6eae6cef4bfb..9b4b17801d3975e9169e73e16809b74662f2378c 100644 (file)
@@ -563,7 +563,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var)
               {
                 0, 0
               };
-              protobufLogResponse(dh, luaconfsLocal, pbData, tv, true, dc->d_source, dc->d_destination, dc->d_mappedSource, dc->d_ednssubnet, dc->d_uuid, dc->d_requestorId, dc->d_deviceId, dc->d_deviceName, dc->d_meta, dc->d_eventTrace);
+              protobufLogResponse(dh, luaconfsLocal, pbData, tv, true, dc->d_source, dc->d_destination, dc->d_mappedSource, dc->d_ednssubnet, dc->d_uuid, dc->d_requestorId, dc->d_deviceId, dc->d_deviceName, dc->d_meta, dc->d_eventTrace, dc->d_policyTags);
             }
 
             if (dc->d_eventTrace.enabled() && SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_log) {
index 5c6fd0b451608554fd13b96a1d543eaf1ba1683a..79d57e45618dd58363b011488521af5bacfb3af0 100644 (file)
@@ -94,19 +94,19 @@ class TestRecursorProtobuf(RecursorTest):
     def getFirstProtobufMessage(self, retries=1, waitTime=1):
         msg = None
 
-        print("in getFirstProtobufMessage")
+        #print("in getFirstProtobufMessage")
         for param in protobufServersParameters:
-          print(param.port)
+          #print(param.port)
           failed = 0
 
           while param.queue.empty:
-            print(failed)
-            print(retries)
+            #print(failed)
+            #print(retries)
             if failed >= retries:
               break
 
             failed = failed + 1
-            print("waiting")
+            #print("waiting")
             time.sleep(waitTime)
 
           self.assertFalse(param.queue.empty())
@@ -118,7 +118,7 @@ class TestRecursorProtobuf(RecursorTest):
           if oldmsg is not None:
             self.assertEqual(msg, oldmsg)
 
-        print(msg)
+        #print(msg)
         return msg
 
     def emptyProtoBufQueue(self):
@@ -232,17 +232,17 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertEqual(msg.response.appliedPolicyKind, kind)
 
     def checkProtobufTags(self, msg, tags):
-        print(tags)
-        print('---')
-        print(msg.response.tags)
+        #print(tags)
+        #print('---')
+        #print(msg.response.tags)
         self.assertEqual(len(msg.response.tags), len(tags))
         for tag in msg.response.tags:
             self.assertTrue(tag in tags)
 
     def checkProtobufMetas(self, msg, metas):
-        print(metas)
-        print('---')
-        print(msg.meta)
+        #print(metas)
+        #print('---')
+        #print(msg.meta)
         self.assertEqual(len(msg.meta), len(metas))
         for m in msg.meta:
             self.assertTrue(m.HasField('key'))
@@ -277,7 +277,7 @@ class TestRecursorProtobuf(RecursorTest):
         self.assertEqual(msg.response.rcode, 65536)
 
     def checkProtobufIdentity(self, msg, requestorId, deviceId, deviceName):
-        print(msg)
+        #print(msg)
         self.assertTrue((requestorId == '') == (not msg.HasField('requestorId')))
         self.assertTrue((deviceId == b'') == (not msg.HasField('deviceId')))
         self.assertTrue((deviceName == '') == (not msg.HasField('deviceName')))
@@ -957,9 +957,85 @@ auth-zones=example=configs/%s/example.zone""" % _confdir
         self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
         self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
         tags = [ self._tag_from_gettag ] + self._tags
+        #print(msg)
         self.checkProtobufTags(msg, tags)
         self.checkNoRemainingMessage()
 
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        # check the protobuf messages corresponding to the UDP query and answer
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+        self.checkProtobufTags(msg, [ self._tag_from_gettag ])
+        # then the response
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        tags = [ self._tag_from_gettag ] + self._tags
+        self.checkProtobufTags(msg, tags)
+        self.checkNoRemainingMessage()
+
+class ProtobufTagCacheTest(TestRecursorProtobuf):
+    """
+    This test makes sure that we correctly cache tags (actually not cache them)
+    """
+
+    _confdir = 'ProtobufTagCache'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone""" % _confdir
+    _lua_config_file = """
+    protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } )
+    """ % (protobufServersParameters[0].port, protobufServersParameters[1].port)
+    _lua_dns_script_file = """
+    function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp)
+      if qname:equal('tagged.example.') then
+        return 0, { '' .. math.random() }
+      end
+      return 0
+    end
+    """
+
+    def testTagged(self):
+        name = 'tagged.example.'
+        expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84')
+        query = dns.message.make_query(name, 'A', want_dnssec=True)
+        query.flags |= dns.flags.CD
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts1 = msg.response.tags[0]
+
+        # Again to check PC case
+        res = self.sendUDPQuery(query)
+        self.assertRRsetInAnswer(res, expected)
+
+        msg = self.getFirstProtobufMessage()
+        self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res)
+        self.assertEqual(len(msg.response.rrs), 1)
+        rr = msg.response.rrs[0]
+        # we have max-cache-ttl set to 15
+        self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15)
+        self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84')
+        self.checkNoRemainingMessage()
+        self.assertEqual(len(msg.response.tags), 1)
+        ts2 = msg.response.tags[0]
+        self.assertNotEqual(ts1, ts2)
+
 class ProtobufSelectedFromLuaTest(TestRecursorProtobuf):
     """
     This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.