]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/test_Protobuf.py
Merge pull request #11417 from nils-wisiol/benchmark-2048bit
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Protobuf.py
1 #!/usr/bin/env python
2 import threading
3 import socket
4 import struct
5 import sys
6 import time
7 from dnsdisttests import DNSDistTest, Queue
8
9 import dns
10 import dnsmessage_pb2
11
12 class DNSDistProtobufTest(DNSDistTest):
13 _protobufServerPort = 4242
14 _protobufQueue = Queue()
15 _protobufServerID = 'dnsdist-server-1'
16 _protobufCounter = 0
17
18 @classmethod
19 def ProtobufListener(cls, port):
20 cls._backgroundThreads[threading.get_native_id()] = True
21 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
22 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
23 try:
24 sock.bind(("127.0.0.1", port))
25 except socket.error as e:
26 print("Error binding in the protbuf listener: %s" % str(e))
27 sys.exit(1)
28
29 sock.listen(100)
30 sock.settimeout(1.0)
31 while True:
32 try:
33 (conn, _) = sock.accept()
34 except socket.timeout:
35 if cls._backgroundThreads.get(threading.get_native_id(), False) == False:
36 del cls._backgroundThreads[threading.get_native_id()]
37 break
38 else:
39 continue
40
41 data = None
42 while True:
43 data = conn.recv(2)
44 if not data:
45 break
46 (datalen,) = struct.unpack("!H", data)
47 data = conn.recv(datalen)
48 if not data:
49 break
50
51 cls._protobufQueue.put(data, True, timeout=2.0)
52
53 conn.close()
54 sock.close()
55
56 @classmethod
57 def startResponders(cls):
58 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
59 cls._UDPResponder.setDaemon(True)
60 cls._UDPResponder.start()
61
62 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
63 cls._TCPResponder.setDaemon(True)
64 cls._TCPResponder.start()
65
66 cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
67 cls._protobufListener.setDaemon(True)
68 cls._protobufListener.start()
69
70 def getFirstProtobufMessage(self):
71 self.assertFalse(self._protobufQueue.empty())
72 data = self._protobufQueue.get(False)
73 self.assertTrue(data)
74 msg = dnsmessage_pb2.PBDNSMessage()
75 msg.ParseFromString(data)
76 return msg
77
78 def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True):
79 self.assertTrue(msg)
80 self.assertTrue(msg.HasField('timeSec'))
81 self.assertTrue(msg.HasField('socketFamily'))
82 self.assertEqual(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
83 self.assertTrue(msg.HasField('from'))
84 fromvalue = getattr(msg, 'from')
85 self.assertEqual(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
86 self.assertTrue(msg.HasField('socketProtocol'))
87 self.assertEqual(msg.socketProtocol, protocol)
88 self.assertTrue(msg.HasField('messageId'))
89 self.assertTrue(msg.HasField('id'))
90 self.assertEqual(msg.id, query.id)
91 self.assertTrue(msg.HasField('inBytes'))
92 self.assertTrue(msg.HasField('serverIdentity'))
93 self.assertEqual(msg.serverIdentity, self._protobufServerID.encode('utf-8'))
94
95 if normalQueryResponse:
96 # compare inBytes with length of query/response
97 self.assertEqual(msg.inBytes, len(query.to_wire()))
98 # dnsdist doesn't set the existing EDNS Subnet for now,
99 # although it might be set from Lua
100 # self.assertTrue(msg.HasField('originalRequestorSubnet'))
101 # self.assertEqual(len(msg.originalRequestorSubnet), 4)
102 # self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
103
104 def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
105 self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
106 self.checkProtobufBase(msg, protocol, query, initiator)
107 # dnsdist doesn't fill the responder field for responses
108 # because it doesn't keep the information around.
109 self.assertTrue(msg.HasField('to'))
110 self.assertEqual(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
111 self.assertTrue(msg.HasField('question'))
112 self.assertTrue(msg.question.HasField('qClass'))
113 self.assertEqual(msg.question.qClass, qclass)
114 self.assertTrue(msg.question.HasField('qType'))
115 self.assertEqual(msg.question.qClass, qtype)
116 self.assertTrue(msg.question.HasField('qName'))
117 self.assertEqual(msg.question.qName, qname)
118
119 def checkProtobufTags(self, tags, expectedTags):
120 # only differences will be in new list
121 listx = set(tags) ^ set(expectedTags)
122 # exclusive or of lists should be empty
123 self.assertEqual(len(listx), 0, "Protobuf tags don't match")
124
125 def checkProtobufQueryConvertedToResponse(self, msg, protocol, response, initiator='127.0.0.0'):
126 self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
127 # skip comparing inBytes (size of the query) with the length of the generated response
128 self.checkProtobufBase(msg, protocol, response, initiator, False)
129 self.assertTrue(msg.HasField('response'))
130 self.assertTrue(msg.response.HasField('queryTimeSec'))
131
132 def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
133 self.assertEqual(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
134 self.checkProtobufBase(msg, protocol, response, initiator)
135 self.assertTrue(msg.HasField('response'))
136 self.assertTrue(msg.response.HasField('queryTimeSec'))
137
138 def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
139 self.assertTrue(record.HasField('class'))
140 self.assertEqual(getattr(record, 'class'), rclass)
141 self.assertTrue(record.HasField('type'))
142 self.assertEqual(record.type, rtype)
143 self.assertTrue(record.HasField('name'))
144 self.assertEqual(record.name, rname)
145 self.assertTrue(record.HasField('ttl'))
146 self.assertEqual(record.ttl, rttl)
147 self.assertTrue(record.HasField('rdata'))
148
149 class TestProtobuf(DNSDistProtobufTest):
150 _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
151 _config_template = """
152 luasmn = newSuffixMatchNode()
153 luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
154
155 function alterProtobufResponse(dq, protobuf)
156 if luasmn:check(dq.qname) then
157 requestor = newCA(tostring(dq.remoteaddr)) -- called by testLuaProtobuf()
158 if requestor:isIPv4() then
159 requestor:truncate(24)
160 else
161 requestor:truncate(56)
162 end
163 protobuf:setRequestor(requestor)
164
165 local tableTags = {}
166 table.insert(tableTags, "TestLabel1,TestData1")
167 table.insert(tableTags, "TestLabel2,TestData2")
168
169 protobuf:setTagArray(tableTags)
170
171 protobuf:setTag('TestLabel3,TestData3')
172
173 protobuf:setTag("Response,456")
174
175 else
176
177 local tableTags = {} -- called by testProtobuf()
178 table.insert(tableTags, "TestLabel1,TestData1")
179 table.insert(tableTags, "TestLabel2,TestData2")
180 protobuf:setTagArray(tableTags)
181
182 protobuf:setTag('TestLabel3,TestData3')
183
184 protobuf:setTag("Response,456")
185
186 end
187 end
188
189 function alterProtobufQuery(dq, protobuf)
190
191 if luasmn:check(dq.qname) then
192 requestor = newCA(tostring(dq.remoteaddr)) -- called by testLuaProtobuf()
193 if requestor:isIPv4() then
194 requestor:truncate(24)
195 else
196 requestor:truncate(56)
197 end
198 protobuf:setRequestor(requestor)
199
200 local tableTags = {}
201 tableTags = dq:getTagArray() -- get table from DNSQuery
202
203 local tablePB = {}
204 for k, v in pairs( tableTags) do
205 table.insert(tablePB, k .. "," .. v)
206 end
207
208 protobuf:setTagArray(tablePB) -- store table in protobuf
209 protobuf:setTag("Query,123") -- add another tag entry in protobuf
210
211 protobuf:setResponseCode(DNSRCode.NXDOMAIN) -- set protobuf response code to be NXDOMAIN
212
213 local strReqName = tostring(dq.qname) -- get request dns name
214
215 protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time
216
217 blobData = '\127' .. '\000' .. '\000' .. '\002' -- 127.0.0.2, note: lua 5.1 can only embed decimal not hex
218
219 protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
220
221 protobuf:setBytes(65) -- set the size of the query to confirm in checkProtobufBase
222
223 else
224
225 local tableTags = {} -- called by testProtobuf()
226 table.insert(tableTags, "TestLabel1,TestData1")
227 table.insert(tableTags, "TestLabel2,TestData2")
228
229 protobuf:setTagArray(tableTags)
230 protobuf:setTag('TestLabel3,TestData3')
231 protobuf:setTag("Query,123")
232
233 end
234 end
235
236 function alterLuaFirst(dq) -- called when dnsdist receives new request
237 local tt = {}
238 tt["TestLabel1"] = "TestData1"
239 tt["TestLabel2"] = "TestData2"
240
241 dq:setTagArray(tt)
242
243 dq:setTag("TestLabel3","TestData3")
244 return DNSAction.None, "" -- continue to the next rule
245 end
246
247 newServer{address="127.0.0.1:%s", useClientSubnet=true}
248 rl = newRemoteLogger('127.0.0.1:%s')
249
250 addAction(AllRule(), LuaAction(alterLuaFirst)) -- Add tags to DNSQuery first
251
252 addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'})) -- Send protobuf message before lookup
253
254 addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'})) -- Send protobuf message after lookup
255
256 """
257
258 def testProtobuf(self):
259 """
260 Protobuf: Send data to a protobuf server
261 """
262 name = 'query.protobuf.tests.powerdns.com.'
263
264 target = 'target.protobuf.tests.powerdns.com.'
265 query = dns.message.make_query(name, 'A', 'IN')
266 response = dns.message.make_response(query)
267
268 rrset = dns.rrset.from_text(name,
269 3600,
270 dns.rdataclass.IN,
271 dns.rdatatype.CNAME,
272 target)
273 response.answer.append(rrset)
274
275 rrset = dns.rrset.from_text(target,
276 3600,
277 dns.rdataclass.IN,
278 dns.rdatatype.A,
279 '127.0.0.1')
280 response.answer.append(rrset)
281
282 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
283 self.assertTrue(receivedQuery)
284 self.assertTrue(receivedResponse)
285 receivedQuery.id = query.id
286 self.assertEqual(query, receivedQuery)
287 self.assertEqual(response, receivedResponse)
288
289 # let the protobuf messages the time to get there
290 time.sleep(1)
291
292 # check the protobuf message corresponding to the UDP query
293 msg = self.getFirstProtobufMessage()
294
295 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
296 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
297
298 # check the protobuf message corresponding to the UDP response
299 msg = self.getFirstProtobufMessage()
300 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
301 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
302 self.assertEqual(len(msg.response.rrs), 2)
303 rr = msg.response.rrs[0]
304 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
305 self.assertEqual(rr.rdata.decode('utf-8'), target)
306 rr = msg.response.rrs[1]
307 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
308 self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
309
310 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
311 self.assertTrue(receivedQuery)
312 self.assertTrue(receivedResponse)
313 receivedQuery.id = query.id
314 self.assertEqual(query, receivedQuery)
315 self.assertEqual(response, receivedResponse)
316
317 # let the protobuf messages the time to get there
318 time.sleep(1)
319
320 # check the protobuf message corresponding to the TCP query
321 msg = self.getFirstProtobufMessage()
322
323 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
324 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
325
326 # check the protobuf message corresponding to the TCP response
327 msg = self.getFirstProtobufMessage()
328 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
329 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
330 self.assertEqual(len(msg.response.rrs), 2)
331 rr = msg.response.rrs[0]
332 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
333 self.assertEqual(rr.rdata.decode('utf-8'), target)
334 rr = msg.response.rrs[1]
335 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
336 self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
337
338 def testLuaProtobuf(self):
339
340 """
341 Protobuf: Check that the Lua callback rewrote the initiator
342 """
343 name = 'lua.protobuf.tests.powerdns.com.'
344 query = dns.message.make_query(name, 'A', 'IN')
345 response = dns.message.make_response(query)
346 rrset = dns.rrset.from_text(name,
347 3600,
348 dns.rdataclass.IN,
349 dns.rdatatype.A,
350 '127.0.0.1')
351 response.answer.append(rrset)
352
353
354 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
355
356 self.assertTrue(receivedQuery)
357 self.assertTrue(receivedResponse)
358 receivedQuery.id = query.id
359 self.assertEqual(query, receivedQuery)
360 self.assertEqual(response, receivedResponse)
361
362
363 # let the protobuf messages the time to get there
364 time.sleep(1)
365
366 # check the protobuf message corresponding to the UDP query
367 msg = self.getFirstProtobufMessage()
368
369 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
370 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
371
372 # check the protobuf message corresponding to the UDP response
373 msg = self.getFirstProtobufMessage()
374 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
375 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
376 self.assertEqual(len(msg.response.rrs), 1)
377 for rr in msg.response.rrs:
378 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
379 self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
380
381 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
382 self.assertTrue(receivedQuery)
383 self.assertTrue(receivedResponse)
384 receivedQuery.id = query.id
385 self.assertEqual(query, receivedQuery)
386 self.assertEqual(response, receivedResponse)
387
388 # let the protobuf messages the time to get there
389 time.sleep(1)
390
391 # check the protobuf message corresponding to the TCP query
392 msg = self.getFirstProtobufMessage()
393 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
394 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
395
396 # check the protobuf message corresponding to the TCP response
397 msg = self.getFirstProtobufMessage()
398 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
399 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
400 self.assertEqual(len(msg.response.rrs), 1)
401 for rr in msg.response.rrs:
402 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
403 self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
404
405 class TestProtobufIPCipher(DNSDistProtobufTest):
406 _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
407 _config_template = """
408 newServer{address="127.0.0.1:%s", useClientSubnet=true}
409 key = makeIPCipherKey("some 16-byte key")
410 rl = newRemoteLogger('127.0.0.1:%s')
411 addAction(AllRule(), RemoteLogAction(rl, nil, {serverID='%s', ipEncryptKey=key})) -- Send protobuf message before lookup
412 addResponseAction(AllRule(), RemoteLogResponseAction(rl, nil, true, {serverID='%s', ipEncryptKey=key})) -- Send protobuf message after lookup
413
414 """
415
416 def testProtobuf(self):
417 """
418 Protobuf: Send data to a protobuf server, with pseudonymization
419 """
420 name = 'query.protobuf-ipcipher.tests.powerdns.com.'
421
422 target = 'target.protobuf-ipcipher.tests.powerdns.com.'
423 query = dns.message.make_query(name, 'A', 'IN')
424 response = dns.message.make_response(query)
425
426 rrset = dns.rrset.from_text(name,
427 3600,
428 dns.rdataclass.IN,
429 dns.rdatatype.CNAME,
430 target)
431 response.answer.append(rrset)
432
433 rrset = dns.rrset.from_text(target,
434 3600,
435 dns.rdataclass.IN,
436 dns.rdatatype.A,
437 '127.0.0.1')
438 response.answer.append(rrset)
439
440 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
441 self.assertTrue(receivedQuery)
442 self.assertTrue(receivedResponse)
443 receivedQuery.id = query.id
444 self.assertEqual(query, receivedQuery)
445 self.assertEqual(response, receivedResponse)
446
447 # let the protobuf messages the time to get there
448 time.sleep(1)
449
450 # check the protobuf message corresponding to the UDP query
451 msg = self.getFirstProtobufMessage()
452
453 # 108.41.239.98 is 127.0.0.1 pseudonymized with ipcipher and the current key
454 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '108.41.239.98')
455
456 # check the protobuf message corresponding to the UDP response
457 msg = self.getFirstProtobufMessage()
458 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '108.41.239.98')
459
460 self.assertEqual(len(msg.response.rrs), 2)
461 rr = msg.response.rrs[0]
462 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
463 self.assertEqual(rr.rdata.decode('ascii'), target)
464 rr = msg.response.rrs[1]
465 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
466 self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
467
468 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
469 self.assertTrue(receivedQuery)
470 self.assertTrue(receivedResponse)
471 receivedQuery.id = query.id
472 self.assertEqual(query, receivedQuery)
473 self.assertEqual(response, receivedResponse)
474
475 # let the protobuf messages the time to get there
476 time.sleep(1)
477
478 # check the protobuf message corresponding to the TCP query
479 msg = self.getFirstProtobufMessage()
480 # 108.41.239.98 is 127.0.0.1 pseudonymized with ipcipher and the current key
481 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name, '108.41.239.98')
482
483 # check the protobuf message corresponding to the TCP response
484 msg = self.getFirstProtobufMessage()
485 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '108.41.239.98')
486 self.assertEqual(len(msg.response.rrs), 2)
487 rr = msg.response.rrs[0]
488 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
489 self.assertEqual(rr.rdata.decode('ascii'), target)
490 rr = msg.response.rrs[1]
491 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
492 self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')