bool handleTimeoutResponseRules(const std::vector<dnsdist::rules::ResponseRuleAction>& rules, InternalQueryState& ids, const std::shared_ptr<DownstreamState>& d_ds, const std::shared_ptr<TCPQuerySender>& sender)
{
- PacketBuffer empty;
- DNSResponse dnsResponse(ids, empty, d_ds);
+ /* let's be nice and restore the original DNS header as well as we can with what we have */
+ PacketBuffer payload(sizeof(dnsheader));
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(payload, [&ids](dnsheader& header) {
+ header.id = ids.origID;
+ header.qdcount = htons(1);
+ restoreFlags(&header, ids.origFlags);
+ return true;
+ });
+ DNSResponse dnsResponse(ids, payload, d_ds);
auto protocol = dnsResponse.getProtocol();
vinfolog("Handling timeout response rules for incoming protocol = %s", protocol.toString());
# TCP query
msg = self.getFirstProtobufMessage()
self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
+
+class TestTimeoutResponseRuleProtobuf(DNSDistProtobufTest):
+
+ _yaml_config_template = """---
+binds:
+ - listen_address: "127.0.0.1:%d"
+ protocol: Do53
+
+backends:
+ - address: "127.0.0.1:%d"
+ protocol: Do53
+ health_checks:
+ mode: "up"
+
+remote_logging:
+ protobuf_loggers:
+ - name: "my-logger"
+ address: "127.0.0.1:%d"
+ timeout: 1
+
+timeout_response_rules:
+ - name: "my-rule"
+ selector:
+ type: "All"
+ action:
+ type: "RemoteLog"
+ logger_name: "my-logger"
+ server_id: "%s"
+"""
+ _dnsDistPort = pickAvailablePort()
+ _testServerPortNotListening = pickAvailablePort()
+ _testServerPort = pickAvailablePort()
+ _yaml_config_params = ['_dnsDistPort', '_testServerPortNotListening', '_protobufServerPort', '_protobufServerID']
+ _config_params = []
+
+ def testProtobuf(self):
+ """
+ Yaml: Remote logging via protobuf of timeouts
+ """
+ name = 'remote-logging-timeout.protobuf.yaml.test.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ query.flags &= ~dns.flags.RD
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False, timeout=2)
+ self.assertEqual(receivedQuery, None)
+ self.assertEqual(receivedResponse, None)
+
+
+ # the UDP timeout usually takes longer to be detected
+ # than the TCP one
+ gotUDP = False
+ gotTCP = False
+ waited = 0
+ while (not gotUDP or not gotTCP) and waited <= 4:
+ if self._protobufQueue.empty():
+ # let the protobuf messages the time to get there
+ time.sleep(1)
+ waited += 1
+ if self._protobufQueue.empty():
+ continue
+
+ # check the protobuf message
+ msg = self.getFirstProtobufMessage()
+ if msg.socketProtocol == dnsmessage_pb2.PBDNSMessage.UDP:
+ gotUDP = True
+ protocol = dnsmessage_pb2.PBDNSMessage.UDP
+ else:
+ gotTCP = True
+ protocol = dnsmessage_pb2.PBDNSMessage.TCP
+
+ self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
+ self.checkProtobufBase(msg, protocol, response, '127.0.0.1', normalQueryResponse=False, v6=False)
+ self.assertTrue(msg.HasField('response'))
+ self.assertTrue(msg.response.HasField('queryTimeSec'))
+ self.assertTrue(msg.HasField('question'))
+ self.assertTrue(msg.question.HasField('qClass'))
+ self.assertTrue(msg.question.HasField('qType'))
+ self.assertTrue(msg.question.HasField('qName'))
+ self.assertEqual(msg.question.qName, name)
+
+ self.assertTrue(gotUDP)
+ self.assertTrue(gotTCP)