From: Jakob Schlyter Date: Sun, 25 Jun 2023 21:01:00 +0000 (+0200) Subject: DNSSEC Algorithm Refactor (#944) X-Git-Tag: v2.4.0rc1~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7d3cde8cb58ae60bf95d888f7a850773ce8af4cf;p=thirdparty%2Fdnspython.git DNSSEC Algorithm Refactor (#944) * Split DNSSEC algorithms into separate classes with a registration mechanism. * Add DNSSEC private algorithm support. --- diff --git a/dns/dnssec.py b/dns/dnssec.py index 55fd7b57..d9b8d98d 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -18,12 +18,11 @@ """Common DNSSEC-related functions and constants.""" -from typing import Any, cast, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import cast, Callable, Dict, List, Optional, Set, Tuple, Union import contextlib import functools import hashlib -import math import struct import time import base64 @@ -41,6 +40,12 @@ import dns.rdataclass import dns.rrset import dns.transaction import dns.zone +from dns.exception import ( # pylint: disable=W0611 + AlgorithmKeyMismatch, + DeniedByPolicy, + UnsupportedAlgorithm, + ValidationFailure, +) from dns.rdtypes.ANY.CDNSKEY import CDNSKEY from dns.rdtypes.ANY.CDS import CDS from dns.rdtypes.ANY.DNSKEY import DNSKEY @@ -51,23 +56,8 @@ from dns.rdtypes.ANY.RRSIG import RRSIG, sigtime_to_posixtime from dns.rdtypes.dnskeybase import Flag -class UnsupportedAlgorithm(dns.exception.DNSException): - """The DNSSEC algorithm is not supported.""" - - -class AlgorithmKeyMismatch(UnsupportedAlgorithm): - """The DNSSEC algorithm is not supported for the given key type.""" - - -class ValidationFailure(dns.exception.DNSException): - """The DNSSEC signature is invalid.""" - - -class DeniedByPolicy(dns.exception.DNSException): - """Denied by DNSSEC policy.""" - - PublicKey = Union[ + "GenericPublicKey", "rsa.RSAPublicKey", "ec.EllipticCurvePublicKey", "ed25519.Ed25519PublicKey", @@ -75,6 +65,7 @@ PublicKey = Union[ ] PrivateKey = Union[ + "GenericPrivateKey", "rsa.RSAPrivateKey", "ec.EllipticCurvePrivateKey", "ed25519.Ed25519PrivateKey", @@ -321,109 +312,6 @@ def _find_candidate_keys( ] -def _is_rsa(algorithm: int) -> bool: - return algorithm in ( - Algorithm.RSAMD5, - Algorithm.RSASHA1, - Algorithm.RSASHA1NSEC3SHA1, - Algorithm.RSASHA256, - Algorithm.RSASHA512, - ) - - -def _is_dsa(algorithm: int) -> bool: - return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1) - - -def _is_ecdsa(algorithm: int) -> bool: - return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384) - - -def _is_eddsa(algorithm: int) -> bool: - return algorithm in (Algorithm.ED25519, Algorithm.ED448) - - -def _is_gost(algorithm: int) -> bool: - return algorithm == Algorithm.ECCGOST - - -def _is_md5(algorithm: int) -> bool: - return algorithm == Algorithm.RSAMD5 - - -def _is_sha1(algorithm: int) -> bool: - return algorithm in ( - Algorithm.DSA, - Algorithm.RSASHA1, - Algorithm.DSANSEC3SHA1, - Algorithm.RSASHA1NSEC3SHA1, - ) - - -def _is_sha256(algorithm: int) -> bool: - return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256) - - -def _is_sha384(algorithm: int) -> bool: - return algorithm == Algorithm.ECDSAP384SHA384 - - -def _is_sha512(algorithm: int) -> bool: - return algorithm == Algorithm.RSASHA512 - - -def _ensure_algorithm_key_combination(algorithm: int, key: PublicKey) -> None: - """Ensure algorithm is valid for key type, throwing an exception on - mismatch.""" - if isinstance(key, rsa.RSAPublicKey): - if _is_rsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for RSA key' % algorithm) - if isinstance(key, dsa.DSAPublicKey): - if _is_dsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for DSA key' % algorithm) - if isinstance(key, ec.EllipticCurvePublicKey): - if _is_ecdsa(algorithm): - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for ECDSA key' % algorithm) - if isinstance(key, ed25519.Ed25519PublicKey): - if algorithm == Algorithm.ED25519: - return - raise AlgorithmKeyMismatch( - 'algorithm "%s" not valid for ED25519 key' % algorithm - ) - if isinstance(key, ed448.Ed448PublicKey): - if algorithm == Algorithm.ED448: - return - raise AlgorithmKeyMismatch('algorithm "%s" not valid for ED448 key' % algorithm) - - raise TypeError("unsupported key type") - - -def _make_hash(algorithm: int) -> Any: - if _is_md5(algorithm): - return hashes.MD5() - if _is_sha1(algorithm): - return hashes.SHA1() - if _is_sha256(algorithm): - return hashes.SHA256() - if _is_sha384(algorithm): - return hashes.SHA384() - if _is_sha512(algorithm): - return hashes.SHA512() - if algorithm == Algorithm.ED25519: - return hashes.SHA512() - if algorithm == Algorithm.ED448: - return hashes.SHAKE256(114) - - raise ValidationFailure("unknown hash for algorithm %u" % algorithm) - - -def _bytes_to_long(b: bytes) -> int: - return int.from_bytes(b, "big") - - def _get_rrname_rdataset( rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], ) -> Tuple[dns.name.Name, dns.rdataset.Rdataset]: @@ -433,85 +321,13 @@ def _get_rrname_rdataset( return rrset.name, rrset -def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: - keyptr: bytes - if _is_rsa(key.algorithm): - # we ignore because mypy is confused and thinks key.key is a str for unknown - # reasons. - keyptr = key.key - (bytes_,) = struct.unpack("!B", keyptr[0:1]) - keyptr = keyptr[1:] - if bytes_ == 0: - (bytes_,) = struct.unpack("!H", keyptr[0:2]) - keyptr = keyptr[2:] - rsa_e = keyptr[0:bytes_] - rsa_n = keyptr[bytes_:] - try: - rsa_public_key = rsa.RSAPublicNumbers( - _bytes_to_long(rsa_e), _bytes_to_long(rsa_n) - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) - elif _is_dsa(key.algorithm): - keyptr = key.key - (t,) = struct.unpack("!B", keyptr[0:1]) - keyptr = keyptr[1:] - octets = 64 + t * 8 - dsa_q = keyptr[0:20] - keyptr = keyptr[20:] - dsa_p = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_g = keyptr[0:octets] - keyptr = keyptr[octets:] - dsa_y = keyptr[0:octets] - try: - dsa_public_key = dsa.DSAPublicNumbers( # type: ignore - _bytes_to_long(dsa_y), - dsa.DSAParameterNumbers( - _bytes_to_long(dsa_p), _bytes_to_long(dsa_q), _bytes_to_long(dsa_g) - ), - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - dsa_public_key.verify(sig, data, chosen_hash) - elif _is_ecdsa(key.algorithm): - keyptr = key.key - curve: Any - if key.algorithm == Algorithm.ECDSAP256SHA256: - curve = ec.SECP256R1() - octets = 32 - else: - curve = ec.SECP384R1() - octets = 48 - ecdsa_x = keyptr[0:octets] - ecdsa_y = keyptr[octets : octets * 2] - try: - ecdsa_public_key = ec.EllipticCurvePublicNumbers( - curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y) - ).public_key(default_backend()) - except ValueError: - raise ValidationFailure("invalid public key") - ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash)) - elif _is_eddsa(key.algorithm): - keyptr = key.key - loader: Any - if key.algorithm == Algorithm.ED25519: - loader = ed25519.Ed25519PublicKey - else: - loader = ed448.Ed448PublicKey - try: - eddsa_public_key = loader.from_public_bytes(keyptr) - except ValueError: - raise ValidationFailure("invalid public key") - eddsa_public_key.verify(sig, data) - elif _is_gost(key.algorithm): - raise UnsupportedAlgorithm( - 'algorithm "%s" not supported by dnspython' - % algorithm_to_text(key.algorithm) - ) - else: - raise ValidationFailure("unknown algorithm %u" % key.algorithm) +def _validate_signature(sig: bytes, data: bytes, key: DNSKEY) -> None: + public_cls = get_algorithm_cls_from_dnskey(key).public_cls + try: + public_key = public_cls.from_dnskey(key) + except ValueError: + raise ValidationFailure("invalid public key") + public_key.verify(sig, data) def _validate_rrsig( @@ -568,29 +384,13 @@ def _validate_rrsig( if rrsig.inception > now: raise ValidationFailure("not yet valid") - if _is_dsa(rrsig.algorithm): - sig_r = rrsig.signature[1:21] - sig_s = rrsig.signature[21:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) - elif _is_ecdsa(rrsig.algorithm): - if rrsig.algorithm == Algorithm.ECDSAP256SHA256: - octets = 32 - else: - octets = 48 - sig_r = rrsig.signature[0:octets] - sig_s = rrsig.signature[octets:] - sig = utils.encode_dss_signature(_bytes_to_long(sig_r), _bytes_to_long(sig_s)) - else: - sig = rrsig.signature - data = _make_rrsig_signature_data(rrset, rrsig, origin) - chosen_hash = _make_hash(rrsig.algorithm) for candidate_key in candidate_keys: if not policy.ok_to_validate(candidate_key): continue try: - _validate_signature(sig, data, candidate_key, chosen_hash) + _validate_signature(rrsig.signature, data, candidate_key) return except (InvalidSignature, ValidationFailure): # this happens on an individual validation failure @@ -778,62 +578,17 @@ def _sign( ) data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin) - chosen_hash = _make_hash(rrsig_template.algorithm) - signature = None - - if isinstance(private_key, rsa.RSAPrivateKey): - if not _is_rsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for RSA key") - signature = private_key.sign(data, padding.PKCS1v15(), chosen_hash) - if verify: - private_key.public_key().verify( - signature, data, padding.PKCS1v15(), chosen_hash - ) - elif isinstance(private_key, dsa.DSAPrivateKey): - if not _is_dsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for DSA key") - public_dsa_key = private_key.public_key() - if public_dsa_key.key_size > 1024: - raise ValueError("DSA key size overflow") - der_signature = private_key.sign(data, chosen_hash) - if verify: - public_dsa_key.verify(der_signature, data, chosen_hash) - dsa_r, dsa_s = utils.decode_dss_signature(der_signature) - dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 - octets = 20 - signature = ( - struct.pack("!B", dsa_t) - + int.to_bytes(dsa_r, length=octets, byteorder="big") - + int.to_bytes(dsa_s, length=octets, byteorder="big") - ) - elif isinstance(private_key, ec.EllipticCurvePrivateKey): - if not _is_ecdsa(dnskey.algorithm): - raise ValueError("Invalid DNSKEY algorithm for EC key") - der_signature = private_key.sign(data, ec.ECDSA(chosen_hash)) - if verify: - private_key.public_key().verify(der_signature, data, ec.ECDSA(chosen_hash)) - if dnskey.algorithm == Algorithm.ECDSAP256SHA256: - octets = 32 - else: - octets = 48 - dsa_r, dsa_s = utils.decode_dss_signature(der_signature) - signature = int.to_bytes(dsa_r, length=octets, byteorder="big") + int.to_bytes( - dsa_s, length=octets, byteorder="big" - ) - elif isinstance(private_key, ed25519.Ed25519PrivateKey): - if dnskey.algorithm != Algorithm.ED25519: - raise ValueError("Invalid DNSKEY algorithm for ED25519 key") - signature = private_key.sign(data) - if verify: - private_key.public_key().verify(signature, data) - elif isinstance(private_key, ed448.Ed448PrivateKey): - if dnskey.algorithm != Algorithm.ED448: - raise ValueError("Invalid DNSKEY algorithm for ED448 key") - signature = private_key.sign(data) - if verify: - private_key.public_key().verify(signature, data) + + if isinstance(private_key, GenericPrivateKey): + signing_key = private_key else: - raise TypeError("Unsupported key algorithm") + try: + private_cls = get_algorithm_cls_from_dnskey(dnskey) + signing_key = private_cls(key=private_key) + except UnsupportedAlgorithm: + raise TypeError("Unsupported key algorithm") + + signature = signing_key.sign(data, verify) return cast(RRSIG, rrsig_template.replace(signature=signature)) @@ -911,9 +666,8 @@ def _make_dnskey( ) -> DNSKEY: """Convert a public key to DNSKEY Rdata - *public_key*, the public key to convert, a - ``cryptography.hazmat.primitives.asymmetric`` public key class applicable - for DNSSEC. + *public_key*, a ``PublicKey`` (``GenericPublicKey`` or + ``cryptography.hazmat.primitives.asymmetric``) to convert. *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. @@ -929,72 +683,13 @@ def _make_dnskey( Return DNSKEY ``Rdata``. """ - def encode_rsa_public_key(public_key: "rsa.RSAPublicKey") -> bytes: - """Encode a public key per RFC 3110, section 2.""" - pn = public_key.public_numbers() - _exp_len = math.ceil(int.bit_length(pn.e) / 8) - exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") - if _exp_len > 255: - exp_header = b"\0" + struct.pack("!H", _exp_len) - else: - exp_header = struct.pack("!B", _exp_len) - if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: - raise ValueError("unsupported RSA key length") - return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") - - def encode_dsa_public_key(public_key: "dsa.DSAPublicKey") -> bytes: - """Encode a public key per RFC 2536, section 2.""" - pn = public_key.public_numbers() - dsa_t = (public_key.key_size // 8 - 64) // 8 - if dsa_t > 8: - raise ValueError("unsupported DSA key size") - octets = 64 + dsa_t * 8 - res = struct.pack("!B", dsa_t) - res += pn.parameter_numbers.q.to_bytes(20, "big") - res += pn.parameter_numbers.p.to_bytes(octets, "big") - res += pn.parameter_numbers.g.to_bytes(octets, "big") - res += pn.y.to_bytes(octets, "big") - return res - - def encode_ecdsa_public_key(public_key: "ec.EllipticCurvePublicKey") -> bytes: - """Encode a public key per RFC 6605, section 4.""" - pn = public_key.public_numbers() - if isinstance(public_key.curve, ec.SECP256R1): - return pn.x.to_bytes(32, "big") + pn.y.to_bytes(32, "big") - elif isinstance(public_key.curve, ec.SECP384R1): - return pn.x.to_bytes(48, "big") + pn.y.to_bytes(48, "big") - else: - raise ValueError("unsupported ECDSA curve") - algorithm = Algorithm.make(algorithm) - _ensure_algorithm_key_combination(algorithm, public_key) - - if isinstance(public_key, rsa.RSAPublicKey): - key_bytes = encode_rsa_public_key(public_key) - elif isinstance(public_key, dsa.DSAPublicKey): - key_bytes = encode_dsa_public_key(public_key) - elif isinstance(public_key, ec.EllipticCurvePublicKey): - key_bytes = encode_ecdsa_public_key(public_key) - elif isinstance(public_key, ed25519.Ed25519PublicKey): - key_bytes = public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) - elif isinstance(public_key, ed448.Ed448PublicKey): - key_bytes = public_key.public_bytes( - encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw - ) + if isinstance(public_key, GenericPublicKey): + return public_key.to_dnskey(flags=flags, protocol=protocol) else: - raise TypeError("unsupported key algorithm") - - return DNSKEY( - rdclass=dns.rdataclass.IN, - rdtype=dns.rdatatype.DNSKEY, - flags=flags, - protocol=protocol, - algorithm=algorithm, - key=key_bytes, - ) + public_cls = get_algorithm_cls(algorithm).public_cls + return public_cls(key=public_key).to_dnskey(flags=flags, protocol=protocol) def _make_cdnskey( @@ -1476,15 +1171,18 @@ def _need_pyca(*args, **kwargs): try: from cryptography.exceptions import InvalidSignature - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import hashes, serialization - from cryptography.hazmat.primitives.asymmetric import padding - from cryptography.hazmat.primitives.asymmetric import utils - from cryptography.hazmat.primitives.asymmetric import dsa - from cryptography.hazmat.primitives.asymmetric import ec - from cryptography.hazmat.primitives.asymmetric import ed25519 - from cryptography.hazmat.primitives.asymmetric import ed448 - from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import ( # pylint: disable=W0611 + ed25519, + ) + from cryptography.hazmat.primitives.asymmetric import ed448 # pylint: disable=W0611 + from cryptography.hazmat.primitives.asymmetric import rsa # pylint: disable=W0611 + from dns.dnssecalgs import ( # pylint: disable=C0412 + get_algorithm_cls, + get_algorithm_cls_from_dnskey, + ) + from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey except ImportError: # pragma: no cover validate = _need_pyca validate_rrsig = _need_pyca diff --git a/dns/dnssecalgs/__init__.py b/dns/dnssecalgs/__init__.py new file mode 100644 index 00000000..d4c89cd6 --- /dev/null +++ b/dns/dnssecalgs/__init__.py @@ -0,0 +1,111 @@ +from typing import Dict, Optional, Tuple, Type, Union + +import dns.name +from dns.dnssecalgs.base import GenericPrivateKey +from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 +from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 +from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519 +from dns.dnssecalgs.rsa import ( + PrivateRSAMD5, + PrivateRSASHA1, + PrivateRSASHA1NSEC3SHA1, + PrivateRSASHA256, + PrivateRSASHA512, +) +from dns.dnssectypes import Algorithm +from dns.exception import UnsupportedAlgorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +AlgorithmPrefix = Optional[Union[bytes, dns.name.Name]] + +algorithms: Dict[Tuple[Algorithm, AlgorithmPrefix], Type[GenericPrivateKey]] = { + (Algorithm.RSAMD5, None): PrivateRSAMD5, + (Algorithm.DSA, None): PrivateDSA, + (Algorithm.RSASHA1, None): PrivateRSASHA1, + (Algorithm.DSANSEC3SHA1, None): PrivateDSANSEC3SHA1, + (Algorithm.RSASHA1NSEC3SHA1, None): PrivateRSASHA1NSEC3SHA1, + (Algorithm.RSASHA256, None): PrivateRSASHA256, + (Algorithm.RSASHA512, None): PrivateRSASHA512, + (Algorithm.ECDSAP256SHA256, None): PrivateECDSAP256SHA256, + (Algorithm.ECDSAP384SHA384, None): PrivateECDSAP384SHA384, + (Algorithm.ED25519, None): PrivateED25519, + (Algorithm.ED448, None): PrivateED448, +} + + +def get_algorithm_cls( + algorithm: Union[int, str], prefix: AlgorithmPrefix = None +) -> Type[GenericPrivateKey]: + """Get Private Key class from Algorithm. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + algorithm = Algorithm.make(algorithm) + cls = algorithms.get((algorithm, prefix)) + if cls: + return cls + raise UnsupportedAlgorithm( + 'algorithm "%s" not supported by dnspython' % Algorithm.to_text(algorithm) + ) + + +def get_algorithm_cls_from_dnskey(dnskey: DNSKEY) -> Type[GenericPrivateKey]: + """Get Private Key class from DNSKEY. + + *dnskey*, a ``DNSKEY`` to get Algorithm class for. + + Raises ``UnsupportedAlgorithm`` if the algorithm is unknown. + + Returns a ``dns.dnssecalgs.GenericPrivateKey`` + """ + prefix: AlgorithmPrefix = None + if dnskey.algorithm == Algorithm.PRIVATEDNS: + prefix, _ = dns.name.from_wire(dnskey.key, 0) + elif dnskey.algorithm == Algorithm.PRIVATEOID: + length = int(dnskey.key[0]) + prefix = dnskey.key[0 : length + 1] + return get_algorithm_cls(dnskey.algorithm, prefix) + + +def register_algorithm_cls( + algorithm: Union[int, str], + algorithm_cls: Type[GenericPrivateKey], + name: Optional[Union[dns.name.Name, str]] = None, + oid: Optional[bytes] = None, +) -> None: + """Register Algorithm Private Key class. + + *algorithm*, a ``str`` or ``int`` specifying the DNSKEY algorithm. + + *algorithm_cls*: A `GenericPrivateKey` class. + + *name*, an optional ``dns.name.Name`` or ``str``, for for PRIVATEDNS algorithms. + + *oid*: an optional BER-encoded `bytes` for PRIVATEOID algorithms. + + Raises ``ValueError`` if a name or oid is specified incorrectly. + """ + if not issubclass(algorithm_cls, GenericPrivateKey): + raise TypeError("Invalid algorithm class") + algorithm = Algorithm.make(algorithm) + prefix: AlgorithmPrefix = None + if algorithm == Algorithm.PRIVATEDNS: + if name is None: + raise ValueError("Name required for PRIVATEDNS algorithms") + if isinstance(name, str): + name = dns.name.from_text(name) + prefix = name + elif algorithm == Algorithm.PRIVATEOID: + if oid is None: + raise ValueError("OID required for PRIVATEOID algorithms") + prefix = bytes([len(oid)]) + oid + elif name: + raise ValueError("Name only supported for PRIVATEDNS algorithm") + elif oid: + raise ValueError("OID only supported for PRIVATEOID algorithm") + algorithms[(algorithm, prefix)] = algorithm_cls diff --git a/dns/dnssecalgs/base.py b/dns/dnssecalgs/base.py new file mode 100644 index 00000000..c70b0812 --- /dev/null +++ b/dns/dnssecalgs/base.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from typing import Optional, Type, Any + +import dns.rdataclass +import dns.rdatatype +from dns.dnssectypes import Algorithm +from dns.exception import AlgorithmKeyMismatch +from dns.rdtypes.ANY.DNSKEY import DNSKEY +from dns.rdtypes.dnskeybase import Flag + + +class GenericPublicKey(ABC): + algorithm: Algorithm + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def verify(self, signature: bytes, data: bytes) -> None: + """Verify signed DNSSEC data""" + + @abstractmethod + def encode_key_bytes(self) -> bytes: + """Encode key as bytes for DNSKEY""" + + @classmethod + def _ensure_algorithm_key_combination(cls, key: DNSKEY) -> None: + if key.algorithm != cls.algorithm: + raise AlgorithmKeyMismatch + + def to_dnskey(self, flags: int = Flag.ZONE, protocol: int = 3) -> DNSKEY: + """Return public key as DNSKEY""" + return DNSKEY( + rdclass=dns.rdataclass.IN, + rdtype=dns.rdatatype.DNSKEY, + flags=flags, + protocol=protocol, + algorithm=self.algorithm, + key=self.encode_key_bytes(), + ) + + @classmethod + @abstractmethod + def from_dnskey(cls, key: DNSKEY) -> "GenericPublicKey": + """Create public key from DNSKEY""" + + @classmethod + @abstractmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + """Create public key from PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + @abstractmethod + def to_pem(self) -> bytes: + """Return public-key as PEM-encoded SubjectPublicKeyInfo as specified + in RFC 5280""" + + +class GenericPrivateKey(ABC): + public_cls: Type[GenericPublicKey] + + @abstractmethod + def __init__(self, key: Any) -> None: + pass + + @abstractmethod + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign DNSSEC data""" + + @abstractmethod + def public_key(self) -> "GenericPublicKey": + """Return public key instance""" + + @classmethod + @abstractmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + """Create private key from PEM-encoded PKCS#8""" + + @abstractmethod + def to_pem(self, password: Optional[bytes] = None) -> bytes: + """Return private key as PEM-encoded PKCS#8""" diff --git a/dns/dnssecalgs/cryptography.py b/dns/dnssecalgs/cryptography.py new file mode 100644 index 00000000..b5bcd2ef --- /dev/null +++ b/dns/dnssecalgs/cryptography.py @@ -0,0 +1,64 @@ +from typing import Any, Optional, Type + +from cryptography.hazmat.primitives import serialization + +from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey +from dns.exception import AlgorithmKeyMismatch + + +class CryptographyPublicKey(GenericPublicKey): + key: Any = None + key_cls: Any = None + + def __init__(self, key: Any) -> None: + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance(key, self.key_cls): + raise AlgorithmKeyMismatch + self.key = key + + @classmethod + def from_pem(cls, public_pem: bytes) -> "GenericPublicKey": + key = serialization.load_pem_public_key(public_pem) + return cls(key=key) + + def to_pem(self) -> bytes: + return self.key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + +class CryptographyPrivateKey(GenericPrivateKey): + key: Any = None + key_cls: Any = None + public_cls: Type[CryptographyPublicKey] + + def __init__(self, key: Any) -> None: + if self.key_cls is None: + raise TypeError("Undefined private key class") + if not isinstance(key, self.key_cls): + raise AlgorithmKeyMismatch + self.key = key + + def public_key(self) -> "CryptographyPublicKey": + return self.public_cls(key=self.key.public_key()) + + @classmethod + def from_pem( + cls, private_pem: bytes, password: Optional[bytes] = None + ) -> "GenericPrivateKey": + key = serialization.load_pem_private_key(private_pem, password=password) + return cls(key=key) + + def to_pem(self, password: Optional[bytes] = None) -> bytes: + encryption_algorithm: serialization.KeySerializationEncryption + if password: + encryption_algorithm = serialization.BestAvailableEncryption(password) + else: + encryption_algorithm = serialization.NoEncryption() + return self.key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=encryption_algorithm, + ) diff --git a/dns/dnssecalgs/dsa.py b/dns/dnssecalgs/dsa.py new file mode 100644 index 00000000..0fe4690d --- /dev/null +++ b/dns/dnssecalgs/dsa.py @@ -0,0 +1,101 @@ +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import dsa, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicDSA(CryptographyPublicKey): + key: dsa.DSAPublicKey + key_cls = dsa.DSAPublicKey + algorithm = Algorithm.DSA + chosen_hash = hashes.SHA1() + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[1:21] + sig_s = signature[21:] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 2536, section 2.""" + pn = self.key.public_numbers() + dsa_t = (self.key.key_size // 8 - 64) // 8 + if dsa_t > 8: + raise ValueError("unsupported DSA key size") + octets = 64 + dsa_t * 8 + res = struct.pack("!B", dsa_t) + res += pn.parameter_numbers.q.to_bytes(20, "big") + res += pn.parameter_numbers.p.to_bytes(octets, "big") + res += pn.parameter_numbers.g.to_bytes(octets, "big") + res += pn.y.to_bytes(octets, "big") + return res + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicDSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (t,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + octets = 64 + t * 8 + dsa_q = keyptr[0:20] + keyptr = keyptr[20:] + dsa_p = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_g = keyptr[0:octets] + keyptr = keyptr[octets:] + dsa_y = keyptr[0:octets] + return cls( + key=dsa.DSAPublicNumbers( # type: ignore + int.from_bytes(dsa_y, "big"), + dsa.DSAParameterNumbers( + int.from_bytes(dsa_p, "big"), + int.from_bytes(dsa_q, "big"), + int.from_bytes(dsa_g, "big"), + ), + ).public_key(default_backend()), + ) + + +class PrivateDSA(CryptographyPrivateKey): + key: dsa.DSAPrivateKey + key_cls = dsa.DSAPrivateKey + public_cls = PublicDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 2536, section 3.""" + public_dsa_key = self.key.public_key() + if public_dsa_key.key_size > 1024: + raise ValueError("DSA key size overflow") + der_signature = self.key.sign(data, self.public_cls.chosen_hash) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + dsa_t = (public_dsa_key.key_size // 8 - 64) // 8 + octets = 20 + signature = ( + struct.pack("!B", dsa_t) + + int.to_bytes(dsa_r, length=octets, byteorder="big") + + int.to_bytes(dsa_s, length=octets, byteorder="big") + ) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateDSA": + return cls( + key=dsa.generate_private_key(key_size=key_size), + ) + + +class PublicDSANSEC3SHA1(PublicDSA): + algorithm = Algorithm.DSANSEC3SHA1 + + +class PrivateDSANSEC3SHA1(PrivateDSA): + public_cls = PublicDSANSEC3SHA1 diff --git a/dns/dnssecalgs/ecdsa.py b/dns/dnssecalgs/ecdsa.py new file mode 100644 index 00000000..a31d79f2 --- /dev/null +++ b/dns/dnssecalgs/ecdsa.py @@ -0,0 +1,89 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec, utils + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicECDSA(CryptographyPublicKey): + key: ec.EllipticCurvePublicKey + key_cls = ec.EllipticCurvePublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + curve: ec.EllipticCurve + octets: int + + def verify(self, signature: bytes, data: bytes) -> None: + sig_r = signature[0 : self.octets] + sig_s = signature[self.octets :] + sig = utils.encode_dss_signature( + int.from_bytes(sig_r, "big"), int.from_bytes(sig_s, "big") + ) + self.key.verify(sig, data, ec.ECDSA(self.chosen_hash)) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 6605, section 4.""" + pn = self.key.public_numbers() + return pn.x.to_bytes(self.octets, "big") + pn.y.to_bytes(self.octets, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicECDSA": + cls._ensure_algorithm_key_combination(key) + ecdsa_x = key.key[0 : cls.octets] + ecdsa_y = key.key[cls.octets : cls.octets * 2] + return cls( + key=ec.EllipticCurvePublicNumbers( + curve=cls.curve, + x=int.from_bytes(ecdsa_x, "big"), + y=int.from_bytes(ecdsa_y, "big"), + ).public_key(default_backend()), + ) + + +class PrivateECDSA(CryptographyPrivateKey): + key: ec.EllipticCurvePrivateKey + key_cls = ec.EllipticCurvePrivateKey + public_cls = PublicECDSA + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 6605, section 4.""" + der_signature = self.key.sign(data, ec.ECDSA(self.public_cls.chosen_hash)) + dsa_r, dsa_s = utils.decode_dss_signature(der_signature) + signature = int.to_bytes( + dsa_r, length=self.public_cls.octets, byteorder="big" + ) + int.to_bytes(dsa_s, length=self.public_cls.octets, byteorder="big") + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateECDSA": + return cls( + key=ec.generate_private_key( + curve=cls.public_cls.curve, backend=default_backend() + ), + ) + + +class PublicECDSAP256SHA256(PublicECDSA): + algorithm = Algorithm.ECDSAP256SHA256 + chosen_hash = hashes.SHA256() + curve = ec.SECP256R1() + octets = 32 + + +class PrivateECDSAP256SHA256(PrivateECDSA): + public_cls = PublicECDSAP256SHA256 + + +class PublicECDSAP384SHA384(PublicECDSA): + algorithm = Algorithm.ECDSAP384SHA384 + chosen_hash = hashes.SHA384() + curve = ec.SECP384R1() + octets = 48 + + +class PrivateECDSAP384SHA384(PrivateECDSA): + public_cls = PublicECDSAP384SHA384 diff --git a/dns/dnssecalgs/eddsa.py b/dns/dnssecalgs/eddsa.py new file mode 100644 index 00000000..70505342 --- /dev/null +++ b/dns/dnssecalgs/eddsa.py @@ -0,0 +1,65 @@ +from typing import Type + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed448, ed25519 + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicEDDSA(CryptographyPublicKey): + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 8080, section 3.""" + return self.key.public_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + ) + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA": + cls._ensure_algorithm_key_combination(key) + return cls( + key=cls.key_cls.from_public_bytes(key.key), + ) + + +class PrivateEDDSA(CryptographyPrivateKey): + public_cls: Type[PublicEDDSA] + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 8080, section 4.""" + signature = self.key.sign(data) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls) -> "PrivateEDDSA": + return cls(key=cls.key_cls.generate()) + + +class PublicED25519(PublicEDDSA): + key: ed25519.Ed25519PublicKey + key_cls = ed25519.Ed25519PublicKey + algorithm = Algorithm.ED25519 + + +class PrivateED25519(PrivateEDDSA): + key: ed25519.Ed25519PrivateKey + key_cls = ed25519.Ed25519PrivateKey + public_cls = PublicED25519 + + +class PublicED448(PublicEDDSA): + key: ed448.Ed448PublicKey + key_cls = ed448.Ed448PublicKey + algorithm = Algorithm.ED448 + + +class PrivateED448(PrivateEDDSA): + key: ed448.Ed448PrivateKey + key_cls = ed448.Ed448PrivateKey + public_cls = PublicED448 diff --git a/dns/dnssecalgs/rsa.py b/dns/dnssecalgs/rsa.py new file mode 100644 index 00000000..e95dcf1d --- /dev/null +++ b/dns/dnssecalgs/rsa.py @@ -0,0 +1,119 @@ +import math +import struct + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa + +from dns.dnssecalgs.cryptography import CryptographyPrivateKey, CryptographyPublicKey +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + + +class PublicRSA(CryptographyPublicKey): + key: rsa.RSAPublicKey + key_cls = rsa.RSAPublicKey + algorithm: Algorithm + chosen_hash: hashes.HashAlgorithm + + def verify(self, signature: bytes, data: bytes) -> None: + self.key.verify(signature, data, padding.PKCS1v15(), self.chosen_hash) + + def encode_key_bytes(self) -> bytes: + """Encode a public key per RFC 3110, section 2.""" + pn = self.key.public_numbers() + _exp_len = math.ceil(int.bit_length(pn.e) / 8) + exp = int.to_bytes(pn.e, length=_exp_len, byteorder="big") + if _exp_len > 255: + exp_header = b"\0" + struct.pack("!H", _exp_len) + else: + exp_header = struct.pack("!B", _exp_len) + if pn.n.bit_length() < 512 or pn.n.bit_length() > 4096: + raise ValueError("unsupported RSA key length") + return exp_header + exp + pn.n.to_bytes((pn.n.bit_length() + 7) // 8, "big") + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicRSA": + cls._ensure_algorithm_key_combination(key) + keyptr = key.key + (bytes_,) = struct.unpack("!B", keyptr[0:1]) + keyptr = keyptr[1:] + if bytes_ == 0: + (bytes_,) = struct.unpack("!H", keyptr[0:2]) + keyptr = keyptr[2:] + rsa_e = keyptr[0:bytes_] + rsa_n = keyptr[bytes_:] + return cls( + key=rsa.RSAPublicNumbers( + int.from_bytes(rsa_e, "big"), int.from_bytes(rsa_n, "big") + ).public_key(default_backend()) + ) + + +class PrivateRSA(CryptographyPrivateKey): + key: rsa.RSAPrivateKey + key_cls = rsa.RSAPrivateKey + public_cls = PublicRSA + default_public_exponent = 65537 + + def sign(self, data: bytes, verify: bool = False) -> bytes: + """Sign using a private key per RFC 3110, section 3.""" + signature = self.key.sign(data, padding.PKCS1v15(), self.public_cls.chosen_hash) + if verify: + self.public_key().verify(signature, data) + return signature + + @classmethod + def generate(cls, key_size: int) -> "PrivateRSA": + return cls( + key=rsa.generate_private_key( + public_exponent=cls.default_public_exponent, + key_size=key_size, + backend=default_backend(), + ) + ) + + +class PublicRSAMD5(PublicRSA): + algorithm = Algorithm.RSAMD5 + chosen_hash = hashes.MD5() + + +class PrivateRSAMD5(PrivateRSA): + public_cls = PublicRSAMD5 + + +class PublicRSASHA1(PublicRSA): + algorithm = Algorithm.RSASHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1(PrivateRSA): + public_cls = PublicRSASHA1 + + +class PublicRSASHA1NSEC3SHA1(PublicRSA): + algorithm = Algorithm.RSASHA1NSEC3SHA1 + chosen_hash = hashes.SHA1() + + +class PrivateRSASHA1NSEC3SHA1(PrivateRSA): + public_cls = PublicRSASHA1NSEC3SHA1 + + +class PublicRSASHA256(PublicRSA): + algorithm = Algorithm.RSASHA256 + chosen_hash = hashes.SHA256() + + +class PrivateRSASHA256(PrivateRSA): + public_cls = PublicRSASHA256 + + +class PublicRSASHA512(PublicRSA): + algorithm = Algorithm.RSASHA512 + chosen_hash = hashes.SHA512() + + +class PrivateRSASHA512(PrivateRSA): + public_cls = PublicRSASHA512 diff --git a/dns/exception.py b/dns/exception.py index 4b1481d1..6982373d 100644 --- a/dns/exception.py +++ b/dns/exception.py @@ -140,6 +140,22 @@ class Timeout(DNSException): super().__init__(*args, **kwargs) +class UnsupportedAlgorithm(DNSException): + """The DNSSEC algorithm is not supported.""" + + +class AlgorithmKeyMismatch(UnsupportedAlgorithm): + """The DNSSEC algorithm is not supported for the given key type.""" + + +class ValidationFailure(DNSException): + """The DNSSEC signature is invalid.""" + + +class DeniedByPolicy(DNSException): + """Denied by DNSSEC policy.""" + + class ExceptionWrapper: def __init__(self, exception_class): self.exception_class = exception_class diff --git a/tests/test_dnssec.py b/tests/test_dnssec.py index 7177dcda..82482063 100644 --- a/tests/test_dnssec.py +++ b/tests/test_dnssec.py @@ -923,6 +923,7 @@ class DNSSECValidatorTestCase(unittest.TestCase): ) +@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported") class DNSSECMiscTestCase(unittest.TestCase): def testDigestToBig(self): with self.assertRaises(ValueError): @@ -932,13 +933,6 @@ class DNSSECMiscTestCase(unittest.TestCase): with self.assertRaises(ValueError): dns.dnssec.NSEC3Hash.make(256) - def testIsNotGOST(self): - self.assertTrue(dns.dnssec._is_gost(dns.dnssec.Algorithm.ECCGOST)) - - def testUnknownHash(self): - with self.assertRaises(dns.dnssec.ValidationFailure): - dns.dnssec._make_hash(100) - def testToTimestamp(self): REFERENCE_TIMESTAMP = 441812220 @@ -1039,6 +1033,7 @@ class DNSSECMiscTestCase(unittest.TestCase): self.assertEqual(zone1.to_text(), zone2.to_text()) +@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported") class DNSSECMakeDSTestCase(unittest.TestCase): def testMnemonicParser(self): good_ds_mnemonic = dns.rdata.from_text( @@ -1269,10 +1264,10 @@ class DNSSECMakeDNSKEYTestCase(unittest.TestCase): key_size=1024, backend=default_backend(), ) - with self.assertRaises(dns.dnssec.AlgorithmKeyMismatch): + with self.assertRaises(dns.exception.AlgorithmKeyMismatch): dns.dnssec.make_dnskey(key.public_key(), dns.dnssec.Algorithm.ED448) - with self.assertRaises(TypeError): + with self.assertRaises(dns.exception.AlgorithmKeyMismatch): dns.dnssec.make_dnskey("xyzzy", dns.dnssec.Algorithm.ED448) key = dsa.generate_private_key(2048) diff --git a/tests/test_dnssecalgs.py b/tests/test_dnssecalgs.py new file mode 100644 index 00000000..8f6f9bd7 --- /dev/null +++ b/tests/test_dnssecalgs.py @@ -0,0 +1,304 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import os +import unittest + +import dns.dnssec +import dns.exception +from dns.dnssectypes import Algorithm +from dns.rdtypes.ANY.DNSKEY import DNSKEY + +try: + from dns.dnssecalgs import ( + get_algorithm_cls, + get_algorithm_cls_from_dnskey, + register_algorithm_cls, + ) + from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 + from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 + from dns.dnssecalgs.eddsa import PrivateED448, PrivateED25519, PublicED25519 + from dns.dnssecalgs.rsa import ( + PrivateRSAMD5, + PrivateRSASHA1, + PrivateRSASHA1NSEC3SHA1, + PrivateRSASHA256, + PrivateRSASHA512, + ) +except ImportError: + pass # Cryptography ImportError already handled in dns.dnssec + + +@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported") +class DNSSECAlgorithm(unittest.TestCase): + def _test_dnssec_alg(self, private_cls, key_size=None): + public_cls = private_cls.public_cls + + private_key = ( + private_cls.generate(key_size) if key_size else private_cls.generate() + ) + + # sign random data + data = os.urandom(1024) + signature = private_key.sign(data, verify=True) + + # validate signature using public key + public_key = private_key.public_key() + public_key.verify(signature, data) + + # create DNSKEY + dnskey = public_key.to_dnskey() + dnskey2 = public_cls.from_dnskey(dnskey).to_dnskey() + self.assertEqual(dnskey, dnskey2) + + # test cryptography keys + _ = private_cls(key=private_key.key) + _ = public_cls(key=public_key.key) + + # to/from PEM + password = b"mekmitasdigoat" + private_pem = private_key.to_pem() + private_pem_encrypted = private_key.to_pem(password=password) + public_pem = public_key.to_pem() + _ = private_cls.from_pem(private_pem) + _ = private_cls.from_pem(private_pem_encrypted, password) + _ = public_cls.from_pem(public_pem) + + def test_rsa(self): + self._test_dnssec_alg(PrivateRSAMD5, 2048) + self._test_dnssec_alg(PrivateRSASHA1, 2048) + self._test_dnssec_alg(PrivateRSASHA1NSEC3SHA1, 2048) + self._test_dnssec_alg(PrivateRSASHA256, 2048) + self._test_dnssec_alg(PrivateRSASHA512, 2048) + + def test_dsa(self): + self._test_dnssec_alg(PrivateDSA, 1024) + self._test_dnssec_alg(PrivateDSANSEC3SHA1, 1024) + with self.assertRaises(ValueError): + k = PrivateDSA.generate(2048) + k.sign(b"hello") + + def test_ecdsa(self): + self._test_dnssec_alg(PrivateECDSAP256SHA256) + self._test_dnssec_alg(PrivateECDSAP384SHA384) + + def test_eddsa(self): + self._test_dnssec_alg(PrivateED25519) + self._test_dnssec_alg(PrivateED448) + + def test_algorithm_mismatch(self): + private_key_ed448 = PrivateED448.generate() + dnskey_ed448 = private_key_ed448.public_key().to_dnskey() + with self.assertRaises(dns.exception.AlgorithmKeyMismatch): + PublicED25519.from_dnskey(dnskey_ed448) + + +@unittest.skipUnless(dns.dnssec._have_pyca, "Python Cryptography cannot be imported") +class DNSSECAlgorithmPrivateAlgorithm(unittest.TestCase): + def test_private(self): + class PublicExampleAlgorithm(PublicED25519): + algorithm = Algorithm.PRIVATEDNS + name = dns.name.from_text("algorithm.example.com") + + def encode_key_bytes(self) -> bytes: + return self.name.to_wire() + super().encode_key_bytes() + + @classmethod + def from_dnskey(cls, key: DNSKEY) -> "PublicEDDSA": + return cls( + key=cls.key_cls.from_public_bytes( + key.key[len(cls.name.to_wire()) :] + ), + ) + + class PrivateExampleAlgorithm(PrivateED25519): + public_cls = PublicExampleAlgorithm + + register_algorithm_cls( + algorithm=Algorithm.PRIVATEDNS, + algorithm_cls=PrivateExampleAlgorithm, + name=PublicExampleAlgorithm.name, + ) + + private_key = PrivateExampleAlgorithm.generate() + public_key = private_key.public_key() + + name = dns.name.from_text("example.com") + rdataset = dns.rdataset.from_text_list("in", "a", 30, ["10.0.0.1", "10.0.0.2"]) + rrset = (name, rdataset) + ttl = 60 + lifetime = 3600 + rrname = rrset[0] + signer = rrname + dnskey = dns.dnssec.make_dnskey( + public_key=public_key, algorithm=Algorithm.PRIVATEDNS + ) + dnskey_rrset = dns.rrset.from_rdata(signer, ttl, dnskey) + + rrsig = dns.dnssec.sign( + rrset=rrset, + private_key=private_key, + dnskey=dnskey, + lifetime=lifetime, + signer=signer, + verify=True, + policy=None, + ) + + keys = {signer: dnskey_rrset} + rrsigset = dns.rrset.from_rdata(rrname, ttl, rrsig) + dns.dnssec.validate(rrset=rrset, rrsigset=rrsigset, keys=keys, policy=None) + + def test_register(self): + register_algorithm_cls( + algorithm=Algorithm.PRIVATEDNS, + algorithm_cls=PrivateED25519, + name="ed25519.example.com", + ) + register_algorithm_cls( + algorithm=Algorithm.PRIVATEOID, + algorithm_cls=PrivateED448, + oid=bytes([1, 2, 3, 4]), + ) + register_algorithm_cls( + algorithm=251, + algorithm_cls=PrivateED25519, + ) + + with self.assertRaises(TypeError): + register_algorithm_cls(algorithm=251, algorithm_cls=str, name="example.com") + + with self.assertRaises(ValueError): + register_algorithm_cls( + algorithm=251, algorithm_cls=PrivateED25519, name="example.com" + ) + + with self.assertRaises(ValueError): + register_algorithm_cls( + algorithm=251, algorithm_cls=PrivateED25519, oid=bytes([1, 2, 3, 4]) + ) + + with self.assertRaises(ValueError): + register_algorithm_cls( + algorithm=Algorithm.PRIVATEDNS, + algorithm_cls=PrivateED25519, + oid=bytes([1, 2, 3, 4]), + ) + + with self.assertRaises(ValueError): + register_algorithm_cls( + algorithm=Algorithm.PRIVATEOID, + algorithm_cls=PrivateED25519, + name="example.com", + ) + + dnskey_251 = DNSKEY( + "IN", + "DNSKEY", + 256, + 3, + 251, + b"hello", + ) + dnskey_dns = DNSKEY( + "IN", + "DNSKEY", + 256, + 3, + Algorithm.PRIVATEDNS, + dns.name.from_text("ed25519.example.com").to_wire() + b"hello", + ) + dnskey_dns_unknown = DNSKEY( + "IN", + "DNSKEY", + 256, + 3, + Algorithm.PRIVATEDNS, + dns.name.from_text("unknown.example.com").to_wire() + b"hello", + ) + dnskey_oid = DNSKEY( + "IN", + "DNSKEY", + 256, + 3, + Algorithm.PRIVATEOID, + bytes([4, 1, 2, 3, 4]) + b"hello", + ) + dnskey_oid_unknown = DNSKEY( + "IN", + "DNSKEY", + 256, + 3, + Algorithm.PRIVATEOID, + bytes([4, 42, 42, 42, 42]) + b"hello", + ) + + with self.assertRaises(dns.exception.UnsupportedAlgorithm): + _ = get_algorithm_cls(250) + + algorithm_cls = get_algorithm_cls(251) + self.assertEqual(algorithm_cls, PrivateED25519) + + algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_251) + self.assertEqual(algorithm_cls, PrivateED25519) + + algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_dns) + self.assertEqual(algorithm_cls, PrivateED25519) + + with self.assertRaises(dns.exception.UnsupportedAlgorithm): + _ = get_algorithm_cls_from_dnskey(dnskey_dns_unknown) + + algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_oid) + self.assertEqual(algorithm_cls, PrivateED448) + + with self.assertRaises(dns.exception.UnsupportedAlgorithm): + _ = get_algorithm_cls_from_dnskey(dnskey_oid_unknown) + + def test_register_canonical_lookup(self): + register_algorithm_cls( + algorithm=Algorithm.PRIVATEDNS, + algorithm_cls=PrivateED25519, + name="testing1234.example.com", + ) + + dnskey_dns = DNSKEY( + "IN", + "DNSKEY", + 256, + 3, + Algorithm.PRIVATEDNS, + dns.name.from_text("TESTING1234.EXAMPLE.COM").to_wire() + b"hello", + ) + + algorithm_cls = get_algorithm_cls_from_dnskey(dnskey_dns) + self.assertEqual(algorithm_cls, PrivateED25519) + + def test_register_private_without_prefix(self): + with self.assertRaises(ValueError): + register_algorithm_cls( + algorithm=Algorithm.PRIVATEDNS, + algorithm_cls=PrivateED25519, + ) + with self.assertRaises(ValueError): + register_algorithm_cls( + algorithm=Algorithm.PRIVATEOID, + algorithm_cls=PrivateED25519, + ) + + +if __name__ == "__main__": + unittest.main()