]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/test_Trailing.py
Merge remote-tracking branch 'origin/master' into fix-boost-random-header
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Trailing.py
1 #!/usr/bin/env python
2 import threading
3 import dns
4 from dnsdisttests import DNSDistTest
5
6 class TestTrailingDataToBackend(DNSDistTest):
7
8 # this test suite uses a different responder port
9 # because, contrary to the other ones, its
10 # responders allow trailing data and we don't want
11 # to mix things up.
12 _testServerPort = 5360
13 _config_template = """
14 newServer{address="127.0.0.1:%s"}
15
16 function replaceTrailingData(dq)
17 local success = dq:setTrailingData("ABC")
18 if not success then
19 return DNSAction.ServFail, ""
20 end
21 return DNSAction.None, ""
22 end
23 addLuaAction("added.trailing.tests.powerdns.com.", replaceTrailingData)
24
25 function fillBuffer(dq)
26 local available = dq.size - dq.len
27 local tail = string.rep("A", available)
28 local success = dq:setTrailingData(tail)
29 if not success then
30 return DNSAction.ServFail, ""
31 end
32 return DNSAction.None, ""
33 end
34 addLuaAction("max.trailing.tests.powerdns.com.", fillBuffer)
35
36 function exceedBuffer(dq)
37 local available = dq.size - dq.len
38 local tail = string.rep("A", available + 1)
39 local success = dq:setTrailingData(tail)
40 if not success then
41 return DNSAction.ServFail, ""
42 end
43 return DNSAction.None, ""
44 end
45 addLuaAction("limited.trailing.tests.powerdns.com.", exceedBuffer)
46 """
47 @classmethod
48 def startResponders(cls):
49 print("Launching responders..")
50
51 # Respond REFUSED to queries with trailing data.
52 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
53 cls._UDPResponder.setDaemon(True)
54 cls._UDPResponder.start()
55
56 # Respond REFUSED to queries with trailing data.
57 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
58 cls._TCPResponder.setDaemon(True)
59 cls._TCPResponder.start()
60
61 def testTrailingPassthrough(self):
62 """
63 Trailing data: Pass through
64
65 """
66 name = 'passthrough.trailing.tests.powerdns.com.'
67 query = dns.message.make_query(name, 'A', 'IN')
68 response = dns.message.make_response(query)
69 rrset = dns.rrset.from_text(name,
70 3600,
71 dns.rdataclass.IN,
72 dns.rdatatype.A,
73 '127.0.0.1')
74 response.answer.append(rrset)
75 expectedResponse = dns.message.make_response(query)
76 expectedResponse.set_rcode(dns.rcode.REFUSED)
77
78 raw = query.to_wire()
79 raw = raw + b'A'* 20
80
81 for method in ("sendUDPQuery", "sendTCPQuery"):
82 sender = getattr(self, method)
83 # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
84 # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
85 (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
86 self.assertTrue(receivedQuery)
87 self.assertTrue(receivedResponse)
88 receivedQuery.id = query.id
89 self.assertEquals(receivedQuery, query)
90 self.assertEquals(receivedResponse, expectedResponse)
91
92 def testTrailingCapacity(self):
93 """
94 Trailing data: Fill buffer
95
96 """
97 name = 'max.trailing.tests.powerdns.com.'
98 query = dns.message.make_query(name, 'A', 'IN')
99 response = dns.message.make_response(query)
100 rrset = dns.rrset.from_text(name,
101 3600,
102 dns.rdataclass.IN,
103 dns.rdatatype.A,
104 '127.0.0.1')
105 response.answer.append(rrset)
106 expectedResponse = dns.message.make_response(query)
107 expectedResponse.set_rcode(dns.rcode.REFUSED)
108
109 for method in ("sendUDPQuery", "sendTCPQuery"):
110 sender = getattr(self, method)
111 # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
112 # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
113 (receivedQuery, receivedResponse) = sender(query, response)
114 self.assertTrue(receivedQuery)
115 self.assertTrue(receivedResponse)
116 receivedQuery.id = query.id
117 self.assertEquals(receivedQuery, query)
118 self.assertEquals(receivedResponse, expectedResponse)
119
120 def testTrailingLimited(self):
121 """
122 Trailing data: Reject buffer overflows
123
124 """
125 name = 'limited.trailing.tests.powerdns.com.'
126 query = dns.message.make_query(name, 'A', 'IN')
127 response = dns.message.make_response(query)
128 rrset = dns.rrset.from_text(name,
129 3600,
130 dns.rdataclass.IN,
131 dns.rdatatype.A,
132 '127.0.0.1')
133 response.answer.append(rrset)
134 expectedResponse = dns.message.make_response(query)
135 expectedResponse.set_rcode(dns.rcode.SERVFAIL)
136
137 for method in ("sendUDPQuery", "sendTCPQuery"):
138 sender = getattr(self, method)
139 # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
140 # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
141 (_, receivedResponse) = sender(query, response)
142 self.assertTrue(receivedResponse)
143 self.assertEquals(receivedResponse, expectedResponse)
144
145 def testTrailingAdded(self):
146 """
147 Trailing data: Add
148
149 """
150 name = 'added.trailing.tests.powerdns.com.'
151 query = dns.message.make_query(name, 'A', 'IN')
152 response = dns.message.make_response(query)
153 rrset = dns.rrset.from_text(name,
154 3600,
155 dns.rdataclass.IN,
156 dns.rdatatype.A,
157 '127.0.0.1')
158 response.answer.append(rrset)
159 expectedResponse = dns.message.make_response(query)
160 expectedResponse.set_rcode(dns.rcode.REFUSED)
161
162 for method in ("sendUDPQuery", "sendTCPQuery"):
163 sender = getattr(self, method)
164 # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
165 # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
166 (receivedQuery, receivedResponse) = sender(query, response)
167 self.assertTrue(receivedQuery)
168 self.assertTrue(receivedResponse)
169 receivedQuery.id = query.id
170 self.assertEquals(receivedQuery, query)
171 self.assertEquals(receivedResponse, expectedResponse)
172
173 class TestTrailingDataToDnsdist(DNSDistTest):
174 _config_template = """
175 newServer{address="127.0.0.1:%s"}
176
177 addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
178
179 function removeTrailingData(dq)
180 local success = dq:setTrailingData("")
181 if not success then
182 return DNSAction.ServFail, ""
183 end
184 return DNSAction.None, ""
185 end
186 addLuaAction("removed.trailing.tests.powerdns.com.", removeTrailingData)
187
188 function reportTrailingData(dq)
189 local tail = dq:getTrailingData()
190 return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com."
191 end
192 addLuaAction("echoed.trailing.tests.powerdns.com.", reportTrailingData)
193
194 function replaceTrailingData(dq)
195 local success = dq:setTrailingData("ABC")
196 if not success then
197 return DNSAction.ServFail, ""
198 end
199 return DNSAction.None, ""
200 end
201 addLuaAction("replaced.trailing.tests.powerdns.com.", replaceTrailingData)
202 addLuaAction("replaced.trailing.tests.powerdns.com.", reportTrailingData)
203
204 function reportTrailingHex(dq)
205 local tail = dq:getTrailingData()
206 local hex = string.gsub(tail, ".", function(ch)
207 return string.sub(string.format("\\x2502X", string.byte(ch)), -2)
208 end)
209 return DNSAction.Spoof, "-0x" .. hex .. ".echoed-hex.trailing.tests.powerdns.com."
210 end
211 addLuaAction("echoed-hex.trailing.tests.powerdns.com.", reportTrailingHex)
212
213 function replaceTrailingData_unsafe(dq)
214 local success = dq:setTrailingData("\\xB0\\x00\\xDE\\xADB\\xF0\\x9F\\x91\\xBB\\xC3\\xBE")
215 if not success then
216 return DNSAction.ServFail, ""
217 end
218 return DNSAction.None, ""
219 end
220 addLuaAction("replaced-unsafe.trailing.tests.powerdns.com.", replaceTrailingData_unsafe)
221 addLuaAction("replaced-unsafe.trailing.tests.powerdns.com.", reportTrailingHex)
222 """
223
224 def testTrailingDropped(self):
225 """
226 Trailing data: Drop query
227
228 """
229 name = 'dropped.trailing.tests.powerdns.com.'
230 query = dns.message.make_query(name, 'A', 'IN')
231 response = dns.message.make_response(query)
232 rrset = dns.rrset.from_text(name,
233 3600,
234 dns.rdataclass.IN,
235 dns.rdatatype.A,
236 '127.0.0.1')
237 response.answer.append(rrset)
238
239 raw = query.to_wire()
240 raw = raw + b'A'* 20
241
242 for method in ("sendUDPQuery", "sendTCPQuery"):
243 sender = getattr(self, method)
244
245 # Verify that queries with no trailing data make it through.
246 # (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
247 # (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
248 (receivedQuery, receivedResponse) = sender(query, response)
249 self.assertTrue(receivedQuery)
250 self.assertTrue(receivedResponse)
251 receivedQuery.id = query.id
252 self.assertEquals(query, receivedQuery)
253 self.assertEquals(response, receivedResponse)
254
255 # Verify that queries with trailing data don't make it through.
256 # (_, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
257 # (_, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
258 (_, receivedResponse) = sender(raw, response, rawQuery=True)
259 self.assertEquals(receivedResponse, None)
260
261 def testTrailingRemoved(self):
262 """
263 Trailing data: Remove
264
265 """
266 name = 'removed.trailing.tests.powerdns.com.'
267 query = dns.message.make_query(name, 'A', 'IN')
268 response = dns.message.make_response(query)
269 rrset = dns.rrset.from_text(name,
270 3600,
271 dns.rdataclass.IN,
272 dns.rdatatype.A,
273 '127.0.0.1')
274 response.answer.append(rrset)
275
276 raw = query.to_wire()
277 raw = raw + b'A'* 20
278
279 for method in ("sendUDPQuery", "sendTCPQuery"):
280 sender = getattr(self, method)
281 # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
282 # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
283 (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
284 self.assertTrue(receivedQuery)
285 self.assertTrue(receivedResponse)
286 receivedQuery.id = query.id
287 self.assertEquals(receivedQuery, query)
288 self.assertEquals(receivedResponse, response)
289
290 def testTrailingRead(self):
291 """
292 Trailing data: Echo
293
294 """
295 name = 'echoed.trailing.tests.powerdns.com.'
296 query = dns.message.make_query(name, 'A', 'IN')
297 response = dns.message.make_response(query)
298 response.set_rcode(dns.rcode.SERVFAIL)
299 expectedResponse = dns.message.make_response(query)
300 rrset = dns.rrset.from_text(name,
301 60,
302 dns.rdataclass.IN,
303 dns.rdatatype.CNAME,
304 '-TrailingData.echoed.trailing.tests.powerdns.com.')
305 expectedResponse.answer.append(rrset)
306
307 raw = query.to_wire()
308 raw = raw + b'TrailingData'
309
310 for method in ("sendUDPQuery", "sendTCPQuery"):
311 sender = getattr(self, method)
312 # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
313 # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
314 (_, receivedResponse) = sender(raw, response, rawQuery=True)
315 self.assertTrue(receivedResponse)
316 expectedResponse.flags = receivedResponse.flags
317 self.assertEquals(receivedResponse, expectedResponse)
318
319 def testTrailingReplaced(self):
320 """
321 Trailing data: Replace
322
323 """
324 name = 'replaced.trailing.tests.powerdns.com.'
325 query = dns.message.make_query(name, 'A', 'IN')
326 response = dns.message.make_response(query)
327 response.set_rcode(dns.rcode.SERVFAIL)
328 expectedResponse = dns.message.make_response(query)
329 rrset = dns.rrset.from_text(name,
330 60,
331 dns.rdataclass.IN,
332 dns.rdatatype.CNAME,
333 '-ABC.echoed.trailing.tests.powerdns.com.')
334 expectedResponse.answer.append(rrset)
335
336 raw = query.to_wire()
337 raw = raw + b'TrailingData'
338
339 for method in ("sendUDPQuery", "sendTCPQuery"):
340 sender = getattr(self, method)
341 # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
342 # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
343 (_, receivedResponse) = sender(raw, response, rawQuery=True)
344 self.assertTrue(receivedResponse)
345 expectedResponse.flags = receivedResponse.flags
346 self.assertEquals(receivedResponse, expectedResponse)
347
348 def testTrailingReadUnsafe(self):
349 """
350 Trailing data: Echo as hex
351
352 """
353 name = 'echoed-hex.trailing.tests.powerdns.com.'
354 query = dns.message.make_query(name, 'A', 'IN')
355 response = dns.message.make_response(query)
356 response.set_rcode(dns.rcode.SERVFAIL)
357 expectedResponse = dns.message.make_response(query)
358 rrset = dns.rrset.from_text(name,
359 60,
360 dns.rdataclass.IN,
361 dns.rdatatype.CNAME,
362 '-0x0000DEAD.echoed-hex.trailing.tests.powerdns.com.')
363 expectedResponse.answer.append(rrset)
364
365 raw = query.to_wire()
366 raw = raw + b'\x00\x00\xDE\xAD'
367
368 for method in ("sendUDPQuery", "sendTCPQuery"):
369 sender = getattr(self, method)
370 # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
371 # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
372 (_, receivedResponse) = sender(raw, response, rawQuery=True)
373 self.assertTrue(receivedResponse)
374 expectedResponse.flags = receivedResponse.flags
375 self.assertEquals(receivedResponse, expectedResponse)
376
377 def testTrailingReplacedUnsafe(self):
378 """
379 Trailing data: Replace with null and/or non-ASCII bytes
380
381 """
382 name = 'replaced-unsafe.trailing.tests.powerdns.com.'
383 query = dns.message.make_query(name, 'A', 'IN')
384 response = dns.message.make_response(query)
385 response.set_rcode(dns.rcode.SERVFAIL)
386 expectedResponse = dns.message.make_response(query)
387 rrset = dns.rrset.from_text(name,
388 60,
389 dns.rdataclass.IN,
390 dns.rdatatype.CNAME,
391 '-0xB000DEAD42F09F91BBC3BE.echoed-hex.trailing.tests.powerdns.com.')
392 expectedResponse.answer.append(rrset)
393
394 raw = query.to_wire()
395 raw = raw + b'TrailingData'
396
397 for method in ("sendUDPQuery", "sendTCPQuery"):
398 sender = getattr(self, method)
399 # (receivedQuery, receivedResponse) = self.sendUDPQuery(raw, response, rawQuery=True)
400 # (receivedQuery, receivedResponse) = self.sendTCPQuery(raw, response, rawQuery=True)
401 (_, receivedResponse) = sender(raw, response, rawQuery=True)
402 self.assertTrue(receivedResponse)
403 expectedResponse.flags = receivedResponse.flags
404 self.assertEquals(receivedResponse, expectedResponse)