3 from __future__
import print_function
16 from pprint
import pprint
17 from eqdnsmessage
import AssertEqualDNSMessageMixin
19 class AuthTest(AssertEqualDNSMessageMixin
, unittest
.TestCase
):
21 Setup auth required for the tests
29 _config_template_default
= """
30 module-dir=../regression-tests/modules
32 bind-config={confdir}/named.conf
33 bind-dnssec-db={bind_dnssec_db}
41 distributor-threads=1"""
45 _root_DS
= "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
47 # The default SOA for zones in the authoritative servers
48 _SOA
= "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
50 # The definitions of the zones on the authoritative servers, the key is the
51 # zonename and the value is the zonefile content. several strings are replaced:
52 # - {soa} => value of _SOA
53 # - {prefix} value of _PREFIX
56 example.org. 3600 IN SOA {soa}
57 example.org. 3600 IN NS ns1.example.org.
58 example.org. 3600 IN NS ns2.example.org.
59 ns1.example.org. 3600 IN A {prefix}.10
60 ns2.example.org. 3600 IN A {prefix}.11
66 Private-key-format: v1.2
67 Algorithm: 13 (ECDSAP256SHA256)
68 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
72 _auth_cmd
= ['authbind',
77 _PREFIX
= os
.environ
['PREFIX']
81 def createConfigDir(cls
, confdir
):
83 shutil
.rmtree(confdir
)
85 if e
.errno
!= errno
.ENOENT
:
87 os
.mkdir(confdir
, 0o755)
90 def generateAuthZone(cls
, confdir
, zonename
, zonecontent
):
91 with
open(os
.path
.join(confdir
, '%s.zone' % zonename
), 'w') as zonefile
:
92 zonefile
.write(zonecontent
.format(prefix
=cls
._PREFIX
, soa
=cls
._SOA
))
95 def generateAuthNamedConf(cls
, confdir
, zones
):
96 with
open(os
.path
.join(confdir
, 'named.conf'), 'w') as namedconf
:
101 for zonename
in zones
:
102 zone
= '.' if zonename
== 'ROOT' else zonename
108 };""" % (zone
, zonename
))
111 def generateAuthConfig(cls
, confdir
):
112 bind_dnssec_db
= os
.path
.join(confdir
, 'bind-dnssec.sqlite3')
114 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
116 with
open(os
.path
.join(confdir
, 'pdns.conf'), 'w') as pdnsconf
:
117 pdnsconf
.write(cls
._config
_template
_default
.format(
118 confdir
=confdir
, prefix
=cls
._PREFIX
,
119 bind_dnssec_db
=bind_dnssec_db
))
120 pdnsconf
.write(cls
._config
_template
% params
)
122 os
.system("sqlite3 ./configs/auth/powerdns.sqlite < ../modules/gsqlite3backend/schema.sqlite3.sql")
124 pdnsutilCmd
= [os
.environ
['PDNSUTIL'],
125 '--config-dir=%s' % confdir
,
129 print(' '.join(pdnsutilCmd
))
131 subprocess
.check_output(pdnsutilCmd
, stderr
=subprocess
.STDOUT
)
132 except subprocess
.CalledProcessError
as e
:
133 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd
, e
.returncode
, e
.output
))
136 def secureZone(cls
, confdir
, zonename
, key
=None):
137 zone
= '.' if zonename
== 'ROOT' else zonename
139 pdnsutilCmd
= [os
.environ
['PDNSUTIL'],
140 '--config-dir=%s' % confdir
,
144 keyfile
= os
.path
.join(confdir
, 'dnssec.key')
145 with
open(keyfile
, 'w') as fdKeyfile
:
148 pdnsutilCmd
= [os
.environ
['PDNSUTIL'],
149 '--config-dir=%s' % confdir
,
156 print(' '.join(pdnsutilCmd
))
158 subprocess
.check_output(pdnsutilCmd
, stderr
=subprocess
.STDOUT
)
159 except subprocess
.CalledProcessError
as e
:
160 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd
, e
.returncode
, e
.output
))
163 def generateAllAuthConfig(cls
, confdir
):
164 cls
.generateAuthConfig(confdir
)
165 cls
.generateAuthNamedConf(confdir
, cls
._zones
.keys())
167 for zonename
, zonecontent
in cls
._zones
.items():
168 cls
.generateAuthZone(confdir
,
171 if cls
._zone
_keys
.get(zonename
, None):
172 cls
.secureZone(confdir
, zonename
, cls
._zone
_keys
.get(zonename
))
175 def startAuth(cls
, confdir
, ipaddress
):
177 print("Launching pdns_server..")
178 authcmd
= list(cls
._auth
_cmd
)
179 authcmd
.append('--config-dir=%s' % confdir
)
180 authcmd
.append('--local-address=%s' % ipaddress
)
181 authcmd
.append('--local-port=%s' % cls
._authPort
)
182 authcmd
.append('--loglevel=9')
183 authcmd
.append('--enable-lua-records')
184 authcmd
.append('--lua-health-checks-interval=1')
185 print(' '.join(authcmd
))
186 logFile
= os
.path
.join(confdir
, 'pdns.log')
187 with
open(logFile
, 'w') as fdLog
:
188 cls
._auths
[ipaddress
] = subprocess
.Popen(authcmd
, close_fds
=True,
189 stdout
=fdLog
, stderr
=fdLog
,
194 if cls
._auths
[ipaddress
].poll() is not None:
196 cls
._auths
[ipaddress
].kill()
198 if e
.errno
!= errno
.ESRCH
:
200 with
open(logFile
, 'r') as fdLog
:
202 sys
.exit(cls
._auths
[ipaddress
].returncode
)
205 def setUpSockets(cls
):
206 print("Setting up UDP socket..")
207 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
208 cls
._sock
.settimeout(2.0)
209 cls
._sock
.connect((cls
._PREFIX
+ ".1", cls
._authPort
))
212 def startResponders(cls
):
219 cls
.startResponders()
221 confdir
= os
.path
.join('configs', cls
._confdir
)
222 cls
.createConfigDir(confdir
)
224 cls
.generateAllAuthConfig(confdir
)
225 cls
.startAuth(confdir
, cls
._PREFIX
+ ".1")
227 print("Launching tests..")
230 def tearDownClass(cls
):
232 cls
.tearDownResponders()
235 def tearDownResponders(cls
):
239 def tearDownClass(cls
):
243 def tearDownAuth(cls
):
244 if 'PDNSRECURSOR_FAST_TESTS' in os
.environ
:
249 for _
, auth
in cls
._auths
.items():
252 if auth
.poll() is None:
254 if auth
.poll() is None:
258 if e
.errno
!= errno
.ESRCH
:
262 def sendUDPQuery(cls
, query
, timeout
=2.0, decode
=True, fwparams
=dict()):
264 cls
._sock
.settimeout(timeout
)
267 cls
._sock
.send(query
.to_wire())
268 data
= cls
._sock
.recv(4096)
269 except socket
.timeout
:
273 cls
._sock
.settimeout(None)
279 message
= dns
.message
.from_wire(data
, **fwparams
)
283 def sendTCPQuery(cls
, query
, timeout
=2.0):
284 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
286 sock
.settimeout(timeout
)
288 sock
.connect(("127.0.0.1", cls
._authPort
))
291 wire
= query
.to_wire()
292 sock
.send(struct
.pack("!H", len(wire
)))
296 (datalen
,) = struct
.unpack("!H", data
)
297 data
= sock
.recv(datalen
)
298 except socket
.timeout
as e
:
299 print("Timeout: %s" % (str(e
)))
301 except socket
.error
as e
:
302 print("Network error: %s" % (str(e
)))
309 message
= dns
.message
.from_wire(data
)
313 def sendTCPQueryMultiResponse(cls
, query
, timeout
=2.0, count
=1):
314 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
316 sock
.settimeout(timeout
)
318 sock
.connect(("127.0.0.1", cls
._authPort
))
321 wire
= query
.to_wire()
322 sock
.send(struct
.pack("!H", len(wire
)))
324 except socket
.timeout
as e
:
325 raise Exception("Timeout: %s" % (str(e
)))
326 except socket
.error
as e
:
327 raise Exception("Network error: %s" % (str(e
)))
330 for i
in range(count
):
333 print("got data", repr(data
))
335 (datalen
,) = struct
.unpack("!H", data
)
336 data
= sock
.recv(datalen
)
337 messages
.append(dns
.message
.from_wire(data
))
340 except socket
.timeout
as e
:
341 raise Exception("Timeout: %s" % (str(e
)))
342 except socket
.error
as e
:
343 raise Exception("Network error: %s" % (str(e
)))
348 # This function is called before every tests
349 super(AuthTest
, self
).setUp()
351 ## Functions for comparisons
352 def assertMessageHasFlags(self
, msg
, flags
, ednsflags
=[]):
353 """Asserts that msg has all the flags from flags set
355 @param msg: the dns.message.Message to check
356 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
357 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
359 if not isinstance(msg
, dns
.message
.Message
):
360 raise TypeError("msg is not a dns.message.Message")
362 if isinstance(flags
, list):
364 if not isinstance(elem
, str):
365 raise TypeError("flags is not a list of strings")
367 raise TypeError("flags is not a list of strings")
369 if isinstance(ednsflags
, list):
370 for elem
in ednsflags
:
371 if not isinstance(elem
, str):
372 raise TypeError("ednsflags is not a list of strings")
374 raise TypeError("ednsflags is not a list of strings")
376 msgFlags
= dns
.flags
.to_text(msg
.flags
).split()
377 missingFlags
= [flag
for flag
in flags
if flag
not in msgFlags
]
379 msgEdnsFlags
= dns
.flags
.edns_to_text(msg
.ednsflags
).split()
380 missingEdnsFlags
= [ednsflag
for ednsflag
in ednsflags
if ednsflag
not in msgEdnsFlags
]
382 if len(missingFlags
) or len(missingEdnsFlags
) or len(msgFlags
) > len(flags
):
383 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
384 (' '.join(flags
), ' '.join(ednsflags
),
385 ' '.join(msgFlags
), ' '.join(msgEdnsFlags
),
388 def assertMessageIsAuthenticated(self
, msg
):
389 """Asserts that the message has the AD bit set
391 @param msg: the dns.message.Message to check"""
393 if not isinstance(msg
, dns
.message
.Message
):
394 raise TypeError("msg is not a dns.message.Message")
396 msgFlags
= dns
.flags
.to_text(msg
.flags
)
397 self
.assertTrue('AD' in msgFlags
, "No AD flag found in the message for %s" % msg
.question
[0].name
)
399 def assertRRsetInAnswer(self
, msg
, rrset
):
400 """Asserts the rrset (without comparing TTL) exists in the
401 answer section of msg
403 @param msg: the dns.message.Message to check
404 @param rrset: a dns.rrset.RRset object"""
407 if not isinstance(msg
, dns
.message
.Message
):
408 raise TypeError("msg is not a dns.message.Message")
410 if not isinstance(rrset
, dns
.rrset
.RRset
):
411 raise TypeError("rrset is not a dns.rrset.RRset")
414 for ans
in msg
.answer
:
415 ret
+= "%s\n" % ans
.to_text()
416 if ans
.match(rrset
.name
, rrset
.rdclass
, rrset
.rdtype
, 0, None):
417 self
.assertEqual(ans
, rrset
, "'%s' != '%s'" % (ans
.to_text(), rrset
.to_text()))
421 raise AssertionError("RRset not found in answer\n\n%s" % ret
)
423 def sortRRsets(self
, rrsets
):
424 """Sorts RRsets in a more useful way than dnspython's default behaviour
426 @param rrsets: an array of dns.rrset.RRset objects"""
428 return sorted(rrsets
, key
=lambda rrset
: (rrset
.name
, rrset
.rdtype
))
430 def assertAnyRRsetInAnswer(self
, msg
, rrsets
):
431 """Asserts that any of the supplied rrsets exists (without comparing TTL)
432 in the answer section of msg
434 @param msg: the dns.message.Message to check
435 @param rrsets: an array of dns.rrset.RRset object"""
437 if not isinstance(msg
, dns
.message
.Message
):
438 raise TypeError("msg is not a dns.message.Message")
442 if not isinstance(rrset
, dns
.rrset
.RRset
):
443 raise TypeError("rrset is not a dns.rrset.RRset")
444 for ans
in msg
.answer
:
445 if ans
.match(rrset
.name
, rrset
.rdclass
, rrset
.rdtype
, 0, None):
450 raise AssertionError("RRset not found in answer\n%s" %
451 "\n".join(([ans
.to_text() for ans
in msg
.answer
])))
453 def assertMatchingRRSIGInAnswer(self
, msg
, coveredRRset
, keys
=None):
454 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
455 that covers that RRset. If keys is not None, this function will also try to
456 validate the RRset against the RRSIG
458 @param msg: The dns.message.Message to check
459 @param coveredRRset: The RRSet to check for
460 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
462 if not isinstance(msg
, dns
.message
.Message
):
463 raise TypeError("msg is not a dns.message.Message")
465 if not isinstance(coveredRRset
, dns
.rrset
.RRset
):
466 raise TypeError("coveredRRset is not a dns.rrset.RRset")
472 for ans
in msg
.answer
:
473 ret
+= ans
.to_text() + "\n"
475 if ans
.match(coveredRRset
.name
, coveredRRset
.rdclass
, coveredRRset
.rdtype
, 0, None):
477 if ans
.match(coveredRRset
.name
, dns
.rdataclass
.IN
, dns
.rdatatype
.RRSIG
, coveredRRset
.rdtype
, None):
479 if msgRRSet
and msgRRsigRRSet
:
483 raise AssertionError("RRset for '%s' not found in answer" % msg
.question
[0].to_text())
485 if not msgRRsigRRSet
:
486 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg
.question
[0].to_text(), ret
))
490 dns
.dnssec
.validate(msgRRSet
, msgRRsigRRSet
.to_rdataset(), keys
)
491 except dns
.dnssec
.ValidationFailure
as e
:
492 raise AssertionError("Signature validation failed for %s:\n%s" % (msg
.question
[0].to_text(), e
))
494 def assertNoRRSIGsInAnswer(self
, msg
):
495 """Checks if there are _no_ RRSIGs in the answer section of msg"""
497 if not isinstance(msg
, dns
.message
.Message
):
498 raise TypeError("msg is not a dns.message.Message")
501 for ans
in msg
.answer
:
502 if ans
.rdtype
== dns
.rdatatype
.RRSIG
:
503 ret
+= ans
.name
.to_text() + "\n"
506 raise AssertionError("RRSIG found in answers for:\n%s" % ret
)
508 def assertAnswerEmpty(self
, msg
):
509 self
.assertTrue(len(msg
.answer
) == 0, "Data found in the the answer section for %s:\n%s" % (msg
.question
[0].to_text(), '\n'.join([i
.to_text() for i
in msg
.answer
])))
511 def assertAnswerNotEmpty(self
, msg
):
512 self
.assertTrue(len(msg
.answer
) > 0, "Answer is empty")
514 def assertRcodeEqual(self
, msg
, rcode
):
515 if not isinstance(msg
, dns
.message
.Message
):
516 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg
))
518 if not isinstance(rcode
, int):
519 if isinstance(rcode
, str):
520 rcode
= dns
.rcode
.from_text(rcode
)
522 raise TypeError("rcode is neither a str nor int")
524 if msg
.rcode() != rcode
:
525 msgRcode
= dns
.rcode
._by
_value
[msg
.rcode()]
526 wantedRcode
= dns
.rcode
._by
_value
[rcode
]
528 raise AssertionError("Rcode for %s is %s, expected %s." % (msg
.question
[0].to_text(), msgRcode
, wantedRcode
))
530 def assertAuthorityHasSOA(self
, msg
):
531 if not isinstance(msg
, dns
.message
.Message
):
532 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg
))
535 for rrset
in msg
.authority
:
536 if rrset
.rdtype
== dns
.rdatatype
.SOA
:
541 raise AssertionError("No SOA record found in the authority section:\n%s" % msg
.to_text())