From: Otto Moerbeek Date: Thu, 20 Jul 2023 10:48:36 +0000 (+0200) Subject: rec: Backport 13021 to rec-4.8.x: fix setting of policy tags X-Git-Tag: rec-4.8.5~4^2~3 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=6a7cb5935dc46f5941c977814da838747584fb9c;p=thirdparty%2Fpdns.git rec: Backport 13021 to rec-4.8.x: fix setting of policy tags Backport of #13021 --- diff --git a/pdns/pdns_recursor.cc b/pdns/pdns_recursor.cc index 6e8ac8219c..858dc28b9d 100644 --- a/pdns/pdns_recursor.cc +++ b/pdns/pdns_recursor.cc @@ -835,6 +835,20 @@ static bool isEnabledForUDRs(const std::shared_ptr(reinterpret_cast(p)); @@ -1603,9 +1617,10 @@ void startDoResolve(void* p) pbMessage.setAppliedPolicyHit(appliedPolicy.d_hit); pbMessage.setAppliedPolicyKind(appliedPolicy.d_kind); } - pbMessage.addPolicyTags(dc->d_policyTags); pbMessage.setInBytes(packet.size()); pbMessage.setValidationState(sr.getValidationState()); + // See if we want to store the policyTags into the PC + addPolicyTagsToPBMessageIfNeeded(*dc, pbMessage); // Take s snap of the current protobuf buffer state to store in the PC pbDataForCache = boost::make_optional(RecursorPacketCache::PBData{ @@ -1693,6 +1708,7 @@ void startDoResolve(void* p) pbMessage.setDeviceId(dq.deviceId); pbMessage.setDeviceName(dq.deviceName); pbMessage.setToPort(dc->d_destination.getPort()); + pbMessage.addPolicyTags(dc->d_gettagPolicyTags); for (const auto& m : dq.meta) { pbMessage.setMeta(m.first, m.second.stringVal, m.second.intVal); @@ -2132,7 +2148,7 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr eventTrace.add(RecEventTrace::AnswerSent); if (t_protobufServers.servers && logResponse && !(luaconfsLocal->protobufExportConfig.taggedOnly && pbData && !pbData->d_tagged)) { - protobufLogResponse(dh, luaconfsLocal, pbData, tv, false, source, destination, mappedSource, ednssubnet, uniqueId, requestorId, deviceId, deviceName, meta, eventTrace); + protobufLogResponse(dh, luaconfsLocal, pbData, tv, false, source, destination, mappedSource, ednssubnet, uniqueId, requestorId, deviceId, deviceName, meta, eventTrace, policyTags); } if (eventTrace.enabled() && SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_log) { diff --git a/pdns/recursordist/rec-main.cc b/pdns/recursordist/rec-main.cc index 343454cd56..bb6b742b89 100644 --- a/pdns/recursordist/rec-main.cc +++ b/pdns/recursordist/rec-main.cc @@ -518,7 +518,8 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolder& meta, - const RecEventTrace& eventTrace) + const RecEventTrace& eventTrace, + const std::unordered_set& policyTags) { pdns::ProtoZero::RecMessage pbMessage(pbData ? pbData->d_message : "", pbData ? pbData->d_response : "", 64, 10); // The extra bytes we are going to add // Normally we take the immutable string from the cache and append a few values, but if it's not there (can this happen?) @@ -538,12 +539,14 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolderprotobufExportConfig.logMappedFrom) { + pbMessage.setSocketFamily(source.sin4.sin_family); Netmask requestorNM(source, source.sin4.sin_family == AF_INET ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6); auto requestor = requestorNM.getMaskedNetwork(); pbMessage.setFrom(requestor); pbMessage.setFromPort(source.getPort()); } else { + pbMessage.setSocketFamily(mappedSource.sin4.sin_family); Netmask requestorNM(mappedSource, mappedSource.sin4.sin_family == AF_INET ? luaconfsLocal->protobufMaskV4 : luaconfsLocal->protobufMaskV6); auto requestor = requestorNM.getMaskedNetwork(); pbMessage.setFrom(requestor); @@ -571,6 +574,8 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolder&& policyTags, shared_ptr luaContext, LuaContext::LuaObject&& data, std::vector&& records) : - d_mdp(true, query), d_now(now), d_query(query), d_policyTags(std::move(policyTags)), d_records(std::move(records)), d_luaContext(luaContext), d_data(std::move(data)) + d_mdp(true, query), d_now(now), d_query(query), d_policyTags(std::move(policyTags)), d_gettagPolicyTags(d_policyTags), d_records(std::move(records)), d_luaContext(luaContext), d_data(std::move(data)) { } @@ -125,6 +125,7 @@ struct DNSComboWriter }; std::string d_query; std::unordered_set d_policyTags; + const std::unordered_set d_gettagPolicyTags; std::string d_routingTag; std::vector d_records; @@ -537,7 +538,8 @@ void protobufLogResponse(const struct dnsheader* dh, LocalStateHolder& meta, - const RecEventTrace& eventTrace); + const RecEventTrace& eventTrace, + const std::unordered_set& policyTags); void requestWipeCaches(const DNSName& canon); void startDoResolve(void* p); bool expectProxyProtocol(const ComboAddress& from); diff --git a/pdns/recursordist/rec-tcp.cc b/pdns/recursordist/rec-tcp.cc index 160fe98572..9b4b17801d 100644 --- a/pdns/recursordist/rec-tcp.cc +++ b/pdns/recursordist/rec-tcp.cc @@ -563,7 +563,7 @@ static void handleRunningTCPQuestion(int fd, FDMultiplexer::funcparam_t& var) { 0, 0 }; - protobufLogResponse(dh, luaconfsLocal, pbData, tv, true, dc->d_source, dc->d_destination, dc->d_mappedSource, dc->d_ednssubnet, dc->d_uuid, dc->d_requestorId, dc->d_deviceId, dc->d_deviceName, dc->d_meta, dc->d_eventTrace); + protobufLogResponse(dh, luaconfsLocal, pbData, tv, true, dc->d_source, dc->d_destination, dc->d_mappedSource, dc->d_ednssubnet, dc->d_uuid, dc->d_requestorId, dc->d_deviceId, dc->d_deviceName, dc->d_meta, dc->d_eventTrace, dc->d_policyTags); } if (dc->d_eventTrace.enabled() && SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_log) { diff --git a/regression-tests.recursor-dnssec/test_Protobuf.py b/regression-tests.recursor-dnssec/test_Protobuf.py index 5c6fd0b451..79d57e4561 100644 --- a/regression-tests.recursor-dnssec/test_Protobuf.py +++ b/regression-tests.recursor-dnssec/test_Protobuf.py @@ -94,19 +94,19 @@ class TestRecursorProtobuf(RecursorTest): def getFirstProtobufMessage(self, retries=1, waitTime=1): msg = None - print("in getFirstProtobufMessage") + #print("in getFirstProtobufMessage") for param in protobufServersParameters: - print(param.port) + #print(param.port) failed = 0 while param.queue.empty: - print(failed) - print(retries) + #print(failed) + #print(retries) if failed >= retries: break failed = failed + 1 - print("waiting") + #print("waiting") time.sleep(waitTime) self.assertFalse(param.queue.empty()) @@ -118,7 +118,7 @@ class TestRecursorProtobuf(RecursorTest): if oldmsg is not None: self.assertEqual(msg, oldmsg) - print(msg) + #print(msg) return msg def emptyProtoBufQueue(self): @@ -232,17 +232,17 @@ class TestRecursorProtobuf(RecursorTest): self.assertEqual(msg.response.appliedPolicyKind, kind) def checkProtobufTags(self, msg, tags): - print(tags) - print('---') - print(msg.response.tags) + #print(tags) + #print('---') + #print(msg.response.tags) self.assertEqual(len(msg.response.tags), len(tags)) for tag in msg.response.tags: self.assertTrue(tag in tags) def checkProtobufMetas(self, msg, metas): - print(metas) - print('---') - print(msg.meta) + #print(metas) + #print('---') + #print(msg.meta) self.assertEqual(len(msg.meta), len(metas)) for m in msg.meta: self.assertTrue(m.HasField('key')) @@ -277,7 +277,7 @@ class TestRecursorProtobuf(RecursorTest): self.assertEqual(msg.response.rcode, 65536) def checkProtobufIdentity(self, msg, requestorId, deviceId, deviceName): - print(msg) + #print(msg) self.assertTrue((requestorId == '') == (not msg.HasField('requestorId'))) self.assertTrue((deviceId == b'') == (not msg.HasField('deviceId'))) self.assertTrue((deviceName == '') == (not msg.HasField('deviceName'))) @@ -957,9 +957,85 @@ auth-zones=example=configs/%s/example.zone""" % _confdir self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') tags = [ self._tag_from_gettag ] + self._tags + #print(msg) self.checkProtobufTags(msg, tags) self.checkNoRemainingMessage() + # Again to check PC case + res = self.sendUDPQuery(query) + self.assertRRsetInAnswer(res, expected) + + # check the protobuf messages corresponding to the UDP query and answer + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name) + self.checkProtobufTags(msg, [ self._tag_from_gettag ]) + # then the response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # we have max-cache-ttl set to 15 + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') + tags = [ self._tag_from_gettag ] + self._tags + self.checkProtobufTags(msg, tags) + self.checkNoRemainingMessage() + +class ProtobufTagCacheTest(TestRecursorProtobuf): + """ + This test makes sure that we correctly cache tags (actually not cache them) + """ + + _confdir = 'ProtobufTagCache' + _config_template = """ +auth-zones=example=configs/%s/example.zone""" % _confdir + _lua_config_file = """ + protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } ) + """ % (protobufServersParameters[0].port, protobufServersParameters[1].port) + _lua_dns_script_file = """ + function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp) + if qname:equal('tagged.example.') then + return 0, { '' .. math.random() } + end + return 0 + end + """ + + def testTagged(self): + name = 'tagged.example.' + expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84') + query = dns.message.make_query(name, 'A', want_dnssec=True) + query.flags |= dns.flags.CD + res = self.sendUDPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # we have max-cache-ttl set to 15 + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') + self.checkNoRemainingMessage() + self.assertEqual(len(msg.response.tags), 1) + ts1 = msg.response.tags[0] + + # Again to check PC case + res = self.sendUDPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # we have max-cache-ttl set to 15 + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') + self.checkNoRemainingMessage() + self.assertEqual(len(msg.response.tags), 1) + ts2 = msg.response.tags[0] + self.assertNotEqual(ts1, ts2) + class ProtobufSelectedFromLuaTest(TestRecursorProtobuf): """ This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.