]>
git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/test_Protobuf.py
7 from dnsdisttests
import DNSDistTest
, Queue
13 class TestProtobuf(DNSDistTest
):
14 _protobufServerPort
= 4242
15 _protobufQueue
= Queue()
16 _protobufServerID
= 'dnsdist-server-1'
18 _config_params
= ['_testServerPort', '_protobufServerPort', '_protobufServerID', '_protobufServerID']
19 _config_template
= """
20 luasmn = newSuffixMatchNode()
21 luasmn:add(newDNSName('lua.protobuf.tests.powerdns.com.'))
23 function alterProtobufResponse(dq, protobuf)
24 if luasmn:check(dq.qname) then
25 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
26 if requestor:isIPv4() then
27 requestor:truncate(24)
29 requestor:truncate(56)
31 protobuf:setRequestor(requestor)
34 table.insert(tableTags, "TestLabel1,TestData1")
35 table.insert(tableTags, "TestLabel2,TestData2")
37 protobuf:setTagArray(tableTags)
39 protobuf:setTag('TestLabel3,TestData3')
41 protobuf:setTag("Response,456")
45 local tableTags = {} -- called by testProtobuf()
46 table.insert(tableTags, "TestLabel1,TestData1")
47 table.insert(tableTags, "TestLabel2,TestData2")
48 protobuf:setTagArray(tableTags)
50 protobuf:setTag('TestLabel3,TestData3')
52 protobuf:setTag("Response,456")
57 function alterProtobufQuery(dq, protobuf)
59 if luasmn:check(dq.qname) then
60 requestor = newCA(dq.remoteaddr:toString()) -- called by testLuaProtobuf()
61 if requestor:isIPv4() then
62 requestor:truncate(24)
64 requestor:truncate(56)
66 protobuf:setRequestor(requestor)
69 tableTags = dq:getTagArray() -- get table from DNSQuery
72 for k, v in pairs( tableTags) do
73 table.insert(tablePB, k .. "," .. v)
76 protobuf:setTagArray(tablePB) -- store table in protobuf
77 protobuf:setTag("Query,123") -- add another tag entry in protobuf
79 protobuf:setResponseCode(dnsdist.NXDOMAIN) -- set protobuf response code to be NXDOMAIN
81 local strReqName = dq.qname:toString() -- get request dns name
83 protobuf:setProtobufResponseType() -- set protobuf to look like a response and not a query, with 0 default time
85 blobData = '\127' .. '\000' .. '\000' .. '\001' -- 127.0.0.1, note: lua 5.1 can only embed decimal not hex
87 protobuf:addResponseRR(strReqName, 1, 1, 123, blobData) -- add a RR to the protobuf
89 protobuf:setBytes(65) -- set the size of the query to confirm in checkProtobufBase
93 local tableTags = {} -- called by testProtobuf()
94 table.insert(tableTags, "TestLabel1,TestData1")
95 table.insert(tableTags, "TestLabel2,TestData2")
97 protobuf:setTagArray(tableTags)
98 protobuf:setTag('TestLabel3,TestData3')
99 protobuf:setTag("Query,123")
104 function alterLuaFirst(dq) -- called when dnsdist receives new request
106 tt["TestLabel1"] = "TestData1"
107 tt["TestLabel2"] = "TestData2"
111 dq:setTag("TestLabel3","TestData3")
112 return DNSAction.None, "" -- continue to the next rule
115 newServer{address="127.0.0.1:%s", useClientSubnet=true}
116 rl = newRemoteLogger('127.0.0.1:%s')
118 addAction(AllRule(), LuaAction(alterLuaFirst)) -- Add tags to DNSQuery first
120 addAction(AllRule(), RemoteLogAction(rl, alterProtobufQuery, {serverID='%s'})) -- Send protobuf message before lookup
122 addResponseAction(AllRule(), RemoteLogResponseAction(rl, alterProtobufResponse, true, {serverID='%s'})) -- Send protobuf message after lookup
127 def ProtobufListener(cls
, port
):
128 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
129 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEPORT
, 1)
131 sock
.bind(("127.0.0.1", port
))
132 except socket
.error
as e
:
133 print("Error binding in the protbuf listener: %s" % str(e
))
138 (conn
, _
) = sock
.accept()
144 (datalen
,) = struct
.unpack("!H", data
)
145 data
= conn
.recv(datalen
)
149 cls
._protobufQueue
.put(data
, True, timeout
=2.0)
155 def startResponders(cls
):
156 cls
._UDPResponder
= threading
.Thread(name
='UDP Responder', target
=cls
.UDPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
157 cls
._UDPResponder
.setDaemon(True)
158 cls
._UDPResponder
.start()
160 cls
._TCPResponder
= threading
.Thread(name
='TCP Responder', target
=cls
.TCPResponder
, args
=[cls
._testServerPort
, cls
._toResponderQueue
, cls
._fromResponderQueue
])
161 cls
._TCPResponder
.setDaemon(True)
162 cls
._TCPResponder
.start()
164 cls
._protobufListener
= threading
.Thread(name
='Protobuf Listener', target
=cls
.ProtobufListener
, args
=[cls
._protobufServerPort
])
165 cls
._protobufListener
.setDaemon(True)
166 cls
._protobufListener
.start()
168 def getFirstProtobufMessage(self
):
169 self
.assertFalse(self
._protobufQueue
.empty())
170 data
= self
._protobufQueue
.get(False)
171 self
.assertTrue(data
)
172 msg
= dnsmessage_pb2
.PBDNSMessage()
173 msg
.ParseFromString(data
)
176 def checkProtobufBase(self
, msg
, protocol
, query
, initiator
, normalQueryResponse
=True):
178 self
.assertTrue(msg
.HasField('timeSec'))
179 self
.assertTrue(msg
.HasField('socketFamily'))
180 self
.assertEquals(msg
.socketFamily
, dnsmessage_pb2
.PBDNSMessage
.INET
)
181 self
.assertTrue(msg
.HasField('from'))
182 fromvalue
= getattr(msg
, 'from')
183 self
.assertEquals(socket
.inet_ntop(socket
.AF_INET
, fromvalue
), initiator
)
184 self
.assertTrue(msg
.HasField('socketProtocol'))
185 self
.assertEquals(msg
.socketProtocol
, protocol
)
186 self
.assertTrue(msg
.HasField('messageId'))
187 self
.assertTrue(msg
.HasField('id'))
188 self
.assertEquals(msg
.id, query
.id)
189 self
.assertTrue(msg
.HasField('inBytes'))
190 self
.assertTrue(msg
.HasField('serverIdentity'))
191 self
.assertEquals(msg
.serverIdentity
, self
._protobufServerID
)
193 if normalQueryResponse
:
194 # compare inBytes with length of query/response
195 self
.assertEquals(msg
.inBytes
, len(query
.to_wire()))
196 # dnsdist doesn't set the existing EDNS Subnet for now,
197 # although it might be set from Lua
198 # self.assertTrue(msg.HasField('originalRequestorSubnet'))
199 # self.assertEquals(len(msg.originalRequestorSubnet), 4)
200 # self.assertEquals(socket.inet_ntop(socket.AF_INET, msg.originalRequestorSubnet), '127.0.0.1')
202 def checkProtobufQuery(self
, msg
, protocol
, query
, qclass
, qtype
, qname
, initiator
='127.0.0.1'):
203 self
.assertEquals(msg
.type, dnsmessage_pb2
.PBDNSMessage
.DNSQueryType
)
204 self
.checkProtobufBase(msg
, protocol
, query
, initiator
)
205 # dnsdist doesn't fill the responder field for responses
206 # because it doesn't keep the information around.
207 self
.assertTrue(msg
.HasField('to'))
208 self
.assertEquals(socket
.inet_ntop(socket
.AF_INET
, msg
.to
), '127.0.0.1')
209 self
.assertTrue(msg
.HasField('question'))
210 self
.assertTrue(msg
.question
.HasField('qClass'))
211 self
.assertEquals(msg
.question
.qClass
, qclass
)
212 self
.assertTrue(msg
.question
.HasField('qType'))
213 self
.assertEquals(msg
.question
.qClass
, qtype
)
214 self
.assertTrue(msg
.question
.HasField('qName'))
215 self
.assertEquals(msg
.question
.qName
, qname
)
217 def checkProtobufTags(self
, tags
, expectedTags
):
218 # only differences will be in new list
219 listx
= set(tags
) ^
set(expectedTags
)
220 # exclusive or of lists should be empty
221 self
.assertEqual(len(listx
), 0, "Protobuf tags don't match")
223 def checkProtobufQueryConvertedToResponse(self
, msg
, protocol
, response
, initiator
='127.0.0.0'):
224 self
.assertEquals(msg
.type, dnsmessage_pb2
.PBDNSMessage
.DNSResponseType
)
225 # skip comparing inBytes (size of the query) with the length of the generated response
226 self
.checkProtobufBase(msg
, protocol
, response
, initiator
, False)
227 self
.assertTrue(msg
.HasField('response'))
228 self
.assertTrue(msg
.response
.HasField('queryTimeSec'))
230 def checkProtobufResponse(self
, msg
, protocol
, response
, initiator
='127.0.0.1'):
231 self
.assertEquals(msg
.type, dnsmessage_pb2
.PBDNSMessage
.DNSResponseType
)
232 self
.checkProtobufBase(msg
, protocol
, response
, initiator
)
233 self
.assertTrue(msg
.HasField('response'))
234 self
.assertTrue(msg
.response
.HasField('queryTimeSec'))
236 def checkProtobufResponseRecord(self
, record
, rclass
, rtype
, rname
, rttl
):
237 self
.assertTrue(record
.HasField('class'))
238 self
.assertEquals(getattr(record
, 'class'), rclass
)
239 self
.assertTrue(record
.HasField('type'))
240 self
.assertEquals(record
.type, rtype
)
241 self
.assertTrue(record
.HasField('name'))
242 self
.assertEquals(record
.name
, rname
)
243 self
.assertTrue(record
.HasField('ttl'))
244 self
.assertEquals(record
.ttl
, rttl
)
245 self
.assertTrue(record
.HasField('rdata'))
247 def testProtobuf(self
):
249 Protobuf: Send data to a protobuf server
251 name
= 'query.protobuf.tests.powerdns.com.'
253 target
= 'target.protobuf.tests.powerdns.com.'
254 query
= dns
.message
.make_query(name
, 'A', 'IN')
255 response
= dns
.message
.make_response(query
)
257 rrset
= dns
.rrset
.from_text(name
,
262 response
.answer
.append(rrset
)
264 rrset
= dns
.rrset
.from_text(target
,
269 response
.answer
.append(rrset
)
271 (receivedQuery
, receivedResponse
) = self
.sendUDPQuery(query
, response
)
272 self
.assertTrue(receivedQuery
)
273 self
.assertTrue(receivedResponse
)
274 receivedQuery
.id = query
.id
275 self
.assertEquals(query
, receivedQuery
)
276 self
.assertEquals(response
, receivedResponse
)
278 # let the protobuf messages the time to get there
281 # check the protobuf message corresponding to the UDP query
282 msg
= self
.getFirstProtobufMessage()
284 self
.checkProtobufQuery(msg
, dnsmessage_pb2
.PBDNSMessage
.UDP
, query
, dns
.rdataclass
.IN
, dns
.rdatatype
.A
, name
)
285 self
.checkProtobufTags(msg
.response
.tags
, [u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Query,123"])
287 # check the protobuf message corresponding to the UDP response
288 msg
= self
.getFirstProtobufMessage()
289 self
.checkProtobufResponse(msg
, dnsmessage_pb2
.PBDNSMessage
.UDP
, response
)
290 self
.checkProtobufTags(msg
.response
.tags
, [ u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Response,456"])
291 self
.assertEquals(len(msg
.response
.rrs
), 2)
292 rr
= msg
.response
.rrs
[0]
293 self
.checkProtobufResponseRecord(rr
, dns
.rdataclass
.IN
, dns
.rdatatype
.CNAME
, name
, 3600)
294 self
.assertEquals(rr
.rdata
, target
)
295 rr
= msg
.response
.rrs
[1]
296 self
.checkProtobufResponseRecord(rr
, dns
.rdataclass
.IN
, dns
.rdatatype
.A
, target
, 3600)
297 self
.assertEquals(socket
.inet_ntop(socket
.AF_INET
, rr
.rdata
), '127.0.0.1')
299 (receivedQuery
, receivedResponse
) = self
.sendTCPQuery(query
, response
)
300 self
.assertTrue(receivedQuery
)
301 self
.assertTrue(receivedResponse
)
302 receivedQuery
.id = query
.id
303 self
.assertEquals(query
, receivedQuery
)
304 self
.assertEquals(response
, receivedResponse
)
306 # let the protobuf messages the time to get there
309 # check the protobuf message corresponding to the TCP query
310 msg
= self
.getFirstProtobufMessage()
312 self
.checkProtobufQuery(msg
, dnsmessage_pb2
.PBDNSMessage
.TCP
, query
, dns
.rdataclass
.IN
, dns
.rdatatype
.A
, name
)
313 self
.checkProtobufTags(msg
.response
.tags
, [u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Query,123"])
315 # check the protobuf message corresponding to the TCP response
316 msg
= self
.getFirstProtobufMessage()
317 self
.checkProtobufResponse(msg
, dnsmessage_pb2
.PBDNSMessage
.TCP
, response
)
318 self
.checkProtobufTags(msg
.response
.tags
, [ u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Response,456"])
319 self
.assertEquals(len(msg
.response
.rrs
), 2)
320 rr
= msg
.response
.rrs
[0]
321 self
.checkProtobufResponseRecord(rr
, dns
.rdataclass
.IN
, dns
.rdatatype
.CNAME
, name
, 3600)
322 self
.assertEquals(rr
.rdata
, target
)
323 rr
= msg
.response
.rrs
[1]
324 self
.checkProtobufResponseRecord(rr
, dns
.rdataclass
.IN
, dns
.rdatatype
.A
, target
, 3600)
325 self
.assertEquals(socket
.inet_ntop(socket
.AF_INET
, rr
.rdata
), '127.0.0.1')
327 def testLuaProtobuf(self
):
330 Protobuf: Check that the Lua callback rewrote the initiator
332 name
= 'lua.protobuf.tests.powerdns.com.'
333 query
= dns
.message
.make_query(name
, 'A', 'IN')
334 response
= dns
.message
.make_response(query
)
335 rrset
= dns
.rrset
.from_text(name
,
340 response
.answer
.append(rrset
)
343 (receivedQuery
, receivedResponse
) = self
.sendUDPQuery(query
, response
)
345 self
.assertTrue(receivedQuery
)
346 self
.assertTrue(receivedResponse
)
347 receivedQuery
.id = query
.id
348 self
.assertEquals(query
, receivedQuery
)
349 self
.assertEquals(response
, receivedResponse
)
352 # let the protobuf messages the time to get there
355 # check the protobuf message corresponding to the UDP query
356 msg
= self
.getFirstProtobufMessage()
358 self
.checkProtobufQueryConvertedToResponse(msg
, dnsmessage_pb2
.PBDNSMessage
.UDP
, response
, '127.0.0.0')
359 self
.checkProtobufTags(msg
.response
.tags
, [ u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Query,123"])
361 # check the protobuf message corresponding to the UDP response
362 msg
= self
.getFirstProtobufMessage()
363 self
.checkProtobufResponse(msg
, dnsmessage_pb2
.PBDNSMessage
.UDP
, response
, '127.0.0.0')
364 self
.checkProtobufTags(msg
.response
.tags
, [ u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Response,456"])
365 self
.assertEquals(len(msg
.response
.rrs
), 1)
366 for rr
in msg
.response
.rrs
:
367 self
.checkProtobufResponseRecord(rr
, dns
.rdataclass
.IN
, dns
.rdatatype
.A
, name
, 3600)
368 self
.assertEquals(socket
.inet_ntop(socket
.AF_INET
, rr
.rdata
), '127.0.0.1')
370 (receivedQuery
, receivedResponse
) = self
.sendTCPQuery(query
, response
)
371 self
.assertTrue(receivedQuery
)
372 self
.assertTrue(receivedResponse
)
373 receivedQuery
.id = query
.id
374 self
.assertEquals(query
, receivedQuery
)
375 self
.assertEquals(response
, receivedResponse
)
377 # let the protobuf messages the time to get there
380 # check the protobuf message corresponding to the TCP query
381 msg
= self
.getFirstProtobufMessage()
382 self
.checkProtobufQueryConvertedToResponse(msg
, dnsmessage_pb2
.PBDNSMessage
.TCP
, response
, '127.0.0.0')
383 self
.checkProtobufTags(msg
.response
.tags
, [ u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Query,123"])
385 # check the protobuf message corresponding to the TCP response
386 msg
= self
.getFirstProtobufMessage()
387 self
.checkProtobufResponse(msg
, dnsmessage_pb2
.PBDNSMessage
.TCP
, response
, '127.0.0.0')
388 self
.checkProtobufTags(msg
.response
.tags
, [ u
"TestLabel1,TestData1", u
"TestLabel2,TestData2", u
"TestLabel3,TestData3", u
"Response,456"])
389 self
.assertEquals(len(msg
.response
.rrs
), 1)
390 for rr
in msg
.response
.rrs
:
391 self
.checkProtobufResponseRecord(rr
, dns
.rdataclass
.IN
, dns
.rdatatype
.A
, name
, 3600)
392 self
.assertEquals(socket
.inet_ntop(socket
.AF_INET
, rr
.rdata
), '127.0.0.1')