]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
DNSSEC Algorithm Refactor (#944)
authorJakob Schlyter <jakob@kirei.se>
Sun, 25 Jun 2023 21:01:00 +0000 (23:01 +0200)
committerGitHub <noreply@github.com>
Sun, 25 Jun 2023 21:01:00 +0000 (14:01 -0700)
* Split DNSSEC algorithms into separate classes with a registration mechanism.
* Add DNSSEC private algorithm support.

dns/dnssec.py
dns/dnssecalgs/__init__.py [new file with mode: 0644]
dns/dnssecalgs/base.py [new file with mode: 0644]
dns/dnssecalgs/cryptography.py [new file with mode: 0644]
dns/dnssecalgs/dsa.py [new file with mode: 0644]
dns/dnssecalgs/ecdsa.py [new file with mode: 0644]
dns/dnssecalgs/eddsa.py [new file with mode: 0644]
dns/dnssecalgs/rsa.py [new file with mode: 0644]
dns/exception.py
tests/test_dnssec.py
tests/test_dnssecalgs.py [new file with mode: 0644]

index 55fd7b57d4f98fee07337e22ff13a60ce9869e4d..d9b8d98d7358240cf63bbcadad3f60f4e9b21a65 100644 (file)
 """Common DNSSEC-related functions and constants."""
 
 
-from typing import Any, cast, Callable, Dict, List, Optional, Set, Tuple, Union
+from typing import cast, Callable, Dict, List, Optional, Set, Tuple, Union
 
 import contextlib
 import functools
 import hashlib
-import math
 import struct
 import time
 import base64
@@ -41,6 +40,12 @@ import dns.rdataclass
 import dns.rrset
 import dns.transaction
 import dns.zone
+from dns.exception import (  # pylint: disable=W0611
+    AlgorithmKeyMismatch,
+    DeniedByPolicy,
+    UnsupportedAlgorithm,
+    ValidationFailure,
+)
 from dns.rdtypes.ANY.CDNSKEY import CDNSKEY
 from dns.rdtypes.ANY.CDS import CDS
 from dns.rdtypes.ANY.DNSKEY import DNSKEY
@@ -51,23 +56,8 @@ from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime
 from dns.rdtypes.dnskeybase import Flag
 
 
-class UnsupportedAlgorithm(dns.exception.DNSException):
-    """The DNSSEC algorithm is not supported."""
-
-
-class AlgorithmKeyMismatch(UnsupportedAlgorithm):
-    """The DNSSEC algorithm is not supported for the given key type."""
-
-
-class ValidationFailure(dns.exception.DNSException):
-    """The DNSSEC signature is invalid."""
-
-
-class DeniedByPolicy(dns.exception.DNSException):
-    """Denied by DNSSEC policy."""
-
-
 PublicKey = Union[
+    "GenericPublicKey",
     "rsa.RSAPublicKey",
     "ec.EllipticCurvePublicKey",
     "ed25519.Ed25519PublicKey",
@@ -75,6 +65,7 @@ PublicKey = Union[
 ]
 
 PrivateKey = Union[
+    "GenericPrivateKey",
     "rsa.RSAPrivateKey",
     "ec.EllipticCurvePrivateKey",
     "ed25519.Ed25519PrivateKey",
@@ -321,109 +312,6 @@ def _find_candidate_keys(
     ]
 
 
-def _is_rsa(algorithm: int) -> bool:
-    return algorithm in (
-        Algorithm.RSAMD5,
-        Algorithm.RSASHA1,
-        Algorithm.RSASHA1NSEC3SHA1,
-        Algorithm.RSASHA256,
-        Algorithm.RSASHA512,
-    )
-
-
-def _is_dsa(algorithm: int) -> bool:
-    return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1)
-
-
-def _is_ecdsa(algorithm: int) -> bool:
-    return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384)
-
-
-def _is_eddsa(algorithm: int) -> bool:
-    return algorithm in (Algorithm.ED25519, Algorithm.ED448)
-
-
-def _is_gost(algorithm: int) -> bool:
-    return algorithm == Algorithm.ECCGOST
-
-
-def _is_md5(algorithm: int) -> bool:
-    return algorithm == Algorithm.RSAMD5
-
-
-def _is_sha1(algorithm: int) -> bool:
-    return algorithm in (
-        Algorithm.DSA,
-        Algorithm.RSASHA1,
-        Algorithm.DSANSEC3SHA1,
-        Algorithm.RSASHA1NSEC3SHA1,
-    )
-
-
-def _is_sha256(algorithm: int) -> bool:
-    return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256)
-
-
-def _is_sha384(algorithm: int) -> bool:
-    return algorithm == Algorithm.ECDSAP384SHA384
-
-
-def _is_sha512(algorithm: int) -> bool:
-    return algorithm == Algorithm.RSASHA512
-
-
-def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None:
-    """Ensure algorithm is valid for key type, throwing an exception on
-    mismatch."""
-    if isinstance(key, rsa.RSAPublicKey):
-        if _is_rsa(algorithm):
-            return
-        raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm)
-    if isinstance(key, dsa.DSAPublicKey):
-        if _is_dsa(algorithm):
-            return
-        raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm)
-    if isinstance(key, ec.EllipticCurvePublicKey):
-        if _is_ecdsa(algorithm):
-            return
-        raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm)
-    if isinstance(key, ed25519.Ed25519PublicKey):
-        if algorithm == Algorithm.ED25519:
-            return
-        raise AlgorithmKeyMismatch(
-            'algorithm "%s" not valid for ED25519 key' % algorithm
-        )
-    if isinstance(key, ed448.Ed448PublicKey):
-        if algorithm == Algorithm.ED448:
-            return
-        raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm)
-
-    raise TypeError("unsupported key type")
-
-
-def _make_hash(algorithm: int) -> Any:
-    if _is_md5(algorithm):
-        return hashes.MD5()
-    if _is_sha1(algorithm):
-        return hashes.SHA1()
-    if _is_sha256(algorithm):
-        return hashes.SHA256()
-    if _is_sha384(algorithm):
-        return hashes.SHA384()
-    if _is_sha512(algorithm):
-        return hashes.SHA512()
-    if algorithm == Algorithm.ED25519:
-        return hashes.SHA512()
-    if algorithm == Algorithm.ED448:
-        return hashes.SHAKE256(114)
-
-    raise ValidationFailure("unknown hash for algorithm %u" % algorithm)
-
-
-def _bytes_to_long(b: bytes) -> int:
-    return int.from_bytes(b, "big")
-
-
 def _get_rrname_rdataset(
     rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
 ) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]:
@@ -433,85 +321,13 @@ def _get_rrname_rdataset(
         return rrset.name, rrset
 
 
-def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None:
-    keyptr: bytes
-    if _is_rsa(key.algorithm):
-        # we ignore because mypy is confused and thinks key.key is a str for unknown
-        # reasons.
-        keyptr = key.key
-        (bytes_,) = struct.unpack("!B", keyptr[0:1])
-        keyptr = keyptr[1:]
-        if bytes_ == 0:
-            (bytes_,) = struct.unpack("!H", keyptr[0:2])
-            keyptr = keyptr[2:]
-        rsa_e = keyptr[0:bytes_]
-        rsa_n = keyptr[bytes_:]
-        try:
-            rsa_public_key = rsa.RSAPublicNumbers(
-                _bytes_to_long(rsa_e), _bytes_to_long(rsa_n)
-            ).public_key(default_backend())
-        except ValueError:
-            raise ValidationFailure("invalid public key")
-        rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash)
-    elif _is_dsa(key.algorithm):
-        keyptr = key.key
-        (t,) = struct.unpack("!B", keyptr[0:1])
-        keyptr = keyptr[1:]
-        octets = 64 + t * 8
-        dsa_q = keyptr[0:20]
-        keyptr = keyptr[20:]
-        dsa_p = keyptr[0:octets]
-        keyptr = keyptr[octets:]
-        dsa_g = keyptr[0:octets]
-        keyptr = keyptr[octets:]
-        dsa_y = keyptr[0:octets]
-        try:
-            dsa_public_key = dsa.DSAPublicNumbers(  # type: ignore
-                _bytes_to_long(dsa_y),
-                dsa.DSAParameterNumbers(
-                    _bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g)
-                ),
-            ).public_key(default_backend())
-        except ValueError:
-            raise ValidationFailure("invalid public key")
-        dsa_public_key.verify(sig, data, chosen_hash)
-    elif _is_ecdsa(key.algorithm):
-        keyptr = key.key
-        curve: Any
-        if key.algorithm == Algorithm.ECDSAP256SHA256:
-            curve = ec.SECP256R1()
-            octets = 32
-        else:
-            curve = ec.SECP384R1()
-            octets = 48
-        ecdsa_x = keyptr[0:octets]
-        ecdsa_y = keyptr[octets : octets * 2]
-        try:
-            ecdsa_public_key = ec.EllipticCurvePublicNumbers(
-                curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y)
-            ).public_key(default_backend())
-        except ValueError:
-            raise ValidationFailure("invalid public key")
-        ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash))
-    elif _is_eddsa(key.algorithm):
-        keyptr = key.key
-        loader: Any
-        if key.algorithm == Algorithm.ED25519:
-            loader = ed25519.Ed25519PublicKey
-        else:
-            loader = ed448.Ed448PublicKey
-        try:
-            eddsa_public_key = loader.from_public_bytes(keyptr)
-        except ValueError:
-            raise ValidationFailure("invalid public key")
-        eddsa_public_key.verify(sig, data)
-    elif _is_gost(key.algorithm):
-        raise UnsupportedAlgorithm(
-            'algorithm "%s" not supported by dnspython'
-            % algorithm_to_text(key.algorithm)
-        )
-    else:
-        raise ValidationFailure("unknown algorithm %u" % key.algorithm)
+def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None:
+    public_cls = get_algorithm_cls_from_dnskey(key).public_cls
+    try:
+        public_key = public_cls.from_dnskey(key)
+    except ValueError:
+        raise ValidationFailure("invalid public key")
+    public_key.verify(sig, data)
 
 
 def _validate_rrsig(
@@ -568,29 +384,13 @@ def _validate_rrsig(
     if rrsig.inception > now:
         raise ValidationFailure("not yet valid")
 
-    if _is_dsa(rrsig.algorithm):
-        sig_r = rrsig.signature[1:21]
-        sig_s = rrsig.signature[21:]
-        sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
-    elif _is_ecdsa(rrsig.algorithm):
-        if rrsig.algorithm == Algorithm.ECDSAP256SHA256:
-            octets = 32
-        else:
-            octets = 48
-        sig_r = rrsig.signature[0:octets]
-        sig_s = rrsig.signature[octets:]
-        sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s))
-    else:
-        sig = rrsig.signature
-
     data = _make_rrsig_signature_data(rrset, rrsig, origin)
-    chosen_hash = _make_hash(rrsig.algorithm)
 
     for candidate_key in candidate_keys:
         if not policy.ok_to_validate(candidate_key):
             continue
         try:
-            _validate_signature(sig, data, candidate_key, chosen_hash)
+            _validate_signature(rrsig.signature, data, candidate_key)
             return
         except (InvalidSignature, ValidationFailure):
             # this happens on an individual validation failure
@@ -778,62 +578,17 @@ def _sign(
     )
 
     data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
-    chosen_hash = _make_hash(rrsig_template.algorithm)
-    signature = None
-
-    if isinstance(private_key, rsa.RSAPrivateKey):
-        if not _is_rsa(dnskey.algorithm):
-            raise ValueError("Invalid DNSKEY algorithm for RSA key")
-        signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash)
-        if verify:
-            private_key.public_key().verify(
-                signature, data, padding.PKCS1v15(), chosen_hash
-            )
-    elif isinstance(private_key, dsa.DSAPrivateKey):
-        if not _is_dsa(dnskey.algorithm):
-            raise ValueError("Invalid DNSKEY algorithm for DSA key")
-        public_dsa_key = private_key.public_key()
-        if public_dsa_key.key_size > 1024:
-            raise ValueError("DSA key size overflow")
-        der_signature = private_key.sign(data, chosen_hash)
-        if verify:
-            public_dsa_key.verify(der_signature, data, chosen_hash)
-        dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
-        dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
-        octets = 20
-        signature = (
-            struct.pack("!B", dsa_t)
-            + int.to_bytes(dsa_r, length=octets, byteorder="big")
-            + int.to_bytes(dsa_s, length=octets, byteorder="big")
-        )
-    elif isinstance(private_key, ec.EllipticCurvePrivateKey):
-        if not _is_ecdsa(dnskey.algorithm):
-            raise ValueError("Invalid DNSKEY algorithm for EC key")
-        der_signature = private_key.sign(data, ec.ECDSA(chosen_hash))
-        if verify:
-            private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash))
-        if dnskey.algorithm == Algorithm.ECDSAP256SHA256:
-            octets = 32
-        else:
-            octets = 48
-        dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
-        signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes(
-            dsa_s, length=octets, byteorder="big"
-        )
-    elif isinstance(private_key, ed25519.Ed25519PrivateKey):
-        if dnskey.algorithm != Algorithm.ED25519:
-            raise ValueError("Invalid DNSKEY algorithm for ED25519 key")
-        signature = private_key.sign(data)
-        if verify:
-            private_key.public_key().verify(signature, data)
-    elif isinstance(private_key, ed448.Ed448PrivateKey):
-        if dnskey.algorithm != Algorithm.ED448:
-            raise ValueError("Invalid DNSKEY algorithm for ED448 key")
-        signature = private_key.sign(data)
-        if verify:
-            private_key.public_key().verify(signature, data)
+
+    if isinstance(private_key, GenericPrivateKey):
+        signing_key = private_key
     else:
-        raise TypeError("Unsupported key algorithm")
+        try:
+            private_cls = get_algorithm_cls_from_dnskey(dnskey)
+            signing_key = private_cls(key=private_key)
+        except UnsupportedAlgorithm:
+            raise TypeError("Unsupported key algorithm")
+
+    signature = signing_key.sign(data, verify)
 
     return cast(RRSIG, rrsig_template.replace(signature=signature))
 
@@ -911,9 +666,8 @@ def _make_dnskey(
 ) -> DNSKEY:
     """Convert a public key to DNSKEY Rdata
 
-    *public_key*, the public key to convert, a
-    ``cryptography.hazmat.primitives.asymmetric`` public key class applicable
-    for DNSSEC.
+    *public_key*, a ``PublicKey`` (``GenericPublicKey`` or
+    ``cryptography.hazmat.primitives.asymmetric``) to convert.
 
     *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
 
@@ -929,72 +683,13 @@ def _make_dnskey(
     Return DNSKEY ``Rdata``.
     """
 
-    def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes:
-        """Encode a public key per RFC 3110, section 2."""
-        pn = public_key.public_numbers()
-        _exp_len = math.ceil(int.bit_length(pn.e) / 8)
-        exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
-        if _exp_len > 255:
-            exp_header = b"\0" + struct.pack("!H", _exp_len)
-        else:
-            exp_header = struct.pack("!B", _exp_len)
-        if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
-            raise ValueError("unsupported RSA key length")
-        return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
-
-    def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes:
-        """Encode a public key per RFC 2536, section 2."""
-        pn = public_key.public_numbers()
-        dsa_t = (public_key.key_size // 8 - 64) // 8
-        if dsa_t > 8:
-            raise ValueError("unsupported DSA key size")
-        octets = 64 + dsa_t * 8
-        res = struct.pack("!B", dsa_t)
-        res += pn.parameter_numbers.q.to_bytes(20, "big")
-        res += pn.parameter_numbers.p.to_bytes(octets, "big")
-        res += pn.parameter_numbers.g.to_bytes(octets, "big")
-        res += pn.y.to_bytes(octets, "big")
-        return res
-
-    def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes:
-        """Encode a public key per RFC 6605, section 4."""
-        pn = public_key.public_numbers()
-        if isinstance(public_key.curve, ec.SECP256R1):
-            return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big")
-        elif isinstance(public_key.curve, ec.SECP384R1):
-            return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big")
-        else:
-            raise ValueError("unsupported ECDSA curve")
-
     algorithm = Algorithm.make(algorithm)
 
-    _ensure_algorithm_key_combination(algorithm, public_key)
-
-    if isinstance(public_key, rsa.RSAPublicKey):
-        key_bytes = encode_rsa_public_key(public_key)
-    elif isinstance(public_key, dsa.DSAPublicKey):
-        key_bytes = encode_dsa_public_key(public_key)
-    elif isinstance(public_key, ec.EllipticCurvePublicKey):
-        key_bytes = encode_ecdsa_public_key(public_key)
-    elif isinstance(public_key, ed25519.Ed25519PublicKey):
-        key_bytes = public_key.public_bytes(
-            encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
-        )
-    elif isinstance(public_key, ed448.Ed448PublicKey):
-        key_bytes = public_key.public_bytes(
-            encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
-        )
+    if isinstance(public_key, GenericPublicKey):
+        return public_key.to_dnskey(flags=flags, protocol=protocol)
     else:
-        raise TypeError("unsupported key algorithm")
-
-    return DNSKEY(
-        rdclass=dns.rdataclass.IN,
-        rdtype=dns.rdatatype.DNSKEY,
-        flags=flags,
-        protocol=protocol,
-        algorithm=algorithm,
-        key=key_bytes,
-    )
+        public_cls = get_algorithm_cls(algorithm).public_cls
+        return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol)
 
 
 def _make_cdnskey(
@@ -1476,15 +1171,18 @@ def _need_pyca(*args, **kwargs):
 
 try:
     from cryptography.exceptions import InvalidSignature
-    from cryptography.hazmat.backends import default_backend
-    from cryptography.hazmat.primitives import hashes, serialization
-    from cryptography.hazmat.primitives.asymmetric import padding
-    from cryptography.hazmat.primitives.asymmetric import utils
-    from cryptography.hazmat.primitives.asymmetric import dsa
-    from cryptography.hazmat.primitives.asymmetric import ec
-    from cryptography.hazmat.primitives.asymmetric import ed25519
-    from cryptography.hazmat.primitives.asymmetric import ed448
-    from cryptography.hazmat.primitives.asymmetric import rsa
+    from cryptography.hazmat.primitives.asymmetric import dsa  # pylint: disable=W0611
+    from cryptography.hazmat.primitives.asymmetric import ec  # pylint: disable=W0611
+    from cryptography.hazmat.primitives.asymmetric import (  # pylint: disable=W0611
+        ed25519,
+    )
+    from cryptography.hazmat.primitives.asymmetric import ed448  # pylint: disable=W0611
+    from cryptography.hazmat.primitives.asymmetric import rsa  # pylint: disable=W0611
+    from dns.dnssecalgs import (  # pylint: disable=C0412
+        get_algorithm_cls,
+        get_algorithm_cls_from_dnskey,
+    )
+    from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
 except ImportError:  # pragma: no cover
     validate = _need_pyca
     validate_rrsig = _need_pyca
diff --git a/dns/dnssecalgs/__init__.py b/dns/dnssecalgs/__init__.py
new file mode 100644 (file)
index 0000000..d4c89cd
--- /dev/null
@@ -0,0 +1,111 @@
+from typing import Dict, Optional, Tuple, Type, Union
+
+import dns.name
+from dns.dnssecalgs.base import GenericPrivateKey
+from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
+from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
+from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519
+from dns.dnssecalgs.rsa import (
+    PrivateRSAMD5,
+    PrivateRSASHA1,
+    PrivateRSASHA1NSEC3SHA1,
+    PrivateRSASHA256,
+    PrivateRSASHA512,
+)
+from dns.dnssectypes import Algorithm
+from dns.exception import UnsupportedAlgorithm
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+
+
+AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]]
+
+algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = {
+    (Algorithm.RSAMD5, None): PrivateRSAMD5,
+    (Algorithm.DSA, None): PrivateDSA,
+    (Algorithm.RSASHA1, None): PrivateRSASHA1,
+    (Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1,
+    (Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1,
+    (Algorithm.RSASHA256, None): PrivateRSASHA256,
+    (Algorithm.RSASHA512, None): PrivateRSASHA512,
+    (Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256,
+    (Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384,
+    (Algorithm.ED25519, None): PrivateED25519,
+    (Algorithm.ED448, None): PrivateED448,
+}
+
+
+def get_algorithm_cls(
+    algorithm: Union[int, str], prefix: AlgorithmPrefix = None
+) -> Type[GenericPrivateKey]:
+    """Get Private Key class from Algorithm.
+
+    *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
+
+    Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
+
+    Returns a ``dns.dnssecalgs.GenericPrivateKey``
+    """
+    algorithm = Algorithm.make(algorithm)
+    cls = algorithms.get((algorithm, prefix))
+    if cls:
+        return cls
+    raise UnsupportedAlgorithm(
+        'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm)
+    )
+
+
+def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]:
+    """Get Private Key class from DNSKEY.
+
+    *dnskey*, a ``DNSKEY`` to get Algorithm class for.
+
+    Raises ``UnsupportedAlgorithm`` if the algorithm is unknown.
+
+    Returns a ``dns.dnssecalgs.GenericPrivateKey``
+    """
+    prefix: AlgorithmPrefix = None
+    if dnskey.algorithm == Algorithm.PRIVATEDNS:
+        prefix, _ = dns.name.from_wire(dnskey.key, 0)
+    elif dnskey.algorithm == Algorithm.PRIVATEOID:
+        length = int(dnskey.key[0])
+        prefix = dnskey.key[0 : length + 1]
+    return get_algorithm_cls(dnskey.algorithm, prefix)
+
+
+def register_algorithm_cls(
+    algorithm: Union[int, str],
+    algorithm_cls: Type[GenericPrivateKey],
+    name: Optional[Union[dns.name.Name, str]] = None,
+    oid: Optional[bytes] = None,
+) -> None:
+    """Register Algorithm Private Key class.
+
+    *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm.
+
+    *algorithm_cls*: A `GenericPrivateKey` class.
+
+    *name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms.
+
+    *oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms.
+
+    Raises ``ValueError`` if a name or oid is specified incorrectly.
+    """
+    if not issubclass(algorithm_cls, GenericPrivateKey):
+        raise TypeError("Invalid algorithm class")
+    algorithm = Algorithm.make(algorithm)
+    prefix: AlgorithmPrefix = None
+    if algorithm == Algorithm.PRIVATEDNS:
+        if name is None:
+            raise ValueError("Name required for PRIVATEDNS algorithms")
+        if isinstance(name, str):
+            name = dns.name.from_text(name)
+        prefix = name
+    elif algorithm == Algorithm.PRIVATEOID:
+        if oid is None:
+            raise ValueError("OID required for PRIVATEOID algorithms")
+        prefix = bytes([len(oid)]) + oid
+    elif name:
+        raise ValueError("Name only supported for PRIVATEDNS algorithm")
+    elif oid:
+        raise ValueError("OID only supported for PRIVATEOID algorithm")
+    algorithms[(algorithm, prefix)] = algorithm_cls
diff --git a/dns/dnssecalgs/base.py b/dns/dnssecalgs/base.py
new file mode 100644 (file)
index 0000000..c70b081
--- /dev/null
@@ -0,0 +1,84 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Type, Any
+
+import dns.rdataclass
+import dns.rdatatype
+from dns.dnssectypes import Algorithm
+from dns.exception import AlgorithmKeyMismatch
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+from dns.rdtypes.dnskeybase import Flag
+
+
+class GenericPublicKey(ABC):
+    algorithm: Algorithm
+
+    @abstractmethod
+    def __init__(self, key: Any) -> None:
+        pass
+
+    @abstractmethod
+    def verify(self, signature: bytes, data: bytes) -> None:
+        """Verify signed DNSSEC data"""
+
+    @abstractmethod
+    def encode_key_bytes(self) -> bytes:
+        """Encode key as bytes for DNSKEY"""
+
+    @classmethod
+    def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None:
+        if key.algorithm != cls.algorithm:
+            raise AlgorithmKeyMismatch
+
+    def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY:
+        """Return public key as DNSKEY"""
+        return DNSKEY(
+            rdclass=dns.rdataclass.IN,
+            rdtype=dns.rdatatype.DNSKEY,
+            flags=flags,
+            protocol=protocol,
+            algorithm=self.algorithm,
+            key=self.encode_key_bytes(),
+        )
+
+    @classmethod
+    @abstractmethod
+    def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey":
+        """Create public key from DNSKEY"""
+
+    @classmethod
+    @abstractmethod
+    def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
+        """Create public key from PEM-encoded SubjectPublicKeyInfo as specified
+        in RFC 5280"""
+
+    @abstractmethod
+    def to_pem(self) -> bytes:
+        """Return public-key as PEM-encoded SubjectPublicKeyInfo as specified
+        in RFC 5280"""
+
+
+class GenericPrivateKey(ABC):
+    public_cls: Type[GenericPublicKey]
+
+    @abstractmethod
+    def __init__(self, key: Any) -> None:
+        pass
+
+    @abstractmethod
+    def sign(self, data: bytes, verify: bool = False) -> bytes:
+        """Sign DNSSEC data"""
+
+    @abstractmethod
+    def public_key(self) -> "GenericPublicKey":
+        """Return public key instance"""
+
+    @classmethod
+    @abstractmethod
+    def from_pem(
+        cls, private_pem: bytes, password: Optional[bytes] = None
+    ) -> "GenericPrivateKey":
+        """Create private key from PEM-encoded PKCS#8"""
+
+    @abstractmethod
+    def to_pem(self, password: Optional[bytes] = None) -> bytes:
+        """Return private key as PEM-encoded PKCS#8"""
diff --git a/dns/dnssecalgs/cryptography.py b/dns/dnssecalgs/cryptography.py
new file mode 100644 (file)
index 0000000..b5bcd2e
--- /dev/null
@@ -0,0 +1,64 @@
+from typing import Any, Optional, Type
+
+from cryptography.hazmat.primitives import serialization
+
+from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
+from dns.exception import AlgorithmKeyMismatch
+
+
+class CryptographyPublicKey(GenericPublicKey):
+    key: Any = None
+    key_cls: Any = None
+
+    def __init__(self, key: Any) -> None:
+        if self.key_cls is None:
+            raise TypeError("Undefined private key class")
+        if not isinstance(key, self.key_cls):
+            raise AlgorithmKeyMismatch
+        self.key = key
+
+    @classmethod
+    def from_pem(cls, public_pem: bytes) -> "GenericPublicKey":
+        key = serialization.load_pem_public_key(public_pem)
+        return cls(key=key)
+
+    def to_pem(self) -> bytes:
+        return self.key.public_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PublicFormat.SubjectPublicKeyInfo,
+        )
+
+
+class CryptographyPrivateKey(GenericPrivateKey):
+    key: Any = None
+    key_cls: Any = None
+    public_cls: Type[CryptographyPublicKey]
+
+    def __init__(self, key: Any) -> None:
+        if self.key_cls is None:
+            raise TypeError("Undefined private key class")
+        if not isinstance(key, self.key_cls):
+            raise AlgorithmKeyMismatch
+        self.key = key
+
+    def public_key(self) -> "CryptographyPublicKey":
+        return self.public_cls(key=self.key.public_key())
+
+    @classmethod
+    def from_pem(
+        cls, private_pem: bytes, password: Optional[bytes] = None
+    ) -> "GenericPrivateKey":
+        key = serialization.load_pem_private_key(private_pem, password=password)
+        return cls(key=key)
+
+    def to_pem(self, password: Optional[bytes] = None) -> bytes:
+        encryption_algorithm: serialization.KeySerializationEncryption
+        if password:
+            encryption_algorithm = serialization.BestAvailableEncryption(password)
+        else:
+            encryption_algorithm = serialization.NoEncryption()
+        return self.key.private_bytes(
+            encoding=serialization.Encoding.PEM,
+            format=serialization.PrivateFormat.PKCS8,
+            encryption_algorithm=encryption_algorithm,
+        )
diff --git a/dns/dnssecalgs/dsa.py b/dns/dnssecalgs/dsa.py
new file mode 100644 (file)
index 0000000..0fe4690
--- /dev/null
@@ -0,0 +1,101 @@
+import struct
+
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import dsa, utils
+
+from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
+from dns.dnssectypes import Algorithm
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+
+
+class PublicDSA(CryptographyPublicKey):
+    key: dsa.DSAPublicKey
+    key_cls = dsa.DSAPublicKey
+    algorithm = Algorithm.DSA
+    chosen_hash = hashes.SHA1()
+
+    def verify(self, signature: bytes, data: bytes) -> None:
+        sig_r = signature[1:21]
+        sig_s = signature[21:]
+        sig = utils.encode_dss_signature(
+            int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
+        )
+        self.key.verify(sig, data, self.chosen_hash)
+
+    def encode_key_bytes(self) -> bytes:
+        """Encode a public key per RFC 2536, section 2."""
+        pn = self.key.public_numbers()
+        dsa_t = (self.key.key_size // 8 - 64) // 8
+        if dsa_t > 8:
+            raise ValueError("unsupported DSA key size")
+        octets = 64 + dsa_t * 8
+        res = struct.pack("!B", dsa_t)
+        res += pn.parameter_numbers.q.to_bytes(20, "big")
+        res += pn.parameter_numbers.p.to_bytes(octets, "big")
+        res += pn.parameter_numbers.g.to_bytes(octets, "big")
+        res += pn.y.to_bytes(octets, "big")
+        return res
+
+    @classmethod
+    def from_dnskey(cls, key: DNSKEY) -> "PublicDSA":
+        cls._ensure_algorithm_key_combination(key)
+        keyptr = key.key
+        (t,) = struct.unpack("!B", keyptr[0:1])
+        keyptr = keyptr[1:]
+        octets = 64 + t * 8
+        dsa_q = keyptr[0:20]
+        keyptr = keyptr[20:]
+        dsa_p = keyptr[0:octets]
+        keyptr = keyptr[octets:]
+        dsa_g = keyptr[0:octets]
+        keyptr = keyptr[octets:]
+        dsa_y = keyptr[0:octets]
+        return cls(
+            key=dsa.DSAPublicNumbers(  # type: ignore
+                int.from_bytes(dsa_y, "big"),
+                dsa.DSAParameterNumbers(
+                    int.from_bytes(dsa_p, "big"),
+                    int.from_bytes(dsa_q, "big"),
+                    int.from_bytes(dsa_g, "big"),
+                ),
+            ).public_key(default_backend()),
+        )
+
+
+class PrivateDSA(CryptographyPrivateKey):
+    key: dsa.DSAPrivateKey
+    key_cls = dsa.DSAPrivateKey
+    public_cls = PublicDSA
+
+    def sign(self, data: bytes, verify: bool = False) -> bytes:
+        """Sign using a private key per RFC 2536, section 3."""
+        public_dsa_key = self.key.public_key()
+        if public_dsa_key.key_size > 1024:
+            raise ValueError("DSA key size overflow")
+        der_signature = self.key.sign(data, self.public_cls.chosen_hash)
+        dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
+        dsa_t = (public_dsa_key.key_size // 8 - 64) // 8
+        octets = 20
+        signature = (
+            struct.pack("!B", dsa_t)
+            + int.to_bytes(dsa_r, length=octets, byteorder="big")
+            + int.to_bytes(dsa_s, length=octets, byteorder="big")
+        )
+        if verify:
+            self.public_key().verify(signature, data)
+        return signature
+
+    @classmethod
+    def generate(cls, key_size: int) -> "PrivateDSA":
+        return cls(
+            key=dsa.generate_private_key(key_size=key_size),
+        )
+
+
+class PublicDSANSEC3SHA1(PublicDSA):
+    algorithm = Algorithm.DSANSEC3SHA1
+
+
+class PrivateDSANSEC3SHA1(PrivateDSA):
+    public_cls = PublicDSANSEC3SHA1
diff --git a/dns/dnssecalgs/ecdsa.py b/dns/dnssecalgs/ecdsa.py
new file mode 100644 (file)
index 0000000..a31d79f
--- /dev/null
@@ -0,0 +1,89 @@
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import ec, utils
+
+from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
+from dns.dnssectypes import Algorithm
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+
+
+class PublicECDSA(CryptographyPublicKey):
+    key: ec.EllipticCurvePublicKey
+    key_cls = ec.EllipticCurvePublicKey
+    algorithm: Algorithm
+    chosen_hash: hashes.HashAlgorithm
+    curve: ec.EllipticCurve
+    octets: int
+
+    def verify(self, signature: bytes, data: bytes) -> None:
+        sig_r = signature[0 : self.octets]
+        sig_s = signature[self.octets :]
+        sig = utils.encode_dss_signature(
+            int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big")
+        )
+        self.key.verify(sig, data, ec.ECDSA(self.chosen_hash))
+
+    def encode_key_bytes(self) -> bytes:
+        """Encode a public key per RFC 6605, section 4."""
+        pn = self.key.public_numbers()
+        return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big")
+
+    @classmethod
+    def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA":
+        cls._ensure_algorithm_key_combination(key)
+        ecdsa_x = key.key[0 : cls.octets]
+        ecdsa_y = key.key[cls.octets : cls.octets * 2]
+        return cls(
+            key=ec.EllipticCurvePublicNumbers(
+                curve=cls.curve,
+                x=int.from_bytes(ecdsa_x, "big"),
+                y=int.from_bytes(ecdsa_y, "big"),
+            ).public_key(default_backend()),
+        )
+
+
+class PrivateECDSA(CryptographyPrivateKey):
+    key: ec.EllipticCurvePrivateKey
+    key_cls = ec.EllipticCurvePrivateKey
+    public_cls = PublicECDSA
+
+    def sign(self, data: bytes, verify: bool = False) -> bytes:
+        """Sign using a private key per RFC 6605, section 4."""
+        der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash))
+        dsa_r, dsa_s = utils.decode_dss_signature(der_signature)
+        signature = int.to_bytes(
+            dsa_r, length=self.public_cls.octets, byteorder="big"
+        ) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big")
+        if verify:
+            self.public_key().verify(signature, data)
+        return signature
+
+    @classmethod
+    def generate(cls) -> "PrivateECDSA":
+        return cls(
+            key=ec.generate_private_key(
+                curve=cls.public_cls.curve, backend=default_backend()
+            ),
+        )
+
+
+class PublicECDSAP256SHA256(PublicECDSA):
+    algorithm = Algorithm.ECDSAP256SHA256
+    chosen_hash = hashes.SHA256()
+    curve = ec.SECP256R1()
+    octets = 32
+
+
+class PrivateECDSAP256SHA256(PrivateECDSA):
+    public_cls = PublicECDSAP256SHA256
+
+
+class PublicECDSAP384SHA384(PublicECDSA):
+    algorithm = Algorithm.ECDSAP384SHA384
+    chosen_hash = hashes.SHA384()
+    curve = ec.SECP384R1()
+    octets = 48
+
+
+class PrivateECDSAP384SHA384(PrivateECDSA):
+    public_cls = PublicECDSAP384SHA384
diff --git a/dns/dnssecalgs/eddsa.py b/dns/dnssecalgs/eddsa.py
new file mode 100644 (file)
index 0000000..7050534
--- /dev/null
@@ -0,0 +1,65 @@
+from typing import Type
+
+from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.asymmetric import ed448, ed25519
+
+from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
+from dns.dnssectypes import Algorithm
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+
+
+class PublicEDDSA(CryptographyPublicKey):
+    def verify(self, signature: bytes, data: bytes) -> None:
+        self.key.verify(signature, data)
+
+    def encode_key_bytes(self) -> bytes:
+        """Encode a public key per RFC 8080, section 3."""
+        return self.key.public_bytes(
+            encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
+        )
+
+    @classmethod
+    def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
+        cls._ensure_algorithm_key_combination(key)
+        return cls(
+            key=cls.key_cls.from_public_bytes(key.key),
+        )
+
+
+class PrivateEDDSA(CryptographyPrivateKey):
+    public_cls: Type[PublicEDDSA]
+
+    def sign(self, data: bytes, verify: bool = False) -> bytes:
+        """Sign using a private key per RFC 8080, section 4."""
+        signature = self.key.sign(data)
+        if verify:
+            self.public_key().verify(signature, data)
+        return signature
+
+    @classmethod
+    def generate(cls) -> "PrivateEDDSA":
+        return cls(key=cls.key_cls.generate())
+
+
+class PublicED25519(PublicEDDSA):
+    key: ed25519.Ed25519PublicKey
+    key_cls = ed25519.Ed25519PublicKey
+    algorithm = Algorithm.ED25519
+
+
+class PrivateED25519(PrivateEDDSA):
+    key: ed25519.Ed25519PrivateKey
+    key_cls = ed25519.Ed25519PrivateKey
+    public_cls = PublicED25519
+
+
+class PublicED448(PublicEDDSA):
+    key: ed448.Ed448PublicKey
+    key_cls = ed448.Ed448PublicKey
+    algorithm = Algorithm.ED448
+
+
+class PrivateED448(PrivateEDDSA):
+    key: ed448.Ed448PrivateKey
+    key_cls = ed448.Ed448PrivateKey
+    public_cls = PublicED448
diff --git a/dns/dnssecalgs/rsa.py b/dns/dnssecalgs/rsa.py
new file mode 100644 (file)
index 0000000..e95dcf1
--- /dev/null
@@ -0,0 +1,119 @@
+import math
+import struct
+
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
+from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey
+from dns.dnssectypes import Algorithm
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+
+
+class PublicRSA(CryptographyPublicKey):
+    key: rsa.RSAPublicKey
+    key_cls = rsa.RSAPublicKey
+    algorithm: Algorithm
+    chosen_hash: hashes.HashAlgorithm
+
+    def verify(self, signature: bytes, data: bytes) -> None:
+        self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash)
+
+    def encode_key_bytes(self) -> bytes:
+        """Encode a public key per RFC 3110, section 2."""
+        pn = self.key.public_numbers()
+        _exp_len = math.ceil(int.bit_length(pn.e) / 8)
+        exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big")
+        if _exp_len > 255:
+            exp_header = b"\0" + struct.pack("!H", _exp_len)
+        else:
+            exp_header = struct.pack("!B", _exp_len)
+        if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096:
+            raise ValueError("unsupported RSA key length")
+        return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big")
+
+    @classmethod
+    def from_dnskey(cls, key: DNSKEY) -> "PublicRSA":
+        cls._ensure_algorithm_key_combination(key)
+        keyptr = key.key
+        (bytes_,) = struct.unpack("!B", keyptr[0:1])
+        keyptr = keyptr[1:]
+        if bytes_ == 0:
+            (bytes_,) = struct.unpack("!H", keyptr[0:2])
+            keyptr = keyptr[2:]
+        rsa_e = keyptr[0:bytes_]
+        rsa_n = keyptr[bytes_:]
+        return cls(
+            key=rsa.RSAPublicNumbers(
+                int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big")
+            ).public_key(default_backend())
+        )
+
+
+class PrivateRSA(CryptographyPrivateKey):
+    key: rsa.RSAPrivateKey
+    key_cls = rsa.RSAPrivateKey
+    public_cls = PublicRSA
+    default_public_exponent = 65537
+
+    def sign(self, data: bytes, verify: bool = False) -> bytes:
+        """Sign using a private key per RFC 3110, section 3."""
+        signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash)
+        if verify:
+            self.public_key().verify(signature, data)
+        return signature
+
+    @classmethod
+    def generate(cls, key_size: int) -> "PrivateRSA":
+        return cls(
+            key=rsa.generate_private_key(
+                public_exponent=cls.default_public_exponent,
+                key_size=key_size,
+                backend=default_backend(),
+            )
+        )
+
+
+class PublicRSAMD5(PublicRSA):
+    algorithm = Algorithm.RSAMD5
+    chosen_hash = hashes.MD5()
+
+
+class PrivateRSAMD5(PrivateRSA):
+    public_cls = PublicRSAMD5
+
+
+class PublicRSASHA1(PublicRSA):
+    algorithm = Algorithm.RSASHA1
+    chosen_hash = hashes.SHA1()
+
+
+class PrivateRSASHA1(PrivateRSA):
+    public_cls = PublicRSASHA1
+
+
+class PublicRSASHA1NSEC3SHA1(PublicRSA):
+    algorithm = Algorithm.RSASHA1NSEC3SHA1
+    chosen_hash = hashes.SHA1()
+
+
+class PrivateRSASHA1NSEC3SHA1(PrivateRSA):
+    public_cls = PublicRSASHA1NSEC3SHA1
+
+
+class PublicRSASHA256(PublicRSA):
+    algorithm = Algorithm.RSASHA256
+    chosen_hash = hashes.SHA256()
+
+
+class PrivateRSASHA256(PrivateRSA):
+    public_cls = PublicRSASHA256
+
+
+class PublicRSASHA512(PublicRSA):
+    algorithm = Algorithm.RSASHA512
+    chosen_hash = hashes.SHA512()
+
+
+class PrivateRSASHA512(PrivateRSA):
+    public_cls = PublicRSASHA512
index 4b1481d1ce411d440e7ec7382f156603fb452864..6982373de2a872057ca1fda3a2a752ff8d566355 100644 (file)
@@ -140,6 +140,22 @@ class Timeout(DNSException):
         super().__init__(*args, **kwargs)
 
 
+class UnsupportedAlgorithm(DNSException):
+    """The DNSSEC algorithm is not supported."""
+
+
+class AlgorithmKeyMismatch(UnsupportedAlgorithm):
+    """The DNSSEC algorithm is not supported for the given key type."""
+
+
+class ValidationFailure(DNSException):
+    """The DNSSEC signature is invalid."""
+
+
+class DeniedByPolicy(DNSException):
+    """Denied by DNSSEC policy."""
+
+
 class ExceptionWrapper:
     def __init__(self, exception_class):
         self.exception_class = exception_class
index 7177dcdabbd6f6587ce8956aab140545a4bd6494..8248206385ffc88012dbac129fc6a816fff90862 100644 (file)
@@ -923,6 +923,7 @@ class DNSSECValidatorTestCase(unittest.TestCase):
             )
 
 
+@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported")
 class DNSSECMiscTestCase(unittest.TestCase):
     def testDigestToBig(self):
         with self.assertRaises(ValueError):
@@ -932,13 +933,6 @@ class DNSSECMiscTestCase(unittest.TestCase):
         with self.assertRaises(ValueError):
             dns.dnssec.NSEC3Hash.make(256)
 
-    def testIsNotGOST(self):
-        self.assertTrue(dns.dnssec._is_gost(dns.dnssec.Algorithm.ECCGOST))
-
-    def testUnknownHash(self):
-        with self.assertRaises(dns.dnssec.ValidationFailure):
-            dns.dnssec._make_hash(100)
-
     def testToTimestamp(self):
         REFERENCE_TIMESTAMP = 441812220
 
@@ -1039,6 +1033,7 @@ class DNSSECMiscTestCase(unittest.TestCase):
         self.assertEqual(zone1.to_text(), zone2.to_text())
 
 
+@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported")
 class DNSSECMakeDSTestCase(unittest.TestCase):
     def testMnemonicParser(self):
         good_ds_mnemonic = dns.rdata.from_text(
@@ -1269,10 +1264,10 @@ class DNSSECMakeDNSKEYTestCase(unittest.TestCase):
             key_size=1024,
             backend=default_backend(),
         )
-        with self.assertRaises(dns.dnssec.AlgorithmKeyMismatch):
+        with self.assertRaises(dns.exception.AlgorithmKeyMismatch):
             dns.dnssec.make_dnskey(key.public_key(), dns.dnssec.Algorithm.ED448)
 
-        with self.assertRaises(TypeError):
+        with self.assertRaises(dns.exception.AlgorithmKeyMismatch):
             dns.dnssec.make_dnskey("xyzzy", dns.dnssec.Algorithm.ED448)
 
         key = dsa.generate_private_key(2048)
diff --git a/tests/test_dnssecalgs.py b/tests/test_dnssecalgs.py
new file mode 100644 (file)
index 0000000..8f6f9bd
--- /dev/null
@@ -0,0 +1,304 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import os
+import unittest
+
+import dns.dnssec
+import dns.exception
+from dns.dnssectypes import Algorithm
+from dns.rdtypes.ANY.DNSKEY import DNSKEY
+
+try:
+    from dns.dnssecalgs import (
+        get_algorithm_cls,
+        get_algorithm_cls_from_dnskey,
+        register_algorithm_cls,
+    )
+    from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
+    from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
+    from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519, PublicED25519
+    from dns.dnssecalgs.rsa import (
+        PrivateRSAMD5,
+        PrivateRSASHA1,
+        PrivateRSASHA1NSEC3SHA1,
+        PrivateRSASHA256,
+        PrivateRSASHA512,
+    )
+except ImportError:
+    pass  # Cryptography ImportError already handled in dns.dnssec
+
+
+@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported")
+class DNSSECAlgorithm(unittest.TestCase):
+    def _test_dnssec_alg(self, private_cls, key_size=None):
+        public_cls = private_cls.public_cls
+
+        private_key = (
+            private_cls.generate(key_size) if key_size else private_cls.generate()
+        )
+
+        # sign random data
+        data = os.urandom(1024)
+        signature = private_key.sign(data, verify=True)
+
+        # validate signature using public key
+        public_key = private_key.public_key()
+        public_key.verify(signature, data)
+
+        # create DNSKEY
+        dnskey = public_key.to_dnskey()
+        dnskey2 = public_cls.from_dnskey(dnskey).to_dnskey()
+        self.assertEqual(dnskey, dnskey2)
+
+        # test cryptography keys
+        _ = private_cls(key=private_key.key)
+        _ = public_cls(key=public_key.key)
+
+        # to/from PEM
+        password = b"mekmitasdigoat"
+        private_pem = private_key.to_pem()
+        private_pem_encrypted = private_key.to_pem(password=password)
+        public_pem = public_key.to_pem()
+        _ = private_cls.from_pem(private_pem)
+        _ = private_cls.from_pem(private_pem_encrypted, password)
+        _ = public_cls.from_pem(public_pem)
+
+    def test_rsa(self):
+        self._test_dnssec_alg(PrivateRSAMD5, 2048)
+        self._test_dnssec_alg(PrivateRSASHA1, 2048)
+        self._test_dnssec_alg(PrivateRSASHA1NSEC3SHA1, 2048)
+        self._test_dnssec_alg(PrivateRSASHA256, 2048)
+        self._test_dnssec_alg(PrivateRSASHA512, 2048)
+
+    def test_dsa(self):
+        self._test_dnssec_alg(PrivateDSA, 1024)
+        self._test_dnssec_alg(PrivateDSANSEC3SHA1, 1024)
+        with self.assertRaises(ValueError):
+            k = PrivateDSA.generate(2048)
+            k.sign(b"hello")
+
+    def test_ecdsa(self):
+        self._test_dnssec_alg(PrivateECDSAP256SHA256)
+        self._test_dnssec_alg(PrivateECDSAP384SHA384)
+
+    def test_eddsa(self):
+        self._test_dnssec_alg(PrivateED25519)
+        self._test_dnssec_alg(PrivateED448)
+
+    def test_algorithm_mismatch(self):
+        private_key_ed448 = PrivateED448.generate()
+        dnskey_ed448 = private_key_ed448.public_key().to_dnskey()
+        with self.assertRaises(dns.exception.AlgorithmKeyMismatch):
+            PublicED25519.from_dnskey(dnskey_ed448)
+
+
+@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported")
+class DNSSECAlgorithmPrivateAlgorithm(unittest.TestCase):
+    def test_private(self):
+        class PublicExampleAlgorithm(PublicED25519):
+            algorithm = Algorithm.PRIVATEDNS
+            name = dns.name.from_text("algorithm.example.com")
+
+            def encode_key_bytes(self) -> bytes:
+                return self.name.to_wire() + super().encode_key_bytes()
+
+            @classmethod
+            def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA":
+                return cls(
+                    key=cls.key_cls.from_public_bytes(
+                        key.key[len(cls.name.to_wire()) :]
+                    ),
+                )
+
+        class PrivateExampleAlgorithm(PrivateED25519):
+            public_cls = PublicExampleAlgorithm
+
+        register_algorithm_cls(
+            algorithm=Algorithm.PRIVATEDNS,
+            algorithm_cls=PrivateExampleAlgorithm,
+            name=PublicExampleAlgorithm.name,
+        )
+
+        private_key = PrivateExampleAlgorithm.generate()
+        public_key = private_key.public_key()
+
+        name = dns.name.from_text("example.com")
+        rdataset = dns.rdataset.from_text_list("in", "a", 30, ["10.0.0.1", "10.0.0.2"])
+        rrset = (name, rdataset)
+        ttl = 60
+        lifetime = 3600
+        rrname = rrset[0]
+        signer = rrname
+        dnskey = dns.dnssec.make_dnskey(
+            public_key=public_key, algorithm=Algorithm.PRIVATEDNS
+        )
+        dnskey_rrset = dns.rrset.from_rdata(signer, ttl, dnskey)
+
+        rrsig = dns.dnssec.sign(
+            rrset=rrset,
+            private_key=private_key,
+            dnskey=dnskey,
+            lifetime=lifetime,
+            signer=signer,
+            verify=True,
+            policy=None,
+        )
+
+        keys = {signer: dnskey_rrset}
+        rrsigset = dns.rrset.from_rdata(rrname, ttl, rrsig)
+        dns.dnssec.validate(rrset=rrset, rrsigset=rrsigset, keys=keys, policy=None)
+
+    def test_register(self):
+        register_algorithm_cls(
+            algorithm=Algorithm.PRIVATEDNS,
+            algorithm_cls=PrivateED25519,
+            name="ed25519.example.com",
+        )
+        register_algorithm_cls(
+            algorithm=Algorithm.PRIVATEOID,
+            algorithm_cls=PrivateED448,
+            oid=bytes([1, 2, 3, 4]),
+        )
+        register_algorithm_cls(
+            algorithm=251,
+            algorithm_cls=PrivateED25519,
+        )
+
+        with self.assertRaises(TypeError):
+            register_algorithm_cls(algorithm=251, algorithm_cls=str, name="example.com")
+
+        with self.assertRaises(ValueError):
+            register_algorithm_cls(
+                algorithm=251, algorithm_cls=PrivateED25519, name="example.com"
+            )
+
+        with self.assertRaises(ValueError):
+            register_algorithm_cls(
+                algorithm=251, algorithm_cls=PrivateED25519, oid=bytes([1, 2, 3, 4])
+            )
+
+        with self.assertRaises(ValueError):
+            register_algorithm_cls(
+                algorithm=Algorithm.PRIVATEDNS,
+                algorithm_cls=PrivateED25519,
+                oid=bytes([1, 2, 3, 4]),
+            )
+
+        with self.assertRaises(ValueError):
+            register_algorithm_cls(
+                algorithm=Algorithm.PRIVATEOID,
+                algorithm_cls=PrivateED25519,
+                name="example.com",
+            )
+
+        dnskey_251 = DNSKEY(
+            "IN",
+            "DNSKEY",
+            256,
+            3,
+            251,
+            b"hello",
+        )
+        dnskey_dns = DNSKEY(
+            "IN",
+            "DNSKEY",
+            256,
+            3,
+            Algorithm.PRIVATEDNS,
+            dns.name.from_text("ed25519.example.com").to_wire() + b"hello",
+        )
+        dnskey_dns_unknown = DNSKEY(
+            "IN",
+            "DNSKEY",
+            256,
+            3,
+            Algorithm.PRIVATEDNS,
+            dns.name.from_text("unknown.example.com").to_wire() + b"hello",
+        )
+        dnskey_oid = DNSKEY(
+            "IN",
+            "DNSKEY",
+            256,
+            3,
+            Algorithm.PRIVATEOID,
+            bytes([4, 1, 2, 3, 4]) + b"hello",
+        )
+        dnskey_oid_unknown = DNSKEY(
+            "IN",
+            "DNSKEY",
+            256,
+            3,
+            Algorithm.PRIVATEOID,
+            bytes([4, 42, 42, 42, 42]) + b"hello",
+        )
+
+        with self.assertRaises(dns.exception.UnsupportedAlgorithm):
+            _ = get_algorithm_cls(250)
+
+        algorithm_cls = get_algorithm_cls(251)
+        self.assertEqual(algorithm_cls, PrivateED25519)
+
+        algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_251)
+        self.assertEqual(algorithm_cls, PrivateED25519)
+
+        algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_dns)
+        self.assertEqual(algorithm_cls, PrivateED25519)
+
+        with self.assertRaises(dns.exception.UnsupportedAlgorithm):
+            _ = get_algorithm_cls_from_dnskey(dnskey_dns_unknown)
+
+        algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_oid)
+        self.assertEqual(algorithm_cls, PrivateED448)
+
+        with self.assertRaises(dns.exception.UnsupportedAlgorithm):
+            _ = get_algorithm_cls_from_dnskey(dnskey_oid_unknown)
+
+    def test_register_canonical_lookup(self):
+        register_algorithm_cls(
+            algorithm=Algorithm.PRIVATEDNS,
+            algorithm_cls=PrivateED25519,
+            name="testing1234.example.com",
+        )
+
+        dnskey_dns = DNSKEY(
+            "IN",
+            "DNSKEY",
+            256,
+            3,
+            Algorithm.PRIVATEDNS,
+            dns.name.from_text("TESTING1234.EXAMPLE.COM").to_wire() + b"hello",
+        )
+
+        algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_dns)
+        self.assertEqual(algorithm_cls, PrivateED25519)
+
+    def test_register_private_without_prefix(self):
+        with self.assertRaises(ValueError):
+            register_algorithm_cls(
+                algorithm=Algorithm.PRIVATEDNS,
+                algorithm_cls=PrivateED25519,
+            )
+        with self.assertRaises(ValueError):
+            register_algorithm_cls(
+                algorithm=Algorithm.PRIVATEOID,
+                algorithm_cls=PrivateED25519,
+            )
+
+
+if __name__ == "__main__":
+    unittest.main()