]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.recursor-dnssec/test_RPZ.py
Merge pull request #7026 from jsoref/configure-enable-with
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_RPZ.py
1 import dns
2 import json
3 import os
4 import requests
5 import socket
6 import struct
7 import sys
8 import threading
9 import time
10
11 from recursortests import RecursorTest
12
13 class RPZServer(object):
14
15 def __init__(self, port):
16 self._currentSerial = 0
17 self._targetSerial = 1
18 self._serverPort = port
19 listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
20 listener.setDaemon(True)
21 listener.start()
22
23 def getCurrentSerial(self):
24 return self._currentSerial
25
26 def moveToSerial(self, newSerial):
27 if newSerial == self._currentSerial:
28 return False
29
30 if newSerial != self._currentSerial + 1:
31 raise AssertionError("Asking the RPZ server to server serial %d, already serving %d" % (newSerial, self._currentSerial))
32 self._targetSerial = newSerial
33 return True
34
35 def _getAnswer(self, message):
36
37 response = dns.message.make_response(message)
38 records = []
39
40 if message.question[0].rdtype == dns.rdatatype.AXFR:
41 if self._currentSerial != 0:
42 print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
43 return (None, self._currentSerial)
44
45 newSerial = self._targetSerial
46 records = [
47 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
48 dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
49 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
50 ]
51
52 elif message.question[0].rdtype == dns.rdatatype.IXFR:
53 oldSerial = message.authority[0][0].serial
54
55 if oldSerial != self._currentSerial:
56 print('Received an IXFR query with an unexpected serial %d, expected %d' % (oldSerial, self._currentSerial))
57 return (None, self._currentSerial)
58
59 newSerial = self._targetSerial
60 if newSerial == 2:
61 records = [
62 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
63 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
64 # no deletion
65 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
66 dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
67 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
68 ]
69 elif newSerial == 3:
70 records = [
71 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
72 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
73 dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
74 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
75 # no addition
76 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
77 ]
78 elif newSerial == 4:
79 records = [
80 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
81 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
82 dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
83 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
84 dns.rrset.from_text('c.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
85 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
86 ]
87 elif newSerial == 5:
88 # this one is a bit special, we are answering with a full AXFR
89 records = [
90 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
91 dns.rrset.from_text('d.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
92 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
93 ]
94 elif newSerial == 6:
95 # back to IXFR
96 records = [
97 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
98 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
99 dns.rrset.from_text('d.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
100 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
101 dns.rrset.from_text('e.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
102 dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
103 ]
104
105 response.answer = records
106 return (newSerial, response)
107
108 def _connectionHandler(self, conn):
109 data = None
110 while True:
111 data = conn.recv(2)
112 if not data:
113 break
114 (datalen,) = struct.unpack("!H", data)
115 data = conn.recv(datalen)
116 if not data:
117 break
118
119 message = dns.message.from_wire(data)
120 if len(message.question) != 1:
121 print('Invalid RPZ query, qdcount is %d' % (len(message.question)))
122 break
123 if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
124 print('Invalid RPZ query, qtype is %d' % (message.question.rdtype))
125 break
126 (serial, answer) = self._getAnswer(message)
127 if not answer:
128 print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype))
129 break
130
131 wire = answer.to_wire()
132 conn.send(struct.pack("!H", len(wire)))
133 conn.send(wire)
134 self._currentSerial = serial
135 break
136
137 conn.close()
138
139 def _listener(self):
140 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
141 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
142 try:
143 sock.bind(("127.0.0.1", self._serverPort))
144 except socket.error as e:
145 print("Error binding in the RPZ listener: %s" % str(e))
146 sys.exit(1)
147
148 sock.listen(100)
149 while True:
150 try:
151 (conn, _) = sock.accept()
152 thread = threading.Thread(name='RPZ Connection Handler',
153 target=self._connectionHandler,
154 args=[conn])
155 thread.setDaemon(True)
156 thread.start()
157
158 except socket.error as e:
159 print('Error in RPZ socket: %s' % str(e))
160 sock.close()
161
162 rpzServerPort = 4250
163 rpzServer = RPZServer(rpzServerPort)
164
165 class RPZRecursorTest(RecursorTest):
166 """
167 This test makes sure that we correctly update RPZ zones via AXFR then IXFR
168 """
169
170 global rpzServerPort
171 _lua_config_file = """
172 -- The first server is a bogus one, to test that we correctly fail over to the second one
173 rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
174 """ % (rpzServerPort)
175 _wsPort = 8042
176 _wsTimeout = 2
177 _wsPassword = 'secretpassword'
178 _apiKey = 'secretapikey'
179 _confdir = 'RPZ'
180 _config_template = """
181 auth-zones=example=configs/%s/example.zone
182 webserver=yes
183 webserver-port=%d
184 webserver-address=127.0.0.1
185 webserver-password=%s
186 api-key=%s
187 """ % (_confdir, _wsPort, _wsPassword, _apiKey)
188 _xfrDone = 0
189
190 @classmethod
191 def generateRecursorConfig(cls, confdir):
192 authzonepath = os.path.join(confdir, 'example.zone')
193 with open(authzonepath, 'w') as authzone:
194 authzone.write("""$ORIGIN example.
195 @ 3600 IN SOA {soa}
196 a 3600 IN A 192.0.2.42
197 b 3600 IN A 192.0.2.42
198 c 3600 IN A 192.0.2.42
199 d 3600 IN A 192.0.2.42
200 e 3600 IN A 192.0.2.42
201 """.format(soa=cls._SOA))
202 super(RPZRecursorTest, cls).generateRecursorConfig(confdir)
203
204 @classmethod
205 def setUpClass(cls):
206
207 cls.setUpSockets()
208 cls.startResponders()
209
210 confdir = os.path.join('configs', cls._confdir)
211 cls.createConfigDir(confdir)
212
213 cls.generateRecursorConfig(confdir)
214 cls.startRecursor(confdir, cls._recursorPort)
215
216 @classmethod
217 def tearDownClass(cls):
218 cls.tearDownRecursor()
219
220 def checkBlocked(self, name, shouldBeBlocked=True):
221 query = dns.message.make_query(name, 'A', want_dnssec=True)
222 query.flags |= dns.flags.CD
223 res = self.sendUDPQuery(query)
224 if shouldBeBlocked:
225 expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.1')
226 else:
227 expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
228
229 self.assertRRsetInAnswer(res, expected)
230
231 def checkNotBlocked(self, name):
232 self.checkBlocked(name, False)
233
234 def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
235 global rpzServer
236
237 rpzServer.moveToSerial(serial)
238
239 attempts = 0
240 while attempts < timeout:
241 currentSerial = rpzServer.getCurrentSerial()
242 if currentSerial > serial:
243 raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
244 if currentSerial == serial:
245 self._xfrDone = self._xfrDone + 1
246 return
247
248 attempts = attempts + 1
249 time.sleep(1)
250
251 raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
252
253 def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount):
254 headers = {'x-api-key': self._apiKey}
255 url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
256 r = requests.get(url, headers=headers, timeout=self._wsTimeout)
257 self.assertTrue(r)
258 self.assertEquals(r.status_code, 200)
259 self.assertTrue(r.json())
260 content = r.json()
261 self.assertIn('zone.rpz.', content)
262 zone = content['zone.rpz.']
263 for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
264 self.assertIn(key, zone)
265
266 self.assertEquals(zone['serial'], serial)
267 self.assertEquals(zone['records'], recordsCount)
268 self.assertEquals(zone['transfers_full'], fullXFRCount)
269 self.assertEquals(zone['transfers_success'], totalXFRCount)
270
271 def testRPZ(self):
272 # first zone, only a should be blocked
273 self.waitUntilCorrectSerialIsLoaded(1)
274 self.checkRPZStats(1, 1, 1, self._xfrDone)
275 self.checkBlocked('a.example.')
276 self.checkNotBlocked('b.example.')
277 self.checkNotBlocked('c.example.')
278
279 # second zone, a and b should be blocked
280 self.waitUntilCorrectSerialIsLoaded(2)
281 self.checkRPZStats(2, 2, 1, self._xfrDone)
282 self.checkBlocked('a.example.')
283 self.checkBlocked('b.example.')
284 self.checkNotBlocked('c.example.')
285
286 # third zone, only b should be blocked
287 self.waitUntilCorrectSerialIsLoaded(3)
288 self.checkRPZStats(3, 1, 1, self._xfrDone)
289 self.checkNotBlocked('a.example.')
290 self.checkBlocked('b.example.')
291 self.checkNotBlocked('c.example.')
292
293 # fourth zone, only c should be blocked
294 self.waitUntilCorrectSerialIsLoaded(4)
295 self.checkRPZStats(4, 1, 1, self._xfrDone)
296 self.checkNotBlocked('a.example.')
297 self.checkNotBlocked('b.example.')
298 self.checkBlocked('c.example.')
299
300 # fifth zone, we should get a full AXFR this time, and only d should be blocked
301 self.waitUntilCorrectSerialIsLoaded(5)
302 self.checkRPZStats(5, 1, 2, self._xfrDone)
303 self.checkNotBlocked('a.example.')
304 self.checkNotBlocked('b.example.')
305 self.checkNotBlocked('c.example.')
306 self.checkBlocked('d.example.')
307
308 # sixth zone, only e should be blocked
309 self.waitUntilCorrectSerialIsLoaded(6)
310 self.checkRPZStats(6, 1, 2, self._xfrDone)
311 self.checkNotBlocked('a.example.')
312 self.checkNotBlocked('b.example.')
313 self.checkNotBlocked('c.example.')
314 self.checkNotBlocked('d.example.')
315 self.checkBlocked('e.example.')