]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/test_Protobuf.py
Merge pull request #5905 from mind04/302
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Protobuf.py
CommitLineData
1d0bd88a
RG
1#!/usr/bin/env python
2import Queue
3import threading
4import socket
5import struct
6import sys
7import time
8from dnsdisttests import DNSDistTest
9
10import dns
11import dnsmessage_pb2
12
1d0bd88a 13
741ebe08 14class TestProtobuf(DNSDistTest):
1d0bd88a
RG
15 _protobufServerPort = 4242
16 _protobufQueue = Queue.Queue()
17 _protobufCounter = 0
18 _config_params = ['_testServerPort', '_protobufServerPort']
19 _config_template = """
29cd61cc
SO
20 luasmn = newSuffixMatchNode()
21 luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
22
23 function alterProtobufResponse(dq, protobuf)
24 if luasmn:check(dq.qname) then
25 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
26 if requestor:isIPv4() then
27 requestor:truncate(24)
8667904b 28 else
29cd61cc 29 requestor:truncate(56)
8667904b 30 end
29cd61cc 31 protobuf:setRequestor(requestor)
741ebe08 32
0096ada6
RG
33 local tableTags = {}
34 table.insert(tableTags, "TestLabel1,TestData1")
35 table.insert(tableTags, "TestLabel2,TestData2")
741ebe08 36
0096ada6 37 protobuf:setTagArray(tableTags)
29cd61cc 38
0096ada6 39 protobuf:setTag('TestLabel3,TestData3')
29cd61cc 40
0096ada6 41 protobuf:setTag("Response,456")
741ebe08 42
29cd61cc 43 else
29cd61cc 44
0096ada6
RG
45 local tableTags = {} -- called by testProtobuf()
46 table.insert(tableTags, "TestLabel1,TestData1")
47 table.insert(tableTags, "TestLabel2,TestData2")
48 protobuf:setTagArray(tableTags)
741ebe08 49
0096ada6
RG
50 protobuf:setTag('TestLabel3,TestData3')
51
52 protobuf:setTag("Response,456")
741ebe08 53
741ebe08
SO
54 end
55 end
56
29cd61cc
SO
57 function alterProtobufQuery(dq, protobuf)
58
59 if luasmn:check(dq.qname) then
741ebe08 60 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
29cd61cc
SO
61 if requestor:isIPv4() then
62 requestor:truncate(24)
741ebe08 63 else
29cd61cc 64 requestor:truncate(56)
741ebe08 65 end
29cd61cc 66 protobuf:setRequestor(requestor)
741ebe08 67
0096ada6
RG
68 local tableTags = {}
69 tableTags = dq:getTagArray() -- get table from DNSQuery
741ebe08 70
0096ada6
RG
71 local tablePB = {}
72 for k, v in pairs( tableTags) do
73 table.insert(tablePB, k .. "," .. v)
74 end
741ebe08 75
0096ada6
RG
76 protobuf:setTagArray(tablePB) -- store table in protobuf
77 protobuf:setTag("Query,123") -- add another tag entry in protobuf
741ebe08 78
0096ada6 79 protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN
741ebe08 80
0096ada6 81 local strReqName = dq.qname:toString() -- get request dns name
741ebe08 82
0096ada6 83 protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time
741ebe08 84
0096ada6 85 blobData = '\127' .. '\000' .. '\000' .. '\001' -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
29cd61cc 86
0096ada6 87 protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
29cd61cc 88
0096ada6 89 protobuf:setBytes(65) -- set the size of the query to confirm in checkProtobufBase
29cd61cc 90
741ebe08
SO
91 else
92
0096ada6
RG
93 local tableTags = {} -- called by testProtobuf()
94 table.insert(tableTags, "TestLabel1,TestData1")
95 table.insert(tableTags, "TestLabel2,TestData2")
741ebe08 96
0096ada6
RG
97 protobuf:setTagArray(tableTags)
98 protobuf:setTag('TestLabel3,TestData3')
99 protobuf:setTag("Query,123")
29cd61cc 100
8667904b
RG
101 end
102 end
103
741ebe08 104 function alterLuaFirst(dq) -- called when dnsdist receives new request
0096ada6
RG
105 local tt = {}
106 tt["TestLabel1"] = "TestData1"
107 tt["TestLabel2"] = "TestData2"
741ebe08 108
0096ada6 109 dq:setTagArray(tt)
741ebe08 110
0096ada6
RG
111 dq:setTag("TestLabel3","TestData3")
112 return DNSAction.None, "" -- continue to the next rule
741ebe08
SO
113 end
114
1d0bd88a
RG
115 newServer{address="127.0.0.1:%s", useClientSubnet=true}
116 rl = newRemoteLogger('127.0.0.1:%s')
741ebe08
SO
117
118 addLuaAction(AllRule(), alterLuaFirst) -- Add tags to DNSQuery first
119
120 addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery)) -- Send protobuf message before lookup
121
122 addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true)) -- Send protobuf message after lookup
123
1d0bd88a
RG
124 """
125
126 @classmethod
127 def ProtobufListener(cls, port):
128 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
129 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
130 try:
131 sock.bind(("127.0.0.1", port))
132 except socket.error as e:
133 print("Error binding in the protbuf listener: %s" % str(e))
134 sys.exit(1)
135
136 sock.listen(100)
137 while True:
138 (conn, _) = sock.accept()
139 data = None
140 while True:
141 data = conn.recv(2)
142 if not data:
143 break
144 (datalen,) = struct.unpack("!H", data)
145 data = conn.recv(datalen)
146 if not data:
147 break
148
149 cls._protobufQueue.put(data, True, timeout=2.0)
150
151 conn.close()
152 sock.close()
153
154 @classmethod
155 def startResponders(cls):
156 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort])
157 cls._UDPResponder.setDaemon(True)
158 cls._UDPResponder.start()
741ebe08 159
1d0bd88a
RG
160 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort])
161 cls._TCPResponder.setDaemon(True)
162 cls._TCPResponder.start()
163
164 cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
165 cls._protobufListener.setDaemon(True)
166 cls._protobufListener.start()
167
168 def getFirstProtobufMessage(self):
169 self.assertFalse(self._protobufQueue.empty())
170 data = self._protobufQueue.get(False)
171 self.assertTrue(data)
172 msg = dnsmessage_pb2.PBDNSMessage()
173 msg.ParseFromString(data)
174 return msg
175
29cd61cc 176 def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True):
1d0bd88a
RG
177 self.assertTrue(msg)
178 self.assertTrue(msg.HasField('timeSec'))
179 self.assertTrue(msg.HasField('socketFamily'))
180 self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
181 self.assertTrue(msg.HasField('from'))
182 fromvalue = getattr(msg, 'from')
29cd61cc 183 self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
1d0bd88a
RG
184 self.assertTrue(msg.HasField('socketProtocol'))
185 self.assertEquals(msg.socketProtocol, protocol)
186 self.assertTrue(msg.HasField('messageId'))
187 self.assertTrue(msg.HasField('id'))
29cd61cc 188 self.assertEquals(msg.id, query.id)
1d0bd88a 189 self.assertTrue(msg.HasField('inBytes'))
0096ada6
RG
190 if normalQueryResponse:
191 # compare inBytes with length of query/response
192 self.assertEquals(msg.inBytes, len(query.to_wire()))
1d0bd88a
RG
193 # dnsdist doesn't set the existing EDNS Subnet for now,
194 # although it might be set from Lua
195 # self.assertTrue(msg.HasField('originalRequestorSubnet'))
196 # self.assertEquals(len(msg.originalRequestorSubnet), 4)
197 # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
198
8667904b 199 def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
0096ada6 200 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
8667904b 201 self.checkProtobufBase(msg, protocol, query, initiator)
1d0bd88a
RG
202 # dnsdist doesn't fill the responder field for responses
203 # because it doesn't keep the information around.
204 self.assertTrue(msg.HasField('to'))
205 self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
206 self.assertTrue(msg.HasField('question'))
207 self.assertTrue(msg.question.HasField('qClass'))
208 self.assertEquals(msg.question.qClass, qclass)
209 self.assertTrue(msg.question.HasField('qType'))
210 self.assertEquals(msg.question.qClass, qtype)
211 self.assertTrue(msg.question.HasField('qName'))
212 self.assertEquals(msg.question.qName, qname)
213
0096ada6
RG
214 def checkProtobufTags(self, tags, expectedTags):
215 # only differences will be in new list
216 listx = set(tags) ^ set(expectedTags)
217 # exclusive or of lists should be empty
218 self.assertEqual(len(listx), 0, "Protobuf tags don't match")
741ebe08 219
29cd61cc
SO
220 def checkProtobufQueryConvertedToResponse(self, msg, protocol, response, initiator='127.0.0.0'):
221 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
0096ada6
RG
222 # skip comparing inBytes (size of the query) with the length of the generated response
223 self.checkProtobufBase(msg, protocol, response, initiator, False)
29cd61cc
SO
224 self.assertTrue(msg.HasField('response'))
225 self.assertTrue(msg.response.HasField('queryTimeSec'))
226
8667904b 227 def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
29cd61cc 228 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
8667904b 229 self.checkProtobufBase(msg, protocol, response, initiator)
29cd61cc
SO
230 self.assertTrue(msg.HasField('response'))
231 self.assertTrue(msg.response.HasField('queryTimeSec'))
741ebe08 232
8667904b
RG
233 def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
234 self.assertTrue(record.HasField('class'))
235 self.assertEquals(getattr(record, 'class'), rclass)
236 self.assertTrue(record.HasField('type'))
237 self.assertEquals(record.type, rtype)
238 self.assertTrue(record.HasField('name'))
239 self.assertEquals(record.name, rname)
240 self.assertTrue(record.HasField('ttl'))
241 self.assertEquals(record.ttl, rttl)
242 self.assertTrue(record.HasField('rdata'))
243
1d0bd88a
RG
244 def testProtobuf(self):
245 """
246 Protobuf: Send data to a protobuf server
247 """
29cd61cc 248 name = 'query.protobuf.tests.powerdns.com.'
741ebe08 249
165c9030 250 target = 'target.protobuf.tests.powerdns.com.'
1d0bd88a
RG
251 query = dns.message.make_query(name, 'A', 'IN')
252 response = dns.message.make_response(query)
741ebe08 253
1d0bd88a 254 rrset = dns.rrset.from_text(name,
165c9030
RG
255 3600,
256 dns.rdataclass.IN,
257 dns.rdatatype.CNAME,
258 target)
259 response.answer.append(rrset)
741ebe08 260
165c9030 261 rrset = dns.rrset.from_text(target,
1d0bd88a
RG
262 3600,
263 dns.rdataclass.IN,
264 dns.rdatatype.A,
265 '127.0.0.1')
266 response.answer.append(rrset)
267
268 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
269 self.assertTrue(receivedQuery)
270 self.assertTrue(receivedResponse)
271 receivedQuery.id = query.id
272 self.assertEquals(query, receivedQuery)
273 self.assertEquals(response, receivedResponse)
274
275 # let the protobuf messages the time to get there
276 time.sleep(1)
277
278 # check the protobuf message corresponding to the UDP query
279 msg = self.getFirstProtobufMessage()
741ebe08 280
29cd61cc 281 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
0096ada6 282 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
1d0bd88a
RG
283
284 # check the protobuf message corresponding to the UDP response
285 msg = self.getFirstProtobufMessage()
0096ada6
RG
286 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
287 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
29cd61cc 288 self.assertEquals(len(msg.response.rrs), 2)
165c9030
RG
289 rr = msg.response.rrs[0]
290 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
291 self.assertEquals(rr.rdata, target)
292 rr = msg.response.rrs[1]
293 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
294 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
1d0bd88a
RG
295
296 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
297 self.assertTrue(receivedQuery)
298 self.assertTrue(receivedResponse)
299 receivedQuery.id = query.id
300 self.assertEquals(query, receivedQuery)
301 self.assertEquals(response, receivedResponse)
302
303 # let the protobuf messages the time to get there
304 time.sleep(1)
305
306 # check the protobuf message corresponding to the TCP query
307 msg = self.getFirstProtobufMessage()
29cd61cc 308
1d0bd88a 309 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
0096ada6 310 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
1d0bd88a
RG
311
312 # check the protobuf message corresponding to the TCP response
313 msg = self.getFirstProtobufMessage()
0096ada6
RG
314 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
315 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
165c9030
RG
316 self.assertEquals(len(msg.response.rrs), 2)
317 rr = msg.response.rrs[0]
318 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
319 self.assertEquals(rr.rdata, target)
320 rr = msg.response.rrs[1]
321 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
322 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
8667904b
RG
323
324 def testLuaProtobuf(self):
741ebe08 325
8667904b
RG
326 """
327 Protobuf: Check that the Lua callback rewrote the initiator
328 """
329 name = 'lua.protobuf.tests.powerdns.com.'
330 query = dns.message.make_query(name, 'A', 'IN')
331 response = dns.message.make_response(query)
332 rrset = dns.rrset.from_text(name,
333 3600,
334 dns.rdataclass.IN,
335 dns.rdatatype.A,
336 '127.0.0.1')
337 response.answer.append(rrset)
338
741ebe08 339
8667904b 340 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
741ebe08 341
29cd61cc
SO
342 self.assertTrue(receivedQuery)
343 self.assertTrue(receivedResponse)
8667904b
RG
344 receivedQuery.id = query.id
345 self.assertEquals(query, receivedQuery)
346 self.assertEquals(response, receivedResponse)
347
741ebe08 348
8667904b
RG
349 # let the protobuf messages the time to get there
350 time.sleep(1)
351
352 # check the protobuf message corresponding to the UDP query
353 msg = self.getFirstProtobufMessage()
29cd61cc 354
0096ada6
RG
355 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
356 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
8667904b
RG
357
358 # check the protobuf message corresponding to the UDP response
359 msg = self.getFirstProtobufMessage()
0096ada6
RG
360 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
361 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
8667904b
RG
362 self.assertEquals(len(msg.response.rrs), 1)
363 for rr in msg.response.rrs:
364 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
365 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
366
367 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
368 self.assertTrue(receivedQuery)
369 self.assertTrue(receivedResponse)
370 receivedQuery.id = query.id
371 self.assertEquals(query, receivedQuery)
372 self.assertEquals(response, receivedResponse)
373
374 # let the protobuf messages the time to get there
375 time.sleep(1)
376
377 # check the protobuf message corresponding to the TCP query
378 msg = self.getFirstProtobufMessage()
0096ada6
RG
379 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
380 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
8667904b
RG
381
382 # check the protobuf message corresponding to the TCP response
383 msg = self.getFirstProtobufMessage()
0096ada6
RG
384 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
385 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
8667904b
RG
386 self.assertEquals(len(msg.response.rrs), 1)
387 for rr in msg.response.rrs:
388 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
1d0bd88a 389 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')