From: Brian Wellington Date: Fri, 15 May 2020 23:47:35 +0000 (-0700) Subject: Improve consistency in DNSSEC code. X-Git-Tag: v2.0.0rc1~205^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ac861662a9f6ed163449281c5600a3ed42930786;p=thirdparty%2Fdnspython.git Improve consistency in DNSSEC code. The make_ds method took its algorithm as a string, and the nsec3_hash method took an algorithm as an int. Change both of them to accept either, and add enums for both sets of algorithms. --- diff --git a/dns/dnssec.py b/dns/dnssec.py index b1febd51..36694ffd 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -17,6 +17,7 @@ """Common DNSSEC-related functions and constants.""" +import enum import hashlib import io import struct @@ -157,6 +158,13 @@ def key_id(key): total += ((total >> 16) & 0xffff) return total & 0xffff +class DSDigest(enum.IntEnum): + """DNSSEC Delgation Signer Digest Algorithm""" + + SHA1 = 1 + SHA256 = 2 + SHA384 = 4 + def make_ds(name, key, algorithm, origin=None): """Create a DS record for a DNSSEC key. @@ -165,7 +173,7 @@ def make_ds(name, key, algorithm, origin=None): *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``, the key the DS is about. - *algorithm*, a ``str`` specifying the hash algorithm. + *algorithm*, a ``str`` or ``int`` specifying the hash algorithm. The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case does not matter for these strings. @@ -177,14 +185,17 @@ def make_ds(name, key, algorithm, origin=None): Returns a ``dns.rdtypes.ANY.DS.DS`` """ - if algorithm.upper() == 'SHA1': - dsalg = 1 + try: + if isinstance(algorithm, str): + algorithm = DSDigest[algorithm.upper()] + except Exception: + raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) + + if algorithm == DSDigest.SHA1: dshash = hashlib.sha1() - elif algorithm.upper() == 'SHA256': - dsalg = 2 + elif algorithm == DSDigest.SHA256: dshash = hashlib.sha256() - elif algorithm.upper() == 'SHA384': - dsalg = 4 + elif algorithm == DSDigest.SHA384: dshash = hashlib.sha384() else: raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) @@ -195,7 +206,8 @@ def make_ds(name, key, algorithm, origin=None): dshash.update(_to_rdata(key, origin)) digest = dshash.digest() - dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, dsalg) + digest + dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \ + digest return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, len(dsrdata)) @@ -524,6 +536,12 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): raise ValidationFailure("no RRSIGs validated") +class NSEC3Hash(enum.IntEnum): + """NSEC3 hash algorithm""" + + SHA1 = 1 + + def nsec3_hash(domain, salt, iterations, algorithm): """ Calculate the NSEC3 hash, according to @@ -536,8 +554,8 @@ def nsec3_hash(domain, salt, iterations, algorithm): *iterations*, an ``int``, the number of iterations. - *algorithm*, an ``int``, the hash algorithm. The only defined algorithm - is SHA1. + *algorithm*, a ``str`` or ``int``, the hash algorithm. + The only defined algorithm is SHA1. Returns a ``str``, the encoded NSEC3 hash. """ @@ -546,7 +564,13 @@ def nsec3_hash(domain, salt, iterations, algorithm): "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", "0123456789ABCDEFGHIJKLMNOPQRSTUV" ) - if algorithm != 1: + try: + if isinstance(algorithm, str): + algorithm = NSEC3Hash[algorithm.upper()] + except Exception: + raise ValueError("Wrong hash algorithm (only SHA1 is supported)") + + if algorithm != NSEC3Hash.SHA1: raise ValueError("Wrong hash algorithm (only SHA1 is supported)") salt_encoded = salt diff --git a/tests/test_dnssec.py b/tests/test_dnssec.py index f38637d8..c10becd1 100644 --- a/tests/test_dnssec.py +++ b/tests/test_dnssec.py @@ -281,21 +281,28 @@ class DNSSECValidatorTestCase(unittest.TestCase): class DNSSECMakeDSTestCase(unittest.TestCase): def testMakeExampleSHA1DS(self): # type: () -> None - ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA1') - self.assertEqual(ds, example_ds_sha1) + for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1): + ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) + self.assertEqual(ds, example_ds_sha1) def testMakeExampleSHA256DS(self): # type: () -> None - ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA256') - self.assertEqual(ds, example_ds_sha256) + for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256): + ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) + self.assertEqual(ds, example_ds_sha256) def testMakeExampleSHA384DS(self): # type: () -> None - ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA384') - self.assertEqual(ds, example_ds_sha384) + for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384): + ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) + self.assertEqual(ds, example_ds_sha384) def testMakeSHA256DS(self): # type: () -> None ds = dns.dnssec.make_ds(abs_dnspython_org, sep_key, 'SHA256') self.assertEqual(ds, good_ds) + def testInvalidAlgorithm(self): # type: () -> None + for algorithm in (10, 'shax'): + with self.assertRaises(dns.dnssec.UnsupportedAlgorithm): + ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) if __name__ == '__main__': unittest.main() diff --git a/tests/test_nsec3_hash.py b/tests/test_nsec3_hash.py index 6f18240d..0fd085cf 100644 --- a/tests/test_nsec3_hash.py +++ b/tests/test_nsec3_hash.py @@ -49,6 +49,15 @@ class NSEC3Hash(unittest.TestCase): 1, ), ("*.test-domain.dev", None, 45, "505k9g118d9sofnjhh54rr8fadgpa0ct", 1), + ( + "example", + "aabbccdd", + 12, + "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", + dnssec.NSEC3Hash.SHA1 + ), + ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "SHA1"), + ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "sha1") ] def test_hash_function(self): @@ -67,6 +76,21 @@ class NSEC3Hash(unittest.TestCase): with self.assertRaises(ValueError): hash = dnssec.nsec3_hash(data[0], data[1], data[2], data[4]) + def test_hash_invalid_algorithm(self): + data = ( + "example.com", + "9F1AB450CF71D", + 0, + "qfo2sv6jaej4cm11a3npoorfrckdao2c", + 1, + ) + with self.assertRaises(ValueError): + dnssec.nsec3_hash(data[0], data[1], data[2], 10) + with self.assertRaises(ValueError): + dnssec.nsec3_hash(data[0], data[1], data[2], "foo") + + + if __name__ == "__main__": unittest.main()