]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.dnsdist/test_Trailing.py
dnsdist: enable doh3 in our CI
[thirdparty/pdns.git] / regression-tests.dnsdist / test_Trailing.py
1 #!/usr/bin/env python
2 import threading
3 import dns
4 from dnsdisttests import DNSDistTest, pickAvailablePort
5
6 class TestTrailingDataToBackend(DNSDistTest):
7
8 # this test suite uses a different responder port
9 # because, contrary to the other ones, its
10 # responders allow trailing data and we don't want
11 # to mix things up.
12 _testServerPort = pickAvailablePort()
13 _verboseMode = True
14 _config_template = """
15 newServer{address="127.0.0.1:%s"}
16
17 function replaceTrailingData(dq)
18 local success = dq:setTrailingData("ABC")
19 if not success then
20 return DNSAction.ServFail, ""
21 end
22 return DNSAction.None, ""
23 end
24 addAction("added.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
25
26 function fillBuffer(dq)
27 local available = 4096 - dq.len
28 if dq.tcp then
29 available = 65535 - dq.len
30 end
31 local tail = string.rep("A", available)
32 local success = dq:setTrailingData(tail)
33 if not success then
34 return DNSAction.ServFail, ""
35 end
36 return DNSAction.None, ""
37 end
38 addAction("max.trailing.tests.powerdns.com.", LuaAction(fillBuffer))
39
40 function exceedBuffer(dq)
41 local available = 4096 - dq.len
42 if dq.tcp then
43 available = 65535 - dq.len
44 end
45 local tail = string.rep("A", available + 1)
46 local success = dq:setTrailingData(tail)
47 if not success then
48 return DNSAction.ServFail, ""
49 end
50 return DNSAction.None, ""
51 end
52 addAction("limited.trailing.tests.powerdns.com.", LuaAction(exceedBuffer))
53 """
54 @classmethod
55 def startResponders(cls):
56 print("Launching responders..")
57
58 # Respond REFUSED to queries with trailing data.
59 cls._UDPResponder = threading.Thread(name='UDP Responder', target=cls.UDPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
60 cls._UDPResponder.daemon = True
61 cls._UDPResponder.start()
62
63 # Respond REFUSED to queries with trailing data.
64 cls._TCPResponder = threading.Thread(name='TCP Responder', target=cls.TCPResponder, args=[cls._testServerPort, cls._toResponderQueue, cls._fromResponderQueue, dns.rcode.REFUSED])
65 cls._TCPResponder.daemon = True
66 cls._TCPResponder.start()
67
68 def testTrailingPassthrough(self):
69 """
70 Trailing data: Pass through
71
72 """
73 name = 'passthrough.trailing.tests.powerdns.com.'
74 query = dns.message.make_query(name, 'A', 'IN')
75 response = dns.message.make_response(query)
76 rrset = dns.rrset.from_text(name,
77 3600,
78 dns.rdataclass.IN,
79 dns.rdatatype.A,
80 '127.0.0.1')
81 response.answer.append(rrset)
82 expectedResponse = dns.message.make_response(query)
83 expectedResponse.set_rcode(dns.rcode.REFUSED)
84
85 raw = query.to_wire()
86 raw = raw + b'A'* 20
87
88 for method in ("sendUDPQuery", "sendTCPQuery"):
89 sender = getattr(self, method)
90 (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
91 self.assertTrue(receivedQuery)
92 self.assertTrue(receivedResponse)
93 receivedQuery.id = query.id
94 self.assertEqual(receivedQuery, query)
95 self.assertEqual(receivedResponse, expectedResponse)
96
97 def testTrailingCapacity(self):
98 """
99 Trailing data: Fill buffer
100
101 """
102 name = 'max.trailing.tests.powerdns.com.'
103 query = dns.message.make_query(name, 'A', 'IN')
104 response = dns.message.make_response(query)
105 rrset = dns.rrset.from_text(name,
106 3600,
107 dns.rdataclass.IN,
108 dns.rdatatype.A,
109 '127.0.0.1')
110 response.answer.append(rrset)
111 expectedResponse = dns.message.make_response(query)
112 expectedResponse.set_rcode(dns.rcode.REFUSED)
113
114 for method in ("sendUDPQuery", "sendTCPQuery"):
115 sender = getattr(self, method)
116 (receivedQuery, receivedResponse) = sender(query, response)
117 self.assertTrue(receivedQuery)
118 self.assertTrue(receivedResponse)
119 receivedQuery.id = query.id
120 self.assertEqual(receivedQuery, query)
121 self.assertEqual(receivedResponse, expectedResponse)
122
123 def testTrailingLimited(self):
124 """
125 Trailing data: Reject buffer overflows
126
127 """
128 name = 'limited.trailing.tests.powerdns.com.'
129 query = dns.message.make_query(name, 'A', 'IN')
130 response = dns.message.make_response(query)
131 rrset = dns.rrset.from_text(name,
132 3600,
133 dns.rdataclass.IN,
134 dns.rdatatype.A,
135 '127.0.0.1')
136 response.answer.append(rrset)
137 expectedResponse = dns.message.make_response(query)
138 expectedResponse.set_rcode(dns.rcode.SERVFAIL)
139
140 for method in ("sendUDPQuery", "sendTCPQuery"):
141 sender = getattr(self, method)
142 (_, receivedResponse) = sender(query, response, useQueue=False)
143 self.assertTrue(receivedResponse)
144 self.assertEqual(receivedResponse, expectedResponse)
145
146 def testTrailingAdded(self):
147 """
148 Trailing data: Add
149
150 """
151 name = 'added.trailing.tests.powerdns.com.'
152 query = dns.message.make_query(name, 'A', 'IN')
153 response = dns.message.make_response(query)
154 rrset = dns.rrset.from_text(name,
155 3600,
156 dns.rdataclass.IN,
157 dns.rdatatype.A,
158 '127.0.0.1')
159 response.answer.append(rrset)
160 expectedResponse = dns.message.make_response(query)
161 expectedResponse.set_rcode(dns.rcode.REFUSED)
162
163 for method in ("sendUDPQuery", "sendTCPQuery"):
164 sender = getattr(self, method)
165 (receivedQuery, receivedResponse) = sender(query, response)
166 self.assertTrue(receivedQuery)
167 self.assertTrue(receivedResponse)
168 receivedQuery.id = query.id
169 self.assertEqual(receivedQuery, query)
170 self.assertEqual(receivedResponse, expectedResponse)
171
172 class TestTrailingDataToDnsdist(DNSDistTest):
173 _verboseMode = True
174 _config_template = """
175 newServer{address="127.0.0.1:%s"}
176
177 addAction(AndRule({QNameRule("dropped.trailing.tests.powerdns.com."), TrailingDataRule()}), DropAction())
178
179 function removeTrailingData(dq)
180 local success = dq:setTrailingData("")
181 if not success then
182 print("Trailing removal failed")
183 return DNSAction.ServFail, ""
184 end
185 return DNSAction.None, ""
186 end
187 addAction("removed.trailing.tests.powerdns.com.", LuaAction(removeTrailingData))
188
189 function reportTrailingData(dq)
190 local tail = dq:getTrailingData()
191 return DNSAction.Spoof, "-" .. tail .. ".echoed.trailing.tests.powerdns.com."
192 end
193 addAction("echoed.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
194
195 function replaceTrailingData(dq)
196 local success = dq:setTrailingData("ABC")
197 if not success then
198 return DNSAction.ServFail, ""
199 end
200 return DNSAction.None, ""
201 end
202 addAction("replaced.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData))
203 addAction("replaced.trailing.tests.powerdns.com.", LuaAction(reportTrailingData))
204
205 function reportTrailingHex(dq)
206 local tail = dq:getTrailingData()
207 local hex = string.gsub(tail, ".", function(ch)
208 return string.sub(string.format("\\x2502X", string.byte(ch)), -2)
209 end)
210 return DNSAction.Spoof, "-0x" .. hex .. ".echoed-hex.trailing.tests.powerdns.com."
211 end
212 addAction("echoed-hex.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
213
214 function replaceTrailingData_unsafe(dq)
215 local success = dq:setTrailingData("\\xB0\\x00\\xDE\\xADB\\xF0\\x9F\\x91\\xBB\\xC3\\xBE")
216 if not success then
217 return DNSAction.ServFail, ""
218 end
219 return DNSAction.None, ""
220 end
221 addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData_unsafe))
222 addAction("replaced-unsafe.trailing.tests.powerdns.com.", LuaAction(reportTrailingHex))
223 """
224
225 def testTrailingDropped(self):
226 """
227 Trailing data: Drop query
228
229 """
230 name = 'dropped.trailing.tests.powerdns.com.'
231 query = dns.message.make_query(name, 'A', 'IN')
232 response = dns.message.make_response(query)
233 rrset = dns.rrset.from_text(name,
234 3600,
235 dns.rdataclass.IN,
236 dns.rdatatype.A,
237 '127.0.0.1')
238 response.answer.append(rrset)
239
240 raw = query.to_wire()
241 raw = raw + b'A'* 20
242
243 for method in ("sendUDPQuery", "sendTCPQuery"):
244 sender = getattr(self, method)
245
246 # Verify that queries with no trailing data make it through.
247 (receivedQuery, receivedResponse) = sender(query, response)
248 self.assertTrue(receivedQuery)
249 self.assertTrue(receivedResponse)
250 receivedQuery.id = query.id
251 self.assertEqual(query, receivedQuery)
252 self.assertEqual(response, receivedResponse)
253
254 # Verify that queries with trailing data don't make it through.
255 (_, receivedResponse) = sender(raw, response, rawQuery=True, useQueue=False)
256 self.assertEqual(receivedResponse, None)
257
258 def testTrailingRemoved(self):
259 """
260 Trailing data: Remove
261
262 """
263 name = 'removed.trailing.tests.powerdns.com.'
264 query = dns.message.make_query(name, 'A', 'IN')
265 response = dns.message.make_response(query)
266 rrset = dns.rrset.from_text(name,
267 3600,
268 dns.rdataclass.IN,
269 dns.rdatatype.A,
270 '127.0.0.1')
271 response.answer.append(rrset)
272
273 raw = query.to_wire()
274 raw = raw + b'A'* 20
275
276 for method in ("sendUDPQuery", "sendTCPQuery"):
277 sender = getattr(self, method)
278 (receivedQuery, receivedResponse) = sender(raw, response, rawQuery=True)
279 self.assertTrue(receivedQuery)
280 self.assertTrue(receivedResponse)
281 receivedQuery.id = query.id
282 self.assertEqual(receivedQuery, query)
283 self.assertEqual(receivedResponse, response)
284
285 def testTrailingRead(self):
286 """
287 Trailing data: Echo
288
289 """
290 name = 'echoed.trailing.tests.powerdns.com.'
291 query = dns.message.make_query(name, 'A', 'IN')
292 response = dns.message.make_response(query)
293 response.set_rcode(dns.rcode.SERVFAIL)
294 expectedResponse = dns.message.make_response(query)
295 rrset = dns.rrset.from_text(name,
296 60,
297 dns.rdataclass.IN,
298 dns.rdatatype.CNAME,
299 '-TrailingData.echoed.trailing.tests.powerdns.com.')
300 expectedResponse.answer.append(rrset)
301
302 raw = query.to_wire()
303 raw = raw + b'TrailingData'
304
305 for method in ("sendUDPQuery", "sendTCPQuery"):
306 sender = getattr(self, method)
307 (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
308 self.assertTrue(receivedResponse)
309 expectedResponse.flags = receivedResponse.flags
310 self.assertEqual(receivedResponse, expectedResponse)
311
312 def testTrailingReplaced(self):
313 """
314 Trailing data: Replace
315
316 """
317 name = 'replaced.trailing.tests.powerdns.com.'
318 query = dns.message.make_query(name, 'A', 'IN')
319 response = dns.message.make_response(query)
320 response.set_rcode(dns.rcode.SERVFAIL)
321 expectedResponse = dns.message.make_response(query)
322 rrset = dns.rrset.from_text(name,
323 60,
324 dns.rdataclass.IN,
325 dns.rdatatype.CNAME,
326 '-ABC.echoed.trailing.tests.powerdns.com.')
327 expectedResponse.answer.append(rrset)
328
329 raw = query.to_wire()
330 raw = raw + b'TrailingData'
331
332 for method in ("sendUDPQuery", "sendTCPQuery"):
333 sender = getattr(self, method)
334 (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
335 self.assertTrue(receivedResponse)
336 expectedResponse.flags = receivedResponse.flags
337 self.assertEqual(receivedResponse, expectedResponse)
338
339 def testTrailingReadUnsafe(self):
340 """
341 Trailing data: Echo as hex
342
343 """
344 name = 'echoed-hex.trailing.tests.powerdns.com.'
345 query = dns.message.make_query(name, 'A', 'IN')
346 response = dns.message.make_response(query)
347 response.set_rcode(dns.rcode.SERVFAIL)
348 expectedResponse = dns.message.make_response(query)
349 rrset = dns.rrset.from_text(name,
350 60,
351 dns.rdataclass.IN,
352 dns.rdatatype.CNAME,
353 '-0x0000DEAD.echoed-hex.trailing.tests.powerdns.com.')
354 expectedResponse.answer.append(rrset)
355
356 raw = query.to_wire()
357 raw = raw + b'\x00\x00\xDE\xAD'
358
359 for method in ("sendUDPQuery", "sendTCPQuery"):
360 sender = getattr(self, method)
361 (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
362 self.assertTrue(receivedResponse)
363 expectedResponse.flags = receivedResponse.flags
364 self.assertEqual(receivedResponse, expectedResponse)
365
366 def testTrailingReplacedUnsafe(self):
367 """
368 Trailing data: Replace with null and/or non-ASCII bytes
369
370 """
371 name = 'replaced-unsafe.trailing.tests.powerdns.com.'
372 query = dns.message.make_query(name, 'A', 'IN')
373 response = dns.message.make_response(query)
374 response.set_rcode(dns.rcode.SERVFAIL)
375 expectedResponse = dns.message.make_response(query)
376 rrset = dns.rrset.from_text(name,
377 60,
378 dns.rdataclass.IN,
379 dns.rdatatype.CNAME,
380 '-0xB000DEAD42F09F91BBC3BE.echoed-hex.trailing.tests.powerdns.com.')
381 expectedResponse.answer.append(rrset)
382
383 raw = query.to_wire()
384 raw = raw + b'TrailingData'
385
386 for method in ("sendUDPQuery", "sendTCPQuery"):
387 sender = getattr(self, method)
388 (_, receivedResponse) = sender(raw, response=None, rawQuery=True, useQueue=False)
389 self.assertTrue(receivedResponse)
390 expectedResponse.flags = receivedResponse.flags
391 self.assertEqual(receivedResponse, expectedResponse)