]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.auth-py/authtests.py
Merge pull request #6823 from klaus3000/load-ourserial-on-NOTIFY
[thirdparty/pdns.git] / regression-tests.auth-py / authtests.py
1 #!/usr/bin/env python2
2
3 from __future__ import print_function
4 import errno
5 import shutil
6 import os
7 import socket
8 import struct
9 import subprocess
10 import sys
11 import time
12 import unittest
13 import dns
14 import dns.message
15
16 from pprint import pprint
17
18 class AuthTest(unittest.TestCase):
19 """
20 Setup auth required for the tests
21 """
22
23 _confdir = 'auth'
24 _authPort = 5300
25
26 _root_DS = "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
27
28 # The default SOA for zones in the authoritative servers
29 _SOA = "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
30
31 # The definitions of the zones on the authoritative servers, the key is the
32 # zonename and the value is the zonefile content. several strings are replaced:
33 # - {soa} => value of _SOA
34 # - {prefix} value of _PREFIX
35 _zones = {
36 'example.org': """
37 example.org. 3600 IN SOA {soa}
38 example.org. 3600 IN NS ns1.example.org.
39 example.org. 3600 IN NS ns2.example.org.
40 ns1.example.org. 3600 IN A {prefix}.10
41 ns2.example.org. 3600 IN A {prefix}.11
42 """,
43 }
44
45 _zone_keys = {
46 'example.org': """
47 Private-key-format: v1.2
48 Algorithm: 13 (ECDSAP256SHA256)
49 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
50 """,
51 }
52
53 _auth_cmd = ['authbind',
54 os.environ['PDNS']]
55 _auth_env = {}
56 _auths = {}
57
58 _PREFIX = os.environ['PREFIX']
59
60
61 @classmethod
62 def createConfigDir(cls, confdir):
63 try:
64 shutil.rmtree(confdir)
65 except OSError as e:
66 if e.errno != errno.ENOENT:
67 raise
68 os.mkdir(confdir, 0o755)
69
70 @classmethod
71 def generateAuthZone(cls, confdir, zonename, zonecontent):
72 with open(os.path.join(confdir, '%s.zone' % zonename), 'w') as zonefile:
73 zonefile.write(zonecontent.format(prefix=cls._PREFIX, soa=cls._SOA))
74
75 @classmethod
76 def generateAuthNamedConf(cls, confdir, zones):
77 with open(os.path.join(confdir, 'named.conf'), 'w') as namedconf:
78 namedconf.write("""
79 options {
80 directory "%s";
81 };""" % confdir)
82 for zonename in zones:
83 zone = '.' if zonename == 'ROOT' else zonename
84
85 namedconf.write("""
86 zone "%s" {
87 type master;
88 file "%s.zone";
89 };""" % (zone, zonename))
90
91 @classmethod
92 def generateAuthConfig(cls, confdir):
93 bind_dnssec_db = os.path.join(confdir, 'bind-dnssec.sqlite3')
94
95 with open(os.path.join(confdir, 'pdns.conf'), 'w') as pdnsconf:
96 pdnsconf.write("""
97 module-dir=../regression-tests/modules
98 launch=bind geoip
99 daemon=no
100 local-ipv6=
101 bind-config={confdir}/named.conf
102 bind-dnssec-db={bind_dnssec_db}
103 socket-dir={confdir}
104 cache-ttl=0
105 negquery-cache-ttl=0
106 query-cache-ttl=0
107 log-dns-queries=yes
108 log-dns-details=yes
109 loglevel=9
110 geoip-database-files=../modules/geoipbackend/regression-tests/GeoLiteCity.mmdb
111 edns-subnet-processing=yes
112 expand-alias=yes
113 resolver={prefix}.1:5301
114 any-to-tcp=no
115 distributor-threads=1""".format(confdir=confdir, prefix=cls._PREFIX,
116 bind_dnssec_db=bind_dnssec_db))
117
118 pdnsutilCmd = [os.environ['PDNSUTIL'],
119 '--config-dir=%s' % confdir,
120 'create-bind-db',
121 bind_dnssec_db]
122
123 print(' '.join(pdnsutilCmd))
124 try:
125 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
126 except subprocess.CalledProcessError as e:
127 print(e.output)
128 raise
129
130 @classmethod
131 def secureZone(cls, confdir, zonename, key=None):
132 zone = '.' if zonename == 'ROOT' else zonename
133 if not key:
134 pdnsutilCmd = [os.environ['PDNSUTIL'],
135 '--config-dir=%s' % confdir,
136 'secure-zone',
137 zone]
138 else:
139 keyfile = os.path.join(confdir, 'dnssec.key')
140 with open(keyfile, 'w') as fdKeyfile:
141 fdKeyfile.write(key)
142
143 pdnsutilCmd = [os.environ['PDNSUTIL'],
144 '--config-dir=%s' % confdir,
145 'import-zone-key',
146 zone,
147 keyfile,
148 'active',
149 'ksk']
150
151 print(' '.join(pdnsutilCmd))
152 try:
153 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
154 except subprocess.CalledProcessError as e:
155 print(e.output)
156 raise
157
158 @classmethod
159 def generateAllAuthConfig(cls, confdir):
160 if cls._zones:
161 cls.generateAuthConfig(confdir)
162 cls.generateAuthNamedConf(confdir, cls._zones.keys())
163
164 for zonename, zonecontent in cls._zones.items():
165 cls.generateAuthZone(confdir,
166 zonename,
167 zonecontent)
168 if cls._zone_keys.get(zonename, None):
169 cls.secureZone(confdir, zonename, cls._zone_keys.get(zonename))
170
171 @classmethod
172 def startAuth(cls, confdir, ipaddress):
173
174 print("Launching pdns_server..")
175 authcmd = list(cls._auth_cmd)
176 authcmd.append('--config-dir=%s' % confdir)
177 authcmd.append('--local-address=%s' % ipaddress)
178 authcmd.append('--local-port=%s' % cls._authPort)
179 authcmd.append('--loglevel=9')
180 authcmd.append('--enable-lua-record')
181 print(' '.join(authcmd))
182
183 logFile = os.path.join(confdir, 'pdns.log')
184 with open(logFile, 'w') as fdLog:
185 cls._auths[ipaddress] = subprocess.Popen(authcmd, close_fds=True,
186 stdout=fdLog, stderr=fdLog,
187 env=cls._auth_env)
188
189 time.sleep(2)
190
191 if cls._auths[ipaddress].poll() is not None:
192 try:
193 cls._auths[ipaddress].kill()
194 except OSError as e:
195 if e.errno != errno.ESRCH:
196 raise
197 with open(logFile, 'r') as fdLog:
198 print(fdLog.read())
199 sys.exit(cls._auths[ipaddress].returncode)
200
201 @classmethod
202 def setUpSockets(cls):
203 print("Setting up UDP socket..")
204 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
205 cls._sock.settimeout(2.0)
206 cls._sock.connect((cls._PREFIX + ".1", cls._authPort))
207
208 @classmethod
209 def startResponders(cls):
210 pass
211
212 @classmethod
213 def setUpClass(cls):
214 cls.setUpSockets()
215
216 cls.startResponders()
217
218 confdir = os.path.join('configs', cls._confdir)
219 cls.createConfigDir(confdir)
220
221 cls.generateAllAuthConfig(confdir)
222 cls.startAuth(confdir, cls._PREFIX + ".1")
223
224 print("Launching tests..")
225
226 @classmethod
227 def tearDownClass(cls):
228 cls.tearDownAuth()
229 cls.tearDownResponders()
230
231 @classmethod
232 def tearDownResponders(cls):
233 pass
234
235 @classmethod
236 def tearDownClass(cls):
237 cls.tearDownAuth()
238
239 @classmethod
240 def tearDownAuth(cls):
241 if 'PDNSRECURSOR_FAST_TESTS' in os.environ:
242 delay = 0.1
243 else:
244 delay = 1.0
245
246 for _, auth in cls._auths.items():
247 try:
248 auth.terminate()
249 if auth.poll() is None:
250 time.sleep(delay)
251 if auth.poll() is None:
252 auth.kill()
253 auth.wait()
254 except OSError as e:
255 if e.errno != errno.ESRCH:
256 raise
257
258 @classmethod
259 def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()):
260 if timeout:
261 cls._sock.settimeout(timeout)
262
263 try:
264 cls._sock.send(query.to_wire())
265 data = cls._sock.recv(4096)
266 except socket.timeout:
267 data = None
268 finally:
269 if timeout:
270 cls._sock.settimeout(None)
271
272 message = None
273 if data:
274 if not decode:
275 return data
276 message = dns.message.from_wire(data, **fwparams)
277 return message
278
279 @classmethod
280 def sendTCPQuery(cls, query, timeout=2.0):
281 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
282 if timeout:
283 sock.settimeout(timeout)
284
285 sock.connect(("127.0.0.1", cls._recursorPort))
286
287 try:
288 wire = query.to_wire()
289 sock.send(struct.pack("!H", len(wire)))
290 sock.send(wire)
291 data = sock.recv(2)
292 if data:
293 (datalen,) = struct.unpack("!H", data)
294 data = sock.recv(datalen)
295 except socket.timeout as e:
296 print("Timeout: %s" % (str(e)))
297 data = None
298 except socket.error as e:
299 print("Network error: %s" % (str(e)))
300 data = None
301 finally:
302 sock.close()
303
304 message = None
305 if data:
306 message = dns.message.from_wire(data)
307 return message
308
309
310 @classmethod
311 def sendTCPQuery(cls, query, timeout=2.0):
312 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
313 if timeout:
314 sock.settimeout(timeout)
315
316 sock.connect(("127.0.0.1", cls._authPort))
317
318 try:
319 wire = query.to_wire()
320 sock.send(struct.pack("!H", len(wire)))
321 sock.send(wire)
322 data = sock.recv(2)
323 if data:
324 (datalen,) = struct.unpack("!H", data)
325 data = sock.recv(datalen)
326 except socket.timeout as e:
327 print("Timeout: %s" % (str(e)))
328 data = None
329 except socket.error as e:
330 print("Network error: %s" % (str(e)))
331 data = None
332 finally:
333 sock.close()
334
335 message = None
336 if data:
337 message = dns.message.from_wire(data)
338 return message
339
340 def setUp(self):
341 # This function is called before every tests
342 return
343
344 ## Functions for comparisons
345 def assertMessageHasFlags(self, msg, flags, ednsflags=[]):
346 """Asserts that msg has all the flags from flags set
347
348 @param msg: the dns.message.Message to check
349 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
350 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
351
352 if not isinstance(msg, dns.message.Message):
353 raise TypeError("msg is not a dns.message.Message")
354
355 if isinstance(flags, list):
356 for elem in flags:
357 if not isinstance(elem, str):
358 raise TypeError("flags is not a list of strings")
359 else:
360 raise TypeError("flags is not a list of strings")
361
362 if isinstance(ednsflags, list):
363 for elem in ednsflags:
364 if not isinstance(elem, str):
365 raise TypeError("ednsflags is not a list of strings")
366 else:
367 raise TypeError("ednsflags is not a list of strings")
368
369 msgFlags = dns.flags.to_text(msg.flags).split()
370 missingFlags = [flag for flag in flags if flag not in msgFlags]
371
372 msgEdnsFlags = dns.flags.edns_to_text(msg.ednsflags).split()
373 missingEdnsFlags = [ednsflag for ednsflag in ednsflags if ednsflag not in msgEdnsFlags]
374
375 if len(missingFlags) or len(missingEdnsFlags) or len(msgFlags) > len(flags):
376 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
377 (' '.join(flags), ' '.join(ednsflags),
378 ' '.join(msgFlags), ' '.join(msgEdnsFlags),
379 msg.question[0]))
380
381 def assertMessageIsAuthenticated(self, msg):
382 """Asserts that the message has the AD bit set
383
384 @param msg: the dns.message.Message to check"""
385
386 if not isinstance(msg, dns.message.Message):
387 raise TypeError("msg is not a dns.message.Message")
388
389 msgFlags = dns.flags.to_text(msg.flags)
390 self.assertTrue('AD' in msgFlags, "No AD flag found in the message for %s" % msg.question[0].name)
391
392 def assertRRsetInAnswer(self, msg, rrset):
393 """Asserts the rrset (without comparing TTL) exists in the
394 answer section of msg
395
396 @param msg: the dns.message.Message to check
397 @param rrset: a dns.rrset.RRset object"""
398
399 ret = ''
400 if not isinstance(msg, dns.message.Message):
401 raise TypeError("msg is not a dns.message.Message")
402
403 if not isinstance(rrset, dns.rrset.RRset):
404 raise TypeError("rrset is not a dns.rrset.RRset")
405
406 found = False
407 for ans in msg.answer:
408 ret += "%s\n" % ans.to_text()
409 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
410 self.assertEqual(ans, rrset, "'%s' != '%s'" % (ans.to_text(), rrset.to_text()))
411 found = True
412
413 if not found :
414 raise AssertionError("RRset not found in answer\n\n%s" % ret)
415
416 def sortRRsets(self, rrsets):
417 """Sorts RRsets in a more useful way than dnspython's default behaviour
418
419 @param rrsets: an array of dns.rrset.RRset objects"""
420
421 return sorted(rrsets, key=lambda rrset: (rrset.name, rrset.rdtype))
422
423 def assertAnyRRsetInAnswer(self, msg, rrsets):
424 """Asserts that any of the supplied rrsets exists (without comparing TTL)
425 in the answer section of msg
426
427 @param msg: the dns.message.Message to check
428 @param rrsets: an array of dns.rrset.RRset object"""
429
430 if not isinstance(msg, dns.message.Message):
431 raise TypeError("msg is not a dns.message.Message")
432
433 found = False
434 for rrset in rrsets:
435 if not isinstance(rrset, dns.rrset.RRset):
436 raise TypeError("rrset is not a dns.rrset.RRset")
437 for ans in msg.answer:
438 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
439 if ans == rrset:
440 found = True
441
442 if not found:
443 raise AssertionError("RRset not found in answer\n%s" %
444 "\n".join(([ans.to_text() for ans in msg.answer])))
445
446 def assertMatchingRRSIGInAnswer(self, msg, coveredRRset, keys=None):
447 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
448 that covers that RRset. If keys is not None, this function will also try to
449 validate the RRset against the RRSIG
450
451 @param msg: The dns.message.Message to check
452 @param coveredRRset: The RRSet to check for
453 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
454
455 if not isinstance(msg, dns.message.Message):
456 raise TypeError("msg is not a dns.message.Message")
457
458 if not isinstance(coveredRRset, dns.rrset.RRset):
459 raise TypeError("coveredRRset is not a dns.rrset.RRset")
460
461 msgRRsigRRSet = None
462 msgRRSet = None
463
464 ret = ''
465 for ans in msg.answer:
466 ret += ans.to_text() + "\n"
467
468 if ans.match(coveredRRset.name, coveredRRset.rdclass, coveredRRset.rdtype, 0, None):
469 msgRRSet = ans
470 if ans.match(coveredRRset.name, dns.rdataclass.IN, dns.rdatatype.RRSIG, coveredRRset.rdtype, None):
471 msgRRsigRRSet = ans
472 if msgRRSet and msgRRsigRRSet:
473 break
474
475 if not msgRRSet:
476 raise AssertionError("RRset for '%s' not found in answer" % msg.question[0].to_text())
477
478 if not msgRRsigRRSet:
479 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg.question[0].to_text(), ret))
480
481 if keys:
482 try:
483 dns.dnssec.validate(msgRRSet, msgRRsigRRSet.to_rdataset(), keys)
484 except dns.dnssec.ValidationFailure as e:
485 raise AssertionError("Signature validation failed for %s:\n%s" % (msg.question[0].to_text(), e))
486
487 def assertNoRRSIGsInAnswer(self, msg):
488 """Checks if there are _no_ RRSIGs in the answer section of msg"""
489
490 if not isinstance(msg, dns.message.Message):
491 raise TypeError("msg is not a dns.message.Message")
492
493 ret = ""
494 for ans in msg.answer:
495 if ans.rdtype == dns.rdatatype.RRSIG:
496 ret += ans.name.to_text() + "\n"
497
498 if len(ret):
499 raise AssertionError("RRSIG found in answers for:\n%s" % ret)
500
501 def assertAnswerEmpty(self, msg):
502 self.assertTrue(len(msg.answer) == 0, "Data found in the the answer section for %s:\n%s" % (msg.question[0].to_text(), '\n'.join([i.to_text() for i in msg.answer])))
503
504 def assertAnswerNotEmpty(self, msg):
505 self.assertTrue(len(msg.answer) > 0, "Answer is empty")
506
507 def assertRcodeEqual(self, msg, rcode):
508 if not isinstance(msg, dns.message.Message):
509 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
510
511 if not isinstance(rcode, int):
512 if isinstance(rcode, str):
513 rcode = dns.rcode.from_text(rcode)
514 else:
515 raise TypeError("rcode is neither a str nor int")
516
517 if msg.rcode() != rcode:
518 msgRcode = dns.rcode._by_value[msg.rcode()]
519 wantedRcode = dns.rcode._by_value[rcode]
520
521 raise AssertionError("Rcode for %s is %s, expected %s." % (msg.question[0].to_text(), msgRcode, wantedRcode))
522
523 def assertAuthorityHasSOA(self, msg):
524 if not isinstance(msg, dns.message.Message):
525 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
526
527 found = False
528 for rrset in msg.authority:
529 if rrset.rdtype == dns.rdatatype.SOA:
530 found = True
531 break
532
533 if not found:
534 raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text())