From: Remi Gacogne Date: Thu, 8 Sep 2016 12:05:13 +0000 (+0200) Subject: dnsdist: Add regression tests for protobuf alteration via Lua X-Git-Tag: dnsdist-1.1.0-beta2~109^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8667904b22abbaed717bb6b1c07d4c221978ea5a;p=thirdparty%2Fpdns.git dnsdist: Add regression tests for protobuf alteration via Lua --- diff --git a/regression-tests.dnsdist/test_Protobuf.py b/regression-tests.dnsdist/test_Protobuf.py index c5722ec73b..d4f017dc6b 100644 --- a/regression-tests.dnsdist/test_Protobuf.py +++ b/regression-tests.dnsdist/test_Protobuf.py @@ -17,10 +17,25 @@ class TestProtobuf(DNSDistTest): _protobufCounter = 0 _config_params = ['_testServerPort', '_protobufServerPort'] _config_template = """ + luasmn = newSuffixMatchNode() + luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.')) + + function alterProtobuf(dq, protobuf) + if luasmn:check(dq.qname) then + requestor = newCA(dq.remoteaddr:toString()) + if requestor:isIPv4() then + requestor:truncate(24) + else + requestor:truncate(56) + end + protobuf:setRequestor(requestor) + end + end + newServer{address="127.0.0.1:%s", useClientSubnet=true} rl = newRemoteLogger('127.0.0.1:%s') - addAction(AllRule(), RemoteLogAction(rl)) - addResponseAction(AllRule(), RemoteLogResponseAction(rl)) + addAction(AllRule(), RemoteLogAction(rl, alterProtobuf)) + addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobuf)) """ @classmethod @@ -72,14 +87,14 @@ class TestProtobuf(DNSDistTest): msg.ParseFromString(data) return msg - def checkProtobufBase(self, msg, protocol, query): + def checkProtobufBase(self, msg, protocol, query, initiator): self.assertTrue(msg) self.assertTrue(msg.HasField('timeSec')) self.assertTrue(msg.HasField('socketFamily')) self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET) self.assertTrue(msg.HasField('from')) fromvalue = getattr(msg, 'from') - self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), '127.0.0.1') + self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator) self.assertTrue(msg.HasField('socketProtocol')) self.assertEquals(msg.socketProtocol, protocol) self.assertTrue(msg.HasField('messageId')) @@ -93,9 +108,9 @@ class TestProtobuf(DNSDistTest): # self.assertEquals(len(msg.originalRequestorSubnet), 4) # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1') - def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname): + def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'): self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType) - self.checkProtobufBase(msg, protocol, query) + self.checkProtobufBase(msg, protocol, query, initiator) # dnsdist doesn't fill the responder field for responses # because it doesn't keep the information around. self.assertTrue(msg.HasField('to')) @@ -108,12 +123,23 @@ class TestProtobuf(DNSDistTest): self.assertTrue(msg.question.HasField('qName')) self.assertEquals(msg.question.qName, qname) - def checkProtobufResponse(self, msg, protocol, response): + def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'): self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType) - self.checkProtobufBase(msg, protocol, response) + self.checkProtobufBase(msg, protocol, response, initiator) self.assertTrue(msg.HasField('response')) self.assertTrue(msg.response.HasField('queryTimeSec')) + def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl): + self.assertTrue(record.HasField('class')) + self.assertEquals(getattr(record, 'class'), rclass) + self.assertTrue(record.HasField('type')) + self.assertEquals(record.type, rtype) + self.assertTrue(record.HasField('name')) + self.assertEquals(record.name, rname) + self.assertTrue(record.HasField('ttl')) + self.assertEquals(record.ttl, rttl) + self.assertTrue(record.HasField('rdata')) + def testProtobuf(self): """ Protobuf: Send data to a protobuf server @@ -147,15 +173,7 @@ class TestProtobuf(DNSDistTest): self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response) self.assertEquals(len(msg.response.rrs), 1) for rr in msg.response.rrs: - self.assertTrue(rr.HasField('class')) - self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN) - self.assertTrue(rr.HasField('type')) - self.assertEquals(rr.type, dns.rdatatype.A) - self.assertTrue(rr.HasField('name')) - self.assertEquals(rr.name, name) - self.assertTrue(rr.HasField('ttl')) - self.assertEquals(rr.ttl, 3600) - self.assertTrue(rr.HasField('rdata')) + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) @@ -177,13 +195,63 @@ class TestProtobuf(DNSDistTest): self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response) self.assertEquals(len(msg.response.rrs), 1) for rr in msg.response.rrs: - self.assertTrue(rr.HasField('class')) - self.assertEquals(getattr(rr, 'class'), dns.rdataclass.IN) - self.assertTrue(rr.HasField('type')) - self.assertEquals(rr.type, dns.rdatatype.A) - self.assertTrue(rr.HasField('name')) - self.assertEquals(rr.name, name) - self.assertTrue(rr.HasField('ttl')) - self.assertEquals(rr.ttl, 3600) - self.assertTrue(rr.HasField('rdata')) + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') + + def testLuaProtobuf(self): + """ + Protobuf: Check that the Lua callback rewrote the initiator + """ + name = 'lua.protobuf.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the UDP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0') + + # check the protobuf message corresponding to the UDP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0') + self.assertEquals(len(msg.response.rrs), 1) + for rr in msg.response.rrs: + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) + self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1') + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + self.assertTrue(receivedQuery) + self.assertTrue(receivedResponse) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + + # let the protobuf messages the time to get there + time.sleep(1) + + # check the protobuf message corresponding to the TCP query + msg = self.getFirstProtobufMessage() + self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '127.0.0.0') + + # check the protobuf message corresponding to the TCP response + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0') + self.assertEquals(len(msg.response.rrs), 1) + for rr in msg.response.rrs: + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600) self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')