]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Improve consistency in DNSSEC code.
authorBrian Wellington <bwelling@xbill.org>
Fri, 15 May 2020 23:47:35 +0000 (16:47 -0700)
committerBrian Wellington <bwelling@xbill.org>
Fri, 15 May 2020 23:47:35 +0000 (16:47 -0700)
The make_ds method took its algorithm as a string, and the nsec3_hash
method took an algorithm as an int.  Change both of them to accept
either, and add enums for both sets of algorithms.

dns/dnssec.py
tests/test_dnssec.py
tests/test_nsec3_hash.py

index b1febd512214661d8ccee9c89ca8e7f4bd74454e..36694ffd3903f646ab145a3267d18d2a360ab30f 100644 (file)
@@ -17,6 +17,7 @@
 
 """Common DNSSEC-related functions and constants."""
 
+import enum
 import hashlib
 import io
 import struct
@@ -157,6 +158,13 @@ def key_id(key):
         total += ((total >> 16) & 0xffff)
         return total & 0xffff
 
+class DSDigest(enum.IntEnum):
+    """DNSSEC Delgation Signer Digest Algorithm"""
+
+    SHA1 = 1
+    SHA256 = 2
+    SHA384 = 4
+
 
 def make_ds(name, key, algorithm, origin=None):
     """Create a DS record for a DNSSEC key.
@@ -165,7 +173,7 @@ def make_ds(name, key, algorithm, origin=None):
 
     *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY``, the key the DS is about.
 
-    *algorithm*, a ``str`` specifying the hash algorithm.
+    *algorithm*, a ``str`` or ``int`` specifying the hash algorithm.
     The currently supported hashes are "SHA1", "SHA256", and "SHA384". Case
     does not matter for these strings.
 
@@ -177,14 +185,17 @@ def make_ds(name, key, algorithm, origin=None):
     Returns a ``dns.rdtypes.ANY.DS.DS``
     """
 
-    if algorithm.upper() == 'SHA1':
-        dsalg = 1
+    try:
+        if isinstance(algorithm, str):
+            algorithm = DSDigest[algorithm.upper()]
+    except Exception:
+        raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
+
+    if algorithm == DSDigest.SHA1:
         dshash = hashlib.sha1()
-    elif algorithm.upper() == 'SHA256':
-        dsalg = 2
+    elif algorithm == DSDigest.SHA256:
         dshash = hashlib.sha256()
-    elif algorithm.upper() == 'SHA384':
-        dsalg = 4
+    elif algorithm == DSDigest.SHA384:
         dshash = hashlib.sha384()
     else:
         raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm)
@@ -195,7 +206,8 @@ def make_ds(name, key, algorithm, origin=None):
     dshash.update(_to_rdata(key, origin))
     digest = dshash.digest()
 
-    dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, dsalg) + digest
+    dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \
+            digest
     return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0,
                                len(dsrdata))
 
@@ -524,6 +536,12 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None):
     raise ValidationFailure("no RRSIGs validated")
 
 
+class NSEC3Hash(enum.IntEnum):
+    """NSEC3 hash algorithm"""
+
+    SHA1 = 1
+
+
 def nsec3_hash(domain, salt, iterations, algorithm):
     """
     Calculate the NSEC3 hash, according to
@@ -536,8 +554,8 @@ def nsec3_hash(domain, salt, iterations, algorithm):
 
     *iterations*, an ``int``, the number of iterations.
 
-    *algorithm*, an ``int``, the hash algorithm.  The only defined algorithm
-    is SHA1.
+    *algorithm*, a ``str`` or ``int``, the hash algorithm.
+    The only defined algorithm is SHA1.
 
     Returns a ``str``, the encoded NSEC3 hash.
     """
@@ -546,7 +564,13 @@ def nsec3_hash(domain, salt, iterations, algorithm):
         "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567", "0123456789ABCDEFGHIJKLMNOPQRSTUV"
     )
 
-    if algorithm != 1:
+    try:
+        if isinstance(algorithm, str):
+            algorithm = NSEC3Hash[algorithm.upper()]
+    except Exception:
+        raise ValueError("Wrong hash algorithm (only SHA1 is supported)")
+
+    if algorithm != NSEC3Hash.SHA1:
         raise ValueError("Wrong hash algorithm (only SHA1 is supported)")
 
     salt_encoded = salt
index f38637d8a0f4a9bad66254749091cf21568822eb..c10becd178e55057e159160e4288b2e684886fee 100644 (file)
@@ -281,21 +281,28 @@ class DNSSECValidatorTestCase(unittest.TestCase):
 class DNSSECMakeDSTestCase(unittest.TestCase):
 
     def testMakeExampleSHA1DS(self):  # type: () -> None
-        ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA1')
-        self.assertEqual(ds, example_ds_sha1)
+        for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1):
+            ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
+            self.assertEqual(ds, example_ds_sha1)
 
     def testMakeExampleSHA256DS(self):  # type: () -> None
-        ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA256')
-        self.assertEqual(ds, example_ds_sha256)
+        for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256):
+            ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
+            self.assertEqual(ds, example_ds_sha256)
 
     def testMakeExampleSHA384DS(self):  # type: () -> None
-        ds = dns.dnssec.make_ds(abs_example, example_sep_key, 'SHA384')
-        self.assertEqual(ds, example_ds_sha384)
+        for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384):
+            ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
+            self.assertEqual(ds, example_ds_sha384)
 
     def testMakeSHA256DS(self):  # type: () -> None
         ds = dns.dnssec.make_ds(abs_dnspython_org, sep_key, 'SHA256')
         self.assertEqual(ds, good_ds)
 
+    def testInvalidAlgorithm(self):  # type: () -> None
+        for algorithm in (10, 'shax'):
+            with self.assertRaises(dns.dnssec.UnsupportedAlgorithm):
+                ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm)
 
 if __name__ == '__main__':
     unittest.main()
index 6f18240def75332acff25ed26864bda7ea58fa3a..0fd085cf1f6cf0d795e10f72f7caed2b7d26220f 100644 (file)
@@ -49,6 +49,15 @@ class NSEC3Hash(unittest.TestCase):
             1,
         ),
         ("*.test-domain.dev", None, 45, "505k9g118d9sofnjhh54rr8fadgpa0ct", 1),
+        (
+            "example",
+            "aabbccdd",
+            12,
+            "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom",
+            dnssec.NSEC3Hash.SHA1
+        ),
+        ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "SHA1"),
+        ("example", "aabbccdd", 12, "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom", "sha1")
     ]
 
     def test_hash_function(self):
@@ -67,6 +76,21 @@ class NSEC3Hash(unittest.TestCase):
         with self.assertRaises(ValueError):
             hash = dnssec.nsec3_hash(data[0], data[1], data[2], data[4])
 
+    def test_hash_invalid_algorithm(self):
+        data = (
+            "example.com",
+            "9F1AB450CF71D",
+            0,
+            "qfo2sv6jaej4cm11a3npoorfrckdao2c",
+            1,
+        )
+        with self.assertRaises(ValueError):
+            dnssec.nsec3_hash(data[0], data[1], data[2], 10)
+        with self.assertRaises(ValueError):
+            dnssec.nsec3_hash(data[0], data[1], data[2], "foo")
+
+
+
 
 if __name__ == "__main__":
     unittest.main()