]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/test_Protobuf.py
Merge pull request #6952 from phonedph1/root-nx
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Protobuf.py
1 #!/usr/bin/env python
2 import threading
3 import socket
4 import struct
5 import sys
6 import time
7 from dnsdisttests import DNSDistTest, Queue
8
9 import dns
10 import dnsmessage_pb2
11
12
13 class TestProtobuf(DNSDistTest):
14 _protobufServerPort = 4242
15 _protobufQueue = Queue()
16 _protobufServerID = 'dnsdist-server-1'
17 _protobufCounter = 0
18 _config_params = ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
19 _config_template = """
20 luasmn = newSuffixMatchNode()
21 luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
22
23 function alterProtobufResponse(dq, protobuf)
24 if luasmn:check(dq.qname) then
25 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
26 if requestor:isIPv4() then
27 requestor:truncate(24)
28 else
29 requestor:truncate(56)
30 end
31 protobuf:setRequestor(requestor)
32
33 local tableTags = {}
34 table.insert(tableTags, "TestLabel1,TestData1")
35 table.insert(tableTags, "TestLabel2,TestData2")
36
37 protobuf:setTagArray(tableTags)
38
39 protobuf:setTag('TestLabel3,TestData3')
40
41 protobuf:setTag("Response,456")
42
43 else
44
45 local tableTags = {} -- called by testProtobuf()
46 table.insert(tableTags, "TestLabel1,TestData1")
47 table.insert(tableTags, "TestLabel2,TestData2")
48 protobuf:setTagArray(tableTags)
49
50 protobuf:setTag('TestLabel3,TestData3')
51
52 protobuf:setTag("Response,456")
53
54 end
55 end
56
57 function alterProtobufQuery(dq, protobuf)
58
59 if luasmn:check(dq.qname) then
60 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
61 if requestor:isIPv4() then
62 requestor:truncate(24)
63 else
64 requestor:truncate(56)
65 end
66 protobuf:setRequestor(requestor)
67
68 local tableTags = {}
69 tableTags = dq:getTagArray() -- get table from DNSQuery
70
71 local tablePB = {}
72 for k, v in pairs( tableTags) do
73 table.insert(tablePB, k .. "," .. v)
74 end
75
76 protobuf:setTagArray(tablePB) -- store table in protobuf
77 protobuf:setTag("Query,123") -- add another tag entry in protobuf
78
79 protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN
80
81 local strReqName = dq.qname:toString() -- get request dns name
82
83 protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time
84
85 blobData = '\127' .. '\000' .. '\000' .. '\001' -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
86
87 protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
88
89 protobuf:setBytes(65) -- set the size of the query to confirm in checkProtobufBase
90
91 else
92
93 local tableTags = {} -- called by testProtobuf()
94 table.insert(tableTags, "TestLabel1,TestData1")
95 table.insert(tableTags, "TestLabel2,TestData2")
96
97 protobuf:setTagArray(tableTags)
98 protobuf:setTag('TestLabel3,TestData3')
99 protobuf:setTag("Query,123")
100
101 end
102 end
103
104 function alterLuaFirst(dq) -- called when dnsdist receives new request
105 local tt = {}
106 tt["TestLabel1"] = "TestData1"
107 tt["TestLabel2"] = "TestData2"
108
109 dq:setTagArray(tt)
110
111 dq:setTag("TestLabel3","TestData3")
112 return DNSAction.None, "" -- continue to the next rule
113 end
114
115 newServer{address="127.0.0.1:%s", useClientSubnet=true}
116 rl = newRemoteLogger('127.0.0.1:%s')
117
118 addAction(AllRule(), LuaAction(alterLuaFirst)) -- Add tags to DNSQuery first
119
120 addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'})) -- Send protobuf message before lookup
121
122 addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'})) -- Send protobuf message after lookup
123
124 """
125
126 @classmethod
127 def ProtobufListener(cls, port):
128 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
129 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
130 try:
131 sock.bind(("127.0.0.1", port))
132 except socket.error as e:
133 print("Error binding in the protbuf listener: %s" % str(e))
134 sys.exit(1)
135
136 sock.listen(100)
137 while True:
138 (conn, _) = sock.accept()
139 data = None
140 while True:
141 data = conn.recv(2)
142 if not data:
143 break
144 (datalen,) = struct.unpack("!H", data)
145 data = conn.recv(datalen)
146 if not data:
147 break
148
149 cls._protobufQueue.put(data, True, timeout=2.0)
150
151 conn.close()
152 sock.close()
153
154 @classmethod
155 def startResponders(cls):
156 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
157 cls._UDPResponder.setDaemon(True)
158 cls._UDPResponder.start()
159
160 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue])
161 cls._TCPResponder.setDaemon(True)
162 cls._TCPResponder.start()
163
164 cls._protobufListener = threading.Thread(name='Protobuf Listener', target=cls.ProtobufListener, args=[cls._protobufServerPort])
165 cls._protobufListener.setDaemon(True)
166 cls._protobufListener.start()
167
168 def getFirstProtobufMessage(self):
169 self.assertFalse(self._protobufQueue.empty())
170 data = self._protobufQueue.get(False)
171 self.assertTrue(data)
172 msg = dnsmessage_pb2.PBDNSMessage()
173 msg.ParseFromString(data)
174 return msg
175
176 def checkProtobufBase(self, msg, protocol, query, initiator, normalQueryResponse=True):
177 self.assertTrue(msg)
178 self.assertTrue(msg.HasField('timeSec'))
179 self.assertTrue(msg.HasField('socketFamily'))
180 self.assertEquals(msg.socketFamily, dnsmessage_pb2.PBDNSMessage.INET)
181 self.assertTrue(msg.HasField('from'))
182 fromvalue = getattr(msg, 'from')
183 self.assertEquals(socket.inet_ntop(socket.AF_INET, fromvalue), initiator)
184 self.assertTrue(msg.HasField('socketProtocol'))
185 self.assertEquals(msg.socketProtocol, protocol)
186 self.assertTrue(msg.HasField('messageId'))
187 self.assertTrue(msg.HasField('id'))
188 self.assertEquals(msg.id, query.id)
189 self.assertTrue(msg.HasField('inBytes'))
190 self.assertTrue(msg.HasField('serverIdentity'))
191 self.assertEquals(msg.serverIdentity, self._protobufServerID)
192
193 if normalQueryResponse:
194 # compare inBytes with length of query/response
195 self.assertEquals(msg.inBytes, len(query.to_wire()))
196 # dnsdist doesn't set the existing EDNS Subnet for now,
197 # although it might be set from Lua
198 # self.assertTrue(msg.HasField('originalRequestorSubnet'))
199 # self.assertEquals(len(msg.originalRequestorSubnet), 4)
200 # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
201
202 def checkProtobufQuery(self, msg, protocol, query, qclass, qtype, qname, initiator='127.0.0.1'):
203 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSQueryType)
204 self.checkProtobufBase(msg, protocol, query, initiator)
205 # dnsdist doesn't fill the responder field for responses
206 # because it doesn't keep the information around.
207 self.assertTrue(msg.HasField('to'))
208 self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.to), '127.0.0.1')
209 self.assertTrue(msg.HasField('question'))
210 self.assertTrue(msg.question.HasField('qClass'))
211 self.assertEquals(msg.question.qClass, qclass)
212 self.assertTrue(msg.question.HasField('qType'))
213 self.assertEquals(msg.question.qClass, qtype)
214 self.assertTrue(msg.question.HasField('qName'))
215 self.assertEquals(msg.question.qName, qname)
216
217 def checkProtobufTags(self, tags, expectedTags):
218 # only differences will be in new list
219 listx = set(tags) ^ set(expectedTags)
220 # exclusive or of lists should be empty
221 self.assertEqual(len(listx), 0, "Protobuf tags don't match")
222
223 def checkProtobufQueryConvertedToResponse(self, msg, protocol, response, initiator='127.0.0.0'):
224 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
225 # skip comparing inBytes (size of the query) with the length of the generated response
226 self.checkProtobufBase(msg, protocol, response, initiator, False)
227 self.assertTrue(msg.HasField('response'))
228 self.assertTrue(msg.response.HasField('queryTimeSec'))
229
230 def checkProtobufResponse(self, msg, protocol, response, initiator='127.0.0.1'):
231 self.assertEquals(msg.type, dnsmessage_pb2.PBDNSMessage.DNSResponseType)
232 self.checkProtobufBase(msg, protocol, response, initiator)
233 self.assertTrue(msg.HasField('response'))
234 self.assertTrue(msg.response.HasField('queryTimeSec'))
235
236 def checkProtobufResponseRecord(self, record, rclass, rtype, rname, rttl):
237 self.assertTrue(record.HasField('class'))
238 self.assertEquals(getattr(record, 'class'), rclass)
239 self.assertTrue(record.HasField('type'))
240 self.assertEquals(record.type, rtype)
241 self.assertTrue(record.HasField('name'))
242 self.assertEquals(record.name, rname)
243 self.assertTrue(record.HasField('ttl'))
244 self.assertEquals(record.ttl, rttl)
245 self.assertTrue(record.HasField('rdata'))
246
247 def testProtobuf(self):
248 """
249 Protobuf: Send data to a protobuf server
250 """
251 name = 'query.protobuf.tests.powerdns.com.'
252
253 target = 'target.protobuf.tests.powerdns.com.'
254 query = dns.message.make_query(name, 'A', 'IN')
255 response = dns.message.make_response(query)
256
257 rrset = dns.rrset.from_text(name,
258 3600,
259 dns.rdataclass.IN,
260 dns.rdatatype.CNAME,
261 target)
262 response.answer.append(rrset)
263
264 rrset = dns.rrset.from_text(target,
265 3600,
266 dns.rdataclass.IN,
267 dns.rdatatype.A,
268 '127.0.0.1')
269 response.answer.append(rrset)
270
271 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
272 self.assertTrue(receivedQuery)
273 self.assertTrue(receivedResponse)
274 receivedQuery.id = query.id
275 self.assertEquals(query, receivedQuery)
276 self.assertEquals(response, receivedResponse)
277
278 # let the protobuf messages the time to get there
279 time.sleep(1)
280
281 # check the protobuf message corresponding to the UDP query
282 msg = self.getFirstProtobufMessage()
283
284 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.UDP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
285 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
286
287 # check the protobuf message corresponding to the UDP response
288 msg = self.getFirstProtobufMessage()
289 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response)
290 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
291 self.assertEquals(len(msg.response.rrs), 2)
292 rr = msg.response.rrs[0]
293 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
294 self.assertEquals(rr.rdata, target)
295 rr = msg.response.rrs[1]
296 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
297 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
298
299 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
300 self.assertTrue(receivedQuery)
301 self.assertTrue(receivedResponse)
302 receivedQuery.id = query.id
303 self.assertEquals(query, receivedQuery)
304 self.assertEquals(response, receivedResponse)
305
306 # let the protobuf messages the time to get there
307 time.sleep(1)
308
309 # check the protobuf message corresponding to the TCP query
310 msg = self.getFirstProtobufMessage()
311
312 self.checkProtobufQuery(msg, dnsmessage_pb2.PBDNSMessage.TCP, query, dns.rdataclass.IN, dns.rdatatype.A, name)
313 self.checkProtobufTags(msg.response.tags, [u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
314
315 # check the protobuf message corresponding to the TCP response
316 msg = self.getFirstProtobufMessage()
317 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response)
318 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
319 self.assertEquals(len(msg.response.rrs), 2)
320 rr = msg.response.rrs[0]
321 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.CNAME, name, 3600)
322 self.assertEquals(rr.rdata, target)
323 rr = msg.response.rrs[1]
324 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, target, 3600)
325 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
326
327 def testLuaProtobuf(self):
328
329 """
330 Protobuf: Check that the Lua callback rewrote the initiator
331 """
332 name = 'lua.protobuf.tests.powerdns.com.'
333 query = dns.message.make_query(name, 'A', 'IN')
334 response = dns.message.make_response(query)
335 rrset = dns.rrset.from_text(name,
336 3600,
337 dns.rdataclass.IN,
338 dns.rdatatype.A,
339 '127.0.0.1')
340 response.answer.append(rrset)
341
342
343 (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
344
345 self.assertTrue(receivedQuery)
346 self.assertTrue(receivedResponse)
347 receivedQuery.id = query.id
348 self.assertEquals(query, receivedQuery)
349 self.assertEquals(response, receivedResponse)
350
351
352 # let the protobuf messages the time to get there
353 time.sleep(1)
354
355 # check the protobuf message corresponding to the UDP query
356 msg = self.getFirstProtobufMessage()
357
358 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
359 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
360
361 # check the protobuf message corresponding to the UDP response
362 msg = self.getFirstProtobufMessage()
363 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, response, '127.0.0.0')
364 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
365 self.assertEquals(len(msg.response.rrs), 1)
366 for rr in msg.response.rrs:
367 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
368 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')
369
370 (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
371 self.assertTrue(receivedQuery)
372 self.assertTrue(receivedResponse)
373 receivedQuery.id = query.id
374 self.assertEquals(query, receivedQuery)
375 self.assertEquals(response, receivedResponse)
376
377 # let the protobuf messages the time to get there
378 time.sleep(1)
379
380 # check the protobuf message corresponding to the TCP query
381 msg = self.getFirstProtobufMessage()
382 self.checkProtobufQueryConvertedToResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
383 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Query,123"])
384
385 # check the protobuf message corresponding to the TCP response
386 msg = self.getFirstProtobufMessage()
387 self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, response, '127.0.0.0')
388 self.checkProtobufTags(msg.response.tags, [ u"TestLabel1,TestData1", u"TestLabel2,TestData2", u"TestLabel3,TestData3", u"Response,456"])
389 self.assertEquals(len(msg.response.rrs), 1)
390 for rr in msg.response.rrs:
391 self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 3600)
392 self.assertEquals(socket.inet_ntop(socket.AF_INET, rr.rdata), '127.0.0.1')