]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.auth-py/authtests.py
Merge pull request #6962 from rgacogne/nmt-faster-removal
[thirdparty/pdns.git] / regression-tests.auth-py / authtests.py
1 #!/usr/bin/env python2
2
3 import errno
4 import shutil
5 import os
6 import socket
7 import struct
8 import subprocess
9 import sys
10 import time
11 import unittest
12 import dns
13 import dns.message
14
15 from pprint import pprint
16
17 class AuthTest(unittest.TestCase):
18 """
19 Setup auth required for the tests
20 """
21
22 _confdir = 'auth'
23 _authPort = 5300
24
25 _root_DS = "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
26
27 # The default SOA for zones in the authoritative servers
28 _SOA = "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
29
30 # The definitions of the zones on the authoritative servers, the key is the
31 # zonename and the value is the zonefile content. several strings are replaced:
32 # - {soa} => value of _SOA
33 # - {prefix} value of _PREFIX
34 _zones = {
35 'example.org': """
36 example.org. 3600 IN SOA {soa}
37 example.org. 3600 IN NS ns1.example.org.
38 example.org. 3600 IN NS ns2.example.org.
39 ns1.example.org. 3600 IN A {prefix}.10
40 ns2.example.org. 3600 IN A {prefix}.11
41 """,
42 }
43
44 _zone_keys = {
45 'example.org': """
46 Private-key-format: v1.2
47 Algorithm: 13 (ECDSAP256SHA256)
48 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
49 """,
50 }
51
52 _auth_cmd = ['authbind',
53 os.environ['PDNS']]
54 _auth_env = {}
55 _auths = {}
56
57 _PREFIX = os.environ['PREFIX']
58
59
60 @classmethod
61 def createConfigDir(cls, confdir):
62 try:
63 shutil.rmtree(confdir)
64 except OSError as e:
65 if e.errno != errno.ENOENT:
66 raise
67 os.mkdir(confdir, 0755)
68
69 @classmethod
70 def generateAuthZone(cls, confdir, zonename, zonecontent):
71 with open(os.path.join(confdir, '%s.zone' % zonename), 'w') as zonefile:
72 zonefile.write(zonecontent.format(prefix=cls._PREFIX, soa=cls._SOA))
73
74 @classmethod
75 def generateAuthNamedConf(cls, confdir, zones):
76 with open(os.path.join(confdir, 'named.conf'), 'w') as namedconf:
77 namedconf.write("""
78 options {
79 directory "%s";
80 };""" % confdir)
81 for zonename in zones:
82 zone = '.' if zonename == 'ROOT' else zonename
83
84 namedconf.write("""
85 zone "%s" {
86 type master;
87 file "%s.zone";
88 };""" % (zone, zonename))
89
90 @classmethod
91 def generateAuthConfig(cls, confdir):
92 bind_dnssec_db = os.path.join(confdir, 'bind-dnssec.sqlite3')
93
94 with open(os.path.join(confdir, 'pdns.conf'), 'w') as pdnsconf:
95 pdnsconf.write("""
96 module-dir=../regression-tests/modules
97 launch=bind geoip
98 daemon=no
99 local-ipv6=
100 bind-config={confdir}/named.conf
101 bind-dnssec-db={bind_dnssec_db}
102 socket-dir={confdir}
103 cache-ttl=0
104 negquery-cache-ttl=0
105 query-cache-ttl=0
106 log-dns-queries=yes
107 log-dns-details=yes
108 loglevel=9
109 geoip-database-files=../modules/geoipbackend/regression-tests/GeoLiteCity.mmdb
110 edns-subnet-processing=yes
111 expand-alias=yes
112 resolver={prefix}.1:5301
113 any-to-tcp=no
114 distributor-threads=1""".format(confdir=confdir, prefix=cls._PREFIX,
115 bind_dnssec_db=bind_dnssec_db))
116
117 pdnsutilCmd = [os.environ['PDNSUTIL'],
118 '--config-dir=%s' % confdir,
119 'create-bind-db',
120 bind_dnssec_db]
121
122 print ' '.join(pdnsutilCmd)
123 try:
124 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
125 except subprocess.CalledProcessError as e:
126 print e.output
127 raise
128
129 @classmethod
130 def secureZone(cls, confdir, zonename, key=None):
131 zone = '.' if zonename == 'ROOT' else zonename
132 if not key:
133 pdnsutilCmd = [os.environ['PDNSUTIL'],
134 '--config-dir=%s' % confdir,
135 'secure-zone',
136 zone]
137 else:
138 keyfile = os.path.join(confdir, 'dnssec.key')
139 with open(keyfile, 'w') as fdKeyfile:
140 fdKeyfile.write(key)
141
142 pdnsutilCmd = [os.environ['PDNSUTIL'],
143 '--config-dir=%s' % confdir,
144 'import-zone-key',
145 zone,
146 keyfile,
147 'active',
148 'ksk']
149
150 print ' '.join(pdnsutilCmd)
151 try:
152 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
153 except subprocess.CalledProcessError as e:
154 print e.output
155 raise
156
157 @classmethod
158 def generateAllAuthConfig(cls, confdir):
159 if cls._zones:
160 cls.generateAuthConfig(confdir)
161 cls.generateAuthNamedConf(confdir, cls._zones.keys())
162
163 for zonename, zonecontent in cls._zones.items():
164 cls.generateAuthZone(confdir,
165 zonename,
166 zonecontent)
167 if cls._zone_keys.get(zonename, None):
168 cls.secureZone(confdir, zonename, cls._zone_keys.get(zonename))
169
170 @classmethod
171 def startAuth(cls, confdir, ipaddress):
172
173 print("Launching pdns_server..")
174 authcmd = list(cls._auth_cmd)
175 authcmd.append('--config-dir=%s' % confdir)
176 authcmd.append('--local-address=%s' % ipaddress)
177 authcmd.append('--local-port=%s' % cls._authPort)
178 authcmd.append('--loglevel=9')
179 authcmd.append('--enable-lua-record')
180 print(' '.join(authcmd))
181
182 logFile = os.path.join(confdir, 'pdns.log')
183 with open(logFile, 'w') as fdLog:
184 cls._auths[ipaddress] = subprocess.Popen(authcmd, close_fds=True,
185 stdout=fdLog, stderr=fdLog,
186 env=cls._auth_env)
187
188 time.sleep(2)
189
190 if cls._auths[ipaddress].poll() is not None:
191 try:
192 cls._auths[ipaddress].kill()
193 except OSError as e:
194 if e.errno != errno.ESRCH:
195 raise
196 with open(logFile, 'r') as fdLog:
197 print fdLog.read()
198 sys.exit(cls._auths[ipaddress].returncode)
199
200 @classmethod
201 def setUpSockets(cls):
202 print("Setting up UDP socket..")
203 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
204 cls._sock.settimeout(2.0)
205 cls._sock.connect((cls._PREFIX + ".1", cls._authPort))
206
207 @classmethod
208 def startResponders(cls):
209 pass
210
211 @classmethod
212 def setUpClass(cls):
213 cls.setUpSockets()
214
215 cls.startResponders()
216
217 confdir = os.path.join('configs', cls._confdir)
218 cls.createConfigDir(confdir)
219
220 cls.generateAllAuthConfig(confdir)
221 cls.startAuth(confdir, cls._PREFIX + ".1")
222
223 print("Launching tests..")
224
225 @classmethod
226 def tearDownClass(cls):
227 cls.tearDownAuth()
228 cls.tearDownResponders()
229
230 @classmethod
231 def tearDownResponders(cls):
232 pass
233
234 @classmethod
235 def tearDownClass(cls):
236 cls.tearDownAuth()
237
238 @classmethod
239 def tearDownAuth(cls):
240 if 'PDNSRECURSOR_FAST_TESTS' in os.environ:
241 delay = 0.1
242 else:
243 delay = 1.0
244
245 for _, auth in cls._auths.items():
246 try:
247 auth.terminate()
248 if auth.poll() is None:
249 time.sleep(delay)
250 if auth.poll() is None:
251 auth.kill()
252 auth.wait()
253 except OSError as e:
254 if e.errno != errno.ESRCH:
255 raise
256
257 @classmethod
258 def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()):
259 if timeout:
260 cls._sock.settimeout(timeout)
261
262 try:
263 cls._sock.send(query.to_wire())
264 data = cls._sock.recv(4096)
265 except socket.timeout:
266 data = None
267 finally:
268 if timeout:
269 cls._sock.settimeout(None)
270
271 message = None
272 if data:
273 if not decode:
274 return data
275 message = dns.message.from_wire(data, **fwparams)
276 return message
277
278 @classmethod
279 def sendTCPQuery(cls, query, timeout=2.0):
280 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
281 if timeout:
282 sock.settimeout(timeout)
283
284 sock.connect(("127.0.0.1", cls._recursorPort))
285
286 try:
287 wire = query.to_wire()
288 sock.send(struct.pack("!H", len(wire)))
289 sock.send(wire)
290 data = sock.recv(2)
291 if data:
292 (datalen,) = struct.unpack("!H", data)
293 data = sock.recv(datalen)
294 except socket.timeout as e:
295 print("Timeout: %s" % (str(e)))
296 data = None
297 except socket.error as e:
298 print("Network error: %s" % (str(e)))
299 data = None
300 finally:
301 sock.close()
302
303 message = None
304 if data:
305 message = dns.message.from_wire(data)
306 return message
307
308
309 @classmethod
310 def sendTCPQuery(cls, query, timeout=2.0):
311 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
312 if timeout:
313 sock.settimeout(timeout)
314
315 sock.connect(("127.0.0.1", cls._authPort))
316
317 try:
318 wire = query.to_wire()
319 sock.send(struct.pack("!H", len(wire)))
320 sock.send(wire)
321 data = sock.recv(2)
322 if data:
323 (datalen,) = struct.unpack("!H", data)
324 data = sock.recv(datalen)
325 except socket.timeout as e:
326 print("Timeout: %s" % (str(e)))
327 data = None
328 except socket.error as e:
329 print("Network error: %s" % (str(e)))
330 data = None
331 finally:
332 sock.close()
333
334 message = None
335 if data:
336 message = dns.message.from_wire(data)
337 return message
338
339 def setUp(self):
340 # This function is called before every tests
341 return
342
343 ## Functions for comparisons
344 def assertMessageHasFlags(self, msg, flags, ednsflags=[]):
345 """Asserts that msg has all the flags from flags set
346
347 @param msg: the dns.message.Message to check
348 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
349 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
350
351 if not isinstance(msg, dns.message.Message):
352 raise TypeError("msg is not a dns.message.Message")
353
354 if isinstance(flags, list):
355 for elem in flags:
356 if not isinstance(elem, str):
357 raise TypeError("flags is not a list of strings")
358 else:
359 raise TypeError("flags is not a list of strings")
360
361 if isinstance(ednsflags, list):
362 for elem in ednsflags:
363 if not isinstance(elem, str):
364 raise TypeError("ednsflags is not a list of strings")
365 else:
366 raise TypeError("ednsflags is not a list of strings")
367
368 msgFlags = dns.flags.to_text(msg.flags).split()
369 missingFlags = [flag for flag in flags if flag not in msgFlags]
370
371 msgEdnsFlags = dns.flags.edns_to_text(msg.ednsflags).split()
372 missingEdnsFlags = [ednsflag for ednsflag in ednsflags if ednsflag not in msgEdnsFlags]
373
374 if len(missingFlags) or len(missingEdnsFlags) or len(msgFlags) > len(flags):
375 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
376 (' '.join(flags), ' '.join(ednsflags),
377 ' '.join(msgFlags), ' '.join(msgEdnsFlags),
378 msg.question[0]))
379
380 def assertMessageIsAuthenticated(self, msg):
381 """Asserts that the message has the AD bit set
382
383 @param msg: the dns.message.Message to check"""
384
385 if not isinstance(msg, dns.message.Message):
386 raise TypeError("msg is not a dns.message.Message")
387
388 msgFlags = dns.flags.to_text(msg.flags)
389 self.assertTrue('AD' in msgFlags, "No AD flag found in the message for %s" % msg.question[0].name)
390
391 def assertRRsetInAnswer(self, msg, rrset):
392 """Asserts the rrset (without comparing TTL) exists in the
393 answer section of msg
394
395 @param msg: the dns.message.Message to check
396 @param rrset: a dns.rrset.RRset object"""
397
398 ret = ''
399 if not isinstance(msg, dns.message.Message):
400 raise TypeError("msg is not a dns.message.Message")
401
402 if not isinstance(rrset, dns.rrset.RRset):
403 raise TypeError("rrset is not a dns.rrset.RRset")
404
405 found = False
406 for ans in msg.answer:
407 ret += "%s\n" % ans.to_text()
408 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
409 self.assertEqual(ans, rrset, "'%s' != '%s'" % (ans.to_text(), rrset.to_text()))
410 found = True
411
412 if not found :
413 raise AssertionError("RRset not found in answer\n\n%s" % ret)
414
415 def assertAnyRRsetInAnswer(self, msg, rrsets):
416 """Asserts that any of the supplied rrsets exists (without comparing TTL)
417 in the answer section of msg
418
419 @param msg: the dns.message.Message to check
420 @param rrsets: an array of dns.rrset.RRset object"""
421
422 if not isinstance(msg, dns.message.Message):
423 raise TypeError("msg is not a dns.message.Message")
424
425 found = False
426 for rrset in rrsets:
427 if not isinstance(rrset, dns.rrset.RRset):
428 raise TypeError("rrset is not a dns.rrset.RRset")
429 for ans in msg.answer:
430 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
431 if ans == rrset:
432 found = True
433
434 if not found:
435 raise AssertionError("RRset not found in answer\n%s" %
436 "\n".join(([ans.to_text() for ans in msg.answer])))
437
438 def assertMatchingRRSIGInAnswer(self, msg, coveredRRset, keys=None):
439 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
440 that covers that RRset. If keys is not None, this function will also try to
441 validate the RRset against the RRSIG
442
443 @param msg: The dns.message.Message to check
444 @param coveredRRset: The RRSet to check for
445 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
446
447 if not isinstance(msg, dns.message.Message):
448 raise TypeError("msg is not a dns.message.Message")
449
450 if not isinstance(coveredRRset, dns.rrset.RRset):
451 raise TypeError("coveredRRset is not a dns.rrset.RRset")
452
453 msgRRsigRRSet = None
454 msgRRSet = None
455
456 ret = ''
457 for ans in msg.answer:
458 ret += ans.to_text() + "\n"
459
460 if ans.match(coveredRRset.name, coveredRRset.rdclass, coveredRRset.rdtype, 0, None):
461 msgRRSet = ans
462 if ans.match(coveredRRset.name, dns.rdataclass.IN, dns.rdatatype.RRSIG, coveredRRset.rdtype, None):
463 msgRRsigRRSet = ans
464 if msgRRSet and msgRRsigRRSet:
465 break
466
467 if not msgRRSet:
468 raise AssertionError("RRset for '%s' not found in answer" % msg.question[0].to_text())
469
470 if not msgRRsigRRSet:
471 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg.question[0].to_text(), ret))
472
473 if keys:
474 try:
475 dns.dnssec.validate(msgRRSet, msgRRsigRRSet.to_rdataset(), keys)
476 except dns.dnssec.ValidationFailure as e:
477 raise AssertionError("Signature validation failed for %s:\n%s" % (msg.question[0].to_text(), e))
478
479 def assertNoRRSIGsInAnswer(self, msg):
480 """Checks if there are _no_ RRSIGs in the answer section of msg"""
481
482 if not isinstance(msg, dns.message.Message):
483 raise TypeError("msg is not a dns.message.Message")
484
485 ret = ""
486 for ans in msg.answer:
487 if ans.rdtype == dns.rdatatype.RRSIG:
488 ret += ans.name.to_text() + "\n"
489
490 if len(ret):
491 raise AssertionError("RRSIG found in answers for:\n%s" % ret)
492
493 def assertAnswerEmpty(self, msg):
494 self.assertTrue(len(msg.answer) == 0, "Data found in the the answer section for %s:\n%s" % (msg.question[0].to_text(), '\n'.join([i.to_text() for i in msg.answer])))
495
496 def assertAnswerNotEmpty(self, msg):
497 self.assertTrue(len(msg.answer) > 0, "Answer is empty")
498
499 def assertRcodeEqual(self, msg, rcode):
500 if not isinstance(msg, dns.message.Message):
501 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
502
503 if not isinstance(rcode, int):
504 if isinstance(rcode, str):
505 rcode = dns.rcode.from_text(rcode)
506 else:
507 raise TypeError("rcode is neither a str nor int")
508
509 if msg.rcode() != rcode:
510 msgRcode = dns.rcode._by_value[msg.rcode()]
511 wantedRcode = dns.rcode._by_value[rcode]
512
513 raise AssertionError("Rcode for %s is %s, expected %s." % (msg.question[0].to_text(), msgRcode, wantedRcode))
514
515 def assertAuthorityHasSOA(self, msg):
516 if not isinstance(msg, dns.message.Message):
517 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
518
519 found = False
520 for rrset in msg.authority:
521 if rrset.rdtype == dns.rdatatype.SOA:
522 found = True
523 break
524
525 if not found:
526 raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text())