]> git.ipfire.org Git - thirdparty/pdns.git/blame - regression-tests.dnsdist/test_Protobuf.py
dnsdist tests: make py3k compatible and pick py3k if available
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Protobuf.py
CommitLineData
1d0bd88a 1#!/usr/bin/env python
1d0bd88a
RG
2import threading
3import socket
4import struct
5import sys
6import time
b4f23783 7from dnsdisttests import DNSDistTest, Queue
1d0bd88a
RG
8
9import dns
10import dnsmessage_pb2
11
1d0bd88a 12
741ebe08 13class TestProtobuf(DNSDistTest):
1d0bd88a 14 _protobufServerPort = 4242
b4f23783 15 _protobufQueue = Queue()
1d0bd88a
RG
16 _protobufCounter = 0
17 _config_params = ['_testServerPort', '_protobufServerPort']
18 _config_template = """
29cd61cc
SO
19 luasmn = newSuffixMatchNode()
20 luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
21
22 function alterProtobufResponse(dq, protobuf)
23 if luasmn:check(dq.qname) then
24 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
25 if requestor:isIPv4() then
26 requestor:truncate(24)
8667904b 27 else
29cd61cc 28 requestor:truncate(56)
8667904b 29 end
29cd61cc 30 protobuf:setRequestor(requestor)
741ebe08 31
0096ada6
RG
32 local tableTags = {}
33 table.insert(tableTags, "TestLabel1,TestData1")
34 table.insert(tableTags, "TestLabel2,TestData2")
741ebe08 35
0096ada6 36 protobuf:setTagArray(tableTags)
29cd61cc 37
0096ada6 38 protobuf:setTag('TestLabel3,TestData3')
29cd61cc 39
0096ada6 40 protobuf:setTag("Response,456")
741ebe08 41
29cd61cc 42 else
29cd61cc 43
0096ada6
RG
44 local tableTags = {} -- called by testProtobuf()
45 table.insert(tableTags, "TestLabel1,TestData1")
46 table.insert(tableTags, "TestLabel2,TestData2")
47 protobuf:setTagArray(tableTags)
741ebe08 48
0096ada6
RG
49 protobuf:setTag('TestLabel3,TestData3')
50
51 protobuf:setTag("Response,456")
741ebe08 52
741ebe08
SO
53 end
54 end
55
29cd61cc
SO
56 function alterProtobufQuery(dq, protobuf)
57
58 if luasmn:check(dq.qname) then
741ebe08 59 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
29cd61cc
SO
60 if requestor:isIPv4() then
61 requestor:truncate(24)
741ebe08 62 else
29cd61cc 63 requestor:truncate(56)
741ebe08 64 end
29cd61cc 65 protobuf:setRequestor(requestor)
741ebe08 66
0096ada6
RG
67 local tableTags = {}
68 tableTags = dq:getTagArray() -- get table from DNSQuery
741ebe08 69
0096ada6
RG
70 local tablePB = {}
71 for k, v in pairs( tableTags) do
72 table.insert(tablePB, k .. "," .. v)
73 end
741ebe08 74
0096ada6
RG
75 protobuf:setTagArray(tablePB) -- store table in protobuf
76 protobuf:setTag("Query,123") -- add another tag entry in protobuf
741ebe08 77
0096ada6 78 protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN
741ebe08 79
0096ada6 80 local strReqName = dq.qname:toString() -- get request dns name
741ebe08 81
0096ada6 82 protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time
741ebe08 83
0096ada6 84 blobData = '\127' .. '\000' .. '\000' .. '\001' -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
29cd61cc 85
0096ada6 86 protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
29cd61cc 87
0096ada6 88 protobuf:setBytes(65) -- set the size of the query to confirm in checkProtobufBase
29cd61cc 89
741ebe08
SO
90 else
91
0096ada6
RG
92 local tableTags = {} -- called by testProtobuf()
93 table.insert(tableTags, "TestLabel1,TestData1")
94 table.insert(tableTags, "TestLabel2,TestData2")
741ebe08 95
0096ada6
RG
96 protobuf:setTagArray(tableTags)
97 protobuf:setTag('TestLabel3,TestData3')
98 protobuf:setTag("Query,123")
29cd61cc 99
8667904b
RG
100 end
101 end
102
741ebe08 103 function alterLuaFirst(dq) -- called when dnsdist receives new request
0096ada6
RG
104 local tt = {}
105 tt["TestLabel1"] = "TestData1"
106 tt["TestLabel2"] = "TestData2"
741ebe08 107
0096ada6 108 dq:setTagArray(tt)
741ebe08 109
0096ada6
RG
110 dq:setTag("TestLabel3","TestData3")
111 return DNSAction.None, "" -- continue to the next rule
741ebe08
SO
112 end
113
1d0bd88a
RG
114 newServer{address="127.0.0.1:%s", useClientSubnet=true}
115 rl = newRemoteLogger('127.0.0.1:%s')
741ebe08 116
a2ff35e3 117 addAction(AllRule(), LuaAction(alterLuaFirst)) -- Add tags to DNSQuery first
741ebe08
SO
118
119 addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery)) -- Send protobuf message before lookup
120
121 addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true)) -- Send protobuf message after lookup
122
1d0bd88a
RG
123 """
124
125 @classmethod
126 def ProtobufListener(cls, port):
127 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
128 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
129 try:
130 sock.bind(("127.0.0.1", port))
131 except socket.error as e:
132 print("Error binding in the protbuf listener: %s" % str(e))
133 sys.exit(1)
134
135 sock.listen(100)
136 while True:
137 (conn, _) = sock.accept()
138 data = None
139 while True:
140 data = conn.recv(2)
141 if not data:
142 break
143 (datalen,) = struct.unpack("!H", data)
144 data = conn.recv(datalen)
145 if not data:
146 break
147
148 cls._protobufQueue.put(data, True, timeout=2.0)
149
150 conn.close()
151 sock.close()
152
153 @classmethod
154 def startResponders(cls):
5df86a8a 155 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
1d0bd88a
RG
156 cls._UDPResponder.setDaemon(True)
157 cls._UDPResponder.start()
741ebe08 158
5df86a8a 159 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
1d0bd88a
RG
160 cls._TCPResponder.setDaemon(True)
161 cls._TCPResponder.start()
162
163 cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
164 cls._protobufListener.setDaemon(True)
165 cls._protobufListener.start()
166
167 def getFirstProtobufMessage(self):
168 self.assertFalse(self._protobufQueue.empty())
169 data = self._protobufQueue.get(False)
170 self.assertTrue(data)
171 msg = dnsmessage_pb2.PBDNSMessage()
172 msg.ParseFromString(data)
173 return msg
174
29cd61cc 175 def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True):
1d0bd88a
RG
176 self.assertTrue(msg)
177 self.assertTrue(msg.HasField('timeSec'))
178 self.assertTrue(msg.HasField('socketFamily'))
179 self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
180 self.assertTrue(msg.HasField('from'))
181 fromvalue = getattr(msg, 'from')
29cd61cc 182 self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
1d0bd88a
RG
183 self.assertTrue(msg.HasField('socketProtocol'))
184 self.assertEquals(msg.socketProtocol, protocol)
185 self.assertTrue(msg.HasField('messageId'))
186 self.assertTrue(msg.HasField('id'))
29cd61cc 187 self.assertEquals(msg.id, query.id)
1d0bd88a 188 self.assertTrue(msg.HasField('inBytes'))
0096ada6
RG
189 if normalQueryResponse:
190 # compare inBytes with length of query/response
191 self.assertEquals(msg.inBytes, len(query.to_wire()))
1d0bd88a
RG
192 # dnsdist doesn't set the existing EDNS Subnet for now,
193 # although it might be set from Lua
194 # self.assertTrue(msg.HasField('originalRequestorSubnet'))
195 # self.assertEquals(len(msg.originalRequestorSubnet), 4)
196 # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
197
8667904b 198 def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
0096ada6 199 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
8667904b 200 self.checkProtobufBase(msg, protocol, query, initiator)
1d0bd88a
RG
201 # dnsdist doesn't fill the responder field for responses
202 # because it doesn't keep the information around.
203 self.assertTrue(msg.HasField('to'))
204 self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
205 self.assertTrue(msg.HasField('question'))
206 self.assertTrue(msg.question.HasField('qClass'))
207 self.assertEquals(msg.question.qClass, qclass)
208 self.assertTrue(msg.question.HasField('qType'))
209 self.assertEquals(msg.question.qClass, qtype)
210 self.assertTrue(msg.question.HasField('qName'))
211 self.assertEquals(msg.question.qName, qname)
212
0096ada6
RG
213 def checkProtobufTags(self, tags, expectedTags):
214 # only differences will be in new list
215 listx = set(tags) ^ set(expectedTags)
216 # exclusive or of lists should be empty
217 self.assertEqual(len(listx), 0, "Protobuf tags don't match")
741ebe08 218
29cd61cc
SO
219 def checkProtobufQueryConvertedToResponse(self, msg, protocol, response, initiator='127.0.0.0'):
220 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
0096ada6
RG
221 # skip comparing inBytes (size of the query) with the length of the generated response
222 self.checkProtobufBase(msg, protocol, response, initiator, False)
29cd61cc
SO
223 self.assertTrue(msg.HasField('response'))
224 self.assertTrue(msg.response.HasField('queryTimeSec'))
225
8667904b 226 def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
29cd61cc 227 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
8667904b 228 self.checkProtobufBase(msg, protocol, response, initiator)
29cd61cc
SO
229 self.assertTrue(msg.HasField('response'))
230 self.assertTrue(msg.response.HasField('queryTimeSec'))
741ebe08 231
8667904b
RG
232 def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
233 self.assertTrue(record.HasField('class'))
234 self.assertEquals(getattr(record, 'class'), rclass)
235 self.assertTrue(record.HasField('type'))
236 self.assertEquals(record.type, rtype)
237 self.assertTrue(record.HasField('name'))
238 self.assertEquals(record.name, rname)
239 self.assertTrue(record.HasField('ttl'))
240 self.assertEquals(record.ttl, rttl)
241 self.assertTrue(record.HasField('rdata'))
242
1d0bd88a
RG
243 def testProtobuf(self):
244 """
245 Protobuf: Send data to a protobuf server
246 """
29cd61cc 247 name = 'query.protobuf.tests.powerdns.com.'
741ebe08 248
165c9030 249 target = 'target.protobuf.tests.powerdns.com.'
1d0bd88a
RG
250 query = dns.message.make_query(name, 'A', 'IN')
251 response = dns.message.make_response(query)
741ebe08 252
1d0bd88a 253 rrset = dns.rrset.from_text(name,
165c9030
RG
254 3600,
255 dns.rdataclass.IN,
256 dns.rdatatype.CNAME,
257 target)
258 response.answer.append(rrset)
741ebe08 259
165c9030 260 rrset = dns.rrset.from_text(target,
1d0bd88a
RG
261 3600,
262 dns.rdataclass.IN,
263 dns.rdatatype.A,
264 '127.0.0.1')
265 response.answer.append(rrset)
266
267 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
268 self.assertTrue(receivedQuery)
269 self.assertTrue(receivedResponse)
270 receivedQuery.id = query.id
271 self.assertEquals(query, receivedQuery)
272 self.assertEquals(response, receivedResponse)
273
274 # let the protobuf messages the time to get there
275 time.sleep(1)
276
277 # check the protobuf message corresponding to the UDP query
278 msg = self.getFirstProtobufMessage()
741ebe08 279
29cd61cc 280 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
0096ada6 281 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
1d0bd88a
RG
282
283 # check the protobuf message corresponding to the UDP response
284 msg = self.getFirstProtobufMessage()
0096ada6
RG
285 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
286 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
29cd61cc 287 self.assertEquals(len(msg.response.rrs), 2)
165c9030
RG
288 rr = msg.response.rrs[0]
289 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
290 self.assertEquals(rr.rdata, target)
291 rr = msg.response.rrs[1]
292 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
293 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
1d0bd88a
RG
294
295 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
296 self.assertTrue(receivedQuery)
297 self.assertTrue(receivedResponse)
298 receivedQuery.id = query.id
299 self.assertEquals(query, receivedQuery)
300 self.assertEquals(response, receivedResponse)
301
302 # let the protobuf messages the time to get there
303 time.sleep(1)
304
305 # check the protobuf message corresponding to the TCP query
306 msg = self.getFirstProtobufMessage()
29cd61cc 307
1d0bd88a 308 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
0096ada6 309 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
1d0bd88a
RG
310
311 # check the protobuf message corresponding to the TCP response
312 msg = self.getFirstProtobufMessage()
0096ada6
RG
313 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
314 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
165c9030
RG
315 self.assertEquals(len(msg.response.rrs), 2)
316 rr = msg.response.rrs[0]
317 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
318 self.assertEquals(rr.rdata, target)
319 rr = msg.response.rrs[1]
320 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
321 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
8667904b
RG
322
323 def testLuaProtobuf(self):
741ebe08 324
8667904b
RG
325 """
326 Protobuf: Check that the Lua callback rewrote the initiator
327 """
328 name = 'lua.protobuf.tests.powerdns.com.'
329 query = dns.message.make_query(name, 'A', 'IN')
330 response = dns.message.make_response(query)
331 rrset = dns.rrset.from_text(name,
332 3600,
333 dns.rdataclass.IN,
334 dns.rdatatype.A,
335 '127.0.0.1')
336 response.answer.append(rrset)
337
741ebe08 338
8667904b 339 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
741ebe08 340
29cd61cc
SO
341 self.assertTrue(receivedQuery)
342 self.assertTrue(receivedResponse)
8667904b
RG
343 receivedQuery.id = query.id
344 self.assertEquals(query, receivedQuery)
345 self.assertEquals(response, receivedResponse)
346
741ebe08 347
8667904b
RG
348 # let the protobuf messages the time to get there
349 time.sleep(1)
350
351 # check the protobuf message corresponding to the UDP query
352 msg = self.getFirstProtobufMessage()
29cd61cc 353
0096ada6
RG
354 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
355 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
8667904b
RG
356
357 # check the protobuf message corresponding to the UDP response
358 msg = self.getFirstProtobufMessage()
0096ada6
RG
359 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
360 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
8667904b
RG
361 self.assertEquals(len(msg.response.rrs), 1)
362 for rr in msg.response.rrs:
363 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
364 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
365
366 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
367 self.assertTrue(receivedQuery)
368 self.assertTrue(receivedResponse)
369 receivedQuery.id = query.id
370 self.assertEquals(query, receivedQuery)
371 self.assertEquals(response, receivedResponse)
372
373 # let the protobuf messages the time to get there
374 time.sleep(1)
375
376 # check the protobuf message corresponding to the TCP query
377 msg = self.getFirstProtobufMessage()
0096ada6
RG
378 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
379 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
8667904b
RG
380
381 # check the protobuf message corresponding to the TCP response
382 msg = self.getFirstProtobufMessage()
0096ada6
RG
383 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
384 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
8667904b
RG
385 self.assertEquals(len(msg.response.rrs), 1)
386 for rr in msg.response.rrs:
387 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
1d0bd88a 388 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')