]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.auth-py/authtests.py
Merge pull request #7212 from Habbie/jdnssec-0.14
[thirdparty/pdns.git] / regression-tests.auth-py / authtests.py
1 #!/usr/bin/env python2
2
3 from __future__ import print_function
4 import errno
5 import shutil
6 import os
7 import socket
8 import struct
9 import subprocess
10 import sys
11 import time
12 import unittest
13 import dns
14 import dns.message
15
16 from pprint import pprint
17
18 class AuthTest(unittest.TestCase):
19 """
20 Setup auth required for the tests
21 """
22
23 _confdir = 'auth'
24 _authPort = 5300
25
26 _root_DS = "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
27
28 # The default SOA for zones in the authoritative servers
29 _SOA = "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
30
31 # The definitions of the zones on the authoritative servers, the key is the
32 # zonename and the value is the zonefile content. several strings are replaced:
33 # - {soa} => value of _SOA
34 # - {prefix} value of _PREFIX
35 _zones = {
36 'example.org': """
37 example.org. 3600 IN SOA {soa}
38 example.org. 3600 IN NS ns1.example.org.
39 example.org. 3600 IN NS ns2.example.org.
40 ns1.example.org. 3600 IN A {prefix}.10
41 ns2.example.org. 3600 IN A {prefix}.11
42 """,
43 }
44
45 _zone_keys = {
46 'example.org': """
47 Private-key-format: v1.2
48 Algorithm: 13 (ECDSAP256SHA256)
49 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
50 """,
51 }
52
53 _auth_cmd = ['authbind',
54 os.environ['PDNS']]
55 _auth_env = {}
56 _auths = {}
57
58 _PREFIX = os.environ['PREFIX']
59
60
61 @classmethod
62 def createConfigDir(cls, confdir):
63 try:
64 shutil.rmtree(confdir)
65 except OSError as e:
66 if e.errno != errno.ENOENT:
67 raise
68 os.mkdir(confdir, 0o755)
69
70 @classmethod
71 def generateAuthZone(cls, confdir, zonename, zonecontent):
72 with open(os.path.join(confdir, '%s.zone' % zonename), 'w') as zonefile:
73 zonefile.write(zonecontent.format(prefix=cls._PREFIX, soa=cls._SOA))
74
75 @classmethod
76 def generateAuthNamedConf(cls, confdir, zones):
77 with open(os.path.join(confdir, 'named.conf'), 'w') as namedconf:
78 namedconf.write("""
79 options {
80 directory "%s";
81 };""" % confdir)
82 for zonename in zones:
83 zone = '.' if zonename == 'ROOT' else zonename
84
85 namedconf.write("""
86 zone "%s" {
87 type master;
88 file "%s.zone";
89 };""" % (zone, zonename))
90
91 @classmethod
92 def generateAuthConfig(cls, confdir):
93 bind_dnssec_db = os.path.join(confdir, 'bind-dnssec.sqlite3')
94
95 with open(os.path.join(confdir, 'pdns.conf'), 'w') as pdnsconf:
96 pdnsconf.write("""
97 module-dir=../regression-tests/modules
98 launch=bind geoip
99 daemon=no
100 local-ipv6=
101 bind-config={confdir}/named.conf
102 bind-dnssec-db={bind_dnssec_db}
103 socket-dir={confdir}
104 cache-ttl=0
105 negquery-cache-ttl=0
106 query-cache-ttl=0
107 log-dns-queries=yes
108 log-dns-details=yes
109 loglevel=9
110 geoip-database-files=../modules/geoipbackend/regression-tests/GeoLiteCity.mmdb
111 edns-subnet-processing=yes
112 expand-alias=yes
113 resolver={prefix}.1:5301
114 any-to-tcp=no
115 distributor-threads=1""".format(confdir=confdir, prefix=cls._PREFIX,
116 bind_dnssec_db=bind_dnssec_db))
117
118 pdnsutilCmd = [os.environ['PDNSUTIL'],
119 '--config-dir=%s' % confdir,
120 'create-bind-db',
121 bind_dnssec_db]
122
123 print(' '.join(pdnsutilCmd))
124 try:
125 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
126 except subprocess.CalledProcessError as e:
127 print(e.output)
128 raise
129
130 @classmethod
131 def secureZone(cls, confdir, zonename, key=None):
132 zone = '.' if zonename == 'ROOT' else zonename
133 if not key:
134 pdnsutilCmd = [os.environ['PDNSUTIL'],
135 '--config-dir=%s' % confdir,
136 'secure-zone',
137 zone]
138 else:
139 keyfile = os.path.join(confdir, 'dnssec.key')
140 with open(keyfile, 'w') as fdKeyfile:
141 fdKeyfile.write(key)
142
143 pdnsutilCmd = [os.environ['PDNSUTIL'],
144 '--config-dir=%s' % confdir,
145 'import-zone-key',
146 zone,
147 keyfile,
148 'active',
149 'ksk']
150
151 print(' '.join(pdnsutilCmd))
152 try:
153 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
154 except subprocess.CalledProcessError as e:
155 print(e.output)
156 raise
157
158 @classmethod
159 def generateAllAuthConfig(cls, confdir):
160 if cls._zones:
161 cls.generateAuthConfig(confdir)
162 cls.generateAuthNamedConf(confdir, cls._zones.keys())
163
164 for zonename, zonecontent in cls._zones.items():
165 cls.generateAuthZone(confdir,
166 zonename,
167 zonecontent)
168 if cls._zone_keys.get(zonename, None):
169 cls.secureZone(confdir, zonename, cls._zone_keys.get(zonename))
170
171 @classmethod
172 def startAuth(cls, confdir, ipaddress):
173
174 print("Launching pdns_server..")
175 authcmd = list(cls._auth_cmd)
176 authcmd.append('--config-dir=%s' % confdir)
177 authcmd.append('--local-address=%s' % ipaddress)
178 authcmd.append('--local-port=%s' % cls._authPort)
179 authcmd.append('--loglevel=9')
180 authcmd.append('--enable-lua-record')
181 print(' '.join(authcmd))
182
183 logFile = os.path.join(confdir, 'pdns.log')
184 with open(logFile, 'w') as fdLog:
185 cls._auths[ipaddress] = subprocess.Popen(authcmd, close_fds=True,
186 stdout=fdLog, stderr=fdLog,
187 env=cls._auth_env)
188
189 time.sleep(2)
190
191 if cls._auths[ipaddress].poll() is not None:
192 try:
193 cls._auths[ipaddress].kill()
194 except OSError as e:
195 if e.errno != errno.ESRCH:
196 raise
197 with open(logFile, 'r') as fdLog:
198 print(fdLog.read())
199 sys.exit(cls._auths[ipaddress].returncode)
200
201 @classmethod
202 def setUpSockets(cls):
203 print("Setting up UDP socket..")
204 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
205 cls._sock.settimeout(2.0)
206 cls._sock.connect((cls._PREFIX + ".1", cls._authPort))
207
208 @classmethod
209 def startResponders(cls):
210 pass
211
212 @classmethod
213 def setUpClass(cls):
214 cls.setUpSockets()
215
216 cls.startResponders()
217
218 confdir = os.path.join('configs', cls._confdir)
219 cls.createConfigDir(confdir)
220
221 cls.generateAllAuthConfig(confdir)
222 cls.startAuth(confdir, cls._PREFIX + ".1")
223
224 print("Launching tests..")
225
226 @classmethod
227 def tearDownClass(cls):
228 cls.tearDownAuth()
229 cls.tearDownResponders()
230
231 @classmethod
232 def tearDownResponders(cls):
233 pass
234
235 @classmethod
236 def tearDownClass(cls):
237 cls.tearDownAuth()
238
239 @classmethod
240 def tearDownAuth(cls):
241 if 'PDNSRECURSOR_FAST_TESTS' in os.environ:
242 delay = 0.1
243 else:
244 delay = 1.0
245
246 for _, auth in cls._auths.items():
247 try:
248 auth.terminate()
249 if auth.poll() is None:
250 time.sleep(delay)
251 if auth.poll() is None:
252 auth.kill()
253 auth.wait()
254 except OSError as e:
255 if e.errno != errno.ESRCH:
256 raise
257
258 @classmethod
259 def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()):
260 if timeout:
261 cls._sock.settimeout(timeout)
262
263 try:
264 cls._sock.send(query.to_wire())
265 data = cls._sock.recv(4096)
266 except socket.timeout:
267 data = None
268 finally:
269 if timeout:
270 cls._sock.settimeout(None)
271
272 message = None
273 if data:
274 if not decode:
275 return data
276 message = dns.message.from_wire(data, **fwparams)
277 return message
278
279 @classmethod
280 def sendTCPQuery(cls, query, timeout=2.0):
281 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
282 if timeout:
283 sock.settimeout(timeout)
284
285 sock.connect(("127.0.0.1", cls._recursorPort))
286
287 try:
288 wire = query.to_wire()
289 sock.send(struct.pack("!H", len(wire)))
290 sock.send(wire)
291 data = sock.recv(2)
292 if data:
293 (datalen,) = struct.unpack("!H", data)
294 data = sock.recv(datalen)
295 except socket.timeout as e:
296 print("Timeout: %s" % (str(e)))
297 data = None
298 except socket.error as e:
299 print("Network error: %s" % (str(e)))
300 data = None
301 finally:
302 sock.close()
303
304 message = None
305 if data:
306 message = dns.message.from_wire(data)
307 return message
308
309
310 @classmethod
311 def sendTCPQuery(cls, query, timeout=2.0):
312 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
313 if timeout:
314 sock.settimeout(timeout)
315
316 sock.connect(("127.0.0.1", cls._authPort))
317
318 try:
319 wire = query.to_wire()
320 sock.send(struct.pack("!H", len(wire)))
321 sock.send(wire)
322 data = sock.recv(2)
323 if data:
324 (datalen,) = struct.unpack("!H", data)
325 data = sock.recv(datalen)
326 except socket.timeout as e:
327 print("Timeout: %s" % (str(e)))
328 data = None
329 except socket.error as e:
330 print("Network error: %s" % (str(e)))
331 data = None
332 finally:
333 sock.close()
334
335 message = None
336 if data:
337 message = dns.message.from_wire(data)
338 return message
339
340 def setUp(self):
341 # This function is called before every tests
342 return
343
344 ## Functions for comparisons
345 def assertMessageHasFlags(self, msg, flags, ednsflags=[]):
346 """Asserts that msg has all the flags from flags set
347
348 @param msg: the dns.message.Message to check
349 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
350 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
351
352 if not isinstance(msg, dns.message.Message):
353 raise TypeError("msg is not a dns.message.Message")
354
355 if isinstance(flags, list):
356 for elem in flags:
357 if not isinstance(elem, str):
358 raise TypeError("flags is not a list of strings")
359 else:
360 raise TypeError("flags is not a list of strings")
361
362 if isinstance(ednsflags, list):
363 for elem in ednsflags:
364 if not isinstance(elem, str):
365 raise TypeError("ednsflags is not a list of strings")
366 else:
367 raise TypeError("ednsflags is not a list of strings")
368
369 msgFlags = dns.flags.to_text(msg.flags).split()
370 missingFlags = [flag for flag in flags if flag not in msgFlags]
371
372 msgEdnsFlags = dns.flags.edns_to_text(msg.ednsflags).split()
373 missingEdnsFlags = [ednsflag for ednsflag in ednsflags if ednsflag not in msgEdnsFlags]
374
375 if len(missingFlags) or len(missingEdnsFlags) or len(msgFlags) > len(flags):
376 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
377 (' '.join(flags), ' '.join(ednsflags),
378 ' '.join(msgFlags), ' '.join(msgEdnsFlags),
379 msg.question[0]))
380
381 def assertMessageIsAuthenticated(self, msg):
382 """Asserts that the message has the AD bit set
383
384 @param msg: the dns.message.Message to check"""
385
386 if not isinstance(msg, dns.message.Message):
387 raise TypeError("msg is not a dns.message.Message")
388
389 msgFlags = dns.flags.to_text(msg.flags)
390 self.assertTrue('AD' in msgFlags, "No AD flag found in the message for %s" % msg.question[0].name)
391
392 def assertRRsetInAnswer(self, msg, rrset):
393 """Asserts the rrset (without comparing TTL) exists in the
394 answer section of msg
395
396 @param msg: the dns.message.Message to check
397 @param rrset: a dns.rrset.RRset object"""
398
399 ret = ''
400 if not isinstance(msg, dns.message.Message):
401 raise TypeError("msg is not a dns.message.Message")
402
403 if not isinstance(rrset, dns.rrset.RRset):
404 raise TypeError("rrset is not a dns.rrset.RRset")
405
406 found = False
407 for ans in msg.answer:
408 ret += "%s\n" % ans.to_text()
409 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
410 self.assertEqual(ans, rrset, "'%s' != '%s'" % (ans.to_text(), rrset.to_text()))
411 found = True
412
413 if not found :
414 raise AssertionError("RRset not found in answer\n\n%s" % ret)
415
416 def assertAnyRRsetInAnswer(self, msg, rrsets):
417 """Asserts that any of the supplied rrsets exists (without comparing TTL)
418 in the answer section of msg
419
420 @param msg: the dns.message.Message to check
421 @param rrsets: an array of dns.rrset.RRset object"""
422
423 if not isinstance(msg, dns.message.Message):
424 raise TypeError("msg is not a dns.message.Message")
425
426 found = False
427 for rrset in rrsets:
428 if not isinstance(rrset, dns.rrset.RRset):
429 raise TypeError("rrset is not a dns.rrset.RRset")
430 for ans in msg.answer:
431 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
432 if ans == rrset:
433 found = True
434
435 if not found:
436 raise AssertionError("RRset not found in answer\n%s" %
437 "\n".join(([ans.to_text() for ans in msg.answer])))
438
439 def assertMatchingRRSIGInAnswer(self, msg, coveredRRset, keys=None):
440 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
441 that covers that RRset. If keys is not None, this function will also try to
442 validate the RRset against the RRSIG
443
444 @param msg: The dns.message.Message to check
445 @param coveredRRset: The RRSet to check for
446 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
447
448 if not isinstance(msg, dns.message.Message):
449 raise TypeError("msg is not a dns.message.Message")
450
451 if not isinstance(coveredRRset, dns.rrset.RRset):
452 raise TypeError("coveredRRset is not a dns.rrset.RRset")
453
454 msgRRsigRRSet = None
455 msgRRSet = None
456
457 ret = ''
458 for ans in msg.answer:
459 ret += ans.to_text() + "\n"
460
461 if ans.match(coveredRRset.name, coveredRRset.rdclass, coveredRRset.rdtype, 0, None):
462 msgRRSet = ans
463 if ans.match(coveredRRset.name, dns.rdataclass.IN, dns.rdatatype.RRSIG, coveredRRset.rdtype, None):
464 msgRRsigRRSet = ans
465 if msgRRSet and msgRRsigRRSet:
466 break
467
468 if not msgRRSet:
469 raise AssertionError("RRset for '%s' not found in answer" % msg.question[0].to_text())
470
471 if not msgRRsigRRSet:
472 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg.question[0].to_text(), ret))
473
474 if keys:
475 try:
476 dns.dnssec.validate(msgRRSet, msgRRsigRRSet.to_rdataset(), keys)
477 except dns.dnssec.ValidationFailure as e:
478 raise AssertionError("Signature validation failed for %s:\n%s" % (msg.question[0].to_text(), e))
479
480 def assertNoRRSIGsInAnswer(self, msg):
481 """Checks if there are _no_ RRSIGs in the answer section of msg"""
482
483 if not isinstance(msg, dns.message.Message):
484 raise TypeError("msg is not a dns.message.Message")
485
486 ret = ""
487 for ans in msg.answer:
488 if ans.rdtype == dns.rdatatype.RRSIG:
489 ret += ans.name.to_text() + "\n"
490
491 if len(ret):
492 raise AssertionError("RRSIG found in answers for:\n%s" % ret)
493
494 def assertAnswerEmpty(self, msg):
495 self.assertTrue(len(msg.answer) == 0, "Data found in the the answer section for %s:\n%s" % (msg.question[0].to_text(), '\n'.join([i.to_text() for i in msg.answer])))
496
497 def assertAnswerNotEmpty(self, msg):
498 self.assertTrue(len(msg.answer) > 0, "Answer is empty")
499
500 def assertRcodeEqual(self, msg, rcode):
501 if not isinstance(msg, dns.message.Message):
502 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
503
504 if not isinstance(rcode, int):
505 if isinstance(rcode, str):
506 rcode = dns.rcode.from_text(rcode)
507 else:
508 raise TypeError("rcode is neither a str nor int")
509
510 if msg.rcode() != rcode:
511 msgRcode = dns.rcode._by_value[msg.rcode()]
512 wantedRcode = dns.rcode._by_value[rcode]
513
514 raise AssertionError("Rcode for %s is %s, expected %s." % (msg.question[0].to_text(), msgRcode, wantedRcode))
515
516 def assertAuthorityHasSOA(self, msg):
517 if not isinstance(msg, dns.message.Message):
518 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
519
520 found = False
521 for rrset in msg.authority:
522 if rrset.rdtype == dns.rdatatype.SOA:
523 found = True
524 break
525
526 if not found:
527 raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text())