]>
Commit | Line | Data |
---|---|---|
8bed4b38 OM |
1 | import dns |
2 | import json | |
3 | import os | |
4 | import requests | |
5 | import socket | |
6 | import struct | |
7 | import sys | |
8 | import threading | |
9 | import time | |
10 | ||
11 | from recursortests import RecursorTest | |
12 | ||
13 | class BadRPZServer(object): | |
14 | ||
15 | def __init__(self, port): | |
16 | self._currentSerial = 0 | |
17 | self._targetSerial = 1 | |
18 | self._serverPort = port | |
19 | listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[]) | |
20 | listener.setDaemon(True) | |
21 | listener.start() | |
22 | ||
23 | def getCurrentSerial(self): | |
24 | return self._currentSerial | |
25 | ||
26 | def moveToSerial(self, newSerial): | |
27 | if newSerial == self._currentSerial: | |
28 | return False | |
29 | ||
30 | #if newSerial != self._currentSerial + 1: | |
31 | # raise AssertionError("Asking the RPZ server to serve serial %d, already serving %d" % (newSerial, self._currentSerial)) | |
32 | self._targetSerial = newSerial | |
33 | return True | |
34 | ||
35 | def _getAnswer(self, message): | |
36 | ||
37 | response = dns.message.make_response(message) | |
38 | records = [] | |
39 | ||
40 | if message.question[0].rdtype == dns.rdatatype.AXFR: | |
41 | if self._currentSerial != 0: | |
42 | print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial)) | |
43 | return (None, self._currentSerial) | |
44 | ||
45 | newSerial = self._targetSerial | |
46 | records = [ | |
47 | dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), | |
48 | dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'), | |
49 | dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial) | |
50 | ] | |
51 | ||
52 | elif message.question[0].rdtype == dns.rdatatype.IXFR: | |
53 | oldSerial = message.authority[0][0].serial | |
54 | ||
55 | newSerial = self._targetSerial | |
56 | if newSerial == 2: | |
57 | records = [ | |
58 | dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), | |
59 | dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial), | |
60 | # no deletion | |
61 | dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), | |
62 | dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'), | |
63 | ] | |
64 | elif newSerial == 3: | |
65 | records = [ | |
66 | dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), | |
67 | dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'), | |
68 | ] | |
69 | ||
70 | response.answer = records | |
71 | return (newSerial, response) | |
72 | ||
73 | def _connectionHandler(self, conn): | |
74 | data = None | |
75 | while True: | |
76 | data = conn.recv(2) | |
77 | if not data: | |
78 | break | |
79 | (datalen,) = struct.unpack("!H", data) | |
80 | data = conn.recv(datalen) | |
81 | if not data: | |
82 | break | |
83 | ||
84 | message = dns.message.from_wire(data) | |
85 | if len(message.question) != 1: | |
86 | print('Invalid RPZ query, qdcount is %d' % (len(message.question)), file=sys.stderr) | |
87 | break | |
88 | if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]: | |
89 | print('Invalid RPZ query, qtype is %d' % (message.question.rdtype), file=sys.stderr) | |
90 | break | |
91 | (serial, answer) = self._getAnswer(message) | |
92 | if not answer: | |
93 | print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr) | |
94 | break | |
95 | ||
96 | wire = answer.to_wire() | |
97 | conn.send(struct.pack("!H", len(wire))) | |
98 | conn.send(wire) | |
99 | self._currentSerial = serial | |
100 | break | |
101 | ||
102 | conn.close() | |
103 | ||
104 | def _listener(self): | |
105 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
106 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
107 | try: | |
108 | sock.bind(("127.0.0.1", self._serverPort)) | |
109 | except socket.error as e: | |
110 | print("Error binding in the RPZ listener: %s" % str(e)) | |
111 | sys.exit(1) | |
112 | ||
113 | sock.listen(100) | |
114 | while True: | |
115 | try: | |
116 | (conn, _) = sock.accept() | |
117 | thread = threading.Thread(name='RPZ Connection Handler', | |
118 | target=self._connectionHandler, | |
119 | args=[conn]) | |
120 | thread.setDaemon(True) | |
121 | thread.start() | |
122 | ||
123 | except socket.error as e: | |
124 | print('Error in RPZ socket: %s' % str(e)) | |
125 | sock.close() | |
126 | ||
127 | class RPZIncompleteRecursorTest(RecursorTest): | |
128 | _wsPort = 8042 | |
129 | _wsTimeout = 2 | |
130 | _wsPassword = 'secretpassword' | |
131 | _apiKey = 'secretapikey' | |
132 | _confdir = 'RPZIncomplete' | |
133 | _auth_zones = { | |
134 | '8': {'threads': 1, | |
135 | 'zones': ['ROOT']}, | |
136 | '10': {'threads': 1, | |
137 | 'zones': ['example']}, | |
138 | } | |
139 | ||
140 | _config_template = """ | |
141 | auth-zones=example=configs/%s/example.zone | |
142 | webserver=yes | |
143 | webserver-port=%d | |
144 | webserver-address=127.0.0.1 | |
145 | webserver-password=%s | |
146 | api-key=%s | |
147 | log-rpz-changes=yes | |
148 | """ % (_confdir, _wsPort, _wsPassword, _apiKey) | |
149 | ||
150 | def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount, failedXFRCount): | |
151 | headers = {'x-api-key': self._apiKey} | |
152 | url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics' | |
153 | r = requests.get(url, headers=headers, timeout=self._wsTimeout) | |
154 | self.assertTrue(r) | |
155 | self.assertEqual(r.status_code, 200) | |
156 | self.assertTrue(r.json()) | |
157 | content = r.json() | |
158 | self.assertIn('zone.rpz.', content) | |
159 | zone = content['zone.rpz.'] | |
160 | for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']: | |
161 | self.assertIn(key, zone) | |
162 | ||
163 | self.assertEqual(zone['serial'], serial) | |
164 | self.assertEqual(zone['records'], recordsCount) | |
165 | self.assertEqual(zone['transfers_full'], fullXFRCount) | |
166 | self.assertEqual(zone['transfers_success'], totalXFRCount) | |
167 | self.assertEqual(zone['transfers_failed'], failedXFRCount) | |
168 | ||
169 | badrpzServerPort = 4251 | |
170 | badrpzServer = BadRPZServer(badrpzServerPort) | |
171 | ||
172 | class RPZXFRIncompleteRecursorTest(RPZIncompleteRecursorTest): | |
173 | """ | |
174 | This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR | |
175 | """ | |
176 | ||
177 | global badrpzServerPort | |
178 | _lua_config_file = """ | |
179 | -- The first server is a bogus one, to test that we correctly fail over to the second one | |
180 | rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 }) | |
181 | """ % (badrpzServerPort) | |
182 | _confdir = 'RPZXFRIncomplete' | |
183 | _wsPort = 8042 | |
184 | _wsTimeout = 2 | |
185 | _wsPassword = 'secretpassword' | |
186 | _apiKey = 'secretapikey' | |
187 | _config_template = """ | |
188 | auth-zones=example=configs/%s/example.zone | |
189 | webserver=yes | |
190 | webserver-port=%d | |
191 | webserver-address=127.0.0.1 | |
192 | webserver-password=%s | |
193 | api-key=%s | |
194 | """ % (_confdir, _wsPort, _wsPassword, _apiKey) | |
195 | ||
196 | @classmethod | |
197 | def generateRecursorConfig(cls, confdir): | |
198 | authzonepath = os.path.join(confdir, 'example.zone') | |
199 | with open(authzonepath, 'w') as authzone: | |
200 | authzone.write("""$ORIGIN example. | |
201 | @ 3600 IN SOA {soa} | |
202 | a 3600 IN A 192.0.2.42 | |
203 | b 3600 IN A 192.0.2.42 | |
204 | c 3600 IN A 192.0.2.42 | |
205 | d 3600 IN A 192.0.2.42 | |
206 | e 3600 IN A 192.0.2.42 | |
207 | """.format(soa=cls._SOA)) | |
208 | super(RPZIncompleteRecursorTest, cls).generateRecursorConfig(confdir) | |
209 | ||
210 | def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5): | |
211 | global badrpzServer | |
212 | ||
213 | badrpzServer.moveToSerial(serial) | |
214 | ||
215 | attempts = 0 | |
216 | while attempts < timeout: | |
217 | currentSerial = badrpzServer.getCurrentSerial() | |
218 | if currentSerial > serial: | |
219 | raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial)) | |
220 | if currentSerial == serial: | |
221 | return | |
222 | ||
223 | attempts = attempts + 1 | |
224 | time.sleep(1) | |
225 | ||
226 | raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial)) | |
227 | ||
228 | def testRPZ(self): | |
229 | self.waitForTCPSocket("127.0.0.1", self._wsPort) | |
230 | # First zone | |
231 | self.waitUntilCorrectSerialIsLoaded(1) | |
232 | self.checkRPZStats(1, 1, 1, 1, 1) # failure count includes a port 9999 attempt | |
233 | ||
234 | # second zone, should fail, incomplete IXFR | |
235 | self.waitUntilCorrectSerialIsLoaded(2) | |
236 | self.checkRPZStats(1, 1, 1, 1, 3) | |
237 | ||
238 | # third zone, should fail, incomplete AXFR | |
239 | self.waitUntilCorrectSerialIsLoaded(3) | |
240 | self.checkRPZStats(1, 1, 1, 1, 5) | |
241 |