"""Common DNSSEC-related functions and constants."""
+import enum
import hashlib
import io
import struct
total += ((total >> 16) & 0xffff)
return total & 0xffff
+class DSDigest(enum.IntEnum):
+ """DNSSEC Delgation Signer Digest Algorithm"""
+
+ SHA1 = 1
+ SHA256 = 2
+ SHA384 = 4
+
def make_ds(name, key, algorithm, origin=None):
"""Create a DS record for a DNSSEC key.
*key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``, the key the DS is about.
- *algorithm*, a ``str`` specifying the hash algorithm.
+ *algorithm*, a ``str`` or ``int`` specifying the hash algorithm.
The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case
does not matter for these strings.
Returns a ``dns.rdtypes.ANY.DS.DS``
"""
- if algorithm.upper() == 'SHA1':
- dsalg = 1
+ try:
+ if isinstance(algorithm, str):
+ algorithm = DSDigest[algorithm.upper()]
+ except Exception:
+ raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
+
+ if algorithm == DSDigest.SHA1:
dshash = hashlib.sha1()
- elif algorithm.upper() == 'SHA256':
- dsalg = 2
+ elif algorithm == DSDigest.SHA256:
dshash = hashlib.sha256()
- elif algorithm.upper() == 'SHA384':
- dsalg = 4
+ elif algorithm == DSDigest.SHA384:
dshash = hashlib.sha384()
else:
raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
dshash.update(_to_rdata(key, origin))
digest = dshash.digest()
- dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, dsalg) + digest
+ dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \
+ digest
return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0,
len(dsrdata))
raise ValidationFailure("no RRSIGs validated")
+class NSEC3Hash(enum.IntEnum):
+ """NSEC3 hash algorithm"""
+
+ SHA1 = 1
+
+
def nsec3_hash(domain, salt, iterations, algorithm):
"""
Calculate the NSEC3 hash, according to
*iterations*, an ``int``, the number of iterations.
- *algorithm*, an ``int``, the hash algorithm. The only defined algorithm
- is SHA1.
+ *algorithm*, a ``str`` or ``int``, the hash algorithm.
+ The only defined algorithm is SHA1.
Returns a ``str``, the encoded NSEC3 hash.
"""
"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", "0123456789ABCDEFGHIJKLMNOPQRSTUV"
)
- if algorithm != 1:
+ try:
+ if isinstance(algorithm, str):
+ algorithm = NSEC3Hash[algorithm.upper()]
+ except Exception:
+ raise ValueError("Wrong hash algorithm (only SHA1 is supported)")
+
+ if algorithm != NSEC3Hash.SHA1:
raise ValueError("Wrong hash algorithm (only SHA1 is supported)")
salt_encoded = salt
class DNSSECMakeDSTestCase(unittest.TestCase):
def testMakeExampleSHA1DS(self): # type: () -> None
- ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA1')
- self.assertEqual(ds, example_ds_sha1)
+ for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1):
+ ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
+ self.assertEqual(ds, example_ds_sha1)
def testMakeExampleSHA256DS(self): # type: () -> None
- ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA256')
- self.assertEqual(ds, example_ds_sha256)
+ for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256):
+ ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
+ self.assertEqual(ds, example_ds_sha256)
def testMakeExampleSHA384DS(self): # type: () -> None
- ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA384')
- self.assertEqual(ds, example_ds_sha384)
+ for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384):
+ ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
+ self.assertEqual(ds, example_ds_sha384)
def testMakeSHA256DS(self): # type: () -> None
ds = dns.dnssec.make_ds(abs_dnspython_org, sep_key, 'SHA256')
self.assertEqual(ds, good_ds)
+ def testInvalidAlgorithm(self): # type: () -> None
+ for algorithm in (10, 'shax'):
+ with self.assertRaises(dns.dnssec.UnsupportedAlgorithm):
+ ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
if __name__ == '__main__':
unittest.main()
1,
),
("*.test-domain.dev", None, 45, "505k9g118d9sofnjhh54rr8fadgpa0ct", 1),
+ (
+ "example",
+ "aabbccdd",
+ 12,
+ "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom",
+ dnssec.NSEC3Hash.SHA1
+ ),
+ ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "SHA1"),
+ ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "sha1")
]
def test_hash_function(self):
with self.assertRaises(ValueError):
hash = dnssec.nsec3_hash(data[0], data[1], data[2], data[4])
+ def test_hash_invalid_algorithm(self):
+ data = (
+ "example.com",
+ "9F1AB450CF71D",
+ 0,
+ "qfo2sv6jaej4cm11a3npoorfrckdao2c",
+ 1,
+ )
+ with self.assertRaises(ValueError):
+ dnssec.nsec3_hash(data[0], data[1], data[2], 10)
+ with self.assertRaises(ValueError):
+ dnssec.nsec3_hash(data[0], data[1], data[2], "foo")
+
+
+
if __name__ == "__main__":
unittest.main()