]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.auth-py/authtests.py
Merge pull request #7056 from mind04/rootkey
[thirdparty/pdns.git] / regression-tests.auth-py / authtests.py
1 #!/usr/bin/env python2
2
3 from __future__ import print_function
4 import errno
5 import shutil
6 import os
7 import socket
8 import struct
9 import subprocess
10 import sys
11 import time
12 import unittest
13 import dns
14 import dns.message
15
16 from pprint import pprint
17
18 class AuthTest(unittest.TestCase):
19 """
20 Setup auth required for the tests
21 """
22
23 _confdir = 'auth'
24 _authPort = 5300
25
26 _root_DS = "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
27
28 # The default SOA for zones in the authoritative servers
29 _SOA = "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
30
31 # The definitions of the zones on the authoritative servers, the key is the
32 # zonename and the value is the zonefile content. several strings are replaced:
33 # - {soa} => value of _SOA
34 # - {prefix} value of _PREFIX
35 _zones = {
36 'example.org': """
37 example.org. 3600 IN SOA {soa}
38 example.org. 3600 IN NS ns1.example.org.
39 example.org. 3600 IN NS ns2.example.org.
40 ns1.example.org. 3600 IN A {prefix}.10
41 ns2.example.org. 3600 IN A {prefix}.11
42 """,
43 }
44
45 _zone_keys = {
46 'example.org': """
47 Private-key-format: v1.2
48 Algorithm: 13 (ECDSAP256SHA256)
49 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
50 """,
51 }
52
53 _auth_cmd = ['authbind',
54 os.environ['PDNS']]
55 _auth_env = {}
56 _auths = {}
57
58 _PREFIX = os.environ['PREFIX']
59
60
61 @classmethod
62 def createConfigDir(cls, confdir):
63 try:
64 shutil.rmtree(confdir)
65 except OSError as e:
66 if e.errno != errno.ENOENT:
67 raise
68 os.mkdir(confdir, 0o755)
69
70 @classmethod
71 def generateAuthZone(cls, confdir, zonename, zonecontent):
72 with open(os.path.join(confdir, '%s.zone' % zonename), 'w') as zonefile:
73 zonefile.write(zonecontent.format(prefix=cls._PREFIX, soa=cls._SOA))
74
75 @classmethod
76 def generateAuthNamedConf(cls, confdir, zones):
77 with open(os.path.join(confdir, 'named.conf'), 'w') as namedconf:
78 namedconf.write("""
79 options {
80 directory "%s";
81 };""" % confdir)
82 for zonename in zones:
83 zone = '.' if zonename == 'ROOT' else zonename
84
85 namedconf.write("""
86 zone "%s" {
87 type master;
88 file "%s.zone";
89 };""" % (zone, zonename))
90
91 @classmethod
92 def generateAuthConfig(cls, confdir):
93 bind_dnssec_db = os.path.join(confdir, 'bind-dnssec.sqlite3')
94
95 with open(os.path.join(confdir, 'pdns.conf'), 'w') as pdnsconf:
96 pdnsconf.write("""
97 module-dir=../regression-tests/modules
98 launch=bind geoip
99 daemon=no
100 local-ipv6=
101 bind-config={confdir}/named.conf
102 bind-dnssec-db={bind_dnssec_db}
103 socket-dir={confdir}
104 cache-ttl=0
105 negquery-cache-ttl=0
106 query-cache-ttl=0
107 log-dns-queries=yes
108 log-dns-details=yes
109 loglevel=9
110 geoip-database-files=../modules/geoipbackend/regression-tests/GeoLiteCity.mmdb
111 edns-subnet-processing=yes
112 expand-alias=yes
113 resolver={prefix}.1:5301
114 any-to-tcp=no
115 distributor-threads=1""".format(confdir=confdir, prefix=cls._PREFIX,
116 bind_dnssec_db=bind_dnssec_db))
117
118 pdnsutilCmd = [os.environ['PDNSUTIL'],
119 '--config-dir=%s' % confdir,
120 'create-bind-db',
121 bind_dnssec_db]
122
123 print(' '.join(pdnsutilCmd))
124 try:
125 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
126 except subprocess.CalledProcessError as e:
127 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd, e.returncode, e.output))
128
129 @classmethod
130 def secureZone(cls, confdir, zonename, key=None):
131 zone = '.' if zonename == 'ROOT' else zonename
132 if not key:
133 pdnsutilCmd = [os.environ['PDNSUTIL'],
134 '--config-dir=%s' % confdir,
135 'secure-zone',
136 zone]
137 else:
138 keyfile = os.path.join(confdir, 'dnssec.key')
139 with open(keyfile, 'w') as fdKeyfile:
140 fdKeyfile.write(key)
141
142 pdnsutilCmd = [os.environ['PDNSUTIL'],
143 '--config-dir=%s' % confdir,
144 'import-zone-key',
145 zone,
146 keyfile,
147 'active',
148 'ksk']
149
150 print(' '.join(pdnsutilCmd))
151 try:
152 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
153 except subprocess.CalledProcessError as e:
154 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd, e.returncode, e.output))
155
156 @classmethod
157 def generateAllAuthConfig(cls, confdir):
158 if cls._zones:
159 cls.generateAuthConfig(confdir)
160 cls.generateAuthNamedConf(confdir, cls._zones.keys())
161
162 for zonename, zonecontent in cls._zones.items():
163 cls.generateAuthZone(confdir,
164 zonename,
165 zonecontent)
166 if cls._zone_keys.get(zonename, None):
167 cls.secureZone(confdir, zonename, cls._zone_keys.get(zonename))
168
169 @classmethod
170 def startAuth(cls, confdir, ipaddress):
171
172 print("Launching pdns_server..")
173 authcmd = list(cls._auth_cmd)
174 authcmd.append('--config-dir=%s' % confdir)
175 authcmd.append('--local-address=%s' % ipaddress)
176 authcmd.append('--local-port=%s' % cls._authPort)
177 authcmd.append('--loglevel=9')
178 authcmd.append('--enable-lua-records')
179 print(' '.join(authcmd))
180
181 logFile = os.path.join(confdir, 'pdns.log')
182 with open(logFile, 'w') as fdLog:
183 cls._auths[ipaddress] = subprocess.Popen(authcmd, close_fds=True,
184 stdout=fdLog, stderr=fdLog,
185 env=cls._auth_env)
186
187 time.sleep(2)
188
189 if cls._auths[ipaddress].poll() is not None:
190 try:
191 cls._auths[ipaddress].kill()
192 except OSError as e:
193 if e.errno != errno.ESRCH:
194 raise
195 with open(logFile, 'r') as fdLog:
196 print(fdLog.read())
197 sys.exit(cls._auths[ipaddress].returncode)
198
199 @classmethod
200 def setUpSockets(cls):
201 print("Setting up UDP socket..")
202 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
203 cls._sock.settimeout(2.0)
204 cls._sock.connect((cls._PREFIX + ".1", cls._authPort))
205
206 @classmethod
207 def startResponders(cls):
208 pass
209
210 @classmethod
211 def setUpClass(cls):
212 cls.setUpSockets()
213
214 cls.startResponders()
215
216 confdir = os.path.join('configs', cls._confdir)
217 cls.createConfigDir(confdir)
218
219 cls.generateAllAuthConfig(confdir)
220 cls.startAuth(confdir, cls._PREFIX + ".1")
221
222 print("Launching tests..")
223
224 @classmethod
225 def tearDownClass(cls):
226 cls.tearDownAuth()
227 cls.tearDownResponders()
228
229 @classmethod
230 def tearDownResponders(cls):
231 pass
232
233 @classmethod
234 def tearDownClass(cls):
235 cls.tearDownAuth()
236
237 @classmethod
238 def tearDownAuth(cls):
239 if 'PDNSRECURSOR_FAST_TESTS' in os.environ:
240 delay = 0.1
241 else:
242 delay = 1.0
243
244 for _, auth in cls._auths.items():
245 try:
246 auth.terminate()
247 if auth.poll() is None:
248 time.sleep(delay)
249 if auth.poll() is None:
250 auth.kill()
251 auth.wait()
252 except OSError as e:
253 if e.errno != errno.ESRCH:
254 raise
255
256 @classmethod
257 def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()):
258 if timeout:
259 cls._sock.settimeout(timeout)
260
261 try:
262 cls._sock.send(query.to_wire())
263 data = cls._sock.recv(4096)
264 except socket.timeout:
265 data = None
266 finally:
267 if timeout:
268 cls._sock.settimeout(None)
269
270 message = None
271 if data:
272 if not decode:
273 return data
274 message = dns.message.from_wire(data, **fwparams)
275 return message
276
277 @classmethod
278 def sendTCPQuery(cls, query, timeout=2.0):
279 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
280 if timeout:
281 sock.settimeout(timeout)
282
283 sock.connect(("127.0.0.1", cls._recursorPort))
284
285 try:
286 wire = query.to_wire()
287 sock.send(struct.pack("!H", len(wire)))
288 sock.send(wire)
289 data = sock.recv(2)
290 if data:
291 (datalen,) = struct.unpack("!H", data)
292 data = sock.recv(datalen)
293 except socket.timeout as e:
294 print("Timeout: %s" % (str(e)))
295 data = None
296 except socket.error as e:
297 print("Network error: %s" % (str(e)))
298 data = None
299 finally:
300 sock.close()
301
302 message = None
303 if data:
304 message = dns.message.from_wire(data)
305 return message
306
307
308 @classmethod
309 def sendTCPQuery(cls, query, timeout=2.0):
310 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
311 if timeout:
312 sock.settimeout(timeout)
313
314 sock.connect(("127.0.0.1", cls._authPort))
315
316 try:
317 wire = query.to_wire()
318 sock.send(struct.pack("!H", len(wire)))
319 sock.send(wire)
320 data = sock.recv(2)
321 if data:
322 (datalen,) = struct.unpack("!H", data)
323 data = sock.recv(datalen)
324 except socket.timeout as e:
325 print("Timeout: %s" % (str(e)))
326 data = None
327 except socket.error as e:
328 print("Network error: %s" % (str(e)))
329 data = None
330 finally:
331 sock.close()
332
333 message = None
334 if data:
335 message = dns.message.from_wire(data)
336 return message
337
338 def setUp(self):
339 # This function is called before every tests
340 return
341
342 ## Functions for comparisons
343 def assertMessageHasFlags(self, msg, flags, ednsflags=[]):
344 """Asserts that msg has all the flags from flags set
345
346 @param msg: the dns.message.Message to check
347 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
348 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
349
350 if not isinstance(msg, dns.message.Message):
351 raise TypeError("msg is not a dns.message.Message")
352
353 if isinstance(flags, list):
354 for elem in flags:
355 if not isinstance(elem, str):
356 raise TypeError("flags is not a list of strings")
357 else:
358 raise TypeError("flags is not a list of strings")
359
360 if isinstance(ednsflags, list):
361 for elem in ednsflags:
362 if not isinstance(elem, str):
363 raise TypeError("ednsflags is not a list of strings")
364 else:
365 raise TypeError("ednsflags is not a list of strings")
366
367 msgFlags = dns.flags.to_text(msg.flags).split()
368 missingFlags = [flag for flag in flags if flag not in msgFlags]
369
370 msgEdnsFlags = dns.flags.edns_to_text(msg.ednsflags).split()
371 missingEdnsFlags = [ednsflag for ednsflag in ednsflags if ednsflag not in msgEdnsFlags]
372
373 if len(missingFlags) or len(missingEdnsFlags) or len(msgFlags) > len(flags):
374 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
375 (' '.join(flags), ' '.join(ednsflags),
376 ' '.join(msgFlags), ' '.join(msgEdnsFlags),
377 msg.question[0]))
378
379 def assertMessageIsAuthenticated(self, msg):
380 """Asserts that the message has the AD bit set
381
382 @param msg: the dns.message.Message to check"""
383
384 if not isinstance(msg, dns.message.Message):
385 raise TypeError("msg is not a dns.message.Message")
386
387 msgFlags = dns.flags.to_text(msg.flags)
388 self.assertTrue('AD' in msgFlags, "No AD flag found in the message for %s" % msg.question[0].name)
389
390 def assertRRsetInAnswer(self, msg, rrset):
391 """Asserts the rrset (without comparing TTL) exists in the
392 answer section of msg
393
394 @param msg: the dns.message.Message to check
395 @param rrset: a dns.rrset.RRset object"""
396
397 ret = ''
398 if not isinstance(msg, dns.message.Message):
399 raise TypeError("msg is not a dns.message.Message")
400
401 if not isinstance(rrset, dns.rrset.RRset):
402 raise TypeError("rrset is not a dns.rrset.RRset")
403
404 found = False
405 for ans in msg.answer:
406 ret += "%s\n" % ans.to_text()
407 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
408 self.assertEqual(ans, rrset, "'%s' != '%s'" % (ans.to_text(), rrset.to_text()))
409 found = True
410
411 if not found :
412 raise AssertionError("RRset not found in answer\n\n%s" % ret)
413
414 def sortRRsets(self, rrsets):
415 """Sorts RRsets in a more useful way than dnspython's default behaviour
416
417 @param rrsets: an array of dns.rrset.RRset objects"""
418
419 return sorted(rrsets, key=lambda rrset: (rrset.name, rrset.rdtype))
420
421 def assertAnyRRsetInAnswer(self, msg, rrsets):
422 """Asserts that any of the supplied rrsets exists (without comparing TTL)
423 in the answer section of msg
424
425 @param msg: the dns.message.Message to check
426 @param rrsets: an array of dns.rrset.RRset object"""
427
428 if not isinstance(msg, dns.message.Message):
429 raise TypeError("msg is not a dns.message.Message")
430
431 found = False
432 for rrset in rrsets:
433 if not isinstance(rrset, dns.rrset.RRset):
434 raise TypeError("rrset is not a dns.rrset.RRset")
435 for ans in msg.answer:
436 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
437 if ans == rrset:
438 found = True
439
440 if not found:
441 raise AssertionError("RRset not found in answer\n%s" %
442 "\n".join(([ans.to_text() for ans in msg.answer])))
443
444 def assertMatchingRRSIGInAnswer(self, msg, coveredRRset, keys=None):
445 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
446 that covers that RRset. If keys is not None, this function will also try to
447 validate the RRset against the RRSIG
448
449 @param msg: The dns.message.Message to check
450 @param coveredRRset: The RRSet to check for
451 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
452
453 if not isinstance(msg, dns.message.Message):
454 raise TypeError("msg is not a dns.message.Message")
455
456 if not isinstance(coveredRRset, dns.rrset.RRset):
457 raise TypeError("coveredRRset is not a dns.rrset.RRset")
458
459 msgRRsigRRSet = None
460 msgRRSet = None
461
462 ret = ''
463 for ans in msg.answer:
464 ret += ans.to_text() + "\n"
465
466 if ans.match(coveredRRset.name, coveredRRset.rdclass, coveredRRset.rdtype, 0, None):
467 msgRRSet = ans
468 if ans.match(coveredRRset.name, dns.rdataclass.IN, dns.rdatatype.RRSIG, coveredRRset.rdtype, None):
469 msgRRsigRRSet = ans
470 if msgRRSet and msgRRsigRRSet:
471 break
472
473 if not msgRRSet:
474 raise AssertionError("RRset for '%s' not found in answer" % msg.question[0].to_text())
475
476 if not msgRRsigRRSet:
477 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg.question[0].to_text(), ret))
478
479 if keys:
480 try:
481 dns.dnssec.validate(msgRRSet, msgRRsigRRSet.to_rdataset(), keys)
482 except dns.dnssec.ValidationFailure as e:
483 raise AssertionError("Signature validation failed for %s:\n%s" % (msg.question[0].to_text(), e))
484
485 def assertNoRRSIGsInAnswer(self, msg):
486 """Checks if there are _no_ RRSIGs in the answer section of msg"""
487
488 if not isinstance(msg, dns.message.Message):
489 raise TypeError("msg is not a dns.message.Message")
490
491 ret = ""
492 for ans in msg.answer:
493 if ans.rdtype == dns.rdatatype.RRSIG:
494 ret += ans.name.to_text() + "\n"
495
496 if len(ret):
497 raise AssertionError("RRSIG found in answers for:\n%s" % ret)
498
499 def assertAnswerEmpty(self, msg):
500 self.assertTrue(len(msg.answer) == 0, "Data found in the the answer section for %s:\n%s" % (msg.question[0].to_text(), '\n'.join([i.to_text() for i in msg.answer])))
501
502 def assertAnswerNotEmpty(self, msg):
503 self.assertTrue(len(msg.answer) > 0, "Answer is empty")
504
505 def assertRcodeEqual(self, msg, rcode):
506 if not isinstance(msg, dns.message.Message):
507 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
508
509 if not isinstance(rcode, int):
510 if isinstance(rcode, str):
511 rcode = dns.rcode.from_text(rcode)
512 else:
513 raise TypeError("rcode is neither a str nor int")
514
515 if msg.rcode() != rcode:
516 msgRcode = dns.rcode._by_value[msg.rcode()]
517 wantedRcode = dns.rcode._by_value[rcode]
518
519 raise AssertionError("Rcode for %s is %s, expected %s." % (msg.question[0].to_text(), msgRcode, wantedRcode))
520
521 def assertAuthorityHasSOA(self, msg):
522 if not isinstance(msg, dns.message.Message):
523 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
524
525 found = False
526 for rrset in msg.authority:
527 if rrset.rdtype == dns.rdatatype.SOA:
528 found = True
529 break
530
531 if not found:
532 raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text())