]>
Commit | Line | Data |
---|---|---|
1bc56192 CHB |
1 | #!/usr/bin/env python2 |
2 | ||
7a0ea291 | 3 | from __future__ import print_function |
1bc56192 CHB |
4 | import errno |
5 | import shutil | |
6 | import os | |
7 | import socket | |
8 | import struct | |
9 | import subprocess | |
10 | import sys | |
11 | import time | |
12 | import unittest | |
13 | import dns | |
14 | import dns.message | |
15 | ||
16 | from pprint import pprint | |
b7f9136f | 17 | from eqdnsmessage import AssertEqualDNSMessageMixin |
1bc56192 | 18 | |
b7f9136f | 19 | class AuthTest(AssertEqualDNSMessageMixin, unittest.TestCase): |
1bc56192 CHB |
20 | """ |
21 | Setup auth required for the tests | |
22 | """ | |
23 | ||
24 | _confdir = 'auth' | |
25 | _authPort = 5300 | |
26 | ||
06c14069 PL |
27 | _config_params = [] |
28 | ||
29 | _config_template_default = """ | |
30 | module-dir=../regression-tests/modules | |
31 | daemon=no | |
06c14069 PL |
32 | bind-config={confdir}/named.conf |
33 | bind-dnssec-db={bind_dnssec_db} | |
34 | socket-dir={confdir} | |
35 | cache-ttl=0 | |
36 | negquery-cache-ttl=0 | |
37 | query-cache-ttl=0 | |
38 | log-dns-queries=yes | |
39 | log-dns-details=yes | |
40 | loglevel=9 | |
41 | distributor-threads=1""" | |
42 | ||
43 | _config_template = "" | |
44 | ||
1bc56192 CHB |
45 | _root_DS = "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a" |
46 | ||
47 | # The default SOA for zones in the authoritative servers | |
48 | _SOA = "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300" | |
49 | ||
50 | # The definitions of the zones on the authoritative servers, the key is the | |
51 | # zonename and the value is the zonefile content. several strings are replaced: | |
52 | # - {soa} => value of _SOA | |
53 | # - {prefix} value of _PREFIX | |
54 | _zones = { | |
55 | 'example.org': """ | |
56 | example.org. 3600 IN SOA {soa} | |
57 | example.org. 3600 IN NS ns1.example.org. | |
58 | example.org. 3600 IN NS ns2.example.org. | |
59 | ns1.example.org. 3600 IN A {prefix}.10 | |
60 | ns2.example.org. 3600 IN A {prefix}.11 | |
61 | """, | |
62 | } | |
63 | ||
64 | _zone_keys = { | |
65 | 'example.org': """ | |
66 | Private-key-format: v1.2 | |
67 | Algorithm: 13 (ECDSAP256SHA256) | |
68 | PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U= | |
69 | """, | |
70 | } | |
71 | ||
72 | _auth_cmd = ['authbind', | |
73 | os.environ['PDNS']] | |
74 | _auth_env = {} | |
75 | _auths = {} | |
76 | ||
77 | _PREFIX = os.environ['PREFIX'] | |
78 | ||
79 | ||
80 | @classmethod | |
81 | def createConfigDir(cls, confdir): | |
82 | try: | |
83 | shutil.rmtree(confdir) | |
84 | except OSError as e: | |
85 | if e.errno != errno.ENOENT: | |
86 | raise | |
7a0ea291 | 87 | os.mkdir(confdir, 0o755) |
1bc56192 CHB |
88 | |
89 | @classmethod | |
90 | def generateAuthZone(cls, confdir, zonename, zonecontent): | |
91 | with open(os.path.join(confdir, '%s.zone' % zonename), 'w') as zonefile: | |
92 | zonefile.write(zonecontent.format(prefix=cls._PREFIX, soa=cls._SOA)) | |
93 | ||
94 | @classmethod | |
95 | def generateAuthNamedConf(cls, confdir, zones): | |
96 | with open(os.path.join(confdir, 'named.conf'), 'w') as namedconf: | |
97 | namedconf.write(""" | |
98 | options { | |
99 | directory "%s"; | |
100 | };""" % confdir) | |
101 | for zonename in zones: | |
102 | zone = '.' if zonename == 'ROOT' else zonename | |
103 | ||
104 | namedconf.write(""" | |
105 | zone "%s" { | |
106 | type master; | |
107 | file "%s.zone"; | |
108 | };""" % (zone, zonename)) | |
109 | ||
110 | @classmethod | |
111 | def generateAuthConfig(cls, confdir): | |
112 | bind_dnssec_db = os.path.join(confdir, 'bind-dnssec.sqlite3') | |
113 | ||
06c14069 PL |
114 | params = tuple([getattr(cls, param) for param in cls._config_params]) |
115 | ||
1bc56192 | 116 | with open(os.path.join(confdir, 'pdns.conf'), 'w') as pdnsconf: |
06c14069 PL |
117 | pdnsconf.write(cls._config_template_default.format( |
118 | confdir=confdir, prefix=cls._PREFIX, | |
119 | bind_dnssec_db=bind_dnssec_db)) | |
120 | pdnsconf.write(cls._config_template % params) | |
1bc56192 CHB |
121 | |
122 | pdnsutilCmd = [os.environ['PDNSUTIL'], | |
123 | '--config-dir=%s' % confdir, | |
124 | 'create-bind-db', | |
125 | bind_dnssec_db] | |
126 | ||
7a0ea291 | 127 | print(' '.join(pdnsutilCmd)) |
1bc56192 CHB |
128 | try: |
129 | subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT) | |
130 | except subprocess.CalledProcessError as e: | |
ff0bc6a6 | 131 | raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd, e.returncode, e.output)) |
1bc56192 CHB |
132 | |
133 | @classmethod | |
134 | def secureZone(cls, confdir, zonename, key=None): | |
135 | zone = '.' if zonename == 'ROOT' else zonename | |
136 | if not key: | |
137 | pdnsutilCmd = [os.environ['PDNSUTIL'], | |
138 | '--config-dir=%s' % confdir, | |
139 | 'secure-zone', | |
140 | zone] | |
141 | else: | |
142 | keyfile = os.path.join(confdir, 'dnssec.key') | |
143 | with open(keyfile, 'w') as fdKeyfile: | |
144 | fdKeyfile.write(key) | |
145 | ||
146 | pdnsutilCmd = [os.environ['PDNSUTIL'], | |
147 | '--config-dir=%s' % confdir, | |
148 | 'import-zone-key', | |
149 | zone, | |
150 | keyfile, | |
151 | 'active', | |
152 | 'ksk'] | |
153 | ||
7a0ea291 | 154 | print(' '.join(pdnsutilCmd)) |
1bc56192 CHB |
155 | try: |
156 | subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT) | |
157 | except subprocess.CalledProcessError as e: | |
ff0bc6a6 | 158 | raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd, e.returncode, e.output)) |
1bc56192 CHB |
159 | |
160 | @classmethod | |
161 | def generateAllAuthConfig(cls, confdir): | |
162 | if cls._zones: | |
163 | cls.generateAuthConfig(confdir) | |
164 | cls.generateAuthNamedConf(confdir, cls._zones.keys()) | |
165 | ||
166 | for zonename, zonecontent in cls._zones.items(): | |
167 | cls.generateAuthZone(confdir, | |
168 | zonename, | |
169 | zonecontent) | |
170 | if cls._zone_keys.get(zonename, None): | |
171 | cls.secureZone(confdir, zonename, cls._zone_keys.get(zonename)) | |
172 | ||
173 | @classmethod | |
174 | def startAuth(cls, confdir, ipaddress): | |
175 | ||
176 | print("Launching pdns_server..") | |
177 | authcmd = list(cls._auth_cmd) | |
178 | authcmd.append('--config-dir=%s' % confdir) | |
179 | authcmd.append('--local-address=%s' % ipaddress) | |
180 | authcmd.append('--local-port=%s' % cls._authPort) | |
181 | authcmd.append('--loglevel=9') | |
27a630b4 | 182 | authcmd.append('--enable-lua-records') |
1bc56192 CHB |
183 | print(' '.join(authcmd)) |
184 | ||
185 | logFile = os.path.join(confdir, 'pdns.log') | |
186 | with open(logFile, 'w') as fdLog: | |
187 | cls._auths[ipaddress] = subprocess.Popen(authcmd, close_fds=True, | |
188 | stdout=fdLog, stderr=fdLog, | |
189 | env=cls._auth_env) | |
190 | ||
191 | time.sleep(2) | |
192 | ||
193 | if cls._auths[ipaddress].poll() is not None: | |
194 | try: | |
195 | cls._auths[ipaddress].kill() | |
196 | except OSError as e: | |
197 | if e.errno != errno.ESRCH: | |
198 | raise | |
199 | with open(logFile, 'r') as fdLog: | |
7a0ea291 | 200 | print(fdLog.read()) |
1bc56192 CHB |
201 | sys.exit(cls._auths[ipaddress].returncode) |
202 | ||
203 | @classmethod | |
204 | def setUpSockets(cls): | |
205 | print("Setting up UDP socket..") | |
206 | cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
207 | cls._sock.settimeout(2.0) | |
208 | cls._sock.connect((cls._PREFIX + ".1", cls._authPort)) | |
209 | ||
210 | @classmethod | |
211 | def startResponders(cls): | |
212 | pass | |
213 | ||
214 | @classmethod | |
215 | def setUpClass(cls): | |
216 | cls.setUpSockets() | |
217 | ||
218 | cls.startResponders() | |
219 | ||
220 | confdir = os.path.join('configs', cls._confdir) | |
221 | cls.createConfigDir(confdir) | |
222 | ||
223 | cls.generateAllAuthConfig(confdir) | |
224 | cls.startAuth(confdir, cls._PREFIX + ".1") | |
225 | ||
226 | print("Launching tests..") | |
227 | ||
228 | @classmethod | |
229 | def tearDownClass(cls): | |
230 | cls.tearDownAuth() | |
231 | cls.tearDownResponders() | |
232 | ||
233 | @classmethod | |
234 | def tearDownResponders(cls): | |
235 | pass | |
236 | ||
237 | @classmethod | |
238 | def tearDownClass(cls): | |
239 | cls.tearDownAuth() | |
240 | ||
241 | @classmethod | |
242 | def tearDownAuth(cls): | |
243 | if 'PDNSRECURSOR_FAST_TESTS' in os.environ: | |
244 | delay = 0.1 | |
245 | else: | |
246 | delay = 1.0 | |
247 | ||
248 | for _, auth in cls._auths.items(): | |
249 | try: | |
250 | auth.terminate() | |
251 | if auth.poll() is None: | |
252 | time.sleep(delay) | |
253 | if auth.poll() is None: | |
254 | auth.kill() | |
255 | auth.wait() | |
256 | except OSError as e: | |
257 | if e.errno != errno.ESRCH: | |
258 | raise | |
259 | ||
260 | @classmethod | |
261 | def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()): | |
262 | if timeout: | |
263 | cls._sock.settimeout(timeout) | |
264 | ||
265 | try: | |
266 | cls._sock.send(query.to_wire()) | |
267 | data = cls._sock.recv(4096) | |
268 | except socket.timeout: | |
269 | data = None | |
270 | finally: | |
271 | if timeout: | |
272 | cls._sock.settimeout(None) | |
273 | ||
274 | message = None | |
275 | if data: | |
276 | if not decode: | |
277 | return data | |
278 | message = dns.message.from_wire(data, **fwparams) | |
279 | return message | |
280 | ||
281 | @classmethod | |
282 | def sendTCPQuery(cls, query, timeout=2.0): | |
283 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
284 | if timeout: | |
285 | sock.settimeout(timeout) | |
286 | ||
287 | sock.connect(("127.0.0.1", cls._recursorPort)) | |
288 | ||
289 | try: | |
290 | wire = query.to_wire() | |
291 | sock.send(struct.pack("!H", len(wire))) | |
292 | sock.send(wire) | |
293 | data = sock.recv(2) | |
294 | if data: | |
295 | (datalen,) = struct.unpack("!H", data) | |
296 | data = sock.recv(datalen) | |
297 | except socket.timeout as e: | |
298 | print("Timeout: %s" % (str(e))) | |
299 | data = None | |
300 | except socket.error as e: | |
301 | print("Network error: %s" % (str(e))) | |
302 | data = None | |
303 | finally: | |
304 | sock.close() | |
305 | ||
306 | message = None | |
307 | if data: | |
308 | message = dns.message.from_wire(data) | |
309 | return message | |
310 | ||
311 | ||
312 | @classmethod | |
313 | def sendTCPQuery(cls, query, timeout=2.0): | |
314 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
315 | if timeout: | |
316 | sock.settimeout(timeout) | |
317 | ||
318 | sock.connect(("127.0.0.1", cls._authPort)) | |
319 | ||
320 | try: | |
321 | wire = query.to_wire() | |
322 | sock.send(struct.pack("!H", len(wire))) | |
323 | sock.send(wire) | |
324 | data = sock.recv(2) | |
325 | if data: | |
326 | (datalen,) = struct.unpack("!H", data) | |
327 | data = sock.recv(datalen) | |
328 | except socket.timeout as e: | |
329 | print("Timeout: %s" % (str(e))) | |
330 | data = None | |
331 | except socket.error as e: | |
332 | print("Network error: %s" % (str(e))) | |
333 | data = None | |
334 | finally: | |
335 | sock.close() | |
336 | ||
337 | message = None | |
338 | if data: | |
339 | message = dns.message.from_wire(data) | |
340 | return message | |
341 | ||
342 | def setUp(self): | |
343 | # This function is called before every tests | |
b7f9136f | 344 | super(AuthTest, self).setUp() |
1bc56192 CHB |
345 | |
346 | ## Functions for comparisons | |
347 | def assertMessageHasFlags(self, msg, flags, ednsflags=[]): | |
348 | """Asserts that msg has all the flags from flags set | |
349 | ||
350 | @param msg: the dns.message.Message to check | |
351 | @param flags: a list of strings with flag mnemonics (like ['RD', 'RA']) | |
352 | @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])""" | |
353 | ||
354 | if not isinstance(msg, dns.message.Message): | |
355 | raise TypeError("msg is not a dns.message.Message") | |
356 | ||
357 | if isinstance(flags, list): | |
358 | for elem in flags: | |
359 | if not isinstance(elem, str): | |
360 | raise TypeError("flags is not a list of strings") | |
361 | else: | |
362 | raise TypeError("flags is not a list of strings") | |
363 | ||
364 | if isinstance(ednsflags, list): | |
365 | for elem in ednsflags: | |
366 | if not isinstance(elem, str): | |
367 | raise TypeError("ednsflags is not a list of strings") | |
368 | else: | |
369 | raise TypeError("ednsflags is not a list of strings") | |
370 | ||
371 | msgFlags = dns.flags.to_text(msg.flags).split() | |
372 | missingFlags = [flag for flag in flags if flag not in msgFlags] | |
373 | ||
374 | msgEdnsFlags = dns.flags.edns_to_text(msg.ednsflags).split() | |
375 | missingEdnsFlags = [ednsflag for ednsflag in ednsflags if ednsflag not in msgEdnsFlags] | |
376 | ||
377 | if len(missingFlags) or len(missingEdnsFlags) or len(msgFlags) > len(flags): | |
378 | raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" % | |
379 | (' '.join(flags), ' '.join(ednsflags), | |
380 | ' '.join(msgFlags), ' '.join(msgEdnsFlags), | |
381 | msg.question[0])) | |
382 | ||
383 | def assertMessageIsAuthenticated(self, msg): | |
384 | """Asserts that the message has the AD bit set | |
385 | ||
386 | @param msg: the dns.message.Message to check""" | |
387 | ||
388 | if not isinstance(msg, dns.message.Message): | |
389 | raise TypeError("msg is not a dns.message.Message") | |
390 | ||
391 | msgFlags = dns.flags.to_text(msg.flags) | |
392 | self.assertTrue('AD' in msgFlags, "No AD flag found in the message for %s" % msg.question[0].name) | |
393 | ||
394 | def assertRRsetInAnswer(self, msg, rrset): | |
395 | """Asserts the rrset (without comparing TTL) exists in the | |
396 | answer section of msg | |
397 | ||
398 | @param msg: the dns.message.Message to check | |
399 | @param rrset: a dns.rrset.RRset object""" | |
400 | ||
401 | ret = '' | |
402 | if not isinstance(msg, dns.message.Message): | |
403 | raise TypeError("msg is not a dns.message.Message") | |
404 | ||
405 | if not isinstance(rrset, dns.rrset.RRset): | |
406 | raise TypeError("rrset is not a dns.rrset.RRset") | |
407 | ||
408 | found = False | |
409 | for ans in msg.answer: | |
410 | ret += "%s\n" % ans.to_text() | |
411 | if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None): | |
412 | self.assertEqual(ans, rrset, "'%s' != '%s'" % (ans.to_text(), rrset.to_text())) | |
413 | found = True | |
414 | ||
415 | if not found : | |
416 | raise AssertionError("RRset not found in answer\n\n%s" % ret) | |
417 | ||
4d4c592e PD |
418 | def sortRRsets(self, rrsets): |
419 | """Sorts RRsets in a more useful way than dnspython's default behaviour | |
420 | ||
421 | @param rrsets: an array of dns.rrset.RRset objects""" | |
422 | ||
423 | return sorted(rrsets, key=lambda rrset: (rrset.name, rrset.rdtype)) | |
424 | ||
1bc56192 CHB |
425 | def assertAnyRRsetInAnswer(self, msg, rrsets): |
426 | """Asserts that any of the supplied rrsets exists (without comparing TTL) | |
427 | in the answer section of msg | |
428 | ||
429 | @param msg: the dns.message.Message to check | |
430 | @param rrsets: an array of dns.rrset.RRset object""" | |
431 | ||
432 | if not isinstance(msg, dns.message.Message): | |
433 | raise TypeError("msg is not a dns.message.Message") | |
434 | ||
435 | found = False | |
436 | for rrset in rrsets: | |
437 | if not isinstance(rrset, dns.rrset.RRset): | |
438 | raise TypeError("rrset is not a dns.rrset.RRset") | |
439 | for ans in msg.answer: | |
440 | if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None): | |
441 | if ans == rrset: | |
442 | found = True | |
443 | ||
444 | if not found: | |
445 | raise AssertionError("RRset not found in answer\n%s" % | |
446 | "\n".join(([ans.to_text() for ans in msg.answer]))) | |
447 | ||
448 | def assertMatchingRRSIGInAnswer(self, msg, coveredRRset, keys=None): | |
449 | """Looks for coveredRRset in the answer section and if there is an RRSIG RRset | |
450 | that covers that RRset. If keys is not None, this function will also try to | |
451 | validate the RRset against the RRSIG | |
452 | ||
453 | @param msg: The dns.message.Message to check | |
454 | @param coveredRRset: The RRSet to check for | |
455 | @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation""" | |
456 | ||
457 | if not isinstance(msg, dns.message.Message): | |
458 | raise TypeError("msg is not a dns.message.Message") | |
459 | ||
460 | if not isinstance(coveredRRset, dns.rrset.RRset): | |
461 | raise TypeError("coveredRRset is not a dns.rrset.RRset") | |
462 | ||
463 | msgRRsigRRSet = None | |
464 | msgRRSet = None | |
465 | ||
466 | ret = '' | |
467 | for ans in msg.answer: | |
468 | ret += ans.to_text() + "\n" | |
469 | ||
470 | if ans.match(coveredRRset.name, coveredRRset.rdclass, coveredRRset.rdtype, 0, None): | |
471 | msgRRSet = ans | |
472 | if ans.match(coveredRRset.name, dns.rdataclass.IN, dns.rdatatype.RRSIG, coveredRRset.rdtype, None): | |
473 | msgRRsigRRSet = ans | |
474 | if msgRRSet and msgRRsigRRSet: | |
475 | break | |
476 | ||
477 | if not msgRRSet: | |
478 | raise AssertionError("RRset for '%s' not found in answer" % msg.question[0].to_text()) | |
479 | ||
480 | if not msgRRsigRRSet: | |
481 | raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg.question[0].to_text(), ret)) | |
482 | ||
483 | if keys: | |
484 | try: | |
485 | dns.dnssec.validate(msgRRSet, msgRRsigRRSet.to_rdataset(), keys) | |
486 | except dns.dnssec.ValidationFailure as e: | |
487 | raise AssertionError("Signature validation failed for %s:\n%s" % (msg.question[0].to_text(), e)) | |
488 | ||
489 | def assertNoRRSIGsInAnswer(self, msg): | |
490 | """Checks if there are _no_ RRSIGs in the answer section of msg""" | |
491 | ||
492 | if not isinstance(msg, dns.message.Message): | |
493 | raise TypeError("msg is not a dns.message.Message") | |
494 | ||
495 | ret = "" | |
496 | for ans in msg.answer: | |
497 | if ans.rdtype == dns.rdatatype.RRSIG: | |
498 | ret += ans.name.to_text() + "\n" | |
499 | ||
500 | if len(ret): | |
501 | raise AssertionError("RRSIG found in answers for:\n%s" % ret) | |
502 | ||
503 | def assertAnswerEmpty(self, msg): | |
504 | self.assertTrue(len(msg.answer) == 0, "Data found in the the answer section for %s:\n%s" % (msg.question[0].to_text(), '\n'.join([i.to_text() for i in msg.answer]))) | |
505 | ||
506 | def assertAnswerNotEmpty(self, msg): | |
507 | self.assertTrue(len(msg.answer) > 0, "Answer is empty") | |
508 | ||
509 | def assertRcodeEqual(self, msg, rcode): | |
510 | if not isinstance(msg, dns.message.Message): | |
511 | raise TypeError("msg is not a dns.message.Message but a %s" % type(msg)) | |
512 | ||
513 | if not isinstance(rcode, int): | |
514 | if isinstance(rcode, str): | |
515 | rcode = dns.rcode.from_text(rcode) | |
516 | else: | |
517 | raise TypeError("rcode is neither a str nor int") | |
518 | ||
519 | if msg.rcode() != rcode: | |
520 | msgRcode = dns.rcode._by_value[msg.rcode()] | |
521 | wantedRcode = dns.rcode._by_value[rcode] | |
522 | ||
523 | raise AssertionError("Rcode for %s is %s, expected %s." % (msg.question[0].to_text(), msgRcode, wantedRcode)) | |
524 | ||
525 | def assertAuthorityHasSOA(self, msg): | |
526 | if not isinstance(msg, dns.message.Message): | |
527 | raise TypeError("msg is not a dns.message.Message but a %s" % type(msg)) | |
528 | ||
529 | found = False | |
530 | for rrset in msg.authority: | |
531 | if rrset.rdtype == dns.rdatatype.SOA: | |
532 | found = True | |
533 | break | |
534 | ||
535 | if not found: | |
536 | raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text()) |