]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.recursor-dnssec/test_RPZ.py
Merge pull request #7335 from jsoref/issue-5140
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_RPZ.py
1 import dns
2 import json
3 import os
4 import requests
5 import socket
6 import struct
7 import sys
8 import threading
9 import time
10
11 from recursortests import RecursorTest
12
13 class RPZServer(object):
14
15 def __init__(self, port):
16 self._currentSerial = 0
17 self._targetSerial = 1
18 self._serverPort = port
19 listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
20 listener.setDaemon(True)
21 listener.start()
22
23 def getCurrentSerial(self):
24 return self._currentSerial
25
26 def moveToSerial(self, newSerial):
27 if newSerial == self._currentSerial:
28 return False
29
30 if newSerial != self._currentSerial + 1:
31 raise AssertionError("Asking the RPZ server to server serial %d, already serving %d" % (newSerial, self._currentSerial))
32 self._targetSerial = newSerial
33 return True
34
35 def _getAnswer(self, message):
36
37 response = dns.message.make_response(message)
38 records = []
39
40 if message.question[0].rdtype == dns.rdatatype.AXFR:
41 if self._currentSerial != 0:
42 print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
43 return (None, self._currentSerial)
44
45 newSerial = self._targetSerial
46 records = [
47 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
48 dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
49 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
50 ]
51
52 elif message.question[0].rdtype == dns.rdatatype.IXFR:
53 oldSerial = message.authority[0][0].serial
54
55 if oldSerial != self._currentSerial:
56 print('Received an IXFR query with an unexpected serial %d, expected %d' % (oldSerial, self._currentSerial))
57 return (None, self._currentSerial)
58
59 newSerial = self._targetSerial
60 if newSerial == 2:
61 records = [
62 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
63 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
64 # no deletion
65 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
66 dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
67 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
68 ]
69 elif newSerial == 3:
70 records = [
71 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
72 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
73 dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
74 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
75 # no addition
76 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
77 ]
78 elif newSerial == 4:
79 records = [
80 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
81 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
82 dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
83 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
84 dns.rrset.from_text('c.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
85 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
86 ]
87 elif newSerial == 5:
88 # this one is a bit special, we are answering with a full AXFR
89 records = [
90 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
91 dns.rrset.from_text('d.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
92 dns.rrset.from_text('tc.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.CNAME, 'rpz-tcp-only.'),
93 dns.rrset.from_text('drop.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.CNAME, 'rpz-drop.'),
94 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
95 ]
96 elif newSerial == 6:
97 # back to IXFR
98 records = [
99 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
100 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
101 dns.rrset.from_text('d.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
102 dns.rrset.from_text('tc.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.CNAME, 'rpz-tcp-only.'),
103 dns.rrset.from_text('drop.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.CNAME, 'rpz-drop.'),
104 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
105 dns.rrset.from_text('e.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1', '192.0.2.2'),
106 dns.rrset.from_text('e.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.MX, '10 mx.example.'),
107 dns.rrset.from_text('f.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.CNAME, 'e.example.'),
108 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
109 ]
110 elif newSerial == 7:
111 records = [
112 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
113 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
114 dns.rrset.from_text('e.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1', '192.0.2.2'),
115 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
116 dns.rrset.from_text('e.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.2'),
117 dns.rrset.from_text('tc.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.CNAME, 'rpz-tcp-only.'),
118 dns.rrset.from_text('drop.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.CNAME, 'rpz-drop.'),
119 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
120 ]
121
122 response.answer = records
123 return (newSerial, response)
124
125 def _connectionHandler(self, conn):
126 data = None
127 while True:
128 data = conn.recv(2)
129 if not data:
130 break
131 (datalen,) = struct.unpack("!H", data)
132 data = conn.recv(datalen)
133 if not data:
134 break
135
136 message = dns.message.from_wire(data)
137 if len(message.question) != 1:
138 print('Invalid RPZ query, qdcount is %d' % (len(message.question)))
139 break
140 if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
141 print('Invalid RPZ query, qtype is %d' % (message.question.rdtype))
142 break
143 (serial, answer) = self._getAnswer(message)
144 if not answer:
145 print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype))
146 break
147
148 wire = answer.to_wire()
149 conn.send(struct.pack("!H", len(wire)))
150 conn.send(wire)
151 self._currentSerial = serial
152 break
153
154 conn.close()
155
156 def _listener(self):
157 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
158 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
159 try:
160 sock.bind(("127.0.0.1", self._serverPort))
161 except socket.error as e:
162 print("Error binding in the RPZ listener: %s" % str(e))
163 sys.exit(1)
164
165 sock.listen(100)
166 while True:
167 try:
168 (conn, _) = sock.accept()
169 thread = threading.Thread(name='RPZ Connection Handler',
170 target=self._connectionHandler,
171 args=[conn])
172 thread.setDaemon(True)
173 thread.start()
174
175 except socket.error as e:
176 print('Error in RPZ socket: %s' % str(e))
177 sock.close()
178
179 rpzServerPort = 4250
180 rpzServer = RPZServer(rpzServerPort)
181
182 class RPZRecursorTest(RecursorTest):
183 """
184 This test makes sure that we correctly update RPZ zones via AXFR then IXFR
185 """
186
187 global rpzServerPort
188 _lua_config_file = """
189 -- The first server is a bogus one, to test that we correctly fail over to the second one
190 rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
191 """ % (rpzServerPort)
192 _wsPort = 8042
193 _wsTimeout = 2
194 _wsPassword = 'secretpassword'
195 _apiKey = 'secretapikey'
196 _confdir = 'RPZ'
197 _lua_dns_script_file = """
198
199 function prerpz(dq)
200 -- disable the RPZ policy named 'zone.rpz' for AD=1 queries
201 if dq:getDH():getAD() then
202 dq:discardPolicy('zone.rpz.')
203 end
204 return false
205 end
206 """
207
208 _config_template = """
209 auth-zones=example=configs/%s/example.zone
210 webserver=yes
211 webserver-port=%d
212 webserver-address=127.0.0.1
213 webserver-password=%s
214 api-key=%s
215 """ % (_confdir, _wsPort, _wsPassword, _apiKey)
216 _xfrDone = 0
217
218 @classmethod
219 def generateRecursorConfig(cls, confdir):
220 authzonepath = os.path.join(confdir, 'example.zone')
221 with open(authzonepath, 'w') as authzone:
222 authzone.write("""$ORIGIN example.
223 @ 3600 IN SOA {soa}
224 a 3600 IN A 192.0.2.42
225 b 3600 IN A 192.0.2.42
226 c 3600 IN A 192.0.2.42
227 d 3600 IN A 192.0.2.42
228 e 3600 IN A 192.0.2.42
229 """.format(soa=cls._SOA))
230 super(RPZRecursorTest, cls).generateRecursorConfig(confdir)
231
232 @classmethod
233 def setUpClass(cls):
234
235 cls.setUpSockets()
236 cls.startResponders()
237
238 confdir = os.path.join('configs', cls._confdir)
239 cls.createConfigDir(confdir)
240
241 cls.generateRecursorConfig(confdir)
242 cls.startRecursor(confdir, cls._recursorPort)
243
244 @classmethod
245 def tearDownClass(cls):
246 cls.tearDownRecursor()
247
248 def checkBlocked(self, name, shouldBeBlocked=True, adQuery=False):
249 query = dns.message.make_query(name, 'A', want_dnssec=True)
250 query.flags |= dns.flags.CD
251 if adQuery:
252 query.flags |= dns.flags.AD
253
254 for method in ("sendUDPQuery", "sendTCPQuery"):
255 sender = getattr(self, method)
256 res = sender(query)
257 self.assertRcodeEqual(res, dns.rcode.NOERROR)
258 if shouldBeBlocked:
259 expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
260 else:
261 expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
262
263 self.assertRRsetInAnswer(res, expected)
264
265 def checkNotBlocked(self, name, adQuery=False):
266 self.checkBlocked(name, False, adQuery)
267
268 def checkCustom(self, qname, qtype, expected):
269 query = dns.message.make_query(qname, qtype, want_dnssec=True)
270 query.flags |= dns.flags.CD
271 for method in ("sendUDPQuery", "sendTCPQuery"):
272 sender = getattr(self, method)
273 res = sender(query)
274 self.assertRcodeEqual(res, dns.rcode.NOERROR)
275 self.assertRRsetInAnswer(res, expected)
276
277 def checkNoData(self, qname, qtype):
278 query = dns.message.make_query(qname, qtype, want_dnssec=True)
279 query.flags |= dns.flags.CD
280 for method in ("sendUDPQuery", "sendTCPQuery"):
281 sender = getattr(self, method)
282 res = sender(query)
283 self.assertRcodeEqual(res, dns.rcode.NOERROR)
284 self.assertEqual(len(res.answer), 0)
285
286 def checkTruncated(self, qname, qtype='A'):
287 query = dns.message.make_query(qname, qtype, want_dnssec=True)
288 query.flags |= dns.flags.CD
289 res = self.sendUDPQuery(query)
290 self.assertRcodeEqual(res, dns.rcode.NOERROR)
291 self.assertMessageHasFlags(res, ['QR', 'RA', 'RD', 'CD', 'TC'])
292 self.assertEqual(len(res.answer), 0)
293 self.assertEqual(len(res.authority), 0)
294 self.assertEqual(len(res.additional), 0)
295
296 res = self.sendTCPQuery(query)
297 self.assertRcodeEqual(res, dns.rcode.NXDOMAIN)
298 self.assertMessageHasFlags(res, ['QR', 'RA', 'RD', 'CD'])
299 self.assertEqual(len(res.answer), 0)
300 self.assertEqual(len(res.authority), 1)
301 self.assertEqual(len(res.additional), 0)
302
303 def checkDropped(self, qname, qtype='A'):
304 query = dns.message.make_query(qname, qtype, want_dnssec=True)
305 query.flags |= dns.flags.CD
306 for method in ("sendUDPQuery", "sendTCPQuery"):
307 sender = getattr(self, method)
308 res = sender(query)
309 self.assertEqual(res, None)
310
311 def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
312 global rpzServer
313
314 rpzServer.moveToSerial(serial)
315
316 attempts = 0
317 while attempts < timeout:
318 currentSerial = rpzServer.getCurrentSerial()
319 if currentSerial > serial:
320 raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
321 if currentSerial == serial:
322 self._xfrDone = self._xfrDone + 1
323 return
324
325 attempts = attempts + 1
326 time.sleep(1)
327
328 raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
329
330 def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount):
331 headers = {'x-api-key': self._apiKey}
332 url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
333 r = requests.get(url, headers=headers, timeout=self._wsTimeout)
334 self.assertTrue(r)
335 self.assertEquals(r.status_code, 200)
336 self.assertTrue(r.json())
337 content = r.json()
338 self.assertIn('zone.rpz.', content)
339 zone = content['zone.rpz.']
340 for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
341 self.assertIn(key, zone)
342
343 self.assertEquals(zone['serial'], serial)
344 self.assertEquals(zone['records'], recordsCount)
345 self.assertEquals(zone['transfers_full'], fullXFRCount)
346 self.assertEquals(zone['transfers_success'], totalXFRCount)
347
348 def testRPZ(self):
349 # first zone, only a should be blocked
350 self.waitUntilCorrectSerialIsLoaded(1)
351 self.checkRPZStats(1, 1, 1, self._xfrDone)
352 self.checkBlocked('a.example.')
353 self.checkNotBlocked('b.example.')
354 self.checkNotBlocked('c.example.')
355
356 # second zone, a and b should be blocked
357 self.waitUntilCorrectSerialIsLoaded(2)
358 self.checkRPZStats(2, 2, 1, self._xfrDone)
359 self.checkBlocked('a.example.')
360 self.checkBlocked('b.example.')
361 self.checkNotBlocked('c.example.')
362
363 # third zone, only b should be blocked
364 self.waitUntilCorrectSerialIsLoaded(3)
365 self.checkRPZStats(3, 1, 1, self._xfrDone)
366 self.checkNotBlocked('a.example.')
367 self.checkBlocked('b.example.')
368 self.checkNotBlocked('c.example.')
369
370 # fourth zone, only c should be blocked
371 self.waitUntilCorrectSerialIsLoaded(4)
372 self.checkRPZStats(4, 1, 1, self._xfrDone)
373 self.checkNotBlocked('a.example.')
374 self.checkNotBlocked('b.example.')
375 self.checkBlocked('c.example.')
376
377 # fifth zone, we should get a full AXFR this time, and only d should be blocked
378 self.waitUntilCorrectSerialIsLoaded(5)
379 self.checkRPZStats(5, 3, 2, self._xfrDone)
380 self.checkNotBlocked('a.example.')
381 self.checkNotBlocked('b.example.')
382 self.checkNotBlocked('c.example.')
383 self.checkBlocked('d.example.')
384
385 # sixth zone, only e should be blocked, f is a local data record
386 self.waitUntilCorrectSerialIsLoaded(6)
387 self.checkRPZStats(6, 2, 2, self._xfrDone)
388 self.checkNotBlocked('a.example.')
389 self.checkNotBlocked('b.example.')
390 self.checkNotBlocked('c.example.')
391 self.checkNotBlocked('d.example.')
392 self.checkCustom('e.example.', 'A', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.1', '192.0.2.2'))
393 self.checkCustom('e.example.', 'MX', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'MX', '10 mx.example.'))
394 self.checkNoData('e.example.', 'AAAA')
395 self.checkCustom('f.example.', 'A', dns.rrset.from_text('f.example.', 0, dns.rdataclass.IN, 'CNAME', 'e.example.'))
396
397 # seventh zone, e should only have one A
398 self.waitUntilCorrectSerialIsLoaded(7)
399 self.checkRPZStats(7, 4, 2, self._xfrDone)
400 self.checkNotBlocked('a.example.')
401 self.checkNotBlocked('b.example.')
402 self.checkNotBlocked('c.example.')
403 self.checkNotBlocked('d.example.')
404 self.checkCustom('e.example.', 'A', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.2'))
405 self.checkCustom('e.example.', 'MX', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'MX', '10 mx.example.'))
406 self.checkNoData('e.example.', 'AAAA')
407 self.checkCustom('f.example.', 'A', dns.rrset.from_text('f.example.', 0, dns.rdataclass.IN, 'CNAME', 'e.example.'))
408 # check that the policy is disabled for AD=1 queries
409 self.checkNotBlocked('e.example.', True)
410 # check non-custom policies
411 self.checkTruncated('tc.example.')
412 self.checkDropped('drop.example.')