3 from __future__
import print_function
16 from pprint
import pprint
17 from eqdnsmessage
import AssertEqualDNSMessageMixin
19 class AuthTest(AssertEqualDNSMessageMixin
, unittest
.TestCase
):
21 Setup auth required for the tests
29 _config_template_default
= """
30 module-dir=../regression-tests/modules
32 bind-config={confdir}/named.conf
33 bind-dnssec-db={bind_dnssec_db}
41 distributor-threads=1"""
45 _root_DS
= "63149 13 1 a59da3f5c1b97fcd5fa2b3b2b0ac91d38a60d33a"
47 # The default SOA for zones in the authoritative servers
48 _SOA
= "ns1.example.net. hostmaster.example.net. 1 3600 1800 1209600 300"
50 # The definitions of the zones on the authoritative servers, the key is the
51 # zonename and the value is the zonefile content. several strings are replaced:
52 # - {soa} => value of _SOA
53 # - {prefix} value of _PREFIX
56 example.org. 3600 IN SOA {soa}
57 example.org. 3600 IN NS ns1.example.org.
58 example.org. 3600 IN NS ns2.example.org.
59 ns1.example.org. 3600 IN A {prefix}.10
60 ns2.example.org. 3600 IN A {prefix}.11
66 Private-key-format: v1.2
67 Algorithm: 13 (ECDSAP256SHA256)
68 PrivateKey: Lt0v0Gol3pRUFM7fDdcy0IWN0O/MnEmVPA+VylL8Y4U=
72 _auth_cmd
= ['authbind',
77 _PREFIX
= os
.environ
['PREFIX']
81 def createConfigDir(cls
, confdir
):
83 shutil
.rmtree(confdir
)
85 if e
.errno
!= errno
.ENOENT
:
87 os
.mkdir(confdir
, 0o755)
90 def generateAuthZone(cls
, confdir
, zonename
, zonecontent
):
91 with
open(os
.path
.join(confdir
, '%s.zone' % zonename
), 'w') as zonefile
:
92 zonefile
.write(zonecontent
.format(prefix
=cls
._PREFIX
, soa
=cls
._SOA
))
95 def generateAuthNamedConf(cls
, confdir
, zones
):
96 with
open(os
.path
.join(confdir
, 'named.conf'), 'w') as namedconf
:
101 for zonename
in zones
:
102 zone
= '.' if zonename
== 'ROOT' else zonename
108 };""" % (zone
, zonename
))
111 def generateAuthConfig(cls
, confdir
):
112 bind_dnssec_db
= os
.path
.join(confdir
, 'bind-dnssec.sqlite3')
114 params
= tuple([getattr(cls
, param
) for param
in cls
._config
_params
])
116 with
open(os
.path
.join(confdir
, 'pdns.conf'), 'w') as pdnsconf
:
117 pdnsconf
.write(cls
._config
_template
_default
.format(
118 confdir
=confdir
, prefix
=cls
._PREFIX
,
119 bind_dnssec_db
=bind_dnssec_db
))
120 pdnsconf
.write(cls
._config
_template
% params
)
122 pdnsutilCmd
= [os
.environ
['PDNSUTIL'],
123 '--config-dir=%s' % confdir
,
127 print(' '.join(pdnsutilCmd
))
129 subprocess
.check_output(pdnsutilCmd
, stderr
=subprocess
.STDOUT
)
130 except subprocess
.CalledProcessError
as e
:
131 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd
, e
.returncode
, e
.output
))
134 def secureZone(cls
, confdir
, zonename
, key
=None):
135 zone
= '.' if zonename
== 'ROOT' else zonename
137 pdnsutilCmd
= [os
.environ
['PDNSUTIL'],
138 '--config-dir=%s' % confdir
,
142 keyfile
= os
.path
.join(confdir
, 'dnssec.key')
143 with
open(keyfile
, 'w') as fdKeyfile
:
146 pdnsutilCmd
= [os
.environ
['PDNSUTIL'],
147 '--config-dir=%s' % confdir
,
154 print(' '.join(pdnsutilCmd
))
156 subprocess
.check_output(pdnsutilCmd
, stderr
=subprocess
.STDOUT
)
157 except subprocess
.CalledProcessError
as e
:
158 raise AssertionError('%s failed (%d): %s' % (pdnsutilCmd
, e
.returncode
, e
.output
))
161 def generateAllAuthConfig(cls
, confdir
):
163 cls
.generateAuthConfig(confdir
)
164 cls
.generateAuthNamedConf(confdir
, cls
._zones
.keys())
166 for zonename
, zonecontent
in cls
._zones
.items():
167 cls
.generateAuthZone(confdir
,
170 if cls
._zone
_keys
.get(zonename
, None):
171 cls
.secureZone(confdir
, zonename
, cls
._zone
_keys
.get(zonename
))
174 def startAuth(cls
, confdir
, ipaddress
):
176 print("Launching pdns_server..")
177 authcmd
= list(cls
._auth
_cmd
)
178 authcmd
.append('--config-dir=%s' % confdir
)
179 authcmd
.append('--local-address=%s' % ipaddress
)
180 authcmd
.append('--local-port=%s' % cls
._authPort
)
181 authcmd
.append('--loglevel=9')
182 authcmd
.append('--enable-lua-records')
183 authcmd
.append('--lua-health-checks-interval=1')
184 print(' '.join(authcmd
))
185 logFile
= os
.path
.join(confdir
, 'pdns.log')
186 with
open(logFile
, 'w') as fdLog
:
187 cls
._auths
[ipaddress
] = subprocess
.Popen(authcmd
, close_fds
=True,
188 stdout
=fdLog
, stderr
=fdLog
,
193 if cls
._auths
[ipaddress
].poll() is not None:
195 cls
._auths
[ipaddress
].kill()
197 if e
.errno
!= errno
.ESRCH
:
199 with
open(logFile
, 'r') as fdLog
:
201 sys
.exit(cls
._auths
[ipaddress
].returncode
)
204 def setUpSockets(cls
):
205 print("Setting up UDP socket..")
206 cls
._sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
)
207 cls
._sock
.settimeout(2.0)
208 cls
._sock
.connect((cls
._PREFIX
+ ".1", cls
._authPort
))
211 def startResponders(cls
):
218 cls
.startResponders()
220 confdir
= os
.path
.join('configs', cls
._confdir
)
221 cls
.createConfigDir(confdir
)
223 cls
.generateAllAuthConfig(confdir
)
224 cls
.startAuth(confdir
, cls
._PREFIX
+ ".1")
226 print("Launching tests..")
229 def tearDownClass(cls
):
231 cls
.tearDownResponders()
234 def tearDownResponders(cls
):
238 def tearDownClass(cls
):
242 def tearDownAuth(cls
):
243 if 'PDNSRECURSOR_FAST_TESTS' in os
.environ
:
248 for _
, auth
in cls
._auths
.items():
251 if auth
.poll() is None:
253 if auth
.poll() is None:
257 if e
.errno
!= errno
.ESRCH
:
261 def sendUDPQuery(cls
, query
, timeout
=2.0, decode
=True, fwparams
=dict()):
263 cls
._sock
.settimeout(timeout
)
266 cls
._sock
.send(query
.to_wire())
267 data
= cls
._sock
.recv(4096)
268 except socket
.timeout
:
272 cls
._sock
.settimeout(None)
278 message
= dns
.message
.from_wire(data
, **fwparams
)
282 def sendTCPQuery(cls
, query
, timeout
=2.0):
283 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
285 sock
.settimeout(timeout
)
287 sock
.connect(("127.0.0.1", cls
._recursorPort
))
290 wire
= query
.to_wire()
291 sock
.send(struct
.pack("!H", len(wire
)))
295 (datalen
,) = struct
.unpack("!H", data
)
296 data
= sock
.recv(datalen
)
297 except socket
.timeout
as e
:
298 print("Timeout: %s" % (str(e
)))
300 except socket
.error
as e
:
301 print("Network error: %s" % (str(e
)))
308 message
= dns
.message
.from_wire(data
)
313 def sendTCPQuery(cls
, query
, timeout
=2.0):
314 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
316 sock
.settimeout(timeout
)
318 sock
.connect(("127.0.0.1", cls
._authPort
))
321 wire
= query
.to_wire()
322 sock
.send(struct
.pack("!H", len(wire
)))
326 (datalen
,) = struct
.unpack("!H", data
)
327 data
= sock
.recv(datalen
)
328 except socket
.timeout
as e
:
329 print("Timeout: %s" % (str(e
)))
331 except socket
.error
as e
:
332 print("Network error: %s" % (str(e
)))
339 message
= dns
.message
.from_wire(data
)
343 # This function is called before every tests
344 super(AuthTest
, self
).setUp()
346 ## Functions for comparisons
347 def assertMessageHasFlags(self
, msg
, flags
, ednsflags
=[]):
348 """Asserts that msg has all the flags from flags set
350 @param msg: the dns.message.Message to check
351 @param flags: a list of strings with flag mnemonics (like ['RD', 'RA'])
352 @param ednsflags: a list of strings with edns-flag mnemonics (like ['DO'])"""
354 if not isinstance(msg
, dns
.message
.Message
):
355 raise TypeError("msg is not a dns.message.Message")
357 if isinstance(flags
, list):
359 if not isinstance(elem
, str):
360 raise TypeError("flags is not a list of strings")
362 raise TypeError("flags is not a list of strings")
364 if isinstance(ednsflags
, list):
365 for elem
in ednsflags
:
366 if not isinstance(elem
, str):
367 raise TypeError("ednsflags is not a list of strings")
369 raise TypeError("ednsflags is not a list of strings")
371 msgFlags
= dns
.flags
.to_text(msg
.flags
).split()
372 missingFlags
= [flag
for flag
in flags
if flag
not in msgFlags
]
374 msgEdnsFlags
= dns
.flags
.edns_to_text(msg
.ednsflags
).split()
375 missingEdnsFlags
= [ednsflag
for ednsflag
in ednsflags
if ednsflag
not in msgEdnsFlags
]
377 if len(missingFlags
) or len(missingEdnsFlags
) or len(msgFlags
) > len(flags
):
378 raise AssertionError("Expected flags '%s' (EDNS: '%s'), found '%s' (EDNS: '%s') in query %s" %
379 (' '.join(flags
), ' '.join(ednsflags
),
380 ' '.join(msgFlags
), ' '.join(msgEdnsFlags
),
383 def assertMessageIsAuthenticated(self
, msg
):
384 """Asserts that the message has the AD bit set
386 @param msg: the dns.message.Message to check"""
388 if not isinstance(msg
, dns
.message
.Message
):
389 raise TypeError("msg is not a dns.message.Message")
391 msgFlags
= dns
.flags
.to_text(msg
.flags
)
392 self
.assertTrue('AD' in msgFlags
, "No AD flag found in the message for %s" % msg
.question
[0].name
)
394 def assertRRsetInAnswer(self
, msg
, rrset
):
395 """Asserts the rrset (without comparing TTL) exists in the
396 answer section of msg
398 @param msg: the dns.message.Message to check
399 @param rrset: a dns.rrset.RRset object"""
402 if not isinstance(msg
, dns
.message
.Message
):
403 raise TypeError("msg is not a dns.message.Message")
405 if not isinstance(rrset
, dns
.rrset
.RRset
):
406 raise TypeError("rrset is not a dns.rrset.RRset")
409 for ans
in msg
.answer
:
410 ret
+= "%s\n" % ans
.to_text()
411 if ans
.match(rrset
.name
, rrset
.rdclass
, rrset
.rdtype
, 0, None):
412 self
.assertEqual(ans
, rrset
, "'%s' != '%s'" % (ans
.to_text(), rrset
.to_text()))
416 raise AssertionError("RRset not found in answer\n\n%s" % ret
)
418 def sortRRsets(self
, rrsets
):
419 """Sorts RRsets in a more useful way than dnspython's default behaviour
421 @param rrsets: an array of dns.rrset.RRset objects"""
423 return sorted(rrsets
, key
=lambda rrset
: (rrset
.name
, rrset
.rdtype
))
425 def assertAnyRRsetInAnswer(self
, msg
, rrsets
):
426 """Asserts that any of the supplied rrsets exists (without comparing TTL)
427 in the answer section of msg
429 @param msg: the dns.message.Message to check
430 @param rrsets: an array of dns.rrset.RRset object"""
432 if not isinstance(msg
, dns
.message
.Message
):
433 raise TypeError("msg is not a dns.message.Message")
437 if not isinstance(rrset
, dns
.rrset
.RRset
):
438 raise TypeError("rrset is not a dns.rrset.RRset")
439 for ans
in msg
.answer
:
440 if ans
.match(rrset
.name
, rrset
.rdclass
, rrset
.rdtype
, 0, None):
445 raise AssertionError("RRset not found in answer\n%s" %
446 "\n".join(([ans
.to_text() for ans
in msg
.answer
])))
448 def assertMatchingRRSIGInAnswer(self
, msg
, coveredRRset
, keys
=None):
449 """Looks for coveredRRset in the answer section and if there is an RRSIG RRset
450 that covers that RRset. If keys is not None, this function will also try to
451 validate the RRset against the RRSIG
453 @param msg: The dns.message.Message to check
454 @param coveredRRset: The RRSet to check for
455 @param keys: a dictionary keyed by dns.name.Name with node or rdataset values to use for validation"""
457 if not isinstance(msg
, dns
.message
.Message
):
458 raise TypeError("msg is not a dns.message.Message")
460 if not isinstance(coveredRRset
, dns
.rrset
.RRset
):
461 raise TypeError("coveredRRset is not a dns.rrset.RRset")
467 for ans
in msg
.answer
:
468 ret
+= ans
.to_text() + "\n"
470 if ans
.match(coveredRRset
.name
, coveredRRset
.rdclass
, coveredRRset
.rdtype
, 0, None):
472 if ans
.match(coveredRRset
.name
, dns
.rdataclass
.IN
, dns
.rdatatype
.RRSIG
, coveredRRset
.rdtype
, None):
474 if msgRRSet
and msgRRsigRRSet
:
478 raise AssertionError("RRset for '%s' not found in answer" % msg
.question
[0].to_text())
480 if not msgRRsigRRSet
:
481 raise AssertionError("No RRSIGs found in answer for %s:\nFull answer:\n%s" % (msg
.question
[0].to_text(), ret
))
485 dns
.dnssec
.validate(msgRRSet
, msgRRsigRRSet
.to_rdataset(), keys
)
486 except dns
.dnssec
.ValidationFailure
as e
:
487 raise AssertionError("Signature validation failed for %s:\n%s" % (msg
.question
[0].to_text(), e
))
489 def assertNoRRSIGsInAnswer(self
, msg
):
490 """Checks if there are _no_ RRSIGs in the answer section of msg"""
492 if not isinstance(msg
, dns
.message
.Message
):
493 raise TypeError("msg is not a dns.message.Message")
496 for ans
in msg
.answer
:
497 if ans
.rdtype
== dns
.rdatatype
.RRSIG
:
498 ret
+= ans
.name
.to_text() + "\n"
501 raise AssertionError("RRSIG found in answers for:\n%s" % ret
)
503 def assertAnswerEmpty(self
, msg
):
504 self
.assertTrue(len(msg
.answer
) == 0, "Data found in the the answer section for %s:\n%s" % (msg
.question
[0].to_text(), '\n'.join([i
.to_text() for i
in msg
.answer
])))
506 def assertAnswerNotEmpty(self
, msg
):
507 self
.assertTrue(len(msg
.answer
) > 0, "Answer is empty")
509 def assertRcodeEqual(self
, msg
, rcode
):
510 if not isinstance(msg
, dns
.message
.Message
):
511 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg
))
513 if not isinstance(rcode
, int):
514 if isinstance(rcode
, str):
515 rcode
= dns
.rcode
.from_text(rcode
)
517 raise TypeError("rcode is neither a str nor int")
519 if msg
.rcode() != rcode
:
520 msgRcode
= dns
.rcode
._by
_value
[msg
.rcode()]
521 wantedRcode
= dns
.rcode
._by
_value
[rcode
]
523 raise AssertionError("Rcode for %s is %s, expected %s." % (msg
.question
[0].to_text(), msgRcode
, wantedRcode
))
525 def assertAuthorityHasSOA(self
, msg
):
526 if not isinstance(msg
, dns
.message
.Message
):
527 raise TypeError("msg is not a dns.message.Message but a %s" % type(msg
))
530 for rrset
in msg
.authority
:
531 if rrset
.rdtype
== dns
.rdatatype
.SOA
:
536 raise AssertionError("No SOA record found in the authority section:\n%s" % msg
.to_text())