]> git.ipfire.org Git - thirdparty/pdns.git/blob - regression-tests.auth-py/authtests.py
Merge pull request #6525 from rgacogne/calidns-max-qps
[thirdparty/pdns.git] / regression-tests.auth-py / authtests.py
1 #!/usr/bin/env python2
2
3 import errno
4 import shutil
5 import os
6 import socket
7 import struct
8 import subprocess
9 import sys
10 import time
11 import unittest
12 import dns
13 import dns.message
14
15 from pprint import pprint
16
17 class AuthTest(unittest.TestCase):
18 """
19 Setup auth required for the tests
20 """
21
22 _confdir = 'auth'
23 _authPort = 5300
24
25 _root_DS = "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
26
27 # The default SOA for zones in the authoritative servers
28 _SOA = "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
29
30 # The definitions of the zones on the authoritative servers, the key is the
31 # zonename and the value is the zonefile content. several strings are replaced:
32 # - {soa} => value of _SOA
33 # - {prefix} value of _PREFIX
34 _zones = {
35 'example.org': """
36 example.org. 3600 IN SOA {soa}
37 example.org. 3600 IN NS ns1.example.org.
38 example.org. 3600 IN NS ns2.example.org.
39 ns1.example.org. 3600 IN A {prefix}.10
40 ns2.example.org. 3600 IN A {prefix}.11
41 """,
42 }
43
44 _zone_keys = {
45 'example.org': """
46 Private-key-format: v1.2
47 Algorithm: 13 (ECDSAP256SHA256)
48 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
49 """,
50 }
51
52 _auth_cmd = ['authbind',
53 os.environ['PDNS']]
54 _auth_env = {}
55 _auths = {}
56
57 _PREFIX = os.environ['PREFIX']
58
59
60 @classmethod
61 def createConfigDir(cls, confdir):
62 try:
63 shutil.rmtree(confdir)
64 except OSError as e:
65 if e.errno != errno.ENOENT:
66 raise
67 os.mkdir(confdir, 0755)
68
69 @classmethod
70 def generateAuthZone(cls, confdir, zonename, zonecontent):
71 with open(os.path.join(confdir, '%s.zone' % zonename), 'w') as zonefile:
72 zonefile.write(zonecontent.format(prefix=cls._PREFIX, soa=cls._SOA))
73
74 @classmethod
75 def generateAuthNamedConf(cls, confdir, zones):
76 with open(os.path.join(confdir, 'named.conf'), 'w') as namedconf:
77 namedconf.write("""
78 options {
79 directory "%s";
80 };""" % confdir)
81 for zonename in zones:
82 zone = '.' if zonename == 'ROOT' else zonename
83
84 namedconf.write("""
85 zone "%s" {
86 type master;
87 file "%s.zone";
88 };""" % (zone, zonename))
89
90 @classmethod
91 def generateAuthConfig(cls, confdir):
92 bind_dnssec_db = os.path.join(confdir, 'bind-dnssec.sqlite3')
93
94 with open(os.path.join(confdir, 'pdns.conf'), 'w') as pdnsconf:
95 pdnsconf.write("""
96 module-dir=../regression-tests/modules
97 launch=bind geoip
98 daemon=no
99 local-ipv6=
100 bind-config={confdir}/named.conf
101 bind-dnssec-db={bind_dnssec_db}
102 socket-dir={confdir}
103 cache-ttl=0
104 negquery-cache-ttl=0
105 query-cache-ttl=0
106 log-dns-queries=yes
107 log-dns-details=yes
108 loglevel=9
109 geoip-database-files=../modules/geoipbackend/regression-tests/GeoLiteCity.mmdb
110 edns-subnet-processing=yes
111 distributor-threads=1""".format(confdir=confdir,
112 bind_dnssec_db=bind_dnssec_db))
113
114 pdnsutilCmd = [os.environ['PDNSUTIL'],
115 '--config-dir=%s' % confdir,
116 'create-bind-db',
117 bind_dnssec_db]
118
119 print ' '.join(pdnsutilCmd)
120 try:
121 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
122 except subprocess.CalledProcessError as e:
123 print e.output
124 raise
125
126 @classmethod
127 def secureZone(cls, confdir, zonename, key=None):
128 zone = '.' if zonename == 'ROOT' else zonename
129 if not key:
130 pdnsutilCmd = [os.environ['PDNSUTIL'],
131 '--config-dir=%s' % confdir,
132 'secure-zone',
133 zone]
134 else:
135 keyfile = os.path.join(confdir, 'dnssec.key')
136 with open(keyfile, 'w') as fdKeyfile:
137 fdKeyfile.write(key)
138
139 pdnsutilCmd = [os.environ['PDNSUTIL'],
140 '--config-dir=%s' % confdir,
141 'import-zone-key',
142 zone,
143 keyfile,
144 'active',
145 'ksk']
146
147 print ' '.join(pdnsutilCmd)
148 try:
149 subprocess.check_output(pdnsutilCmd, stderr=subprocess.STDOUT)
150 except subprocess.CalledProcessError as e:
151 print e.output
152 raise
153
154 @classmethod
155 def generateAllAuthConfig(cls, confdir):
156 if cls._zones:
157 cls.generateAuthConfig(confdir)
158 cls.generateAuthNamedConf(confdir, cls._zones.keys())
159
160 for zonename, zonecontent in cls._zones.items():
161 cls.generateAuthZone(confdir,
162 zonename,
163 zonecontent)
164 if cls._zone_keys.get(zonename, None):
165 cls.secureZone(confdir, zonename, cls._zone_keys.get(zonename))
166
167 @classmethod
168 def startAuth(cls, confdir, ipaddress):
169
170 print("Launching pdns_server..")
171 authcmd = list(cls._auth_cmd)
172 authcmd.append('--config-dir=%s' % confdir)
173 authcmd.append('--local-address=%s' % ipaddress)
174 authcmd.append('--local-port=%s' % cls._authPort)
175 authcmd.append('--loglevel=9')
176 authcmd.append('--enable-lua-record')
177 print(' '.join(authcmd))
178
179 logFile = os.path.join(confdir, 'pdns.log')
180 with open(logFile, 'w') as fdLog:
181 cls._auths[ipaddress] = subprocess.Popen(authcmd, close_fds=True,
182 stdout=fdLog, stderr=fdLog,
183 env=cls._auth_env)
184
185 time.sleep(2)
186
187 if cls._auths[ipaddress].poll() is not None:
188 try:
189 cls._auths[ipaddress].kill()
190 except OSError as e:
191 if e.errno != errno.ESRCH:
192 raise
193 with open(logFile, 'r') as fdLog:
194 print fdLog.read()
195 sys.exit(cls._auths[ipaddress].returncode)
196
197 @classmethod
198 def setUpSockets(cls):
199 print("Setting up UDP socket..")
200 cls._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
201 cls._sock.settimeout(2.0)
202 cls._sock.connect((cls._PREFIX + ".1", cls._authPort))
203
204 @classmethod
205 def startResponders(cls):
206 pass
207
208 @classmethod
209 def setUpClass(cls):
210 cls.setUpSockets()
211
212 cls.startResponders()
213
214 confdir = os.path.join('configs', cls._confdir)
215 cls.createConfigDir(confdir)
216
217 cls.generateAllAuthConfig(confdir)
218 cls.startAuth(confdir, cls._PREFIX + ".1")
219
220 print("Launching tests..")
221
222 @classmethod
223 def tearDownClass(cls):
224 cls.tearDownAuth()
225 cls.tearDownResponders()
226
227 @classmethod
228 def tearDownResponders(cls):
229 pass
230
231 @classmethod
232 def tearDownClass(cls):
233 cls.tearDownAuth()
234
235 @classmethod
236 def tearDownAuth(cls):
237 if 'PDNSRECURSOR_FAST_TESTS' in os.environ:
238 delay = 0.1
239 else:
240 delay = 1.0
241
242 for _, auth in cls._auths.items():
243 try:
244 auth.terminate()
245 if auth.poll() is None:
246 time.sleep(delay)
247 if auth.poll() is None:
248 auth.kill()
249 auth.wait()
250 except OSError as e:
251 if e.errno != errno.ESRCH:
252 raise
253
254 @classmethod
255 def sendUDPQuery(cls, query, timeout=2.0, decode=True, fwparams=dict()):
256 if timeout:
257 cls._sock.settimeout(timeout)
258
259 try:
260 cls._sock.send(query.to_wire())
261 data = cls._sock.recv(4096)
262 except socket.timeout:
263 data = None
264 finally:
265 if timeout:
266 cls._sock.settimeout(None)
267
268 message = None
269 if data:
270 if not decode:
271 return data
272 message = dns.message.from_wire(data, **fwparams)
273 return message
274
275 @classmethod
276 def sendTCPQuery(cls, query, timeout=2.0):
277 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
278 if timeout:
279 sock.settimeout(timeout)
280
281 sock.connect(("127.0.0.1", cls._recursorPort))
282
283 try:
284 wire = query.to_wire()
285 sock.send(struct.pack("!H", len(wire)))
286 sock.send(wire)
287 data = sock.recv(2)
288 if data:
289 (datalen,) = struct.unpack("!H", data)
290 data = sock.recv(datalen)
291 except socket.timeout as e:
292 print("Timeout: %s" % (str(e)))
293 data = None
294 except socket.error as e:
295 print("Network error: %s" % (str(e)))
296 data = None
297 finally:
298 sock.close()
299
300 message = None
301 if data:
302 message = dns.message.from_wire(data)
303 return message
304
305
306 @classmethod
307 def sendTCPQuery(cls, query, timeout=2.0):
308 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
309 if timeout:
310 sock.settimeout(timeout)
311
312 sock.connect(("127.0.0.1", cls._authPort))
313
314 try:
315 wire = query.to_wire()
316 sock.send(struct.pack("!H", len(wire)))
317 sock.send(wire)
318 data = sock.recv(2)
319 if data:
320 (datalen,) = struct.unpack("!H", data)
321 data = sock.recv(datalen)
322 except socket.timeout as e:
323 print("Timeout: %s" % (str(e)))
324 data = None
325 except socket.error as e:
326 print("Network error: %s" % (str(e)))
327 data = None
328 finally:
329 sock.close()
330
331 message = None
332 if data:
333 message = dns.message.from_wire(data)
334 return message
335
336 def setUp(self):
337 # This function is called before every tests
338 return
339
340 ## Functions for comparisons
341 def assertMessageHasFlags(self, msg, flags, ednsflags=[]):
342 """Asserts that msg has all the flags from flags set
343
344 @param msg: the dns.message.Message to check
345 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
346 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
347
348 if not isinstance(msg, dns.message.Message):
349 raise TypeError("msg is not a dns.message.Message")
350
351 if isinstance(flags, list):
352 for elem in flags:
353 if not isinstance(elem, str):
354 raise TypeError("flags is not a list of strings")
355 else:
356 raise TypeError("flags is not a list of strings")
357
358 if isinstance(ednsflags, list):
359 for elem in ednsflags:
360 if not isinstance(elem, str):
361 raise TypeError("ednsflags is not a list of strings")
362 else:
363 raise TypeError("ednsflags is not a list of strings")
364
365 msgFlags = dns.flags.to_text(msg.flags).split()
366 missingFlags = [flag for flag in flags if flag not in msgFlags]
367
368 msgEdnsFlags = dns.flags.edns_to_text(msg.ednsflags).split()
369 missingEdnsFlags = [ednsflag for ednsflag in ednsflags if ednsflag not in msgEdnsFlags]
370
371 if len(missingFlags) or len(missingEdnsFlags) or len(msgFlags) > len(flags):
372 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
373 (' '.join(flags), ' '.join(ednsflags),
374 ' '.join(msgFlags), ' '.join(msgEdnsFlags),
375 msg.question[0]))
376
377 def assertMessageIsAuthenticated(self, msg):
378 """Asserts that the message has the AD bit set
379
380 @param msg: the dns.message.Message to check"""
381
382 if not isinstance(msg, dns.message.Message):
383 raise TypeError("msg is not a dns.message.Message")
384
385 msgFlags = dns.flags.to_text(msg.flags)
386 self.assertTrue('AD' in msgFlags, "No AD flag found in the message for %s" % msg.question[0].name)
387
388 def assertRRsetInAnswer(self, msg, rrset):
389 """Asserts the rrset (without comparing TTL) exists in the
390 answer section of msg
391
392 @param msg: the dns.message.Message to check
393 @param rrset: a dns.rrset.RRset object"""
394
395 ret = ''
396 if not isinstance(msg, dns.message.Message):
397 raise TypeError("msg is not a dns.message.Message")
398
399 if not isinstance(rrset, dns.rrset.RRset):
400 raise TypeError("rrset is not a dns.rrset.RRset")
401
402 found = False
403 for ans in msg.answer:
404 ret += "%s\n" % ans.to_text()
405 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
406 self.assertEqual(ans, rrset, "'%s' != '%s'" % (ans.to_text(), rrset.to_text()))
407 found = True
408
409 if not found :
410 raise AssertionError("RRset not found in answer\n\n%s" % ret)
411
412 def assertAnyRRsetInAnswer(self, msg, rrsets):
413 """Asserts that any of the supplied rrsets exists (without comparing TTL)
414 in the answer section of msg
415
416 @param msg: the dns.message.Message to check
417 @param rrsets: an array of dns.rrset.RRset object"""
418
419 if not isinstance(msg, dns.message.Message):
420 raise TypeError("msg is not a dns.message.Message")
421
422 found = False
423 for rrset in rrsets:
424 if not isinstance(rrset, dns.rrset.RRset):
425 raise TypeError("rrset is not a dns.rrset.RRset")
426 for ans in msg.answer:
427 if ans.match(rrset.name, rrset.rdclass, rrset.rdtype, 0, None):
428 if ans == rrset:
429 found = True
430
431 if not found:
432 raise AssertionError("RRset not found in answer\n%s" %
433 "\n".join(([ans.to_text() for ans in msg.answer])))
434
435 def assertMatchingRRSIGInAnswer(self, msg, coveredRRset, keys=None):
436 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
437 that covers that RRset. If keys is not None, this function will also try to
438 validate the RRset against the RRSIG
439
440 @param msg: The dns.message.Message to check
441 @param coveredRRset: The RRSet to check for
442 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
443
444 if not isinstance(msg, dns.message.Message):
445 raise TypeError("msg is not a dns.message.Message")
446
447 if not isinstance(coveredRRset, dns.rrset.RRset):
448 raise TypeError("coveredRRset is not a dns.rrset.RRset")
449
450 msgRRsigRRSet = None
451 msgRRSet = None
452
453 ret = ''
454 for ans in msg.answer:
455 ret += ans.to_text() + "\n"
456
457 if ans.match(coveredRRset.name, coveredRRset.rdclass, coveredRRset.rdtype, 0, None):
458 msgRRSet = ans
459 if ans.match(coveredRRset.name, dns.rdataclass.IN, dns.rdatatype.RRSIG, coveredRRset.rdtype, None):
460 msgRRsigRRSet = ans
461 if msgRRSet and msgRRsigRRSet:
462 break
463
464 if not msgRRSet:
465 raise AssertionError("RRset for '%s' not found in answer" % msg.question[0].to_text())
466
467 if not msgRRsigRRSet:
468 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg.question[0].to_text(), ret))
469
470 if keys:
471 try:
472 dns.dnssec.validate(msgRRSet, msgRRsigRRSet.to_rdataset(), keys)
473 except dns.dnssec.ValidationFailure as e:
474 raise AssertionError("Signature validation failed for %s:\n%s" % (msg.question[0].to_text(), e))
475
476 def assertNoRRSIGsInAnswer(self, msg):
477 """Checks if there are _no_ RRSIGs in the answer section of msg"""
478
479 if not isinstance(msg, dns.message.Message):
480 raise TypeError("msg is not a dns.message.Message")
481
482 ret = ""
483 for ans in msg.answer:
484 if ans.rdtype == dns.rdatatype.RRSIG:
485 ret += ans.name.to_text() + "\n"
486
487 if len(ret):
488 raise AssertionError("RRSIG found in answers for:\n%s" % ret)
489
490 def assertAnswerEmpty(self, msg):
491 self.assertTrue(len(msg.answer) == 0, "Data found in the the answer section for %s:\n%s" % (msg.question[0].to_text(), '\n'.join([i.to_text() for i in msg.answer])))
492
493 def assertAnswerNotEmpty(self, msg):
494 self.assertTrue(len(msg.answer) > 0, "Answer is empty")
495
496 def assertRcodeEqual(self, msg, rcode):
497 if not isinstance(msg, dns.message.Message):
498 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
499
500 if not isinstance(rcode, int):
501 if isinstance(rcode, str):
502 rcode = dns.rcode.from_text(rcode)
503 else:
504 raise TypeError("rcode is neither a str nor int")
505
506 if msg.rcode() != rcode:
507 msgRcode = dns.rcode._by_value[msg.rcode()]
508 wantedRcode = dns.rcode._by_value[rcode]
509
510 raise AssertionError("Rcode for %s is %s, expected %s." % (msg.question[0].to_text(), msgRcode, wantedRcode))
511
512 def assertAuthorityHasSOA(self, msg):
513 if not isinstance(msg, dns.message.Message):
514 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
515
516 found = False
517 for rrset in msg.authority:
518 if rrset.rdtype == dns.rdatatype.SOA:
519 found = True
520 break
521
522 if not found:
523 raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text())