]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.recursor-dnssec/test_RPZ.py
Merge pull request #7249 from Habbie/gtar
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_RPZ.py
1 import dns
2 import json
3 import os
4 import requests
5 import socket
6 import struct
7 import sys
8 import threading
9 import time
10
11 from recursortests import RecursorTest
12
13 class RPZServer(object):
14
15 def __init__(self, port):
16 self._currentSerial = 0
17 self._targetSerial = 1
18 self._serverPort = port
19 listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
20 listener.setDaemon(True)
21 listener.start()
22
23 def getCurrentSerial(self):
24 return self._currentSerial
25
26 def moveToSerial(self, newSerial):
27 if newSerial == self._currentSerial:
28 return False
29
30 if newSerial != self._currentSerial + 1:
31 raise AssertionError("Asking the RPZ server to server serial %d, already serving %d" % (newSerial, self._currentSerial))
32 self._targetSerial = newSerial
33 return True
34
35 def _getAnswer(self, message):
36
37 response = dns.message.make_response(message)
38 records = []
39
40 if message.question[0].rdtype == dns.rdatatype.AXFR:
41 if self._currentSerial != 0:
42 print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
43 return (None, self._currentSerial)
44
45 newSerial = self._targetSerial
46 records = [
47 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
48 dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
49 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
50 ]
51
52 elif message.question[0].rdtype == dns.rdatatype.IXFR:
53 oldSerial = message.authority[0][0].serial
54
55 if oldSerial != self._currentSerial:
56 print('Received an IXFR query with an unexpected serial %d, expected %d' % (oldSerial, self._currentSerial))
57 return (None, self._currentSerial)
58
59 newSerial = self._targetSerial
60 if newSerial == 2:
61 records = [
62 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
63 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
64 # no deletion
65 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
66 dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
67 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
68 ]
69 elif newSerial == 3:
70 records = [
71 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
72 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
73 dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
74 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
75 # no addition
76 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
77 ]
78 elif newSerial == 4:
79 records = [
80 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
81 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
82 dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
83 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
84 dns.rrset.from_text('c.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
85 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
86 ]
87 elif newSerial == 5:
88 # this one is a bit special, we are answering with a full AXFR
89 records = [
90 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
91 dns.rrset.from_text('d.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
92 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
93 ]
94 elif newSerial == 6:
95 # back to IXFR
96 records = [
97 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
98 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
99 dns.rrset.from_text('d.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
100 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
101 dns.rrset.from_text('e.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
102 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
103 ]
104
105 response.answer = records
106 return (newSerial, response)
107
108 def _connectionHandler(self, conn):
109 data = None
110 while True:
111 data = conn.recv(2)
112 if not data:
113 break
114 (datalen,) = struct.unpack("!H", data)
115 data = conn.recv(datalen)
116 if not data:
117 break
118
119 message = dns.message.from_wire(data)
120 if len(message.question) != 1:
121 print('Invalid RPZ query, qdcount is %d' % (len(message.question)))
122 break
123 if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
124 print('Invalid RPZ query, qtype is %d' % (message.question.rdtype))
125 break
126 (serial, answer) = self._getAnswer(message)
127 if not answer:
128 print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype))
129 break
130
131 wire = answer.to_wire()
132 conn.send(struct.pack("!H", len(wire)))
133 conn.send(wire)
134 self._currentSerial = serial
135 break
136
137 conn.close()
138
139 def _listener(self):
140 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
141 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
142 try:
143 sock.bind(("127.0.0.1", self._serverPort))
144 except socket.error as e:
145 print("Error binding in the RPZ listener: %s" % str(e))
146 sys.exit(1)
147
148 sock.listen(100)
149 while True:
150 try:
151 (conn, _) = sock.accept()
152 thread = threading.Thread(name='RPZ Connection Handler',
153 target=self._connectionHandler,
154 args=[conn])
155 thread.setDaemon(True)
156 thread.start()
157
158 except socket.error as e:
159 print('Error in RPZ socket: %s' % str(e))
160 sock.close()
161
162 rpzServerPort = 4250
163 rpzServer = RPZServer(rpzServerPort)
164
165 class RPZRecursorTest(RecursorTest):
166 """
167 This test makes sure that we correctly update RPZ zones via AXFR then IXFR
168 """
169
170 global rpzServerPort
171 _lua_config_file = """
172 -- The first server is a bogus one, to test that we correctly fail over to the second one
173 rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
174 """ % (rpzServerPort)
175 _wsPort = 8042
176 _wsTimeout = 2
177 _wsPassword = 'secretpassword'
178 _apiKey = 'secretapikey'
179 _confdir = 'RPZ'
180 _lua_dns_script_file = """
181
182 function prerpz(dq)
183 -- disable the RPZ policy named 'zone.rpz' for AD=1 queries
184 if dq:getDH():getAD() then
185 dq:discardPolicy('zone.rpz.')
186 end
187 return false
188 end
189 """
190
191 _config_template = """
192 auth-zones=example=configs/%s/example.zone
193 webserver=yes
194 webserver-port=%d
195 webserver-address=127.0.0.1
196 webserver-password=%s
197 api-key=%s
198 """ % (_confdir, _wsPort, _wsPassword, _apiKey)
199 _xfrDone = 0
200
201 @classmethod
202 def generateRecursorConfig(cls, confdir):
203 authzonepath = os.path.join(confdir, 'example.zone')
204 with open(authzonepath, 'w') as authzone:
205 authzone.write("""$ORIGIN example.
206 @ 3600 IN SOA {soa}
207 a 3600 IN A 192.0.2.42
208 b 3600 IN A 192.0.2.42
209 c 3600 IN A 192.0.2.42
210 d 3600 IN A 192.0.2.42
211 e 3600 IN A 192.0.2.42
212 """.format(soa=cls._SOA))
213 super(RPZRecursorTest, cls).generateRecursorConfig(confdir)
214
215 @classmethod
216 def setUpClass(cls):
217
218 cls.setUpSockets()
219 cls.startResponders()
220
221 confdir = os.path.join('configs', cls._confdir)
222 cls.createConfigDir(confdir)
223
224 cls.generateRecursorConfig(confdir)
225 cls.startRecursor(confdir, cls._recursorPort)
226
227 @classmethod
228 def tearDownClass(cls):
229 cls.tearDownRecursor()
230
231 def checkBlocked(self, name, shouldBeBlocked=True, adQuery=False):
232 query = dns.message.make_query(name, 'A', want_dnssec=True)
233 query.flags |= dns.flags.CD
234 if adQuery:
235 query.flags |= dns.flags.AD
236 res = self.sendUDPQuery(query)
237 if shouldBeBlocked:
238 expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
239 else:
240 expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
241
242 self.assertRRsetInAnswer(res, expected)
243
244 def checkNotBlocked(self, name, adQuery=False):
245 self.checkBlocked(name, False, adQuery)
246
247 def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
248 global rpzServer
249
250 rpzServer.moveToSerial(serial)
251
252 attempts = 0
253 while attempts < timeout:
254 currentSerial = rpzServer.getCurrentSerial()
255 if currentSerial > serial:
256 raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
257 if currentSerial == serial:
258 self._xfrDone = self._xfrDone + 1
259 return
260
261 attempts = attempts + 1
262 time.sleep(1)
263
264 raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
265
266 def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount):
267 headers = {'x-api-key': self._apiKey}
268 url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
269 r = requests.get(url, headers=headers, timeout=self._wsTimeout)
270 self.assertTrue(r)
271 self.assertEquals(r.status_code, 200)
272 self.assertTrue(r.json())
273 content = r.json()
274 self.assertIn('zone.rpz.', content)
275 zone = content['zone.rpz.']
276 for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
277 self.assertIn(key, zone)
278
279 self.assertEquals(zone['serial'], serial)
280 self.assertEquals(zone['records'], recordsCount)
281 self.assertEquals(zone['transfers_full'], fullXFRCount)
282 self.assertEquals(zone['transfers_success'], totalXFRCount)
283
284 def testRPZ(self):
285 # first zone, only a should be blocked
286 self.waitUntilCorrectSerialIsLoaded(1)
287 self.checkRPZStats(1, 1, 1, self._xfrDone)
288 self.checkBlocked('a.example.')
289 self.checkNotBlocked('b.example.')
290 self.checkNotBlocked('c.example.')
291
292 # second zone, a and b should be blocked
293 self.waitUntilCorrectSerialIsLoaded(2)
294 self.checkRPZStats(2, 2, 1, self._xfrDone)
295 self.checkBlocked('a.example.')
296 self.checkBlocked('b.example.')
297 self.checkNotBlocked('c.example.')
298
299 # third zone, only b should be blocked
300 self.waitUntilCorrectSerialIsLoaded(3)
301 self.checkRPZStats(3, 1, 1, self._xfrDone)
302 self.checkNotBlocked('a.example.')
303 self.checkBlocked('b.example.')
304 self.checkNotBlocked('c.example.')
305
306 # fourth zone, only c should be blocked
307 self.waitUntilCorrectSerialIsLoaded(4)
308 self.checkRPZStats(4, 1, 1, self._xfrDone)
309 self.checkNotBlocked('a.example.')
310 self.checkNotBlocked('b.example.')
311 self.checkBlocked('c.example.')
312
313 # fifth zone, we should get a full AXFR this time, and only d should be blocked
314 self.waitUntilCorrectSerialIsLoaded(5)
315 self.checkRPZStats(5, 1, 2, self._xfrDone)
316 self.checkNotBlocked('a.example.')
317 self.checkNotBlocked('b.example.')
318 self.checkNotBlocked('c.example.')
319 self.checkBlocked('d.example.')
320
321 # sixth zone, only e should be blocked
322 self.waitUntilCorrectSerialIsLoaded(6)
323 self.checkRPZStats(6, 1, 2, self._xfrDone)
324 self.checkNotBlocked('a.example.')
325 self.checkNotBlocked('b.example.')
326 self.checkNotBlocked('c.example.')
327 self.checkNotBlocked('d.example.')
328 self.checkBlocked('e.example.')
329 # check that the policy is disabled for AD=1 queries
330 self.checkNotBlocked('e.example.', True)