From: Otto Moerbeek Date: Wed, 12 Jul 2023 12:58:53 +0000 (+0200) Subject: Set the pb policy tags in the right places X-Git-Tag: rec-5.0.0-alpha1~99^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c837140e1d39b9cec75ba75ed23487724d9a3a51;p=thirdparty%2Fpdns.git Set the pb policy tags in the right places --- diff --git a/pdns/recursordist/pdns_recursor.cc b/pdns/recursordist/pdns_recursor.cc index 927cf06dd8..1b982ab1f1 100644 --- a/pdns/recursordist/pdns_recursor.cc +++ b/pdns/recursordist/pdns_recursor.cc @@ -1679,7 +1679,6 @@ void startDoResolve(void* arg) // NOLINT(readability-function-cognitive-complexi pbMessage.setAppliedPolicyHit(appliedPolicy.d_hit); pbMessage.setAppliedPolicyKind(appliedPolicy.d_kind); } - pbMessage.addPolicyTags(comboWriter->d_policyTags); pbMessage.setInBytes(packet.size()); pbMessage.setValidationState(resolver.getValidationState()); @@ -1775,6 +1774,7 @@ void startDoResolve(void* arg) // NOLINT(readability-function-cognitive-complexi pbMessage.setDeviceId(dnsQuestion.deviceId); pbMessage.setDeviceName(dnsQuestion.deviceName); pbMessage.setToPort(comboWriter->d_destination.getPort()); + pbMessage.addPolicyTags(comboWriter->d_policyTags); for (const auto& metaValue : dnsQuestion.meta) { pbMessage.setMeta(metaValue.first, metaValue.second.stringVal, metaValue.second.intVal); @@ -2243,7 +2243,7 @@ static string* doProcessUDPQuestion(const std::string& question, const ComboAddr eventTrace.add(RecEventTrace::AnswerSent); if (t_protobufServers.servers && logResponse && (!luaconfsLocal->protobufExportConfig.taggedOnly || !pbData || pbData->d_tagged)) { - protobufLogResponse(dnsheader, luaconfsLocal, pbData, tval, false, source, destination, mappedSource, ednssubnet, uniqueId, requestorId, deviceId, deviceName, meta, eventTrace); + protobufLogResponse(dnsheader, luaconfsLocal, pbData, tval, false, source, destination, mappedSource, ednssubnet, uniqueId, requestorId, deviceId, deviceName, meta, eventTrace, policyTags); } if (eventTrace.enabled() && (SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_log) != 0) { diff --git a/pdns/recursordist/rec-main.cc b/pdns/recursordist/rec-main.cc index 2d46bbe944..36b8517013 100644 --- a/pdns/recursordist/rec-main.cc +++ b/pdns/recursordist/rec-main.cc @@ -493,7 +493,7 @@ void protobufLogQuery(LocalStateHolder& luaconfsLocal, const boo msg.setRequestorId(requestorId); msg.setDeviceId(deviceId); msg.setDeviceName(deviceName); - + if (!policyTags.empty()) { msg.addPolicyTags(policyTags); } @@ -526,7 +526,8 @@ void protobufLogResponse(const struct dnsheader* header, LocalStateHolder& meta, - const RecEventTrace& eventTrace) + const RecEventTrace& eventTrace, + const std::unordered_set& policyTags) { pdns::ProtoZero::RecMessage pbMessage(pbData ? pbData->d_message : "", pbData ? pbData->d_response : "", 64, 10); // The extra bytes we are going to add // Normally we take the immutable string from the cache and append a few values, but if it's not there (can this happen?) @@ -581,6 +582,8 @@ void protobufLogResponse(const struct dnsheader* header, LocalStateHolder& meta, - const RecEventTrace& eventTrace); + const RecEventTrace& eventTrace, + const std::unordered_set& policyTags); void requestWipeCaches(const DNSName& canon); void startDoResolve(void*); bool expectProxyProtocol(const ComboAddress& from); diff --git a/pdns/recursordist/rec-tcp.cc b/pdns/recursordist/rec-tcp.cc index 70c3f9595a..4b43326e05 100644 --- a/pdns/recursordist/rec-tcp.cc +++ b/pdns/recursordist/rec-tcp.cc @@ -577,7 +577,7 @@ static void handleRunningTCPQuestion(int fileDesc, FDMultiplexer::funcparam_t& v { 0, 0 }; - protobufLogResponse(dnsheader, luaconfsLocal, pbData, tval, true, comboWriter->d_source, comboWriter->d_destination, comboWriter->d_mappedSource, comboWriter->d_ednssubnet, comboWriter->d_uuid, comboWriter->d_requestorId, comboWriter->d_deviceId, comboWriter->d_deviceName, comboWriter->d_meta, comboWriter->d_eventTrace); + protobufLogResponse(dnsheader, luaconfsLocal, pbData, tval, true, comboWriter->d_source, comboWriter->d_destination, comboWriter->d_mappedSource, comboWriter->d_ednssubnet, comboWriter->d_uuid, comboWriter->d_requestorId, comboWriter->d_deviceId, comboWriter->d_deviceName, comboWriter->d_meta, comboWriter->d_eventTrace, comboWriter->d_policyTags); } if (comboWriter->d_eventTrace.enabled() && (SyncRes::s_event_trace_enabled & SyncRes::event_trace_to_log) != 0) { diff --git a/regression-tests.recursor-dnssec/test_Protobuf.py b/regression-tests.recursor-dnssec/test_Protobuf.py index 698f5109d8..5143a5b805 100644 --- a/regression-tests.recursor-dnssec/test_Protobuf.py +++ b/regression-tests.recursor-dnssec/test_Protobuf.py @@ -94,19 +94,19 @@ class TestRecursorProtobuf(RecursorTest): def getFirstProtobufMessage(self, retries=1, waitTime=1): msg = None - print("in getFirstProtobufMessage") + #print("in getFirstProtobufMessage") for param in protobufServersParameters: print(param.port) failed = 0 while param.queue.empty: - print(failed) - print(retries) + #print(failed) + #print(retries) if failed >= retries: break failed = failed + 1 - print("waiting") + #print("waiting") time.sleep(waitTime) self.assertFalse(param.queue.empty()) @@ -118,7 +118,7 @@ class TestRecursorProtobuf(RecursorTest): if oldmsg is not None: self.assertEqual(msg, oldmsg) - print(msg) + #print(msg) return msg def emptyProtoBufQueue(self): @@ -232,17 +232,17 @@ class TestRecursorProtobuf(RecursorTest): self.assertEqual(msg.response.appliedPolicyKind, kind) def checkProtobufTags(self, msg, tags): - print(tags) - print('---') - print(msg.response.tags) + #print(tags) + #print('---') + #print(msg.response.tags) self.assertEqual(len(msg.response.tags), len(tags)) for tag in msg.response.tags: self.assertTrue(tag in tags) def checkProtobufMetas(self, msg, metas): - print(metas) - print('---') - print(msg.meta) + #print(metas) + #print('---') + #print(msg.meta) self.assertEqual(len(msg.meta), len(metas)) for m in msg.meta: self.assertTrue(m.HasField('key')) @@ -277,7 +277,7 @@ class TestRecursorProtobuf(RecursorTest): self.assertEqual(msg.response.rcode, 65536) def checkProtobufIdentity(self, msg, requestorId, deviceId, deviceName): - print(msg) + #print(msg) self.assertTrue((requestorId == '') == (not msg.HasField('requestorId'))) self.assertTrue((deviceId == b'') == (not msg.HasField('deviceId'))) self.assertTrue((deviceName == '') == (not msg.HasField('deviceName')))