]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.auth-py/authtests.py
Merge pull request #8331 from mind04/pdns-lmdb-cleanup
[thirdparty/pdns.git] / regression-tests.auth-py / authtests.py
1 #!/usr/bin/env python2
2
3 from __future__ import print_function
4 import errno
5 import shutil
6 import os
7 import socket
8 import struct
9 import subprocess
10 import sys
11 import time
12 import unittest
13 import dns
14 import dns.message
15
16 from pprint import pprint
17 from eqdnsmessage import AssertEqualDNSMessageMixin
18
19 class AuthTest(AssertEqualDNSMessageMixin, unittest.TestCase):
20 """
21 Setup auth required for the tests
22 """
23
24 _confdir = 'auth'
25 _authPort = 5300
26
27 _config_params = []
28
29 _config_template_default = """
30 module-dir=../regression-tests/modules
31 daemon=no
32 bind-config={confdir}/named.conf
33 bind-dnssec-db={bind_dnssec_db}
34 socket-dir={confdir}
35 cache-ttl=0
36 negquery-cache-ttl=0
37 query-cache-ttl=0
38 log-dns-queries=yes
39 log-dns-details=yes
40 loglevel=9
41 distributor-threads=1"""
42
43 _config_template = ""
44
45 _root_DS = "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
46
47 # The default SOA for zones in the authoritative servers
48 _SOA = "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
49
50 # The definitions of the zones on the authoritative servers, the key is the
51 # zonename and the value is the zonefile content. several strings are replaced:
52 # - {soa} => value of _SOA
53 # - {prefix} value of _PREFIX
54 _zones = {
55 'example.org': """
56 example.org. 3600 IN SOA {soa}
57 example.org. 3600 IN NS ns1.example.org.
58 example.org. 3600 IN NS ns2.example.org.
59 ns1.example.org. 3600 IN A {prefix}.10
60 ns2.example.org. 3600 IN A {prefix}.11
61 """,
62 }
63
64 _zone_keys = {
65 'example.org': """
66 Private-key-format: v1.2
67 Algorithm: 13 (ECDSAP256SHA256)
68 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
69 """,
70 }
71
72 _auth_cmd = ['authbind',
73 os.environ['PDNS']]
74 _auth_env = {}
75 _auths = {}
76
77 _PREFIX = os.environ['PREFIX']
78
79
80 @classmethod
81 def createConfigDir(cls, confdir):
82 try:
83 shutil.rmtree(confdir)
84 except OSError as e:
85 if e.errno != errno.ENOENT:
86 raise
87 os.mkdir(confdir, 0o755)
88
89 @classmethod
90 def generateAuthZone(cls, confdir, zonename, zonecontent):
91 with open(os.path.join(confdir, '%s.zone' % zonename), 'w') as zonefile:
92 zonefile.write(zonecontent.format(prefix=cls._PREFIX, soa=cls._SOA))
93
94 @classmethod
95 def generateAuthNamedConf(cls, confdir, zones):
96 with open(os.path.join(confdir, 'named.conf'), 'w') as namedconf:
97 namedconf.write("""
98 options {
99 directory "%s";
100 };""" % confdir)
101 for zonename in zones:
102 zone = '.' if zonename == 'ROOT' else zonename
103
104 namedconf.write("""
105 zone "%s" {
106 type master;
107 file "%s.zone";
108 };""" % (zone, zonename))
109
110 @classmethod
111 def generateAuthConfig(cls, confdir):
112 bind_dnssec_db = os.path.join(confdir, 'bind-dnssec.sqlite3')
113
114 params = tuple([getattr(cls, param) for param in cls._config_params])
115
116 with open(os.path.join(confdir, 'pdns.conf'), 'w') as pdnsconf:
117 pdnsconf.write(cls._config_template_default.format(
118 confdir=confdir, prefix=cls._PREFIX,
119 bind_dnssec_db=bind_dnssec_db))
120 pdnsconf.write(cls._config_template % params)
121
122 pdnsutilCmd = [os.environ['PDNSUTIL'],
123 '--config-dir=%s' % confdir,
124 'create-bind-db',
125 bind_dnssec_db]
126
127 print(' '.join(pdnsutilCmd))
128 try:
129 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
130 except subprocess.CalledProcessError as e:
131 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd, e.returncode, e.output))
132
133 @classmethod
134 def secureZone(cls, confdir, zonename, key=None):
135 zone = '.' if zonename == 'ROOT' else zonename
136 if not key:
137 pdnsutilCmd = [os.environ['PDNSUTIL'],
138 '--config-dir=%s' % confdir,
139 'secure-zone',
140 zone]
141 else:
142 keyfile = os.path.join(confdir, 'dnssec.key')
143 with open(keyfile, 'w') as fdKeyfile:
144 fdKeyfile.write(key)
145
146 pdnsutilCmd = [os.environ['PDNSUTIL'],
147 '--config-dir=%s' % confdir,
148 'import-zone-key',
149 zone,
150 keyfile,
151 'active',
152 'ksk']
153
154 print(' '.join(pdnsutilCmd))
155 try:
156 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
157 except subprocess.CalledProcessError as e:
158 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd, e.returncode, e.output))
159
160 @classmethod
161 def generateAllAuthConfig(cls, confdir):
162 if cls._zones:
163 cls.generateAuthConfig(confdir)
164 cls.generateAuthNamedConf(confdir, cls._zones.keys())
165
166 for zonename, zonecontent in cls._zones.items():
167 cls.generateAuthZone(confdir,
168 zonename,
169 zonecontent)
170 if cls._zone_keys.get(zonename, None):
171 cls.secureZone(confdir, zonename, cls._zone_keys.get(zonename))
172
173 @classmethod
174 def startAuth(cls, confdir, ipaddress):
175
176 print("Launching pdns_server..")
177 authcmd = list(cls._auth_cmd)
178 authcmd.append('--config-dir=%s' % confdir)
179 authcmd.append('--local-address=%s' % ipaddress)
180 authcmd.append('--local-port=%s' % cls._authPort)
181 authcmd.append('--loglevel=9')
182 authcmd.append('--enable-lua-records')
183 authcmd.append('--lua-health-checks-interval=1')
184 print(' '.join(authcmd))
185 logFile = os.path.join(confdir, 'pdns.log')
186 with open(logFile, 'w') as fdLog:
187 cls._auths[ipaddress] = subprocess.Popen(authcmd, close_fds=True,
188 stdout=fdLog, stderr=fdLog,
189 env=cls._auth_env)
190
191 time.sleep(2)
192
193 if cls._auths[ipaddress].poll() is not None:
194 try:
195 cls._auths[ipaddress].kill()
196 except OSError as e:
197 if e.errno != errno.ESRCH:
198 raise
199 with open(logFile, 'r') as fdLog:
200 print(fdLog.read())
201 sys.exit(cls._auths[ipaddress].returncode)
202
203 @classmethod
204 def setUpSockets(cls):
205 print("Setting up UDP socket..")
206 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
207 cls._sock.settimeout(2.0)
208 cls._sock.connect((cls._PREFIX + ".1", cls._authPort))
209
210 @classmethod
211 def startResponders(cls):
212 pass
213
214 @classmethod
215 def setUpClass(cls):
216 cls.setUpSockets()
217
218 cls.startResponders()
219
220 confdir = os.path.join('configs', cls._confdir)
221 cls.createConfigDir(confdir)
222
223 cls.generateAllAuthConfig(confdir)
224 cls.startAuth(confdir, cls._PREFIX + ".1")
225
226 print("Launching tests..")
227
228 @classmethod
229 def tearDownClass(cls):
230 cls.tearDownAuth()
231 cls.tearDownResponders()
232
233 @classmethod
234 def tearDownResponders(cls):
235 pass
236
237 @classmethod
238 def tearDownClass(cls):
239 cls.tearDownAuth()
240
241 @classmethod
242 def tearDownAuth(cls):
243 if 'PDNSRECURSOR_FAST_TESTS' in os.environ:
244 delay = 0.1
245 else:
246 delay = 1.0
247
248 for _, auth in cls._auths.items():
249 try:
250 auth.terminate()
251 if auth.poll() is None:
252 time.sleep(delay)
253 if auth.poll() is None:
254 auth.kill()
255 auth.wait()
256 except OSError as e:
257 if e.errno != errno.ESRCH:
258 raise
259
260 @classmethod
261 def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()):
262 if timeout:
263 cls._sock.settimeout(timeout)
264
265 try:
266 cls._sock.send(query.to_wire())
267 data = cls._sock.recv(4096)
268 except socket.timeout:
269 data = None
270 finally:
271 if timeout:
272 cls._sock.settimeout(None)
273
274 message = None
275 if data:
276 if not decode:
277 return data
278 message = dns.message.from_wire(data, **fwparams)
279 return message
280
281 @classmethod
282 def sendTCPQuery(cls, query, timeout=2.0):
283 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
284 if timeout:
285 sock.settimeout(timeout)
286
287 sock.connect(("127.0.0.1", cls._recursorPort))
288
289 try:
290 wire = query.to_wire()
291 sock.send(struct.pack("!H", len(wire)))
292 sock.send(wire)
293 data = sock.recv(2)
294 if data:
295 (datalen,) = struct.unpack("!H", data)
296 data = sock.recv(datalen)
297 except socket.timeout as e:
298 print("Timeout: %s" % (str(e)))
299 data = None
300 except socket.error as e:
301 print("Network error: %s" % (str(e)))
302 data = None
303 finally:
304 sock.close()
305
306 message = None
307 if data:
308 message = dns.message.from_wire(data)
309 return message
310
311
312 @classmethod
313 def sendTCPQuery(cls, query, timeout=2.0):
314 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
315 if timeout:
316 sock.settimeout(timeout)
317
318 sock.connect(("127.0.0.1", cls._authPort))
319
320 try:
321 wire = query.to_wire()
322 sock.send(struct.pack("!H", len(wire)))
323 sock.send(wire)
324 data = sock.recv(2)
325 if data:
326 (datalen,) = struct.unpack("!H", data)
327 data = sock.recv(datalen)
328 except socket.timeout as e:
329 print("Timeout: %s" % (str(e)))
330 data = None
331 except socket.error as e:
332 print("Network error: %s" % (str(e)))
333 data = None
334 finally:
335 sock.close()
336
337 message = None
338 if data:
339 message = dns.message.from_wire(data)
340 return message
341
342 def setUp(self):
343 # This function is called before every tests
344 super(AuthTest, self).setUp()
345
346 ## Functions for comparisons
347 def assertMessageHasFlags(self, msg, flags, ednsflags=[]):
348 """Asserts that msg has all the flags from flags set
349
350 @param msg: the dns.message.Message to check
351 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
352 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
353
354 if not isinstance(msg, dns.message.Message):
355 raise TypeError("msg is not a dns.message.Message")
356
357 if isinstance(flags, list):
358 for elem in flags:
359 if not isinstance(elem, str):
360 raise TypeError("flags is not a list of strings")
361 else:
362 raise TypeError("flags is not a list of strings")
363
364 if isinstance(ednsflags, list):
365 for elem in ednsflags:
366 if not isinstance(elem, str):
367 raise TypeError("ednsflags is not a list of strings")
368 else:
369 raise TypeError("ednsflags is not a list of strings")
370
371 msgFlags = dns.flags.to_text(msg.flags).split()
372 missingFlags = [flag for flag in flags if flag not in msgFlags]
373
374 msgEdnsFlags = dns.flags.edns_to_text(msg.ednsflags).split()
375 missingEdnsFlags = [ednsflag for ednsflag in ednsflags if ednsflag not in msgEdnsFlags]
376
377 if len(missingFlags) or len(missingEdnsFlags) or len(msgFlags) > len(flags):
378 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
379 (' '.join(flags), ' '.join(ednsflags),
380 ' '.join(msgFlags), ' '.join(msgEdnsFlags),
381 msg.question[0]))
382
383 def assertMessageIsAuthenticated(self, msg):
384 """Asserts that the message has the AD bit set
385
386 @param msg: the dns.message.Message to check"""
387
388 if not isinstance(msg, dns.message.Message):
389 raise TypeError("msg is not a dns.message.Message")
390
391 msgFlags = dns.flags.to_text(msg.flags)
392 self.assertTrue('AD' in msgFlags, "No AD flag found in the message for %s" % msg.question[0].name)
393
394 def assertRRsetInAnswer(self, msg, rrset):
395 """Asserts the rrset (without comparing TTL) exists in the
396 answer section of msg
397
398 @param msg: the dns.message.Message to check
399 @param rrset: a dns.rrset.RRset object"""
400
401 ret = ''
402 if not isinstance(msg, dns.message.Message):
403 raise TypeError("msg is not a dns.message.Message")
404
405 if not isinstance(rrset, dns.rrset.RRset):
406 raise TypeError("rrset is not a dns.rrset.RRset")
407
408 found = False
409 for ans in msg.answer:
410 ret += "%s\n" % ans.to_text()
411 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
412 self.assertEqual(ans, rrset, "'%s' != '%s'" % (ans.to_text(), rrset.to_text()))
413 found = True
414
415 if not found :
416 raise AssertionError("RRset not found in answer\n\n%s" % ret)
417
418 def sortRRsets(self, rrsets):
419 """Sorts RRsets in a more useful way than dnspython's default behaviour
420
421 @param rrsets: an array of dns.rrset.RRset objects"""
422
423 return sorted(rrsets, key=lambda rrset: (rrset.name, rrset.rdtype))
424
425 def assertAnyRRsetInAnswer(self, msg, rrsets):
426 """Asserts that any of the supplied rrsets exists (without comparing TTL)
427 in the answer section of msg
428
429 @param msg: the dns.message.Message to check
430 @param rrsets: an array of dns.rrset.RRset object"""
431
432 if not isinstance(msg, dns.message.Message):
433 raise TypeError("msg is not a dns.message.Message")
434
435 found = False
436 for rrset in rrsets:
437 if not isinstance(rrset, dns.rrset.RRset):
438 raise TypeError("rrset is not a dns.rrset.RRset")
439 for ans in msg.answer:
440 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
441 if ans == rrset:
442 found = True
443
444 if not found:
445 raise AssertionError("RRset not found in answer\n%s" %
446 "\n".join(([ans.to_text() for ans in msg.answer])))
447
448 def assertMatchingRRSIGInAnswer(self, msg, coveredRRset, keys=None):
449 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
450 that covers that RRset. If keys is not None, this function will also try to
451 validate the RRset against the RRSIG
452
453 @param msg: The dns.message.Message to check
454 @param coveredRRset: The RRSet to check for
455 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
456
457 if not isinstance(msg, dns.message.Message):
458 raise TypeError("msg is not a dns.message.Message")
459
460 if not isinstance(coveredRRset, dns.rrset.RRset):
461 raise TypeError("coveredRRset is not a dns.rrset.RRset")
462
463 msgRRsigRRSet = None
464 msgRRSet = None
465
466 ret = ''
467 for ans in msg.answer:
468 ret += ans.to_text() + "\n"
469
470 if ans.match(coveredRRset.name, coveredRRset.rdclass, coveredRRset.rdtype, 0, None):
471 msgRRSet = ans
472 if ans.match(coveredRRset.name, dns.rdataclass.IN, dns.rdatatype.RRSIG, coveredRRset.rdtype, None):
473 msgRRsigRRSet = ans
474 if msgRRSet and msgRRsigRRSet:
475 break
476
477 if not msgRRSet:
478 raise AssertionError("RRset for '%s' not found in answer" % msg.question[0].to_text())
479
480 if not msgRRsigRRSet:
481 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg.question[0].to_text(), ret))
482
483 if keys:
484 try:
485 dns.dnssec.validate(msgRRSet, msgRRsigRRSet.to_rdataset(), keys)
486 except dns.dnssec.ValidationFailure as e:
487 raise AssertionError("Signature validation failed for %s:\n%s" % (msg.question[0].to_text(), e))
488
489 def assertNoRRSIGsInAnswer(self, msg):
490 """Checks if there are _no_ RRSIGs in the answer section of msg"""
491
492 if not isinstance(msg, dns.message.Message):
493 raise TypeError("msg is not a dns.message.Message")
494
495 ret = ""
496 for ans in msg.answer:
497 if ans.rdtype == dns.rdatatype.RRSIG:
498 ret += ans.name.to_text() + "\n"
499
500 if len(ret):
501 raise AssertionError("RRSIG found in answers for:\n%s" % ret)
502
503 def assertAnswerEmpty(self, msg):
504 self.assertTrue(len(msg.answer) == 0, "Data found in the the answer section for %s:\n%s" % (msg.question[0].to_text(), '\n'.join([i.to_text() for i in msg.answer])))
505
506 def assertAnswerNotEmpty(self, msg):
507 self.assertTrue(len(msg.answer) > 0, "Answer is empty")
508
509 def assertRcodeEqual(self, msg, rcode):
510 if not isinstance(msg, dns.message.Message):
511 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
512
513 if not isinstance(rcode, int):
514 if isinstance(rcode, str):
515 rcode = dns.rcode.from_text(rcode)
516 else:
517 raise TypeError("rcode is neither a str nor int")
518
519 if msg.rcode() != rcode:
520 msgRcode = dns.rcode._by_value[msg.rcode()]
521 wantedRcode = dns.rcode._by_value[rcode]
522
523 raise AssertionError("Rcode for %s is %s, expected %s." % (msg.question[0].to_text(), msgRcode, wantedRcode))
524
525 def assertAuthorityHasSOA(self, msg):
526 if not isinstance(msg, dns.message.Message):
527 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
528
529 found = False
530 for rrset in msg.authority:
531 if rrset.rdtype == dns.rdatatype.SOA:
532 found = True
533 break
534
535 if not found:
536 raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text())